Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AILBlockWalker: Call and CallExpr handle .target. #264

Merged
merged 1 commit into from
Dec 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 33 additions & 21 deletions ailment/block_walker.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,14 @@ def _handle_Assignment(self, stmt_idx: int, stmt: Assignment, block: Block | Non
return None

def _handle_Call(self, stmt_idx: int, stmt: Call, block: Block | None):
if stmt.args:
changed = False
changed = False

new_target = self._handle_expr(-1, stmt.target, stmt_idx, stmt, block)
if new_target is not None and new_target is not stmt.target:
changed = True

new_args = None
if stmt.args is not None:
new_args = []

i = 0
Expand All @@ -321,19 +327,19 @@ def _handle_Call(self, stmt_idx: int, stmt: Call, block: Block | None):
new_args.append(arg)
i += 1

if changed:
new_stmt = Call(
stmt.idx,
stmt.target,
calling_convention=stmt.calling_convention,
prototype=stmt.prototype,
args=new_args,
ret_expr=stmt.ret_expr,
**stmt.tags,
)
if self._update_block and block is not None:
block.statements[stmt_idx] = new_stmt
return new_stmt
if changed:
new_stmt = Call(
stmt.idx,
new_target if new_target is not None else stmt.target,
calling_convention=stmt.calling_convention,
prototype=stmt.prototype,
args=new_args,
ret_expr=stmt.ret_expr,
**stmt.tags,
)
if self._update_block and block is not None:
block.statements[stmt_idx] = new_stmt
return new_stmt
return None

def _handle_Store(self, stmt_idx: int, stmt: Store, block: Block | None):
Expand Down Expand Up @@ -485,7 +491,12 @@ def _handle_Load(self, expr_idx: int, expr: Load, stmt_idx: int, stmt: Statement
def _handle_CallExpr(self, expr_idx: int, expr: Call, stmt_idx: int, stmt: Statement, block: Block | None):
changed = False

if expr.args:
new_target = self._handle_expr(-1, expr.target, stmt_idx, stmt, block)
if new_target is not None and new_target is not expr.target:
changed = True

new_args = None
if expr.args is not None:
i = 0
new_args = []
while i < len(expr.args):
Expand All @@ -502,11 +513,12 @@ def _handle_CallExpr(self, expr_idx: int, expr: Call, stmt_idx: int, stmt: State
new_args.append(arg)
i += 1

if changed:
expr = expr.copy()
expr.args = new_args
return expr

if changed:
expr = expr.copy()
if new_target is not None:
expr.target = new_target
expr.args = new_args
return expr
return None

def _handle_BinaryOp(self, expr_idx: int, expr: BinaryOp, stmt_idx: int, stmt: Statement, block: Block | None):
Expand Down
Loading