Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add runtime check for Gather Op #3069

Merged
merged 5 commits into from
Feb 8, 2025
Merged

Conversation

chentong319
Copy link
Collaborator

@chentong319 chentong319 commented Feb 7, 2025

Some onnx op, such as GatherOp and GatherElementsOp, requires the input value to be within certain range. For efficiency, onnx-mlir did not check that with the assumption that everything is correct.
This PR adds runtime check for the purpose of Debugging. The implementation is not the optimized. When the bug in Runtime bound check in llvm-project is fixed, we can use that for this purpose.
I tested this PR with RunONNXModel.py with certain inputs. Not sure how to add these test cases into standard test easily.

Future work: any other runtime check is needed?

Copy link
Collaborator

@AlexandreEichenberger AlexandreEichenberger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, since you have to merge current changes anyway, please && the two conditions for a slightly faster test.

Thanks for the quick turnaround on this.

"indices of GatherOp is larger than the upper bound");
Value compareLowerBound =
create.math.sge(index.getValue(), zeroIE.getValue());
rewriter.create<cf::AssertOp>(loc, compareLowerBound,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chentong319 I know it is not optimized for speed, "anding" both condition and calling assert only once would speed the check up a bit. create.math.andi(compareUpperBound, compareLowerBound).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separate the check to provide more accurate error message.

LiteralIndexExpr zero(0);
Value compareLowerBound =
create.math.sge(index.getValue(), zero.getValue());
rewriter.create<cf::AssertOp>(loc, compareLowerBound,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, can use andi.

Signed-off-by: Chen Tong <[email protected]>
Signed-off-by: Chen Tong <[email protected]>
@chentong319 chentong319 merged commit 584ee43 into onnx:main Feb 8, 2025
7 checks passed
@chentong319 chentong319 deleted the gather-check branch February 8, 2025 02:41
@jenkins-droid
Copy link
Collaborator

Jenkins Linux s390x Build #16269 [push] Add runtime check for Ga... started at 21:42

@jenkins-droid
Copy link
Collaborator

Jenkins Linux ppc64le Build #15296 [push] Add runtime check for Ga... started at 22:02

@jenkins-droid
Copy link
Collaborator

Jenkins Linux amd64 Build #16267 [push] Add runtime check for Ga... started at 20:42

@jenkins-droid
Copy link
Collaborator

Jenkins Linux amd64 Build #16267 [push] Add runtime check for Ga... passed after 1 hr 24 min

@jenkins-droid
Copy link
Collaborator

Jenkins Linux s390x Build #16269 [push] Add runtime check for Ga... passed after 1 hr 27 min

@jenkins-droid
Copy link
Collaborator

Jenkins Linux ppc64le Build #15296 [push] Add runtime check for Ga... passed after 2 hr 27 min

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants