Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make explicit arg for pip args #2403

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion sdk/python/kubeflow/training/api/training_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def create_job(
env_vars: Optional[
Union[Dict[str, str], List[Union[models.V1EnvVar, models.V1EnvVar]]]
] = None,
pip_args: Optional[List[str]] = None,
):
"""Create the Training Job.
Job can be created using one of the following options:
Expand Down Expand Up @@ -418,7 +419,9 @@ def create_job(
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvVar.md)
or a kubernetes.client.models.V1EnvFromSource (documented here:
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvFromSource.md)
pip_args: List of args to pass to pip install that applies to all packages specified in
packages_to_install. For a full list of args, see the pip documentation
https://pip.pypa.io/en/stable/cli/pip_install/
Raises:
ValueError: Invalid input parameters.
TimeoutError: Timeout to create Job.
Expand Down Expand Up @@ -486,6 +489,7 @@ def create_job(
train_func_parameters=parameters,
packages_to_install=packages_to_install,
pip_index_url=pip_index_url,
pip_args=pip_args,
)

# Get Training Container template.
Expand Down
9 changes: 5 additions & 4 deletions sdk/python/kubeflow/training/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,21 +110,21 @@ def has_condition(conditions: List[models.V1JobCondition], condition_type: str)


def get_script_for_python_packages(
packages_to_install: List[str], pip_index_url: str
packages_to_install: List[str], pip_index_url: str, pip_args: Optional[List[str]]
) -> str:
"""
Get init script to install Python packages from the given pip index URL.
"""
packages_str = " ".join([str(package) for package in packages_to_install])

pip_args_str = " ".join(pip_args) if pip_args is not None else ""
script_for_python_packages = textwrap.dedent(
f"""
if ! [ -x "$(command -v pip)" ]; then
python -m ensurepip || python -m ensurepip --user || apt-get install python-pip
fi
PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet \
--no-warn-script-location --index-url {pip_index_url} {packages_str}
--no-warn-script-location --index-url {pip_index_url} {pip_args_str} {packages_str}
"""
)

Expand All @@ -137,6 +137,7 @@ def get_command_using_train_func(
train_func_parameters: Optional[Dict[str, Any]] = None,
packages_to_install: Optional[List[str]] = None,
pip_index_url: str = constants.DEFAULT_PIP_INDEX_URL,
pip_args: Optional[List[str]] = None
) -> Tuple[List[str], List[str]]:
"""
Get container args and command from the given training function and parameters.
Expand Down Expand Up @@ -180,7 +181,7 @@ def get_command_using_train_func(
# Install Python packages if that is required.
if packages_to_install is not None:
exec_script = (
get_script_for_python_packages(packages_to_install, pip_index_url)
get_script_for_python_packages(packages_to_install, pip_index_url, pip_args)
+ exec_script
)

Expand Down