-
Notifications
You must be signed in to change notification settings - Fork 881
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MJX] jax.lax.while_loop in solver.py prevents computation of backward gradients #2259
Comments
I like this suggestion and have labeled it as a good one for someone to take on externally. If no one does, we'll eventually implement it ourselves. If someone would like to try it, I'd recommend briefly proposing (in this issue) how to modify the API to expose this functionality, and then if we all agree, then open a PR. |
@erikfrey are you still looking for a volunteer to tackle this? I'd like to give it a shot. |
Hi @erikfrey I previously did try to contribute to this project. But my PR is still pending(now can complete as I now have some better experience with a large codebase[LibreOffice]) Problem understanding:jax.lax.while_loop in the solver prevents gradient computation during backpropagation, we need to replace the dynamic loop with a static one when a fixed iteration count is specified. Solution proposed:All Im thinking of introducing a new boolean option in the Model.opt. based on that, the solver will use a static loop with a fixed number of iterations, enabling gradient computation. Named something like - static_iterations Code modifications:Check static_iterations Flag: In the solver's solve function, we can use jax.lax.fori_loop instead of jax.lax.while_loop when static_iterations is enabled. Run Fixed Iterations: When static_iterations is True, execute the solver loop exactly m.opt.iterations times, bypassing the convergence checks. Basically replacing jax.lax.while_loop with jax.lax.scan or with jax.lax.fori_loop when the flag is enabled? I'm thinking of using jax.lax.fori_loop seems more appropriate choice |
In C MuJoCo there is a trivial way to fix the number of iterations: set But I'll let @erikfrey comment on the correct way to do this in JAX |
I tried to understand regarding setting mjModel.opt.tolerance = 0 i.e. this effectively disables the convergence check which means the solver will continue iterating until it reaches the maximum number of iterations specified, regardless of whether the solution has converged. It might lead to slightly less accurate solutions because the solver may continue even after reaching a satisfactory solution. But, In Jax we are using a while(dynamic) loop which stops when the tolerance criteria met and having a fixed number of iterations could provide a deterministic behavior just like C MuJoCo. |
The feature, motivation and pitch
Problem
The solver's
jax.lax.while_loop
implementation prevents gradient computation through the environment step during gradient based trajectory optimization. This occurs in the solver implementation when iterations > 1.Error encountered with
jax.jit
compiled grad function:Current workaround of using
opt.iteration=1
leads to potentially inaccurate simulation and gradients.Proposed Solution
Add an option to set a fixed iteration count (e.g., 4) that would be compatible with reverse-mode differentiation using either
lax.scan
orlax.fori_loop
with static bounds.Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: