Skip to content

Commit

Permalink
Use take_function_args in more places (#14525)
Browse files Browse the repository at this point in the history
* refactor: apply take_function_args() in functions crate

* fix: handle plural vs. singular grammar for "argument(s)"

* fix: run cargo clippy and fix errors

* style: apply cargo fmt

* refactor: move func to datafusion_common and update imports

* refactor: apply take_function_args

* fix: update test output language

* fix: simplify doc test for take_function_args

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
lgingerich and alamb authored Feb 12, 2025
1 parent 82461b7 commit 2d30334
Show file tree
Hide file tree
Showing 44 changed files with 298 additions and 470 deletions.
41 changes: 40 additions & 1 deletion datafusion/common/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub mod memory;
pub mod proxy;
pub mod string_utils;

use crate::error::{_internal_datafusion_err, _internal_err};
use crate::error::{_exec_datafusion_err, _internal_datafusion_err, _internal_err};
use crate::{DataFusionError, Result, ScalarValue};
use arrow::array::{
cast::AsArray, Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray,
Expand Down Expand Up @@ -905,6 +905,45 @@ pub fn get_available_parallelism() -> usize {
.get()
}

/// Converts a collection of function arguments into an fixed-size array of length N
/// producing a reasonable error message in case of unexpected number of arguments.
///
/// # Example
/// ```
/// # use datafusion_common::Result;
/// # use datafusion_common::utils::take_function_args;
/// # use datafusion_common::ScalarValue;
/// fn my_function(args: &[ScalarValue]) -> Result<()> {
/// // function expects 2 args, so create a 2-element array
/// let [arg1, arg2] = take_function_args("my_function", args)?;
/// // ... do stuff..
/// Ok(())
/// }
///
/// // Calling the function with 1 argument produces an error:
/// let args = vec![ScalarValue::Int32(Some(10))];
/// let err = my_function(&args).unwrap_err();
/// assert_eq!(err.to_string(), "Execution error: my_function function requires 2 arguments, got 1");
/// // Calling the function with 2 arguments works great
/// let args = vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(20))];
/// my_function(&args).unwrap();
/// ```
pub fn take_function_args<const N: usize, T>(
function_name: &str,
args: impl IntoIterator<Item = T>,
) -> Result<[T; N]> {
let args = args.into_iter().collect::<Vec<_>>();
args.try_into().map_err(|v: Vec<T>| {
_exec_datafusion_err!(
"{} function requires {} {}, got {}",
function_name,
N,
if N == 1 { "argument" } else { "arguments" },
v.len()
)
})
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
8 changes: 3 additions & 5 deletions datafusion/expr/src/test/function_stub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use arrow::datatypes::{
DataType, Field, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
};

use datafusion_common::{exec_err, not_impl_err, Result};
use datafusion_common::{exec_err, not_impl_err, utils::take_function_args, Result};

use crate::type_coercion::aggregates::{avg_return_type, coerce_avg_type, NUMERICS};
use crate::Volatility::Immutable;
Expand Down Expand Up @@ -125,9 +125,7 @@ impl AggregateUDFImpl for Sum {
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 1 {
return exec_err!("SUM expects exactly one argument");
}
let [array] = take_function_args(self.name(), arg_types)?;

// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
// smallint, int, bigint, real, double precision, decimal, or interval.
Expand All @@ -147,7 +145,7 @@ impl AggregateUDFImpl for Sum {
}
}

Ok(vec![coerced_type(&arg_types[0])?])
Ok(vec![coerced_type(array)?])
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Expand Down
10 changes: 5 additions & 5 deletions datafusion/functions-aggregate/src/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ use arrow::datatypes::{
i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field,
Float64Type, UInt64Type,
};
use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue};
use datafusion_common::{
exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::type_coercion::aggregates::{avg_return_type, coerce_avg_type};
use datafusion_expr::utils::format_state_name;
Expand Down Expand Up @@ -247,10 +249,8 @@ impl AggregateUDFImpl for Avg {
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 1 {
return exec_err!("{} expects exactly one argument.", self.name());
}
coerce_avg_type(self.name(), arg_types)
let [args] = take_function_args(self.name(), arg_types)?;
coerce_avg_type(self.name(), std::slice::from_ref(args))
}

fn documentation(&self) -> Option<&Documentation> {
Expand Down
10 changes: 5 additions & 5 deletions datafusion/functions-aggregate/src/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ use arrow::datatypes::{
DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
};
use arrow::{array::ArrayRef, datatypes::Field};
use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue};
use datafusion_common::{
exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
};
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::utils::format_state_name;
Expand Down Expand Up @@ -125,9 +127,7 @@ impl AggregateUDFImpl for Sum {
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 1 {
return exec_err!("SUM expects exactly one argument");
}
let [args] = take_function_args(self.name(), arg_types)?;

// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
// smallint, int, bigint, real, double precision, decimal, or interval.
Expand All @@ -147,7 +147,7 @@ impl AggregateUDFImpl for Sum {
}
}

Ok(vec![coerced_type(&arg_types[0])?])
Ok(vec![coerced_type(args)?])
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Expand Down
14 changes: 6 additions & 8 deletions datafusion/functions-nested/src/cardinality.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use arrow::datatypes::{
DataType::{FixedSizeList, LargeList, List, Map, UInt64},
};
use datafusion_common::cast::{as_large_list_array, as_list_array, as_map_array};
use datafusion_common::utils::take_function_args;
use datafusion_common::Result;
use datafusion_common::{exec_err, plan_err};
use datafusion_expr::{
Expand Down Expand Up @@ -127,21 +128,18 @@ impl ScalarUDFImpl for Cardinality {

/// Cardinality SQL function
pub fn cardinality_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("cardinality expects one argument");
}

match &args[0].data_type() {
let [array] = take_function_args("cardinality", args)?;
match &array.data_type() {
List(_) => {
let list_array = as_list_array(&args[0])?;
let list_array = as_list_array(&array)?;
generic_list_cardinality::<i32>(list_array)
}
LargeList(_) => {
let list_array = as_large_list_array(&args[0])?;
let list_array = as_large_list_array(&array)?;
generic_list_cardinality::<i64>(list_array)
}
Map(_, _) => {
let map_array = as_map_array(&args[0])?;
let map_array = as_map_array(&array)?;
generic_map_cardinality(map_array)
}
other => {
Expand Down
16 changes: 7 additions & 9 deletions datafusion/functions-nested/src/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ use arrow::buffer::OffsetBuffer;
use arrow::datatypes::{DataType, Field};
use datafusion_common::Result;
use datafusion_common::{
cast::as_generic_list_array, exec_err, not_impl_err, plan_err, utils::list_ndims,
cast::as_generic_list_array,
exec_err, not_impl_err, plan_err,
utils::{list_ndims, take_function_args},
};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
Expand Down Expand Up @@ -415,23 +417,19 @@ fn concat_internal<O: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {

/// Array_append SQL function
pub(crate) fn array_append_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_append expects two arguments");
}
let [array, _] = take_function_args("array_append", args)?;

match args[0].data_type() {
match array.data_type() {
DataType::LargeList(_) => general_append_and_prepend::<i64>(args, true),
_ => general_append_and_prepend::<i32>(args, true),
}
}

/// Array_prepend SQL function
pub(crate) fn array_prepend_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_prepend expects two arguments");
}
let [_, array] = take_function_args("array_prepend", args)?;

match args[1].data_type() {
match array.data_type() {
DataType::LargeList(_) => general_append_and_prepend::<i64>(args, false),
_ => general_append_and_prepend::<i32>(args, false),
}
Expand Down
22 changes: 9 additions & 13 deletions datafusion/functions-nested/src/dimension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use arrow::datatypes::{
use std::any::Any;

use datafusion_common::cast::{as_large_list_array, as_list_array};
use datafusion_common::{exec_err, plan_err, Result};
use datafusion_common::{exec_err, plan_err, utils::take_function_args, Result};

use crate::utils::{compute_array_dims, make_scalar_function};
use datafusion_expr::{
Expand Down Expand Up @@ -203,20 +203,18 @@ impl ScalarUDFImpl for ArrayNdims {

/// Array_dims SQL function
pub fn array_dims_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("array_dims needs one argument");
}
let [array] = take_function_args("array_dims", args)?;

let data = match args[0].data_type() {
let data = match array.data_type() {
List(_) => {
let array = as_list_array(&args[0])?;
let array = as_list_array(&array)?;
array
.iter()
.map(compute_array_dims)
.collect::<Result<Vec<_>>>()?
}
LargeList(_) => {
let array = as_large_list_array(&args[0])?;
let array = as_large_list_array(&array)?;
array
.iter()
.map(compute_array_dims)
Expand All @@ -234,9 +232,7 @@ pub fn array_dims_inner(args: &[ArrayRef]) -> Result<ArrayRef> {

/// Array_ndims SQL function
pub fn array_ndims_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("array_ndims needs one argument");
}
let [array_dim] = take_function_args("array_ndims", args)?;

fn general_list_ndims<O: OffsetSizeTrait>(
array: &GenericListArray<O>,
Expand All @@ -254,13 +250,13 @@ pub fn array_ndims_inner(args: &[ArrayRef]) -> Result<ArrayRef> {

Ok(Arc::new(UInt64Array::from(data)) as ArrayRef)
}
match args[0].data_type() {
match array_dim.data_type() {
List(_) => {
let array = as_list_array(&args[0])?;
let array = as_list_array(&array_dim)?;
general_list_ndims::<i32>(array)
}
LargeList(_) => {
let array = as_large_list_array(&args[0])?;
let array = as_large_list_array(&array_dim)?;
general_list_ndims::<i64>(array)
}
array_type => exec_err!("array_ndims does not support type {array_type:?}"),
Expand Down
14 changes: 6 additions & 8 deletions datafusion/functions-nested/src/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ use datafusion_common::cast::{
as_int64_array,
};
use datafusion_common::utils::coerced_fixed_size_list_to_list;
use datafusion_common::{exec_err, internal_datafusion_err, Result};
use datafusion_common::{
exec_err, internal_datafusion_err, utils::take_function_args, Result,
};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
Expand Down Expand Up @@ -110,9 +112,7 @@ impl ScalarUDFImpl for ArrayDistance {
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 2 {
return exec_err!("array_distance expects exactly two arguments");
}
let [_, _] = take_function_args(self.name(), arg_types)?;
let mut result = Vec::new();
for arg_type in arg_types {
match arg_type {
Expand Down Expand Up @@ -142,11 +142,9 @@ impl ScalarUDFImpl for ArrayDistance {
}

pub fn array_distance_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_distance expects exactly two arguments");
}
let [array1, array2] = take_function_args("array_distance", args)?;

match (&args[0].data_type(), &args[1].data_type()) {
match (&array1.data_type(), &array2.data_type()) {
(List(_), List(_)) => general_array_distance::<i32>(args),
(LargeList(_), LargeList(_)) => general_array_distance::<i64>(args),
(array_type1, array_type2) => {
Expand Down
12 changes: 5 additions & 7 deletions datafusion/functions-nested/src/empty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use arrow::datatypes::{
DataType::{Boolean, FixedSizeList, LargeList, List},
};
use datafusion_common::cast::as_generic_list_array;
use datafusion_common::{exec_err, plan_err, Result};
use datafusion_common::{exec_err, plan_err, utils::take_function_args, Result};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
Expand Down Expand Up @@ -117,14 +117,12 @@ impl ScalarUDFImpl for ArrayEmpty {

/// Array_empty SQL function
pub fn array_empty_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("array_empty expects one argument");
}
let [array] = take_function_args("array_empty", args)?;

let array_type = args[0].data_type();
let array_type = array.data_type();
match array_type {
List(_) => general_array_empty::<i32>(&args[0]),
LargeList(_) => general_array_empty::<i64>(&args[0]),
List(_) => general_array_empty::<i32>(array),
LargeList(_) => general_array_empty::<i64>(array),
_ => exec_err!("array_empty does not support type '{array_type:?}'."),
}
}
Expand Down
10 changes: 3 additions & 7 deletions datafusion/functions-nested/src/except.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ use arrow::array::{cast::AsArray, Array, ArrayRef, GenericListArray, OffsetSizeT
use arrow::buffer::OffsetBuffer;
use arrow::datatypes::{DataType, FieldRef};
use arrow::row::{RowConverter, SortField};
use datafusion_common::{exec_err, internal_err, HashSet, Result};
use datafusion_common::utils::take_function_args;
use datafusion_common::{internal_err, HashSet, Result};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
Expand Down Expand Up @@ -124,12 +125,7 @@ impl ScalarUDFImpl for ArrayExcept {

/// Array_except SQL function
pub fn array_except_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_except needs two arguments");
}

let array1 = &args[0];
let array2 = &args[1];
let [array1, array2] = take_function_args("array_except", args)?;

match (array1.data_type(), array2.data_type()) {
(DataType::Null, _) | (_, DataType::Null) => Ok(array1.to_owned()),
Expand Down
Loading

0 comments on commit 2d30334

Please sign in to comment.