You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi all! I'm a PhD student working on ML/PL (automatic differentiation, and probabilistic programming). Note that I have looked at the other discussion on AD in Mojo, so I’m hoping to probe into a few further details.
I've used JAX to both design and implement program transformations (as interpreters, which I stage out at compile time, that implement the semantics I'm curious about). These program transformations are quite advanced, some involved mixed usage of continuation passing style (CPS) transformations, as well as custom AD logic, mixed with probabilistic programming interpreters. On the other hand, JAX makes it easy to both write these transformations to operate at compile time, and to compose them together when required.
I'd love to understand:
Is anyone thinking about program transformations as a first class construct -- like JAX / lightweight modular staging / staged meta-circular evaluation) in Mojo? A pithy way to say it would be: composable transformations on typed IR. JAX is one model of this.
Any stdlib developers / experts think that a "JAX-like" thing possible? I've come to really appreciate their model of composition of transformations.
One issue that immediately arises: Python's typing is permissive enough to support JAX's symbolic tracer values (which the entire edifice of JAX rests upon) -- Mojo has static systems, and when static systems are introduced (just like Julia) "the JAX easy route" becomes harder (if a function is typed, and checked for types, this can prevent tracing with symbolic values from working).
There's certain classes of models / inference processes (think: non-parametrics, grammars, etc) that I'd like to write in JAX, but of course, it's hard, because (a) JAX is designed for array programming first and foremost and (b) I have to conform to the static guarantees that JAX imposes on arrays (to ensure XLA can be aggressive).
Can anyone comment on the AD design implications of this (dynamic shapes)?
In JAX, the compute graph is also static (like Mojo, as far as I can tell? Experts please weigh in)
JAX also has a great model for AD: https://arxiv.org/abs/2204.10923 -- writing forward mode rules are easy, and then pullbacks are automatically defined via a composition of transformations.
I'm in no place to make suggestions, but if Mojo AD was "JAX, but with dynamic shapes" -- that would be a pretty incredible value proposition, with a well investigated engineering path. Even more so if the program transformation model from JAX was also migrated...
Sorry for the dump! Excited to see where the project goes. Thanks for any insight.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi all! I'm a PhD student working on ML/PL (automatic differentiation, and probabilistic programming). Note that I have looked at the other discussion on AD in Mojo, so I’m hoping to probe into a few further details.
I've used JAX to both design and implement program transformations (as interpreters, which I stage out at compile time, that implement the semantics I'm curious about). These program transformations are quite advanced, some involved mixed usage of continuation passing style (CPS) transformations, as well as custom AD logic, mixed with probabilistic programming interpreters. On the other hand, JAX makes it easy to both write these transformations to operate at compile time, and to compose them together when required.
Sorry for the dump! Excited to see where the project goes. Thanks for any insight.
Beta Was this translation helpful? Give feedback.
All reactions