Skip to content

Commit

Permalink
Merge pull request #33647 Support relative paths for yaml includes an…
Browse files Browse the repository at this point in the history
…d resources.
  • Loading branch information
robertwb authored Jan 21, 2025
2 parents 89cee24 + b2e1eb7 commit 9e0950c
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 68 deletions.
5 changes: 4 additions & 1 deletion sdks/python/apache_beam/yaml/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,10 @@ def run(argv=None):
known_args.jinja_variables)}) as p:
print("Building pipeline...")
yaml_transform.expand_pipeline(
p, pipeline_spec, validate_schema=known_args.json_schema_validation)
p,
pipeline_spec,
validate_schema=known_args.json_schema_validation,
pipeline_path=known_args.yaml_pipeline_file)
print("Running pipeline...")


Expand Down
5 changes: 2 additions & 3 deletions sdks/python/apache_beam/yaml/yaml_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from typing import Optional

import fastavro
import yaml

import apache_beam as beam
import apache_beam.io as beam_io
Expand Down Expand Up @@ -573,5 +572,5 @@ def write_to_iceberg(


def io_providers():
with open(os.path.join(os.path.dirname(__file__), 'standard_io.yaml')) as fin:
return yaml_provider.parse_providers(yaml.load(fin, Loader=yaml.SafeLoader))
return yaml_provider.load_providers(
os.path.join(os.path.dirname(__file__), 'standard_io.yaml'))
125 changes: 84 additions & 41 deletions sdks/python/apache_beam/yaml/yaml_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
various PTransforms."""

import collections
import functools
import hashlib
import inspect
import json
Expand All @@ -33,14 +34,12 @@
import warnings
from collections.abc import Callable
from collections.abc import Iterable
from collections.abc import Iterator
from collections.abc import Mapping
from typing import Any
from typing import Optional

import docstring_parser
import yaml
from yaml.loader import SafeLoader

import apache_beam as beam
import apache_beam.dataframe.io
Expand Down Expand Up @@ -205,7 +204,7 @@ def create_external_transform(self, urn, args):
self._service)

@classmethod
def provider_from_spec(cls, spec):
def provider_from_spec(cls, source_path, spec):
from apache_beam.yaml.yaml_transform import SafeLineLoader
for required in ('type', 'transforms'):
if required not in spec:
Expand All @@ -226,7 +225,10 @@ def provider_from_spec(cls, spec):
config['version'] = beam_version
if type in cls._provider_types:
try:
result = cls._provider_types[type](urns, **config)
constructor = cls._provider_types[type]
if 'provider_base_path' in inspect.signature(constructor).parameters:
config['provider_base_path'] = source_path
result = constructor(urns, **config)
if not hasattr(result, 'to_json'):
result.to_json = lambda: spec
return result
Expand All @@ -249,12 +251,13 @@ def apply(constructor):


@ExternalProvider.register_provider_type('javaJar')
def java_jar(urns, jar: str):
def java_jar(urns, provider_base_path, jar: str):
if not os.path.exists(jar):
parsed = urllib.parse.urlparse(jar)
if not parsed.scheme or not parsed.netloc:
raise ValueError(f'Invalid path or url: {jar}')
return ExternalJavaProvider(urns, lambda: jar)
return ExternalJavaProvider(
urns, lambda: _join_url_or_filepath(provider_base_path, jar))


@ExternalProvider.register_provider_type('mavenJar')
Expand Down Expand Up @@ -335,9 +338,9 @@ def cache_artifacts(self):


@ExternalProvider.register_provider_type('python')
def python(urns, packages=()):
def python(urns, provider_base_path, packages=()):
if packages:
return ExternalPythonProvider(urns, packages)
return ExternalPythonProvider(urns, provider_base_path, packages)
else:
return InlineProvider({
name:
Expand All @@ -348,8 +351,18 @@ def python(urns, packages=()):

@ExternalProvider.register_provider_type('pythonPackage')
class ExternalPythonProvider(ExternalProvider):
def __init__(self, urns, packages: Iterable[str]):
super().__init__(urns, PypiExpansionService(packages))
def __init__(self, urns, provider_base_path, packages: Iterable[str]):
def is_path_or_urn(package):
return (
'/' in package or urllib.parse.urlparse(package).scheme or
os.path.exists(package))

super().__init__(
urns,
PypiExpansionService([
_join_url_or_filepath(provider_base_path, package)
if is_path_or_urn(package) else package for package in packages
]))

def available(self):
return True # If we're running this script, we have Python installed.
Expand Down Expand Up @@ -1119,10 +1132,16 @@ def __exit__(self, *args):

@ExternalProvider.register_provider_type('renaming')
class RenamingProvider(Provider):
def __init__(self, transforms, mappings, underlying_provider, defaults=None):
def __init__(
self,
transforms,
provider_base_path,
mappings,
underlying_provider,
defaults=None):
if isinstance(underlying_provider, dict):
underlying_provider = ExternalProvider.provider_from_spec(
underlying_provider)
provider_base_path, underlying_provider)
self._transforms = transforms
self._underlying_provider = underlying_provider
for transform in transforms.keys():
Expand Down Expand Up @@ -1225,41 +1244,67 @@ def cache_artifacts(self):
self._underlying_provider.cache_artifacts()


def flatten_included_provider_specs(
provider_specs: Iterable[Mapping]) -> Iterator[Mapping]:
def _as_list(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return list(func(*args, **kwargs))

return wrapper


def _join_url_or_filepath(base, path):
base_scheme = urllib.parse.urlparse(base, '').scheme
path_scheme = urllib.parse.urlparse(path, base_scheme).scheme
if path_scheme != base_scheme:
return path
elif base_scheme and base_scheme in urllib.parse.uses_relative:
return urllib.parse.urljoin(base, path)
else:
return FileSystems.join(FileSystems.split(base)[0], path)


def _read_url_or_filepath(path):
scheme = urllib.parse.urlparse(path, '').scheme
if scheme and scheme in urllib.parse.uses_netloc:
with urllib.request.urlopen(path) as response:
return response.read()
else:
with FileSystems.open(path) as fin:
return fin.read()


def load_providers(source_path: str) -> Iterable[Provider]:
from apache_beam.yaml.yaml_transform import SafeLineLoader
provider_specs = yaml.load(
_read_url_or_filepath(source_path), Loader=SafeLineLoader)
if not isinstance(provider_specs, list):
raise ValueError(f"Provider file {source_path} must be a list of Providers")
return parse_providers(source_path, provider_specs)


@_as_list
def parse_providers(source_path,
provider_specs: Iterable[Mapping]) -> Iterable[Provider]:
from apache_beam.yaml.yaml_transform import SafeLineLoader
for provider_spec in provider_specs:
if 'include' in provider_spec:
if len(SafeLineLoader.strip_metadata(provider_spec)) != 1:
raise ValueError(
f"When using include, it must be the only parameter: "
f"{provider_spec} "
f"at line {{SafeLineLoader.get_line(provider_spec)}}")
include_uri = provider_spec['include']
f"at {source_path}:{SafeLineLoader.get_line(provider_spec)}")
include_path = _join_url_or_filepath(
source_path, provider_spec['include'])
try:
with urllib.request.urlopen(include_uri) as response:
content = response.read()
except (ValueError, urllib.error.URLError) as exn:
if 'unknown url type' in str(exn):
with FileSystems.open(include_uri) as fin:
content = fin.read()
else:
raise
included_providers = yaml.load(content, Loader=SafeLineLoader)
if not isinstance(included_providers, list):
yield from load_providers(include_path)

except Exception as exn:
raise ValueError(
f"Included file {include_uri} must be a list of Providers "
f"at line {{SafeLineLoader.get_line(provider_spec)}}")
yield from flatten_included_provider_specs(included_providers)
f"Error loading providers from {include_path} included at "
f"{source_path}:{SafeLineLoader.get_line(provider_spec)}\n" +
str(exn)) from exn
else:
yield provider_spec


def parse_providers(provider_specs: Iterable[Mapping]) -> Iterable[Provider]:
return [
ExternalProvider.provider_from_spec(provider_spec)
for provider_spec in flatten_included_provider_specs(provider_specs)
]
yield ExternalProvider.provider_from_spec(source_path, provider_spec)


def merge_providers(*provider_sets) -> Mapping[str, Iterable[Provider]]:
Expand All @@ -1283,9 +1328,6 @@ def standard_providers():
from apache_beam.yaml.yaml_mapping import create_mapping_providers
from apache_beam.yaml.yaml_join import create_join_providers
from apache_beam.yaml.yaml_io import io_providers
with open(os.path.join(os.path.dirname(__file__),
'standard_providers.yaml')) as fin:
standard_providers = yaml.load(fin, Loader=SafeLoader)

return merge_providers(
YamlProviders.create_builtin_provider(),
Expand All @@ -1294,4 +1336,5 @@ def standard_providers():
create_combine_providers(),
create_join_providers(),
io_providers(),
parse_providers(standard_providers))
load_providers(
os.path.join(os.path.dirname(__file__), 'standard_providers.yaml')))
56 changes: 36 additions & 20 deletions sdks/python/apache_beam/yaml/yaml_provider_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tempfile
import unittest

import mock
import yaml

import apache_beam as beam
Expand Down Expand Up @@ -88,21 +89,26 @@ def setUpClass(cls):
cls.to_include_nested = os.path.join(
cls.tempdir.name, 'nested_providers.yaml')
with open(cls.to_include_nested, 'w') as fout:
yaml.dump([{'include': cls.to_include}, cls.EXTRA_PROVIDER], fout)
yaml.dump([{'include': './providers.yaml'}, cls.EXTRA_PROVIDER], fout)

@classmethod
def tearDownClass(cls):
cls.tempdir.cleanup()

@mock.patch(
'apache_beam.yaml.yaml_provider.ExternalProvider.provider_from_spec',
lambda _,
x: x)
def test_include_file(self):
flattened = [
SafeLineLoader.strip_metadata(spec)
for spec in yaml_provider.flatten_included_provider_specs([
self.INLINE_PROVIDER,
{
'include': self.to_include
},
])
for spec in yaml_provider.parse_providers(
'', [
self.INLINE_PROVIDER,
{
'include': self.to_include
},
])
]

self.assertEqual([
Expand All @@ -111,15 +117,20 @@ def test_include_file(self):
],
flattened)

@mock.patch(
'apache_beam.yaml.yaml_provider.ExternalProvider.provider_from_spec',
lambda _,
x: x)
def test_include_url(self):
flattened = [
SafeLineLoader.strip_metadata(spec)
for spec in yaml_provider.flatten_included_provider_specs([
self.INLINE_PROVIDER,
{
'include': 'file:///' + self.to_include
},
])
for spec in yaml_provider.parse_providers(
'', [
self.INLINE_PROVIDER,
{
'include': 'file:///' + self.to_include
},
])
]

self.assertEqual([
Expand All @@ -128,15 +139,20 @@ def test_include_url(self):
],
flattened)

@mock.patch(
'apache_beam.yaml.yaml_provider.ExternalProvider.provider_from_spec',
lambda _,
x: x)
def test_nested_include(self):
flattened = [
SafeLineLoader.strip_metadata(spec)
for spec in yaml_provider.flatten_included_provider_specs([
self.INLINE_PROVIDER,
{
'include': self.to_include_nested
},
])
for spec in yaml_provider.parse_providers(
'', [
self.INLINE_PROVIDER,
{
'include': self.to_include_nested
},
])
]

self.assertEqual([
Expand Down Expand Up @@ -195,7 +211,7 @@ def test_yaml_define_provider(self):
result = p | YamlTransform(
pipeline,
providers=yaml_provider.parse_providers(
yaml.load(providers, Loader=SafeLineLoader)))
'', yaml.load(providers, Loader=SafeLineLoader)))
assert_that(
result | beam.Map(lambda x: (x.element, x.power)),
equal_to([(0, 0), (1, 1), (2, 4), (3, 9)]))
Expand Down
9 changes: 6 additions & 3 deletions sdks/python/apache_beam/yaml/yaml_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,9 @@ def expand_composite_transform(spec, scope):
for (key, value) in empty_if_explicitly_empty(spec['input']).items()
},
spec['transforms'],
# TODO(robertwb): Are scoped providers ever used? Worth supporting?
yaml_provider.merge_providers(
yaml_provider.parse_providers(spec.get('providers', [])),
yaml_provider.parse_providers('', spec.get('providers', [])),
scope.providers),
scope.input_providers)

Expand Down Expand Up @@ -1027,7 +1028,8 @@ def expand_pipeline(
pipeline,
pipeline_spec,
providers=None,
validate_schema='generic' if jsonschema is not None else None):
validate_schema='generic' if jsonschema is not None else None,
pipeline_path=''):
if isinstance(pipeline_spec, str):
pipeline_spec = yaml.load(pipeline_spec, Loader=SafeLineLoader)
# TODO(robertwb): It's unclear whether this gives as good of errors, but
Expand All @@ -1038,5 +1040,6 @@ def expand_pipeline(
return YamlTransform(
pipeline_as_composite(pipeline_spec['pipeline']),
yaml_provider.merge_providers(
yaml_provider.parse_providers(pipeline_spec.get('providers', [])),
yaml_provider.parse_providers(
pipeline_path, pipeline_spec.get('providers', [])),
providers or {})).expand(beam.pvalue.PBegin(pipeline))

0 comments on commit 9e0950c

Please sign in to comment.