diff --git a/ailment/block_walker.py b/ailment/block_walker.py index 52490ba..e0a6203 100644 --- a/ailment/block_walker.py +++ b/ailment/block_walker.py @@ -3,7 +3,7 @@ from collections.abc import Callable from . import Block -from .statement import Call, Statement, ConditionalJump, Assignment, Store, Return +from .statement import Call, Statement, ConditionalJump, Assignment, Store, Return, Jump from .expression import ( Load, Expression, @@ -33,6 +33,7 @@ def __init__(self, stmt_handlers=None, expr_handlers=None): Call: self._handle_Call, Store: self._handle_Store, ConditionalJump: self._handle_ConditionalJump, + Jump: self._handle_Jump, Return: self._handle_Return, } @@ -115,6 +116,9 @@ def _handle_Store(self, stmt_idx: int, stmt: Store, block: Block | None): self._handle_expr(0, stmt.addr, stmt_idx, stmt, block) self._handle_expr(1, stmt.data, stmt_idx, stmt, block) + def _handle_Jump(self, stmt_idx: int, stmt: Jump, block: Block | None): + self._handle_expr(0, stmt.target, stmt_idx, stmt, block) + def _handle_ConditionalJump(self, stmt_idx: int, stmt: ConditionalJump, block: Block | None): self._handle_expr(0, stmt.condition, stmt_idx, stmt, block) self._handle_expr(1, stmt.true_target, stmt_idx, stmt, block) @@ -191,10 +195,44 @@ class AILBlockWalker(AILBlockWalkerBase): Walks all statements and expressions of an AIL node, and rebuilds expressions, statements, or blocks if needed. If you need a pure walker without rebuilding, use AILBlockWalkerBase instead. + + :ivar update_block: True if the block should be updated in place, False if a new block should be created and + returned as the result of walk(). """ - def __init__(self, stmt_handlers=None, expr_handlers=None): + def __init__(self, stmt_handlers=None, expr_handlers=None, update_block: bool = True): super().__init__(stmt_handlers=stmt_handlers, expr_handlers=expr_handlers) + self._update_block = update_block + + def walk(self, block: Block) -> Block | None: + """ + Walk the block and rebuild it if necessary. The block will be rebuilt in-place (by updating statements in the + original block when self._update_block is set to True), or a new block will be created and returned. + + :param block: The block to walk. + :return: The new block that is rebuilt, or None if the block is not changed or when self._update_block + is set to True. + """ + + changed = False + new_block: Block | None = None + + i = 0 + while i < len(block.statements): + stmt = block.statements[i] + new_stmt = self._handle_stmt(i, stmt, block) + if new_stmt is not None: + changed = True + if not self._update_block: + if new_block is None: + new_block = block.copy(statements=block.statements[:i]) + new_block.statements.append(new_stmt) + else: + if new_block is not None: + new_block.statements.append(stmt) + i += 1 + + return new_block if changed else None def _handle_stmt(self, stmt_idx: int, stmt: Statement, block: Block | None) -> Any: try: @@ -243,7 +281,7 @@ def _handle_Assignment(self, stmt_idx: int, stmt: Assignment, block: Block | Non if changed: # update the statement directly in the block new_stmt = Assignment(stmt.idx, dst, src, **stmt.tags) - if block is not None: + if self._update_block and block is not None: block.statements[stmt_idx] = new_stmt return new_stmt return None @@ -278,7 +316,7 @@ def _handle_Call(self, stmt_idx: int, stmt: Call, block: Block | None): ret_expr=stmt.ret_expr, **stmt.tags, ) - if block is not None: + if self._update_block and block is not None: block.statements[stmt_idx] = new_stmt return new_stmt return None @@ -311,7 +349,28 @@ def _handle_Store(self, stmt_idx: int, stmt: Store, block: Block | None): offset=stmt.offset, **stmt.tags, ) - if block is not None: + if self._update_block and block is not None: + block.statements[stmt_idx] = new_stmt + return new_stmt + return None + + def _handle_Jump(self, stmt_idx: int, stmt: Jump, block: Block | None): + changed = False + + target = self._handle_expr(0, stmt.target, stmt_idx, stmt, block) + if target is not None and target is not stmt.target: + changed = True + else: + target = stmt.target + + if changed: + new_stmt = Jump( + stmt.idx, + target, + target_idx=stmt.target_idx, + **stmt.tags, + ) + if self._update_block and block is not None: block.statements[stmt_idx] = new_stmt return new_stmt return None @@ -347,7 +406,7 @@ def _handle_ConditionalJump(self, stmt_idx: int, stmt: ConditionalJump, block: B false_target_idx=stmt.false_target_idx, **stmt.tags, ) - if block is not None: + if self._update_block and block is not None: block.statements[stmt_idx] = new_stmt return new_stmt return None @@ -368,7 +427,7 @@ def _handle_Return(self, stmt_idx: int, stmt: Return, block: Block | None): if changed: new_stmt = Return(stmt.idx, new_ret_exprs, **stmt.tags) - if block is not None: + if self._update_block and block is not None: block.statements[stmt_idx] = new_stmt return new_stmt return None