Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AILBlockWalker: Handle Reinterpret. #275

Merged
merged 2 commits into from
Jan 23, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 59 additions & 26 deletions ailment/block_walker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# pylint:disable=unused-argument,no-self-use
# pyright: reportIncompatibleMethodOverride=false
from typing import Any
from collections.abc import Callable

Expand All @@ -16,6 +17,7 @@
Tmp,
Register,
Const,
Reinterpret,
MultiStatementExpression,
VirtualVariable,
Phi,
Expand Down Expand Up @@ -49,6 +51,7 @@ def __init__(self, stmt_handlers=None, expr_handlers=None):
VEXCCallExpression: self._handle_VEXCCallExpression,
Tmp: self._handle_Tmp,
Register: self._handle_Register,
Reinterpret: self._handle_Reinterpret,
Const: self._handle_Const,
MultiStatementExpression: self._handle_MultiStatementExpression,
VirtualVariable: self._handle_VirtualVariable,
Expand All @@ -58,7 +61,7 @@ def __init__(self, stmt_handlers=None, expr_handlers=None):
self.stmt_handlers: dict[type, Callable] = stmt_handlers if stmt_handlers else _default_stmt_handlers
self.expr_handlers: dict[type, Callable] = expr_handlers if expr_handlers else _default_expr_handlers

def walk(self, block: Block):
def walk(self, block: Block) -> None:
i = 0
while i < len(block.statements):
stmt = block.statements[i]
Expand Down Expand Up @@ -108,7 +111,8 @@ def _handle_Assignment(self, stmt_idx: int, stmt: Assignment, block: Block | Non
self._handle_expr(1, stmt.src, stmt_idx, stmt, block)

def _handle_Call(self, stmt_idx: int, stmt: Call, block: Block | None):
self._handle_expr(-1, stmt.target, stmt_idx, stmt, block)
if not isinstance(stmt.target, str):
self._handle_expr(-1, stmt.target, stmt_idx, stmt, block)
if stmt.args:
for i, arg in enumerate(stmt.args):
self._handle_expr(i, arg, stmt_idx, stmt, block)
Expand All @@ -124,8 +128,10 @@ def _handle_Jump(self, stmt_idx: int, stmt: Jump, block: Block | None):

def _handle_ConditionalJump(self, stmt_idx: int, stmt: ConditionalJump, block: Block | None):
self._handle_expr(0, stmt.condition, stmt_idx, stmt, block)
self._handle_expr(1, stmt.true_target, stmt_idx, stmt, block)
self._handle_expr(2, stmt.false_target, stmt_idx, stmt, block)
if stmt.true_target is not None:
self._handle_expr(1, stmt.true_target, stmt_idx, stmt, block)
if stmt.false_target is not None:
self._handle_expr(2, stmt.false_target, stmt_idx, stmt, block)

def _handle_Return(self, stmt_idx: int, stmt: Return, block: Block | None):
if stmt.ret_exprs:
Expand All @@ -139,7 +145,8 @@ def _handle_Load(self, expr_idx: int, expr: Load, stmt_idx: int, stmt: Statement
self._handle_expr(0, expr.addr, stmt_idx, stmt, block)

def _handle_CallExpr(self, expr_idx: int, expr: Call, stmt_idx: int, stmt: Statement, block: Block | None):
self._handle_expr(-1, expr.target, stmt_idx, stmt, block)
if not isinstance(expr.target, str):
self._handle_expr(-1, expr.target, stmt_idx, stmt, block)
if expr.args:
for i, arg in enumerate(expr.args):
self._handle_expr(i, arg, stmt_idx, stmt, block)
Expand All @@ -154,6 +161,11 @@ def _handle_UnaryOp(self, expr_idx: int, expr: UnaryOp, stmt_idx: int, stmt: Sta
def _handle_Convert(self, expr_idx: int, expr: Convert, stmt_idx: int, stmt: Statement, block: Block | None):
self._handle_expr(expr_idx, expr.operand, stmt_idx, stmt, block)

def _handle_Reinterpret(
self, expr_idx: int, expr: Reinterpret, stmt_idx: int, stmt: Statement, block: Block | None
):
self._handle_expr(expr_idx, expr.operand, stmt_idx, stmt, block)

def _handle_ITE(self, expr_idx: int, expr: ITE, stmt_idx: int, stmt: Statement, block: Block | None):
self._handle_expr(0, expr.cond, stmt_idx, stmt, block)
self._handle_expr(1, expr.iftrue, stmt_idx, stmt, block)
Expand All @@ -175,7 +187,8 @@ def _handle_VirtualVariable(

def _handle_Phi(self, expr_id: int, expr: Phi, stmt_idx: int, stmt: Statement, block: Block | None):
for idx, (_, vvar) in enumerate(expr.src_and_vvars):
self._handle_expr(idx, vvar, stmt_idx, stmt, block)
if vvar is not None:
self._handle_expr(idx, vvar, stmt_idx, stmt, block)

def _handle_MultiStatementExpression(
self, expr_idx, expr: MultiStatementExpression, stmt_idx: int, stmt: Statement, block: Block | None
Expand Down Expand Up @@ -278,7 +291,7 @@ def _handle_expr(
# Default handlers
#

def _handle_Assignment(self, stmt_idx: int, stmt: Assignment, block: Block | None):
def _handle_Assignment(self, stmt_idx: int, stmt: Assignment, block: Block | None) -> Assignment | None:
changed = False

dst = self._handle_expr(0, stmt.dst, stmt_idx, stmt, block)
Expand All @@ -304,9 +317,12 @@ def _handle_Assignment(self, stmt_idx: int, stmt: Assignment, block: Block | Non
def _handle_Call(self, stmt_idx: int, stmt: Call, block: Block | None):
changed = False

new_target = self._handle_expr(-1, stmt.target, stmt_idx, stmt, block)
if new_target is not None and new_target is not stmt.target:
changed = True
if isinstance(stmt.target, str):
new_target = None
else:
new_target = self._handle_expr(-1, stmt.target, stmt_idx, stmt, block)
if new_target is not None and new_target is not stmt.target:
changed = True

new_args = None
if stmt.args is not None:
Expand All @@ -319,7 +335,7 @@ def _handle_Call(self, stmt_idx: int, stmt: Call, block: Block | None):
if new_arg is not None and new_arg is not arg:
if not changed:
# initialize new_args
new_args = stmt.args[:i]
new_args = list(stmt.args[:i])
new_args.append(new_arg)
changed = True
else:
Expand Down Expand Up @@ -411,17 +427,21 @@ def _handle_ConditionalJump(self, stmt_idx: int, stmt: ConditionalJump, block: B
else:
condition = stmt.condition

true_target = self._handle_expr(1, stmt.true_target, stmt_idx, stmt, block)
if true_target is not None and true_target is not stmt.true_target:
changed = True
else:
true_target = stmt.true_target
true_target = None
if stmt.true_target is not None:
true_target = self._handle_expr(1, stmt.true_target, stmt_idx, stmt, block)
if true_target is not None and true_target is not stmt.true_target:
changed = True
else:
true_target = stmt.true_target

false_target = self._handle_expr(2, stmt.false_target, stmt_idx, stmt, block)
if false_target is not None and false_target is not stmt.false_target:
changed = True
else:
false_target = stmt.false_target
false_target = None
if stmt.false_target is not None:
false_target = self._handle_expr(2, stmt.false_target, stmt_idx, stmt, block)
if false_target is not None and false_target is not stmt.false_target:
changed = True
else:
false_target = stmt.false_target

if changed:
new_stmt = ConditionalJump(
Expand Down Expand Up @@ -491,9 +511,12 @@ def _handle_Load(self, expr_idx: int, expr: Load, stmt_idx: int, stmt: Statement
def _handle_CallExpr(self, expr_idx: int, expr: Call, stmt_idx: int, stmt: Statement, block: Block | None):
changed = False

new_target = self._handle_expr(-1, expr.target, stmt_idx, stmt, block)
if new_target is not None and new_target is not expr.target:
changed = True
if isinstance(expr.target, str):
new_target = None
else:
new_target = self._handle_expr(-1, expr.target, stmt_idx, stmt, block)
if new_target is not None and new_target is not expr.target:
changed = True

new_args = None
if expr.args is not None:
Expand All @@ -505,7 +528,7 @@ def _handle_CallExpr(self, expr_idx: int, expr: Call, stmt_idx: int, stmt: State
if new_arg is not None and new_arg is not arg:
if not changed:
# initialize new_args
new_args = expr.args[:i]
new_args = list(expr.args[:i])
new_args.append(new_arg)
changed = True
else:
Expand Down Expand Up @@ -557,6 +580,16 @@ def _handle_Convert(self, expr_idx: int, expr: Convert, stmt_idx: int, stmt: Sta
return Convert(expr.idx, expr.from_bits, expr.to_bits, expr.is_signed, new_operand, **expr.tags)
return None

def _handle_Reinterpret(
self, expr_idx: int, expr: Reinterpret, stmt_idx: int, stmt: Statement, block: Block | None
):
new_operand = self._handle_expr(expr_idx, expr.operand, stmt_idx, stmt, block)
if new_operand is not None and new_operand is not expr.operand:
return Reinterpret(
expr.idx, expr.from_bits, expr.from_type, expr.to_bits, expr.to_type, new_operand, **expr.tags
)
return None

def _handle_ITE(self, expr_idx: int, expr: ITE, stmt_idx: int, stmt: Statement, block: Block | None):
changed = False

Expand Down Expand Up @@ -593,7 +626,7 @@ def _handle_Phi(self, expr_id: int, expr: Phi, stmt_idx: int, stmt: Statement, b

changed = False

src_and_vvars = None
src_and_vvars = []
for idx, (src, vvar) in enumerate(expr.src_and_vvars):
if vvar is None:
if src_and_vvars is not None:
Expand Down
Loading