Skip to content

Commit

Permalink
Let local jar and package file be relative.
Browse files Browse the repository at this point in the history
  • Loading branch information
robertwb committed Jan 18, 2025
1 parent 5dfb9a0 commit b2e1eb7
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 14 deletions.
42 changes: 31 additions & 11 deletions sdks/python/apache_beam/yaml/yaml_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def create_external_transform(self, urn, args):
self._service)

@classmethod
def provider_from_spec(cls, spec):
def provider_from_spec(cls, source_path, spec):
from apache_beam.yaml.yaml_transform import SafeLineLoader
for required in ('type', 'transforms'):
if required not in spec:
Expand All @@ -225,7 +225,10 @@ def provider_from_spec(cls, spec):
config['version'] = beam_version
if type in cls._provider_types:
try:
result = cls._provider_types[type](urns, **config)
constructor = cls._provider_types[type]
if 'provider_base_path' in inspect.signature(constructor).parameters:
config['provider_base_path'] = source_path
result = constructor(urns, **config)
if not hasattr(result, 'to_json'):
result.to_json = lambda: spec
return result
Expand All @@ -248,12 +251,13 @@ def apply(constructor):


@ExternalProvider.register_provider_type('javaJar')
def java_jar(urns, jar: str):
def java_jar(urns, provider_base_path, jar: str):
if not os.path.exists(jar):
parsed = urllib.parse.urlparse(jar)
if not parsed.scheme or not parsed.netloc:
raise ValueError(f'Invalid path or url: {jar}')
return ExternalJavaProvider(urns, lambda: jar)
return ExternalJavaProvider(
urns, lambda: _join_url_or_filepath(provider_base_path, jar))


@ExternalProvider.register_provider_type('mavenJar')
Expand Down Expand Up @@ -334,9 +338,9 @@ def cache_artifacts(self):


@ExternalProvider.register_provider_type('python')
def python(urns, packages=()):
def python(urns, provider_base_path, packages=()):
if packages:
return ExternalPythonProvider(urns, packages)
return ExternalPythonProvider(urns, provider_base_path, packages)
else:
return InlineProvider({
name:
Expand All @@ -347,8 +351,18 @@ def python(urns, packages=()):

@ExternalProvider.register_provider_type('pythonPackage')
class ExternalPythonProvider(ExternalProvider):
def __init__(self, urns, packages: Iterable[str]):
super().__init__(urns, PypiExpansionService(packages))
def __init__(self, urns, provider_base_path, packages: Iterable[str]):
def is_path_or_urn(package):
return (
'/' in package or urllib.parse.urlparse(package).scheme or
os.path.exists(package))

super().__init__(
urns,
PypiExpansionService([
_join_url_or_filepath(provider_base_path, package)
if is_path_or_urn(package) else package for package in packages
]))

def available(self):
return True # If we're running this script, we have Python installed.
Expand Down Expand Up @@ -1118,10 +1132,16 @@ def __exit__(self, *args):

@ExternalProvider.register_provider_type('renaming')
class RenamingProvider(Provider):
def __init__(self, transforms, mappings, underlying_provider, defaults=None):
def __init__(
self,
transforms,
provider_base_path,
mappings,
underlying_provider,
defaults=None):
if isinstance(underlying_provider, dict):
underlying_provider = ExternalProvider.provider_from_spec(
underlying_provider)
provider_base_path, underlying_provider)
self._transforms = transforms
self._underlying_provider = underlying_provider
for transform in transforms.keys():
Expand Down Expand Up @@ -1284,7 +1304,7 @@ def parse_providers(source_path,
f"{source_path}:{SafeLineLoader.get_line(provider_spec)}\n" +
str(exn)) from exn
else:
yield ExternalProvider.provider_from_spec(provider_spec)
yield ExternalProvider.provider_from_spec(source_path, provider_spec)


def merge_providers(*provider_sets) -> Mapping[str, Iterable[Provider]]:
Expand Down
9 changes: 6 additions & 3 deletions sdks/python/apache_beam/yaml/yaml_provider_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def tearDownClass(cls):

@mock.patch(
'apache_beam.yaml.yaml_provider.ExternalProvider.provider_from_spec',
lambda x: x)
lambda _,
x: x)
def test_include_file(self):
flattened = [
SafeLineLoader.strip_metadata(spec)
Expand All @@ -118,7 +119,8 @@ def test_include_file(self):

@mock.patch(
'apache_beam.yaml.yaml_provider.ExternalProvider.provider_from_spec',
lambda x: x)
lambda _,
x: x)
def test_include_url(self):
flattened = [
SafeLineLoader.strip_metadata(spec)
Expand All @@ -139,7 +141,8 @@ def test_include_url(self):

@mock.patch(
'apache_beam.yaml.yaml_provider.ExternalProvider.provider_from_spec',
lambda x: x)
lambda _,
x: x)
def test_nested_include(self):
flattened = [
SafeLineLoader.strip_metadata(spec)
Expand Down

0 comments on commit b2e1eb7

Please sign in to comment.