Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use take_function_args in more places #14525

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
41 changes: 40 additions & 1 deletion datafusion/common/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub mod memory;
pub mod proxy;
pub mod string_utils;

use crate::error::{_internal_datafusion_err, _internal_err};
use crate::error::{_exec_datafusion_err, _internal_datafusion_err, _internal_err};
use crate::{DataFusionError, Result, ScalarValue};
use arrow::array::{
cast::AsArray, Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray,
Expand Down Expand Up @@ -905,6 +905,45 @@ pub fn get_available_parallelism() -> usize {
.get()
}

/// Converts a collection of function arguments into an fixed-size array of length N
/// producing a reasonable error message in case of unexpected number of arguments.
///
/// # Example
/// ```
/// # use datafusion_common::Result;
/// # use datafusion_common::utils::take_function_args;
/// # use datafusion_common::ScalarValue;
/// fn my_function(args: &[ScalarValue]) -> Result<()> {
/// // function expects 2 args, so create a 2-element array
/// let [arg1, arg2] = take_function_args("my_function", args)?;
/// // ... do stuff..
/// Ok(())
/// }
///
/// // Calling the function with 1 argument produces an error:
/// let args = vec![ScalarValue::Int32(Some(10))];
/// let err = my_function(&args).unwrap_err();
/// assert_eq!(err.to_string(), "Execution error: my_function function requires 2 arguments, got 1");
/// // Calling the function with 2 arguments works great
/// let args = vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(20))];
/// my_function(&args).unwrap();
/// ```
pub fn take_function_args<const N: usize, T>(
function_name: &str,
args: impl IntoIterator<Item = T>,
) -> Result<[T; N]> {
let args = args.into_iter().collect::<Vec<_>>();
args.try_into().map_err(|v: Vec<T>| {
_exec_datafusion_err!(
"{} function requires {} {}, got {}",
function_name,
N,
if N == 1 { "argument" } else { "arguments" },
v.len()
)
})
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
8 changes: 3 additions & 5 deletions datafusion/expr/src/test/function_stub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use arrow::datatypes::{
DataType, Field, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
};

use datafusion_common::{exec_err, not_impl_err, Result};
use datafusion_common::{exec_err, not_impl_err, utils::take_function_args, Result};

use crate::type_coercion::aggregates::{avg_return_type, coerce_avg_type, NUMERICS};
use crate::Volatility::Immutable;
Expand Down Expand Up @@ -125,9 +125,7 @@ impl AggregateUDFImpl for Sum {
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 1 {
return exec_err!("SUM expects exactly one argument");
}
let [array] = take_function_args(self.name(), arg_types)?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is so much nicer


// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
// smallint, int, bigint, real, double precision, decimal, or interval.
Expand All @@ -147,7 +145,7 @@ impl AggregateUDFImpl for Sum {
}
}

Ok(vec![coerced_type(&arg_types[0])?])
Ok(vec![coerced_type(array)?])
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Expand Down
10 changes: 5 additions & 5 deletions datafusion/functions-aggregate/src/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ use arrow::datatypes::{
i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field,
Float64Type, UInt64Type,
};
use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue};
use datafusion_common::{
exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::type_coercion::aggregates::{avg_return_type, coerce_avg_type};
use datafusion_expr::utils::format_state_name;
Expand Down Expand Up @@ -247,10 +249,8 @@ impl AggregateUDFImpl for Avg {
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 1 {
return exec_err!("{} expects exactly one argument.", self.name());
}
coerce_avg_type(self.name(), arg_types)
let [args] = take_function_args(self.name(), arg_types)?;
coerce_avg_type(self.name(), std::slice::from_ref(args))
}

fn documentation(&self) -> Option<&Documentation> {
Expand Down
10 changes: 5 additions & 5 deletions datafusion/functions-aggregate/src/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ use arrow::datatypes::{
DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
};
use arrow::{array::ArrayRef, datatypes::Field};
use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue};
use datafusion_common::{
exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
};
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::utils::format_state_name;
Expand Down Expand Up @@ -125,9 +127,7 @@ impl AggregateUDFImpl for Sum {
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 1 {
return exec_err!("SUM expects exactly one argument");
}
let [args] = take_function_args(self.name(), arg_types)?;

// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
// smallint, int, bigint, real, double precision, decimal, or interval.
Expand All @@ -147,7 +147,7 @@ impl AggregateUDFImpl for Sum {
}
}

Ok(vec![coerced_type(&arg_types[0])?])
Ok(vec![coerced_type(args)?])
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Expand Down
14 changes: 6 additions & 8 deletions datafusion/functions-nested/src/cardinality.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use arrow::datatypes::{
DataType::{FixedSizeList, LargeList, List, Map, UInt64},
};
use datafusion_common::cast::{as_large_list_array, as_list_array, as_map_array};
use datafusion_common::utils::take_function_args;
use datafusion_common::Result;
use datafusion_common::{exec_err, plan_err};
use datafusion_expr::{
Expand Down Expand Up @@ -127,21 +128,18 @@ impl ScalarUDFImpl for Cardinality {

/// Cardinality SQL function
pub fn cardinality_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("cardinality expects one argument");
}

match &args[0].data_type() {
let [array] = take_function_args("cardinality", args)?;
match &array.data_type() {
List(_) => {
let list_array = as_list_array(&args[0])?;
let list_array = as_list_array(&array)?;
generic_list_cardinality::<i32>(list_array)
}
LargeList(_) => {
let list_array = as_large_list_array(&args[0])?;
let list_array = as_large_list_array(&array)?;
generic_list_cardinality::<i64>(list_array)
}
Map(_, _) => {
let map_array = as_map_array(&args[0])?;
let map_array = as_map_array(&array)?;
generic_map_cardinality(map_array)
}
other => {
Expand Down
16 changes: 7 additions & 9 deletions datafusion/functions-nested/src/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ use arrow::buffer::OffsetBuffer;
use arrow::datatypes::{DataType, Field};
use datafusion_common::Result;
use datafusion_common::{
cast::as_generic_list_array, exec_err, not_impl_err, plan_err, utils::list_ndims,
cast::as_generic_list_array,
exec_err, not_impl_err, plan_err,
utils::{list_ndims, take_function_args},
};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
Expand Down Expand Up @@ -415,23 +417,19 @@ fn concat_internal<O: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {

/// Array_append SQL function
pub(crate) fn array_append_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_append expects two arguments");
}
let [array, _] = take_function_args("array_append", args)?;

match args[0].data_type() {
match array.data_type() {
DataType::LargeList(_) => general_append_and_prepend::<i64>(args, true),
_ => general_append_and_prepend::<i32>(args, true),
}
}

/// Array_prepend SQL function
pub(crate) fn array_prepend_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_prepend expects two arguments");
}
let [_, array] = take_function_args("array_prepend", args)?;

match args[1].data_type() {
match array.data_type() {
DataType::LargeList(_) => general_append_and_prepend::<i64>(args, false),
_ => general_append_and_prepend::<i32>(args, false),
}
Expand Down
22 changes: 9 additions & 13 deletions datafusion/functions-nested/src/dimension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use arrow::datatypes::{
use std::any::Any;

use datafusion_common::cast::{as_large_list_array, as_list_array};
use datafusion_common::{exec_err, plan_err, Result};
use datafusion_common::{exec_err, plan_err, utils::take_function_args, Result};

use crate::utils::{compute_array_dims, make_scalar_function};
use datafusion_expr::{
Expand Down Expand Up @@ -203,20 +203,18 @@ impl ScalarUDFImpl for ArrayNdims {

/// Array_dims SQL function
pub fn array_dims_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("array_dims needs one argument");
}
let [array] = take_function_args("array_dims", args)?;

let data = match args[0].data_type() {
let data = match array.data_type() {
List(_) => {
let array = as_list_array(&args[0])?;
let array = as_list_array(&array)?;
array
.iter()
.map(compute_array_dims)
.collect::<Result<Vec<_>>>()?
}
LargeList(_) => {
let array = as_large_list_array(&args[0])?;
let array = as_large_list_array(&array)?;
array
.iter()
.map(compute_array_dims)
Expand All @@ -234,9 +232,7 @@ pub fn array_dims_inner(args: &[ArrayRef]) -> Result<ArrayRef> {

/// Array_ndims SQL function
pub fn array_ndims_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("array_ndims needs one argument");
}
let [array_dim] = take_function_args("array_ndims", args)?;

fn general_list_ndims<O: OffsetSizeTrait>(
array: &GenericListArray<O>,
Expand All @@ -254,13 +250,13 @@ pub fn array_ndims_inner(args: &[ArrayRef]) -> Result<ArrayRef> {

Ok(Arc::new(UInt64Array::from(data)) as ArrayRef)
}
match args[0].data_type() {
match array_dim.data_type() {
List(_) => {
let array = as_list_array(&args[0])?;
let array = as_list_array(&array_dim)?;
general_list_ndims::<i32>(array)
}
LargeList(_) => {
let array = as_large_list_array(&args[0])?;
let array = as_large_list_array(&array_dim)?;
general_list_ndims::<i64>(array)
}
array_type => exec_err!("array_ndims does not support type {array_type:?}"),
Expand Down
14 changes: 6 additions & 8 deletions datafusion/functions-nested/src/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ use datafusion_common::cast::{
as_int64_array,
};
use datafusion_common::utils::coerced_fixed_size_list_to_list;
use datafusion_common::{exec_err, internal_datafusion_err, Result};
use datafusion_common::{
exec_err, internal_datafusion_err, utils::take_function_args, Result,
};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
Expand Down Expand Up @@ -110,9 +112,7 @@ impl ScalarUDFImpl for ArrayDistance {
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 2 {
return exec_err!("array_distance expects exactly two arguments");
}
let [_, _] = take_function_args(self.name(), arg_types)?;
let mut result = Vec::new();
for arg_type in arg_types {
match arg_type {
Expand Down Expand Up @@ -142,11 +142,9 @@ impl ScalarUDFImpl for ArrayDistance {
}

pub fn array_distance_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_distance expects exactly two arguments");
}
let [array1, array2] = take_function_args("array_distance", args)?;

match (&args[0].data_type(), &args[1].data_type()) {
match (&array1.data_type(), &array2.data_type()) {
(List(_), List(_)) => general_array_distance::<i32>(args),
(LargeList(_), LargeList(_)) => general_array_distance::<i64>(args),
(array_type1, array_type2) => {
Expand Down
12 changes: 5 additions & 7 deletions datafusion/functions-nested/src/empty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use arrow::datatypes::{
DataType::{Boolean, FixedSizeList, LargeList, List},
};
use datafusion_common::cast::as_generic_list_array;
use datafusion_common::{exec_err, plan_err, Result};
use datafusion_common::{exec_err, plan_err, utils::take_function_args, Result};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
Expand Down Expand Up @@ -117,14 +117,12 @@ impl ScalarUDFImpl for ArrayEmpty {

/// Array_empty SQL function
pub fn array_empty_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("array_empty expects one argument");
}
let [array] = take_function_args("array_empty", args)?;

let array_type = args[0].data_type();
let array_type = array.data_type();
match array_type {
List(_) => general_array_empty::<i32>(&args[0]),
LargeList(_) => general_array_empty::<i64>(&args[0]),
List(_) => general_array_empty::<i32>(array),
LargeList(_) => general_array_empty::<i64>(array),
_ => exec_err!("array_empty does not support type '{array_type:?}'."),
}
}
Expand Down
10 changes: 3 additions & 7 deletions datafusion/functions-nested/src/except.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ use arrow::array::{cast::AsArray, Array, ArrayRef, GenericListArray, OffsetSizeT
use arrow::buffer::OffsetBuffer;
use arrow::datatypes::{DataType, FieldRef};
use arrow::row::{RowConverter, SortField};
use datafusion_common::{exec_err, internal_err, HashSet, Result};
use datafusion_common::utils::take_function_args;
use datafusion_common::{internal_err, HashSet, Result};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
Expand Down Expand Up @@ -124,12 +125,7 @@ impl ScalarUDFImpl for ArrayExcept {

/// Array_except SQL function
pub fn array_except_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_except needs two arguments");
}

let array1 = &args[0];
let array2 = &args[1];
let [array1, array2] = take_function_args("array_except", args)?;

match (array1.data_type(), array2.data_type()) {
(DataType::Null, _) | (_, DataType::Null) => Ok(array1.to_owned()),
Expand Down
Loading
Loading