Skip to content

Commit

Permalink
Fix ffill
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Jul 23, 2024
1 parent 8301a84 commit 7656911
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 72 deletions.
1 change: 0 additions & 1 deletion flox/aggregate_flox.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ def nanmean(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None

def ffill(group_idx, array, *, axis, **kwargs):
group_idx, array, perm = _prepare_for_flox(group_idx, array)

shape = array.shape
ndim = array.ndim
assert axis == (ndim - 1), (axis, ndim - 1)
Expand Down
73 changes: 14 additions & 59 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict

import numpy as np
import pandas as pd
from numpy.typing import ArrayLike, DTypeLike

from . import aggregate_flox, aggregate_npg, xrutils
Expand Down Expand Up @@ -575,7 +574,7 @@ class Scan:
# between reductions and scans
name: str
# binary operation (e.g. add)
binary_op: Callable
# binary_op: Callable
# in-memory grouped scan function (e.g. cumsum)
scan: str
# Grouped reduction that yields the last result of the scan (e.g. sum)
Expand All @@ -597,66 +596,22 @@ def __post_init__(self):
assert self.array.shape[-1] == self.group_idx.size


def scan_binary_op(
left: AlignedArrays, right: AlignedArrays, *, op: Callable, fill_value: Any
) -> AlignedArrays:
from .core import reindex_
@dataclass
class ScanState:
"""Dataclass representing intermediates for scan."""

reindexed = reindex_(
left.array,
from_=pd.Index(left.group_idx),
# TODO: `right.group_idx` instead?
to=pd.RangeIndex(right.group_idx.max() + 1),
fill_value=fill_value,
axis=-1,
)
return AlignedArrays(
array=op(reindexed[..., right.group_idx], right.array), group_idx=right.group_idx
)
# last value of each group seen so far
state: AlignedArrays | None
# intermediate result
result: AlignedArrays | None

def __post_init__(self):
assert (self.state is not None) or (self.result is not None)


def _fill_with_last_one(
left: AlignedArrays, right: AlignedArrays, *, fill_value: Any
) -> AlignedArrays:
from .aggregate_flox import ffill

if right.group_idx[0] not in left.group_idx:
return right

# from .core import reindex_
# reindexed = reindex_(
# left.array,
# from_=pd.Index(left.group_idx),
# to=pd.Index(right.group_idx),
# fill_value=fill_value,
# axis=-1,
# )

new = ffill(
np.concatenate([left.group_idx, right.group_idx], axis=-1),
np.concatenate([left.array, right.array], axis=-1),
axis=right.array.ndim - 1,
)[..., left.group_idx.size :]
return AlignedArrays(array=new, group_idx=right.group_idx)


cumsum = Scan(
"cumsum",
binary_op=partial(scan_binary_op, op=np.add),
reduction="sum",
scan="cumsum",
identity=0,
)
nancumsum = Scan(
"nancumsum",
binary_op=partial(scan_binary_op, op=np.add),
reduction="nansum",
scan="nancumsum",
identity=0,
)
ffill = Scan(
"ffill", binary_op=_fill_with_last_one, reduction="nanlast", scan="ffill", identity=np.nan
)
cumsum = Scan("cumsum", reduction="sum", scan="cumsum", identity=0)
nancumsum = Scan("nancumsum", reduction="nansum", scan="nancumsum", identity=0)
ffill = Scan("ffill", reduction="nanlast", scan="ffill", identity=np.nan)
# cumprod = Scan("cumprod", binary_op=np.multiply, preop="prod", scan="cumprod")


Expand Down
50 changes: 40 additions & 10 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
Aggregation,
AlignedArrays,
Scan,
ScanState,
_atleast_1d,
_initialize_aggregation,
generic_aggregate,
Expand Down Expand Up @@ -2717,17 +2718,17 @@ def groupby_scan(

if not has_dask:
(single_axis,) = axis_
result = grouped_scan(
final_state = grouped_scan(
AlignedArrays(array=array, group_idx=by_), axis=single_axis, agg=agg, dtype=agg.dtype
)
return result.array
return extract_array(final_state)
else:
return dask_groupby_scan(array, by_, axes=axis_, agg=agg)


def grouped_scan(
inp: AlignedArrays, *, axis: int, agg: Scan, dtype=None, keepdims=None
) -> AlignedArrays:
) -> ScanState:
assert axis == inp.array.ndim - 1

# TODO: factorize here (maybe?)
Expand All @@ -2740,10 +2741,10 @@ def grouped_scan(
dtype=dtype,
fill_value=agg.identity,
)
return AlignedArrays(array=accumulated, group_idx=inp.group_idx)
return ScanState(result=AlignedArrays(array=accumulated, group_idx=inp.group_idx), state=None)


def grouped_reduce(inp: AlignedArrays, *, agg: Scan, axis: int, keepdims=None) -> AlignedArrays:
def grouped_reduce(inp: AlignedArrays, *, agg: Scan, axis: int, keepdims=None) -> ScanState:
assert axis == inp.array.ndim - 1
reduced = chunk_reduce(
inp.array,
Expand All @@ -2755,15 +2756,42 @@ def grouped_reduce(inp: AlignedArrays, *, agg: Scan, axis: int, keepdims=None) -
fill_value=agg.identity,
expected_groups=None,
)
return AlignedArrays(array=reduced["intermediates"][0], group_idx=reduced["groups"])
return ScanState(
state=AlignedArrays(array=reduced["intermediates"][0], group_idx=reduced["groups"]),
result=None,
)


def scan_binary_op(
left_state: ScanState, right_state: ScanState, agg: Scan, *, fill_value: Any
) -> ScanState:

assert left_state.state is not None
left = left_state.state
right = right_state.result if right_state.result is not None else right_state.state

def _zip(group_idx, array):
new_group_idx = np.concatenate([left.group_idx, right.group_idx], axis=-1)
new_array = np.concatenate([left.array, right.array], axis=-1)

new = generic_aggregate(
new_group_idx, new_array, func=agg.scan, axis=right.array.ndim - 1, engine="flox"
)[..., left.group_idx.size :]
# This is quite important. We need to update the state seen so far and propagate that.
lasts = grouped_reduce(
AlignedArrays(group_idx=new_group_idx, array=new_array), agg=agg, axis=right.array.ndim - 1
)
return ScanState(
state=lasts.state,
result=AlignedArrays(array=new, group_idx=right.group_idx),
)


def _zip(group_idx, array) -> AlignedArrays:
return AlignedArrays(group_idx=group_idx, array=array)


def extract_array(block: AlignedArrays):
return block.array
def extract_array(block: ScanState) -> np.ndarray:
return block.result.array


def dask_groupby_scan(array, by, axes: T_Axes, agg: Scan) -> DaskArray:
Expand All @@ -2785,10 +2813,12 @@ def dask_groupby_scan(array, by, axes: T_Axes, agg: Scan) -> DaskArray:
# dask tokenizing error workaround
scan_.__name__ = scan_.func.__name__

binop = partial(scan_binary_op, agg=agg, fill_value=agg.identity)

# 2. Run the scan
accumulated = scan(
func=scan_,
binop=partial(agg.binary_op, fill_value=agg.identity),
binop=binop,
ident=agg.identity,
x=zipped,
axis=axis,
Expand Down
9 changes: 9 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1804,3 +1804,12 @@ def test_nanlen_string(dtype, engine):
expected = np.array([3, 2, 1], dtype=np.intp)
actual, *_ = groupby_reduce(array, by, func="count", engine=engine)
assert_equal(expected, actual)


# from numpy import nan

# array = np.array([nan, 0., nan, nan, 0.], dtype=np.float32)
# group_idx = np.array([0, 0, 1, 0, 0])
# ffill.dtype = array.dtype
# dask_groupby_scan(dask.array.from_array(array, chunks=(1, 1, 1, 2)), group_idx, axes=(0,), agg=ffill).compute()
# dask_groupby_scan(dask.array.from_array(array, chunks=(2, 1, 2)), group_idx, axes=(0,), agg=ffill).compute()
4 changes: 2 additions & 2 deletions tests/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,8 @@ def test_simple_scans(data, array):
@given(
data=st.data(),
array=chunked_arrays(),
# func=st.sampled_from(tuple(NUMPY_SCAN_FUNCS))
func=st.just("ffill"),
func=st.sampled_from(tuple(NUMPY_SCAN_FUNCS)),
# func=st.just("ffill"),
)
def test_scans(data, array, func):
by = data.draw(by_arrays(shape=(array.shape[-1],)))
Expand Down

0 comments on commit 7656911

Please sign in to comment.