-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixing tensor.numpy on wrapped tensors #627
base: main
Are you sure you want to change the base?
Conversation
Fixes pytorch#626 Description: - Fixing tensor.numpy on wrapped tensors
level = _C.maybe_get_level(tensor) | ||
if level == -1: | ||
return _old_numpy(tensor) | ||
|
||
if _C.is_functionaltensor(tensor): | ||
# Since we're unwrapping the FunctionalTensorWrapper, we need to make sure | ||
# that it's up to date first | ||
torch._sync(tensor) | ||
|
||
value = _C.get_unwrapped(tensor) | ||
dl_enabled = _C.tls_set_is_included() | ||
try: | ||
# Disable temporarily kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys | ||
if (dl_enabled): | ||
_C._set_dynamic_layer_keys_included(False) | ||
return value.numpy() | ||
finally: | ||
# Reenable kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys | ||
if (dl_enabled): | ||
_C._set_dynamic_layer_keys_included(True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, so this is a little more complicated than this I think.
When someone calls .numpy() under vmap, we probably want to error out. Otherwise some weird things might happen:
def f(x):
return torch.tensor(x.numpy())
x = torch.randn(B)
vmap(f)(x) # returns a Tensor of size B, B -- is that what we want?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When someone calls .numpy() under the grad transform then we should support this (as long as there are no vmaps involved). I'm not sure what the best way to support this is... one thing we can do is keep unwrapping the Tensor and seeing that no BatchedTensors are involved.
In the long-term we want a better fix for this that perhaps involves making the pytorch dispatcher recognize .numpy() as an operation
Fixes #626
Description: