Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(flow): flow aggr udaf refactor #5515

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
12 changes: 12 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/flow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ datafusion-expr.workspace = true
datafusion-physical-expr.workspace = true
datafusion-substrait.workspace = true
datatypes.workspace = true
derive-where = "1.2.7"
enum-as-inner = "0.6.0"
enum_dispatch = "0.3"
futures = "0.3"
Expand Down
35 changes: 33 additions & 2 deletions src/flow/src/adapter/node_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use std::sync::Arc;

use common_recordbatch::RecordBatch;
use common_telemetry::trace;
use datafusion_expr::AggregateUDF;
use datatypes::prelude::ConcreteDataType;
use session::context::QueryContext;
use snafu::{OptionExt, ResultExt};
Expand Down Expand Up @@ -59,11 +60,26 @@ pub struct FlownodeContext {
/// All the tables that have been registered in the worker
pub table_repr: IdToNameMap,
pub query_context: Option<Arc<QueryContext>>,
/// Aggregate functions registered in the context
pub aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
}

pub fn all_built_in_udaf() -> HashMap<String, Arc<AggregateUDF>> {
// sum/min/max/count are built-in aggregate functions
use datafusion::functions_aggregate::count::count_udaf;
use datafusion::functions_aggregate::min_max::{max_udaf, min_udaf};
use datafusion::functions_aggregate::sum::sum_udaf;
HashMap::from([
("sum".to_string(), sum_udaf()),
("min".to_string(), min_udaf()),
("max".to_string(), max_udaf()),
("count".to_string(), count_udaf()),
])
}

