Skip to content

Commit

Permalink
AILBlockWalker: Support off-band block updates. (#245)
Browse files Browse the repository at this point in the history
  • Loading branch information
ltfish authored Oct 9, 2024
1 parent bf3ddc1 commit f5fec81
Showing 1 changed file with 66 additions and 7 deletions.
73 changes: 66 additions & 7 deletions ailment/block_walker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Callable

from . import Block
from .statement import Call, Statement, ConditionalJump, Assignment, Store, Return
from .statement import Call, Statement, ConditionalJump, Assignment, Store, Return, Jump
from .expression import (
Load,
Expression,
Expand Down Expand Up @@ -33,6 +33,7 @@ def __init__(self, stmt_handlers=None, expr_handlers=None):
Call: self._handle_Call,
Store: self._handle_Store,
ConditionalJump: self._handle_ConditionalJump,
Jump: self._handle_Jump,
Return: self._handle_Return,
}

Expand Down Expand Up @@ -115,6 +116,9 @@ def _handle_Store(self, stmt_idx: int, stmt: Store, block: Block | None):
self._handle_expr(0, stmt.addr, stmt_idx, stmt, block)
self._handle_expr(1, stmt.data, stmt_idx, stmt, block)

def _handle_Jump(self, stmt_idx: int, stmt: Jump, block: Block | None):
self._handle_expr(0, stmt.target, stmt_idx, stmt, block)

def _handle_ConditionalJump(self, stmt_idx: int, stmt: ConditionalJump, block: Block | None):
self._handle_expr(0, stmt.condition, stmt_idx, stmt, block)
self._handle_expr(1, stmt.true_target, stmt_idx, stmt, block)
Expand Down Expand Up @@ -191,10 +195,44 @@ class AILBlockWalker(AILBlockWalkerBase):
Walks all statements and expressions of an AIL node, and rebuilds expressions, statements, or blocks if needed.
If you need a pure walker without rebuilding, use AILBlockWalkerBase instead.
:ivar update_block: True if the block should be updated in place, False if a new block should be created and
returned as the result of walk().
"""

def __init__(self, stmt_handlers=None, expr_handlers=None):
def __init__(self, stmt_handlers=None, expr_handlers=None, update_block: bool = True):
super().__init__(stmt_handlers=stmt_handlers, expr_handlers=expr_handlers)
self._update_block = update_block

def walk(self, block: Block) -> Block | None:
"""
Walk the block and rebuild it if necessary. The block will be rebuilt in-place (by updating statements in the
original block when self._update_block is set to True), or a new block will be created and returned.
:param block: The block to walk.
:return: The new block that is rebuilt, or None if the block is not changed or when self._update_block
is set to True.
"""

changed = False
new_block: Block | None = None

i = 0
while i < len(block.statements):
stmt = block.statements[i]
new_stmt = self._handle_stmt(i, stmt, block)
if new_stmt is not None:
changed = True
if not self._update_block:
if new_block is None:
new_block = block.copy(statements=block.statements[:i])
new_block.statements.append(new_stmt)
else:
if new_block is not None:
new_block.statements.append(stmt)
i += 1

return new_block if changed else None

def _handle_stmt(self, stmt_idx: int, stmt: Statement, block: Block | None) -> Any:
try:
Expand Down Expand Up @@ -243,7 +281,7 @@ def _handle_Assignment(self, stmt_idx: int, stmt: Assignment, block: Block | Non
if changed:
# update the statement directly in the block
new_stmt = Assignment(stmt.idx, dst, src, **stmt.tags)
if block is not None:
if self._update_block and block is not None:
block.statements[stmt_idx] = new_stmt
return new_stmt
return None
Expand Down Expand Up @@ -278,7 +316,7 @@ def _handle_Call(self, stmt_idx: int, stmt: Call, block: Block | None):
ret_expr=stmt.ret_expr,
**stmt.tags,
)
if block is not None:
if self._update_block and block is not None:
block.statements[stmt_idx] = new_stmt
return new_stmt
return None
Expand Down Expand Up @@ -311,7 +349,28 @@ def _handle_Store(self, stmt_idx: int, stmt: Store, block: Block | None):
offset=stmt.offset,
**stmt.tags,
)
if block is not None:
if self._update_block and block is not None:
block.statements[stmt_idx] = new_stmt
return new_stmt
return None

def _handle_Jump(self, stmt_idx: int, stmt: Jump, block: Block | None):
changed = False

target = self._handle_expr(0, stmt.target, stmt_idx, stmt, block)
if target is not None and target is not stmt.target:
changed = True
else:
target = stmt.target

if changed:
new_stmt = Jump(
stmt.idx,
target,
target_idx=stmt.target_idx,
**stmt.tags,
)
if self._update_block and block is not None:
block.statements[stmt_idx] = new_stmt
return new_stmt
return None
Expand Down Expand Up @@ -347,7 +406,7 @@ def _handle_ConditionalJump(self, stmt_idx: int, stmt: ConditionalJump, block: B
false_target_idx=stmt.false_target_idx,
**stmt.tags,
)
if block is not None:
if self._update_block and block is not None:
block.statements[stmt_idx] = new_stmt
return new_stmt
return None
Expand All @@ -368,7 +427,7 @@ def _handle_Return(self, stmt_idx: int, stmt: Return, block: Block | None):

if changed:
new_stmt = Return(stmt.idx, new_ret_exprs, **stmt.tags)
if block is not None:
if self._update_block and block is not None:
block.statements[stmt_idx] = new_stmt
return new_stmt
return None
Expand Down

0 comments on commit f5fec81

Please sign in to comment.