Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelize umap with process pools #221

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/fibad/data_sets/inference_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import re
from collections.abc import Generator
from multiprocessing import Pool
from pathlib import Path
from typing import Optional, Union

Expand Down Expand Up @@ -197,6 +198,7 @@

self.all_ids = np.array([], dtype=np.int64)
self.all_batch_nums = np.array([], dtype=np.int64)
self.writer_pool = Pool()

Check warning on line 201 in src/fibad/data_sets/inference_dataset.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_sets/inference_dataset.py#L201

Added line #L201 was not covered by tests

def write_batch(self, ids: np.ndarray, tensors: list[np.ndarray]):
"""Write a batch of tensors into the dataset. This writes the whole batch immediately.
Expand Down Expand Up @@ -226,7 +228,10 @@
if savepath.exists():
RuntimeError(f"Writing objects in batch {self.batch_index} but {filename} already exists.")

np.save(savepath, structured_batch, allow_pickle=False)
self.writer_pool.apply_async(

Check warning on line 231 in src/fibad/data_sets/inference_dataset.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_sets/inference_dataset.py#L231

Added line #L231 was not covered by tests
func=np.save, args=(savepath, structured_batch), kwds={"allow_pickle": False}
)

self.all_ids = np.append(self.all_ids, ids)
self.all_batch_nums = np.append(self.all_batch_nums, np.full(batch_len, self.batch_index))

Expand All @@ -236,6 +241,11 @@
"""Writes out the batch index built up by this object over multiple write_batch calls.
See save_batch_index for details.
"""
# First ensure we are done writing out all batches
self.writer_pool.close()
self.writer_pool.join()

Check warning on line 246 in src/fibad/data_sets/inference_dataset.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_sets/inference_dataset.py#L245-L246

Added lines #L245 - L246 were not covered by tests

# Then write out the batch index.
InferenceDataSetWriter.save_batch_index(self.result_dir, self.all_ids, self.all_batch_nums)

@staticmethod
Expand Down
76 changes: 61 additions & 15 deletions src/fibad/verbs/umap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import pickle
import warnings
from argparse import ArgumentParser, Namespace
from multiprocessing import cpu_count
from pathlib import Path
from typing import Optional, Union

Expand Down Expand Up @@ -60,13 +62,21 @@
None
The method does not return anything but saves the UMAP representations to disk.
"""
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=FutureWarning)
return self._run(input_dir)

Check warning on line 67 in src/fibad/verbs/umap.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/verbs/umap.py#L65-L67

Added lines #L65 - L67 were not covered by tests

def _run(self, input_dir: Optional[Union[Path, str]] = None):
"""See run()"""
from multiprocessing import Pool

Check warning on line 71 in src/fibad/verbs/umap.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/verbs/umap.py#L71

Added line #L71 was not covered by tests

import umap
from tqdm.auto import tqdm

from fibad.config_utils import create_results_dir
from fibad.data_sets.inference_dataset import InferenceDataSet, InferenceDataSetWriter

reducer = umap.UMAP(**self.config["umap.UMAP"])
self.reducer = umap.UMAP(**self.config["umap.UMAP"])

Check warning on line 79 in src/fibad/verbs/umap.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/verbs/umap.py#L79

Added line #L79 was not covered by tests

# Set up the results directory where we will store our umapped output
results_dir = create_results_dir(self.config, "umap")
Expand All @@ -87,29 +97,65 @@
data_sample = inference_results[index_choices].numpy().reshape((sample_size, -1))

# Fit a single reducer on the sampled data
reducer.fit(data_sample)
self.reducer.fit(data_sample)

Check warning on line 100 in src/fibad/verbs/umap.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/verbs/umap.py#L100

Added line #L100 was not covered by tests

# Save the reducer to our results directory
with open(results_dir / "umap.pickle", "wb") as f:
pickle.dump(reducer, f)
pickle.dump(self.reducer, f)

Check warning on line 104 in src/fibad/verbs/umap.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/verbs/umap.py#L104

Added line #L104 was not covered by tests

# Run all data through the reducer in batches, writing it out as we go.
batch_size = self.config["data_loader"]["batch_size"]
num_batches = int(np.ceil(total_length / batch_size))

all_indexes = np.arange(0, total_length)
all_ids = np.array([int(i) for i in inference_results.ids()])
for batch_indexes in tqdm(
np.array_split(all_indexes, num_batches),
desc="Creating Lower Dimensional Representation using UMAP",
total=num_batches,
):
# We flatten all dimensions of the input array except the dimension
# corresponding to batch elements. This ensures that all inputs to
# the UMAP algorithm are flattend per input item in the batch
batch = inference_results[batch_indexes].reshape(len(batch_indexes), -1)
batch_ids = all_ids[batch_indexes]
transformed_batch = reducer.transform(batch)
umap_results.write_batch(batch_ids, transformed_batch)

# Process pool to do all the transforms
with Pool(processes=cpu_count()) as pool:

Check warning on line 114 in src/fibad/verbs/umap.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/verbs/umap.py#L114

Added line #L114 was not covered by tests
# Generator expression that gives a batch tuple composed of:
# batch ids, inference results
args = (

Check warning on line 117 in src/fibad/verbs/umap.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/verbs/umap.py#L117

Added line #L117 was not covered by tests
(
all_ids[batch_indexes],
# We flatten all dimensions of the input array except the dimension
# corresponding to batch elements. This ensures that all inputs to
# the UMAP algorithm are flattend per input item in the batch
inference_results[batch_indexes].reshape(len(batch_indexes), -1),
)
for batch_indexes in np.array_split(all_indexes, num_batches)
)

# iterate over the mapped results to write out the umapped points
# imap returns results as they complete so writing should complete in parallel for large datasets
for batch_ids, transformed_batch in tqdm(

Check warning on line 130 in src/fibad/verbs/umap.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/verbs/umap.py#L130

Added line #L130 was not covered by tests
pool.imap(self._transform_batch, args),
desc="Creating LowerDimensional Representation using UMAP:",
total=num_batches,
):
logger.debug("Writing a batch out async...")
umap_results.write_batch(batch_ids, transformed_batch)

Check warning on line 136 in src/fibad/verbs/umap.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/verbs/umap.py#L135-L136

Added lines #L135 - L136 were not covered by tests

umap_results.write_index()

def _transform_batch(self, batch_tuple: tuple):
"""Private helper to transform a single batch

Parameters
----------
batch_tuple : tuple()
first element is the IDs of the batch as a numpy array
second element is the inference results to transform as a numpy array with shape (batch_len, N)
where N is the total number of dimensions in the inference result. Caller flattens all inference
result axes for us.

Returns
-------
tuple
first element is the ids of the batch as a numpy array
second element is the results of running the umap transform on the input as a numpy array.
"""
batch_ids, batch = batch_tuple
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=FutureWarning)
logger.debug("Transforming a batch ...")
return (batch_ids, self.reducer.transform(batch))

Check warning on line 161 in src/fibad/verbs/umap.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/verbs/umap.py#L157-L161

Added lines #L157 - L161 were not covered by tests
Loading