impl FlownodeContext {
pub fn new(table_source: Box<dyn FlowTableSource>) -> Self {
Self {
let mut ret = Self {
source_to_tasks: Default::default(),
flow_to_sink: Default::default(),
flow_plans: Default::default(),
Expand All @@ -73,7 +89,22 @@ impl FlownodeContext {
table_source,
table_repr: Default::default(),
query_context: Default::default(),
}
aggregate_functions: Default::default(),
};
ret.register_all_built_in_aggr_fns();
ret
}

pub fn register_all_built_in_aggr_fns(&mut self) {
self.aggregate_functions.extend(all_built_in_udaf());
}

pub fn register_aggr_fn(
&mut self,
name: String,
aggr_fn: Arc<AggregateUDF>,
) -> Option<Arc<AggregateUDF>> {
self.aggregate_functions.insert(name, aggr_fn)
}

pub fn get_flow_ids(&self, table_id: TableId) -> Option<&BTreeSet<FlowId>> {
Expand Down
110 changes: 64 additions & 46 deletions src/flow/src/compute/render/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use std::ops::Range;
use std::sync::Arc;

use arrow::array::new_null_array;
use common_error::ext::BoxedError;
use common_telemetry::trace;
use datatypes::data_type::ConcreteDataType;
use datatypes::prelude::DataType;
Expand All @@ -29,9 +30,14 @@ use snafu::{ensure, OptionExt, ResultExt};
use crate::compute::render::{Context, SubgraphArg};
use crate::compute::types::{Arranged, Collection, CollectionBundle, ErrCollector, Toff};
use crate::error::{Error, NotImplementedSnafu, PlanSnafu};
use crate::expr::error::{ArrowSnafu, DataAlreadyExpiredSnafu, DataTypeSnafu, InternalSnafu};
use crate::expr::{Accum, Accumulator, Batch, EvalError, ScalarExpr, VectorDiff};
use crate::plan::{AccumulablePlan, AggrWithIndex, KeyValPlan, ReducePlan, TypedPlan};
use crate::expr::error::{
ArrowSnafu, DataAlreadyExpiredSnafu, DataTypeSnafu, ExternalSnafu, InternalSnafu,
};
use crate::expr::{Batch, EvalError, ScalarExpr};
use crate::plan::{
AccumulablePlan, AccumulablePlanV2, AggrWithIndex, AggrWithIndexV2, KeyValPlan, ReducePlan,
TypedPlan,
};
use crate::repr::{self, DiffRow, KeyValDiffRow, RelationType, Row};
use crate::utils::{ArrangeHandler, ArrangeReader, ArrangeWriter, KeyExpiryManager};

Expand All @@ -48,13 +54,7 @@ impl Context<'_, '_> {
reduce_plan: &ReducePlan,
output_type: &RelationType,
) -> Result<CollectionBundle<Batch>, Error> {
let accum_plan = if let ReducePlan::Accumulable(accum_plan) = reduce_plan {
if !accum_plan.distinct_aggrs.is_empty() {
NotImplementedSnafu {
reason: "Distinct aggregation is not supported in batch mode",
}
.fail()?
}
let accum_plan = if let ReducePlan::AccumulableV2(accum_plan) = reduce_plan {
accum_plan.clone()
} else {
NotImplementedSnafu {
Expand Down Expand Up @@ -252,6 +252,7 @@ impl Context<'_, '_> {
) -> Option<Vec<ArrangeHandler>> {
match reduce_plan {
ReducePlan::Distinct => None,
ReducePlan::AccumulableV2(_) => None,
ReducePlan::Accumulable(AccumulablePlan { distinct_aggrs, .. }) => {
(!distinct_aggrs.is_empty()).then(|| {
std::iter::repeat_with(|| {
Expand Down Expand Up @@ -357,7 +358,7 @@ fn reduce_batch_subgraph(
arrange: &ArrangeHandler,
src_data: impl IntoIterator<Item = Batch>,
key_val_plan: &KeyValPlan,
accum_plan: &AccumulablePlan,
accum_plan: &AccumulablePlanV2,
output_type: &RelationType,
SubgraphArg {
now,
Expand Down Expand Up @@ -529,39 +530,48 @@ fn reduce_batch_subgraph(
err_collector.run(|| -> Result<(), _> {
let (accums, _, _) = arrange.get(now, &key).unwrap_or_default();
let accum_list =
from_accum_values_to_live_accums(accums.unpack(), accum_plan.simple_aggrs.len())?;
from_accum_values_to_live_accums(accums.unpack(), accum_plan.full_aggrs.len())?;

let mut accum_output = AccumOutput::new();
for AggrWithIndex {
for AggrWithIndexV2 {
expr,
input_idx,
input_idxs,
output_idx,
} in accum_plan.simple_aggrs.iter()
state_types,
} in accum_plan.full_aggrs.iter()
{
let cur_accum_value = accum_list.get(*output_idx).cloned().unwrap_or_default();
let mut cur_accum = if cur_accum_value.is_empty() {
Accum::new_accum(&expr.func.clone())?
} else {
Accum::try_into_accum(&expr.func, cur_accum_value)?
};
let mut cur_accum = expr
.create_accumulator()
.map_err(BoxedError::new)
.context(ExternalSnafu)?;
if !cur_accum_value.is_empty() {
cur_accum.merge_states(&[cur_accum_value], state_types)?;
}

for val_batch in val_batches.iter() {
// if batch is empty, input null instead
let cur_input = val_batch
.batch()
.get(*input_idx)
.cloned()
.unwrap_or_else(|| Arc::new(NullVector::new(val_batch.row_count())));
let batch = val_batch.batch();

let cur_input = input_idxs
.iter()
.map(|idx| {
batch
.get(*idx)
.cloned()
.unwrap_or_else(|| Arc::new(NullVector::new(val_batch.row_count())))
})
.collect_vec();
let len = cur_input.len();
cur_accum.update_batch(&expr.func, VectorDiff::from(cur_input))?;
cur_accum.update_batch(&cur_input)?;

trace!("Reduce accum after take {} rows: {:?}", len, cur_accum);
}
let final_output = cur_accum.eval(&expr.func)?;
let final_output = cur_accum.evaluate()?;
trace!("Reduce accum final output: {:?}", final_output);
accum_output.insert_output(*output_idx, final_output);

let cur_accum_value = cur_accum.into_state();
let cur_accum_value = cur_accum.state()?.into_iter().map(|(v, _)| v).collect();
accum_output.insert_accum(*output_idx, cur_accum_value);
}

Expand Down Expand Up @@ -679,6 +689,7 @@ fn reduce_subgraph(
send,
},
),
ReducePlan::AccumulableV2(_) => unimplemented!(),
};
}

Expand Down Expand Up @@ -1211,12 +1222,14 @@ mod test {
use std::time::Duration;

use common_time::Timestamp;
use datafusion::functions_aggregate::sum::sum_udaf;
use datatypes::data_type::{ConcreteDataType, ConcreteDataType as CDT};
use hydroflow::scheduled::graph::Hydroflow;

use super::*;
use crate::compute::render::test::{get_output_handle, harness_test_ctx, run_and_check};
use crate::compute::state::DataflowState;
use crate::expr::relation::{AggregateExprV2, OrderingReq};
use crate::expr::{
self, AggregateExpr, AggregateFunc, BinaryFunc, GlobalId, MapFilterProject, UnaryFunc,
};
Expand Down Expand Up @@ -1581,26 +1594,31 @@ mod test {
val_plan: MapFilterProject::new(1).project([0]).unwrap().into_safe(),
};

let simple_aggrs = vec![AggrWithIndex::new(
AggregateExpr {
func: AggregateFunc::SumInt64,
expr: ScalarExpr::Column(0),
distinct: false,
},
0,
0,
)];
let accum_plan = AccumulablePlan {
full_aggrs: vec![AggregateExpr {
func: AggregateFunc::SumInt64,
expr: ScalarExpr::Column(0),
distinct: false,
}],
simple_aggrs,
distinct_aggrs: vec![],
let aggr_expr = AggregateExprV2 {
func: sum_udaf().as_ref().clone(),
args: vec![
ScalarExpr::Column(0).with_type(ColumnType::new(CDT::int64_datatype(), false))
],
return_type: CDT::int64_datatype(),
name: "sum".to_string(),
schema: RelationType::new(vec![ColumnType::new(
ConcreteDataType::int32_datatype(),
false,
)])
.into_named(vec![Some("number".to_string())]),
ordering_req: OrderingReq::empty(),
ignore_nulls: false,
is_distinct: false,
is_reversed: false,
input_types: vec![CDT::int64_datatype()],
is_nullable: true,
};

let reduce_plan = ReducePlan::Accumulable(accum_plan);
let accum_plan = AccumulablePlanV2 {
full_aggrs: vec![AggrWithIndexV2::new(aggr_expr, vec![0], 0).unwrap()],
};

let reduce_plan = ReducePlan::AccumulableV2(accum_plan);
let bundle = ctx
.render_reduce_batch(
Box::new(input_plan.with_types(typ.into_unnamed())),
Expand Down
4 changes: 2 additions & 2 deletions src/flow/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,10 @@ pub enum Error {
location: Location,
},

#[snafu(display("Datatypes error: {source} with extra message: {extra}"))]
#[snafu(display("Datatypes error: {source} with extra message: {context}"))]
Datatypes {
source: datatypes::Error,
extra: String,
context: String,
#[snafu(implicit)]
location: Location,
},
Expand Down
2 changes: 1 addition & 1 deletion src/flow/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub(crate) use func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc}
pub(crate) use id::{GlobalId, Id, LocalId};
use itertools::Itertools;
pub(crate) use linear::{MapFilterProject, MfpPlan, SafeMfpPlan};
pub(crate) use relation::{Accum, Accumulator, AggregateExpr, AggregateFunc};
pub(crate) use relation::{AggregateExpr, AggregateFunc};
pub(crate) use scalar::{ScalarExpr, TypedExpr};
use snafu::{ensure, ResultExt};

Expand Down
68 changes: 66 additions & 2 deletions src/flow/src/expr/relation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,24 @@

//! Describes an aggregation function and it's input expression.

pub(crate) use accum::{Accum, Accumulator};
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::AggregateUDF;
use datatypes::prelude::{ConcreteDataType, DataType};
use derive_where::derive_where;
pub(crate) use func::AggregateFunc;
use snafu::ResultExt;
pub use udaf::{OrderingReq, SortExpr};

use crate::expr::ScalarExpr;
use crate::error::DatafusionSnafu;
use crate::expr::relation::accum_v2::{AccumulatorV2, DfAccumulatorAdapter};
use crate::expr::{ScalarExpr, TypedExpr};
use crate::repr::RelationDesc;
use crate::Error;

mod accum;
mod accum_v2;
mod func;
mod udaf;

/// Describes an aggregation expression.
#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
Expand All @@ -34,3 +45,56 @@ pub struct AggregateExpr {
/// Should the aggregation be applied only to distinct results in each group.
pub distinct: bool,
}

#[derive(Clone, Debug, Eq, PartialEq, PartialOrd)]
#[derive_where(Ord)]
pub struct AggregateExprV2 {
/// skipping `Ord` impl for func for convenience
#[derive_where(skip)]
pub func: AggregateUDF,
/// should only be a simple column ref list
pub args: Vec<TypedExpr>,
/// Output / return type of this aggregate
pub return_type: ConcreteDataType,
pub name: String,
/// The schema of the input relation to this aggregate
pub schema: RelationDesc,
// i.e. FIRST_VALUE(a ORDER BY b)
pub ordering_req: OrderingReq,
pub ignore_nulls: bool,
pub is_distinct: bool,
pub is_reversed: bool,
/// The types of the arguments to this aggregate
pub input_types: Vec<ConcreteDataType>,
pub is_nullable: bool,
}

impl AggregateExprV2 {}

impl AggregateExprV2 {
pub fn create_accumulator(&self) -> Result<Box<dyn AccumulatorV2>, Error> {
let data_type = self.return_type.as_arrow_type();
let schema = self.schema.to_df_schema()?;
let ordering_req = self.ordering_req.to_lex_ordering(&schema)?;
let exprs = self
.args
.iter()
.map(|e| e.expr.as_physical_expr(&schema))
.collect::<Result<Vec<_>, _>>()?;
let accum_args = AccumulatorArgs {
return_type: &data_type,
schema: schema.as_arrow(),
ignore_nulls: self.ignore_nulls,
ordering_req: &ordering_req,
is_reversed: self.is_reversed,
name: &self.name,
is_distinct: self.is_distinct,
exprs: &exprs,
};
let acc = self.func.accumulator(accum_args).context(DatafusionSnafu {
context: "Fail to build accumulator",
})?;
let acc = DfAccumulatorAdapter::new_unchecked(acc);
Ok(Box::new(acc))
}
}
Loading
Loading