Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] Remove AxisName.getattr_from from ExperimentAxisQuery #3557

Merged
merged 6 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 51 additions & 87 deletions apis/python/src/tiledbsoma/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,9 @@
Callable,
Dict,
Literal,
Mapping,
Protocol,
Sequence,
TypeVar,
cast,
overload,
)

import attrs
Expand Down Expand Up @@ -88,30 +85,7 @@ class AxisName(enum.Enum):

@property
def value(self) -> Literal["obs", "var"]:
return super().value # type: ignore[no-any-return]

@overload
def getattr_from(self, __source: _HasObsVar[_T]) -> _T: ...

@overload
def getattr_from(
self, __source: Any, *, pre: Literal[""], suf: Literal[""]
) -> object: ...

@overload
def getattr_from(
self, __source: Any, *, pre: str = ..., suf: str = ...
) -> object: ...

def getattr_from(self, __source: Any, *, pre: str = "", suf: str = "") -> object:
"""Equivalent to ``something.<pre><obs/var><suf>``."""
return getattr(__source, pre + self.value + suf)

def getitem_from(
self, __source: Mapping[str, "_T"], *, pre: str = "", suf: str = ""
) -> _T:
"""Equivalent to ``something[pre + "obs"/"var" + suf]``."""
return __source[pre + self.value + suf]
return super().value


@attrs.define
Expand Down Expand Up @@ -389,7 +363,7 @@ def obs_scene_ids(self) -> pa.Array:
)

full_table = obs_scene.read(
coords=((AxisName.OBS.getattr_from(self._joinids), slice(None))),
coords=(self._joinids.obs, slice(None)),
result_order=ResultOrder.COLUMN_MAJOR,
value_filter="data != 0",
).concat()
Expand All @@ -416,7 +390,7 @@ def var_scene_ids(self) -> pa.Array:
)

