Skip to content

mx.vmap throws error "[tree_flatten] The argument should contain only arrays" #1824

Answered by awni
dasayan05 asked this question in Q&A
Discussion options

You must be logged in to vote

You can't vmap functions which take non-mx.array inputs. In your case that would be the callable S. You can instead use a closure or functools.partial. For example with a closure:

def _ism(x, score):
   ...

def vmappable_ism(x):
   return _ism(x, score)

mx.vmap(vmappable_ism, ...)

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@dasayan05
Comment options

@dasayan05
Comment options

@awni
Comment options

awni Feb 3, 2025
Maintainer

Answer selected by dasayan05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants