Skip to content

Commit

Permalink
Type check.
Browse files Browse the repository at this point in the history
  • Loading branch information
ltfish committed Jan 23, 2025
1 parent 3d91bb4 commit 6c70467
Showing 1 changed file with 42 additions and 26 deletions.
68 changes: 42 additions & 26 deletions ailment/block_walker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# pylint:disable=unused-argument,no-self-use
# pyright: reportIncompatibleMethodOverride=false
from typing import Any
from collections.abc import Callable

Expand Down Expand Up @@ -60,7 +61,7 @@ def __init__(self, stmt_handlers=None, expr_handlers=None):
self.stmt_handlers: dict[type, Callable] = stmt_handlers if stmt_handlers else _default_stmt_handlers
self.expr_handlers: dict[type, Callable] = expr_handlers if expr_handlers else _default_expr_handlers

def walk(self, block: Block):
def walk(self, block: Block) -> None:
i = 0
while i < len(block.statements):
stmt = block.statements[i]
Expand Down Expand Up @@ -110,7 +111,8 @@ def _handle_Assignment(self, stmt_idx: int, stmt: Assignment, block: Block | Non
self._handle_expr(1, stmt.src, stmt_idx, stmt, block)

def _handle_Call(self, stmt_idx: int, stmt: Call, block: Block | None):
self._handle_expr(-1, stmt.target, stmt_idx, stmt, block)
if not isinstance(stmt.target, str):
self._handle_expr(-1, stmt.target, stmt_idx, stmt, block)
if stmt.args:
for i, arg in enumerate(stmt.args):
self._handle_expr(i, arg, stmt_idx, stmt, block)
Expand All @@ -126,8 +128,10 @@ def _handle_Jump(self, stmt_idx: int, stmt: Jump, block: Block | None):

def _handle_ConditionalJump(self, stmt_idx: int, stmt: ConditionalJump, block: Block | None):
self._handle_expr(0, stmt.condition, stmt_idx, stmt, block)
self._handle_expr(1, stmt.true_target, stmt_idx, stmt, block)
self._handle_expr(2, stmt.false_target, stmt_idx, stmt, block)
if stmt.true_target is not None:
self._handle_expr(1, stmt.true_target, stmt_idx, stmt, block)
if stmt.false_target is not None:
self._handle_expr(2, stmt.false_target, stmt_idx, stmt, block)

def _handle_Return(self, stmt_idx: int, stmt: Return, block: Block | None):
if stmt.ret_exprs:
Expand All @@ -141,7 +145,8 @@ def _handle_Load(self, expr_idx: int, expr: Load, stmt_idx: int, stmt: Statement
self._handle_expr(0, expr.addr, stmt_idx, stmt, block)

def _handle_CallExpr(self, expr_idx: int, expr: Call, stmt_idx: int, stmt: Statement, block: Block | None):
self._handle_expr(-1, expr.target, stmt_idx, stmt, block)
if not isinstance(expr.target, str):
self._handle_expr(-1, expr.target, stmt_idx, stmt, block)
if expr.args:
for i, arg in enumerate(expr.args):
self._handle_expr(i, arg, stmt_idx, stmt, block)
Expand Down Expand Up @@ -182,7 +187,8 @@ def _handle_VirtualVariable(

def _handle_Phi(self, expr_id: int, expr: Phi, stmt_idx: int, stmt: Statement, block: Block | None):
for idx, (_, vvar) in enumerate(expr.src_and_vvars):
self._handle_expr(idx, vvar, stmt_idx, stmt, block)
if vvar is not None:
self._handle_expr(idx, vvar, stmt_idx, stmt, block)

def _handle_MultiStatementExpression(
self, expr_idx, expr: MultiStatementExpression, stmt_idx: int, stmt: Statement, block: Block | None
Expand Down Expand Up @@ -285,7 +291,7 @@ def _handle_expr(
# Default handlers
#

def _handle_Assignment(self, stmt_idx: int, stmt: Assignment, block: Block | None):
def _handle_Assignment(self, stmt_idx: int, stmt: Assignment, block: Block | None) -> Assignment | None:
changed = False

dst = self._handle_expr(0, stmt.dst, stmt_idx, stmt, block)
Expand All @@ -311,9 +317,12 @@ def _handle_Assignment(self, stmt_idx: int, stmt: Assignment, block: Block | Non
def _handle_Call(self, stmt_idx: int, stmt: Call, block: Block | None):
changed = False

new_target = self._handle_expr(-1, stmt.target, stmt_idx, stmt, block)
if new_target is not None and new_target is not stmt.target:
changed = True
if isinstance(stmt.target, str):
new_target = None
else:
new_target = self._handle_expr(-1, stmt.target, stmt_idx, stmt, block)
if new_target is not None and new_target is not stmt.target:
changed = True

new_args = None
if stmt.args is not None:
Expand All @@ -326,7 +335,7 @@ def _handle_Call(self, stmt_idx: int, stmt: Call, block: Block | None):
if new_arg is not None and new_arg is not arg:
if not changed:
# initialize new_args
new_args = stmt.args[:i]
new_args = list(stmt.args[:i])
new_args.append(new_arg)
changed = True
else:
Expand Down Expand Up @@ -418,17 +427,21 @@ def _handle_ConditionalJump(self, stmt_idx: int, stmt: ConditionalJump, block: B
else:
condition = stmt.condition

true_target = self._handle_expr(1, stmt.true_target, stmt_idx, stmt, block)
if true_target is not None and true_target is not stmt.true_target:
changed = True
else:
true_target = stmt.true_target
true_target = None
if stmt.true_target is not None:
true_target = self._handle_expr(1, stmt.true_target, stmt_idx, stmt, block)
if true_target is not None and true_target is not stmt.true_target:
changed = True
else:
true_target = stmt.true_target

false_target = self._handle_expr(2, stmt.false_target, stmt_idx, stmt, block)
if false_target is not None and false_target is not stmt.false_target:
changed = True
else:
false_target = stmt.false_target
false_target = None
if stmt.false_target is not None:
false_target = self._handle_expr(2, stmt.false_target, stmt_idx, stmt, block)
if false_target is not None and false_target is not stmt.false_target:
changed = True
else:
false_target = stmt.false_target

if changed:
new_stmt = ConditionalJump(
Expand Down Expand Up @@ -498,9 +511,12 @@ def _handle_Load(self, expr_idx: int, expr: Load, stmt_idx: int, stmt: Statement
def _handle_CallExpr(self, expr_idx: int, expr: Call, stmt_idx: int, stmt: Statement, block: Block | None):
changed = False

new_target = self._handle_expr(-1, expr.target, stmt_idx, stmt, block)
if new_target is not None and new_target is not expr.target:
changed = True
if isinstance(expr.target, str):
new_target = None
else:
new_target = self._handle_expr(-1, expr.target, stmt_idx, stmt, block)
if new_target is not None and new_target is not expr.target:
changed = True

new_args = None
if expr.args is not None:
Expand All @@ -512,7 +528,7 @@ def _handle_CallExpr(self, expr_idx: int, expr: Call, stmt_idx: int, stmt: State
if new_arg is not None and new_arg is not arg:
if not changed:
# initialize new_args
new_args = expr.args[:i]
new_args = list(expr.args[:i])
new_args.append(new_arg)
changed = True
else:
Expand Down Expand Up @@ -610,7 +626,7 @@ def _handle_Phi(self, expr_id: int, expr: Phi, stmt_idx: int, stmt: Statement, b

changed = False

src_and_vvars = None
src_and_vvars = []
for idx, (src, vvar) in enumerate(expr.src_and_vvars):
if vvar is None:
if src_and_vvars is not None:
Expand Down

0 comments on commit 6c70467

Please sign in to comment.