Skip to content

Commit

Permalink
add converter for square (#954)
Browse files Browse the repository at this point in the history
Summary:

as titled

Reviewed By: jfix71, khabinov

Differential Revision: D50425718
  • Loading branch information
Cyrus Daruwala authored and facebook-github-bot committed Oct 26, 2023
1 parent dd89780 commit 910629a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
12 changes: 12 additions & 0 deletions fx2ait/fx2ait/converters/ait_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,18 @@ def acc_ops_mul(
return create_binary_op(FuncEnum.MUL, args, kwargs, name)


@ait_converter(acc_ops.square)
def acc_ops_square(
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> ConverterOutput:
new_kwargs = dict(kwargs.copy())
new_kwargs["other"] = new_kwargs["input"]
return create_binary_op(FuncEnum.MUL, args, new_kwargs, name)


@ait_converter(acc_ops.div)
def acc_ops_div(
target: Target,
Expand Down
29 changes: 29 additions & 0 deletions fx2ait/fx2ait/test/converters/test_ait_square.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from fx2ait.acc_tracer import acc_ops
from fx2ait.tools.common_fx2ait import AITTestCase


class TestSquareConverter(AITTestCase):
def test_square(self):
class TestModule(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.square(x)

inputs = [torch.randn(3, 10, 20).cuda().half()]
model = TestModule().cuda().half()

self.run_test(model, inputs, expected_ops={acc_ops.square})

0 comments on commit 910629a

Please sign in to comment.