Skip to content

Commit

Permalink
Make a local reference to patterns_pb2.Condition in structured_writer.
Browse files Browse the repository at this point in the history
To make pytyping of downstream use cases easier.

PiperOrigin-RevId: 557928776
Change-Id: I521cc1b6af42e803ef1889bca1975eeccb5033b5
  • Loading branch information
pwohlhart authored and copybara-github committed Aug 17, 2023
1 parent 33aba4e commit 58f5f01
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions reverb/structured_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

# TODO(b/204423296): Expose Python abstractions rather than the raw protos.
Config = patterns_pb2.StructuredWriterConfig
ConditionProto = patterns_pb2.Condition

Pattern = tree.Structure[patterns_pb2.PatternNode]
ReferenceStep = NewType('ReferenceStep', Any)
Expand Down Expand Up @@ -288,7 +289,7 @@ def my_transform(step):

def create_config(pattern: Pattern,
table: str,
conditions: Sequence[patterns_pb2.Condition] = (),
conditions: Sequence[ConditionProto] = (),
priority: Optional[patterns_pb2.Priority] = None):
structure = tree.map_structure(lambda _: None, pattern)
if priority is None:
Expand Down Expand Up @@ -387,7 +388,7 @@ def _validate_and_convert_to_spec(path, *nodes):
class _ConditionBuilder:
"""Helper class to make it easier to build conditions."""

def __init__(self, incomplete_condition: patterns_pb2.Condition):
def __init__(self, incomplete_condition: ConditionProto):
self._incomplete_condition = incomplete_condition

def __mod__(self, cmp: int) -> '_ConditionBuilder':
Expand All @@ -396,31 +397,31 @@ def __mod__(self, cmp: int) -> '_ConditionBuilder':
return _ConditionBuilder(incomplete_condition)

# pytype: disable=signature-mismatch # overriding-return-type-checks
def __eq__(self, cmp: int) -> patterns_pb2.Condition:
def __eq__(self, cmp: int) -> ConditionProto:
condition = copy.deepcopy(self._incomplete_condition)
if condition.mod_eq.mod:
condition.mod_eq.eq = cmp
else:
condition.eq = cmp
return condition

def __ne__(self, cmp: int) -> patterns_pb2.Condition:
def __ne__(self, cmp: int) -> ConditionProto:
condition = self == cmp
condition.inverse = True
return condition

def __gt__(self, cmp: int) -> patterns_pb2.Condition:
def __gt__(self, cmp: int) -> ConditionProto:
return self >= cmp + 1

def __ge__(self, cmp: int) -> patterns_pb2.Condition:
def __ge__(self, cmp: int) -> ConditionProto:
condition = copy.deepcopy(self._incomplete_condition)
condition.ge = cmp
return condition

def __lt__(self, cmp: int) -> patterns_pb2.Condition:
def __lt__(self, cmp: int) -> ConditionProto:
return self <= cmp - 1

def __le__(self, cmp: int) -> patterns_pb2.Condition:
def __le__(self, cmp: int) -> ConditionProto:
condition = self > cmp
condition.inverse = True
return condition
Expand All @@ -434,23 +435,23 @@ class Condition:
@staticmethod
def step_index():
"""(Zero) index of the most recent appended step within the episode."""
return _ConditionBuilder(patterns_pb2.Condition(step_index=True))
return _ConditionBuilder(ConditionProto(step_index=True))

@staticmethod
def steps_since_applied():
"""Number of added steps since an item was created for this config."""
return _ConditionBuilder(patterns_pb2.Condition(steps_since_applied=True))
return _ConditionBuilder(ConditionProto(steps_since_applied=True))

@staticmethod
def is_end_episode():
"""True only when end_episode is called on the writer."""
return patterns_pb2.Condition(is_end_episode=True, eq=1)
return ConditionProto(is_end_episode=True, eq=1)

@staticmethod
def data(step_structure: tree.Structure[Any]):
"""Value of a scalar integer or bool in the source data."""
flat = [
_ConditionBuilder(patterns_pb2.Condition(flat_source_index=i))
_ConditionBuilder(ConditionProto(flat_source_index=i))
for i in range(len(tree.flatten(step_structure)))
]
return tree.unflatten_as(step_structure, flat)
Expand Down

0 comments on commit 58f5f01

Please sign in to comment.