Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ERROR: Nan in state inside plant #5

Open
XinChen-stars opened this issue Sep 19, 2024 · 3 comments
Open

ERROR: Nan in state inside plant #5

XinChen-stars opened this issue Sep 19, 2024 · 3 comments

Comments

@XinChen-stars
Copy link

Hi, thank you for sharing the code!

Description

I tried following the MPC Usage: Listing 2 & Listing 3
the cartpole_plant_example running result:
cartpole
However,when trying to run the quadrotor example it returned the following error:
quadrotor_error
quadrotor plant code :

#pragma once
#include <mppi/core/base_plant.hpp>
#include <mppi/dynamics/quadrotor/quadrotor_dynamics.cuh>
template <class CONTROLLER_T>
class SimpleQuadrotorPlant : public BasePlant<CONTROLLER_T>
{
public:
 using control_array = typename QuadrotorDynamics::control_array;
 using state_array = typename QuadrotorDynamics::state_array;
 using output_array = typename QuadrotorDynamics::output_array;

 SimpleQuadrotorPlant(std::shared_ptr<CONTROLLER_T> controller, int hz, int optimization_stride)
   : BasePlant<CONTROLLER_T>(controller, hz, optimization_stride)
 {
   system_dynamics_ = std::make_shared<QuadrotorDynamics>();
 }

 void pubControl(const control_array& u)
 {
   state_array state_derivative;
   output_array dynamics_output;
   state_array prev_state = current_state_;
   float t = this->state_time_;
   float dt = this->controller_->getDt();
   system_dynamics_->step(prev_state, current_state_, state_derivative, u, dynamics_output, t, dt);
   current_time_ += dt;
 }

 void pubNominalState(const state_array& s)
 {
 }

 void pubFreeEnergyStatistics(MPPIFreeEnergyStatistics& fe_stats)
 {
 }

 int checkStatus()
 {
   return 0;
 }

 double getCurrentTime()
 {
   return current_time_;
 }

 double getPoseTime()
 {
   return this->state_time_;
 }

 double getAvgLoopTime() const
 {
   return this->avg_loop_time_ms_;
 }

 double getLastOptimizationTime() const
 {
   return this->optimization_duration_;
 }

 state_array current_state_ = state_array::Zero();

protected:
 std::shared_ptr<QuadrotorDynamics> system_dynamics_;
 double current_time_ = 0.0;
};

quadrotor example code :

#include <mppi/instantiations/quadrotor_mppi/quadrotor_mppi.cuh>
#include <quadrotor_plant.hpp>

const int NUM_TIMESTEPS = 100;
const int NUM_ROLLOUTS = 1024;
const int DYN_BLOCK_X = 32;
using DYN_T = QuadrotorDynamics;
const int DYN_BLOCK_Y = DYN_T::STATE_DIM;
using COST_T = QuadrotorQuadraticCost;
using FB_T = DDPFeedback<DYN_T, NUM_TIMESTEPS>;
using SAMPLING_T = mppi::sampling_distributions::GaussianDistribution<DYN_T::DYN_PARAMS_T>;
using CONTROLLER_T = VanillaMPPIController<DYN_T, COST_T, FB_T, NUM_TIMESTEPS, NUM_ROLLOUTS, SAMPLING_T>;
using CONTROLLER_PARAMS_T = CONTROLLER_T::TEMPLATED_PARAMS;

using PLANT_T = SimpleQuadrotorPlant<CONTROLLER_T>;

int main(int argc, char** argv)
{
 float dt = 0.02;
 DYN_T dynamics;                     // set up dynamics
 COST_T cost;                        // set up cost
 FB_T fb_controller(&dynamics, dt);  // set up feedback controller
 // set up sampling distribution
 SAMPLING_T sampler;
 auto sampler_params = sampler.getParams();
 std::fill(sampler_params.std_dev, sampler_params.std_dev + DYN_T::CONTROL_DIM, 10.0);
 sampler.setParams(sampler_params);

 // set up MPPI Controller
 CONTROLLER_PARAMS_T controller_params;
 controller_params.dt_ = dt;
 controller_params.lambda_ = 1.0;
 controller_params.dynamics_rollout_dim_ = dim3(DYN_BLOCK_X, DYN_BLOCK_Y, 1);
 controller_params.cost_rollout_dim_ = dim3(96, 1, 1);
 std::shared_ptr<CONTROLLER_T> controller =
     std::make_shared<CONTROLLER_T>(&dynamics, &cost, &fb_controller, &sampler, controller_params);

 // Create plant
 PLANT_T plant(controller, (1.0 / dt), 1);

 std::atomic<bool> alive(true);
 for (int t = 0; t < 10000; t++)
 {
   plant.updateState(plant.current_state_, (t + 1) * dt);
   plant.runControlIteration(&alive);
 }

 std::cout << "Avg Optimization time: " << plant.getAvgOptimizationTime() << " ms" << std::endl;
 std::cout << "Last Optimization time: " << plant.getLastOptimizationTime() << " ms" << std::endl;
 std::cout << "Avg Loop time: " << plant.getAvgLoopTime() << " ms" << std::endl;
 std::cout << "Avg Optimization Hz: " << 1.0 / (plant.getAvgOptimizationTime() * 1e-3) << " Hz" << std::endl;

 auto control_sequence = controller->getControlSeq();
 std::cout << "State: \n" << plant.current_state_.transpose() << std::endl;
 std::cout << "Control Sequence:\n" << control_sequence << std::endl;
 return 0;
}

Is there anything I am overlooking on coding an quadrotor example?

@JasonGibson274
Copy link
Member

Hello,

Yes the quadrotor dynamics have quaternions in them. By initializing it all to zero it will throw nans when doing update state here https://github.com/ACDSLab/MPPI-Generic/blob/main/include/mppi/dynamics/quadrotor/quadrotor_dynamics.cu#L182C24-L182C77.

Just use the method getZeroState https://github.com/ACDSLab/MPPI-Generic/blob/main/include/mppi/dynamics/quadrotor/quadrotor_dynamics.cu#L212 to get a valid initial state.

@JasonGibson274
Copy link
Member

I will make a note to move the examples to using the getZeroState method to prevent this issue in the future.

@XinChen-stars
Copy link
Author

Thank you for your reply!
I solved this problem by using the method getZeroState:

state_array current_state_ = system_dynamics_->getZeroState();

success

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants