Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use yaml.dump over old dump command, stripped double .cwl, and fixed import issues #287

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 79 additions & 37 deletions cwl_utils/graph_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,39 @@
import json
import os
import sys
from typing import IO, TYPE_CHECKING, Any, List, MutableMapping, Set, Union, cast
from io import TextIOWrapper
from pathlib import Path
from typing import (
IO,
TYPE_CHECKING,
Any,
List,
MutableMapping,
Set,
Union,
cast,
Optional, TextIO,
)
import logging
import re

from cwlformat.formatter import stringify_dict
from ruamel.yaml.dumper import RoundTripDumper
from ruamel.yaml.main import YAML, dump
from ruamel.yaml.comments import Format
from ruamel.yaml.main import YAML
from ruamel.yaml.representer import RoundTripRepresenter
from schema_salad.sourceline import SourceLine, add_lc_filename

from cwl_utils.loghandler import _logger as _cwlutilslogger

if TYPE_CHECKING:
from _typeshed import StrPath

_logger = logging.getLogger("cwl-graph-split") # pylint: disable=invalid-name
defaultStreamHandler = logging.StreamHandler() # pylint: disable=invalid-name
_logger.addHandler(defaultStreamHandler)
_logger.setLevel(logging.INFO)
_cwlutilslogger.setLevel(100)


def arg_parser() -> argparse.ArgumentParser:
"""Build the argument parser."""
Expand Down Expand Up @@ -81,11 +103,11 @@ def run(args: List[str]) -> int:


