Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #21 and add export functionality. #22

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 43 additions & 11 deletions pdm-plugin-torch/pdm_plugin_torch/config.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,77 @@
"""Plugin configuration."""
from __future__ import annotations

from dataclasses import dataclass, field


@dataclass(frozen=True)
class Configuration:
"""
Plugin configuration.

Attributes:
dependencies: list of top level dependencies.
enable_cpu: CPU feature flag.
enable_cuda: CUDA feature flag.
enable_rocm: ROCm feature flag.
cuda_versions: list of versions for CUDA to support.
rocm_versions: list of ROCm versions to support.
lockfile: path to the lock file to use.
"""

# Dependency list.
dependencies: list[str]
# Feature flags.
enable_cpu: bool = False

enable_cuda: bool = False
cuda_versions: list[str] = field(default_factory=list)

enable_rocm: bool = False
# Version identifiers for the different possible versioned dependencies.
cuda_versions: list[str] = field(default_factory=list)
rocm_versions: list[str] = field(default_factory=list)

# Lockfile configuration.
lockfile: str = "torch.lock"

@staticmethod
def from_toml(data: dict[str, str | list[str] | bool]) -> "Configuration":
"""
Create a configuration object from a pyproject.toml configuration file.

Args:
data: parsed TOML of the pyproject file.

Returns:
Configuration object.
"""
fixed_dashes = {k.replace("-", "_"): v for (k, v) in data.items()}

return Configuration(**fixed_dashes)

@property
def variants(self):
resolves = {}
def variants(self) -> dict[str, tuple[str, str]]:
"""
Get resolution URL and build identifier for all configured variants for the plugin.

Returns:
A dictionary of torch build alternatives to a tuple of
(resolution URL, build identifier).
"""
resolves = {}
if self.enable_cpu:
# We can omit the build identifier for the CPU only versions
# since the resolution at the CPU URL works correctly for all
# versions only without a tag (see the MacOS builds at
# https://download.pytorch.org/whl/cpu).
resolves["cpu"] = ("https://download.pytorch.org/whl/cpu", "")
if self.enable_cuda:
for cuda_version in self.cuda_versions:
resolves[cuda_version] = (
f"https://download.pytorch.org/whl/{cuda_version}/",
f"+{cuda_version}",
)

if self.enable_rocm:
for rocm_version in self.rocm_versions:
resolves[f"rocm{rocm_version}"] = (
"https://download.pytorch.org/whl/",
f"+rocm{rocm_version}",
)

if self.enable_cpu:
resolves["cpu"] = ("https://download.pytorch.org/whl/cpu", "+cpu")

return resolves
Loading