Skip to content

Commit

Permalink
Merge pull request #109 from flatironinstitute/108-dask-slurmrunner-u…
Browse files Browse the repository at this point in the history
…pdate

108 dask slurmrunner update
  • Loading branch information
geoffwoollard authored Jan 17, 2025
2 parents 526c549 + 2705edd commit 2e5fc12
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ dependencies = [
"ecos",
"dask",
"dask[distributed]",
"dask_hpc_runner @ git+https://github.com/jacobtomlinson/dask-hpc-runner.git@main",
"dask-jobqueue",
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dask import delayed, compute
from dask.distributed import Client
from dask.diagnostics import ProgressBar
from dask_hpc_runner import SlurmRunner
from dask_jobqueue.slurm import SLURMRunner

from cryo_challenge._preprocessing.fourier_utils import downsample_volume

Expand Down Expand Up @@ -259,6 +259,9 @@ def parse_args():
parser.add_argument(
"--n_i", type=int, default=80, help="Number of volumes in set i"
)
parser.add_argument(
"--n_j", type=int, default=80, help="Number of volumes in set j"
)
parser.add_argument(
"--n_downsample_pix", type=int, default=20, help="Number of downsample pixels"
)
Expand Down Expand Up @@ -378,7 +381,7 @@ def main(args):
submission = torch.load(fname, weights_only=False)
volumes = submission["volumes"].to(torch_dtype)
volumes_i = volumes[: args.n_i]
volumes_j = volumes
volumes_j = volumes[: args.n_j]
n_downsample_pix = args.n_downsample_pix
top_k = args.top_k
exponent = args.exponent
Expand Down Expand Up @@ -411,7 +414,7 @@ def main(args):
args = parse_args()
if args.slurm:
job_id = os.environ["SLURM_JOB_ID"]
with SlurmRunner(
with SLURMRunner(
scheduler_file=args.scheduler_file,
) as runner:
# The runner object contains the scheduler address and can be passed directly to a client
Expand Down
4 changes: 2 additions & 2 deletions src/cryo_challenge/_map_to_map/map_to_map_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import mrcfile
import numpy as np
from dask.distributed import Client
from dask_hpc_runner import SlurmRunner
from dask_jobqueue.slurm import SLURMRunner

from .gromov_wasserstein.gw_weighted_voxels import get_distance_matrix_dask_gw

Expand Down Expand Up @@ -503,7 +503,7 @@ def get_distance_matrix(self, maps1, maps2, global_store_of_running_results):
scheduler_file = os.path.join(
extra_params["scheduler_file_dir"], f"scheduler-{job_id}.json"
)
with SlurmRunner(
with SLURMRunner(
scheduler_file=scheduler_file,
) as runner:
# The runner object contains the scheduler address and can be passed directly to a client
Expand Down

0 comments on commit 2e5fc12

Please sign in to comment.