def graph_split(
sourceIO: IO[str],
output_dir: "StrPath",
output_format: str,
mainfile: str,
pretty: bool,
sourceIO: IO[str],
output_dir: "StrPath",
output_format: str,
mainfile: str,
pretty: bool,
) -> None:
"""Loop over the provided packed CWL document and split it up."""
yaml = YAML(typ="rt")
Expand All @@ -99,8 +121,15 @@ def graph_split(

version = source.pop("cwlVersion")

# Check outdir parent exists
if not Path(output_dir).parent.is_dir():
raise NotADirectoryError(f"Parent directory of {output_dir} does not exist")
# If output_dir is not a directory, create it
if not Path(output_dir).is_dir():
os.mkdir(output_dir)

def my_represent_none(
self: Any, data: Any
self: Any, data: Any
) -> Any: # pylint: disable=unused-argument
"""Force clean representation of 'null'."""
return self.represent_scalar("tag:yaml.org,2002:null", "null")
Expand All @@ -110,7 +139,7 @@ def my_represent_none(
for entry in source["$graph"]:
entry_id = entry.pop("id").lstrip("#")
entry["cwlVersion"] = version
imports = rewrite(entry, entry_id)
imports = rewrite(entry, entry_id, Path(output_dir))
if imports:
for import_name in imports:
rewrite_types(entry, f"#{import_name}", False)
Expand All @@ -120,46 +149,46 @@ def my_represent_none(
else:
entry_id = mainfile

output_file = os.path.join(output_dir, entry_id + ".cwl")
output_file = Path(output_dir) / (re.sub(".cwl$", "", entry_id) + ".cwl")
if output_format == "json":
json_dump(entry, output_file)
elif output_format == "yaml":
yaml_dump(entry, output_file, pretty)


def rewrite(document: Any, doc_id: str) -> Set[str]:
def rewrite(document: Any, doc_id: str, output_dir: Path, pretty: Optional[bool] = False) -> Set[str]:
"""Rewrite the given element from the CWL $graph."""
imports = set()
if isinstance(document, list) and not isinstance(document, str):
for entry in document:
imports.update(rewrite(entry, doc_id))
imports.update(rewrite(entry, doc_id, output_dir, pretty))
elif isinstance(document, dict):
this_id = document["id"] if "id" in document else None
for key, value in document.items():
with SourceLine(document, key, Exception):
if key == "run" and isinstance(value, str) and value[0] == "#":
document[key] = f"{value[1:]}.cwl"
document[key] = f"{re.sub('.cwl$', '', value[1:])}.cwl"
elif key in ("id", "outputSource") and value.startswith("#" + doc_id):
document[key] = value[len(doc_id) + 2 :]
document[key] = value[len(doc_id) + 2:]
elif key == "out" and isinstance(value, list):

def rewrite_id(entry: Any) -> Union[MutableMapping[Any, Any], str]:
if isinstance(entry, MutableMapping):
if entry["id"].startswith(this_id):
entry["id"] = cast(str, entry["id"])[len(this_id) + 1 :]
entry["id"] = cast(str, entry["id"])[len(this_id) + 1:]
return entry
elif isinstance(entry, str):
if this_id and entry.startswith(this_id):
return entry[len(this_id) + 1 :]
return entry[len(this_id) + 1:]
return entry
raise Exception(f"{entry} is neither a dictionary nor string.")

document[key][:] = [rewrite_id(entry) for entry in value]
elif key in ("source", "scatter", "items", "format"):
if (
isinstance(value, str)
and value.startswith("#")
and "/" in value
isinstance(value, str)
and value.startswith("#")
and "/" in value
):
referrant_file, sub = value[1:].split("/", 1)
if referrant_file == doc_id:
Expand All @@ -170,22 +199,22 @@ def rewrite_id(entry: Any) -> Union[MutableMapping[Any, Any], str]:
new_sources = list()
for entry in value:
if entry.startswith("#" + doc_id):
new_sources.append(entry[len(doc_id) + 2 :])
new_sources.append(entry[len(doc_id) + 2:])
else:
new_sources.append(entry)
document[key] = new_sources
elif key == "$import":
rewrite_import(document)
elif key == "class" and value == "SchemaDefRequirement":
return rewrite_schemadef(document)
return rewrite_schemadef(document, output_dir, pretty)
else:
imports.update(rewrite(value, doc_id))
imports.update(rewrite(value, doc_id, output_dir, pretty))
return imports


def rewrite_import(document: MutableMapping[str, Any]) -> None:
"""Adjust the $import directive."""
external_file = document["$import"].split("/")[0][1:]
external_file = document["$import"].split("/")[0].lstrip("#")
document["$import"] = external_file


Expand All @@ -201,7 +230,7 @@ def rewrite_types(field: Any, entry_file: str, sameself: bool) -> None:
if key == name:
if isinstance(value, str) and value.startswith(entry_file):
if sameself:
field[key] = value[len(entry_file) + 1 :]
field[key] = value[len(entry_file) + 1:]
else:
field[key] = "{d[0]}#{d[1]}".format(
d=value[1:].split("/", 1)
Expand All @@ -213,19 +242,19 @@ def rewrite_types(field: Any, entry_file: str, sameself: bool) -> None:
rewrite_types(entry, entry_file, sameself)


def rewrite_schemadef(document: MutableMapping[str, Any]) -> Set[str]:
def rewrite_schemadef(document: MutableMapping[str, Any], output_dir: Path, pretty: Optional[bool] = False) -> Set[str]:
"""Dump the schemadefs to their own file."""
for entry in document["types"]:
if "$import" in entry:
rewrite_import(entry)
elif "name" in entry and "/" in entry["name"]:
entry_file, entry["name"] = entry["name"].split("/")
entry_file, entry["name"] = entry["name"].lstrip("#").split("/")
for field in entry["fields"]:
field["name"] = field["name"].split("/")[2]
rewrite_types(field, entry_file, True)
with open(entry_file[1:], "a", encoding="utf-8") as entry_handle:
dump([entry], entry_handle, Dumper=RoundTripDumper)
entry["$import"] = entry_file[1:]
with open(output_dir / entry_file, "a", encoding="utf-8") as entry_handle:
yaml_dump(entry, entry_handle, pretty)
entry["$import"] = entry_file
del entry["name"]
del entry["type"]
del entry["fields"]
Expand All @@ -251,20 +280,33 @@ def json_dump(entry: Any, output_file: str) -> None:
json.dump(entry, result_handle, indent=4)


def yaml_dump(entry: Any, output_file: str, pretty: bool) -> None:
def yaml_dump(entry: Any, output_file_or_handle: Optional[Union[str, Path, TextIOWrapper, TextIO]], pretty: bool) -> None:
"""Output object as YAML."""
yaml = YAML(typ="rt")
yaml = YAML(typ="rt", pure=True)
yaml.default_flow_style = False
yaml.map_indent = 4
yaml.sequence_indent = 2
with open(output_file, "w", encoding="utf-8") as result_handle:
yaml.indent = 4
yaml.block_seq_indent = 2

if isinstance(output_file_or_handle, (str, Path)):
with open(output_file_or_handle, "w", encoding="utf-8") as result_handle:
if pretty:
result_handle.write(stringify_dict(entry))
else:
yaml.dump(
entry,
result_handle
)
elif isinstance(output_file_or_handle, (TextIOWrapper, TextIO)):
if pretty:
result_handle.write(stringify_dict(entry))
output_file_or_handle.write(stringify_dict(entry))
else:
yaml.dump(
entry,
result_handle,
output_file_or_handle
)
else:
raise ValueError(
f"output_file_or_handle must be a string or a file handle but got {type(output_file_or_handle)}")


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ requests
schema-salad >= 8.5, < 9
ruamel.yaml >= 0.17.6, < 0.19
importlib_resources;python_version<'3.9'
cwlformat >= 2022.2.18
Loading