Skip to content

Commit

Permalink
dialects: Add UB dialect
Browse files Browse the repository at this point in the history
The goal of this dialect is to expose an MLIR type that represents
T + UB, where `UB` is a special value that represents undefined behavior.
This can be both used for the global state (where UB is the C++ definition
of undefined behavior), or for poison values (where UB is poison).
  • Loading branch information
math-fehr committed Jan 16, 2025
1 parent b085d1f commit d9670e3
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 0 deletions.
28 changes: 28 additions & 0 deletions tests/filecheck/dialects/ub.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: xdsl-smt "%s" | xdsl-smt | filecheck "%s"

%value = "smt.declare_const"() : () -> i32
%ub = ub.ub : !ub.ub_or<i32>
%non_ub = ub.from %value : !ub.ub_or<i32>
%res = ub.match %ub : !ub.ub_or<i32> -> i64 {
^bb0(%val: i32):
%x = "smt.declare_const"() : () -> i64
ub.yield %x : i64
} {
%y = "smt.declare_const"() : () -> i64
ub.yield %y : i64
}


// CHECK: builtin.module {
// CHECK-NEXT: %value = "smt.declare_const"() : () -> i32
// CHECK-NEXT: %ub = ub.ub : !ub.ub_or<i32>
// CHECK-NEXT: %non_ub = ub.from %value : !ub.ub_or<i32>
// CHECK-NEXT: %res = ub.match %ub : !ub.ub_or<i32> -> i64 {
// CHECK-NEXT: ^0(%val : i32):
// CHECK-NEXT: %x = "smt.declare_const"() : () -> i64
// CHECK-NEXT: ub.yield %x : i64
// CHECK-NEXT: } {
// CHECK-NEXT: %y = "smt.declare_const"() : () -> i64
// CHECK-NEXT: ub.yield %y : i64
// CHECK-NEXT: }
// CHECK-NEXT: }
2 changes: 2 additions & 0 deletions xdsl_smt/cli/xdsl_smt.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from xdsl_smt.dialects.hw_dialect import HW
from xdsl_smt.dialects.llvm_dialect import LLVM
from xdsl_smt.dialects.tv_dialect import TVDialect
from xdsl_smt.dialects.ub import UBDialect

from xdsl_smt.passes.dead_code_elimination import DeadCodeElimination
from xdsl_smt.passes.lower_pairs import LowerPairs
Expand Down Expand Up @@ -85,6 +86,7 @@ def register_all_dialects(self):
self.ctx.register_dialect(LLVM.name, lambda: LLVM)
self.ctx.register_dialect(Test.name, lambda: Test)
self.ctx.register_dialect(MemRef.name, lambda: MemRef)
self.ctx.register_dialect(UBDialect.name, lambda: UBDialect)
self.ctx.load_registered_dialect(SMTDialect.name)
self.ctx.load_registered_dialect(Transfer.name)
self.ctx.load_registered_dialect(SMTIntDialect.name)
Expand Down
129 changes: 129 additions & 0 deletions xdsl_smt/dialects/ub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from __future__ import annotations
from typing import Annotated, Generic, TypeVar

from xdsl.irdl import (
irdl_attr_definition,
irdl_op_definition,
operand_def,
result_def,
IRDLOperation,
ParameterDef,
ConstraintVar,
region_def,
var_result_def,
var_operand_def,
)
from xdsl.ir import (
ParametrizedAttribute,
Attribute,
SSAValue,
Region,
Block,
Dialect,
TypeAttribute,
)
from xdsl.traits import IsTerminator
from xdsl.utils.isattr import isattr

_UBAttrParameter = TypeVar("_UBAttrParameter", bound=Attribute)


@irdl_attr_definition
class UBOrType(Generic[_UBAttrParameter], ParametrizedAttribute, TypeAttribute):
"""A tagged union between a type and a UB singleton."""

name = "ub.ub_or"

type: ParameterDef[_UBAttrParameter]

def __init__(self, type: _UBAttrParameter):
super().__init__([type])


@irdl_op_definition
class UBOp(IRDLOperation):
"""Create an UB value."""

name = "ub.ub"

new_ub = result_def(UBOrType)

assembly_format = "attr-dict `:` type($new_ub)"

def __init__(self, type: Attribute):
"""Create an UB value for the given type."""
super().__init__(
operands=[],
result_types=[UBOrType(type)],
)


@irdl_op_definition
class FromOp(IRDLOperation):
"""Convert a value to a value + UB type."""

name = "ub.from"

T = Annotated[Attribute, ConstraintVar("T")]

value = operand_def(T)
result = result_def(UBOrType[T])

assembly_format = "$value attr-dict `:` type($result)"

def __init__(self, value: SSAValue):
super().__init__(
operands=[value],
result_types=[UBOrType(value.type)],
)


@irdl_op_definition
class MatchOp(IRDLOperation):
"""Pattern match on a tagged union between a value and UB."""

name = "ub.match"

T = Annotated[Attribute, ConstraintVar("T")]

value = operand_def(UBOrType[T])

value_region = region_def(single_block="single_block")
ub_region = region_def(single_block="single_block")

res = var_result_def()

assembly_format = "$value attr-dict-with-keyword `:` type($value) `->` type($res) $value_region $ub_region"

def __init__(self, value: SSAValue):
if not isattr(value.type, UBOrType[Attribute]):
raise ValueError(f"Expected a '{UBOrType.name}' type, got {value.type}")
value_region = Region(Block((), arg_types=[value.type.type]))
ub_region = Region(Block((), arg_types=[]))
super().__init__(
operands=[value],
result_types=[],
regions=[value_region, ub_region],
)


@irdl_op_definition
class YieldOp(IRDLOperation):
"""Yield a value inside an `ub.match` region."""

name = "ub.yield"

rets = var_operand_def()

assembly_format = "$rets attr-dict `:` type($rets)"

traits = frozenset([IsTerminator()])

def __init__(self, *rets: SSAValue):
super().__init__(
operands=list(rets),
result_types=[],
)


UBDialect = Dialect("ub", [UBOp, FromOp, MatchOp, YieldOp], [UBOrType])

0 comments on commit d9670e3

Please sign in to comment.