Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MJX] jax.lax.while_loop in solver.py prevents computation of backward gradients #2259

Open
EGalahad opened this issue Nov 29, 2024 · 5 comments
Assignees
Labels
enhancement New feature or request good first issue Good for newcomers

Comments

@EGalahad
Copy link

The feature, motivation and pitch

Problem

The solver's jax.lax.while_loop implementation prevents gradient computation through the environment step during gradient based trajectory optimization. This occurs in the solver implementation when iterations > 1.

Error encountered with jax.jit compiled grad function:

ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values.

Current workaround of using opt.iteration=1 leads to potentially inaccurate simulation and gradients.

Proposed Solution

Add an option to set a fixed iteration count (e.g., 4) that would be compatible with reverse-mode differentiation using either lax.scan or lax.fori_loop with static bounds.

Alternatives

No response

Additional context

No response

@EGalahad EGalahad added the enhancement New feature or request label Nov 29, 2024
@erikfrey erikfrey added the good first issue Good for newcomers label Dec 3, 2024
@erikfrey
Copy link
Collaborator

erikfrey commented Dec 3, 2024

I like this suggestion and have labeled it as a good one for someone to take on externally. If no one does, we'll eventually implement it ourselves.

If someone would like to try it, I'd recommend briefly proposing (in this issue) how to modify the API to expose this functionality, and then if we all agree, then open a PR.

@jaraujo98
Copy link

@erikfrey are you still looking for a volunteer to tackle this? I'd like to give it a shot.

@varshneydevansh
Copy link

Hi @erikfrey

I previously did try to contribute to this project. But my PR is still pending(now can complete as I now have some better experience with a large codebase[LibreOffice])

Problem understanding:

jax.lax.while_loop in the solver prevents gradient computation during backpropagation, we need to replace the dynamic loop with a static one when a fixed iteration count is specified.

Solution proposed:

All Im thinking of introducing a new boolean option in the Model.opt. based on that, the solver will use a static loop with a fixed number of iterations, enabling gradient computation.

Named something like - static_iterations

Code modifications:

Check static_iterations Flag: In the solver's solve function, we can use jax.lax.fori_loop instead of jax.lax.while_loop when static_iterations is enabled.

Run Fixed Iterations: When static_iterations is True, execute the solver loop exactly m.opt.iterations times, bypassing the convergence checks.

Basically replacing jax.lax.while_loop with jax.lax.scan or with jax.lax.fori_loop when the flag is enabled?

I'm thinking of using jax.lax.fori_loop seems more appropriate choice

@yuvaltassa
Copy link
Collaborator

In C MuJoCo there is a trivial way to fix the number of iterations: set mjModel.opt.tolerance = 0.

But I'll let @erikfrey comment on the correct way to do this in JAX

@varshneydevansh
Copy link

I tried to understand regarding setting mjModel.opt.tolerance = 0 i.e. this effectively disables the convergence check which means the solver will continue iterating until it reaches the maximum number of iterations specified, regardless of whether the solution has converged.

It might lead to slightly less accurate solutions because the solver may continue even after reaching a satisfactory solution.

But, In Jax we are using a while(dynamic) loop which stops when the tolerance criteria met and having a fixed number of iterations could provide a deterministic behavior just like C MuJoCo.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

6 participants