From 58f5f018082860caa4057d24d75d725709dcd2bb Mon Sep 17 00:00:00 2001 From: Paul Wohlhart Date: Thu, 17 Aug 2023 13:58:58 -0700 Subject: [PATCH] Make a local reference to patterns_pb2.Condition in structured_writer. To make pytyping of downstream use cases easier. PiperOrigin-RevId: 557928776 Change-Id: I521cc1b6af42e803ef1889bca1975eeccb5033b5 --- reverb/structured_writer.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/reverb/structured_writer.py b/reverb/structured_writer.py index 3b8d95b9..b6e66d7c 100644 --- a/reverb/structured_writer.py +++ b/reverb/structured_writer.py @@ -37,6 +37,7 @@ # TODO(b/204423296): Expose Python abstractions rather than the raw protos. Config = patterns_pb2.StructuredWriterConfig +ConditionProto = patterns_pb2.Condition Pattern = tree.Structure[patterns_pb2.PatternNode] ReferenceStep = NewType('ReferenceStep', Any) @@ -288,7 +289,7 @@ def my_transform(step): def create_config(pattern: Pattern, table: str, - conditions: Sequence[patterns_pb2.Condition] = (), + conditions: Sequence[ConditionProto] = (), priority: Optional[patterns_pb2.Priority] = None): structure = tree.map_structure(lambda _: None, pattern) if priority is None: @@ -387,7 +388,7 @@ def _validate_and_convert_to_spec(path, *nodes): class _ConditionBuilder: """Helper class to make it easier to build conditions.""" - def __init__(self, incomplete_condition: patterns_pb2.Condition): + def __init__(self, incomplete_condition: ConditionProto): self._incomplete_condition = incomplete_condition def __mod__(self, cmp: int) -> '_ConditionBuilder': @@ -396,7 +397,7 @@ def __mod__(self, cmp: int) -> '_ConditionBuilder': return _ConditionBuilder(incomplete_condition) # pytype: disable=signature-mismatch # overriding-return-type-checks - def __eq__(self, cmp: int) -> patterns_pb2.Condition: + def __eq__(self, cmp: int) -> ConditionProto: condition = copy.deepcopy(self._incomplete_condition) if condition.mod_eq.mod: condition.mod_eq.eq = cmp @@ -404,23 +405,23 @@ def __eq__(self, cmp: int) -> patterns_pb2.Condition: condition.eq = cmp return condition - def __ne__(self, cmp: int) -> patterns_pb2.Condition: + def __ne__(self, cmp: int) -> ConditionProto: condition = self == cmp condition.inverse = True return condition - def __gt__(self, cmp: int) -> patterns_pb2.Condition: + def __gt__(self, cmp: int) -> ConditionProto: return self >= cmp + 1 - def __ge__(self, cmp: int) -> patterns_pb2.Condition: + def __ge__(self, cmp: int) -> ConditionProto: condition = copy.deepcopy(self._incomplete_condition) condition.ge = cmp return condition - def __lt__(self, cmp: int) -> patterns_pb2.Condition: + def __lt__(self, cmp: int) -> ConditionProto: return self <= cmp - 1 - def __le__(self, cmp: int) -> patterns_pb2.Condition: + def __le__(self, cmp: int) -> ConditionProto: condition = self > cmp condition.inverse = True return condition @@ -434,23 +435,23 @@ class Condition: @staticmethod def step_index(): """(Zero) index of the most recent appended step within the episode.""" - return _ConditionBuilder(patterns_pb2.Condition(step_index=True)) + return _ConditionBuilder(ConditionProto(step_index=True)) @staticmethod def steps_since_applied(): """Number of added steps since an item was created for this config.""" - return _ConditionBuilder(patterns_pb2.Condition(steps_since_applied=True)) + return _ConditionBuilder(ConditionProto(steps_since_applied=True)) @staticmethod def is_end_episode(): """True only when end_episode is called on the writer.""" - return patterns_pb2.Condition(is_end_episode=True, eq=1) + return ConditionProto(is_end_episode=True, eq=1) @staticmethod def data(step_structure: tree.Structure[Any]): """Value of a scalar integer or bool in the source data.""" flat = [ - _ConditionBuilder(patterns_pb2.Condition(flat_source_index=i)) + _ConditionBuilder(ConditionProto(flat_source_index=i)) for i in range(len(tree.flatten(step_structure))) ] return tree.unflatten_as(step_structure, flat)