From 013448a2e11d139d35cbfc0aa10ba4a3d09a6a65 Mon Sep 17 00:00:00 2001 From: Fish Date: Thu, 23 Jan 2025 13:54:51 -0600 Subject: [PATCH] AILBlockWalker: Handle Reinterpret. (#275) * AILBlockWalker: Handle Reinterpret. * Type check. --- ailment/block_walker.py | 85 ++++++++++++++++++++++++++++------------- 1 file changed, 59 insertions(+), 26 deletions(-) diff --git a/ailment/block_walker.py b/ailment/block_walker.py index 8c12e17..92af66a 100644 --- a/ailment/block_walker.py +++ b/ailment/block_walker.py @@ -1,4 +1,5 @@ # pylint:disable=unused-argument,no-self-use +# pyright: reportIncompatibleMethodOverride=false from typing import Any from collections.abc import Callable @@ -16,6 +17,7 @@ Tmp, Register, Const, + Reinterpret, MultiStatementExpression, VirtualVariable, Phi, @@ -49,6 +51,7 @@ def __init__(self, stmt_handlers=None, expr_handlers=None): VEXCCallExpression: self._handle_VEXCCallExpression, Tmp: self._handle_Tmp, Register: self._handle_Register, + Reinterpret: self._handle_Reinterpret, Const: self._handle_Const, MultiStatementExpression: self._handle_MultiStatementExpression, VirtualVariable: self._handle_VirtualVariable, @@ -58,7 +61,7 @@ def __init__(self, stmt_handlers=None, expr_handlers=None): self.stmt_handlers: dict[type, Callable] = stmt_handlers if stmt_handlers else _default_stmt_handlers self.expr_handlers: dict[type, Callable] = expr_handlers if expr_handlers else _default_expr_handlers - def walk(self, block: Block): + def walk(self, block: Block) -> None: i = 0 while i < len(block.statements): stmt = block.statements[i] @@ -108,7 +111,8 @@ def _handle_Assignment(self, stmt_idx: int, stmt: Assignment, block: Block | Non self._handle_expr(1, stmt.src, stmt_idx, stmt, block) def _handle_Call(self, stmt_idx: int, stmt: Call, block: Block | None): - self._handle_expr(-1, stmt.target, stmt_idx, stmt, block) + if not isinstance(stmt.target, str): + self._handle_expr(-1, stmt.target, stmt_idx, stmt, block) if stmt.args: for i, arg in enumerate(stmt.args): self._handle_expr(i, arg, stmt_idx, stmt, block) @@ -124,8 +128,10 @@ def _handle_Jump(self, stmt_idx: int, stmt: Jump, block: Block | None): def _handle_ConditionalJump(self, stmt_idx: int, stmt: ConditionalJump, block: Block | None): self._handle_expr(0, stmt.condition, stmt_idx, stmt, block) - self._handle_expr(1, stmt.true_target, stmt_idx, stmt, block) - self._handle_expr(2, stmt.false_target, stmt_idx, stmt, block) + if stmt.true_target is not None: + self._handle_expr(1, stmt.true_target, stmt_idx, stmt, block) + if stmt.false_target is not None: + self._handle_expr(2, stmt.false_target, stmt_idx, stmt, block) def _handle_Return(self, stmt_idx: int, stmt: Return, block: Block | None): if stmt.ret_exprs: @@ -139,7 +145,8 @@ def _handle_Load(self, expr_idx: int, expr: Load, stmt_idx: int, stmt: Statement self._handle_expr(0, expr.addr, stmt_idx, stmt, block) def _handle_CallExpr(self, expr_idx: int, expr: Call, stmt_idx: int, stmt: Statement, block: Block | None): - self._handle_expr(-1, expr.target, stmt_idx, stmt, block) + if not isinstance(expr.target, str): + self._handle_expr(-1, expr.target, stmt_idx, stmt, block) if expr.args: for i, arg in enumerate(expr.args): self._handle_expr(i, arg, stmt_idx, stmt, block) @@ -154,6 +161,11 @@ def _handle_UnaryOp(self, expr_idx: int, expr: UnaryOp, stmt_idx: int, stmt: Sta def _handle_Convert(self, expr_idx: int, expr: Convert, stmt_idx: int, stmt: Statement, block: Block | None): self._handle_expr(expr_idx, expr.operand, stmt_idx, stmt, block) + def _handle_Reinterpret( + self, expr_idx: int, expr: Reinterpret, stmt_idx: int, stmt: Statement, block: Block | None + ): + self._handle_expr(expr_idx, expr.operand, stmt_idx, stmt, block) + def _handle_ITE(self, expr_idx: int, expr: ITE, stmt_idx: int, stmt: Statement, block: Block | None): self._handle_expr(0, expr.cond, stmt_idx, stmt, block) self._handle_expr(1, expr.iftrue, stmt_idx, stmt, block) @@ -175,7 +187,8 @@ def _handle_VirtualVariable( def _handle_Phi(self, expr_id: int, expr: Phi, stmt_idx: int, stmt: Statement, block: Block | None): for idx, (_, vvar) in enumerate(expr.src_and_vvars): - self._handle_expr(idx, vvar, stmt_idx, stmt, block) + if vvar is not None: + self._handle_expr(idx, vvar, stmt_idx, stmt, block) def _handle_MultiStatementExpression( self, expr_idx, expr: MultiStatementExpression, stmt_idx: int, stmt: Statement, block: Block | None @@ -278,7 +291,7 @@ def _handle_expr( # Default handlers # - def _handle_Assignment(self, stmt_idx: int, stmt: Assignment, block: Block | None): + def _handle_Assignment(self, stmt_idx: int, stmt: Assignment, block: Block | None) -> Assignment | None: changed = False dst = self._handle_expr(0, stmt.dst, stmt_idx, stmt, block) @@ -304,9 +317,12 @@ def _handle_Assignment(self, stmt_idx: int, stmt: Assignment, block: Block | Non def _handle_Call(self, stmt_idx: int, stmt: Call, block: Block | None): changed = False - new_target = self._handle_expr(-1, stmt.target, stmt_idx, stmt, block) - if new_target is not None and new_target is not stmt.target: - changed = True + if isinstance(stmt.target, str): + new_target = None + else: + new_target = self._handle_expr(-1, stmt.target, stmt_idx, stmt, block) + if new_target is not None and new_target is not stmt.target: + changed = True new_args = None if stmt.args is not None: @@ -319,7 +335,7 @@ def _handle_Call(self, stmt_idx: int, stmt: Call, block: Block | None): if new_arg is not None and new_arg is not arg: if not changed: # initialize new_args - new_args = stmt.args[:i] + new_args = list(stmt.args[:i]) new_args.append(new_arg) changed = True else: @@ -411,17 +427,21 @@ def _handle_ConditionalJump(self, stmt_idx: int, stmt: ConditionalJump, block: B else: condition = stmt.condition - true_target = self._handle_expr(1, stmt.true_target, stmt_idx, stmt, block) - if true_target is not None and true_target is not stmt.true_target: - changed = True - else: - true_target = stmt.true_target + true_target = None + if stmt.true_target is not None: + true_target = self._handle_expr(1, stmt.true_target, stmt_idx, stmt, block) + if true_target is not None and true_target is not stmt.true_target: + changed = True + else: + true_target = stmt.true_target - false_target = self._handle_expr(2, stmt.false_target, stmt_idx, stmt, block) - if false_target is not None and false_target is not stmt.false_target: - changed = True - else: - false_target = stmt.false_target + false_target = None + if stmt.false_target is not None: + false_target = self._handle_expr(2, stmt.false_target, stmt_idx, stmt, block) + if false_target is not None and false_target is not stmt.false_target: + changed = True + else: + false_target = stmt.false_target if changed: new_stmt = ConditionalJump( @@ -491,9 +511,12 @@ def _handle_Load(self, expr_idx: int, expr: Load, stmt_idx: int, stmt: Statement def _handle_CallExpr(self, expr_idx: int, expr: Call, stmt_idx: int, stmt: Statement, block: Block | None): changed = False - new_target = self._handle_expr(-1, expr.target, stmt_idx, stmt, block) - if new_target is not None and new_target is not expr.target: - changed = True + if isinstance(expr.target, str): + new_target = None + else: + new_target = self._handle_expr(-1, expr.target, stmt_idx, stmt, block) + if new_target is not None and new_target is not expr.target: + changed = True new_args = None if expr.args is not None: @@ -505,7 +528,7 @@ def _handle_CallExpr(self, expr_idx: int, expr: Call, stmt_idx: int, stmt: State if new_arg is not None and new_arg is not arg: if not changed: # initialize new_args - new_args = expr.args[:i] + new_args = list(expr.args[:i]) new_args.append(new_arg) changed = True else: @@ -557,6 +580,16 @@ def _handle_Convert(self, expr_idx: int, expr: Convert, stmt_idx: int, stmt: Sta return Convert(expr.idx, expr.from_bits, expr.to_bits, expr.is_signed, new_operand, **expr.tags) return None + def _handle_Reinterpret( + self, expr_idx: int, expr: Reinterpret, stmt_idx: int, stmt: Statement, block: Block | None + ): + new_operand = self._handle_expr(expr_idx, expr.operand, stmt_idx, stmt, block) + if new_operand is not None and new_operand is not expr.operand: + return Reinterpret( + expr.idx, expr.from_bits, expr.from_type, expr.to_bits, expr.to_type, new_operand, **expr.tags + ) + return None + def _handle_ITE(self, expr_idx: int, expr: ITE, stmt_idx: int, stmt: Statement, block: Block | None): changed = False @@ -593,7 +626,7 @@ def _handle_Phi(self, expr_id: int, expr: Phi, stmt_idx: int, stmt: Statement, b changed = False - src_and_vvars = None + src_and_vvars = [] for idx, (src, vvar) in enumerate(expr.src_and_vvars): if vvar is None: if src_and_vvars is not None: