Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VirtualVariable: Add some accessors for type safety. #277

Merged
merged 2 commits into from
Jan 24, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 50 additions & 21 deletions ailment/expression.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# pylint:disable=arguments-renamed,isinstance-second-argument-not-valid-type,missing-class-docstring
# pylint:disable=arguments-renamed,isinstance-second-argument-not-valid-type,missing-class-docstring,too-many-boolean-expressions
from __future__ import annotations
from typing import TYPE_CHECKING, cast
from collections.abc import Sequence
Expand Down Expand Up @@ -127,13 +127,16 @@ def likes(self, other):
)

matches = likes
__hash__ = TaggedObject.__hash__
__hash__ = TaggedObject.__hash__ # type: ignore

def _hash_core(self):
return stable_hash((self.value, self.bits))

@property
def sign_bit(self):
if not self.is_int:
raise TypeError("Sign bit is only available for int constants.")
assert isinstance(self.value, int)
return self.value >> (self.bits - 1)

def copy(self) -> Const:
Expand Down Expand Up @@ -167,7 +170,7 @@ def likes(self, other):
return type(self) is type(other) and self.tmp_idx == other.tmp_idx and self.bits == other.bits

matches = likes
__hash__ = TaggedObject.__hash__
__hash__ = TaggedObject.__hash__ # type: ignore

def _hash_core(self):
return stable_hash(("tmp", self.tmp_idx, self.bits))
Expand Down Expand Up @@ -204,7 +207,7 @@ def __str__(self):
return "%s" % str(self.variable.name)

matches = likes
__hash__ = TaggedObject.__hash__
__hash__ = TaggedObject.__hash__ # type: ignore

def _hash_core(self):
return stable_hash(("reg", self.reg_offset, self.bits, self.idx))
Expand Down Expand Up @@ -269,18 +272,44 @@ def was_tmp(self) -> bool:
@property
def reg_offset(self) -> int:
if self.was_reg:
assert isinstance(self.oident, int)
return self.oident
raise TypeError("Is not a register")

@property
def stack_offset(self) -> int:
if self.was_stack:
assert isinstance(self.oident, int)
return self.oident
raise TypeError("Is not a stack variable")

@property
def tmp_idx(self) -> int | None:
return self.oident if self.was_tmp else None
if self.was_tmp:
assert isinstance(self.oident, int)
return self.oident
return None

@property
def parameter_category(self) -> VirtualVariableCategory | None:
if self.was_parameter:
assert isinstance(self.oident, tuple)
return self.oident[0]
return None

@property
def parameter_reg_offset(self) -> int | None:
if self.was_parameter and self.parameter_category == VirtualVariableCategory.REGISTER:
assert isinstance(self.oident, tuple)
return self.oident[1]
return None

@property
def parameter_stack_offset(self) -> int | None:
if self.was_parameter and self.parameter_category == VirtualVariableCategory.STACK:
assert isinstance(self.oident, tuple)
return self.oident[1]
return None

def likes(self, other):
return (
Expand Down Expand Up @@ -308,7 +337,7 @@ def __repr__(self):
ori_str = f"{{stack {self.oident}}}"
return f"vvar_{self.varid}{ori_str}"

__hash__ = TaggedObject.__hash__
__hash__ = TaggedObject.__hash__ # type: ignore

def _hash_core(self):
return stable_hash(("var", self.varid, self.bits, self.category, self.oident))
Expand Down Expand Up @@ -379,7 +408,7 @@ def matches(self, other) -> bool:
and other_vvar is not None
or self_vvar is not None
and other_vvar is None
or not self_vvar.matches(other_vvar)
or (self_vvar is not None and other_vvar is not None and not self_vvar.matches(other_vvar))
):
return False
return True
Expand All @@ -388,7 +417,7 @@ def matches(self, other) -> bool:
def __repr__(self):
return f"𝜙@{self.bits}b {self.src_and_vvars}"

__hash__ = TaggedObject.__hash__
__hash__ = TaggedObject.__hash__ # type: ignore