full_table = var_scene.read(
coords=((AxisName.VAR.getattr_from(self._joinids), slice(None))),
coords=(self._joinids.var, slice(None)),
result_order=ResultOrder.COLUMN_MAJOR,
value_filter="data != 0",
).concat()
Expand Down Expand Up @@ -477,6 +451,8 @@ def to_anndata(
obs_table, var_table = tp.map(
self._read_axis_dataframe,
(AxisName.OBS, AxisName.VAR),
(self._obs_df, self._var_df),
(self._matrix_axis_query.obs, self._matrix_axis_query.var),
(column_names, column_names),
)
obs_joinids = self.obs_joinids()
Expand All @@ -496,19 +472,43 @@ def to_anndata(
x_future = x_matrices.pop(X_name)

obsm_future = {
key: tp.submit(self._axism_inner_ndarray, AxisName.OBS, key)
key: tp.submit(
_read_inner_ndarray,
self._get_annotation_layer("obsm", key),
obs_joinids,
self.indexer.by_obs,
)
for key in obsm_layers
}
varm_future = {
key: tp.submit(self._axism_inner_ndarray, AxisName.VAR, key)
key: tp.submit(
_read_inner_ndarray,
self._get_annotation_layer("varm", key),
var_joinids,
self.indexer.by_var,
)
for key in varm_layers
}
obsp_future = {
key: tp.submit(self._axisp_inner_sparray, AxisName.OBS, key)
key: tp.submit(
_read_as_csr,
self._get_annotation_layer("obsp", key),
obs_joinids,
obs_joinids,
self.indexer.by_obs,
self.indexer.by_obs,
)
for key in obsp_layers
}
varp_future = {
key: tp.submit(self._axisp_inner_sparray, AxisName.VAR, key)
key: tp.submit(
_read_as_csr,
self._get_annotation_layer("varp", key),
var_joinids,
var_joinids,
self.indexer.by_var,
self.indexer.by_var,
)
for key in varp_layers
}

Expand Down Expand Up @@ -778,15 +778,13 @@ def __exit__(self, *_: Any) -> None:
def _read_axis_dataframe(
self,
axis: AxisName,
axis_df: DataFrame,
axis_query: AxisQuery,
axis_column_names: AxisColumnNames,
) -> pa.Table:
"""Reads the specified axis. Will cache join IDs if not present."""
column_names = axis_column_names.get(axis.value)

axis_df = axis.getattr_from(self, pre="_", suf="_df")
assert isinstance(axis_df, DataFrame)
axis_query = axis.getattr_from(self._matrix_axis_query)

# If we can cache join IDs, prepare to add them to the cache.
joinids_cached = self._joinids._is_cached(axis)
query_columns = column_names
Expand Down Expand Up @@ -859,56 +857,6 @@ def _get_annotation_layer(
)
return layer

def _convert_to_ndarray(
self, axis: AxisName, table: pa.Table, n_row: int, n_col: int
) -> npt.NDArray[np.float32]:
indexer = cast(
Callable[[Numpyable], npt.NDArray[np.intp]],
axis.getattr_from(self.indexer, pre="by_"),
)
idx = indexer(table["soma_dim_0"])
z: npt.NDArray[np.float32] = np.zeros(n_row * n_col, dtype=np.float32)
np.put(z, idx * n_col + table["soma_dim_1"], table["soma_data"])
return z.reshape(n_row, n_col)

def _axisp_inner_sparray(
self,
axis: AxisName,
layer: str,
) -> sp.csr_matrix:
joinids = axis.getattr_from(self._joinids)
indexer = cast(
Callable[[Numpyable], npt.NDArray[np.intp]],
axis.getattr_from(self.indexer, pre="by_"),
)
annotation_name = f"{axis.value}p"
return _read_as_csr(
self._get_annotation_layer(annotation_name, layer),
joinids,
joinids,
indexer,
indexer,
)

def _axism_inner_ndarray(
self,
axis: AxisName,
layer: str,
) -> npt.NDArray[np.float32]:
joinids = axis.getattr_from(self._joinids)
annotation_name = f"{axis.value}m"
table = (
self._get_annotation_layer(annotation_name, layer)
.read((joinids, slice(None)))
.tables()
.concat()
)

n_row = len(joinids)
n_col = len(table["soma_dim_1"].unique())

return self._convert_to_ndarray(axis, table, n_row, n_col)

@property
def _obs_df(self) -> DataFrame:
return self.experiment.obs
Expand Down Expand Up @@ -995,6 +943,22 @@ def load_joinids(df: DataFrame, axq: AxisQuery) -> pa.IntegerArray:
return tbl.column("soma_joinid").combine_chunks()


def _read_inner_ndarray(
matrix: SparseNDArray,
joinids: pa.IntegerArray,
indexer: Callable[[Numpyable], npt.NDArray[np.intp]],
) -> npt.NDArray[np.float32]:
table = matrix.read((joinids, slice(None))).tables().concat()

n_row = len(joinids)
n_col = len(table["soma_dim_1"].unique())

idx = indexer(table["soma_dim_0"])
z: npt.NDArray[np.float32] = np.zeros(n_row * n_col, dtype=np.float32)
np.put(z, idx * n_col + table["soma_dim_1"], table["soma_data"])
return z.reshape(n_row, n_col)


def _read_as_csr(
matrix: SparseNDArray,
d0_joinids_arr: pa.IntegerArray,
Expand Down
14 changes: 0 additions & 14 deletions apis/python/tests/test_experiment_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
)
from tiledbsoma._collection import CollectionBase
from tiledbsoma._experiment import Experiment
from tiledbsoma._query import AxisName
from tiledbsoma.experiment_query import X_as_series

from tests._util import raises_no_typeguard
Expand Down Expand Up @@ -965,16 +964,3 @@ class IHaveObsVarStuff:
var: int
the_obs_suf: str
the_var_suf: str


def test_axis_helpers() -> None:
thing = IHaveObsVarStuff(obs=1, var=2, the_obs_suf="observe", the_var_suf="vary")
assert 1 == AxisName.OBS.getattr_from(thing)
assert 2 == AxisName.VAR.getattr_from(thing)
assert "observe" == AxisName.OBS.getattr_from(thing, pre="the_", suf="_suf")
assert "vary" == AxisName.VAR.getattr_from(thing, pre="the_", suf="_suf")
ovdict = {"obs": "erve", "var": "y", "i_obscure": "hide", "i_varcure": "???"}
assert "erve" == AxisName.OBS.getitem_from(ovdict)
assert "y" == AxisName.VAR.getitem_from(ovdict)
assert "hide" == AxisName.OBS.getitem_from(ovdict, pre="i_", suf="cure")
assert "???" == AxisName.VAR.getitem_from(ovdict, pre="i_", suf="cure")