From 82501daeb0bc9c18c08a748978fc85fe70c28f70 Mon Sep 17 00:00:00 2001 From: Audrey Dutcher Date: Thu, 12 Dec 2024 01:06:25 -0700 Subject: [PATCH] Typecheck more things (#265) --------- Co-authored-by: Fish --- ailment/__init__.py | 18 +- ailment/block.py | 2 +- ailment/block_walker.py | 4 +- ailment/converter_vex.py | 77 +++++-- ailment/expression.py | 424 +++++++++++++++------------------------ ailment/py.typed | 0 ailment/statement.py | 85 +++++--- ailment/tagged_object.py | 4 +- ailment/utils.py | 33 ++- setup.cfg | 6 + 10 files changed, 324 insertions(+), 329 deletions(-) create mode 100644 ailment/py.typed diff --git a/ailment/__init__.py b/ailment/__init__.py index 73fdf4e..ee7edae 100644 --- a/ailment/__init__.py +++ b/ailment/__init__.py @@ -3,9 +3,9 @@ import logging from .block import Block -from . import statement as Stmt -from . import expression as Expr -from .statement import Assignment +from . import statement +from . import expression +from .statement import Assignment, Statement from .expression import Expression, Const, Tmp, Register, UnaryOp, BinaryOp from .converter_common import Converter from .manager import Manager @@ -13,6 +13,9 @@ log = logging.getLogger(__name__) +# REALLY BAD +Expr = expression +Stmt = statement available_converters: set[str] = set() @@ -24,6 +27,7 @@ except ImportError as e: log.debug("Could not import VEXIRSBConverter") log.debug(e) + VEXIRSBConverter = None try: from .converter_pcode import PCodeIRSBConverter @@ -33,6 +37,7 @@ except ImportError as e: log.debug("Could not import PCodeIRSBConverter") log.debug(e) + PCodeIRSBConverter = None class IRSBConverter(Converter): @@ -57,8 +62,11 @@ def convert(irsb, manager): # pylint:disable=arguments-differ __all__ = [ "available_converters", "Block", + "expression", + "statement", "Stmt", "Expr", + "Statement", "Assignment", "Expression", "Const", @@ -70,6 +78,6 @@ def convert(irsb, manager): # pylint:disable=arguments-differ "IRSBConverter", "AILBlockWalkerBase", "AILBlockWalker", - *(["PCodeIRSBConverter"] if "pcode" in available_converters else []), - *(["VEXIRSBConverter"] if "vex" in available_converters else []), + "PCodeIRSBConverter", + "VEXIRSBConverter", ] diff --git a/ailment/block.py b/ailment/block.py index 892ad56..a5fc57e 100644 --- a/ailment/block.py +++ b/ailment/block.py @@ -16,7 +16,7 @@ class Block: "idx", ) - def __init__(self, addr, original_size, statements=None, idx=None): + def __init__(self, addr: int, original_size, statements=None, idx=None): self.addr = addr self.original_size = original_size self.statements: list["Statement"] = [] if statements is None else statements diff --git a/ailment/block_walker.py b/ailment/block_walker.py index cdc3651..8c12e17 100644 --- a/ailment/block_walker.py +++ b/ailment/block_walker.py @@ -72,10 +72,10 @@ def walk_expression( self, expr: Expression, stmt_idx: int | None = None, - stmt: int | None = None, + stmt: Statement | None = None, block: Block | None = None, ): - return self._handle_expr(0, expr, stmt_idx, stmt, block) + return self._handle_expr(0, expr, stmt_idx or 0, stmt, block) def _handle_stmt(self, stmt_idx: int, stmt: Statement, block: Block | None) -> Any: try: diff --git a/ailment/converter_vex.py b/ailment/converter_vex.py index d20573b..dd473be 100644 --- a/ailment/converter_vex.py +++ b/ailment/converter_vex.py @@ -20,7 +20,6 @@ ITE, Reinterpret, VEXCCallExpression, - TernaryOp, ) from .converter_common import SkipConversionNotice, Converter @@ -288,10 +287,68 @@ def Binop(expr, manager): bits = op._output_size_bits - extra_kwargs = {} if op_name == "DivMod": - extra_kwargs["from_bits"] = op._from_size if op._from_size is not None else operands[1].bits - extra_kwargs["to_bits"] = op._to_size if op._to_size is not None else operands[1].bits + op1_size = op._from_size if op._from_size is not None else operands[0].bits + op2_size = op._to_size if op._to_size is not None else operands[1].bits + + if op2_size < op1_size: + # e.g., DivModU64to32 + operands[1] = Convert( + manager.next_atom(), + op2_size, + op1_size, + op._from_signed != "U", + operands[1], + ins_addr=manager.ins_addr, + vex_block_addr=manager.block_addr, + vex_stmt_idx=manager.vex_stmt_idx, + ) + chunk_bits = bits // 2 + + div = BinaryOp( + manager.next_atom(), + "Div", + operands, + signed, + ins_addr=manager.ins_addr, + vex_block_addr=manager.block_addr, + vex_stmt_idx=manager.vex_stmt_idx, + bits=op1_size, + ) + truncated_div = Convert( + manager.next_atom(), + op1_size, + chunk_bits, + signed, + div, + ins_addr=manager.ins_addr, + vex_block_addr=manager.block_addr, + vex_stmt_idx=manager.vex_stmt_idx, + ) + mod = BinaryOp( + manager.next_atom(), + "Mod", + operands, + signed, + ins_addr=manager.ins_addr, + vex_block_addr=manager.block_addr, + vex_stmt_idx=manager.vex_stmt_idx, + bits=op1_size, + ) + truncated_mod = Convert( + manager.next_atom(), + op1_size, + chunk_bits, + signed, + mod, + ins_addr=manager.ins_addr, + vex_block_addr=manager.block_addr, + vex_stmt_idx=manager.vex_stmt_idx, + ) + + operands = [truncated_mod, truncated_div] + op_name = "Concat" + signed = False return BinaryOp( manager.next_atom(), @@ -304,7 +361,6 @@ def Binop(expr, manager): bits=bits, vector_count=vector_count, vector_size=vector_size, - **extra_kwargs, ) @staticmethod @@ -332,14 +388,9 @@ def Triop(expr, manager): bits=bits, ) - return TernaryOp( - manager.next_atom(), - op_name, - operands, - ins_addr=manager.ins_addr, - vex_block_addr=manager.block_addr, - vex_stmt_idx=manager.vex_stmt_idx, - bits=bits, + raise TypeError( + "Please figure out what kind of operation this is (smart money says fused multiply) and convert it into " + "multiple binops" ) @staticmethod diff --git a/ailment/expression.py b/ailment/expression.py index 4feefa1..e49af80 100644 --- a/ailment/expression.py +++ b/ailment/expression.py @@ -1,8 +1,11 @@ # pylint:disable=arguments-renamed,isinstance-second-argument-not-valid-type,missing-class-docstring from __future__ import annotations -from enum import IntEnum +from typing import TYPE_CHECKING, cast +from collections.abc import Sequence +from enum import Enum, IntEnum +from abc import abstractmethod +from typing_extensions import Self -from typing import TYPE_CHECKING try: import claripy @@ -21,12 +24,18 @@ class Expression(TaggedObject): The base class of all AIL expressions. """ - __slots__ = ("depth",) + bits: int + + __slots__ = ( + "bits", + "depth", + ) def __init__(self, idx, depth, **kwargs): super().__init__(idx, **kwargs) self.depth = depth + @abstractmethod def __repr__(self): raise NotImplementedError() @@ -41,16 +50,18 @@ def __eq__(self, other): return True return type(self) is type(other) and self.likes(other) and self.idx == other.idx - def likes(self, atom): # pylint:disable=unused-argument,no-self-use + @abstractmethod + def likes(self, other): # pylint:disable=unused-argument,no-self-use raise NotImplementedError() - def matches(self, atom): # pylint:disable=unused-argument,no-self-use - return NotImplementedError() + @abstractmethod + def matches(self, other): # pylint:disable=unused-argument,no-self-use + raise NotImplementedError() - def replace(self, old_expr, new_expr): + def replace(self, old_expr: Expression, new_expr: Expression) -> tuple[bool, Self]: if self is old_expr: r = True - replaced = new_expr + replaced = cast(Self, new_expr) elif not isinstance(self, Atom): r, replaced = self.replace(old_expr, new_expr) else: @@ -59,10 +70,10 @@ def replace(self, old_expr, new_expr): return r, replaced def __add__(self, other): - return BinaryOp(None, "Add", [self, other], False, **self.tags) + return BinaryOp(None, "Add", [self, other], signed=False, **self.tags) def __sub__(self, other): - return BinaryOp(None, "Sub", [self, other], False, **self.tags) + return BinaryOp(None, "Sub", [self, other], signed=False, **self.tags) class Atom(Expression): @@ -71,25 +82,22 @@ class Atom(Expression): "variable_offset", ) - def __init__(self, idx, variable=None, variable_offset=0, **kwargs): + def __init__(self, idx: int | None, variable=None, variable_offset=0, **kwargs): super().__init__(idx, 0, **kwargs) self.variable = variable self.variable_offset = variable_offset - def __repr__(self): + def __repr__(self) -> str: return "Atom (%d)" % self.idx - def copy(self): # pylint:disable=no-self-use - return NotImplementedError() + def copy(self) -> Self: # pylint:disable=no-self-use + raise NotImplementedError() class Const(Atom): - __slots__ = ( - "value", - "bits", - ) + __slots__ = ("value",) - def __init__(self, idx, variable, value, bits, **kwargs): + def __init__(self, idx: int | None, variable, value: int | float, bits: int, **kwargs): super().__init__(idx, variable, **kwargs) self.value = value @@ -133,12 +141,9 @@ def copy(self) -> Const: class Tmp(Atom): - __slots__ = ( - "tmp_idx", - "bits", - ) + __slots__ = ("tmp_idx",) - def __init__(self, idx, variable, tmp_idx, bits, **kwargs): + def __init__(self, idx: int | None, variable, tmp_idx: int, bits, **kwargs): super().__init__(idx, variable, **kwargs) self.tmp_idx = tmp_idx @@ -168,12 +173,9 @@ def copy(self) -> Tmp: class Register(Atom): - __slots__ = ( - "reg_offset", - "bits", - ) + __slots__ = ("reg_offset",) - def __init__(self, idx, variable, reg_offset, bits, **kwargs): + def __init__(self, idx: int | None, variable, reg_offset: int, bits: int, **kwargs): super().__init__(idx, variable, **kwargs) self.reg_offset = reg_offset @@ -183,10 +185,8 @@ def __init__(self, idx, variable, reg_offset, bits, **kwargs): def size(self): return self.bits // 8 - def likes(self, atom): - return type(self) is type(atom) and self.reg_offset == atom.reg_offset and self.bits == atom.bits - - matches = likes + def likes(self, other): + return type(self) is type(other) and self.reg_offset == other.reg_offset and self.bits == other.bits def __repr__(self): return str(self) @@ -199,6 +199,7 @@ def __str__(self): else: return "%s" % str(self.variable.name) + matches = likes __hash__ = TaggedObject.__hash__ def _hash_core(self): @@ -220,7 +221,6 @@ class VirtualVariableCategory(IntEnum): class VirtualVariable(Atom): __slots__ = ( - "bits", "varid", "category", "oident", @@ -263,36 +263,36 @@ def was_tmp(self) -> bool: return self.category == VirtualVariableCategory.TMP @property - def reg_offset(self) -> int | None: + def reg_offset(self) -> int: if self.was_reg: return self.oident - return None + raise TypeError("Is not a register") @property - def stack_offset(self) -> int | None: + def stack_offset(self) -> int: if self.was_stack: return self.oident - return None + raise TypeError("Is not a stack variable") @property def tmp_idx(self) -> int | None: return self.oident if self.was_tmp else None - def likes(self, atom): + def likes(self, other): return ( - isinstance(atom, VirtualVariable) - and self.varid == atom.varid - and self.bits == atom.bits - and self.category == atom.category - and self.oident == atom.oident + isinstance(other, VirtualVariable) + and self.varid == other.varid + and self.bits == other.bits + and self.category == other.category + and self.oident == other.oident ) - def matches(self, atom): + def matches(self, other): return ( - isinstance(atom, VirtualVariable) - and self.bits == atom.bits - and self.category == atom.category - and self.oident == atom.oident + isinstance(other, VirtualVariable) + and self.bits == other.bits + and self.category == other.category + and self.oident == other.oident ) def __repr__(self): @@ -324,10 +324,7 @@ def copy(self) -> VirtualVariable: class Phi(Atom): - __slots__ = ( - "bits", - "src_and_vvars", - ) + __slots__ = ("src_and_vvars",) def __init__( self, @@ -352,21 +349,21 @@ def op(self) -> str: def verbose_op(self) -> str: return "Phi" - def likes(self, atom) -> bool: - if isinstance(atom, Phi) and self.bits == atom.bits: + def likes(self, other) -> bool: + if isinstance(other, Phi) and self.bits == other.bits: self_src_and_vvarids = {(src, vvar.varid if vvar is not None else None) for src, vvar in self.src_and_vvars} other_src_and_vvarids = { - (src, vvar.varid if vvar is not None else None) for src, vvar in atom.src_and_vvars + (src, vvar.varid if vvar is not None else None) for src, vvar in other.src_and_vvars } return self_src_and_vvarids == other_src_and_vvarids return False - def matches(self, atom) -> bool: - if isinstance(atom, Phi) and self.bits == atom.bits: - if len(self.src_and_vvars) != len(atom.src_and_vvars): + def matches(self, other) -> bool: + if isinstance(other, Phi) and self.bits == other.bits: + if len(self.src_and_vvars) != len(other.src_and_vvars): return False self_src_and_vvars = dict(self.src_and_vvars) - other_src_and_vvars = dict(atom.src_and_vvars) + other_src_and_vvars = dict(other.src_and_vvars) for src, self_vvar in self_src_and_vvars.items(): if src not in other_src_and_vvars: return False @@ -449,12 +446,20 @@ def verbose_op(self): class UnaryOp(Op): __slots__ = ( "operand", - "bits", "variable", "variable_offset", ) - def __init__(self, idx, op, operand, variable=None, variable_offset=None, bits: int | None = None, **kwargs): + def __init__( + self, + idx: int | None, + op: str, + operand: Expression, + variable=None, + variable_offset: int | None = None, + bits=None, + **kwargs, + ): super().__init__(idx, (operand.depth if isinstance(operand, Expression) else 0) + 1, op, **kwargs) self.operand = operand @@ -476,12 +481,12 @@ def likes(self, other): and self.operand.likes(other.operand) ) - def matches(self, atom): + def matches(self, other): return ( - type(atom) is UnaryOp - and self.op == atom.op - and self.bits == atom.bits - and self.operand.matches(atom.operand) + type(other) is UnaryOp + and self.op == other.op + and self.bits == other.bits + and self.operand.matches(other.operand) ) __hash__ = TaggedObject.__hash__ @@ -526,10 +531,15 @@ def has_atom(self, atom, identity=True): return self.operand.has_atom(atom, identity=identity) -class Convert(UnaryOp): +class ConvertType(Enum): TYPE_INT = 0 TYPE_FP = 1 + +class Convert(UnaryOp): + TYPE_INT = ConvertType.TYPE_INT + TYPE_FP = ConvertType.TYPE_FP + __slots__ = ( "from_bits", "to_bits", @@ -541,13 +551,13 @@ class Convert(UnaryOp): def __init__( self, - idx, - from_bits, - to_bits, - is_signed, - operand, - from_type=TYPE_INT, - to_type=TYPE_INT, + idx: int | None, + from_bits: int, + to_bits: int, + is_signed: bool, + operand: Expression, + from_type: ConvertType = TYPE_INT, + to_type: ConvertType = TYPE_INT, rounding_mode=None, **kwargs, ): @@ -738,14 +748,11 @@ def copy(self) -> Reinterpret: class BinaryOp(Op): __slots__ = ( "operands", - "bits", - "signed", "variable", "variable_offset", "floating_point", "rounding_mode", - "from_bits", # for divmod - "to_bits", # for divmod + "signed", "vector_count", "vector_size", ) @@ -761,7 +768,6 @@ class BinaryOp(Op): "MulV": "*", "Div": "/", "DivF": "/", - "DivMod": "/m", "Mod": "%", "Xor": "^", "And": "&", @@ -778,10 +784,10 @@ class BinaryOp(Op): "CmpLE": "<=", "CmpGT": ">", "CmpGE": ">=", - "CmpLTs": "s", - "CmpGEs": ">=s", + "CmpLT (signed)": "s", + "CmpGE (signed)": ">=s", "Concat": "CONCAT", "Ror": "ROR", "Rol": "ROL", @@ -797,25 +803,20 @@ class BinaryOp(Op): "CmpGE": "CmpLT", "CmpLE": "CmpGT", "CmpGT": "CmpLE", - "CmpLTs": "CmpGEs", - "CmpGEs": "CmpLTs", - "CmpLEs": "CmpGTs", - "CmpGTs": "CmpLEs", } def __init__( self, - idx, - op, - operands, - signed, + idx: int | None, + op: str, + operands: Sequence[Expression], + signed: bool = False, + *, variable=None, variable_offset=None, bits=None, - floating_point: bool = False, - rounding_mode: str | None = None, - from_bits: int | None = None, - to_bits: int | None = None, + floating_point=False, + rounding_mode=None, vector_count: int | None = None, vector_size: int | None = None, **kwargs, @@ -828,11 +829,6 @@ def __init__( + 1 ) - # special handling of initialization with signed op names - if op and op.endswith("s"): - op = op[:-1] - signed = True - super().__init__(idx, depth, op, **kwargs) assert len(operands) == 2 @@ -868,9 +864,6 @@ def __init__( self.vector_count = vector_count self.vector_size = vector_size - self.from_bits = from_bits - self.to_bits = to_bits - # TODO: sanity check of operands' sizes for some ops # assert self.bits == operands[1].bits @@ -932,14 +925,14 @@ def has_atom(self, atom, identity=True): return False - def replace(self, old_expr, new_expr): + def replace(self, old_expr: Expression, new_expr: Expression) -> tuple[bool, BinaryOp]: if self.operands[0] == old_expr: r0 = True replaced_operand_0 = new_expr elif isinstance(self.operands[0], Expression): r0, replaced_operand_0 = self.operands[0].replace(old_expr, new_expr) else: - r0, replaced_operand_0 = False, None + r0, replaced_operand_0 = False, new_expr if self.operands[1] == old_expr: r1 = True @@ -947,7 +940,7 @@ def replace(self, old_expr, new_expr): elif isinstance(self.operands[1], Expression): r1, replaced_operand_1 = self.operands[1].replace(old_expr, new_expr) else: - r1, replaced_operand_1 = False, None + r1, replaced_operand_1 = False, new_expr r2, replaced_rm = False, None if self.rounding_mode is not None: @@ -960,12 +953,10 @@ def replace(self, old_expr, new_expr): self.idx, self.op, [replaced_operand_0 if r0 else self.operands[0], replaced_operand_1 if r1 else self.operands[1]], - self.signed, + signed=self.signed, bits=self.bits, floating_point=self.floating_point, rounding_mode=replaced_rm if r2 else self.rounding_mode, - from_bits=self.from_bits, - to_bits=self.to_bits, **self.tags, ) else: @@ -975,10 +966,10 @@ def replace(self, old_expr, new_expr): def verbose_op(self): op = self.op if self.floating_point: - op += "F" + op += " (float)" else: if self.signed: - op += "s" + op += " (signed)" return op @property @@ -990,129 +981,16 @@ def copy(self) -> BinaryOp: self.idx, self.op, self.operands[::], - self.signed, variable=self.variable, + signed=self.signed, variable_offset=self.variable_offset, bits=self.bits, floating_point=self.floating_point, rounding_mode=self.rounding_mode, - from_bits=self.from_bits, - to_bits=self.to_bits, **self.tags, ) -class TernaryOp(Op): - OPSTR_MAP = {} - - __slots__ = ( - "operands", - "bits", - ) - - def __init__(self, idx, op, operands, bits=None, **kwargs): - depth = ( - max( - operands[0].depth if isinstance(operands[0], Expression) else 0, - operands[1].depth if isinstance(operands[1], Expression) else 0, - operands[2].depth if isinstance(operands[1], Expression) else 0, - ) - + 1 - ) - super().__init__(idx, depth, op, **kwargs) - - assert len(operands) == 3 - self.operands = operands - self.bits = bits - - def __str__(self): - return f"{self.verbose_op}({self.operands[0]}, {self.operands[1]}, {self.operands[2]})" - - def __repr__(self): - return f"{self.verbose_op}({self.operands[0]}, {self.operands[1]}, {self.operands[2]})" - - def likes(self, other): - return ( - type(other) is TernaryOp - and self.op == other.op - and self.bits == other.bits - and is_none_or_likeable(self.operands, other.operands, is_list=True) - ) - - def matches(self, other): - return ( - type(other) is TernaryOp - and self.op == other.op - and self.bits == other.bits - and is_none_or_matchable(self.operands, other.operands, is_list=True) - ) - - __hash__ = TaggedObject.__hash__ - - def _hash_core(self): - return stable_hash((self.op, tuple(self.operands), self.bits)) - - def has_atom(self, atom, identity=True): - if super().has_atom(atom, identity=identity): - return True - - for op in self.operands: - if identity and op == atom: - return True - if not identity and isinstance(op, Atom) and op.likes(atom): - return True - if isinstance(op, Atom) and op.has_atom(atom, identity=identity): - return True - return False - - def replace(self, old_expr, new_expr): - if self.operands[0] == old_expr: - r0 = True - replaced_operand_0 = new_expr - elif isinstance(self.operands[0], Expression): - r0, replaced_operand_0 = self.operands[0].replace(old_expr, new_expr) - else: - r0, replaced_operand_0 = False, None - - if self.operands[1] == old_expr: - r1 = True - replaced_operand_1 = new_expr - elif isinstance(self.operands[1], Expression): - r1, replaced_operand_1 = self.operands[1].replace(old_expr, new_expr) - else: - r1, replaced_operand_1 = False, None - - if self.operands[2] == old_expr: - r2 = True - replaced_operand_2 = new_expr - elif isinstance(self.operands[2], Expression): - r2, replaced_operand_2 = self.operands[2].replace(old_expr, new_expr) - else: - r2, replaced_operand_2 = False, None - - if r0 or r1 or r2: - return True, TernaryOp( - self.idx, - self.op, - [replaced_operand_0, replaced_operand_1, replaced_operand_2], - bits=self.bits, - **self.tags, - ) - else: - return False, self - - @property - def verbose_op(self): - return self.op - - @property - def size(self): - return self.bits // 8 - - def copy(self) -> TernaryOp: - return TernaryOp(self.idx, self.op, self.operands[::], bits=self.bits, **self.tags) - - class Load(Expression): __slots__ = ( "addr", @@ -1124,7 +1002,18 @@ class Load(Expression): "alt", ) - def __init__(self, idx, addr, size, endness, variable=None, variable_offset=None, guard=None, alt=None, **kwargs): + def __init__( + self, + idx: int | None, + addr: Expression, + size: int, + endness: str, + variable=None, + variable_offset=None, + guard=None, + alt=None, + **kwargs, + ): depth = max(addr.depth, size.depth if isinstance(size, Expression) else 0) + 1 super().__init__(idx, depth, **kwargs) @@ -1135,10 +1024,7 @@ def __init__(self, idx, addr, size, endness, variable=None, variable_offset=None self.alt = alt self.variable = variable self.variable_offset = variable_offset - - @property - def bits(self): - return self.size * 8 + self.bits = self.size * 8 def __repr__(self): return str(self) @@ -1169,6 +1055,7 @@ def replace(self, old_expr, new_expr): def _likes_addr(self, other_addr): if hasattr(self.addr, "likes") and hasattr(other_addr, "likes"): return self.addr.likes(other_addr) + return self.addr == other_addr def likes(self, other): @@ -1220,12 +1107,20 @@ class ITE(Expression): "cond", "iffalse", "iftrue", - "bits", "variable", "variable_offset", ) - def __init__(self, idx, cond, iffalse, iftrue, variable=None, variable_offset=None, **kwargs): + def __init__( + self, + idx: int | None, + cond: Expression, + iffalse: Expression, + iftrue: Expression, + variable=None, + variable_offset=None, + **kwargs, + ): depth = ( max( cond.depth if isinstance(cond, Expression) else 0, @@ -1249,22 +1144,22 @@ def __repr__(self): def __str__(self): return f"(({self.cond}) ? ({self.iftrue}) : ({self.iffalse}))" - def likes(self, atom): + def likes(self, other): return ( - type(atom) is ITE - and self.cond.likes(atom.cond) - and self.iffalse.likes(atom.iffalse) - and self.iftrue.likes(atom.iftrue) - and self.bits == atom.bits + type(other) is ITE + and self.cond.likes(other.cond) + and self.iffalse == other.iffalse + and self.iftrue == other.iftrue + and self.bits == other.bits ) - def matches(self, atom): + def matches(self, other): return ( - type(atom) is ITE - and self.cond.matches(atom.cond) - and self.iffalse.matches(atom.iffalse) - and self.iftrue.matches(atom.iftrue) - and self.bits == atom.bits + type(other) is ITE + and self.cond.matches(other.cond) + and self.iffalse == other.iffalse + and self.iftrue == other.iftrue + and self.bits == other.bits ) __hash__ = TaggedObject.__hash__ @@ -1324,7 +1219,6 @@ class DirtyExpression(Expression): "mfx", "maddr", "msize", - "bits", ) def __init__( @@ -1452,6 +1346,8 @@ def replace(self, old_expr: Expression, new_expr: Expression): @property def size(self): + if self.bits is None: + return None return self.bits // 8 @@ -1459,10 +1355,9 @@ class VEXCCallExpression(Expression): __slots__ = ( "callee", "operands", - "bits", ) - def __init__(self, idx, callee: str, operands: list[Expression], bits=None, **kwargs): + def __init__(self, idx: int | None, callee: str, operands: tuple[Expression, ...], bits: int, **kwargs): super().__init__(idx, max(operand.depth for operand in operands), **kwargs) self.callee = callee self.operands = operands @@ -1531,6 +1426,8 @@ def replace(self, old_expr, new_expr): @property def size(self): + if self.bits is None: + return None return self.bits // 8 @@ -1548,6 +1445,7 @@ def __init__(self, idx: int | None, stmts: list[Statement], expr: Expression, ** super().__init__(idx, expr.depth + 1, **kwargs) self.stmts = stmts self.expr = expr + self.bits = self.expr.bits __hash__ = TaggedObject.__hash__ @@ -1562,12 +1460,12 @@ def likes(self, other): and self.expr.likes(other.expr) ) - def matches(self, atom): + def matches(self, other): return ( - type(self) is type(atom) - and len(self.stmts) == len(atom.stmts) - and all(s_stmt.matches(o_stmt) for s_stmt, o_stmt in zip(self.stmts, atom.stmts)) - and self.expr.matches(atom.expr) + type(self) is type(other) + and len(self.stmts) == len(other.stmts) + and all(s_stmt.matches(o_stmt) for s_stmt, o_stmt in zip(self.stmts, other.stmts)) + and self.expr.matches(other.expr) ) def __repr__(self): @@ -1579,10 +1477,6 @@ def __str__(self): concatenated_str = ", ".join(stmts_str + [expr_str]) return f"({concatenated_str})" - @property - def bits(self): - return self.expr.bits - @property def size(self): return self.expr.size @@ -1620,14 +1514,22 @@ def copy(self) -> MultiStatementExpression: class BasePointerOffset(Expression): __slots__ = ( - "bits", "base", "offset", "variable", "variable_offset", ) - def __init__(self, idx, bits, base, offset, variable=None, variable_offset=None, **kwargs): + def __init__( + self, + idx: int | None, + bits: int, + base: Expression | str, + offset: int, + variable=None, + variable_offset=None, + **kwargs, + ): super().__init__(idx, (offset.depth if isinstance(offset, Expression) else 0) + 1, **kwargs) self.bits = bits self.base = base @@ -1688,7 +1590,7 @@ def copy(self) -> BasePointerOffset: class StackBaseOffset(BasePointerOffset): __slots__ = () - def __init__(self, idx, bits, offset, **kwargs): + def __init__(self, idx: int | None, bits: int, offset: int, **kwargs): # stack base offset is always signed if offset >= (1 << (bits - 1)): offset -= 1 << bits @@ -1707,7 +1609,7 @@ def negate(expr: Expression) -> Expression: expr.idx, BinaryOp.COMPARISON_NEGATION[expr.op], expr.operands, - expr.signed, + signed=expr.signed, bits=expr.bits, floating_point=expr.floating_point, rounding_mode=expr.rounding_mode, diff --git a/ailment/py.typed b/ailment/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/ailment/statement.py b/ailment/statement.py index 4edb5d9..0e436c4 100644 --- a/ailment/statement.py +++ b/ailment/statement.py @@ -1,5 +1,9 @@ # pylint:disable=isinstance-second-argument-not-valid-type,no-self-use,arguments-renamed -from typing import Optional, TYPE_CHECKING +from __future__ import annotations +from typing import TYPE_CHECKING +from collections.abc import Sequence +from abc import ABC, abstractmethod +from typing_extensions import Self try: import claripy @@ -8,26 +12,29 @@ from .utils import stable_hash, is_none_or_likeable, is_none_or_matchable from .tagged_object import TaggedObject -from .expression import Expression, DirtyExpression +from .expression import Atom, Expression, DirtyExpression if TYPE_CHECKING: from angr.calling_conventions import SimCC -class Statement(TaggedObject): +class Statement(TaggedObject, ABC): """ The base class of all AIL statements. """ __slots__ = () + @abstractmethod def __repr__(self): raise NotImplementedError() + @abstractmethod def __str__(self): raise NotImplementedError() - def replace(self, old_expr, new_expr): + @abstractmethod + def replace(self, old_expr: Expression, new_expr: Expression) -> tuple[bool, Self]: raise NotImplementedError() def eq(self, expr0, expr1): # pylint:disable=no-self-use @@ -35,11 +42,13 @@ def eq(self, expr0, expr1): # pylint:disable=no-self-use return expr0 is expr1 return expr0 == expr1 - def likes(self, atom): # pylint:disable=unused-argument,no-self-use + @abstractmethod + def likes(self, other) -> bool: # pylint:disable=unused-argument,no-self-use raise NotImplementedError() - def matches(self, atom): # pylint:disable=unused-argument,no-self-use - return NotImplementedError() + @abstractmethod + def matches(self, other) -> bool: # pylint:disable=unused-argument,no-self-use + raise NotImplementedError() class Assignment(Statement): @@ -52,7 +61,7 @@ class Assignment(Statement): "src", ) - def __init__(self, idx, dst, src, **kwargs): + def __init__(self, idx: int | None, dst: Atom, src: Expression, **kwargs): super().__init__(idx, **kwargs) self.dst = dst @@ -78,9 +87,10 @@ def __repr__(self): def __str__(self): return f"{str(self.dst)} = {str(self.src)}" - def replace(self, old_expr, new_expr): + def replace(self, old_expr: Expression, new_expr: Expression): if self.dst == old_expr: r_dst = True + assert isinstance(new_expr, Atom) replaced_dst = new_expr else: r_dst, replaced_dst = self.dst.replace(old_expr, new_expr) @@ -96,7 +106,7 @@ def replace(self, old_expr, new_expr): else: return False, self - def copy(self) -> "Assignment": + def copy(self) -> Assignment: return Assignment(self.idx, self.dst, self.src, **self.tags) @@ -115,7 +125,18 @@ class Store(Statement): "guard", ) - def __init__(self, idx, addr, data, size, endness, guard=None, variable=None, offset=None, **kwargs): + def __init__( + self, + idx: int | None, + addr: Expression, + data: Expression, + size: int, + endness: str, + guard: Expression | None = None, + variable=None, + offset=None, + **kwargs, + ): super().__init__(idx, **kwargs) self.addr = addr @@ -219,7 +240,7 @@ def replace(self, old_expr, new_expr): else: return False, self - def copy(self) -> "Store": + def copy(self) -> Store: return Store( self.idx, self.addr, @@ -424,7 +445,7 @@ def replace(self, old_expr, new_expr): else: return False, self - def copy(self) -> "ConditionalJump": + def copy(self) -> ConditionalJump: return ConditionalJump( self.idx, self.condition, @@ -452,18 +473,17 @@ class Call(Expression, Statement): "args", "ret_expr", "fp_ret_expr", - "bits", ) def __init__( self, idx, target, - calling_convention: Optional["SimCC"] = None, + calling_convention: SimCC | None = None, prototype=None, - args=None, - ret_expr=None, - fp_ret_expr=None, + args: Sequence[Expression] | None = None, + ret_expr: Expression | None = None, + fp_ret_expr: Expression | None = None, bits: int | None = None, **kwargs, ): @@ -475,7 +495,14 @@ def __init__( self.args = args self.ret_expr = ret_expr self.fp_ret_expr = fp_ret_expr - self.bits = bits if bits is not None else ret_expr.bits if ret_expr is not None else None + if bits is not None: + self.bits = bits + elif ret_expr is not None: + self.bits = ret_expr.bits + elif fp_ret_expr is not None: + self.bits = fp_ret_expr.bits + else: + self.bits = 0 # uhhhhhhhhhhhhhhhhhhh def likes(self, other): return ( @@ -544,7 +571,7 @@ def verbose_op(self): def op(self): return "call" - def replace(self, old_expr, new_expr): + def replace(self, old_expr: Expression, new_expr: Expression): if isinstance(self.target, Expression): r0, replaced_target = self.target.replace(old_expr, new_expr) else: @@ -692,7 +719,7 @@ class DirtyStatement(Statement): __slots__ = ("dirty",) - def __init__(self, idx, dirty: DirtyExpression, **kwargs): + def __init__(self, idx: int | None, dirty: DirtyExpression, **kwargs): super().__init__(idx, **kwargs) self.dirty = dirty @@ -707,15 +734,22 @@ def __str__(self): def replace(self, old_expr, new_expr): if self.dirty == old_expr: + assert isinstance(new_expr, DirtyExpression) return True, DirtyStatement(self.idx, new_expr, **self.tags) r, new_dirty = self.dirty.replace(old_expr, new_expr) if r: return True, DirtyStatement(self.idx, new_dirty, **self.tags) return False, self - def copy(self) -> "DirtyStatement": + def copy(self) -> DirtyStatement: return DirtyStatement(self.idx, self.dirty, **self.tags) + def likes(self, other): + return type(other) is DirtyStatement and self.dirty.likes(other.dirty) + + def matches(self, other): + return type(other) is DirtyStatement and self.dirty.matches(other.dirty) + class Label(Statement): """ @@ -734,9 +768,12 @@ def __init__(self, idx, name: str, ins_addr: int, block_idx: int | None = None, self.ins_addr = ins_addr self.block_idx = block_idx - def likes(self, other: "Label"): + def likes(self, other: Label): return isinstance(other, Label) + def replace(self, old_expr, new_expr): + return False, self + matches = likes def _hash_core(self): @@ -755,5 +792,5 @@ def __repr__(self): def __str__(self): return f"{self.name}:" - def copy(self) -> "Label": + def copy(self) -> Label: return Label(self.idx, self.name, self.ins_addr, self.block_idx, **self.tags) diff --git a/ailment/tagged_object.py b/ailment/tagged_object.py index b7dc446..f174497 100644 --- a/ailment/tagged_object.py +++ b/ailment/tagged_object.py @@ -9,7 +9,7 @@ class TaggedObject: "_hash", ) - def __init__(self, idx, **kwargs): + def __init__(self, idx: int | None, **kwargs): self._tags = None self.idx = idx self._hash = None @@ -43,7 +43,7 @@ def __new__(cls, *args, **kwargs): # pylint:disable=unused-argument self._tags = None return self - def __hash__(self): + def __hash__(self) -> int: if self._hash is None: self._hash = self._hash_core() return self._hash diff --git a/ailment/utils.py b/ailment/utils.py index 01cbf4e..a441807 100644 --- a/ailment/utils.py +++ b/ailment/utils.py @@ -1,37 +1,29 @@ -from typing import Union, TYPE_CHECKING +# pylint:disable=ungrouped-imports,wrong-import-position +from __future__ import annotations +from typing import TypeAlias import struct try: - import claripy + from claripy.ast import Bits except ImportError: - claripy = None + from typing_extensions import Never as Bits try: import _md5 as md5lib except ImportError: import hashlib as md5lib -if TYPE_CHECKING: - from .expression import Expression +GetBitsTypeParams: TypeAlias = "Bits | Expression" -get_bits_type_params = Union[int, "Expression"] -if claripy: - get_bits_type_params = Union[int, claripy.ast.Bits, "Expression"] - - -def get_bits(expr: get_bits_type_params) -> int | None: - # delayed import - from .expression import Expression # pylint:disable=import-outside-toplevel +def get_bits(expr: GetBitsTypeParams) -> int: if isinstance(expr, Expression): return expr.bits - elif isinstance(expr, claripy.ast.Bits): + elif isinstance(expr, Bits): return expr.size() - elif hasattr(expr, "bits"): - return expr.bits else: - return None + raise TypeError(type(expr)) md5_unpacker = struct.Struct("4I") @@ -95,8 +87,6 @@ def is_none_or_likeable(arg1, arg2, is_list=False): """ Returns whether two things are both None or can like each other """ - from .expression import Expression # pylint:disable=import-outside-toplevel - if arg1 is None or arg2 is None: if arg1 == arg2: return True @@ -114,8 +104,6 @@ def is_none_or_matchable(arg1, arg2, is_list=False): """ Returns whether two things are both None or can match each other """ - from .expression import Expression # pylint:disable=import-outside-toplevel - if arg1 is None or arg2 is None: if arg1 == arg2: return True @@ -127,3 +115,6 @@ def is_none_or_matchable(arg1, arg2, is_list=False): if isinstance(arg1, Expression): return arg1.matches(arg2) return arg1 == arg2 + + +from .expression import Expression # noqa: E402 diff --git a/setup.cfg b/setup.cfg index de5280f..eb8fbc3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,6 +18,8 @@ classifiers = [options] packages = find: +install_requires = + typing-extensions python_requires = >=3.10 [options.extras_require] @@ -29,3 +31,7 @@ docs = testing = pytest pytest-xdist + +[options.package_data] +ailment = + py.typed