You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
importonnxscriptcommon_opset=onnxscript.values.Opset(domain="pkg.onnxscript.torch_lib.common", version=1)
torchlib_opset=onnxscript.values.Opset(domain="pkg.onnxscript.torch_lib", version=1)
@onnxscript.script(common_opset)defIsScalar(input):
"""Return whether the input has rank 0, or is a scalar."""returnop.Equal(op.Size(op.Shape(input)), op.Constant(value_int=0))
@onnxscript.script(torchlib_opset)defaten_clamp_max(self, max_):
"""clamp_max(Tensor self, Tensor max) -> Tensor"""self_size=op.Size(self)
max_shape=op.Shape(max_)
ifself_size==0:
result=op.Expand(self, max_shape)
else:
ifIsScalar(max_):
max_=op.CastLike(max_, self)
result=op.Clip(self, None, max_)
else:
result=op.Min(self, max_)
returnresult
where IsScalar is an OnnxFunxtion from a custom opset does not have that opset imported for the function. I notice IsScalar is used in an if branch/subgraph so that may be the issue.
Generated model:
E <
E ir_version: 8,
E opset_import: ["pkg.onnxscript.torch_lib" : 1, "" : 18, "pkg.onnxscript.torch_lib.common" : 1],
E producer_name: "pytorch",
E producer_version: "2.2.0"
E >
E main_graph (float16[5] input_0, float16 input_1) => (float16[5] _val_2) {
E _val_2 = pkg.onnxscript.torch_lib.aten_clamp_max (input_0, input_1)
E }
E <
E domain: "pkg.onnxscript.torch_lib",
E opset_import: ["" : 18]
E >
E aten_clamp_max (self, max_) => (result_5)
E {
E self_size = Size (self)
E max_shape = Shape (max_)
E int64_0 = Constant <value: tensor = int64 int64_0 {0}> ()
E int64_0_cast = CastLike (int64_0, self_size)
E cond = Equal (self_size, int64_0_cast)
E result_5 = If (cond) <then_branch: graph = thenGraph_7 () => ( result) {
E result = Expand (self, max_shape)
E }, else_branch: graph = elseGraph_7 () => ( result_4) {
E cond_0 = pkg.onnxscript.torch_lib.common.IsScalar (max_)
E result_4 = If (cond_0) <then_branch: graph = thenGraph_10 () => ( result_2) {
E max__1 = CastLike (max_, self)
E result_2 = Clip (self, , max__1)
E }, else_branch: graph = elseGraph_10 () => ( result_3) {
E result_3 = Min (self, max_)
E }>
E }>
E }
E <
E domain: "pkg.onnxscript.torch_lib.common",
E opset_import: ["" : 18]
E >
E Rank (input) => (return_val)
E {
E tmp = Shape (input)
E return_val = Size (tmp)
E }
E <
E domain: "pkg.onnxscript.torch_lib.common",
E opset_import: ["" : 18]
E >
E IsScalar (input) => (return_val)
E {
E tmp = Shape (input)
E tmp_0 = Size (tmp)
E tmp_1 = Constant <value_int: int = 0> ()
E return_val = Equal (tmp_0, tmp_1)
E }
The following function
where
IsScalar
is an OnnxFunxtion from a custom opset does not have that opset imported for the function. I notice IsScalar is used in anif
branch/subgraph so that may be the issue.Generated model:
Original issue onnx/onnx#5701
cc @gramalingam
The text was updated successfully, but these errors were encountered: