Skip to content

Commit

Permalink
fix: count_wildcard_to_time_index_rule doesn't handle table reference…
Browse files Browse the repository at this point in the history
… properly (#3847)

* validate time index col

Signed-off-by: Ruihang Xia <[email protected]>

* use TableReference instead

Signed-off-by: Ruihang Xia <[email protected]>

* add more tests

Signed-off-by: Ruihang Xia <[email protected]>

---------

Signed-off-by: Ruihang Xia <[email protected]>
  • Loading branch information
waynexia authored Apr 30, 2024
1 parent e84b1ee commit e6eca8c
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 9 deletions.
79 changes: 70 additions & 9 deletions src/query/src/optimizer/count_wildcard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ use datafusion::datasource::DefaultTableSource;
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor,
};
use datafusion_common::Result as DataFusionResult;
use datafusion_common::{Column, Result as DataFusionResult};
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, WindowFunction};
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::{col, lit, Expr, LogicalPlan, WindowFunctionDefinition};
use datafusion_optimizer::utils::NamePreserver;
use datafusion_optimizer::AnalyzerRule;
use datafusion_sql::TableReference;
use table::table::adapter::DfTableProviderAdapter;

/// A replacement to DataFusion's [`CountWildcardRule`]. This rule
Expand Down Expand Up @@ -77,11 +78,27 @@ impl CountWildcardToTimeIndexRule {
})
}

fn try_find_time_index_col(plan: &LogicalPlan) -> Option<String> {
fn try_find_time_index_col(plan: &LogicalPlan) -> Option<Column> {
let mut finder = TimeIndexFinder::default();
// Safety: `TimeIndexFinder` won't throw error.
plan.visit(&mut finder).unwrap();
finder.time_index
let col = finder.into_column();

// check if the time index is a valid column as for current plan
if let Some(col) = &col {
let mut is_valid = false;
for input in plan.inputs() {
if input.schema().has_column(col) {
is_valid = true;
break;
}
}
if !is_valid {
return None;
}
}

col
}
}

Expand Down Expand Up @@ -114,16 +131,16 @@ impl CountWildcardToTimeIndexRule {

#[derive(Default)]
struct TimeIndexFinder {
time_index: Option<String>,
table_alias: Option<String>,
time_index_col: Option<String>,
table_alias: Option<TableReference>,
}

impl TreeNodeVisitor for TimeIndexFinder {
type Node = LogicalPlan;

fn f_down(&mut self, node: &Self::Node) -> DataFusionResult<TreeNodeRecursion> {
if let LogicalPlan::SubqueryAlias(subquery_alias) = node {
self.table_alias = Some(subquery_alias.alias.to_string());
self.table_alias = Some(subquery_alias.alias.clone());
}

if let LogicalPlan::TableScan(table_scan) = &node {
Expand All @@ -138,9 +155,13 @@ impl TreeNodeVisitor for TimeIndexFinder {
.downcast_ref::<DfTableProviderAdapter>()
{
let table_info = adapter.table().table_info();
let col_name = table_info.meta.schema.timestamp_column().map(|c| &c.name);
let table_name = self.table_alias.as_ref().unwrap_or(&table_info.name);
self.time_index = col_name.map(|s| format!("{}.{}", table_name, s));
self.table_alias
.get_or_insert(TableReference::bare(table_info.name.clone()));
self.time_index_col = table_info
.meta
.schema
.timestamp_column()
.map(|c| c.name.clone());

return Ok(TreeNodeRecursion::Stop);
}
Expand All @@ -154,3 +175,43 @@ impl TreeNodeVisitor for TimeIndexFinder {
Ok(TreeNodeRecursion::Stop)
}
}

impl TimeIndexFinder {
fn into_column(self) -> Option<Column> {
self.time_index_col
.map(|c| Column::new(self.table_alias, c))
}
}

#[cfg(test)]
mod test {
use std::sync::Arc;

use datafusion_expr::{count, wildcard, LogicalPlanBuilder};
use table::table::numbers::NumbersTable;

use super::*;

#[test]
fn uppercase_table_name() {
let numbers_table = NumbersTable::table_with_name(0, "AbCdE".to_string());
let table_source = Arc::new(DefaultTableSource::new(Arc::new(
DfTableProviderAdapter::new(numbers_table),
)));

let plan = LogicalPlanBuilder::scan_with_filters("t", table_source, None, vec![])
.unwrap()
.aggregate(Vec::<Expr>::new(), vec![count(wildcard())])
.unwrap()
.alias(r#""FgHiJ""#)
.unwrap()
.build()
.unwrap();

let mut finder = TimeIndexFinder::default();
plan.visit(&mut finder).unwrap();

assert_eq!(finder.table_alias, Some(TableReference::bare("FgHiJ")));
assert!(finder.time_index_col.is_none());
}
}
56 changes: 56 additions & 0 deletions tests/cases/standalone/common/aggregate/count.result
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
create table "HelloWorld" (a string, b timestamp time index);

Affected Rows: 0

insert into "HelloWorld" values ("a", 1) ,("b", 2);

Affected Rows: 2

select count(*) from "HelloWorld";

+----------+
| COUNT(*) |
+----------+
| 2 |
+----------+

create table test (a string, "BbB" timestamp time index);

Affected Rows: 0

insert into test values ("c", 1) ;

Affected Rows: 1

select count(*) from test;

+----------+
| COUNT(*) |
+----------+
| 1 |
+----------+

select count(*) from (select count(*) from test where a = 'a');

+----------+
| COUNT(*) |
+----------+
| 1 |
+----------+

select count(*) from (select * from test cross join "HelloWorld");

+----------+
| COUNT(*) |
+----------+
| 2 |
+----------+

drop table "HelloWorld";

Affected Rows: 0

drop table test;

Affected Rows: 0

19 changes: 19 additions & 0 deletions tests/cases/standalone/common/aggregate/count.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
create table "HelloWorld" (a string, b timestamp time index);

insert into "HelloWorld" values ("a", 1) ,("b", 2);

select count(*) from "HelloWorld";

create table test (a string, "BbB" timestamp time index);

insert into test values ("c", 1) ;

select count(*) from test;

select count(*) from (select count(*) from test where a = 'a');

select count(*) from (select * from test cross join "HelloWorld");

drop table "HelloWorld";

drop table test;

0 comments on commit e6eca8c

Please sign in to comment.