Skip to content

Commit

Permalink
Remove mutability and update return types
Browse files Browse the repository at this point in the history
  • Loading branch information
jkosh44 committed Feb 10, 2025
1 parent 0de206b commit 3a7c0e6
Show file tree
Hide file tree
Showing 13 changed files with 51 additions and 160 deletions.
55 changes: 12 additions & 43 deletions datafusion/common/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -591,76 +591,45 @@ pub fn base_type(data_type: &DataType) -> DataType {
}
}

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum ArrayFunctionMutability {
/// The array function does not modify the array.
Immutable,
/// The array function does modify the array.
Mutable,
}

/// A helper function to coerce base type in List.
///
/// Example
/// ```
/// use arrow::datatypes::{DataType, Field};
/// use datafusion_common::utils::{coerced_type_with_base_type_only, ArrayFunctionMutability};
/// use datafusion_common::utils::{coerced_type_with_base_type_only};
/// use std::sync::Arc;
///
/// let data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true)));
/// let base_type = DataType::Float64;
/// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type, &ArrayFunctionMutability::Mutable);
/// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type);
/// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new_list_field(DataType::Float64, true))));
pub fn coerced_type_with_base_type_only(
data_type: &DataType,
base_type: &DataType,
mutability: &ArrayFunctionMutability,
) -> DataType {
match (data_type, mutability) {
(DataType::List(field), _) => {
let field_type = coerced_type_with_base_type_only(
field.data_type(),
base_type,
mutability,
);

DataType::List(Arc::new(Field::new(
field.name(),
field_type,
field.is_nullable(),
)))
}
(DataType::FixedSizeList(field, _), ArrayFunctionMutability::Mutable) => {
let field_type = coerced_type_with_base_type_only(
field.data_type(),
base_type,
mutability,
);
match data_type {
DataType::List(field) => {
let field_type =
coerced_type_with_base_type_only(field.data_type(), base_type);

DataType::List(Arc::new(Field::new(
field.name(),
field_type,
field.is_nullable(),
)))
}
(DataType::FixedSizeList(field, len), ArrayFunctionMutability::Immutable) => {
let field_type = coerced_type_with_base_type_only(
field.data_type(),
base_type,
mutability,
);
DataType::FixedSizeList(field, len) => {
let field_type =
coerced_type_with_base_type_only(field.data_type(), base_type);

DataType::FixedSizeList(
Arc::new(Field::new(field.name(), field_type, field.is_nullable())),
*len,
)
}
(DataType::LargeList(field), _) => {
let field_type = coerced_type_with_base_type_only(
field.data_type(),
base_type,
mutability,
);
DataType::LargeList(field) => {
let field_type =
coerced_type_with_base_type_only(field.data_type(), base_type);

DataType::LargeList(Arc::new(Field::new(
field.name(),
Expand Down
31 changes: 5 additions & 26 deletions datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ use std::fmt::Display;
use crate::type_coercion::aggregates::NUMERICS;
use arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
use datafusion_common::types::{LogicalTypeRef, NativeType};
use datafusion_common::utils::ArrayFunctionMutability;
use itertools::Itertools;

/// Constant that is used as a placeholder for any valid timezone.
Expand Down Expand Up @@ -231,8 +230,6 @@ pub enum ArrayFunctionSignature {
Array {
/// A full list of the arguments accepted by this function.
arguments: ArrayFunctionArguments,
/// Whether any of the input arrays are modified.
mutability: ArrayFunctionMutability,
},
/// A function takes a single argument that must be a List/LargeList/FixedSizeList
/// which gets coerced to List, with element type recursively coerced to List too if it is list-like.
Expand Down Expand Up @@ -613,10 +610,7 @@ impl Signature {
}
}
/// Specialized Signature for ArrayAppend and similar functions
pub fn array_and_element(
volatility: Volatility,
mutability: ArrayFunctionMutability,
) -> Self {
pub fn array_and_element(volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::ArraySignature(
ArrayFunctionSignature::Array {
Expand All @@ -625,17 +619,13 @@ impl Signature {
ArrayFunctionArgument::Element,
])
.expect("contains array"),
mutability,
},
),
volatility,
}
}
/// Specialized Signature for Array functions with an optional index
pub fn array_and_element_and_optional_index(
volatility: Volatility,
mutability: ArrayFunctionMutability,
) -> Self {
pub fn array_and_element_and_optional_index(volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::OneOf(vec![
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
Expand All @@ -644,7 +634,6 @@ impl Signature {
ArrayFunctionArgument::Element,
])
.expect("contains array"),
mutability: mutability.clone(),
}),
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
arguments: ArrayFunctionArguments::new(vec![
Expand All @@ -653,17 +642,13 @@ impl Signature {
ArrayFunctionArgument::Index,
])
.expect("contains array"),
mutability,
}),
]),
volatility,
}
}
/// Specialized Signature for ArrayPrepend and similar functions
pub fn element_and_array(
volatility: Volatility,
mutability: ArrayFunctionMutability,
) -> Self {
pub fn element_and_array(volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::ArraySignature(
ArrayFunctionSignature::Array {
Expand All @@ -672,17 +657,13 @@ impl Signature {
ArrayFunctionArgument::Array,
])
.expect("contains array"),
mutability,
},
),
volatility,
}
}
/// Specialized Signature for ArrayElement and similar functions
pub fn array_and_index(
volatility: Volatility,
mutability: ArrayFunctionMutability,
) -> Self {
pub fn array_and_index(volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::ArraySignature(
ArrayFunctionSignature::Array {
Expand All @@ -691,22 +672,20 @@ impl Signature {
ArrayFunctionArgument::Index,
])
.expect("contains array"),
mutability,
},
),
volatility,
}
}
/// Specialized Signature for ArrayEmpty and similar functions
pub fn array(volatility: Volatility, mutability: ArrayFunctionMutability) -> Self {
pub fn array(volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::ArraySignature(
ArrayFunctionSignature::Array {
arguments: ArrayFunctionArguments::new(vec![
ArrayFunctionArgument::Array,
])
.expect("contains array"),
mutability,
},
),
volatility,
Expand Down
11 changes: 3 additions & 8 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ use arrow::{
compute::can_cast_types,
datatypes::{DataType, TimeUnit},
};
use datafusion_common::utils::{
coerced_fixed_size_list_to_list, ArrayFunctionMutability,
};
use datafusion_common::utils::coerced_fixed_size_list_to_list;
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err,
types::{LogicalType, NativeType},
Expand Down Expand Up @@ -364,7 +362,6 @@ fn get_valid_types(
function_name: &str,
current_types: &[DataType],
arguments: &[ArrayFunctionArgument],
mutability: &ArrayFunctionMutability,
) -> Result<Vec<Vec<DataType>>> {
if current_types.len() != arguments.len() {
return Ok(vec![vec![]]);
Expand Down Expand Up @@ -399,7 +396,6 @@ fn get_valid_types(
let new_array_type = datafusion_common::utils::coerced_type_with_base_type_only(
&array_type,
&new_base_type,
mutability,
);

let new_elem_type = match new_array_type {
Expand All @@ -422,7 +418,6 @@ fn get_valid_types(
datafusion_common::utils::coerced_type_with_base_type_only(
&current_type,
&new_base_type,
mutability,
);
// All array arguments must be coercible to the same type
if new_type != new_array_type {
Expand Down Expand Up @@ -705,8 +700,8 @@ fn get_valid_types(
TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
TypeSignature::ArraySignature(ref function_signature) => {
match function_signature {
ArrayFunctionSignature::Array { arguments, mutability } => {
array_valid_types(function_name, current_types, arguments.inner(), mutability)?
ArrayFunctionSignature::Array { arguments } => {
array_valid_types(function_name, current_types, arguments.inner())?
}
ArrayFunctionSignature::RecursiveArray => {
if current_types.len() != 1 {
Expand Down
6 changes: 1 addition & 5 deletions datafusion/functions-nested/src/array_has.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ use itertools::Itertools;

use crate::utils::make_scalar_function;

use datafusion_common::utils::ArrayFunctionMutability;
use std::any::Any;
use std::sync::Arc;

Expand Down Expand Up @@ -95,10 +94,7 @@ impl Default for ArrayHas {
impl ArrayHas {
pub fn new() -> Self {
Self {
signature: Signature::array_and_element(
Volatility::Immutable,
ArrayFunctionMutability::Immutable,
),
signature: Signature::array_and_element(Volatility::Immutable),
aliases: vec![
String::from("list_has"),
String::from("array_contains"),
Expand Down
2 changes: 0 additions & 2 deletions datafusion/functions-nested/src/cardinality.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use arrow::array::{
use arrow_schema::DataType;
use arrow_schema::DataType::{FixedSizeList, LargeList, List, Map, UInt64};
use datafusion_common::cast::{as_large_list_array, as_list_array, as_map_array};
use datafusion_common::utils::ArrayFunctionMutability;
use datafusion_common::Result;
use datafusion_common::{exec_err, plan_err};
use datafusion_expr::{
Expand Down Expand Up @@ -53,7 +52,6 @@ impl Cardinality {
ArrayFunctionArgument::Array,
])
.expect("contains array"),
mutability: ArrayFunctionMutability::Immutable,
}),
TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray),
],
Expand Down
16 changes: 5 additions & 11 deletions datafusion/functions-nested/src/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use arrow::array::{
};
use arrow::buffer::OffsetBuffer;
use arrow_schema::{DataType, Field};
use datafusion_common::utils::ArrayFunctionMutability;
use datafusion_common::utils::coerced_fixed_size_list_to_list;
use datafusion_common::Result;
use datafusion_common::{
cast::as_generic_list_array, exec_err, not_impl_err, plan_err, utils::list_ndims,
Expand Down Expand Up @@ -79,10 +79,7 @@ impl Default for ArrayAppend {
impl ArrayAppend {
pub fn new() -> Self {
Self {
signature: Signature::array_and_element(
Volatility::Immutable,
ArrayFunctionMutability::Mutable,
),
signature: Signature::array_and_element(Volatility::Immutable),
aliases: vec![
String::from("list_append"),
String::from("array_push_back"),
Expand All @@ -106,7 +103,7 @@ impl ScalarUDFImpl for ArrayAppend {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(arg_types[0].clone())
Ok(coerced_fixed_size_list_to_list(&arg_types[0]))
}

fn invoke_batch(
Expand Down Expand Up @@ -167,10 +164,7 @@ impl Default for ArrayPrepend {
impl ArrayPrepend {
pub fn new() -> Self {
Self {
signature: Signature::element_and_array(
Volatility::Immutable,
ArrayFunctionMutability::Mutable,
),
signature: Signature::element_and_array(Volatility::Immutable),
aliases: vec![
String::from("list_prepend"),
String::from("array_push_front"),
Expand All @@ -194,7 +188,7 @@ impl ScalarUDFImpl for ArrayPrepend {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(arg_types[1].clone())
Ok(coerced_fixed_size_list_to_list(&arg_types[1]))
}

fn invoke_batch(
Expand Down
11 changes: 2 additions & 9 deletions datafusion/functions-nested/src/dimension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ use datafusion_common::{exec_err, plan_err, Result};
use crate::utils::{compute_array_dims, make_scalar_function};
use arrow_schema::DataType::{FixedSizeList, LargeList, List, UInt64};
use arrow_schema::Field;
use datafusion_common::utils::ArrayFunctionMutability;
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
Expand Down Expand Up @@ -76,10 +75,7 @@ impl Default for ArrayDims {
impl ArrayDims {
pub fn new() -> Self {
Self {
signature: Signature::array(
Volatility::Immutable,
ArrayFunctionMutability::Immutable,
),
signature: Signature::array(Volatility::Immutable),
aliases: vec!["list_dims".to_string()],
}
}
Expand Down Expand Up @@ -159,10 +155,7 @@ pub(super) struct ArrayNdims {
impl ArrayNdims {
pub fn new() -> Self {
Self {
signature: Signature::array(
Volatility::Immutable,
ArrayFunctionMutability::Immutable,
),
signature: Signature::array(Volatility::Immutable),
aliases: vec![String::from("list_ndims")],
}
}
Expand Down
6 changes: 1 addition & 5 deletions datafusion/functions-nested/src/empty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ use arrow::array::{ArrayRef, BooleanArray, OffsetSizeTrait};
use arrow_schema::DataType;
use arrow_schema::DataType::{Boolean, FixedSizeList, LargeList, List};
use datafusion_common::cast::as_generic_list_array;
use datafusion_common::utils::ArrayFunctionMutability;
use datafusion_common::{exec_err, plan_err, Result};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
Expand Down Expand Up @@ -70,10 +69,7 @@ impl Default for ArrayEmpty {
impl ArrayEmpty {
pub fn new() -> Self {
Self {
signature: Signature::array(
Volatility::Immutable,
ArrayFunctionMutability::Immutable,
),
signature: Signature::array(Volatility::Immutable),
aliases: vec!["array_empty".to_string(), "list_empty".to_string()],
}
}
Expand Down
Loading

0 comments on commit 3a7c0e6

Please sign in to comment.