diff --git a/Cargo.lock b/Cargo.lock index e4091bcf5384..2c6cd82cc601 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3517,6 +3517,17 @@ dependencies = [ "syn 2.0.96", ] +[[package]] +name = "derive-where" +version = "1.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62d671cc41a825ebabc75757b62d3d168c577f9149b2d49ece1dad1f72119d25" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", +] + [[package]] name = "derive_arbitrary" version = "1.3.2" @@ -4184,6 +4195,7 @@ dependencies = [ "datafusion-physical-expr", "datafusion-substrait", "datatypes", + "derive-where", "enum-as-inner", "enum_dispatch", "futures", diff --git a/src/flow/Cargo.toml b/src/flow/Cargo.toml index b4545e1a899a..91e5945ee2db 100644 --- a/src/flow/Cargo.toml +++ b/src/flow/Cargo.toml @@ -39,6 +39,7 @@ datafusion-expr.workspace = true datafusion-physical-expr.workspace = true datafusion-substrait.workspace = true datatypes.workspace = true +derive-where = "1.2.7" enum-as-inner = "0.6.0" enum_dispatch = "0.3" futures = "0.3" diff --git a/src/flow/src/adapter/node_context.rs b/src/flow/src/adapter/node_context.rs index 7983b396fedc..7b961a4f3f93 100644 --- a/src/flow/src/adapter/node_context.rs +++ b/src/flow/src/adapter/node_context.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use common_recordbatch::RecordBatch; use common_telemetry::trace; +use datafusion_expr::AggregateUDF; use datatypes::prelude::ConcreteDataType; use session::context::QueryContext; use snafu::{OptionExt, ResultExt}; @@ -59,11 +60,26 @@ pub struct FlownodeContext { /// All the tables that have been registered in the worker pub table_repr: IdToNameMap, pub query_context: Option>, + /// Aggregate functions registered in the context + pub aggregate_functions: HashMap>, +} + +pub fn all_built_in_udaf() -> HashMap> { + // sum/min/max/count are built-in aggregate functions + use datafusion::functions_aggregate::count::count_udaf; + use datafusion::functions_aggregate::min_max::{max_udaf, min_udaf}; + use datafusion::functions_aggregate::sum::sum_udaf; + HashMap::from([ + ("sum".to_string(), sum_udaf()), + ("min".to_string(), min_udaf()), + ("max".to_string(), max_udaf()), + ("count".to_string(), count_udaf()), + ]) } impl FlownodeContext { pub fn new(table_source: Box) -> Self { - Self { + let mut ret = Self { source_to_tasks: Default::default(), flow_to_sink: Default::default(), flow_plans: Default::default(), @@ -73,7 +89,22 @@ impl FlownodeContext { table_source, table_repr: Default::default(), query_context: Default::default(), - } + aggregate_functions: Default::default(), + }; + ret.register_all_built_in_aggr_fns(); + ret + } + + pub fn register_all_built_in_aggr_fns(&mut self) { + self.aggregate_functions.extend(all_built_in_udaf()); + } + + pub fn register_aggr_fn( + &mut self, + name: String, + aggr_fn: Arc, + ) -> Option> { + self.aggregate_functions.insert(name, aggr_fn) } pub fn get_flow_ids(&self, table_id: TableId) -> Option<&BTreeSet> { diff --git a/src/flow/src/compute/render/reduce.rs b/src/flow/src/compute/render/reduce.rs index 0bbc613260d7..393ad24ce0e9 100644 --- a/src/flow/src/compute/render/reduce.rs +++ b/src/flow/src/compute/render/reduce.rs @@ -17,6 +17,7 @@ use std::ops::Range; use std::sync::Arc; use arrow::array::new_null_array; +use common_error::ext::BoxedError; use common_telemetry::trace; use datatypes::data_type::ConcreteDataType; use datatypes::prelude::DataType; @@ -29,9 +30,14 @@ use snafu::{ensure, OptionExt, ResultExt}; use crate::compute::render::{Context, SubgraphArg}; use crate::compute::types::{Arranged, Collection, CollectionBundle, ErrCollector, Toff}; use crate::error::{Error, NotImplementedSnafu, PlanSnafu}; -use crate::expr::error::{ArrowSnafu, DataAlreadyExpiredSnafu, DataTypeSnafu, InternalSnafu}; -use crate::expr::{Accum, Accumulator, Batch, EvalError, ScalarExpr, VectorDiff}; -use crate::plan::{AccumulablePlan, AggrWithIndex, KeyValPlan, ReducePlan, TypedPlan}; +use crate::expr::error::{ + ArrowSnafu, DataAlreadyExpiredSnafu, DataTypeSnafu, ExternalSnafu, InternalSnafu, +}; +use crate::expr::{Batch, EvalError, ScalarExpr}; +use crate::plan::{ + AccumulablePlan, AccumulablePlanV2, AggrWithIndex, AggrWithIndexV2, KeyValPlan, ReducePlan, + TypedPlan, +}; use crate::repr::{self, DiffRow, KeyValDiffRow, RelationType, Row}; use crate::utils::{ArrangeHandler, ArrangeReader, ArrangeWriter, KeyExpiryManager}; @@ -48,13 +54,7 @@ impl Context<'_, '_> { reduce_plan: &ReducePlan, output_type: &RelationType, ) -> Result, Error> { - let accum_plan = if let ReducePlan::Accumulable(accum_plan) = reduce_plan { - if !accum_plan.distinct_aggrs.is_empty() { - NotImplementedSnafu { - reason: "Distinct aggregation is not supported in batch mode", - } - .fail()? - } + let accum_plan = if let ReducePlan::AccumulableV2(accum_plan) = reduce_plan { accum_plan.clone() } else { NotImplementedSnafu { @@ -252,6 +252,7 @@ impl Context<'_, '_> { ) -> Option> { match reduce_plan { ReducePlan::Distinct => None, + ReducePlan::AccumulableV2(_) => None, ReducePlan::Accumulable(AccumulablePlan { distinct_aggrs, .. }) => { (!distinct_aggrs.is_empty()).then(|| { std::iter::repeat_with(|| { @@ -357,7 +358,7 @@ fn reduce_batch_subgraph( arrange: &ArrangeHandler, src_data: impl IntoIterator, key_val_plan: &KeyValPlan, - accum_plan: &AccumulablePlan, + accum_plan: &AccumulablePlanV2, output_type: &RelationType, SubgraphArg { now, @@ -529,39 +530,48 @@ fn reduce_batch_subgraph( err_collector.run(|| -> Result<(), _> { let (accums, _, _) = arrange.get(now, &key).unwrap_or_default(); let accum_list = - from_accum_values_to_live_accums(accums.unpack(), accum_plan.simple_aggrs.len())?; + from_accum_values_to_live_accums(accums.unpack(), accum_plan.full_aggrs.len())?; let mut accum_output = AccumOutput::new(); - for AggrWithIndex { + for AggrWithIndexV2 { expr, - input_idx, + input_idxs, output_idx, - } in accum_plan.simple_aggrs.iter() + state_types, + } in accum_plan.full_aggrs.iter() { let cur_accum_value = accum_list.get(*output_idx).cloned().unwrap_or_default(); - let mut cur_accum = if cur_accum_value.is_empty() { - Accum::new_accum(&expr.func.clone())? - } else { - Accum::try_into_accum(&expr.func, cur_accum_value)? - }; + let mut cur_accum = expr + .create_accumulator() + .map_err(BoxedError::new) + .context(ExternalSnafu)?; + if !cur_accum_value.is_empty() { + cur_accum.merge_states(&[cur_accum_value], state_types)?; + } for val_batch in val_batches.iter() { // if batch is empty, input null instead - let cur_input = val_batch - .batch() - .get(*input_idx) - .cloned() - .unwrap_or_else(|| Arc::new(NullVector::new(val_batch.row_count()))); + let batch = val_batch.batch(); + + let cur_input = input_idxs + .iter() + .map(|idx| { + batch + .get(*idx) + .cloned() + .unwrap_or_else(|| Arc::new(NullVector::new(val_batch.row_count()))) + }) + .collect_vec(); let len = cur_input.len(); - cur_accum.update_batch(&expr.func, VectorDiff::from(cur_input))?; + cur_accum.update_batch(&cur_input)?; trace!("Reduce accum after take {} rows: {:?}", len, cur_accum); } - let final_output = cur_accum.eval(&expr.func)?; + let final_output = cur_accum.evaluate()?; trace!("Reduce accum final output: {:?}", final_output); accum_output.insert_output(*output_idx, final_output); - let cur_accum_value = cur_accum.into_state(); + let cur_accum_value = cur_accum.state()?.into_iter().map(|(v, _)| v).collect(); accum_output.insert_accum(*output_idx, cur_accum_value); } @@ -679,6 +689,7 @@ fn reduce_subgraph( send, }, ), + ReducePlan::AccumulableV2(_) => unimplemented!(), }; } @@ -1211,12 +1222,14 @@ mod test { use std::time::Duration; use common_time::Timestamp; + use datafusion::functions_aggregate::sum::sum_udaf; use datatypes::data_type::{ConcreteDataType, ConcreteDataType as CDT}; use hydroflow::scheduled::graph::Hydroflow; use super::*; use crate::compute::render::test::{get_output_handle, harness_test_ctx, run_and_check}; use crate::compute::state::DataflowState; + use crate::expr::relation::{AggregateExprV2, OrderingReq}; use crate::expr::{ self, AggregateExpr, AggregateFunc, BinaryFunc, GlobalId, MapFilterProject, UnaryFunc, }; @@ -1581,26 +1594,31 @@ mod test { val_plan: MapFilterProject::new(1).project([0]).unwrap().into_safe(), }; - let simple_aggrs = vec![AggrWithIndex::new( - AggregateExpr { - func: AggregateFunc::SumInt64, - expr: ScalarExpr::Column(0), - distinct: false, - }, - 0, - 0, - )]; - let accum_plan = AccumulablePlan { - full_aggrs: vec![AggregateExpr { - func: AggregateFunc::SumInt64, - expr: ScalarExpr::Column(0), - distinct: false, - }], - simple_aggrs, - distinct_aggrs: vec![], + let aggr_expr = AggregateExprV2 { + func: sum_udaf().as_ref().clone(), + args: vec![ + ScalarExpr::Column(0).with_type(ColumnType::new(CDT::int64_datatype(), false)) + ], + return_type: CDT::int64_datatype(), + name: "sum".to_string(), + schema: RelationType::new(vec![ColumnType::new( + ConcreteDataType::int32_datatype(), + false, + )]) + .into_named(vec![Some("number".to_string())]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::int64_datatype()], + is_nullable: true, }; - let reduce_plan = ReducePlan::Accumulable(accum_plan); + let accum_plan = AccumulablePlanV2 { + full_aggrs: vec![AggrWithIndexV2::new(aggr_expr, vec![0], 0).unwrap()], + }; + + let reduce_plan = ReducePlan::AccumulableV2(accum_plan); let bundle = ctx .render_reduce_batch( Box::new(input_plan.with_types(typ.into_unnamed())), diff --git a/src/flow/src/error.rs b/src/flow/src/error.rs index 703b47641c5c..bf44c12fba0a 100644 --- a/src/flow/src/error.rs +++ b/src/flow/src/error.rs @@ -148,10 +148,10 @@ pub enum Error { location: Location, }, - #[snafu(display("Datatypes error: {source} with extra message: {extra}"))] + #[snafu(display("Datatypes error: {source} with extra message: {context}"))] Datatypes { source: datatypes::Error, - extra: String, + context: String, #[snafu(implicit)] location: Location, }, diff --git a/src/flow/src/expr.rs b/src/flow/src/expr.rs index a3c12a974247..86055195326b 100644 --- a/src/flow/src/expr.rs +++ b/src/flow/src/expr.rs @@ -34,7 +34,7 @@ pub(crate) use func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc} pub(crate) use id::{GlobalId, Id, LocalId}; use itertools::Itertools; pub(crate) use linear::{MapFilterProject, MfpPlan, SafeMfpPlan}; -pub(crate) use relation::{Accum, Accumulator, AggregateExpr, AggregateFunc}; +pub(crate) use relation::{AggregateExpr, AggregateFunc}; pub(crate) use scalar::{ScalarExpr, TypedExpr}; use snafu::{ensure, ResultExt}; diff --git a/src/flow/src/expr/relation.rs b/src/flow/src/expr/relation.rs index b5d7e4ef2078..a2f20719bae6 100644 --- a/src/flow/src/expr/relation.rs +++ b/src/flow/src/expr/relation.rs @@ -14,13 +14,24 @@ //! Describes an aggregation function and it's input expression. -pub(crate) use accum::{Accum, Accumulator}; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::AggregateUDF; +use datatypes::prelude::{ConcreteDataType, DataType}; +use derive_where::derive_where; pub(crate) use func::AggregateFunc; +use snafu::ResultExt; +pub use udaf::{OrderingReq, SortExpr}; -use crate::expr::ScalarExpr; +use crate::error::DatafusionSnafu; +use crate::expr::relation::accum_v2::{AccumulatorV2, DfAccumulatorAdapter}; +use crate::expr::{ScalarExpr, TypedExpr}; +use crate::repr::RelationDesc; +use crate::Error; mod accum; +mod accum_v2; mod func; +mod udaf; /// Describes an aggregation expression. #[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] @@ -34,3 +45,56 @@ pub struct AggregateExpr { /// Should the aggregation be applied only to distinct results in each group. pub distinct: bool, } + +#[derive(Clone, Debug, Eq, PartialEq, PartialOrd)] +#[derive_where(Ord)] +pub struct AggregateExprV2 { + /// skipping `Ord` impl for func for convenience + #[derive_where(skip)] + pub func: AggregateUDF, + /// should only be a simple column ref list + pub args: Vec, + /// Output / return type of this aggregate + pub return_type: ConcreteDataType, + pub name: String, + /// The schema of the input relation to this aggregate + pub schema: RelationDesc, + // i.e. FIRST_VALUE(a ORDER BY b) + pub ordering_req: OrderingReq, + pub ignore_nulls: bool, + pub is_distinct: bool, + pub is_reversed: bool, + /// The types of the arguments to this aggregate + pub input_types: Vec, + pub is_nullable: bool, +} + +impl AggregateExprV2 {} + +impl AggregateExprV2 { + pub fn create_accumulator(&self) -> Result, Error> { + let data_type = self.return_type.as_arrow_type(); + let schema = self.schema.to_df_schema()?; + let ordering_req = self.ordering_req.to_lex_ordering(&schema)?; + let exprs = self + .args + .iter() + .map(|e| e.expr.as_physical_expr(&schema)) + .collect::, _>>()?; + let accum_args = AccumulatorArgs { + return_type: &data_type, + schema: schema.as_arrow(), + ignore_nulls: self.ignore_nulls, + ordering_req: &ordering_req, + is_reversed: self.is_reversed, + name: &self.name, + is_distinct: self.is_distinct, + exprs: &exprs, + }; + let acc = self.func.accumulator(accum_args).context(DatafusionSnafu { + context: "Fail to build accumulator", + })?; + let acc = DfAccumulatorAdapter::new_unchecked(acc); + Ok(Box::new(acc)) + } +} diff --git a/src/flow/src/expr/relation/accum.rs b/src/flow/src/expr/relation/accum.rs index 252913de56f6..9e048e0ad68b 100644 --- a/src/flow/src/expr/relation/accum.rs +++ b/src/flow/src/expr/relation/accum.rs @@ -36,6 +36,7 @@ use crate::expr::{AggregateFunc, EvalError}; use crate::repr::Diff; /// Accumulates values for the various types of accumulable aggregations. +/// TODO(discord9): refactor it to be more like datafusion's Accumulator #[enum_dispatch] pub trait Accumulator: Sized { fn into_state(self) -> Vec; diff --git a/src/flow/src/expr/relation/accum_v2.rs b/src/flow/src/expr/relation/accum_v2.rs new file mode 100644 index 000000000000..835ec5949347 --- /dev/null +++ b/src/flow/src/expr/relation/accum_v2.rs @@ -0,0 +1,308 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! new accumulator trait that is more flexible and can be used in the future for more complex accumulators + +use datafusion::logical_expr::Accumulator as DfAccumulator; +use datatypes::prelude::{ConcreteDataType as CDT, DataType}; +use datatypes::value::Value; +use datatypes::vectors::VectorRef; +use snafu::{ensure, ResultExt}; + +use crate::expr::error::{DataTypeSnafu, DatafusionSnafu, InternalSnafu}; +use crate::expr::EvalError; + +/// Basically a copy of datafusion's Accumulator, but with a few modifications +/// to accommodate our needs in flow and keep the upgradability of datafusion +pub trait AccumulatorV2: Send + Sync + std::fmt::Debug { + /// Updates the accumulator’s state from its input. + fn update_batch(&mut self, values: &[VectorRef]) -> Result<(), EvalError>; + + /// Returns the current aggregate value, NOT consuming the internal state, so it can be called multiple times. + fn evaluate(&mut self) -> Result; + + /// Returns the allocated size required for this accumulator, in bytes, including Self. + fn size(&self) -> usize; + + /// Returns the intermediate state of the accumulator, consuming the intermediate state. + /// + /// note that Value::Null's type is unknown, so (Value, ConcreteDataType) is used instead of just Value + fn state(&mut self) -> Result, EvalError>; + + /// Merges the states of multiple accumulators into this accumulator. + /// The states array passed was formed by concatenating the results of calling `Self::into_state` on zero or more other Accumulator instances. + fn merge_batch(&mut self, states: &[VectorRef]) -> Result<(), EvalError>; + + /// Retracts (removed) an update (caused by the given inputs) to accumulator’s state. + /// + /// currently unused, but will be used in the future for i.e. windowed aggregates + fn retract_batch(&mut self, _values: &[VectorRef]) -> Result<(), EvalError> { + InternalSnafu { + reason: format!( + "retract_batch not implemented for this accumulator {:?}", + self + ), + } + .fail() + } + + /// Does the accumulator support incrementally updating its value by removing values. + fn supports_retract_batch(&self) -> bool { + false + } + + /// Merge states using `&[Vec]` instead of `&[VectorRef]` + fn merge_states(&mut self, states: &[Vec], typs: &[CDT]) -> Result<(), EvalError> { + let states = states_to_batch(states, typs)?; + self.merge_batch(&states) + } +} + +/// Adapter for several hand-picked datafusion accumulators that can be used in flow +/// +/// i.e: can call evaluate multiple times. +#[derive(Debug)] +pub struct DfAccumulatorAdapter { + /// accumulator that is wrapped in a mutex to allow for evaluation + inner: Box, +} + +pub trait AcceptDfAccumulator: DfAccumulator {} + +impl AcceptDfAccumulator for datafusion::functions_aggregate::min_max::MaxAccumulator {} + +impl AcceptDfAccumulator for datafusion::functions_aggregate::min_max::MinAccumulator {} + +impl DfAccumulatorAdapter { + /// create a new accumulator from a datafusion accumulator without checking if it is supported in flow + pub fn new_unchecked(acc: Box) -> Self { + Self { inner: acc } + } + + pub fn new(acc: T) -> Self { + Self::new_unchecked(Box::new(acc)) + } +} + +impl AccumulatorV2 for DfAccumulatorAdapter { + fn update_batch(&mut self, values: &[VectorRef]) -> Result<(), EvalError> { + let values = values + .iter() + .map(|v| v.to_arrow_array().clone()) + .collect::>(); + self.inner.update_batch(&values).context(DatafusionSnafu { + context: "failed to update batch: {}", + }) + } + + fn evaluate(&mut self) -> Result { + // TODO(discord9): find a way to confirm internal state is not consumed + let value = self.inner.evaluate().context(DatafusionSnafu { + context: "failed to evaluate accumulator: {}", + })?; + let value = Value::try_from(value).context(DataTypeSnafu { + msg: "failed to convert evaluate result from `ScalarValue` to `Value`", + })?; + Ok(value) + } + + fn size(&self) -> usize { + self.inner.size() + } + + fn state(&mut self) -> Result, EvalError> { + let state = self.inner.state().context(DatafusionSnafu { + context: "failed to get state: {}", + })?; + let state = state + .into_iter() + .map(|v| -> Result<_, _> { + let dt = CDT::try_from(&v.data_type())?; + let val = Value::try_from(v)?; + Ok((val, dt)) + }) + .collect::, _>>() + .context(DataTypeSnafu { + msg: "failed to convert `ScalarValue` state to `Value`", + })?; + Ok(state) + } + + fn merge_batch(&mut self, states: &[VectorRef]) -> Result<(), EvalError> { + let states = states + .iter() + .map(|v| v.to_arrow_array().clone()) + .collect::>(); + self.inner.merge_batch(&states).context(DatafusionSnafu { + context: "failed to merge batch", + }) + } +} + +/// Convert a list of states(from `Accumulator::into_state`) +/// to a batch of vectors(that can be feed to `Accumulator::merge_batch`) +fn states_to_batch(states: &[Vec], dts: &[CDT]) -> Result, EvalError> { + if states.is_empty() || states[0].is_empty() { + return Ok(vec![]); + } + ensure!( + states.iter().map(|v| v.len()).all(|l| l == states[0].len()), + InternalSnafu { + reason: "states have different lengths" + } + ); + let state_len = states[0].len(); + ensure!( + state_len == dts.len(), + InternalSnafu { + reason: format!( + "states and data types have different lengths: {} != {}", + state_len, + dts.len() + ) + } + ); + let mut ret = dts + .iter() + .map(|dt| dt.create_mutable_vector(state_len)) + .collect::>(); + for (i, vectors) in ret.iter_mut().enumerate() { + for state in states { + vectors.push_value_ref(state[i].as_value_ref()); + } + } + let ret = ret.into_iter().map(|mut v| v.to_vector()).collect(); + Ok(ret) +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use datatypes::prelude::ConcreteDataType as CDT; + use datatypes::vectors::{UInt32Vector, UInt64Vector, VectorRef}; + + use crate::adapter::node_context::all_built_in_udaf; + use crate::expr::relation::{AggregateExprV2, OrderingReq}; + use crate::expr::ScalarExpr; + use crate::repr::{ColumnType, RelationType}; + + #[test] + pub fn test_can_get_state_after_eval() { + let udaf_list = all_built_in_udaf(); + let test_cases = [ + ( + AggregateExprV2 { + func: udaf_list.get("sum").unwrap().as_ref().clone(), + args: vec![ScalarExpr::Column(0) + .with_type(ColumnType::new(CDT::uint64_datatype(), true))], + return_type: CDT::uint64_datatype(), + name: "sum".to_string(), + schema: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]) + .into_named(vec![None]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint64_datatype()], + is_nullable: true, + }, + vec![Arc::new(UInt64Vector::from_slice([1, 2, 3])) as VectorRef], + ), + ( + AggregateExprV2 { + func: udaf_list.get("max").unwrap().as_ref().clone(), + args: vec![ScalarExpr::Column(0) + .with_type(ColumnType::new(CDT::uint32_datatype(), false))], + return_type: CDT::uint32_datatype(), + name: "max".to_string(), + schema: RelationType::new(vec![ + ColumnType::new(CDT::uint32_datatype(), false), + ColumnType::new(CDT::timestamp_millisecond_datatype(), false), + ]) + .into_named(vec![Some("number".to_string()), Some("ts".to_string())]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint32_datatype()], + is_nullable: true, + }, + vec![Arc::new(UInt32Vector::from_slice([1, 2, 3])) as VectorRef], + ), + ( + AggregateExprV2 { + func: udaf_list.get("count").unwrap().as_ref().clone(), + args: vec![ScalarExpr::Column(1) + .with_type(ColumnType::new(CDT::uint32_datatype(), false))], + return_type: CDT::int64_datatype(), + name: "count".to_string(), + schema: RelationType::new(vec![ + ColumnType::new(CDT::uint64_datatype(), true), + ColumnType::new(CDT::uint32_datatype(), false), + ]) + .into_named(vec![None, Some("number".to_string())]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint32_datatype()], + is_nullable: true, + }, + vec![Arc::new(UInt32Vector::from_slice([1, 2, 3])) as VectorRef], + ), + ( + AggregateExprV2 { + func: udaf_list.get("min").unwrap().as_ref().clone(), + args: vec![ScalarExpr::Column(0) + .with_type(ColumnType::new(CDT::uint32_datatype(), false))], + return_type: CDT::uint32_datatype(), + name: "min".to_string(), + schema: RelationType::new(vec![ + ColumnType::new(CDT::uint32_datatype(), false), + ColumnType::new(CDT::timestamp_millisecond_datatype(), false), + ]) + .into_named(vec![Some("number".to_string()), Some("ts".to_string())]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint32_datatype()], + is_nullable: true, + }, + vec![Arc::new(UInt32Vector::from_slice([1, 2, 3])) as VectorRef], + ), + ]; + + for (aggr, input) in test_cases { + let mut accum = aggr.create_accumulator().unwrap(); + accum.update_batch(&input).unwrap(); + let state1 = accum.state().unwrap(); + let (state, dt): (Vec<_>, Vec<_>) = state1.clone().into_iter().unzip(); + + // merge_states() & state() works in pair as expected + let mut accum = aggr.create_accumulator().unwrap(); + accum.merge_states(&[state.clone()], &dt).unwrap(); + let state2 = accum.state().unwrap(); + assert_eq!(state1, state2); + + // call state() after evaluate() works as expected(although this is undefined behavior by datafusion?) + let mut accum = aggr.create_accumulator().unwrap(); + accum.merge_states(&[state.clone()], &dt).unwrap(); + accum.evaluate().unwrap(); + let state3 = accum.state().unwrap(); + assert_eq!(state1, state3); + } + } +} diff --git a/src/flow/src/expr/relation/udaf.rs b/src/flow/src/expr/relation/udaf.rs new file mode 100644 index 000000000000..ee638f5fe07d --- /dev/null +++ b/src/flow/src/expr/relation/udaf.rs @@ -0,0 +1,64 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::Debug; + +use datafusion_common::DFSchema; +use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; + +use crate::expr::ScalarExpr; +use crate::Result; + +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)] +pub struct OrderingReq { + pub exprs: Vec, +} + +impl OrderingReq { + pub fn empty() -> Self { + Self { exprs: vec![] } + } + pub fn to_lex_ordering(&self, schema: &DFSchema) -> Result { + Ok(LexOrdering::new( + self.exprs + .iter() + .map(|e| e.to_sort_phy_expr(schema)) + .collect::>()?, + )) + } +} + +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)] +pub struct SortExpr { + /// expression representing the column to sort + pub expr: ScalarExpr, + /// Whether to sort in descending order + pub descending: bool, + /// Whether to sort nulls first + pub nulls_first: bool, +} + +impl SortExpr { + pub fn to_sort_phy_expr(&self, schema: &DFSchema) -> Result { + let phy = self.expr.as_physical_expr(schema)?; + let sort_options = datafusion_common::arrow::compute::SortOptions { + descending: self.descending, + nulls_first: self.nulls_first, + }; + Ok(PhysicalSortExpr { + expr: phy, + options: sort_options, + }) + } +} diff --git a/src/flow/src/expr/scalar.rs b/src/flow/src/expr/scalar.rs index 94af5e0e425a..59a2dced63ce 100644 --- a/src/flow/src/expr/scalar.rs +++ b/src/flow/src/expr/scalar.rs @@ -15,9 +15,12 @@ //! Scalar expressions. use std::collections::{BTreeMap, BTreeSet}; +use std::sync::Arc; use arrow::array::{make_array, ArrayData, ArrayRef}; use common_error::ext::BoxedError; +use datafusion_common::DFSchema; +use datafusion_physical_expr::PhysicalExpr; use datatypes::prelude::{ConcreteDataType, DataType}; use datatypes::value::Value; use datatypes::vectors::{BooleanVector, Helper, VectorRef}; @@ -26,7 +29,8 @@ use itertools::Itertools; use snafu::{ensure, OptionExt, ResultExt}; use crate::error::{ - DatafusionSnafu, Error, InvalidQuerySnafu, UnexpectedSnafu, UnsupportedTemporalFilterSnafu, + DatafusionSnafu, DatatypesSnafu, Error, InvalidQuerySnafu, NotImplementedSnafu, + UnexpectedSnafu, UnsupportedTemporalFilterSnafu, }; use crate::expr::error::{ ArrowSnafu, DataTypeSnafu, EvalError, InvalidArgumentSnafu, OptimizeSnafu, TypeMismatchSnafu, @@ -95,6 +99,41 @@ pub enum ScalarExpr { } impl ScalarExpr { + // TODO(discord9): impl more convert? + pub fn as_physical_expr(&self, df_schema: &DFSchema) -> Result, Error> { + match self { + Self::Column(i) => { + ensure!( + *i < df_schema.fields().len(), + InvalidQuerySnafu { + reason: format!( + "column index {} out of range of len={} in df_schema={:?}", + i, + df_schema.fields().len(), + df_schema + ), + } + ); + let field = df_schema.field(*i); + datafusion::physical_expr::expressions::col(field.name(), df_schema.as_arrow()) + .with_context(|_| DatafusionSnafu { + context: "Failed to create datafusion column expression", + }) + } + Self::Literal(val, typ) => { + let val = val.try_to_scalar_value(typ).context(DatatypesSnafu { + context: format!("Failed to convert val=:{:?} to literal", val), + })?; + let ret = datafusion::physical_expr::expressions::lit(val); + Ok(ret) + } + _ => NotImplementedSnafu { + reason: "Not implemented yet".to_string(), + } + .fail()?, + } + } + pub fn with_type(self, typ: ColumnType) -> TypedExpr { TypedExpr::new(self, typ) } diff --git a/src/flow/src/plan.rs b/src/flow/src/plan.rs index b2c91015e065..e7aa8e2ed5f8 100644 --- a/src/flow/src/plan.rs +++ b/src/flow/src/plan.rs @@ -23,7 +23,9 @@ use std::collections::BTreeSet; use crate::error::Error; use crate::expr::{GlobalId, Id, LocalId, MapFilterProject, SafeMfpPlan, ScalarExpr, TypedExpr}; use crate::plan::join::JoinPlan; -pub(crate) use crate::plan::reduce::{AccumulablePlan, AggrWithIndex, KeyValPlan, ReducePlan}; +pub(crate) use crate::plan::reduce::{ + AccumulablePlan, AccumulablePlanV2, AggrWithIndex, AggrWithIndexV2, KeyValPlan, ReducePlan, +}; use crate::repr::{DiffRow, RelationDesc}; /// A plan for a dataflow component. But with type to indicate the output type of the relation. diff --git a/src/flow/src/plan/reduce.rs b/src/flow/src/plan/reduce.rs index 65a83756b57f..af7f40213abe 100644 --- a/src/flow/src/plan/reduce.rs +++ b/src/flow/src/plan/reduce.rs @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +use datatypes::prelude::ConcreteDataType; + +use crate::expr::relation::AggregateExprV2; use crate::expr::{AggregateExpr, SafeMfpPlan, ScalarExpr}; /// Describe how to extract key-value pair from a `Row` @@ -43,6 +46,8 @@ pub enum ReducePlan { /// Plan for computing only accumulable aggregations. /// Including simple functions like `sum`, `count`, `min/max`(without deletion) Accumulable(AccumulablePlan), + /// Calling AggregateExprV2 + AccumulableV2(AccumulablePlanV2), } /// Accumulable plan for the execution of a reduction. @@ -85,3 +90,45 @@ impl AggrWithIndex { } } } + +/// Accumulable plan for the execution of a reduction. +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] +pub struct AccumulablePlanV2 { + /// All of the aggregations we were asked to compute, stored + /// in order. + pub full_aggrs: Vec, +} + +/// This struct basically get useful info from `expr` and store it so no need +/// to get it repeatedly +/// Invariant: the output index is the index of the aggregation in `full_aggrs` +/// which means output index is always smaller than the length of `full_aggrs` +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] +pub struct AggrWithIndexV2 { + /// aggregation expression + pub expr: AggregateExprV2, + /// index of aggr input among input row, get from `self.expr.args` + pub input_idxs: Vec, + /// index of aggr output among output row + pub output_idx: usize, + /// The types of intermidate state field + pub state_types: Vec, +} + +impl AggrWithIndexV2 { + /// Create a new `AggrWithIndex` + pub fn new( + expr: AggregateExprV2, + input_idxs: Vec, + output_idx: usize, + ) -> Result { + let mut test_accum = expr.create_accumulator()?; + let states = test_accum.state()?; + Ok(Self { + expr, + input_idxs, + output_idx, + state_types: states.into_iter().map(|(_, ty)| ty).collect(), + }) + } +} diff --git a/src/flow/src/transform/aggr.rs b/src/flow/src/transform/aggr.rs index 27c7acfb1da9..35c7a9a0c3a3 100644 --- a/src/flow/src/transform/aggr.rs +++ b/src/flow/src/transform/aggr.rs @@ -12,18 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. +use datatypes::prelude::{ConcreteDataType, DataType}; use itertools::Itertools; -use snafu::OptionExt; +use snafu::{OptionExt, ResultExt}; use substrait_proto::proto; use substrait_proto::proto::aggregate_function::AggregationInvocation; use substrait_proto::proto::aggregate_rel::{Grouping, Measure}; use substrait_proto::proto::function_argument::ArgType; +use substrait_proto::proto::sort_field::{SortDirection, SortKind}; -use crate::error::{Error, NotImplementedSnafu, PlanSnafu}; +use crate::error::{ + DatafusionSnafu, DatatypesSnafu, Error, NotImplementedSnafu, PlanSnafu, UnexpectedSnafu, +}; +use crate::expr::relation::{AggregateExprV2, OrderingReq, SortExpr}; use crate::expr::{ AggregateExpr, AggregateFunc, MapFilterProject, ScalarExpr, TypedExpr, UnaryFunc, }; -use crate::plan::{AccumulablePlan, AggrWithIndex, KeyValPlan, Plan, ReducePlan, TypedPlan}; +use crate::plan::{AccumulablePlanV2, AggrWithIndexV2, KeyValPlan, Plan, ReducePlan, TypedPlan}; use crate::repr::{ColumnType, RelationDesc, RelationType}; use crate::transform::{substrait_proto, FlownodeContext, FunctionExtensions}; @@ -86,7 +91,7 @@ impl TypedExpr { impl AggregateExpr { /// Convert list of `Measure` into Flow's AggregateExpr /// - /// Return both the AggregateExpr and a MapFilterProject that is the final output of the aggregate function + /// Return the AggregateExpr List that is the final output of the aggregate function async fn from_substrait_agg_measures( ctx: &mut FlownodeContext, measures: &[Measure], @@ -197,12 +202,161 @@ impl AggregateExpr { } } +impl AggregateExprV2 { + /// Convert list of `Measure` into Flow's AggregateExpr + /// + /// Return the AggregateExpr List that is the final output of the aggregate function + async fn from_substrait_agg_measures( + ctx: &mut FlownodeContext, + measures: &[Measure], + typ: &RelationDesc, + extensions: &FunctionExtensions, + ) -> Result, Error> { + let mut all_aggr_exprs = Vec::with_capacity(measures.len()); + + for m in measures { + let filter = match &m.filter { + Some(fil) => Some(TypedExpr::from_substrait_rex(fil, typ, extensions).await?), + None => None, + }; + + let Some(f) = &m.measure else { + not_impl_err!("Expect aggregate function")? + }; + + let aggr_expr = Self::from_substrait_agg_func(ctx, f, typ, extensions, &filter).await?; + all_aggr_exprs.push(aggr_expr); + } + Ok(all_aggr_exprs) + } + + /// Convert AggregateFunction into Flow's AggregateExpr + /// + /// the returned value is a tuple of AggregateExpr and a optional ScalarExpr that if exist is the final output of the aggregate function + /// since aggr functions like `avg` need to be transform to `sum(x)/cast(count(x) as x_type)` + pub async fn from_substrait_agg_func( + ctx: &mut FlownodeContext, + f: &proto::AggregateFunction, + input_schema: &RelationDesc, + extensions: &FunctionExtensions, + filter: &Option, + ) -> Result { + // TODO(discord9): impl filter + let _ = filter; + + let mut args = Vec::with_capacity(f.arguments.len()); + for arg in &f.arguments { + let arg_expr = match &arg.arg_type { + Some(ArgType::Value(e)) => { + TypedExpr::from_substrait_rex(e, input_schema, extensions).await + } + _ => not_impl_err!("Aggregated function argument non-Value type not supported"), + }?; + args.push(arg_expr); + } + let args = args; + let distinct = match f.invocation { + _ if f.invocation == AggregationInvocation::Distinct as i32 => true, + _ if f.invocation == AggregationInvocation::All as i32 => false, + _ => false, + }; + + let fn_name = extensions + .get(&f.function_reference) + .cloned() + .with_context(|| PlanSnafu { + reason: format!( + "Aggregated function not found: function anchor = {:?}", + f.function_reference + ), + })?; + + let fn_impl = ctx + .aggregate_functions + .get(&fn_name) + .with_context(|| PlanSnafu { + reason: format!("Aggregate function not found: {:?}", fn_name), + })?; + + let input_types = args + .iter() + .map(|a| a.typ.scalar_type().clone()) + .collect_vec(); + + let return_type = { + let arrow_input_types = input_types.iter().map(|t| t.as_arrow_type()).collect_vec(); + let ret = fn_impl + .return_type(&arrow_input_types) + .context(DatafusionSnafu { + context: "failed to get return type of aggregate function", + })?; + ConcreteDataType::try_from(&ret).context(DatatypesSnafu { + context: "failed to convert return type to ConcreteDataType", + })? + }; + + let ordering_req = { + let mut ret = Vec::with_capacity(f.sorts.len()); + for sort in &f.sorts { + let Some(raw_expr) = sort.expr.as_ref() else { + return not_impl_err!("Sort expression not found in sort"); + }; + let expr = + TypedExpr::from_substrait_rex(raw_expr, input_schema, extensions).await?; + let sort_dir = sort.sort_kind; + let Some(SortKind::Direction(dir)) = sort_dir else { + return not_impl_err!("Sort direction not found in sort"); + }; + let dir = SortDirection::try_from(dir).map_err(|e| { + PlanSnafu { + reason: format!("{} is not a valid direction", e.0), + } + .build() + })?; + + let (descending, nulls_first) = match dir { + // align with default datafusion option + SortDirection::Unspecified => (false, true), + SortDirection::AscNullsFirst => (false, true), + SortDirection::AscNullsLast => (false, false), + SortDirection::DescNullsFirst => (true, true), + SortDirection::DescNullsLast => (true, false), + SortDirection::Clustered => not_impl_err!("Clustered sort not supported")?, + }; + + let sort_expr = SortExpr { + expr: expr.expr, + descending, + nulls_first, + }; + ret.push(sort_expr); + } + OrderingReq { exprs: ret } + }; + + // TODO(discord9): determine other options from substrait too instead of default + Ok(Self { + func: fn_impl.as_ref().clone(), + args, + return_type, + name: fn_name, + schema: input_schema.clone(), + ordering_req, + ignore_nulls: false, + is_distinct: distinct, + is_reversed: false, + input_types, + is_nullable: true, + }) + } +} + impl KeyValPlan { /// Generate KeyValPlan from AggregateExpr and group_exprs /// /// will also change aggregate expr to use column ref if necessary fn from_substrait_gen_key_val_plan( - aggr_exprs: &mut [AggregateExpr], + aggr_exprs: &mut [AggregateExprV2], group_exprs: &[TypedExpr], input_arity: usize, ) -> Result { @@ -217,25 +371,52 @@ impl KeyValPlan { .project(input_arity..input_arity + output_arity)?; // val_plan is extracted from aggr_exprs to give aggr function it's necessary input - // and since aggr func need inputs that is column ref, we just add a prefix mfp to transform any expr that is not into a column ref + // and since aggr func need inputs that is column ref(or literal), we just add a prefix mfp to transform any expr that is not into a column ref let val_plan = { - let need_mfp = aggr_exprs.iter().any(|agg| agg.expr.as_column().is_none()); + let need_mfp = aggr_exprs + .iter() + .any(|agg| agg.args.iter().any(|e| e.expr.as_column().is_none())); if need_mfp { // create mfp from aggr_expr, and modify aggr_expr to use the output column of mfp - let input_exprs = aggr_exprs - .iter_mut() - .enumerate() - .map(|(idx, aggr)| { - let ret = aggr.expr.clone(); - aggr.expr = ScalarExpr::Column(idx); - ret - }) - .collect_vec(); - let aggr_arity = aggr_exprs.len(); - - MapFilterProject::new(input_arity) + let mut input_exprs = Vec::with_capacity(aggr_exprs.len()); + for aggr_expr in aggr_exprs.iter_mut() { + // FIX: also modify input_schema to fit input_exprs, a `new_input_schema` is needed + // so we can separate all scalar compute to a mfp before aggr + for arg in aggr_expr.args.iter_mut() { + match arg.expr.clone() { + ScalarExpr::Column(idx) => { + // directly refer to column in mfp + arg.expr = ScalarExpr::Column(input_exprs.len()); + input_exprs.push(ScalarExpr::Column(idx)); + } + ScalarExpr::Literal(_, _) => { + // already literal, but still need to make it ref + let ret = arg.expr.clone(); + arg.expr = ScalarExpr::Column(input_exprs.len()); + input_exprs.push(ret); + } + _ => { + // create a new expr and let arg ref to that expr's column instead + let ret = arg.expr.clone(); + arg.expr = ScalarExpr::Column(input_exprs.len()); + input_exprs.push(ret); + } + } + } + } + let new_input_len = input_exprs.len(); + let pre_mfp = MapFilterProject::new(input_arity) .map(input_exprs)? - .project(input_arity..input_arity + aggr_arity)? + .project(input_arity..input_arity + new_input_len)?; + // adjust input schema according to pre_mfp + if let Some(first) = aggr_exprs.first() { + let new_input_schema = first.schema.apply_mfp(&pre_mfp.clone().into_safe())?; + + for aggr in aggr_exprs.iter_mut() { + aggr.schema = new_input_schema.clone(); + } + } + pre_mfp } else { // simply take all inputs as value MapFilterProject::new(input_arity) @@ -292,7 +473,7 @@ impl TypedPlan { let time_index = find_time_index_in_group_exprs(&group_exprs); - let mut aggr_exprs = AggregateExpr::from_substrait_agg_measures( + let mut aggr_exprs = AggregateExprV2::from_substrait_agg_measures( ctx, &agg.measures, &input.schema, @@ -324,9 +505,7 @@ impl TypedPlan { } for aggr in &aggr_exprs { - output_types.push(ColumnType::new_nullable( - aggr.func.signature().output.clone(), - )); + output_types.push(ColumnType::new_nullable(aggr.return_type.clone())); // TODO(discord9): find a clever way to name them? output_names.push(None); } @@ -342,36 +521,26 @@ impl TypedPlan { // copy aggr_exprs to full_aggrs, and split them into simple_aggrs and distinct_aggrs // also set them input/output column - let full_aggrs = aggr_exprs; - let mut simple_aggrs = Vec::new(); - let mut distinct_aggrs = Vec::new(); - for (output_column, aggr_expr) in full_aggrs.iter().enumerate() { - let input_column = aggr_expr.expr.as_column().with_context(|| PlanSnafu { - reason: "Expect aggregate argument to be transformed into a column at this point", - })?; - if aggr_expr.distinct { - distinct_aggrs.push(AggrWithIndex::new( - aggr_expr.clone(), - input_column, - output_column, - )); - } else { - simple_aggrs.push(AggrWithIndex::new( - aggr_expr.clone(), - input_column, - output_column, - )); - } + let mut full_aggrs = Vec::with_capacity(aggr_exprs.len()); + for (idx, aggr) in aggr_exprs.into_iter().enumerate() { + let input_idxs = aggr + .args + .iter() + .map(|a| { + a.expr.as_column().with_context(|| UnexpectedSnafu { + reason: format!("Expect {:?} to be a column", a), + }) + }) + .collect::, _>>()?; + let aggr = AggrWithIndexV2::new(aggr, input_idxs, idx)?; + full_aggrs.push(aggr); } - let accum_plan = AccumulablePlan { - full_aggrs, - simple_aggrs, - distinct_aggrs, - }; + + let accum_plan = AccumulablePlanV2 { full_aggrs }; let plan = Plan::Reduce { input: Box::new(input), key_val_plan, - reduce_plan: ReducePlan::Accumulable(accum_plan), + reduce_plan: ReducePlan::AccumulableV2(accum_plan), }; // FIX(discord9): deal with key first return Ok(TypedPlan { @@ -387,6 +556,9 @@ mod test { use bytes::BytesMut; use common_time::{IntervalMonthDayNano, Timestamp}; + use datafusion::functions_aggregate::count::count_udaf; + use datafusion::functions_aggregate::min_max::{max_udaf, min_udaf}; + use datafusion::functions_aggregate::sum::sum_udaf; use datatypes::prelude::ConcreteDataType; use datatypes::value::Value; use pretty_assertions::assert_eq; @@ -409,10 +581,24 @@ mod test { .await .unwrap(); - let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt64, - expr: ScalarExpr::Column(0), - distinct: false, + let aggr_expr = AggregateExprV2 { + func: sum_udaf().as_ref().clone(), + args: vec![ + ScalarExpr::Column(0).with_type(ColumnType::new(CDT::uint64_datatype(), true)) + ], + return_type: CDT::uint64_datatype(), + name: "sum".to_string(), + schema: RelationType::new(vec![ColumnType::new( + ConcreteDataType::uint64_datatype(), + true, + )]) + .into_named(vec![None]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint64_datatype()], + is_nullable: true, }; let expected = TypedPlan { schema: RelationType::new(vec![ @@ -508,10 +694,8 @@ mod test { .unwrap() .into_safe(), }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: vec![aggr_expr.clone()], - simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)], - distinct_aggrs: vec![], + reduce_plan: ReducePlan::AccumulableV2(AccumulablePlanV2 { + full_aggrs: vec![AggrWithIndexV2::new(aggr_expr, vec![0], 0).unwrap()], }), } .with_types( @@ -550,10 +734,24 @@ mod test { .await .unwrap(); - let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt64, - expr: ScalarExpr::Column(0), - distinct: false, + let aggr_expr = AggregateExprV2 { + func: sum_udaf().as_ref().clone(), + args: vec![ + ScalarExpr::Column(0).with_type(ColumnType::new(CDT::uint64_datatype(), true)) + ], + return_type: CDT::uint64_datatype(), + name: "sum".to_string(), + schema: RelationType::new(vec![ColumnType::new( + ConcreteDataType::uint64_datatype(), + true, + )]) + .into_named(vec![None]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint64_datatype()], + is_nullable: true, }; let expected = TypedPlan { schema: RelationType::new(vec![ @@ -622,10 +820,8 @@ mod test { .unwrap() .into_safe(), }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: vec![aggr_expr.clone()], - simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)], - distinct_aggrs: vec![], + reduce_plan: ReducePlan::AccumulableV2(AccumulablePlanV2 { + full_aggrs: vec![AggrWithIndexV2::new(aggr_expr, vec![0], 0).unwrap()], }), } .with_types( @@ -688,16 +884,52 @@ mod test { .unwrap(); let aggr_exprs = vec![ - AggregateExpr { - func: AggregateFunc::SumUInt64, - expr: ScalarExpr::Column(0), - distinct: false, - }, - AggregateExpr { - func: AggregateFunc::Count, - expr: ScalarExpr::Column(1), - distinct: false, - }, + AggrWithIndexV2::new( + AggregateExprV2 { + func: sum_udaf().as_ref().clone(), + args: vec![ScalarExpr::Column(0) + .with_type(ColumnType::new(CDT::uint64_datatype(), true))], + return_type: CDT::uint64_datatype(), + name: "sum".to_string(), + schema: RelationType::new(vec![ + ColumnType::new(ConcreteDataType::uint64_datatype(), true), + ColumnType::new(ConcreteDataType::uint32_datatype(), false), + ]) + .into_named(vec![None, Some("number".to_string())]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint64_datatype()], + is_nullable: true, + }, + vec![0], + 0, + ) + .unwrap(), + AggrWithIndexV2::new( + AggregateExprV2 { + func: count_udaf().as_ref().clone(), + args: vec![ScalarExpr::Column(1) + .with_type(ColumnType::new(CDT::uint32_datatype(), false))], + return_type: CDT::int64_datatype(), + name: "count".to_string(), + schema: RelationType::new(vec![ + ColumnType::new(ConcreteDataType::uint64_datatype(), true), + ColumnType::new(ConcreteDataType::uint32_datatype(), false), + ]) + .into_named(vec![None, Some("number".to_string())]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint32_datatype()], + is_nullable: true, + }, + vec![1], + 1, + ) + .unwrap(), ]; let avg_expr = ScalarExpr::If { cond: Box::new(ScalarExpr::Column(4).call_binary( @@ -769,13 +1001,8 @@ mod test { .unwrap() .into_safe(), }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: aggr_exprs.clone(), - simple_aggrs: vec![ - AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0), - AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1), - ], - distinct_aggrs: vec![], + reduce_plan: ReducePlan::AccumulableV2(AccumulablePlanV2 { + full_aggrs: aggr_exprs, }), } .with_types( @@ -839,11 +1066,26 @@ mod test { .await .unwrap(); - let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt64, - expr: ScalarExpr::Column(0), - distinct: false, + let aggr_expr = AggregateExprV2 { + func: sum_udaf().as_ref().clone(), + args: vec![ + ScalarExpr::Column(0).with_type(ColumnType::new(CDT::uint64_datatype(), true)) + ], + return_type: CDT::uint64_datatype(), + name: "sum".to_string(), + schema: RelationType::new(vec![ColumnType::new( + ConcreteDataType::uint64_datatype(), + true, + )]) + .into_named(vec![None]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint64_datatype()], + is_nullable: true, }; + let expected = TypedPlan { schema: RelationType::new(vec![ ColumnType::new(CDT::uint64_datatype(), true), // sum(number) @@ -907,10 +1149,8 @@ mod test { .unwrap() .into_safe(), }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: vec![aggr_expr.clone()], - simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)], - distinct_aggrs: vec![], + reduce_plan: ReducePlan::AccumulableV2(AccumulablePlanV2 { + full_aggrs: vec![AggrWithIndexV2::new(aggr_expr, vec![0], 0).unwrap()], }), } .with_types( @@ -949,11 +1189,26 @@ mod test { .await .unwrap(); - let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt64, - expr: ScalarExpr::Column(0), - distinct: false, + let aggr_expr = AggregateExprV2 { + func: sum_udaf().as_ref().clone(), + args: vec![ + ScalarExpr::Column(0).with_type(ColumnType::new(CDT::uint64_datatype(), true)) + ], + return_type: CDT::uint64_datatype(), + name: "sum".to_string(), + schema: RelationType::new(vec![ColumnType::new( + ConcreteDataType::uint64_datatype(), + true, + )]) + .into_named(vec![None]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint64_datatype()], + is_nullable: true, }; + let expected = TypedPlan { schema: RelationType::new(vec![ ColumnType::new(CDT::uint64_datatype(), true), // sum(number) @@ -1021,10 +1276,8 @@ mod test { .unwrap() .into_safe(), }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: vec![aggr_expr.clone()], - simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)], - distinct_aggrs: vec![], + reduce_plan: ReducePlan::AccumulableV2(AccumulablePlanV2 { + full_aggrs: vec![AggrWithIndexV2::new(aggr_expr, vec![0], 0).unwrap()], }), } .with_types( @@ -1062,16 +1315,52 @@ mod test { let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let aggr_exprs = vec![ - AggregateExpr { - func: AggregateFunc::SumUInt64, - expr: ScalarExpr::Column(0), - distinct: false, - }, - AggregateExpr { - func: AggregateFunc::Count, - expr: ScalarExpr::Column(1), - distinct: false, - }, + AggrWithIndexV2::new( + AggregateExprV2 { + func: sum_udaf().as_ref().clone(), + args: vec![ScalarExpr::Column(0) + .with_type(ColumnType::new(CDT::uint64_datatype(), true))], + return_type: CDT::uint64_datatype(), + name: "sum".to_string(), + schema: RelationType::new(vec![ + ColumnType::new(ConcreteDataType::uint64_datatype(), true), + ColumnType::new(ConcreteDataType::uint32_datatype(), false), + ]) + .into_named(vec![None, Some("number".to_string())]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint64_datatype()], + is_nullable: true, + }, + vec![0], + 0, + ) + .unwrap(), + AggrWithIndexV2::new( + AggregateExprV2 { + func: count_udaf().as_ref().clone(), + args: vec![ScalarExpr::Column(1) + .with_type(ColumnType::new(CDT::uint32_datatype(), false))], + return_type: CDT::int64_datatype(), + name: "count".to_string(), + schema: RelationType::new(vec![ + ColumnType::new(ConcreteDataType::uint64_datatype(), true), + ColumnType::new(ConcreteDataType::uint32_datatype(), false), + ]) + .into_named(vec![None, Some("number".to_string())]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint32_datatype()], + is_nullable: true, + }, + vec![1], + 1, + ) + .unwrap(), ]; let avg_expr = ScalarExpr::If { cond: Box::new(ScalarExpr::Column(2).call_binary( @@ -1137,13 +1426,8 @@ mod test { .unwrap() .into_safe(), }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: aggr_exprs.clone(), - simple_aggrs: vec![ - AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0), - AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1), - ], - distinct_aggrs: vec![], + reduce_plan: ReducePlan::AccumulableV2(AccumulablePlanV2 { + full_aggrs: aggr_exprs, }), } .with_types( @@ -1187,16 +1471,52 @@ mod test { .unwrap(); let aggr_exprs = vec![ - AggregateExpr { - func: AggregateFunc::SumUInt64, - expr: ScalarExpr::Column(0), - distinct: false, - }, - AggregateExpr { - func: AggregateFunc::Count, - expr: ScalarExpr::Column(1), - distinct: false, - }, + AggrWithIndexV2::new( + AggregateExprV2 { + func: sum_udaf().as_ref().clone(), + args: vec![ScalarExpr::Column(0) + .with_type(ColumnType::new(CDT::uint64_datatype(), true))], + return_type: CDT::uint64_datatype(), + name: "sum".to_string(), + schema: RelationType::new(vec![ + ColumnType::new(ConcreteDataType::uint64_datatype(), true), + ColumnType::new(ConcreteDataType::uint32_datatype(), false), + ]) + .into_named(vec![None, Some("number".to_string())]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint64_datatype()], + is_nullable: true, + }, + vec![0], + 0, + ) + .unwrap(), + AggrWithIndexV2::new( + AggregateExprV2 { + func: count_udaf().as_ref().clone(), + args: vec![ScalarExpr::Column(1) + .with_type(ColumnType::new(CDT::uint32_datatype(), false))], + return_type: CDT::int64_datatype(), + name: "count".to_string(), + schema: RelationType::new(vec![ + ColumnType::new(ConcreteDataType::uint64_datatype(), true), + ColumnType::new(ConcreteDataType::uint32_datatype(), false), + ]) + .into_named(vec![None, Some("number".to_string())]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint32_datatype()], + is_nullable: true, + }, + vec![1], + 1, + ) + .unwrap(), ]; let avg_expr = ScalarExpr::If { cond: Box::new(ScalarExpr::Column(1).call_binary( @@ -1259,13 +1579,8 @@ mod test { .unwrap() .into_safe(), }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: aggr_exprs.clone(), - simple_aggrs: vec![ - AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0), - AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1), - ], - distinct_aggrs: vec![], + reduce_plan: ReducePlan::AccumulableV2(AccumulablePlanV2 { + full_aggrs: aggr_exprs, }), } .with_types( @@ -1298,10 +1613,24 @@ mod test { let mut ctx = create_test_ctx(); let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; - let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt64, - expr: ScalarExpr::Column(0), - distinct: false, + let aggr_expr = AggregateExprV2 { + func: sum_udaf().as_ref().clone(), + args: vec![ + ScalarExpr::Column(0).with_type(ColumnType::new(CDT::uint64_datatype(), true)) + ], + return_type: CDT::uint64_datatype(), + name: "sum".to_string(), + schema: RelationType::new(vec![ColumnType::new( + ConcreteDataType::uint64_datatype(), + true, + )]) + .into_named(vec![None]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint64_datatype()], + is_nullable: true, }; let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]) @@ -1334,10 +1663,8 @@ mod test { .unwrap() .into_safe(), }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: vec![aggr_expr.clone()], - simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)], - distinct_aggrs: vec![], + reduce_plan: ReducePlan::AccumulableV2(AccumulablePlanV2 { + full_aggrs: vec![AggrWithIndexV2::new(aggr_expr, vec![0], 0).unwrap()], }), }, }; @@ -1388,11 +1715,7 @@ mod test { .unwrap() .into_safe(), }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: vec![], - simple_aggrs: vec![], - distinct_aggrs: vec![], - }), + reduce_plan: ReducePlan::AccumulableV2(AccumulablePlanV2 { full_aggrs: vec![] }), }, }; @@ -1410,10 +1733,24 @@ mod test { .await .unwrap(); - let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt64, - expr: ScalarExpr::Column(0), - distinct: false, + let aggr_expr = AggregateExprV2 { + func: sum_udaf().as_ref().clone(), + args: vec![ + ScalarExpr::Column(0).with_type(ColumnType::new(CDT::uint64_datatype(), true)) + ], + return_type: CDT::uint64_datatype(), + name: "sum".to_string(), + schema: RelationType::new(vec![ColumnType::new( + ConcreteDataType::uint64_datatype(), + true, + )]) + .into_named(vec![None]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint64_datatype()], + is_nullable: true, }; let expected = TypedPlan { schema: RelationType::new(vec![ @@ -1457,10 +1794,8 @@ mod test { .unwrap() .into_safe(), }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: vec![aggr_expr.clone()], - simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)], - distinct_aggrs: vec![], + reduce_plan: ReducePlan::AccumulableV2(AccumulablePlanV2 { + full_aggrs: vec![AggrWithIndexV2::new(aggr_expr, vec![0], 0).unwrap()], }), } .with_types( @@ -1492,10 +1827,24 @@ mod test { let mut ctx = create_test_ctx(); let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; - let aggr_expr = AggregateExpr { - func: AggregateFunc::SumUInt64, - expr: ScalarExpr::Column(0), - distinct: false, + let aggr_expr = AggregateExprV2 { + func: sum_udaf().as_ref().clone(), + args: vec![ + ScalarExpr::Column(0).with_type(ColumnType::new(CDT::uint64_datatype(), true)) + ], + return_type: CDT::uint64_datatype(), + name: "sum".to_string(), + schema: RelationType::new(vec![ColumnType::new( + ConcreteDataType::uint64_datatype(), + true, + )]) + .into_named(vec![None]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint64_datatype()], + is_nullable: true, }; let expected = TypedPlan { schema: RelationType::new(vec![ColumnType::new(CDT::uint64_datatype(), true)]) @@ -1541,10 +1890,8 @@ mod test { .unwrap() .into_safe(), }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: vec![aggr_expr.clone()], - simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)], - distinct_aggrs: vec![], + reduce_plan: ReducePlan::AccumulableV2(AccumulablePlanV2 { + full_aggrs: vec![AggrWithIndexV2::new(aggr_expr, vec![0], 0).unwrap()], }), }, }; @@ -1561,16 +1908,52 @@ mod test { let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await; let aggr_exprs = vec![ - AggregateExpr { - func: AggregateFunc::MaxUInt32, - expr: ScalarExpr::Column(0), - distinct: false, - }, - AggregateExpr { - func: AggregateFunc::MinUInt32, - expr: ScalarExpr::Column(0), - distinct: false, - }, + AggrWithIndexV2::new( + AggregateExprV2 { + func: max_udaf().as_ref().clone(), + args: vec![ScalarExpr::Column(0) + .with_type(ColumnType::new(CDT::uint32_datatype(), false))], + return_type: CDT::uint32_datatype(), + name: "max".to_string(), + schema: RelationType::new(vec![ + ColumnType::new(ConcreteDataType::uint32_datatype(), false), + ColumnType::new(ConcreteDataType::timestamp_millisecond_datatype(), false), + ]) + .into_named(vec![Some("number".to_string()), Some("ts".to_string())]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint32_datatype()], + is_nullable: true, + }, + vec![0], + 0, + ) + .unwrap(), + AggrWithIndexV2::new( + AggregateExprV2 { + func: min_udaf().as_ref().clone(), + args: vec![ScalarExpr::Column(0) + .with_type(ColumnType::new(CDT::uint32_datatype(), false))], + return_type: CDT::uint32_datatype(), + name: "min".to_string(), + schema: RelationType::new(vec![ + ColumnType::new(ConcreteDataType::uint32_datatype(), false), + ColumnType::new(ConcreteDataType::timestamp_millisecond_datatype(), false), + ]) + .into_named(vec![Some("number".to_string()), Some("ts".to_string())]), + ordering_req: OrderingReq::empty(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + input_types: vec![CDT::uint32_datatype()], + is_nullable: true, + }, + vec![0], + 1, + ) + .unwrap(), ]; let expected = TypedPlan { schema: RelationType::new(vec![ @@ -1648,11 +2031,8 @@ mod test { val_plan: MapFilterProject::new(2) .into_safe(), }, - reduce_plan: ReducePlan::Accumulable(AccumulablePlan { - full_aggrs: aggr_exprs.clone(), - simple_aggrs: vec![AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0), - AggrWithIndex::new(aggr_exprs[1].clone(), 0, 1)], - distinct_aggrs: vec![], + reduce_plan: ReducePlan::AccumulableV2(AccumulablePlanV2 { + full_aggrs: aggr_exprs, }), } .with_types( diff --git a/src/flow/src/transform/expr.rs b/src/flow/src/transform/expr.rs index ed75252ee21c..981748267f65 100644 --- a/src/flow/src/transform/expr.rs +++ b/src/flow/src/transform/expr.rs @@ -292,6 +292,7 @@ impl TypedExpr { 1 if UnaryFunc::is_valid_func_name(fn_name) => { let func = UnaryFunc::from_str_and_type(fn_name, None)?; let arg = arg_exprs[0].clone(); + // TODO(discord9); forward nullable to return type too? let ret_type = ColumnType::new_nullable(func.signature().output.clone()); Ok(TypedExpr::new(arg.call_unary(func), ret_type)) @@ -347,7 +348,7 @@ impl TypedExpr { datatypes::types::cast(val.clone(), &dest_type) .with_context(|_| DatatypesSnafu{ - extra: format!("Failed to implicitly cast literal {val:?} to type {dest_type:?}") + context: format!("Failed to implicitly cast literal {val:?} to type {dest_type:?}") })? } else { val.clone()