def _hash_core(self):
return stable_hash(("phi", self.bits, tuple(sorted(self.src_and_vvars, key=self._src_and_vvar_filter))))
Expand Down Expand Up @@ -432,7 +461,7 @@ def _src_and_vvar_filter(
if src[1] is None:
src = src[0], -1
vvar_id = vvar.varid if vvar is not None else -1
return src, vvar_id
return src, vvar_id # type: ignore


class Op(Expression):
Expand Down Expand Up @@ -493,7 +522,7 @@ def matches(self, other):
and self.operand.matches(other.operand)
)

__hash__ = TaggedObject.__hash__
__hash__ = TaggedObject.__hash__ # type: ignore

def _hash_core(self):
return stable_hash((self.op, self.operand, self.bits))
Expand Down Expand Up @@ -608,7 +637,7 @@ def matches(self, other):
and self.rounding_mode == other.rounding_mode
)

__hash__ = TaggedObject.__hash__
__hash__ = TaggedObject.__hash__ # type: ignore

def _hash_core(self):
return stable_hash(
Expand Down Expand Up @@ -716,7 +745,7 @@ def matches(self, other):
and self.operand.matches(other.operand)
)

__hash__ = TaggedObject.__hash__
__hash__ = TaggedObject.__hash__ # type: ignore

def _hash_core(self):
return stable_hash(
Expand Down Expand Up @@ -900,7 +929,7 @@ def matches(self, other):
and self.rounding_mode == other.rounding_mode
)

__hash__ = TaggedObject.__hash__
__hash__ = TaggedObject.__hash__ # type: ignore

def _hash_core(self):
return stable_hash(
Expand Down Expand Up @@ -1087,7 +1116,7 @@ def matches(self, other):
and self.alt == other.alt
)

__hash__ = TaggedObject.__hash__
__hash__ = TaggedObject.__hash__ # type: ignore

def _hash_core(self):
return stable_hash(("Load", self.addr, self.size, self.endness))
Expand Down Expand Up @@ -1166,7 +1195,7 @@ def matches(self, other):
and self.bits == other.bits
)

__hash__ = TaggedObject.__hash__
__hash__ = TaggedObject.__hash__ # type: ignore

def _hash_core(self):
return stable_hash((ITE, self.cond, self.iffalse, self.iftrue, self.bits))
Expand Down Expand Up @@ -1236,7 +1265,7 @@ def __init__(
maddr: Expression | None = None,
msize: int | None = None,
# TODO: fxstate (guest state effects) is not modeled yet
bits=None,
bits: int,
**kwargs,
):
super().__init__(idx, 1, **kwargs)
Expand Down Expand Up @@ -1283,7 +1312,7 @@ def matches(self, other):
and self.bits == other.bits
)

__hash__ = TaggedObject.__hash__
__hash__ = TaggedObject.__hash__ # type: ignore

def _hash_core(self):
return stable_hash(
Expand Down Expand Up @@ -1393,7 +1422,7 @@ def matches(self, other):
and all(op1.matches(op2) for op1, op2 in zip(other.operands, self.operands))
)

__hash__ = TaggedObject.__hash__
__hash__ = TaggedObject.__hash__ # type: ignore

def _hash_core(self):
return stable_hash((VEXCCallExpression, self.callee, self.bits, tuple(self.operands)))
Expand Down Expand Up @@ -1424,7 +1453,7 @@ def replace(self, old_expr, new_expr):
new_operands.append(operand)

if replaced:
return True, VEXCCallExpression(self.idx, self.callee, list(new_operands), bits=self.bits, **self.tags)
return True, VEXCCallExpression(self.idx, self.callee, tuple(new_operands), bits=self.bits, **self.tags)
else:
return False, self

Expand All @@ -1451,7 +1480,7 @@ def __init__(self, idx: int | None, stmts: list[Statement], expr: Expression, **
self.expr = expr
self.bits = self.expr.bits

__hash__ = TaggedObject.__hash__
__hash__ = TaggedObject.__hash__ # type: ignore

def _hash_core(self):
return stable_hash((MultiStatementExpression,) + tuple(self.stmts) + (self.expr,))
Expand Down Expand Up @@ -1568,7 +1597,7 @@ def likes(self, other):
)

matches = likes
__hash__ = TaggedObject.__hash__
__hash__ = TaggedObject.__hash__ # type: ignore

def _hash_core(self):
return stable_hash((self.bits, self.base, self.offset))
Expand Down
Loading