Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom coders in Reshuffle #33932

Merged
merged 7 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,12 @@

## New Features / Improvements

* Support custom coders in Reshuffle ([#29908](https://github.com/apache/beam/issues/29908), [#33356](https://github.com/apache/beam/issues/33356)).

* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).

## Breaking Changes
* [Python] Reshuffle is now more faithful to user-specified typehints and coders. Previously, Reshuffle could incorrectly use FastPrimitivesCoder in some pipelines. This update corrects that behavior. However, it may be a breaking change for pipelines with incorrect typehints for Reshuffle. If you encounter issues after upgrading to this Beam version, you can temporarily specify update_compatibility_version to an older Beam version (e.g. 2.63.0) in your pipeline options as a workaround. The recommended long-term solution is to correct the inaccurate typehints in your pipeline. ([#33932](https://github.com/apache/beam/pull/33932))

* X behavior was changed ([#X](https://github.com/apache/beam/issues/X)).

Expand Down
11 changes: 11 additions & 0 deletions sdks/python/apache_beam/coders/coders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,6 +1445,17 @@ def __hash__(self):
return hash(
(self.wrapped_value_coder, self.timestamp_coder, self.window_coder))

@classmethod
def from_type_hint(cls, typehint, registry):
# type: (Any, CoderRegistry) -> WindowedValueCoder
# Ideally this'd take two parameters so that one could hint at
# the window type as well instead of falling back to the
# pickle coders.
return cls(registry.get_coder(typehint.inner_type))

def to_type_hint(self):
return typehints.WindowedValue[self.wrapped_value_coder.to_type_hint()]


Coder.register_structured_urn(
common_urns.coders.WINDOWED_VALUE.urn, WindowedValueCoder)
Expand Down
6 changes: 6 additions & 0 deletions sdks/python/apache_beam/coders/coders_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,12 @@ def test_numpy_int(self):
_ = indata | "CombinePerKey" >> beam.CombinePerKey(sum)


class WindowedValueCoderTest(unittest.TestCase):
def test_to_type_hint(self):
coder = coders.WindowedValueCoder(coders.VarIntCoder())
self.assertEqual(coder.to_type_hint(), typehints.WindowedValue[int]) # type: ignore[misc]


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
2 changes: 2 additions & 0 deletions sdks/python/apache_beam/coders/typecoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def register_standard_coders(self, fallback_coder):
self._register_coder_internal(str, coders.StrUtf8Coder)
self._register_coder_internal(typehints.TupleConstraint, coders.TupleCoder)
self._register_coder_internal(typehints.DictConstraint, coders.MapCoder)
self._register_coder_internal(
typehints.WindowedTypeConstraint, coders.WindowedValueCoder)
# Default fallback coders applied in that order until the first matching
# coder found.
default_fallback_coders = [
Expand Down
52 changes: 50 additions & 2 deletions sdks/python/apache_beam/transforms/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from collections.abc import Iterable
from typing import TYPE_CHECKING
from typing import Any
from typing import Optional
from typing import TypeVar
from typing import Union

Expand All @@ -40,6 +41,7 @@
from apache_beam import pvalue
from apache_beam import typehints
from apache_beam.metrics import Metrics
from apache_beam.options import pipeline_options
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.pvalue import AsSideInput
Expand Down Expand Up @@ -71,11 +73,13 @@
from apache_beam.transforms.window import TimestampedValue
from apache_beam.typehints import trivial_inference
from apache_beam.typehints.decorators import get_signature
from apache_beam.typehints.native_type_compatibility import TypedWindowedValue
from apache_beam.typehints.sharded_key_type import ShardedKeyType
from apache_beam.utils import shared
from apache_beam.utils import windowed_value
from apache_beam.utils.annotations import deprecated
from apache_beam.utils.sharded_key import ShardedKey
from apache_beam.utils.timestamp import Timestamp

if TYPE_CHECKING:
from apache_beam.runners.pipeline_context import PipelineContext
Expand All @@ -102,6 +106,8 @@
V = TypeVar('V')
T = TypeVar('T')

RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION = "2.64.0"


class CoGroupByKey(PTransform):
"""Groups results across several PCollections by key.
Expand Down Expand Up @@ -922,6 +928,25 @@ def get_window_coder(self):
return self._window_coder


def is_compat_version_prior_to_breaking_change(
shunping marked this conversation as resolved.
Show resolved Hide resolved
update_compatibility_version, breaking_change_version):
# This function is used in a branch statement to determine whether we should
# keep the old behavior prior to a breaking change or use the new behavior.
# - If update_compatibility_version < breaking_change_version, we will return
# True and keep the old behavior.
# - If update_compatibility_version is None or >= breaking_change_version, we
# will return False and use the behavior from the breaking change.
if update_compatibility_version is None:
return False

compat_version = tuple(map(int, update_compatibility_version.split('.')[0:3]))
change_version = tuple(map(int, breaking_change_version.split('.')[0:3]))
for i in range(min(len(compat_version), len(change_version))):
if compat_version[i] < change_version[i]:
return True
return False


@typehints.with_input_types(tuple[K, V])
@typehints.with_output_types(tuple[K, V])
class ReshufflePerKey(PTransform):
Expand All @@ -931,6 +956,8 @@ class ReshufflePerKey(PTransform):
transforms.
"""
def expand(self, pcoll):
compat_version = pcoll.pipeline.options.view_as(
pipeline_options.StreamingOptions).update_compatibility_version
windowing_saved = pcoll.windowing
if windowing_saved.is_default():
# In this (common) case we can use a trivial trigger driver
Expand All @@ -951,6 +978,14 @@ def restore_timestamps(element):
window.GlobalWindows.windowed_value((key, value), timestamp)
for (value, timestamp) in values
]

if is_compat_version_prior_to_breaking_change(
compat_version, RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION):
ungrouped = pcoll | Map(reify_timestamps).with_output_types(Any)
else:
ungrouped = pcoll | Map(reify_timestamps).with_input_types(
tuple[K, V]).with_output_types(
tuple[K, tuple[V, Optional[Timestamp]]])
else:

# typing: All conditional function variants must have identical signatures
Expand All @@ -964,7 +999,12 @@ def restore_timestamps(element):
key, windowed_values = element
return [wv.with_value((key, wv.value)) for wv in windowed_values]

ungrouped = pcoll | Map(reify_timestamps).with_output_types(Any)
if is_compat_version_prior_to_breaking_change(
compat_version, RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION):
ungrouped = pcoll | Map(reify_timestamps).with_output_types(Any)
shunping marked this conversation as resolved.
Show resolved Hide resolved
else:
ungrouped = pcoll | Map(reify_timestamps).with_input_types(
tuple[K, V]).with_output_types(tuple[K, TypedWindowedValue[V]])

# TODO(https://github.com/apache/beam/issues/19785) Using global window as
# one of the standard window. This is to mitigate the Dataflow Java Runner
Expand Down Expand Up @@ -1012,11 +1052,19 @@ def __init__(self, num_buckets=None):

def expand(self, pcoll):
# type: (pvalue.PValue) -> pvalue.PCollection
compat_version = pcoll.pipeline.options.view_as(
pipeline_options.StreamingOptions).update_compatibility_version
if is_compat_version_prior_to_breaking_change(
compat_version, RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION):
reshuffle_step = ReshufflePerKey()
else:
reshuffle_step = ReshufflePerKey().with_input_types(
tuple[int, T]).with_output_types(tuple[int, T])
return (
pcoll | 'AddRandomKeys' >>
Map(lambda t: (random.randrange(0, self.num_buckets), t)
).with_input_types(T).with_output_types(tuple[int, T])
| ReshufflePerKey()
| reshuffle_step
| 'RemoveRandomKeys' >> Map(lambda t: t[1]).with_input_types(
tuple[int, T]).with_output_types(T))

Expand Down
76 changes: 76 additions & 0 deletions sdks/python/apache_beam/transforms/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,82 @@ def format_with_timestamp(element, timestamp=beam.DoFn.TimestampParam):
equal_to(expected_data),
label="formatted_after_reshuffle")

global _Unpicklable
global _UnpicklableCoder

class _Unpicklable(object):
def __init__(self, value):
self.value = value

def __getstate__(self):
raise NotImplementedError()

def __setstate__(self, state):
raise NotImplementedError()

class _UnpicklableCoder(beam.coders.Coder):
def encode(self, value):
return str(value.value).encode()

def decode(self, encoded):
return _Unpicklable(int(encoded.decode()))

def to_type_hint(self):
return _Unpicklable

def is_deterministic(self):
return True

def reshuffle_unpicklable_in_global_window_helper(
self, update_compatibility_version=None):
with TestPipeline(options=PipelineOptions(
update_compatibility_version=update_compatibility_version)) as pipeline:
data = [_Unpicklable(i) for i in range(5)]
expected_data = [0, 10, 20, 30, 40]
result = (
pipeline
| beam.Create(data)
| beam.WindowInto(GlobalWindows())
| beam.Reshuffle()
| beam.Map(lambda u: u.value * 10))
assert_that(result, equal_to(expected_data))

def test_reshuffle_unpicklable_in_global_window(self):
beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder)

self.reshuffle_unpicklable_in_global_window_helper()
# An exception is raised when running reshuffle on unpicklable objects
# prior to 2.64.0
self.assertRaises(
RuntimeError,
self.reshuffle_unpicklable_in_global_window_helper,
"2.63.0")

def reshuffle_unpicklable_in_non_global_window_helper(
self, update_compatibility_version=None):
with TestPipeline(options=PipelineOptions(
update_compatibility_version=update_compatibility_version)) as pipeline:
data = [_Unpicklable(i) for i in range(5)]
expected_data = [0, 0, 0, 10, 10, 10, 20, 20, 20, 30, 30, 30, 40, 40, 40]
result = (
pipeline
| beam.Create(data)
| beam.WindowInto(window.SlidingWindows(size=3, period=1))
| beam.Reshuffle()
| beam.Map(lambda u: u.value * 10))
assert_that(result, equal_to(expected_data))

def test_reshuffle_unpicklable_in_non_global_window(self):
beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder)

self.reshuffle_unpicklable_in_non_global_window_helper()
# An exception is raised when running reshuffle on unpicklable objects
# prior to 2.64.0
self.assertRaises(
RuntimeError,
self.reshuffle_unpicklable_in_non_global_window_helper,
"2.63.0")


class WithKeysTest(unittest.TestCase):
def setUp(self):
Expand Down
20 changes: 20 additions & 0 deletions sdks/python/apache_beam/typehints/native_type_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@
import sys
import types
import typing
from typing import Generic
from typing import TypeVar

from apache_beam.typehints import typehints

T = TypeVar('T')

_LOGGER = logging.getLogger(__name__)

# Describes an entry in the type map in convert_to_beam_type.
Expand Down Expand Up @@ -277,6 +281,18 @@ def is_builtin(typ):
return getattr(typ, '__origin__', None) in _BUILTINS


# During type inference of WindowedValue, we need to pass in the inner value
# type. This cannot be achieved immediately with WindowedValue class because it
# is not parameterized. Changing it to a generic class (e.g. WindowedValue[T])
# could work in theory. However, the class is cythonized and it seems that
# cython does not handle generic classes well.
# The workaround here is to create a separate class solely for the type
# inference purpose. This class should never be used for creating instances.
class TypedWindowedValue(Generic[T]):
def __init__(self, *args, **kwargs):
raise NotImplementedError("This class is solely for type inference")


def convert_to_beam_type(typ):
"""Convert a given typing type to a Beam type.

Expand Down Expand Up @@ -385,6 +401,10 @@ def convert_to_beam_type(typ):
match=_match_is_exactly_collection,
arity=1,
beam_type=typehints.Collection),
_TypeMapEntry(
match=_match_issubclass(TypedWindowedValue),
arity=1,
beam_type=typehints.WindowedValue),
]

# Find the first matching entry.
Expand Down
9 changes: 9 additions & 0 deletions sdks/python/apache_beam/typehints/typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,15 @@ def type_check(self, instance):
repr(self.inner_type),
instance.value.__class__.__name__))

def bind_type_variables(self, bindings):
bound_inner_type = bind_type_variables(self.inner_type, bindings)
if bound_inner_type == self.inner_type:
return self
return WindowedValue[bound_inner_type]

def __repr__(self):
return 'WindowedValue[%s]' % repr(self.inner_type)


class GeneratorHint(IteratorHint):
"""A Generator type hint.
Expand Down
Loading