From a42f5c54734effc962398ea5a2e846eb94da2734 Mon Sep 17 00:00:00 2001 From: discord9 Date: Tue, 14 Jan 2025 16:40:30 +0800 Subject: [PATCH 01/18] todo: refactor accumulator --- src/flow/src/expr/relation/accum.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/flow/src/expr/relation/accum.rs b/src/flow/src/expr/relation/accum.rs index 252913de56f6..9e048e0ad68b 100644 --- a/src/flow/src/expr/relation/accum.rs +++ b/src/flow/src/expr/relation/accum.rs @@ -36,6 +36,7 @@ use crate::expr::{AggregateFunc, EvalError}; use crate::repr::Diff; /// Accumulates values for the various types of accumulable aggregations. +/// TODO(discord9): refactor it to be more like datafusion's Accumulator #[enum_dispatch] pub trait Accumulator: Sized { fn into_state(self) -> Vec; From 83b35f33b2955f75fb0a0f41fe74338c21fd5cba Mon Sep 17 00:00:00 2001 From: discord9 Date: Wed, 15 Jan 2025 14:31:17 +0800 Subject: [PATCH 02/18] WIP: accum v2(will replace v1) --- src/flow/src/expr/relation/accum.rs | 3 ++ src/flow/src/expr/relation/accum/accum_v2.rs | 56 ++++++++++++++++++++ src/flow/src/expr/relation/accum/min_max.rs | 13 +++++ 3 files changed, 72 insertions(+) create mode 100644 src/flow/src/expr/relation/accum/accum_v2.rs create mode 100644 src/flow/src/expr/relation/accum/min_max.rs diff --git a/src/flow/src/expr/relation/accum.rs b/src/flow/src/expr/relation/accum.rs index 9e048e0ad68b..9c5dfae6f859 100644 --- a/src/flow/src/expr/relation/accum.rs +++ b/src/flow/src/expr/relation/accum.rs @@ -20,6 +20,9 @@ //! Currently support sum, count, any, all and min/max(with one caveat that min/max can't support delete with aggregate). //! TODO: think of better ways to not ser/de every time a accum needed to be updated, since it's in a tight loop +mod accum_v2; +mod min_max; + use std::any::type_name; use std::fmt::Display; diff --git a/src/flow/src/expr/relation/accum/accum_v2.rs b/src/flow/src/expr/relation/accum/accum_v2.rs new file mode 100644 index 000000000000..8e72effcf9ee --- /dev/null +++ b/src/flow/src/expr/relation/accum/accum_v2.rs @@ -0,0 +1,56 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! new accumulator trait that is more flexible and can be used in the future for more complex accumulators + +use datatypes::value::Value; +use datatypes::vectors::VectorRef; + +use crate::expr::error::InternalSnafu; +use crate::expr::EvalError; + +/// Basically a copy of datafusion's Accumulator, but with a few modifications +/// to accomodate our needs in flow and keep the upgradability of datafusion +pub trait AccumulatorV2: Send + Sync + std::fmt::Debug { + /// Updates the accumulator’s state from its input. + fn update_batch(&mut self, values: &[VectorRef]) -> Result<(), EvalError>; + + /// Returns the current aggregate value, NOT consuming the internal state, so it can be called multiple times. + fn evaluate(&self) -> Result; + + /// Returns the allocated size required for this accumulator, in bytes, including Self. + fn size(&self) -> usize; + + fn into_state(self) -> Result, EvalError>; + + /// Merges the states of multiple accumulators into this accumulator. + /// The states array passed was formed by concatenating the results of calling `Self::into_state` on zero or more other Accumulator instances. + fn merge_batch(&mut self, states: &[VectorRef]) -> Result<(), EvalError>; + + /// Retracts (removed) an update (caused by the given inputs) to accumulator’s state. + fn retract_batch(&mut self, _values: &[VectorRef]) -> Result<(), EvalError> { + InternalSnafu { + reason: format!( + "retract_batch not implemented for this accumulator {:?}", + self + ), + } + .fail() + } + + /// Does the accumulator support incrementally updating its value by removing values. + fn supports_retract_batch(&self) -> bool { + false + } +} diff --git a/src/flow/src/expr/relation/accum/min_max.rs b/src/flow/src/expr/relation/accum/min_max.rs new file mode 100644 index 000000000000..59f3388c4861 --- /dev/null +++ b/src/flow/src/expr/relation/accum/min_max.rs @@ -0,0 +1,13 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. From 0b9315112182bfe21adaf9e647ba9d3c11358986 Mon Sep 17 00:00:00 2001 From: discord9 Date: Mon, 10 Feb 2025 11:46:25 +0800 Subject: [PATCH 03/18] WIP --- src/flow/src/expr/relation/accum.rs | 1 - src/flow/src/expr/relation/accum/accum_v2.rs | 70 ++++++++++++++++++++ src/flow/src/expr/relation/accum/min_max.rs | 13 ---- 3 files changed, 70 insertions(+), 14 deletions(-) delete mode 100644 src/flow/src/expr/relation/accum/min_max.rs diff --git a/src/flow/src/expr/relation/accum.rs b/src/flow/src/expr/relation/accum.rs index 9c5dfae6f859..49df95bc1e4d 100644 --- a/src/flow/src/expr/relation/accum.rs +++ b/src/flow/src/expr/relation/accum.rs @@ -21,7 +21,6 @@ //! TODO: think of better ways to not ser/de every time a accum needed to be updated, since it's in a tight loop mod accum_v2; -mod min_max; use std::any::type_name; use std::fmt::Display; diff --git a/src/flow/src/expr/relation/accum/accum_v2.rs b/src/flow/src/expr/relation/accum/accum_v2.rs index 8e72effcf9ee..cbc54d828222 100644 --- a/src/flow/src/expr/relation/accum/accum_v2.rs +++ b/src/flow/src/expr/relation/accum/accum_v2.rs @@ -14,11 +14,16 @@ //! new accumulator trait that is more flexible and can be used in the future for more complex accumulators +use std::any::type_name; + use datatypes::value::Value; use datatypes::vectors::VectorRef; +use serde::{Deserialize, Serialize}; +use snafu::ensure; use crate::expr::error::InternalSnafu; use crate::expr::EvalError; +use crate::repr::Diff as FlowDiff; /// Basically a copy of datafusion's Accumulator, but with a few modifications /// to accomodate our needs in flow and keep the upgradability of datafusion @@ -32,13 +37,21 @@ pub trait AccumulatorV2: Send + Sync + std::fmt::Debug { /// Returns the allocated size required for this accumulator, in bytes, including Self. fn size(&self) -> usize; + /// Returns the intermediate state of the accumulator, consuming the intermediate state. fn into_state(self) -> Result, EvalError>; + /// Creates an accumulator from its intermediate state. + fn from_state(values: &[Value]) -> Result + where + Self: Sized; + /// Merges the states of multiple accumulators into this accumulator. /// The states array passed was formed by concatenating the results of calling `Self::into_state` on zero or more other Accumulator instances. fn merge_batch(&mut self, states: &[VectorRef]) -> Result<(), EvalError>; /// Retracts (removed) an update (caused by the given inputs) to accumulator’s state. + /// + /// currently unused, but will be used in the future for i.e. windowed aggregates fn retract_batch(&mut self, _values: &[VectorRef]) -> Result<(), EvalError> { InternalSnafu { reason: format!( @@ -54,3 +67,60 @@ pub trait AccumulatorV2: Send + Sync + std::fmt::Debug { false } } + +/// Bool accumulator, used for `Any` `All` `Max/MinBool` +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub struct Bool { + /// The number of `true` values observed. + trues: FlowDiff, + /// The number of `false` values observed. + falses: FlowDiff, +} + +impl AccumulatorV2 for Bool { + fn from_state(values: &[Value]) -> Result + where + Self: Sized, + { + let mut iter = values.iter(); + Ok(Self { + trues: FlowDiff::try_from(iter.next().ok_or_else(fail_accum::)?) + .map_err(err_try_from_val)?, + falses: FlowDiff::try_from(iter.next().ok_or_else(fail_accum::)?) + .map_err(err_try_from_val)?, + }) + } + + fn into_state(self) -> Result, EvalError> { + Ok(vec![self.trues.into(), self.falses.into()]) + } + + fn update_batch(&mut self, values: &[VectorRef]) -> Result<(), EvalError> { + ensure!( + values.len() == 1, + InternalSnafu { + reason: format!("Bool accumulator expects 1 column, got {}", values.len()) + } + ); + let values = &values[0]; + todo!(); + Ok(()) + } +} + +fn fail_accum() -> EvalError { + InternalSnafu { + reason: format!( + "list of values exhausted before a accum of type {} can be build from it", + type_name::() + ), + } + .build() +} + +fn err_try_from_val(reason: T) -> EvalError { + TryFromValueSnafu { + msg: reason.to_string(), + } + .build() +} diff --git a/src/flow/src/expr/relation/accum/min_max.rs b/src/flow/src/expr/relation/accum/min_max.rs deleted file mode 100644 index 59f3388c4861..000000000000 --- a/src/flow/src/expr/relation/accum/min_max.rs +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright 2023 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. From a9e2aef9b9a3edc108c905cdb5133be8327e9e6c Mon Sep 17 00:00:00 2001 From: discord9 Date: Mon, 10 Feb 2025 16:04:04 +0800 Subject: [PATCH 04/18] WIP more&more --- src/flow/src/expr/relation.rs | 2 + src/flow/src/expr/relation/accum.rs | 2 - .../src/expr/relation/{accum => }/accum_v2.rs | 49 +------------------ src/flow/src/expr/relation/udaf.rs | 13 +++++ 4 files changed, 17 insertions(+), 49 deletions(-) rename src/flow/src/expr/relation/{accum => }/accum_v2.rs (67%) create mode 100644 src/flow/src/expr/relation/udaf.rs diff --git a/src/flow/src/expr/relation.rs b/src/flow/src/expr/relation.rs index b5d7e4ef2078..b3f26edb8750 100644 --- a/src/flow/src/expr/relation.rs +++ b/src/flow/src/expr/relation.rs @@ -20,7 +20,9 @@ pub(crate) use func::AggregateFunc; use crate::expr::ScalarExpr; mod accum; +mod accum_v2; mod func; +mod udaf; /// Describes an aggregation expression. #[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] diff --git a/src/flow/src/expr/relation/accum.rs b/src/flow/src/expr/relation/accum.rs index 49df95bc1e4d..9e048e0ad68b 100644 --- a/src/flow/src/expr/relation/accum.rs +++ b/src/flow/src/expr/relation/accum.rs @@ -20,8 +20,6 @@ //! Currently support sum, count, any, all and min/max(with one caveat that min/max can't support delete with aggregate). //! TODO: think of better ways to not ser/de every time a accum needed to be updated, since it's in a tight loop -mod accum_v2; - use std::any::type_name; use std::fmt::Display; diff --git a/src/flow/src/expr/relation/accum/accum_v2.rs b/src/flow/src/expr/relation/accum_v2.rs similarity index 67% rename from src/flow/src/expr/relation/accum/accum_v2.rs rename to src/flow/src/expr/relation/accum_v2.rs index cbc54d828222..58f61746dc7d 100644 --- a/src/flow/src/expr/relation/accum/accum_v2.rs +++ b/src/flow/src/expr/relation/accum_v2.rs @@ -15,15 +15,15 @@ //! new accumulator trait that is more flexible and can be used in the future for more complex accumulators use std::any::type_name; +use std::fmt::Display; use datatypes::value::Value; use datatypes::vectors::VectorRef; use serde::{Deserialize, Serialize}; use snafu::ensure; -use crate::expr::error::InternalSnafu; +use crate::expr::error::{InternalSnafu, TryFromValueSnafu}; use crate::expr::EvalError; -use crate::repr::Diff as FlowDiff; /// Basically a copy of datafusion's Accumulator, but with a few modifications /// to accomodate our needs in flow and keep the upgradability of datafusion @@ -40,11 +40,6 @@ pub trait AccumulatorV2: Send + Sync + std::fmt::Debug { /// Returns the intermediate state of the accumulator, consuming the intermediate state. fn into_state(self) -> Result, EvalError>; - /// Creates an accumulator from its intermediate state. - fn from_state(values: &[Value]) -> Result - where - Self: Sized; - /// Merges the states of multiple accumulators into this accumulator. /// The states array passed was formed by concatenating the results of calling `Self::into_state` on zero or more other Accumulator instances. fn merge_batch(&mut self, states: &[VectorRef]) -> Result<(), EvalError>; @@ -68,46 +63,6 @@ pub trait AccumulatorV2: Send + Sync + std::fmt::Debug { } } -/// Bool accumulator, used for `Any` `All` `Max/MinBool` -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] -pub struct Bool { - /// The number of `true` values observed. - trues: FlowDiff, - /// The number of `false` values observed. - falses: FlowDiff, -} - -impl AccumulatorV2 for Bool { - fn from_state(values: &[Value]) -> Result - where - Self: Sized, - { - let mut iter = values.iter(); - Ok(Self { - trues: FlowDiff::try_from(iter.next().ok_or_else(fail_accum::)?) - .map_err(err_try_from_val)?, - falses: FlowDiff::try_from(iter.next().ok_or_else(fail_accum::)?) - .map_err(err_try_from_val)?, - }) - } - - fn into_state(self) -> Result, EvalError> { - Ok(vec![self.trues.into(), self.falses.into()]) - } - - fn update_batch(&mut self, values: &[VectorRef]) -> Result<(), EvalError> { - ensure!( - values.len() == 1, - InternalSnafu { - reason: format!("Bool accumulator expects 1 column, got {}", values.len()) - } - ); - let values = &values[0]; - todo!(); - Ok(()) - } -} - fn fail_accum() -> EvalError { InternalSnafu { reason: format!( diff --git a/src/flow/src/expr/relation/udaf.rs b/src/flow/src/expr/relation/udaf.rs new file mode 100644 index 000000000000..59f3388c4861 --- /dev/null +++ b/src/flow/src/expr/relation/udaf.rs @@ -0,0 +1,13 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. From f5f27af944e7ad0e1ef08aa17e52e16a40aa9aa8 Mon Sep 17 00:00:00 2001 From: discord9 Date: Mon, 10 Feb 2025 19:42:02 +0800 Subject: [PATCH 05/18] feat: binding datafusion aggr func --- src/flow/src/expr/relation/accum_v2.rs | 91 +++++++++++++++++++++++++- src/flow/src/expr/relation/udaf.rs | 64 ++++++++++++++++++ 2 files changed, 152 insertions(+), 3 deletions(-) diff --git a/src/flow/src/expr/relation/accum_v2.rs b/src/flow/src/expr/relation/accum_v2.rs index 58f61746dc7d..fc00fe7f0756 100644 --- a/src/flow/src/expr/relation/accum_v2.rs +++ b/src/flow/src/expr/relation/accum_v2.rs @@ -14,15 +14,18 @@ //! new accumulator trait that is more flexible and can be used in the future for more complex accumulators -use std::any::type_name; +use std::any::{type_name, Any}; use std::fmt::Display; +use std::sync::{Mutex, MutexGuard}; +use datafusion::logical_expr::Accumulator as DfAccumulator; +use datatypes::prelude::ConcreteDataType; use datatypes::value::Value; use datatypes::vectors::VectorRef; use serde::{Deserialize, Serialize}; -use snafu::ensure; +use snafu::{ensure, ResultExt}; -use crate::expr::error::{InternalSnafu, TryFromValueSnafu}; +use crate::expr::error::{DataTypeSnafu, DatafusionSnafu, InternalSnafu, TryFromValueSnafu}; use crate::expr::EvalError; /// Basically a copy of datafusion's Accumulator, but with a few modifications @@ -63,6 +66,88 @@ pub trait AccumulatorV2: Send + Sync + std::fmt::Debug { } } +/// Adapter for several hand-picked datafusion accumulators that can be used in flow +/// +/// i.e: can call evaluate multiple times. +#[derive(Debug)] +pub struct DfAccumulatorAdapter { + inner: Mutex>, +} + +pub trait AcceptDfAccumulator: DfAccumulator {} + +impl AcceptDfAccumulator for datafusion::functions_aggregate::min_max::MaxAccumulator {} + +impl AcceptDfAccumulator for datafusion::functions_aggregate::min_max::MinAccumulator {} + +impl DfAccumulatorAdapter { + // TODO(discord9): find a way to whitelist only certain type of accumulators + fn new_unchecked(acc: Box) -> Self { + Self { + inner: Mutex::new(acc), + } + } + + fn new(acc: T) -> Self { + Self::new_unchecked(Box::new(acc)) + } + + fn acc(&self) -> MutexGuard> { + self.inner.lock().expect("lock poisoned") + } +} + +impl AccumulatorV2 for DfAccumulatorAdapter { + fn update_batch(&mut self, values: &[VectorRef]) -> Result<(), EvalError> { + let values = values + .iter() + .map(|v| v.to_arrow_array().clone()) + .collect::>(); + self.acc().update_batch(&values).context(DatafusionSnafu { + context: "failed to update batch: {}", + }) + } + + fn evaluate(&self) -> Result { + // TODO(discord9): find a way to confirm internal state is not consumed + let value = self.acc().evaluate().context(DatafusionSnafu { + context: "failed to evaluate accumulator: {}", + })?; + let value = Value::try_from(value).context(DataTypeSnafu { + msg: "failed to convert evaluate result from `ScalarValue` to `Value`", + })?; + Ok(value) + } + + fn size(&self) -> usize { + self.acc().size() + } + + fn into_state(self) -> Result, EvalError> { + let state = self.acc().state().context(DatafusionSnafu { + context: "failed to get state: {}", + })?; + let state = state + .into_iter() + .map(Value::try_from) + .collect::, _>>() + .with_context(|_| DataTypeSnafu { + msg: "failed to convert `ScalarValue` state to `Value`", + })?; + Ok(state) + } + + fn merge_batch(&mut self, states: &[VectorRef]) -> Result<(), EvalError> { + let states = states + .iter() + .map(|v| v.to_arrow_array().clone()) + .collect::>(); + self.acc().merge_batch(&states).context(DatafusionSnafu { + context: "failed to merge batch", + }) + } +} + fn fail_accum() -> EvalError { InternalSnafu { reason: format!( diff --git a/src/flow/src/expr/relation/udaf.rs b/src/flow/src/expr/relation/udaf.rs index 59f3388c4861..4df2726400c6 100644 --- a/src/flow/src/expr/relation/udaf.rs +++ b/src/flow/src/expr/relation/udaf.rs @@ -11,3 +11,67 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +use std::fmt::Debug; + +use common_query::prelude::Signature; +use datatypes::prelude::ConcreteDataType; + +use crate::expr::relation::accum_v2::AccumulatorV2; +use crate::expr::ScalarExpr; +use crate::repr::RelationDesc; +use crate::Result; + +/// User-defined aggregate function (UDAF) implementation. +/// All built-in UDAFs for flow is also impl by this trait. +pub trait AggrUDFImpl: Debug + Send + Sync { + fn as_any(&self) -> &dyn std::any::Any; + + fn name(&self) -> &str; + + fn signature(&self) -> &Signature; + /// What ConcreteDataType will be returned by this function, given the types of the arguments + /// + /// Keep the return_type's error type the same as `Function`'s return_type + fn return_type(&self, arg_type: &[ConcreteDataType]) -> Result; + + fn accumulator(&self, acc_args: AccumulatorArgs<'_>) -> Result>; +} + +/// contains information about how an aggregate function was called, +/// including the types of its arguments and any optional ordering expressions. +/// +/// Should be created from AggregateExpr +pub struct AccumulatorArgs<'a> { + pub return_type: &'a ConcreteDataType, + pub schema: &'a RelationDesc, + pub ignore_nulls: bool, + /// The expressions in the `ORDER BY` clause passed to this aggregator. + /// + /// SQL allows the user to specify the ordering of arguments to the + /// aggregate using an `ORDER BY`. For example: + /// + /// ```sql + /// SELECT FIRST_VALUE(column1 ORDER BY column2) FROM t; + /// ``` + /// + /// If no `ORDER BY` is specified, `ordering_req` will be empty. + pub ordering_req: &'a OrderingReq, + pub is_reversed: bool, + pub name: &'a str, + pub is_distinct: bool, + pub exprs: &'a [ScalarExpr], +} + +pub struct OrderingReq { + pub exprs: Vec, +} + +pub struct SortExpr { + /// expression representing the column to sort + pub expr: ScalarExpr, + /// Whether to sort in descending order + pub descending: bool, + /// Whether to sort nulls first + pub nulls_first: bool, +} From a3d5509c609f369d37f4a857c65517063d82cc1b Mon Sep 17 00:00:00 2001 From: discord9 Date: Tue, 11 Feb 2025 15:13:46 +0800 Subject: [PATCH 06/18] feat: new accum --- src/flow/src/expr/relation/accum_v2.rs | 90 +++++++++++++++++++------- 1 file changed, 66 insertions(+), 24 deletions(-) diff --git a/src/flow/src/expr/relation/accum_v2.rs b/src/flow/src/expr/relation/accum_v2.rs index fc00fe7f0756..266394f43bae 100644 --- a/src/flow/src/expr/relation/accum_v2.rs +++ b/src/flow/src/expr/relation/accum_v2.rs @@ -14,16 +14,14 @@ //! new accumulator trait that is more flexible and can be used in the future for more complex accumulators -use std::any::{type_name, Any}; +use std::any::type_name; use std::fmt::Display; -use std::sync::{Mutex, MutexGuard}; use datafusion::logical_expr::Accumulator as DfAccumulator; -use datatypes::prelude::ConcreteDataType; +use datatypes::prelude::{ConcreteDataType as CDT, DataType}; use datatypes::value::Value; use datatypes::vectors::VectorRef; -use serde::{Deserialize, Serialize}; -use snafu::{ensure, ResultExt}; +use snafu::{ensure, OptionExt, ResultExt}; use crate::expr::error::{DataTypeSnafu, DatafusionSnafu, InternalSnafu, TryFromValueSnafu}; use crate::expr::EvalError; @@ -35,13 +33,15 @@ pub trait AccumulatorV2: Send + Sync + std::fmt::Debug { fn update_batch(&mut self, values: &[VectorRef]) -> Result<(), EvalError>; /// Returns the current aggregate value, NOT consuming the internal state, so it can be called multiple times. - fn evaluate(&self) -> Result; + fn evaluate(&mut self) -> Result; /// Returns the allocated size required for this accumulator, in bytes, including Self. fn size(&self) -> usize; /// Returns the intermediate state of the accumulator, consuming the intermediate state. - fn into_state(self) -> Result, EvalError>; + /// + /// note that Value::Null's type is unknown, so (Value, ConcreteDataType) is used instead of just Value + fn into_state(self) -> Result, EvalError>; /// Merges the states of multiple accumulators into this accumulator. /// The states array passed was formed by concatenating the results of calling `Self::into_state` on zero or more other Accumulator instances. @@ -71,7 +71,8 @@ pub trait AccumulatorV2: Send + Sync + std::fmt::Debug { /// i.e: can call evaluate multiple times. #[derive(Debug)] pub struct DfAccumulatorAdapter { - inner: Mutex>, + /// accumulator that is wrapped in a mutex to allow for evaluation + inner: Box, } pub trait AcceptDfAccumulator: DfAccumulator {} @@ -81,20 +82,14 @@ impl AcceptDfAccumulator for datafusion::functions_aggregate::min_max::MaxAccumu impl AcceptDfAccumulator for datafusion::functions_aggregate::min_max::MinAccumulator {} impl DfAccumulatorAdapter { - // TODO(discord9): find a way to whitelist only certain type of accumulators + /// create a new accumulator from a datafusion accumulator without checking if it is supported in flow fn new_unchecked(acc: Box) -> Self { - Self { - inner: Mutex::new(acc), - } + Self { inner: acc } } fn new(acc: T) -> Self { Self::new_unchecked(Box::new(acc)) } - - fn acc(&self) -> MutexGuard> { - self.inner.lock().expect("lock poisoned") - } } impl AccumulatorV2 for DfAccumulatorAdapter { @@ -103,14 +98,14 @@ impl AccumulatorV2 for DfAccumulatorAdapter { .iter() .map(|v| v.to_arrow_array().clone()) .collect::>(); - self.acc().update_batch(&values).context(DatafusionSnafu { + self.inner.update_batch(&values).context(DatafusionSnafu { context: "failed to update batch: {}", }) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { // TODO(discord9): find a way to confirm internal state is not consumed - let value = self.acc().evaluate().context(DatafusionSnafu { + let value = self.inner.evaluate().context(DatafusionSnafu { context: "failed to evaluate accumulator: {}", })?; let value = Value::try_from(value).context(DataTypeSnafu { @@ -120,16 +115,20 @@ impl AccumulatorV2 for DfAccumulatorAdapter { } fn size(&self) -> usize { - self.acc().size() + self.inner.size() } - fn into_state(self) -> Result, EvalError> { - let state = self.acc().state().context(DatafusionSnafu { + fn into_state(mut self) -> Result, EvalError> { + let state = self.inner.state().context(DatafusionSnafu { context: "failed to get state: {}", })?; let state = state .into_iter() - .map(Value::try_from) + .map(|v| -> Result<_, _> { + let dt = CDT::try_from(&v.data_type())?; + let val = Value::try_from(v)?; + Ok((val, dt)) + }) .collect::, _>>() .with_context(|_| DataTypeSnafu { msg: "failed to convert `ScalarValue` state to `Value`", @@ -142,12 +141,55 @@ impl AccumulatorV2 for DfAccumulatorAdapter { .iter() .map(|v| v.to_arrow_array().clone()) .collect::>(); - self.acc().merge_batch(&states).context(DatafusionSnafu { + self.inner.merge_batch(&states).context(DatafusionSnafu { context: "failed to merge batch", }) } } +/// Convert a list of states(from `Accumulator::into_state`) +/// to a batch of vectors(that can be feed to `Accumulator::merge_batch`) +fn states_to_batch(states: Vec>, dts: Vec) -> Result, EvalError> { + if states.is_empty() || states[0].is_empty() { + return Ok(vec![]); + } + ensure!( + states.iter().map(|v| v.len()).all(|l| l == states[0].len()), + InternalSnafu { + reason: "states have different lengths" + } + ); + let state_cnt = states.len(); + let state_len = states[0].len(); + ensure!( + state_len == dts.len(), + InternalSnafu { + reason: format!( + "states and data types have different lengths: {} != {}", + state_len, + dts.len() + ) + } + ); + let mut ret = dts + .into_iter() + .map(|dt| dt.create_mutable_vector(state_len)) + .collect::>(); + for i in 0..state_len { + for j in 0..state_cnt { + // the j-th state's i-th value + let val = &states[j][i]; + ret.get_mut(i) + .with_context(|| InternalSnafu { + reason: format!("failed to get mutable vector at index {}", i), + })? + .push_value_ref(val.as_value_ref()); + } + } + let ret = ret.into_iter().map(|mut v| v.to_vector()).collect(); + Ok(ret) +} + fn fail_accum() -> EvalError { InternalSnafu { reason: format!( From 8155d04429380f74808a7bf7a5a85c93d5a14ccc Mon Sep 17 00:00:00 2001 From: discord9 Date: Tue, 11 Feb 2025 20:28:21 +0800 Subject: [PATCH 07/18] feat: from_substrait --- src/flow/Cargo.toml | 1 + src/flow/src/adapter/node_context.rs | 32 ++++- src/flow/src/error.rs | 4 +- src/flow/src/expr/relation.rs | 31 +++++ src/flow/src/expr/relation/accum_v2.rs | 5 +- src/flow/src/expr/relation/udaf.rs | 2 + src/flow/src/transform/aggr.rs | 158 ++++++++++++++++++++++++- src/flow/src/transform/expr.rs | 2 +- 8 files changed, 224 insertions(+), 11 deletions(-) diff --git a/src/flow/Cargo.toml b/src/flow/Cargo.toml index b4545e1a899a..91e5945ee2db 100644 --- a/src/flow/Cargo.toml +++ b/src/flow/Cargo.toml @@ -39,6 +39,7 @@ datafusion-expr.workspace = true datafusion-physical-expr.workspace = true datafusion-substrait.workspace = true datatypes.workspace = true +derive-where = "1.2.7" enum-as-inner = "0.6.0" enum_dispatch = "0.3" futures = "0.3" diff --git a/src/flow/src/adapter/node_context.rs b/src/flow/src/adapter/node_context.rs index 7983b396fedc..c42b63a49989 100644 --- a/src/flow/src/adapter/node_context.rs +++ b/src/flow/src/adapter/node_context.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use common_recordbatch::RecordBatch; use common_telemetry::trace; +use datafusion_expr::AggregateUDF; use datatypes::prelude::ConcreteDataType; use session::context::QueryContext; use snafu::{OptionExt, ResultExt}; @@ -59,11 +60,13 @@ pub struct FlownodeContext { /// All the tables that have been registered in the worker pub table_repr: IdToNameMap, pub query_context: Option>, + /// Aggregate functions registered in the context + pub aggregate_functions: HashMap>, } impl FlownodeContext { pub fn new(table_source: Box) -> Self { - Self { + let mut ret = Self { source_to_tasks: Default::default(), flow_to_sink: Default::default(), flow_plans: Default::default(), @@ -73,7 +76,32 @@ impl FlownodeContext { table_source, table_repr: Default::default(), query_context: Default::default(), - } + aggregate_functions: Default::default(), + }; + ret.register_all_built_in_aggr_fns(); + ret + } + + pub fn register_all_built_in_aggr_fns(&mut self) { + // sum/min/max/count are built-in aggregate functions + use datafusion::functions_aggregate::count::count_udaf; + use datafusion::functions_aggregate::min_max::{max_udaf, min_udaf}; + use datafusion::functions_aggregate::sum::sum_udaf; + let res = HashMap::from([ + ("sum".to_string(), sum_udaf()), + ("min".to_string(), min_udaf()), + ("max".to_string(), max_udaf()), + ("count".to_string(), count_udaf()), + ]); + self.aggregate_functions.extend(res); + } + + pub fn register_aggr_fn( + &mut self, + name: String, + aggr_fn: Arc, + ) -> Option> { + self.aggregate_functions.insert(name, aggr_fn) } pub fn get_flow_ids(&self, table_id: TableId) -> Option<&BTreeSet> { diff --git a/src/flow/src/error.rs b/src/flow/src/error.rs index 703b47641c5c..bf44c12fba0a 100644 --- a/src/flow/src/error.rs +++ b/src/flow/src/error.rs @@ -148,10 +148,10 @@ pub enum Error { location: Location, }, - #[snafu(display("Datatypes error: {source} with extra message: {extra}"))] + #[snafu(display("Datatypes error: {source} with extra message: {context}"))] Datatypes { source: datatypes::Error, - extra: String, + context: String, #[snafu(implicit)] location: Location, }, diff --git a/src/flow/src/expr/relation.rs b/src/flow/src/expr/relation.rs index b3f26edb8750..851bf7430735 100644 --- a/src/flow/src/expr/relation.rs +++ b/src/flow/src/expr/relation.rs @@ -15,9 +15,15 @@ //! Describes an aggregation function and it's input expression. pub(crate) use accum::{Accum, Accumulator}; +use datafusion_expr::AggregateUDF; +use datatypes::prelude::ConcreteDataType; pub(crate) use func::AggregateFunc; +pub use udaf::{OrderingReq, SortExpr}; +use crate::expr::relation::accum_v2::AccumulatorV2; use crate::expr::ScalarExpr; +use crate::repr::RelationDesc; +use crate::Error; mod accum; mod accum_v2; @@ -36,3 +42,28 @@ pub struct AggregateExpr { /// Should the aggregation be applied only to distinct results in each group. pub distinct: bool, } + +#[derive(Clone, Debug)] +pub struct AggregateExprV2 { + pub func: AggregateUDF, + pub args: Vec, + /// Output / return type of this aggregate + pub return_type: ConcreteDataType, + pub name: String, + /// The schema of the input relation to this aggregate + pub schema: RelationDesc, + // i.e. FIRST_VALUE(a ORDER BY b) + pub ordering_req: OrderingReq, + pub ignore_nulls: bool, + pub is_distinct: bool, + pub is_reversed: bool, + /// The types of the arguments to this aggregate + pub input_types: Vec, + pub is_nullable: bool, +} + +impl AggregateExprV2 { + pub fn create_accumulator(&self) -> Result, Error> { + todo!("create_accumulator") + } +} diff --git a/src/flow/src/expr/relation/accum_v2.rs b/src/flow/src/expr/relation/accum_v2.rs index 266394f43bae..485dd19a6354 100644 --- a/src/flow/src/expr/relation/accum_v2.rs +++ b/src/flow/src/expr/relation/accum_v2.rs @@ -159,7 +159,6 @@ fn states_to_batch(states: Vec>, dts: Vec) -> Result>, dts: Vec) -> Result>(); for i in 0..state_len { - for j in 0..state_cnt { + for state in states.iter() { // the j-th state's i-th value - let val = &states[j][i]; + let val = &state[i]; ret.get_mut(i) .with_context(|| InternalSnafu { reason: format!("failed to get mutable vector at index {}", i), diff --git a/src/flow/src/expr/relation/udaf.rs b/src/flow/src/expr/relation/udaf.rs index 4df2726400c6..c87afca70a58 100644 --- a/src/flow/src/expr/relation/udaf.rs +++ b/src/flow/src/expr/relation/udaf.rs @@ -63,10 +63,12 @@ pub struct AccumulatorArgs<'a> { pub exprs: &'a [ScalarExpr], } +#[derive(Debug, Clone)] pub struct OrderingReq { pub exprs: Vec, } +#[derive(Debug, Clone)] pub struct SortExpr { /// expression representing the column to sort pub expr: ScalarExpr, diff --git a/src/flow/src/transform/aggr.rs b/src/flow/src/transform/aggr.rs index 27c7acfb1da9..e1bde528b4c0 100644 --- a/src/flow/src/transform/aggr.rs +++ b/src/flow/src/transform/aggr.rs @@ -12,14 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. +use datatypes::prelude::{ConcreteDataType, DataType}; use itertools::Itertools; -use snafu::OptionExt; +use snafu::{OptionExt, ResultExt}; use substrait_proto::proto; use substrait_proto::proto::aggregate_function::AggregationInvocation; use substrait_proto::proto::aggregate_rel::{Grouping, Measure}; use substrait_proto::proto::function_argument::ArgType; +use substrait_proto::proto::sort_field::{SortDirection, SortKind}; -use crate::error::{Error, NotImplementedSnafu, PlanSnafu}; +use crate::error::{DatafusionSnafu, DatatypesSnafu, Error, NotImplementedSnafu, PlanSnafu}; +use crate::expr::relation::{AggregateExprV2, OrderingReq, SortExpr}; use crate::expr::{ AggregateExpr, AggregateFunc, MapFilterProject, ScalarExpr, TypedExpr, UnaryFunc, }; @@ -86,7 +89,7 @@ impl TypedExpr { impl AggregateExpr { /// Convert list of `Measure` into Flow's AggregateExpr /// - /// Return both the AggregateExpr and a MapFilterProject that is the final output of the aggregate function + /// Return the AggregateExpr List that is the final output of the aggregate function async fn from_substrait_agg_measures( ctx: &mut FlownodeContext, measures: &[Measure], @@ -197,6 +200,155 @@ impl AggregateExpr { } } +impl AggregateExprV2 { + /// Convert list of `Measure` into Flow's AggregateExpr + /// + /// Return the AggregateExpr List that is the final output of the aggregate function + async fn from_substrait_agg_measures( + ctx: &mut FlownodeContext, + measures: &[Measure], + typ: &RelationDesc, + extensions: &FunctionExtensions, + ) -> Result, Error> { + let mut all_aggr_exprs = vec![]; + + for m in measures { + let filter = match &m.filter { + Some(fil) => Some(TypedExpr::from_substrait_rex(fil, typ, extensions).await?), + None => None, + }; + + let Some(f) = &m.measure else { + not_impl_err!("Expect aggregate function")? + }; + + let aggr_expr = Self::from_substrait_agg_func(ctx, f, typ, extensions, &filter).await?; + all_aggr_exprs.push(aggr_expr); + } + Ok(all_aggr_exprs) + } + + /// Convert AggregateFunction into Flow's AggregateExpr + /// + /// the returned value is a tuple of AggregateExpr and a optional ScalarExpr that if exist is the final output of the aggregate function + /// since aggr functions like `avg` need to be transform to `sum(x)/cast(count(x) as x_type)` + pub async fn from_substrait_agg_func( + ctx: &mut FlownodeContext, + f: &proto::AggregateFunction, + input_schema: &RelationDesc, + extensions: &FunctionExtensions, + filter: &Option, + ) -> Result { + // TODO(discord9): impl filter + let _ = filter; + + let mut args = vec![]; + for arg in &f.arguments { + let arg_expr = match &arg.arg_type { + Some(ArgType::Value(e)) => { + TypedExpr::from_substrait_rex(e, input_schema, extensions).await + } + _ => not_impl_err!("Aggregated function argument non-Value type not supported"), + }?; + args.push(arg_expr); + } + let args = args; + let distinct = match f.invocation { + _ if f.invocation == AggregationInvocation::Distinct as i32 => true, + _ if f.invocation == AggregationInvocation::All as i32 => false, + _ => false, + }; + + let fn_name = extensions + .get(&f.function_reference) + .cloned() + .with_context(|| PlanSnafu { + reason: format!( + "Aggregated function not found: function anchor = {:?}", + f.function_reference + ), + })?; + + let fn_impl = ctx + .aggregate_functions + .get(&fn_name) + .with_context(|| PlanSnafu { + reason: format!("Aggregate function not found: {:?}", fn_name), + })?; + + let input_types = args + .iter() + .map(|a| a.typ.scalar_type().clone()) + .collect_vec(); + + let return_type = { + let arrow_input_types = input_types.iter().map(|t| t.as_arrow_type()).collect_vec(); + let ret = fn_impl + .return_type(&arrow_input_types) + .context(DatafusionSnafu { + context: "failed to get return type of aggregate function", + })?; + ConcreteDataType::try_from(&ret).context(DatatypesSnafu { + context: "failed to convert return type to ConcreteDataType", + })? + }; + + let ordering_req = { + let mut ret = Vec::with_capacity(f.sorts.len()); + for sort in &f.sorts { + let Some(raw_expr) = sort.expr.as_ref() else { + return not_impl_err!("Sort expression not found in sort"); + }; + let expr = + TypedExpr::from_substrait_rex(raw_expr, input_schema, extensions).await?; + let sort_dir = sort.sort_kind; + let Some(SortKind::Direction(dir)) = sort_dir else { + return not_impl_err!("Sort direction not found in sort"); + }; + let dir = SortDirection::try_from(dir).map_err(|e| { + PlanSnafu { + reason: format!("{} is not a valid direction", e.0), + } + .build() + })?; + + let (descending, nulls_first) = match dir { + // align with default datafusion option + SortDirection::Unspecified => (false, true), + SortDirection::AscNullsFirst => (false, true), + SortDirection::AscNullsLast => (false, false), + SortDirection::DescNullsFirst => (true, true), + SortDirection::DescNullsLast => (true, false), + SortDirection::Clustered => not_impl_err!("Clustered sort not supported")?, + }; + + let sort_expr = SortExpr { + expr: expr.expr, + descending, + nulls_first, + }; + ret.push(sort_expr); + } + OrderingReq { exprs: ret } + }; + + // TODO(discord9): determine other options from substrait too instead of default + Ok(Self { + func: fn_impl.as_ref().clone(), + args: args.into_iter().map(|a| a.expr).collect(), + return_type, + name: fn_name, + schema: input_schema.clone(), + ordering_req, + ignore_nulls: false, + is_distinct: distinct, + is_reversed: false, + input_types, + is_nullable: true, + }) + } +} + impl KeyValPlan { /// Generate KeyValPlan from AggregateExpr and group_exprs /// diff --git a/src/flow/src/transform/expr.rs b/src/flow/src/transform/expr.rs index ed75252ee21c..b6c750979290 100644 --- a/src/flow/src/transform/expr.rs +++ b/src/flow/src/transform/expr.rs @@ -347,7 +347,7 @@ impl TypedExpr { datatypes::types::cast(val.clone(), &dest_type) .with_context(|_| DatatypesSnafu{ - extra: format!("Failed to implicitly cast literal {val:?} to type {dest_type:?}") + context: format!("Failed to implicitly cast literal {val:?} to type {dest_type:?}") })? } else { val.clone() From 41c79520b9e2c21bd9976fa12ffa33d6cb100506 Mon Sep 17 00:00:00 2001 From: discord9 Date: Wed, 12 Feb 2025 12:33:40 +0800 Subject: [PATCH 08/18] feat: create acc --- src/flow/src/expr/relation.rs | 34 ++++++++++-- src/flow/src/expr/relation/accum_v2.rs | 4 +- src/flow/src/expr/relation/udaf.rs | 72 ++++++++++---------------- src/flow/src/expr/scalar.rs | 23 +++++++- 4 files changed, 82 insertions(+), 51 deletions(-) diff --git a/src/flow/src/expr/relation.rs b/src/flow/src/expr/relation.rs index 851bf7430735..e7460723784d 100644 --- a/src/flow/src/expr/relation.rs +++ b/src/flow/src/expr/relation.rs @@ -15,12 +15,15 @@ //! Describes an aggregation function and it's input expression. pub(crate) use accum::{Accum, Accumulator}; +use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::AggregateUDF; -use datatypes::prelude::ConcreteDataType; +use datatypes::prelude::{ConcreteDataType, DataType}; pub(crate) use func::AggregateFunc; +use snafu::ResultExt; pub use udaf::{OrderingReq, SortExpr}; -use crate::expr::relation::accum_v2::AccumulatorV2; +use crate::error::DatafusionSnafu; +use crate::expr::relation::accum_v2::{AccumulatorV2, DfAccumulatorAdapter}; use crate::expr::ScalarExpr; use crate::repr::RelationDesc; use crate::Error; @@ -64,6 +67,31 @@ pub struct AggregateExprV2 { impl AggregateExprV2 { pub fn create_accumulator(&self) -> Result, Error> { - todo!("create_accumulator") + let data_type = self.return_type.as_arrow_type(); + let schema = self.schema.to_df_schema()?; + let ordering_req = self.ordering_req.to_lex_ordering(&schema)?; + let exprs = self + .args + .iter() + .map(|e| e.as_physical_expr(&schema)) + .collect::, _>>()?; + let accum_args = AccumulatorArgs { + return_type: &data_type, + schema: schema.as_arrow(), + ignore_nulls: self.ignore_nulls, + ordering_req: &ordering_req, + is_reversed: self.is_reversed, + name: &self.name, + is_distinct: self.is_distinct, + exprs: &exprs, + }; + let acc = self + .func + .accumulator(accum_args) + .with_context(|_| DatafusionSnafu { + context: "Fail to build accumulator", + })?; + let acc = DfAccumulatorAdapter::new_unchecked(acc); + Ok(Box::new(acc)) } } diff --git a/src/flow/src/expr/relation/accum_v2.rs b/src/flow/src/expr/relation/accum_v2.rs index 485dd19a6354..6b7a3acb81f7 100644 --- a/src/flow/src/expr/relation/accum_v2.rs +++ b/src/flow/src/expr/relation/accum_v2.rs @@ -83,11 +83,11 @@ impl AcceptDfAccumulator for datafusion::functions_aggregate::min_max::MinAccumu impl DfAccumulatorAdapter { /// create a new accumulator from a datafusion accumulator without checking if it is supported in flow - fn new_unchecked(acc: Box) -> Self { + pub fn new_unchecked(acc: Box) -> Self { Self { inner: acc } } - fn new(acc: T) -> Self { + pub fn new(acc: T) -> Self { Self::new_unchecked(Box::new(acc)) } } diff --git a/src/flow/src/expr/relation/udaf.rs b/src/flow/src/expr/relation/udaf.rs index c87afca70a58..e8aa013574b7 100644 --- a/src/flow/src/expr/relation/udaf.rs +++ b/src/flow/src/expr/relation/udaf.rs @@ -14,60 +14,28 @@ use std::fmt::Debug; -use common_query::prelude::Signature; -use datatypes::prelude::ConcreteDataType; +use datafusion_common::DFSchema; +use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; -use crate::expr::relation::accum_v2::AccumulatorV2; use crate::expr::ScalarExpr; -use crate::repr::RelationDesc; use crate::Result; -/// User-defined aggregate function (UDAF) implementation. -/// All built-in UDAFs for flow is also impl by this trait. -pub trait AggrUDFImpl: Debug + Send + Sync { - fn as_any(&self) -> &dyn std::any::Any; - - fn name(&self) -> &str; - - fn signature(&self) -> &Signature; - /// What ConcreteDataType will be returned by this function, given the types of the arguments - /// - /// Keep the return_type's error type the same as `Function`'s return_type - fn return_type(&self, arg_type: &[ConcreteDataType]) -> Result; - - fn accumulator(&self, acc_args: AccumulatorArgs<'_>) -> Result>; -} - -/// contains information about how an aggregate function was called, -/// including the types of its arguments and any optional ordering expressions. -/// -/// Should be created from AggregateExpr -pub struct AccumulatorArgs<'a> { - pub return_type: &'a ConcreteDataType, - pub schema: &'a RelationDesc, - pub ignore_nulls: bool, - /// The expressions in the `ORDER BY` clause passed to this aggregator. - /// - /// SQL allows the user to specify the ordering of arguments to the - /// aggregate using an `ORDER BY`. For example: - /// - /// ```sql - /// SELECT FIRST_VALUE(column1 ORDER BY column2) FROM t; - /// ``` - /// - /// If no `ORDER BY` is specified, `ordering_req` will be empty. - pub ordering_req: &'a OrderingReq, - pub is_reversed: bool, - pub name: &'a str, - pub is_distinct: bool, - pub exprs: &'a [ScalarExpr], -} - #[derive(Debug, Clone)] pub struct OrderingReq { pub exprs: Vec, } +impl OrderingReq { + pub fn to_lex_ordering(&self, schema: &DFSchema) -> Result { + Ok(LexOrdering::new( + self.exprs + .iter() + .map(|e| e.to_sort_phy_expr(schema)) + .collect::>()?, + )) + } +} + #[derive(Debug, Clone)] pub struct SortExpr { /// expression representing the column to sort @@ -77,3 +45,17 @@ pub struct SortExpr { /// Whether to sort nulls first pub nulls_first: bool, } + +impl SortExpr { + pub fn to_sort_phy_expr(&self, schema: &DFSchema) -> Result { + let phy = self.expr.as_physical_expr(schema)?; + let sort_options = datafusion_common::arrow::compute::SortOptions { + descending: self.descending, + nulls_first: self.nulls_first, + }; + Ok(PhysicalSortExpr { + expr: phy, + options: sort_options, + }) + } +} diff --git a/src/flow/src/expr/scalar.rs b/src/flow/src/expr/scalar.rs index 94af5e0e425a..3061981fa6c5 100644 --- a/src/flow/src/expr/scalar.rs +++ b/src/flow/src/expr/scalar.rs @@ -15,9 +15,12 @@ //! Scalar expressions. use std::collections::{BTreeMap, BTreeSet}; +use std::sync::Arc; use arrow::array::{make_array, ArrayData, ArrayRef}; use common_error::ext::BoxedError; +use datafusion_common::DFSchema; +use datafusion_physical_expr::PhysicalExpr; use datatypes::prelude::{ConcreteDataType, DataType}; use datatypes::value::Value; use datatypes::vectors::{BooleanVector, Helper, VectorRef}; @@ -26,7 +29,8 @@ use itertools::Itertools; use snafu::{ensure, OptionExt, ResultExt}; use crate::error::{ - DatafusionSnafu, Error, InvalidQuerySnafu, UnexpectedSnafu, UnsupportedTemporalFilterSnafu, + DatafusionSnafu, Error, InvalidQuerySnafu, NotImplementedSnafu, UnexpectedSnafu, + UnsupportedTemporalFilterSnafu, }; use crate::expr::error::{ ArrowSnafu, DataTypeSnafu, EvalError, InvalidArgumentSnafu, OptimizeSnafu, TypeMismatchSnafu, @@ -95,6 +99,23 @@ pub enum ScalarExpr { } impl ScalarExpr { + // TODO(discord9): impl more convert? + pub fn as_physical_expr(&self, df_schema: &DFSchema) -> Result, Error> { + match self { + Self::Column(i) => { + let field = df_schema.field(*i); + datafusion::physical_expr::expressions::col(field.name(), df_schema.as_arrow()) + .with_context(|_| DatafusionSnafu { + context: "Failed to create datafusion column expression", + }) + } + _ => NotImplementedSnafu { + reason: "Not implemented yet".to_string(), + } + .fail()?, + } + } + pub fn with_type(self, typ: ColumnType) -> TypedExpr { TypedExpr::new(self, typ) } From e7b453ad3dccdd2785ca0c5510e910fcf81d1cb7 Mon Sep 17 00:00:00 2001 From: discord9 Date: Wed, 12 Feb 2025 16:14:18 +0800 Subject: [PATCH 09/18] feat: switch transform to v2(need test) --- src/flow/src/compute/render/reduce.rs | 2 + src/flow/src/expr/relation.rs | 7 +- src/flow/src/expr/relation/udaf.rs | 4 +- src/flow/src/plan.rs | 4 +- src/flow/src/plan/reduce.rs | 34 +++++++++ src/flow/src/transform/aggr.rs | 100 +++++++++++++------------- 6 files changed, 99 insertions(+), 52 deletions(-) diff --git a/src/flow/src/compute/render/reduce.rs b/src/flow/src/compute/render/reduce.rs index 0bbc613260d7..ff45f8202f7d 100644 --- a/src/flow/src/compute/render/reduce.rs +++ b/src/flow/src/compute/render/reduce.rs @@ -252,6 +252,7 @@ impl Context<'_, '_> { ) -> Option> { match reduce_plan { ReducePlan::Distinct => None, + ReducePlan::AccumulableV2(_) => None, ReducePlan::Accumulable(AccumulablePlan { distinct_aggrs, .. }) => { (!distinct_aggrs.is_empty()).then(|| { std::iter::repeat_with(|| { @@ -679,6 +680,7 @@ fn reduce_subgraph( send, }, ), + ReducePlan::AccumulableV2(_) => unimplemented!(), }; } diff --git a/src/flow/src/expr/relation.rs b/src/flow/src/expr/relation.rs index e7460723784d..be0befa6bfe2 100644 --- a/src/flow/src/expr/relation.rs +++ b/src/flow/src/expr/relation.rs @@ -18,6 +18,7 @@ pub(crate) use accum::{Accum, Accumulator}; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::AggregateUDF; use datatypes::prelude::{ConcreteDataType, DataType}; +use derive_where::derive_where; pub(crate) use func::AggregateFunc; use snafu::ResultExt; pub use udaf::{OrderingReq, SortExpr}; @@ -46,9 +47,13 @@ pub struct AggregateExpr { pub distinct: bool, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq, PartialOrd)] +#[derive_where(Ord)] pub struct AggregateExprV2 { + /// skipping `Ord` impl for func for convenience + #[derive_where(skip)] pub func: AggregateUDF, + /// should only be a simple column ref list pub args: Vec, /// Output / return type of this aggregate pub return_type: ConcreteDataType, diff --git a/src/flow/src/expr/relation/udaf.rs b/src/flow/src/expr/relation/udaf.rs index e8aa013574b7..f3e07642a8eb 100644 --- a/src/flow/src/expr/relation/udaf.rs +++ b/src/flow/src/expr/relation/udaf.rs @@ -20,7 +20,7 @@ use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; use crate::expr::ScalarExpr; use crate::Result; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)] pub struct OrderingReq { pub exprs: Vec, } @@ -36,7 +36,7 @@ impl OrderingReq { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)] pub struct SortExpr { /// expression representing the column to sort pub expr: ScalarExpr, diff --git a/src/flow/src/plan.rs b/src/flow/src/plan.rs index b2c91015e065..e7aa8e2ed5f8 100644 --- a/src/flow/src/plan.rs +++ b/src/flow/src/plan.rs @@ -23,7 +23,9 @@ use std::collections::BTreeSet; use crate::error::Error; use crate::expr::{GlobalId, Id, LocalId, MapFilterProject, SafeMfpPlan, ScalarExpr, TypedExpr}; use crate::plan::join::JoinPlan; -pub(crate) use crate::plan::reduce::{AccumulablePlan, AggrWithIndex, KeyValPlan, ReducePlan}; +pub(crate) use crate::plan::reduce::{ + AccumulablePlan, AccumulablePlanV2, AggrWithIndex, AggrWithIndexV2, KeyValPlan, ReducePlan, +}; use crate::repr::{DiffRow, RelationDesc}; /// A plan for a dataflow component. But with type to indicate the output type of the relation. diff --git a/src/flow/src/plan/reduce.rs b/src/flow/src/plan/reduce.rs index 65a83756b57f..a6fa7e2cbd77 100644 --- a/src/flow/src/plan/reduce.rs +++ b/src/flow/src/plan/reduce.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use crate::expr::relation::AggregateExprV2; use crate::expr::{AggregateExpr, SafeMfpPlan, ScalarExpr}; /// Describe how to extract key-value pair from a `Row` @@ -43,6 +44,8 @@ pub enum ReducePlan { /// Plan for computing only accumulable aggregations. /// Including simple functions like `sum`, `count`, `min/max`(without deletion) Accumulable(AccumulablePlan), + /// Calling AggregateExprV2 + AccumulableV2(AccumulablePlanV2), } /// Accumulable plan for the execution of a reduction. @@ -85,3 +88,34 @@ impl AggrWithIndex { } } } + +/// Accumulable plan for the execution of a reduction. +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] +pub struct AccumulablePlanV2 { + /// All of the aggregations we were asked to compute, stored + /// in order. + pub full_aggrs: Vec, +} + +/// Invariant: the output index is the index of the aggregation in `full_aggrs` +/// which means output index is always smaller than the length of `full_aggrs` +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] +pub struct AggrWithIndexV2 { + /// aggregation expression + pub expr: AggregateExprV2, + /// index of aggr input among input row, get from `self.expr.args` + pub input_idxs: Vec, + /// index of aggr output among output row + pub output_idx: usize, +} + +impl AggrWithIndexV2 { + /// Create a new `AggrWithIndex` + pub fn new(expr: AggregateExprV2, input_idxs: Vec, output_idx: usize) -> Self { + Self { + expr, + input_idxs, + output_idx, + } + } +} diff --git a/src/flow/src/transform/aggr.rs b/src/flow/src/transform/aggr.rs index e1bde528b4c0..d34744140985 100644 --- a/src/flow/src/transform/aggr.rs +++ b/src/flow/src/transform/aggr.rs @@ -21,12 +21,14 @@ use substrait_proto::proto::aggregate_rel::{Grouping, Measure}; use substrait_proto::proto::function_argument::ArgType; use substrait_proto::proto::sort_field::{SortDirection, SortKind}; -use crate::error::{DatafusionSnafu, DatatypesSnafu, Error, NotImplementedSnafu, PlanSnafu}; +use crate::error::{ + DatafusionSnafu, DatatypesSnafu, Error, NotImplementedSnafu, PlanSnafu, UnexpectedSnafu, +}; use crate::expr::relation::{AggregateExprV2, OrderingReq, SortExpr}; use crate::expr::{ AggregateExpr, AggregateFunc, MapFilterProject, ScalarExpr, TypedExpr, UnaryFunc, }; -use crate::plan::{AccumulablePlan, AggrWithIndex, KeyValPlan, Plan, ReducePlan, TypedPlan}; +use crate::plan::{AccumulablePlanV2, AggrWithIndexV2, KeyValPlan, Plan, ReducePlan, TypedPlan}; use crate::repr::{ColumnType, RelationDesc, RelationType}; use crate::transform::{substrait_proto, FlownodeContext, FunctionExtensions}; @@ -354,7 +356,7 @@ impl KeyValPlan { /// /// will also change aggregate expr to use column ref if necessary fn from_substrait_gen_key_val_plan( - aggr_exprs: &mut [AggregateExpr], + aggr_exprs: &mut [AggregateExprV2], group_exprs: &[TypedExpr], input_arity: usize, ) -> Result { @@ -371,23 +373,33 @@ impl KeyValPlan { // val_plan is extracted from aggr_exprs to give aggr function it's necessary input // and since aggr func need inputs that is column ref, we just add a prefix mfp to transform any expr that is not into a column ref let val_plan = { - let need_mfp = aggr_exprs.iter().any(|agg| agg.expr.as_column().is_none()); + let need_mfp = aggr_exprs + .iter() + .any(|agg| agg.args.iter().any(|e| e.as_column().is_none())); if need_mfp { // create mfp from aggr_expr, and modify aggr_expr to use the output column of mfp - let input_exprs = aggr_exprs - .iter_mut() - .enumerate() - .map(|(idx, aggr)| { - let ret = aggr.expr.clone(); - aggr.expr = ScalarExpr::Column(idx); - ret - }) - .collect_vec(); - let aggr_arity = aggr_exprs.len(); - + let mut input_exprs = Vec::new(); + for aggr_expr in aggr_exprs.iter_mut() { + for arg in aggr_expr.args.iter_mut() { + match arg.as_column() { + Some(idx) => { + // directly refer to column in mfp + *arg = ScalarExpr::Column(input_exprs.len()); + input_exprs.push(ScalarExpr::Column(idx)); + } + None => { + // create a new expr and let arg ref to that expr's column instead + let ret = arg.clone(); + *arg = ScalarExpr::Column(input_exprs.len()); + input_exprs.push(ret); + } + } + } + } + let new_input_len = input_exprs.len(); MapFilterProject::new(input_arity) .map(input_exprs)? - .project(input_arity..input_arity + aggr_arity)? + .project(input_arity..input_arity + new_input_len)? } else { // simply take all inputs as value MapFilterProject::new(input_arity) @@ -444,7 +456,7 @@ impl TypedPlan { let time_index = find_time_index_in_group_exprs(&group_exprs); - let mut aggr_exprs = AggregateExpr::from_substrait_agg_measures( + let mut aggr_exprs = AggregateExprV2::from_substrait_agg_measures( ctx, &agg.measures, &input.schema, @@ -476,9 +488,7 @@ impl TypedPlan { } for aggr in &aggr_exprs { - output_types.push(ColumnType::new_nullable( - aggr.func.signature().output.clone(), - )); + output_types.push(ColumnType::new_nullable(aggr.return_type.clone())); // TODO(discord9): find a clever way to name them? output_names.push(None); } @@ -494,36 +504,30 @@ impl TypedPlan { // copy aggr_exprs to full_aggrs, and split them into simple_aggrs and distinct_aggrs // also set them input/output column - let full_aggrs = aggr_exprs; - let mut simple_aggrs = Vec::new(); - let mut distinct_aggrs = Vec::new(); - for (output_column, aggr_expr) in full_aggrs.iter().enumerate() { - let input_column = aggr_expr.expr.as_column().with_context(|| PlanSnafu { - reason: "Expect aggregate argument to be transformed into a column at this point", - })?; - if aggr_expr.distinct { - distinct_aggrs.push(AggrWithIndex::new( - aggr_expr.clone(), - input_column, - output_column, - )); - } else { - simple_aggrs.push(AggrWithIndex::new( - aggr_expr.clone(), - input_column, - output_column, - )); - } - } - let accum_plan = AccumulablePlan { - full_aggrs, - simple_aggrs, - distinct_aggrs, - }; + let full_aggrs = aggr_exprs + .into_iter() + .enumerate() + .map(|(idx, aggr)| -> Result<_, Error> { + Ok(AggrWithIndexV2 { + output_idx: idx, + input_idxs: aggr + .args + .iter() + .map(|a| { + a.as_column().with_context(|| UnexpectedSnafu { + reason: format!("Expect {:?} to be a column", a), + }) + }) + .collect::, _>>()?, + expr: aggr, + }) + }) + .collect::, _>>()?; + let accum_plan = AccumulablePlanV2 { full_aggrs }; let plan = Plan::Reduce { input: Box::new(input), key_val_plan, - reduce_plan: ReducePlan::Accumulable(accum_plan), + reduce_plan: ReducePlan::AccumulableV2(accum_plan), }; // FIX(discord9): deal with key first return Ok(TypedPlan { @@ -545,7 +549,7 @@ mod test { use super::*; use crate::expr::{BinaryFunc, DfScalarFunction, GlobalId, RawDfScalarFn}; - use crate::plan::{Plan, TypedPlan}; + use crate::plan::{AccumulablePlan, AggrWithIndex, Plan, TypedPlan}; use crate::repr::{ColumnType, RelationType}; use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait}; use crate::transform::CDT; From 681dd685914a300a5b6d895fb87eb65ea8a8c1a9 Mon Sep 17 00:00:00 2001 From: discord9 Date: Wed, 12 Feb 2025 16:45:26 +0800 Subject: [PATCH 10/18] tests: fix one test, 69 more to fix --- src/flow/src/expr/relation.rs | 2 ++ src/flow/src/expr/relation/udaf.rs | 3 +++ src/flow/src/transform/aggr.rs | 31 ++++++++++++++++++++++-------- 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/src/flow/src/expr/relation.rs b/src/flow/src/expr/relation.rs index be0befa6bfe2..8345095031de 100644 --- a/src/flow/src/expr/relation.rs +++ b/src/flow/src/expr/relation.rs @@ -70,6 +70,8 @@ pub struct AggregateExprV2 { pub is_nullable: bool, } +impl AggregateExprV2 {} + impl AggregateExprV2 { pub fn create_accumulator(&self) -> Result, Error> { let data_type = self.return_type.as_arrow_type(); diff --git a/src/flow/src/expr/relation/udaf.rs b/src/flow/src/expr/relation/udaf.rs index f3e07642a8eb..ee638f5fe07d 100644 --- a/src/flow/src/expr/relation/udaf.rs +++ b/src/flow/src/expr/relation/udaf.rs @@ -26,6 +26,9 @@ pub struct OrderingReq { } impl OrderingReq { + pub fn empty() -> Self { + Self { exprs: vec![] } + } pub fn to_lex_ordering(&self, schema: &DFSchema) -> Result { Ok(LexOrdering::new( self.exprs diff --git a/src/flow/src/transform/aggr.rs b/src/flow/src/transform/aggr.rs index d34744140985..67313e8c2ac3 100644 --- a/src/flow/src/transform/aggr.rs +++ b/src/flow/src/transform/aggr.rs @@ -543,6 +543,7 @@ mod test { use bytes::BytesMut; use common_time::{IntervalMonthDayNano, Timestamp}; + use datafusion::functions_aggregate::sum::sum_udaf; use datatypes::prelude::ConcreteDataType; use datatypes::value::Value; use pretty_assertions::assert_eq; @@ -1454,10 +1455,22 @@ mod test { let mut ctx = create_test_ctx(); let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; - let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt64, - expr: ScalarExpr::Column(0), - distinct: false, + let aggr_expr = AggregateExprV2 { + func: sum_udaf().as_ref().clone(), + args: vec![ScalarExpr::Column(0)], + return_type: CDT::uint64_datatype(), + name: "sum".to_string(), + schema: RelationType::new(vec![ColumnType::new( + ConcreteDataType::uint32_datatype(), + false, + )]) + .into_named(vec![Some("number".to_string())]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint64_datatype()], + is_nullable: true, }; let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]) @@ -1490,10 +1503,12 @@ mod test { .unwrap() .into_safe(), }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: vec![aggr_expr.clone()], - simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)], - distinct_aggrs: vec![], + reduce_plan: ReducePlan::AccumulableV2(AccumulablePlanV2 { + full_aggrs: vec![AggrWithIndexV2 { + expr: aggr_expr, + input_idxs: vec![0], + output_idx: 0, + }], }), }, }; From e8d4a9f534ab4fc90f4f4eed659f4b74db7576a5 Mon Sep 17 00:00:00 2001 From: discord9 Date: Wed, 12 Feb 2025 18:30:10 +0800 Subject: [PATCH 11/18] feat: execute v2 accum(need test) --- src/flow/src/compute/render/reduce.rs | 63 ++++++++++++++------------ src/flow/src/expr.rs | 2 +- src/flow/src/expr/relation.rs | 1 - src/flow/src/expr/relation/accum_v2.rs | 13 ++++-- src/flow/src/plan/reduce.rs | 19 ++++++-- src/flow/src/transform/aggr.rs | 38 ++++++---------- 6 files changed, 76 insertions(+), 60 deletions(-) diff --git a/src/flow/src/compute/render/reduce.rs b/src/flow/src/compute/render/reduce.rs index ff45f8202f7d..2dd6aba58e74 100644 --- a/src/flow/src/compute/render/reduce.rs +++ b/src/flow/src/compute/render/reduce.rs @@ -17,6 +17,7 @@ use std::ops::Range; use std::sync::Arc; use arrow::array::new_null_array; +use common_error::ext::BoxedError; use common_telemetry::trace; use datatypes::data_type::ConcreteDataType; use datatypes::prelude::DataType; @@ -29,9 +30,14 @@ use snafu::{ensure, OptionExt, ResultExt}; use crate::compute::render::{Context, SubgraphArg}; use crate::compute::types::{Arranged, Collection, CollectionBundle, ErrCollector, Toff}; use crate::error::{Error, NotImplementedSnafu, PlanSnafu}; -use crate::expr::error::{ArrowSnafu, DataAlreadyExpiredSnafu, DataTypeSnafu, InternalSnafu}; -use crate::expr::{Accum, Accumulator, Batch, EvalError, ScalarExpr, VectorDiff}; -use crate::plan::{AccumulablePlan, AggrWithIndex, KeyValPlan, ReducePlan, TypedPlan}; +use crate::expr::error::{ + ArrowSnafu, DataAlreadyExpiredSnafu, DataTypeSnafu, ExternalSnafu, InternalSnafu, +}; +use crate::expr::{Batch, EvalError, ScalarExpr}; +use crate::plan::{ + AccumulablePlan, AccumulablePlanV2, AggrWithIndex, AggrWithIndexV2, KeyValPlan, ReducePlan, + TypedPlan, +}; use crate::repr::{self, DiffRow, KeyValDiffRow, RelationType, Row}; use crate::utils::{ArrangeHandler, ArrangeReader, ArrangeWriter, KeyExpiryManager}; @@ -48,13 +54,7 @@ impl Context<'_, '_> { reduce_plan: &ReducePlan, output_type: &RelationType, ) -> Result, Error> { - let accum_plan = if let ReducePlan::Accumulable(accum_plan) = reduce_plan { - if !accum_plan.distinct_aggrs.is_empty() { - NotImplementedSnafu { - reason: "Distinct aggregation is not supported in batch mode", - } - .fail()? - } + let accum_plan = if let ReducePlan::AccumulableV2(accum_plan) = reduce_plan { accum_plan.clone() } else { NotImplementedSnafu { @@ -358,7 +358,7 @@ fn reduce_batch_subgraph( arrange: &ArrangeHandler, src_data: impl IntoIterator, key_val_plan: &KeyValPlan, - accum_plan: &AccumulablePlan, + accum_plan: &AccumulablePlanV2, output_type: &RelationType, SubgraphArg { now, @@ -530,39 +530,46 @@ fn reduce_batch_subgraph( err_collector.run(|| -> Result<(), _> { let (accums, _, _) = arrange.get(now, &key).unwrap_or_default(); let accum_list = - from_accum_values_to_live_accums(accums.unpack(), accum_plan.simple_aggrs.len())?; + from_accum_values_to_live_accums(accums.unpack(), accum_plan.full_aggrs.len())?; let mut accum_output = AccumOutput::new(); - for AggrWithIndex { + for AggrWithIndexV2 { expr, - input_idx, + input_idxs, output_idx, - } in accum_plan.simple_aggrs.iter() + state_types, + } in accum_plan.full_aggrs.iter() { let cur_accum_value = accum_list.get(*output_idx).cloned().unwrap_or_default(); - let mut cur_accum = if cur_accum_value.is_empty() { - Accum::new_accum(&expr.func.clone())? - } else { - Accum::try_into_accum(&expr.func, cur_accum_value)? - }; + let mut cur_accum = expr + .create_accumulator() + .map_err(BoxedError::new) + .context(ExternalSnafu)?; + cur_accum.merge_states(&[cur_accum_value], state_types)?; for val_batch in val_batches.iter() { // if batch is empty, input null instead - let cur_input = val_batch - .batch() - .get(*input_idx) - .cloned() - .unwrap_or_else(|| Arc::new(NullVector::new(val_batch.row_count()))); + let batch = val_batch.batch(); + + let cur_input = input_idxs + .iter() + .map(|idx| { + batch + .get(*idx) + .cloned() + .unwrap_or_else(|| Arc::new(NullVector::new(val_batch.row_count()))) + }) + .collect_vec(); let len = cur_input.len(); - cur_accum.update_batch(&expr.func, VectorDiff::from(cur_input))?; + cur_accum.update_batch(&cur_input)?; trace!("Reduce accum after take {} rows: {:?}", len, cur_accum); } - let final_output = cur_accum.eval(&expr.func)?; + let final_output = cur_accum.evaluate()?; trace!("Reduce accum final output: {:?}", final_output); accum_output.insert_output(*output_idx, final_output); - let cur_accum_value = cur_accum.into_state(); + let cur_accum_value = cur_accum.state()?.into_iter().map(|(v, _)| v).collect(); accum_output.insert_accum(*output_idx, cur_accum_value); } diff --git a/src/flow/src/expr.rs b/src/flow/src/expr.rs index a3c12a974247..86055195326b 100644 --- a/src/flow/src/expr.rs +++ b/src/flow/src/expr.rs @@ -34,7 +34,7 @@ pub(crate) use func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc} pub(crate) use id::{GlobalId, Id, LocalId}; use itertools::Itertools; pub(crate) use linear::{MapFilterProject, MfpPlan, SafeMfpPlan}; -pub(crate) use relation::{Accum, Accumulator, AggregateExpr, AggregateFunc}; +pub(crate) use relation::{AggregateExpr, AggregateFunc}; pub(crate) use scalar::{ScalarExpr, TypedExpr}; use snafu::{ensure, ResultExt}; diff --git a/src/flow/src/expr/relation.rs b/src/flow/src/expr/relation.rs index 8345095031de..039f9d64b0cd 100644 --- a/src/flow/src/expr/relation.rs +++ b/src/flow/src/expr/relation.rs @@ -14,7 +14,6 @@ //! Describes an aggregation function and it's input expression. -pub(crate) use accum::{Accum, Accumulator}; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::AggregateUDF; use datatypes::prelude::{ConcreteDataType, DataType}; diff --git a/src/flow/src/expr/relation/accum_v2.rs b/src/flow/src/expr/relation/accum_v2.rs index 6b7a3acb81f7..55d4ddcf0cc2 100644 --- a/src/flow/src/expr/relation/accum_v2.rs +++ b/src/flow/src/expr/relation/accum_v2.rs @@ -41,7 +41,7 @@ pub trait AccumulatorV2: Send + Sync + std::fmt::Debug { /// Returns the intermediate state of the accumulator, consuming the intermediate state. /// /// note that Value::Null's type is unknown, so (Value, ConcreteDataType) is used instead of just Value - fn into_state(self) -> Result, EvalError>; + fn state(&mut self) -> Result, EvalError>; /// Merges the states of multiple accumulators into this accumulator. /// The states array passed was formed by concatenating the results of calling `Self::into_state` on zero or more other Accumulator instances. @@ -64,6 +64,11 @@ pub trait AccumulatorV2: Send + Sync + std::fmt::Debug { fn supports_retract_batch(&self) -> bool { false } + + fn merge_states(&mut self, states: &[Vec], typs: &[CDT]) -> Result<(), EvalError> { + let states = states_to_batch(states, typs)?; + self.merge_batch(&states) + } } /// Adapter for several hand-picked datafusion accumulators that can be used in flow @@ -118,7 +123,7 @@ impl AccumulatorV2 for DfAccumulatorAdapter { self.inner.size() } - fn into_state(mut self) -> Result, EvalError> { + fn state(&mut self) -> Result, EvalError> { let state = self.inner.state().context(DatafusionSnafu { context: "failed to get state: {}", })?; @@ -149,7 +154,7 @@ impl AccumulatorV2 for DfAccumulatorAdapter { /// Convert a list of states(from `Accumulator::into_state`) /// to a batch of vectors(that can be feed to `Accumulator::merge_batch`) -fn states_to_batch(states: Vec>, dts: Vec) -> Result, EvalError> { +fn states_to_batch(states: &[Vec], dts: &[CDT]) -> Result, EvalError> { if states.is_empty() || states[0].is_empty() { return Ok(vec![]); } @@ -171,7 +176,7 @@ fn states_to_batch(states: Vec>, dts: Vec) -> Result>(); for i in 0..state_len { diff --git a/src/flow/src/plan/reduce.rs b/src/flow/src/plan/reduce.rs index a6fa7e2cbd77..4b93fc7bccd7 100644 --- a/src/flow/src/plan/reduce.rs +++ b/src/flow/src/plan/reduce.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use datatypes::prelude::ConcreteDataType; + use crate::expr::relation::AggregateExprV2; use crate::expr::{AggregateExpr, SafeMfpPlan, ScalarExpr}; @@ -97,6 +99,8 @@ pub struct AccumulablePlanV2 { pub full_aggrs: Vec, } +/// This struct basically get useful info from `expr` and store it so no need +/// to get it repeatly /// Invariant: the output index is the index of the aggregation in `full_aggrs` /// which means output index is always smaller than the length of `full_aggrs` #[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] @@ -107,15 +111,24 @@ pub struct AggrWithIndexV2 { pub input_idxs: Vec, /// index of aggr output among output row pub output_idx: usize, + /// The types of intermidate state field + pub state_types: Vec, } impl AggrWithIndexV2 { /// Create a new `AggrWithIndex` - pub fn new(expr: AggregateExprV2, input_idxs: Vec, output_idx: usize) -> Self { - Self { + pub fn new( + expr: AggregateExprV2, + input_idxs: Vec, + output_idx: usize, + ) -> Result { + let mut test_accum = expr.create_accumulator()?; + let states = test_accum.state()?; + Ok(Self { expr, input_idxs, output_idx, - } + state_types: states.into_iter().map(|(_, ty)| ty).collect(), + }) } } diff --git a/src/flow/src/transform/aggr.rs b/src/flow/src/transform/aggr.rs index 67313e8c2ac3..a3f041eac314 100644 --- a/src/flow/src/transform/aggr.rs +++ b/src/flow/src/transform/aggr.rs @@ -504,25 +504,21 @@ impl TypedPlan { // copy aggr_exprs to full_aggrs, and split them into simple_aggrs and distinct_aggrs // also set them input/output column - let full_aggrs = aggr_exprs - .into_iter() - .enumerate() - .map(|(idx, aggr)| -> Result<_, Error> { - Ok(AggrWithIndexV2 { - output_idx: idx, - input_idxs: aggr - .args - .iter() - .map(|a| { - a.as_column().with_context(|| UnexpectedSnafu { - reason: format!("Expect {:?} to be a column", a), - }) - }) - .collect::, _>>()?, - expr: aggr, + let mut full_aggrs = vec![]; + for (idx, aggr) in aggr_exprs.into_iter().enumerate() { + let input_idxs = aggr + .args + .iter() + .map(|a| { + a.as_column().with_context(|| UnexpectedSnafu { + reason: format!("Expect {:?} to be a column", a), + }) }) - }) - .collect::, _>>()?; + .collect::, _>>()?; + let aggr = AggrWithIndexV2::new(aggr, input_idxs, idx)?; + full_aggrs.push(aggr); + } + let accum_plan = AccumulablePlanV2 { full_aggrs }; let plan = Plan::Reduce { input: Box::new(input), @@ -1504,11 +1500,7 @@ mod test { .into_safe(), }, reduce_plan: ReducePlan::AccumulableV2(AccumulablePlanV2 { - full_aggrs: vec![AggrWithIndexV2 { - expr: aggr_expr, - input_idxs: vec![0], - output_idx: 0, - }], + full_aggrs: vec![AggrWithIndexV2::new(aggr_expr, vec![0], 0).unwrap()], }), }, }; From 75bef7b4596a45dc4c3f5fa8d02e39a9d8e142c9 Mon Sep 17 00:00:00 2001 From: discord9 Date: Wed, 12 Feb 2025 19:02:44 +0800 Subject: [PATCH 12/18] tests: one fixed --- src/flow/src/compute/render/reduce.rs | 45 ++++++++++++++++----------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/src/flow/src/compute/render/reduce.rs b/src/flow/src/compute/render/reduce.rs index 2dd6aba58e74..2bdf47882c10 100644 --- a/src/flow/src/compute/render/reduce.rs +++ b/src/flow/src/compute/render/reduce.rs @@ -545,7 +545,9 @@ fn reduce_batch_subgraph( .create_accumulator() .map_err(BoxedError::new) .context(ExternalSnafu)?; - cur_accum.merge_states(&[cur_accum_value], state_types)?; + if !cur_accum_value.is_empty() { + cur_accum.merge_states(&[cur_accum_value], state_types)?; + } for val_batch in val_batches.iter() { // if batch is empty, input null instead @@ -1220,12 +1222,14 @@ mod test { use std::time::Duration; use common_time::Timestamp; + use datafusion::functions_aggregate::sum::sum_udaf; use datatypes::data_type::{ConcreteDataType, ConcreteDataType as CDT}; use hydroflow::scheduled::graph::Hydroflow; use super::*; use crate::compute::render::test::{get_output_handle, harness_test_ctx, run_and_check}; use crate::compute::state::DataflowState; + use crate::expr::relation::{AggregateExprV2, OrderingReq}; use crate::expr::{ self, AggregateExpr, AggregateFunc, BinaryFunc, GlobalId, MapFilterProject, UnaryFunc, }; @@ -1590,26 +1594,29 @@ mod test { val_plan: MapFilterProject::new(1).project([0]).unwrap().into_safe(), }; - let simple_aggrs = vec![AggrWithIndex::new( - AggregateExpr { - func: AggregateFunc::SumInt64, - expr: ScalarExpr::Column(0), - distinct: false, - }, - 0, - 0, - )]; - let accum_plan = AccumulablePlan { - full_aggrs: vec![AggregateExpr { - func: AggregateFunc::SumInt64, - expr: ScalarExpr::Column(0), - distinct: false, - }], - simple_aggrs, - distinct_aggrs: vec![], + let aggr_expr = AggregateExprV2 { + func: sum_udaf().as_ref().clone(), + args: vec![ScalarExpr::Column(0)], + return_type: CDT::int64_datatype(), + name: "sum".to_string(), + schema: RelationType::new(vec![ColumnType::new( + ConcreteDataType::int32_datatype(), + false, + )]) + .into_named(vec![Some("number".to_string())]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::int64_datatype()], + is_nullable: true, }; - let reduce_plan = ReducePlan::Accumulable(accum_plan); + let accum_plan = AccumulablePlanV2 { + full_aggrs: vec![AggrWithIndexV2::new(aggr_expr, vec![0], 0).unwrap()], + }; + + let reduce_plan = ReducePlan::AccumulableV2(accum_plan); let bundle = ctx .render_reduce_batch( Box::new(input_plan.with_types(typ.into_unnamed())), From 3e672019eaeb6dd78d00c89607a9d9371eae99b2 Mon Sep 17 00:00:00 2001 From: discord9 Date: Thu, 13 Feb 2025 14:59:24 +0800 Subject: [PATCH 13/18] WIP --- Cargo.lock | 12 ++++++++++++ src/flow/src/expr/scalar.rs | 11 +++++++++-- src/flow/src/transform/aggr.rs | 18 ++++++++++++++---- 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e4091bcf5384..2c6cd82cc601 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3517,6 +3517,17 @@ dependencies = [ "syn 2.0.96", ] +[[package]] +name = "derive-where" +version = "1.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62d671cc41a825ebabc75757b62d3d168c577f9149b2d49ece1dad1f72119d25" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", +] + [[package]] name = "derive_arbitrary" version = "1.3.2" @@ -4184,6 +4195,7 @@ dependencies = [ "datafusion-physical-expr", "datafusion-substrait", "datatypes", + "derive-where", "enum-as-inner", "enum_dispatch", "futures", diff --git a/src/flow/src/expr/scalar.rs b/src/flow/src/expr/scalar.rs index 3061981fa6c5..9df6f849143a 100644 --- a/src/flow/src/expr/scalar.rs +++ b/src/flow/src/expr/scalar.rs @@ -29,8 +29,8 @@ use itertools::Itertools; use snafu::{ensure, OptionExt, ResultExt}; use crate::error::{ - DatafusionSnafu, Error, InvalidQuerySnafu, NotImplementedSnafu, UnexpectedSnafu, - UnsupportedTemporalFilterSnafu, + DatafusionSnafu, DatatypesSnafu, Error, InvalidQuerySnafu, NotImplementedSnafu, + UnexpectedSnafu, UnsupportedTemporalFilterSnafu, }; use crate::expr::error::{ ArrowSnafu, DataTypeSnafu, EvalError, InvalidArgumentSnafu, OptimizeSnafu, TypeMismatchSnafu, @@ -109,6 +109,13 @@ impl ScalarExpr { context: "Failed to create datafusion column expression", }) } + Self::Literal(val, typ) => { + let val = val.try_to_scalar_value(typ).context(DatatypesSnafu { + context: format!("Failed to convert val=:{:?} to literal", val), + })?; + let ret = datafusion::physical_expr::expressions::lit(val); + Ok(ret) + } _ => NotImplementedSnafu { reason: "Not implemented yet".to_string(), } diff --git a/src/flow/src/transform/aggr.rs b/src/flow/src/transform/aggr.rs index a3f041eac314..c5b41aa3943e 100644 --- a/src/flow/src/transform/aggr.rs +++ b/src/flow/src/transform/aggr.rs @@ -335,6 +335,8 @@ impl AggregateExprV2 { }; // TODO(discord9): determine other options from substrait too instead of default + dbg!(&args); + dbg!(&input_schema); Ok(Self { func: fn_impl.as_ref().clone(), args: args.into_iter().map(|a| a.expr).collect(), @@ -371,7 +373,7 @@ impl KeyValPlan { .project(input_arity..input_arity + output_arity)?; // val_plan is extracted from aggr_exprs to give aggr function it's necessary input - // and since aggr func need inputs that is column ref, we just add a prefix mfp to transform any expr that is not into a column ref + // and since aggr func need inputs that is column ref(or literal), we just add a prefix mfp to transform any expr that is not into a column ref let val_plan = { let need_mfp = aggr_exprs .iter() @@ -380,14 +382,22 @@ impl KeyValPlan { // create mfp from aggr_expr, and modify aggr_expr to use the output column of mfp let mut input_exprs = Vec::new(); for aggr_expr in aggr_exprs.iter_mut() { + // FIX: also modify input_schema to fit input_exprs, a `new_input_schema` is needed + // so we can separate all scalar compute to a mfp before aggr for arg in aggr_expr.args.iter_mut() { - match arg.as_column() { - Some(idx) => { + match arg.clone() { + ScalarExpr::Column(idx) => { // directly refer to column in mfp *arg = ScalarExpr::Column(input_exprs.len()); input_exprs.push(ScalarExpr::Column(idx)); } - None => { + ScalarExpr::Literal(_, _) => { + // already literal, but still need to make it ref + let ret = arg.clone(); + *arg = ScalarExpr::Column(input_exprs.len()); + input_exprs.push(ret); + } + _ => { // create a new expr and let arg ref to that expr's column instead let ret = arg.clone(); *arg = ScalarExpr::Column(input_exprs.len()); From 9362893732c676d6900a868dfe2ecbffd231ad08 Mon Sep 17 00:00:00 2001 From: discord9 Date: Thu, 13 Feb 2025 17:06:18 +0800 Subject: [PATCH 14/18] fix: literal as column ref --- src/flow/src/compute/render/reduce.rs | 4 ++- src/flow/src/expr/relation.rs | 6 ++--- src/flow/src/expr/relation/accum_v2.rs | 2 +- src/flow/src/expr/scalar.rs | 11 ++++++++ src/flow/src/plan/reduce.rs | 2 +- src/flow/src/transform/aggr.rs | 37 ++++++++++++++++---------- src/flow/src/transform/expr.rs | 1 + 7 files changed, 43 insertions(+), 20 deletions(-) diff --git a/src/flow/src/compute/render/reduce.rs b/src/flow/src/compute/render/reduce.rs index 2bdf47882c10..393ad24ce0e9 100644 --- a/src/flow/src/compute/render/reduce.rs +++ b/src/flow/src/compute/render/reduce.rs @@ -1596,7 +1596,9 @@ mod test { let aggr_expr = AggregateExprV2 { func: sum_udaf().as_ref().clone(), - args: vec![ScalarExpr::Column(0)], + args: vec![ + ScalarExpr::Column(0).with_type(ColumnType::new(CDT::int64_datatype(), false)) + ], return_type: CDT::int64_datatype(), name: "sum".to_string(), schema: RelationType::new(vec![ColumnType::new( diff --git a/src/flow/src/expr/relation.rs b/src/flow/src/expr/relation.rs index 039f9d64b0cd..2e5bd75f9de2 100644 --- a/src/flow/src/expr/relation.rs +++ b/src/flow/src/expr/relation.rs @@ -24,7 +24,7 @@ pub use udaf::{OrderingReq, SortExpr}; use crate::error::DatafusionSnafu; use crate::expr::relation::accum_v2::{AccumulatorV2, DfAccumulatorAdapter}; -use crate::expr::ScalarExpr; +use crate::expr::{ScalarExpr, TypedExpr}; use crate::repr::RelationDesc; use crate::Error; @@ -53,7 +53,7 @@ pub struct AggregateExprV2 { #[derive_where(skip)] pub func: AggregateUDF, /// should only be a simple column ref list - pub args: Vec, + pub args: Vec, /// Output / return type of this aggregate pub return_type: ConcreteDataType, pub name: String, @@ -79,7 +79,7 @@ impl AggregateExprV2 { let exprs = self .args .iter() - .map(|e| e.as_physical_expr(&schema)) + .map(|e| e.expr.as_physical_expr(&schema)) .collect::, _>>()?; let accum_args = AccumulatorArgs { return_type: &data_type, diff --git a/src/flow/src/expr/relation/accum_v2.rs b/src/flow/src/expr/relation/accum_v2.rs index 55d4ddcf0cc2..cd5af6d8d5c5 100644 --- a/src/flow/src/expr/relation/accum_v2.rs +++ b/src/flow/src/expr/relation/accum_v2.rs @@ -27,7 +27,7 @@ use crate::expr::error::{DataTypeSnafu, DatafusionSnafu, InternalSnafu, TryFromV use crate::expr::EvalError; /// Basically a copy of datafusion's Accumulator, but with a few modifications -/// to accomodate our needs in flow and keep the upgradability of datafusion +/// to accommodate our needs in flow and keep the upgradability of datafusion pub trait AccumulatorV2: Send + Sync + std::fmt::Debug { /// Updates the accumulator’s state from its input. fn update_batch(&mut self, values: &[VectorRef]) -> Result<(), EvalError>; diff --git a/src/flow/src/expr/scalar.rs b/src/flow/src/expr/scalar.rs index 9df6f849143a..7288146bb511 100644 --- a/src/flow/src/expr/scalar.rs +++ b/src/flow/src/expr/scalar.rs @@ -103,6 +103,17 @@ impl ScalarExpr { pub fn as_physical_expr(&self, df_schema: &DFSchema) -> Result, Error> { match self { Self::Column(i) => { + if *i >= df_schema.fields().len() { + return InvalidQuerySnafu { + reason: format!( + "column index {} out of range of len={} in df_schema={:?}", + i, + df_schema.fields().len(), + df_schema + ), + } + .fail(); + } let field = df_schema.field(*i); datafusion::physical_expr::expressions::col(field.name(), df_schema.as_arrow()) .with_context(|_| DatafusionSnafu { diff --git a/src/flow/src/plan/reduce.rs b/src/flow/src/plan/reduce.rs index 4b93fc7bccd7..af7f40213abe 100644 --- a/src/flow/src/plan/reduce.rs +++ b/src/flow/src/plan/reduce.rs @@ -100,7 +100,7 @@ pub struct AccumulablePlanV2 { } /// This struct basically get useful info from `expr` and store it so no need -/// to get it repeatly +/// to get it repeatedly /// Invariant: the output index is the index of the aggregation in `full_aggrs` /// which means output index is always smaller than the length of `full_aggrs` #[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] diff --git a/src/flow/src/transform/aggr.rs b/src/flow/src/transform/aggr.rs index c5b41aa3943e..f1b8b7238451 100644 --- a/src/flow/src/transform/aggr.rs +++ b/src/flow/src/transform/aggr.rs @@ -335,11 +335,9 @@ impl AggregateExprV2 { }; // TODO(discord9): determine other options from substrait too instead of default - dbg!(&args); - dbg!(&input_schema); Ok(Self { func: fn_impl.as_ref().clone(), - args: args.into_iter().map(|a| a.expr).collect(), + args, return_type, name: fn_name, schema: input_schema.clone(), @@ -377,7 +375,7 @@ impl KeyValPlan { let val_plan = { let need_mfp = aggr_exprs .iter() - .any(|agg| agg.args.iter().any(|e| e.as_column().is_none())); + .any(|agg| agg.args.iter().any(|e| e.expr.as_column().is_none())); if need_mfp { // create mfp from aggr_expr, and modify aggr_expr to use the output column of mfp let mut input_exprs = Vec::new(); @@ -385,31 +383,40 @@ impl KeyValPlan { // FIX: also modify input_schema to fit input_exprs, a `new_input_schema` is needed // so we can separate all scalar compute to a mfp before aggr for arg in aggr_expr.args.iter_mut() { - match arg.clone() { + match arg.expr.clone() { ScalarExpr::Column(idx) => { // directly refer to column in mfp - *arg = ScalarExpr::Column(input_exprs.len()); + arg.expr = ScalarExpr::Column(input_exprs.len()); input_exprs.push(ScalarExpr::Column(idx)); } ScalarExpr::Literal(_, _) => { // already literal, but still need to make it ref - let ret = arg.clone(); - *arg = ScalarExpr::Column(input_exprs.len()); + let ret = arg.expr.clone(); + arg.expr = ScalarExpr::Column(input_exprs.len()); input_exprs.push(ret); } _ => { // create a new expr and let arg ref to that expr's column instead - let ret = arg.clone(); - *arg = ScalarExpr::Column(input_exprs.len()); + let ret = arg.expr.clone(); + arg.expr = ScalarExpr::Column(input_exprs.len()); input_exprs.push(ret); } } } } let new_input_len = input_exprs.len(); - MapFilterProject::new(input_arity) + let pre_mfp = MapFilterProject::new(input_arity) .map(input_exprs)? - .project(input_arity..input_arity + new_input_len)? + .project(input_arity..input_arity + new_input_len)?; + // adjust input schema according to pre_mfp + if let Some(first) = aggr_exprs.first() { + let new_input_schema = first.schema.apply_mfp(&pre_mfp.clone().into_safe())?; + + for aggr in aggr_exprs.iter_mut() { + aggr.schema = new_input_schema.clone(); + } + } + pre_mfp } else { // simply take all inputs as value MapFilterProject::new(input_arity) @@ -520,7 +527,7 @@ impl TypedPlan { .args .iter() .map(|a| { - a.as_column().with_context(|| UnexpectedSnafu { + a.expr.as_column().with_context(|| UnexpectedSnafu { reason: format!("Expect {:?} to be a column", a), }) }) @@ -1463,7 +1470,9 @@ mod test { let aggr_expr = AggregateExprV2 { func: sum_udaf().as_ref().clone(), - args: vec![ScalarExpr::Column(0)], + args: vec![ + ScalarExpr::Column(0).with_type(ColumnType::new(CDT::uint64_datatype(), true)) + ], return_type: CDT::uint64_datatype(), name: "sum".to_string(), schema: RelationType::new(vec![ColumnType::new( diff --git a/src/flow/src/transform/expr.rs b/src/flow/src/transform/expr.rs index b6c750979290..981748267f65 100644 --- a/src/flow/src/transform/expr.rs +++ b/src/flow/src/transform/expr.rs @@ -292,6 +292,7 @@ impl TypedExpr { 1 if UnaryFunc::is_valid_func_name(fn_name) => { let func = UnaryFunc::from_str_and_type(fn_name, None)?; let arg = arg_exprs[0].clone(); + // TODO(discord9); forward nullable to return type too? let ret_type = ColumnType::new_nullable(func.signature().output.clone()); Ok(TypedExpr::new(arg.call_unary(func), ret_type)) From dbf88da86b4dc9460352f955a796478228bdc884 Mon Sep 17 00:00:00 2001 From: discord9 Date: Thu, 13 Feb 2025 19:54:38 +0800 Subject: [PATCH 15/18] tests: fix all tests --- src/flow/src/transform/aggr.rs | 442 ++++++++++++++++++++++++--------- 1 file changed, 320 insertions(+), 122 deletions(-) diff --git a/src/flow/src/transform/aggr.rs b/src/flow/src/transform/aggr.rs index f1b8b7238451..3ec8d97c1f0c 100644 --- a/src/flow/src/transform/aggr.rs +++ b/src/flow/src/transform/aggr.rs @@ -556,6 +556,8 @@ mod test { use bytes::BytesMut; use common_time::{IntervalMonthDayNano, Timestamp}; + use datafusion::functions_aggregate::count::count_udaf; + use datafusion::functions_aggregate::min_max::{max_udaf, min_udaf}; use datafusion::functions_aggregate::sum::sum_udaf; use datatypes::prelude::ConcreteDataType; use datatypes::value::Value; @@ -579,10 +581,24 @@ mod test { .await .unwrap(); - let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt64, - expr: ScalarExpr::Column(0), - distinct: false, + let aggr_expr = AggregateExprV2 { + func: sum_udaf().as_ref().clone(), + args: vec![ + ScalarExpr::Column(0).with_type(ColumnType::new(CDT::uint64_datatype(), true)) + ], + return_type: CDT::uint64_datatype(), + name: "sum".to_string(), + schema: RelationType::new(vec![ColumnType::new( + ConcreteDataType::uint64_datatype(), + true, + )]) + .into_named(vec![None]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint64_datatype()], + is_nullable: true, }; let expected = TypedPlan { schema: RelationType::new(vec![ @@ -678,10 +694,8 @@ mod test { .unwrap() .into_safe(), }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: vec![aggr_expr.clone()], - simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)], - distinct_aggrs: vec![], + reduce_plan: ReducePlan::AccumulableV2(AccumulablePlanV2 { + full_aggrs: vec![AggrWithIndexV2::new(aggr_expr, vec![0], 0).unwrap()], }), } .with_types( @@ -720,10 +734,24 @@ mod test { .await .unwrap(); - let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt64, - expr: ScalarExpr::Column(0), - distinct: false, + let aggr_expr = AggregateExprV2 { + func: sum_udaf().as_ref().clone(), + args: vec![ + ScalarExpr::Column(0).with_type(ColumnType::new(CDT::uint64_datatype(), true)) + ], + return_type: CDT::uint64_datatype(), + name: "sum".to_string(), + schema: RelationType::new(vec![ColumnType::new( + ConcreteDataType::uint64_datatype(), + true, + )]) + .into_named(vec![None]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint64_datatype()], + is_nullable: true, }; let expected = TypedPlan { schema: RelationType::new(vec![ @@ -792,10 +820,8 @@ mod test { .unwrap() .into_safe(), }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: vec![aggr_expr.clone()], - simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)], - distinct_aggrs: vec![], + reduce_plan: ReducePlan::AccumulableV2(AccumulablePlanV2 { + full_aggrs: vec![AggrWithIndexV2::new(aggr_expr, vec![0], 0).unwrap()], }), } .with_types( @@ -858,16 +884,52 @@ mod test { .unwrap(); let aggr_exprs = vec![ - AggregateExpr { - func: AggregateFunc::SumUInt64, - expr: ScalarExpr::Column(0), - distinct: false, - }, - AggregateExpr { - func: AggregateFunc::Count, - expr: ScalarExpr::Column(1), - distinct: false, - }, + AggrWithIndexV2::new( + AggregateExprV2 { + func: sum_udaf().as_ref().clone(), + args: vec![ScalarExpr::Column(0) + .with_type(ColumnType::new(CDT::uint64_datatype(), true))], + return_type: CDT::uint64_datatype(), + name: "sum".to_string(), + schema: RelationType::new(vec![ + ColumnType::new(ConcreteDataType::uint64_datatype(), true), + ColumnType::new(ConcreteDataType::uint32_datatype(), false), + ]) + .into_named(vec![None, Some("number".to_string())]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint64_datatype()], + is_nullable: true, + }, + vec![0], + 0, + ) + .unwrap(), + AggrWithIndexV2::new( + AggregateExprV2 { + func: count_udaf().as_ref().clone(), + args: vec![ScalarExpr::Column(1) + .with_type(ColumnType::new(CDT::uint32_datatype(), false))], + return_type: CDT::int64_datatype(), + name: "count".to_string(), + schema: RelationType::new(vec![ + ColumnType::new(ConcreteDataType::uint64_datatype(), true), + ColumnType::new(ConcreteDataType::uint32_datatype(), false), + ]) + .into_named(vec![None, Some("number".to_string())]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint32_datatype()], + is_nullable: true, + }, + vec![1], + 1, + ) + .unwrap(), ]; let avg_expr = ScalarExpr::If { cond: Box::new(ScalarExpr::Column(4).call_binary( @@ -939,13 +1001,8 @@ mod test { .unwrap() .into_safe(), }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: aggr_exprs.clone(), - simple_aggrs: vec![ - AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0), - AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1), - ], - distinct_aggrs: vec![], + reduce_plan: ReducePlan::AccumulableV2(AccumulablePlanV2 { + full_aggrs: aggr_exprs, }), } .with_types( @@ -1009,11 +1066,26 @@ mod test { .await .unwrap(); - let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt64, - expr: ScalarExpr::Column(0), - distinct: false, + let aggr_expr = AggregateExprV2 { + func: sum_udaf().as_ref().clone(), + args: vec![ + ScalarExpr::Column(0).with_type(ColumnType::new(CDT::uint64_datatype(), true)) + ], + return_type: CDT::uint64_datatype(), + name: "sum".to_string(), + schema: RelationType::new(vec![ColumnType::new( + ConcreteDataType::uint64_datatype(), + true, + )]) + .into_named(vec![None]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint64_datatype()], + is_nullable: true, }; + let expected = TypedPlan { schema: RelationType::new(vec![ ColumnType::new(CDT::uint64_datatype(), true), // sum(number) @@ -1077,10 +1149,8 @@ mod test { .unwrap() .into_safe(), }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: vec![aggr_expr.clone()], - simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)], - distinct_aggrs: vec![], + reduce_plan: ReducePlan::AccumulableV2(AccumulablePlanV2 { + full_aggrs: vec![AggrWithIndexV2::new(aggr_expr, vec![0], 0).unwrap()], }), } .with_types( @@ -1119,11 +1189,26 @@ mod test { .await .unwrap(); - let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt64, - expr: ScalarExpr::Column(0), - distinct: false, + let aggr_expr = AggregateExprV2 { + func: sum_udaf().as_ref().clone(), + args: vec![ + ScalarExpr::Column(0).with_type(ColumnType::new(CDT::uint64_datatype(), true)) + ], + return_type: CDT::uint64_datatype(), + name: "sum".to_string(), + schema: RelationType::new(vec![ColumnType::new( + ConcreteDataType::uint64_datatype(), + true, + )]) + .into_named(vec![None]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint64_datatype()], + is_nullable: true, }; + let expected = TypedPlan { schema: RelationType::new(vec![ ColumnType::new(CDT::uint64_datatype(), true), // sum(number) @@ -1191,10 +1276,8 @@ mod test { .unwrap() .into_safe(), }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: vec![aggr_expr.clone()], - simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)], - distinct_aggrs: vec![], + reduce_plan: ReducePlan::AccumulableV2(AccumulablePlanV2 { + full_aggrs: vec![AggrWithIndexV2::new(aggr_expr, vec![0], 0).unwrap()], }), } .with_types( @@ -1232,16 +1315,52 @@ mod test { let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let aggr_exprs = vec![ - AggregateExpr { - func: AggregateFunc::SumUInt64, - expr: ScalarExpr::Column(0), - distinct: false, - }, - AggregateExpr { - func: AggregateFunc::Count, - expr: ScalarExpr::Column(1), - distinct: false, - }, + AggrWithIndexV2::new( + AggregateExprV2 { + func: sum_udaf().as_ref().clone(), + args: vec![ScalarExpr::Column(0) + .with_type(ColumnType::new(CDT::uint64_datatype(), true))], + return_type: CDT::uint64_datatype(), + name: "sum".to_string(), + schema: RelationType::new(vec![ + ColumnType::new(ConcreteDataType::uint64_datatype(), true), + ColumnType::new(ConcreteDataType::uint32_datatype(), false), + ]) + .into_named(vec![None, Some("number".to_string())]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint64_datatype()], + is_nullable: true, + }, + vec![0], + 0, + ) + .unwrap(), + AggrWithIndexV2::new( + AggregateExprV2 { + func: count_udaf().as_ref().clone(), + args: vec![ScalarExpr::Column(1) + .with_type(ColumnType::new(CDT::uint32_datatype(), false))], + return_type: CDT::int64_datatype(), + name: "count".to_string(), + schema: RelationType::new(vec![ + ColumnType::new(ConcreteDataType::uint64_datatype(), true), + ColumnType::new(ConcreteDataType::uint32_datatype(), false), + ]) + .into_named(vec![None, Some("number".to_string())]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint32_datatype()], + is_nullable: true, + }, + vec![1], + 1, + ) + .unwrap(), ]; let avg_expr = ScalarExpr::If { cond: Box::new(ScalarExpr::Column(2).call_binary( @@ -1307,13 +1426,8 @@ mod test { .unwrap() .into_safe(), }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: aggr_exprs.clone(), - simple_aggrs: vec![ - AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0), - AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1), - ], - distinct_aggrs: vec![], + reduce_plan: ReducePlan::AccumulableV2(AccumulablePlanV2 { + full_aggrs: aggr_exprs, }), } .with_types( @@ -1357,16 +1471,52 @@ mod test { .unwrap(); let aggr_exprs = vec![ - AggregateExpr { - func: AggregateFunc::SumUInt64, - expr: ScalarExpr::Column(0), - distinct: false, - }, - AggregateExpr { - func: AggregateFunc::Count, - expr: ScalarExpr::Column(1), - distinct: false, - }, + AggrWithIndexV2::new( + AggregateExprV2 { + func: sum_udaf().as_ref().clone(), + args: vec![ScalarExpr::Column(0) + .with_type(ColumnType::new(CDT::uint64_datatype(), true))], + return_type: CDT::uint64_datatype(), + name: "sum".to_string(), + schema: RelationType::new(vec![ + ColumnType::new(ConcreteDataType::uint64_datatype(), true), + ColumnType::new(ConcreteDataType::uint32_datatype(), false), + ]) + .into_named(vec![None, Some("number".to_string())]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint64_datatype()], + is_nullable: true, + }, + vec![0], + 0, + ) + .unwrap(), + AggrWithIndexV2::new( + AggregateExprV2 { + func: count_udaf().as_ref().clone(), + args: vec![ScalarExpr::Column(1) + .with_type(ColumnType::new(CDT::uint32_datatype(), false))], + return_type: CDT::int64_datatype(), + name: "count".to_string(), + schema: RelationType::new(vec![ + ColumnType::new(ConcreteDataType::uint64_datatype(), true), + ColumnType::new(ConcreteDataType::uint32_datatype(), false), + ]) + .into_named(vec![None, Some("number".to_string())]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint32_datatype()], + is_nullable: true, + }, + vec![1], + 1, + ) + .unwrap(), ]; let avg_expr = ScalarExpr::If { cond: Box::new(ScalarExpr::Column(1).call_binary( @@ -1429,13 +1579,8 @@ mod test { .unwrap() .into_safe(), }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: aggr_exprs.clone(), - simple_aggrs: vec![ - AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0), - AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1), - ], - distinct_aggrs: vec![], + reduce_plan: ReducePlan::AccumulableV2(AccumulablePlanV2 { + full_aggrs: aggr_exprs, }), } .with_types( @@ -1476,10 +1621,10 @@ mod test { return_type: CDT::uint64_datatype(), name: "sum".to_string(), schema: RelationType::new(vec![ColumnType::new( - ConcreteDataType::uint32_datatype(), - false, + ConcreteDataType::uint64_datatype(), + true, )]) - .into_named(vec![Some("number".to_string())]), + .into_named(vec![None]), ordering_req: OrderingReq::empty(), ignore_nulls: false, is_distinct: false, @@ -1570,11 +1715,7 @@ mod test { .unwrap() .into_safe(), }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: vec![], - simple_aggrs: vec![], - distinct_aggrs: vec![], - }), + reduce_plan: ReducePlan::AccumulableV2(AccumulablePlanV2 { full_aggrs: vec![] }), }, }; @@ -1592,10 +1733,24 @@ mod test { .await .unwrap(); - let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt64, - expr: ScalarExpr::Column(0), - distinct: false, + let aggr_expr = AggregateExprV2 { + func: sum_udaf().as_ref().clone(), + args: vec![ + ScalarExpr::Column(0).with_type(ColumnType::new(CDT::uint64_datatype(), true)) + ], + return_type: CDT::uint64_datatype(), + name: "sum".to_string(), + schema: RelationType::new(vec![ColumnType::new( + ConcreteDataType::uint64_datatype(), + true, + )]) + .into_named(vec![None]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint64_datatype()], + is_nullable: true, }; let expected = TypedPlan { schema: RelationType::new(vec![ @@ -1639,10 +1794,8 @@ mod test { .unwrap() .into_safe(), }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: vec![aggr_expr.clone()], - simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)], - distinct_aggrs: vec![], + reduce_plan: ReducePlan::AccumulableV2(AccumulablePlanV2 { + full_aggrs: vec![AggrWithIndexV2::new(aggr_expr, vec![0], 0).unwrap()], }), } .with_types( @@ -1674,10 +1827,24 @@ mod test { let mut ctx = create_test_ctx(); let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; - let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt64, - expr: ScalarExpr::Column(0), - distinct: false, + let aggr_expr = AggregateExprV2 { + func: sum_udaf().as_ref().clone(), + args: vec![ + ScalarExpr::Column(0).with_type(ColumnType::new(CDT::uint64_datatype(), true)) + ], + return_type: CDT::uint64_datatype(), + name: "sum".to_string(), + schema: RelationType::new(vec![ColumnType::new( + ConcreteDataType::uint64_datatype(), + true, + )]) + .into_named(vec![None]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint64_datatype()], + is_nullable: true, }; let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]) @@ -1723,10 +1890,8 @@ mod test { .unwrap() .into_safe(), }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: vec![aggr_expr.clone()], - simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)], - distinct_aggrs: vec![], + reduce_plan: ReducePlan::AccumulableV2(AccumulablePlanV2 { + full_aggrs: vec![AggrWithIndexV2::new(aggr_expr, vec![0], 0).unwrap()], }), }, }; @@ -1743,16 +1908,52 @@ mod test { let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let aggr_exprs = vec![ - AggregateExpr { - func: AggregateFunc::MaxUInt32, - expr: ScalarExpr::Column(0), - distinct: false, - }, - AggregateExpr { - func: AggregateFunc::MinUInt32, - expr: ScalarExpr::Column(0), - distinct: false, - }, + AggrWithIndexV2::new( + AggregateExprV2 { + func: max_udaf().as_ref().clone(), + args: vec![ScalarExpr::Column(0) + .with_type(ColumnType::new(CDT::uint32_datatype(), false))], + return_type: CDT::uint32_datatype(), + name: "max".to_string(), + schema: RelationType::new(vec![ + ColumnType::new(ConcreteDataType::uint32_datatype(), false), + ColumnType::new(ConcreteDataType::timestamp_millisecond_datatype(), false), + ]) + .into_named(vec![Some("number".to_string()), Some("ts".to_string())]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint32_datatype()], + is_nullable: true, + }, + vec![0], + 0, + ) + .unwrap(), + AggrWithIndexV2::new( + AggregateExprV2 { + func: min_udaf().as_ref().clone(), + args: vec![ScalarExpr::Column(0) + .with_type(ColumnType::new(CDT::uint32_datatype(), false))], + return_type: CDT::uint32_datatype(), + name: "min".to_string(), + schema: RelationType::new(vec![ + ColumnType::new(ConcreteDataType::uint32_datatype(), false), + ColumnType::new(ConcreteDataType::timestamp_millisecond_datatype(), false), + ]) + .into_named(vec![Some("number".to_string()), Some("ts".to_string())]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint32_datatype()], + is_nullable: true, + }, + vec![0], + 1, + ) + .unwrap(), ]; let expected = TypedPlan { schema: RelationType::new(vec![ @@ -1830,11 +2031,8 @@ mod test { val_plan: MapFilterProject::new(2) .into_safe(), }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: aggr_exprs.clone(), - simple_aggrs: vec![AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0), - AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1)], - distinct_aggrs: vec![], + reduce_plan: ReducePlan::AccumulableV2(AccumulablePlanV2 { + full_aggrs: aggr_exprs, }), } .with_types( From c5bab3585641d59bd9c769f5416b85adce4e1c96 Mon Sep 17 00:00:00 2001 From: discord9 Date: Thu, 13 Feb 2025 19:55:00 +0800 Subject: [PATCH 16/18] chore: clippy --- src/flow/src/transform/aggr.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/flow/src/transform/aggr.rs b/src/flow/src/transform/aggr.rs index 3ec8d97c1f0c..7e7972e55ea9 100644 --- a/src/flow/src/transform/aggr.rs +++ b/src/flow/src/transform/aggr.rs @@ -565,7 +565,7 @@ mod test { use super::*; use crate::expr::{BinaryFunc, DfScalarFunction, GlobalId, RawDfScalarFn}; - use crate::plan::{AccumulablePlan, AggrWithIndex, Plan, TypedPlan}; + use crate::plan::{Plan, TypedPlan}; use crate::repr::{ColumnType, RelationType}; use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait}; use crate::transform::CDT; From 448c41c64d32ad88b957c51371412b06d92acc30 Mon Sep 17 00:00:00 2001 From: discord9 Date: Fri, 14 Feb 2025 19:45:17 +0800 Subject: [PATCH 17/18] test: state&eval --- src/flow/src/adapter/node_context.rs | 25 ++--- src/flow/src/expr/relation/accum_v2.rs | 121 +++++++++++++++++++++++++ 2 files changed, 135 insertions(+), 11 deletions(-) diff --git a/src/flow/src/adapter/node_context.rs b/src/flow/src/adapter/node_context.rs index c42b63a49989..7b961a4f3f93 100644 --- a/src/flow/src/adapter/node_context.rs +++ b/src/flow/src/adapter/node_context.rs @@ -64,6 +64,19 @@ pub struct FlownodeContext { pub aggregate_functions: HashMap>, } +pub fn all_built_in_udaf() -> HashMap> { + // sum/min/max/count are built-in aggregate functions + use datafusion::functions_aggregate::count::count_udaf; + use datafusion::functions_aggregate::min_max::{max_udaf, min_udaf}; + use datafusion::functions_aggregate::sum::sum_udaf; + HashMap::from([ + ("sum".to_string(), sum_udaf()), + ("min".to_string(), min_udaf()), + ("max".to_string(), max_udaf()), + ("count".to_string(), count_udaf()), + ]) +} + impl FlownodeContext { pub fn new(table_source: Box) -> Self { let mut ret = Self { @@ -83,17 +96,7 @@ impl FlownodeContext { } pub fn register_all_built_in_aggr_fns(&mut self) { - // sum/min/max/count are built-in aggregate functions - use datafusion::functions_aggregate::count::count_udaf; - use datafusion::functions_aggregate::min_max::{max_udaf, min_udaf}; - use datafusion::functions_aggregate::sum::sum_udaf; - let res = HashMap::from([ - ("sum".to_string(), sum_udaf()), - ("min".to_string(), min_udaf()), - ("max".to_string(), max_udaf()), - ("count".to_string(), count_udaf()), - ]); - self.aggregate_functions.extend(res); + self.aggregate_functions.extend(all_built_in_udaf()); } pub fn register_aggr_fn( diff --git a/src/flow/src/expr/relation/accum_v2.rs b/src/flow/src/expr/relation/accum_v2.rs index cd5af6d8d5c5..9da1dc36460e 100644 --- a/src/flow/src/expr/relation/accum_v2.rs +++ b/src/flow/src/expr/relation/accum_v2.rs @@ -210,3 +210,124 @@ fn err_try_from_val(reason: T) -> EvalError { } .build() } + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use datatypes::prelude::ConcreteDataType as CDT; + use datatypes::vectors::{UInt32Vector, UInt64Vector, VectorRef}; + + use crate::adapter::node_context::all_built_in_udaf; + use crate::expr::relation::{AggregateExprV2, OrderingReq}; + use crate::expr::ScalarExpr; + use crate::repr::{ColumnType, RelationType}; + + #[test] + pub fn test_can_get_state_after_eval() { + let udaf_list = all_built_in_udaf(); + let test_cases = [ + ( + AggregateExprV2 { + func: udaf_list.get("sum").unwrap().as_ref().clone(), + args: vec![ScalarExpr::Column(0) + .with_type(ColumnType::new(CDT::uint64_datatype(), true))], + return_type: CDT::uint64_datatype(), + name: "sum".to_string(), + schema: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]) + .into_named(vec![None]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint64_datatype()], + is_nullable: true, + }, + vec![Arc::new(UInt64Vector::from_slice([1, 2, 3])) as VectorRef], + ), + ( + AggregateExprV2 { + func: udaf_list.get("max").unwrap().as_ref().clone(), + args: vec![ScalarExpr::Column(0) + .with_type(ColumnType::new(CDT::uint32_datatype(), false))], + return_type: CDT::uint32_datatype(), + name: "max".to_string(), + schema: RelationType::new(vec![ + ColumnType::new(CDT::uint32_datatype(), false), + ColumnType::new(CDT::timestamp_millisecond_datatype(), false), + ]) + .into_named(vec![Some("number".to_string()), Some("ts".to_string())]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint32_datatype()], + is_nullable: true, + }, + vec![Arc::new(UInt32Vector::from_slice([1, 2, 3])) as VectorRef], + ), + ( + AggregateExprV2 { + func: udaf_list.get("count").unwrap().as_ref().clone(), + args: vec![ScalarExpr::Column(1) + .with_type(ColumnType::new(CDT::uint32_datatype(), false))], + return_type: CDT::int64_datatype(), + name: "count".to_string(), + schema: RelationType::new(vec![ + ColumnType::new(CDT::uint64_datatype(), true), + ColumnType::new(CDT::uint32_datatype(), false), + ]) + .into_named(vec![None, Some("number".to_string())]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint32_datatype()], + is_nullable: true, + }, + vec![Arc::new(UInt32Vector::from_slice([1, 2, 3])) as VectorRef], + ), + ( + AggregateExprV2 { + func: udaf_list.get("min").unwrap().as_ref().clone(), + args: vec![ScalarExpr::Column(0) + .with_type(ColumnType::new(CDT::uint32_datatype(), false))], + return_type: CDT::uint32_datatype(), + name: "min".to_string(), + schema: RelationType::new(vec![ + ColumnType::new(CDT::uint32_datatype(), false), + ColumnType::new(CDT::timestamp_millisecond_datatype(), false), + ]) + .into_named(vec![Some("number".to_string()), Some("ts".to_string())]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint32_datatype()], + is_nullable: true, + }, + vec![Arc::new(UInt32Vector::from_slice([1, 2, 3])) as VectorRef], + ), + ]; + + for (aggr, input) in test_cases { + let mut accum = aggr.create_accumulator().unwrap(); + accum.update_batch(&input).unwrap(); + let state1 = accum.state().unwrap(); + let (state, dt): (Vec<_>, Vec<_>) = state1.clone().into_iter().unzip(); + + // merge_states() & state() works in pair as expected + let mut accum = aggr.create_accumulator().unwrap(); + accum.merge_states(&[state.clone()], &dt).unwrap(); + let state2 = accum.state().unwrap(); + assert_eq!(state1, state2); + + // call state() after evaluate() works as expected(although this is undefined behavior by datafusion?) + let mut accum = aggr.create_accumulator().unwrap(); + accum.merge_states(&[state.clone()], &dt).unwrap(); + accum.evaluate().unwrap(); + let state3 = accum.state().unwrap(); + assert_eq!(state1, state3); + } + } +} From 2007117bf73f5b71f9cce8d6830d131dfd889d12 Mon Sep 17 00:00:00 2001 From: discord9 Date: Fri, 14 Feb 2025 19:54:22 +0800 Subject: [PATCH 18/18] chore: per review --- src/flow/src/expr/relation.rs | 9 ++---- src/flow/src/expr/relation/accum_v2.rs | 39 +++++--------------------- src/flow/src/expr/scalar.rs | 8 +++--- src/flow/src/transform/aggr.rs | 8 +++--- 4 files changed, 18 insertions(+), 46 deletions(-) diff --git a/src/flow/src/expr/relation.rs b/src/flow/src/expr/relation.rs index 2e5bd75f9de2..a2f20719bae6 100644 --- a/src/flow/src/expr/relation.rs +++ b/src/flow/src/expr/relation.rs @@ -91,12 +91,9 @@ impl AggregateExprV2 { is_distinct: self.is_distinct, exprs: &exprs, }; - let acc = self - .func - .accumulator(accum_args) - .with_context(|_| DatafusionSnafu { - context: "Fail to build accumulator", - })?; + let acc = self.func.accumulator(accum_args).context(DatafusionSnafu { + context: "Fail to build accumulator", + })?; let acc = DfAccumulatorAdapter::new_unchecked(acc); Ok(Box::new(acc)) } diff --git a/src/flow/src/expr/relation/accum_v2.rs b/src/flow/src/expr/relation/accum_v2.rs index 9da1dc36460e..835ec5949347 100644 --- a/src/flow/src/expr/relation/accum_v2.rs +++ b/src/flow/src/expr/relation/accum_v2.rs @@ -14,16 +14,13 @@ //! new accumulator trait that is more flexible and can be used in the future for more complex accumulators -use std::any::type_name; -use std::fmt::Display; - use datafusion::logical_expr::Accumulator as DfAccumulator; use datatypes::prelude::{ConcreteDataType as CDT, DataType}; use datatypes::value::Value; use datatypes::vectors::VectorRef; -use snafu::{ensure, OptionExt, ResultExt}; +use snafu::{ensure, ResultExt}; -use crate::expr::error::{DataTypeSnafu, DatafusionSnafu, InternalSnafu, TryFromValueSnafu}; +use crate::expr::error::{DataTypeSnafu, DatafusionSnafu, InternalSnafu}; use crate::expr::EvalError; /// Basically a copy of datafusion's Accumulator, but with a few modifications @@ -65,6 +62,7 @@ pub trait AccumulatorV2: Send + Sync + std::fmt::Debug { false } + /// Merge states using `&[Vec]` instead of `&[VectorRef]` fn merge_states(&mut self, states: &[Vec], typs: &[CDT]) -> Result<(), EvalError> { let states = states_to_batch(states, typs)?; self.merge_batch(&states) @@ -135,7 +133,7 @@ impl AccumulatorV2 for DfAccumulatorAdapter { Ok((val, dt)) }) .collect::, _>>() - .with_context(|_| DataTypeSnafu { + .context(DataTypeSnafu { msg: "failed to convert `ScalarValue` state to `Value`", })?; Ok(state) @@ -179,38 +177,15 @@ fn states_to_batch(states: &[Vec], dts: &[CDT]) -> Result, .iter() .map(|dt| dt.create_mutable_vector(state_len)) .collect::>(); - for i in 0..state_len { - for state in states.iter() { - // the j-th state's i-th value - let val = &state[i]; - ret.get_mut(i) - .with_context(|| InternalSnafu { - reason: format!("failed to get mutable vector at index {}", i), - })? - .push_value_ref(val.as_value_ref()); + for (i, vectors) in ret.iter_mut().enumerate() { + for state in states { + vectors.push_value_ref(state[i].as_value_ref()); } } let ret = ret.into_iter().map(|mut v| v.to_vector()).collect(); Ok(ret) } -fn fail_accum() -> EvalError { - InternalSnafu { - reason: format!( - "list of values exhausted before a accum of type {} can be build from it", - type_name::() - ), - } - .build() -} - -fn err_try_from_val(reason: T) -> EvalError { - TryFromValueSnafu { - msg: reason.to_string(), - } - .build() -} - #[cfg(test)] mod test { use std::sync::Arc; diff --git a/src/flow/src/expr/scalar.rs b/src/flow/src/expr/scalar.rs index 7288146bb511..59a2dced63ce 100644 --- a/src/flow/src/expr/scalar.rs +++ b/src/flow/src/expr/scalar.rs @@ -103,8 +103,9 @@ impl ScalarExpr { pub fn as_physical_expr(&self, df_schema: &DFSchema) -> Result, Error> { match self { Self::Column(i) => { - if *i >= df_schema.fields().len() { - return InvalidQuerySnafu { + ensure!( + *i < df_schema.fields().len(), + InvalidQuerySnafu { reason: format!( "column index {} out of range of len={} in df_schema={:?}", i, @@ -112,8 +113,7 @@ impl ScalarExpr { df_schema ), } - .fail(); - } + ); let field = df_schema.field(*i); datafusion::physical_expr::expressions::col(field.name(), df_schema.as_arrow()) .with_context(|_| DatafusionSnafu { diff --git a/src/flow/src/transform/aggr.rs b/src/flow/src/transform/aggr.rs index 7e7972e55ea9..35c7a9a0c3a3 100644 --- a/src/flow/src/transform/aggr.rs +++ b/src/flow/src/transform/aggr.rs @@ -212,7 +212,7 @@ impl AggregateExprV2 { typ: &RelationDesc, extensions: &FunctionExtensions, ) -> Result, Error> { - let mut all_aggr_exprs = vec![]; + let mut all_aggr_exprs = Vec::with_capacity(measures.len()); for m in measures { let filter = match &m.filter { @@ -244,7 +244,7 @@ impl AggregateExprV2 { // TODO(discord9): impl filter let _ = filter; - let mut args = vec![]; + let mut args = Vec::with_capacity(f.arguments.len()); for arg in &f.arguments { let arg_expr = match &arg.arg_type { Some(ArgType::Value(e)) => { @@ -378,7 +378,7 @@ impl KeyValPlan { .any(|agg| agg.args.iter().any(|e| e.expr.as_column().is_none())); if need_mfp { // create mfp from aggr_expr, and modify aggr_expr to use the output column of mfp - let mut input_exprs = Vec::new(); + let mut input_exprs = Vec::with_capacity(aggr_exprs.len()); for aggr_expr in aggr_exprs.iter_mut() { // FIX: also modify input_schema to fit input_exprs, a `new_input_schema` is needed // so we can separate all scalar compute to a mfp before aggr @@ -521,7 +521,7 @@ impl TypedPlan { // copy aggr_exprs to full_aggrs, and split them into simple_aggrs and distinct_aggrs // also set them input/output column - let mut full_aggrs = vec![]; + let mut full_aggrs = Vec::with_capacity(aggr_exprs.len()); for (idx, aggr) in aggr_exprs.into_iter().enumerate() { let input_idxs = aggr .args