Skip to content

Commit

Permalink
Finish removing getattr_from from AttrName
Browse files Browse the repository at this point in the history
  • Loading branch information
jp-dark committed Jan 16, 2025
1 parent f5a4c07 commit 7c934f8
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 49 deletions.
40 changes: 5 additions & 35 deletions apis/python/src/tiledbsoma/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@
Callable,
Dict,
Literal,
Mapping,
Protocol,
Sequence,
TypeVar,
overload,
)

import attrs
Expand Down Expand Up @@ -87,30 +85,7 @@ class AxisName(enum.Enum):

@property
def value(self) -> Literal["obs", "var"]:
return super().value # type: ignore[no-any-return]

@overload
def getattr_from(self, __source: _HasObsVar[_T]) -> _T: ...

@overload
def getattr_from(
self, __source: Any, *, pre: Literal[""], suf: Literal[""]
) -> object: ...

@overload
def getattr_from(
self, __source: Any, *, pre: str = ..., suf: str = ...
) -> object: ...

def getattr_from(self, __source: Any, *, pre: str = "", suf: str = "") -> object:
"""Equivalent to ``something.<pre><obs/var><suf>``."""
return getattr(__source, pre + self.value + suf)

def getitem_from(
self, __source: Mapping[str, "_T"], *, pre: str = "", suf: str = ""
) -> _T:
"""Equivalent to ``something[pre + "obs"/"var" + suf]``."""
return __source[pre + self.value + suf]
return super().value


@attrs.define
Expand Down Expand Up @@ -476,6 +451,8 @@ def to_anndata(
obs_table, var_table = tp.map(
self._read_axis_dataframe,
(AxisName.OBS, AxisName.VAR),
(self._obs_df, self._var_df),
(self._matrix_axis_query.obs, self._matrix_axis_query.var),
(column_names, column_names),
)
obs_joinids = self.obs_joinids()
Expand Down Expand Up @@ -801,20 +778,13 @@ def __exit__(self, *_: Any) -> None:
def _read_axis_dataframe(
self,
axis: AxisName,
axis_df: DataFrame,
axis_query: AxisQuery,
axis_column_names: AxisColumnNames,
) -> pa.Table:
"""Reads the specified axis. Will cache join IDs if not present."""
column_names = axis_column_names.get(axis.value)

if axis.value == "obs":
axis_df = self._obs_df
axis_query = self._matrix_axis_query.obs
else:
axis_df = self._var_df
axis_query = self._matrix_axis_query.var

assert isinstance(axis_df, DataFrame)

# If we can cache join IDs, prepare to add them to the cache.
joinids_cached = self._joinids._is_cached(axis)
query_columns = column_names
Expand Down
14 changes: 0 additions & 14 deletions apis/python/tests/test_experiment_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
)
from tiledbsoma._collection import CollectionBase
from tiledbsoma._experiment import Experiment
from tiledbsoma._query import AxisName
from tiledbsoma.experiment_query import X_as_series

from tests._util import raises_no_typeguard
Expand Down Expand Up @@ -965,16 +964,3 @@ class IHaveObsVarStuff:
var: int
the_obs_suf: str
the_var_suf: str


def test_axis_helpers() -> None:
thing = IHaveObsVarStuff(obs=1, var=2, the_obs_suf="observe", the_var_suf="vary")
assert 1 == AxisName.OBS.getattr_from(thing)
assert 2 == AxisName.VAR.getattr_from(thing)
assert "observe" == AxisName.OBS.getattr_from(thing, pre="the_", suf="_suf")
assert "vary" == AxisName.VAR.getattr_from(thing, pre="the_", suf="_suf")
ovdict = {"obs": "erve", "var": "y", "i_obscure": "hide", "i_varcure": "???"}
assert "erve" == AxisName.OBS.getitem_from(ovdict)
assert "y" == AxisName.VAR.getitem_from(ovdict)
assert "hide" == AxisName.OBS.getitem_from(ovdict, pre="i_", suf="cure")
assert "???" == AxisName.VAR.getitem_from(ovdict, pre="i_", suf="cure")

0 comments on commit 7c934f8

Please sign in to comment.