Skip to content

Commit b39fc37

Browse files
committed
feat: Support scalar subquery in WHERE
1 parent 6b006e5 commit b39fc37

File tree

8 files changed

+269
-113
lines changed

8 files changed

+269
-113
lines changed

datafusion/core/src/logical_plan/builder.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ use super::{dfschema::ToDFSchema, expr_rewriter::coerce_plan_expr_for_schema, Di
4747
use super::{exprlist_to_fields, Expr, JoinConstraint, JoinType, LogicalPlan, PlanType};
4848
use crate::logical_plan::{
4949
columnize_expr, normalize_col, normalize_cols, rewrite_sort_cols_by_aggs, Column,
50-
CrossJoin, DFField, DFSchema, DFSchemaRef, Limit, Partitioning, Repartition, Values,
50+
CrossJoin, DFField, DFSchema, DFSchemaRef, Limit, Partitioning, Repartition,
51+
SubqueryType, Values,
5152
};
5253
use crate::sql::utils::group_window_expr_by_sort_keys;
5354

@@ -527,7 +528,7 @@ impl LogicalPlanBuilder {
527528
/// Apply correlated sub query
528529
pub fn subquery(
529530
&self,
530-
subqueries: impl IntoIterator<Item = impl Into<LogicalPlan>>,
531+
subqueries: impl IntoIterator<Item = impl Into<(LogicalPlan, SubqueryType)>>,
531532
) -> Result<Self> {
532533
let subqueries = subqueries.into_iter().map(|l| l.into()).collect::<Vec<_>>();
533534
let schema = Arc::new(Subquery::merged_schema(&self.plan, &subqueries));

datafusion/core/src/logical_plan/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,6 @@ pub use plan::{
6868
CreateCatalogSchema, CreateExternalTable, CreateMemoryTable, CrossJoin, Distinct,
6969
DropTable, EmptyRelation, Filter, JoinConstraint, JoinType, Limit, LogicalPlan,
7070
Partitioning, PlanType, PlanVisitor, Repartition, StringifiedPlan, Subquery,
71-
TableScan, ToStringifiedPlan, Union, Values,
71+
SubqueryType, TableScan, ToStringifiedPlan, Union, Values,
7272
};
7373
pub use registry::FunctionRegistry;

datafusion/core/src/logical_plan/plan.rs

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -268,21 +268,51 @@ pub struct Limit {
268268
#[derive(Clone)]
269269
pub struct Subquery {
270270
/// The list of sub queries
271-
pub subqueries: Vec<LogicalPlan>,
271+
pub subqueries: Vec<(LogicalPlan, SubqueryType)>,
272272
/// The incoming logical plan
273273
pub input: Arc<LogicalPlan>,
274274
/// The schema description of the output
275275
pub schema: DFSchemaRef,
276276
}
277277

278+
/// Subquery type
279+
#[derive(Debug, Clone, Copy, PartialEq)]
280+
pub enum SubqueryType {
281+
/// Scalar (SELECT, WHERE) evaluating to one value
282+
Scalar,
283+
// This will be extended with `Exists` and `AnyAll` types.
284+
}
285+
278286
impl Subquery {
279287
/// Merge schema of main input and correlated subquery columns
280-
pub fn merged_schema(input: &LogicalPlan, subqueries: &[LogicalPlan]) -> DFSchema {
281-
subqueries.iter().fold((**input.schema()).clone(), |a, b| {
282-
let mut res = a;
283-
res.merge(b.schema());
284-
res
285-
})
288+
pub fn merged_schema(
289+
input: &LogicalPlan,
290+
subqueries: &[(LogicalPlan, SubqueryType)],
291+
) -> DFSchema {
292+
subqueries
293+
.iter()
294+
.fold((**input.schema()).clone(), |input_schema, (plan, typ)| {
295+
let mut res = input_schema;
296+
let subquery_schema = Self::transform_dfschema(plan.schema(), *typ);
297+
res.merge(&subquery_schema);
298+
res
299+
})
300+
}
301+
302+
/// Transform DataFusion schema according to subquery type
303+
pub fn transform_dfschema(schema: &DFSchema, typ: SubqueryType) -> DFSchema {
304+
match typ {
305+
SubqueryType::Scalar => schema.clone(),
306+
// Schema will be transformed for `Exists` and `AnyAll`
307+
}
308+
}
309+
310+
/// Transform Arrow field according to subquery type
311+
pub fn transform_field(field: &Field, typ: SubqueryType) -> Field {
312+
match typ {
313+
SubqueryType::Scalar => field.clone(),
314+
// Field will be transformed for `Exists` and `AnyAll`
315+
}
286316
}
287317
}
288318

@@ -585,7 +615,7 @@ impl LogicalPlan {
585615
input, subqueries, ..
586616
}) => vec![input.as_ref()]
587617
.into_iter()
588-
.chain(subqueries.iter())
618+
.chain(subqueries.iter().map(|(q, _)| q))
589619
.collect(),
590620
LogicalPlan::Filter(Filter { input, .. }) => vec![input],
591621
LogicalPlan::Repartition(Repartition { input, .. }) => vec![input],
@@ -728,7 +758,7 @@ impl LogicalPlan {
728758
input, subqueries, ..
729759
}) => {
730760
input.accept(visitor)?;
731-
for input in subqueries {
761+
for (input, _) in subqueries {
732762
if !input.accept(visitor)? {
733763
return Ok(false);
734764
}

datafusion/core/src/optimizer/projection_push_down.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ fn optimize_plan(
456456
input, subqueries, ..
457457
}) => {
458458
let mut subquery_required_columns = HashSet::new();
459-
for subquery in subqueries.iter() {
459+
for subquery in subqueries.iter().map(|(q, _)| q) {
460460
let mut inputs = vec![subquery];
461461
while !inputs.is_empty() {
462462
let mut next_inputs = Vec::new();

datafusion/core/src/optimizer/utils.rs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,17 @@ pub fn from_plan(
161161
alias: alias.clone(),
162162
}))
163163
}
164-
LogicalPlan::Subquery(Subquery { schema, .. }) => {
165-
Ok(LogicalPlan::Subquery(Subquery {
166-
subqueries: inputs[1..inputs.len()].to_vec(),
167-
input: Arc::new(inputs[0].clone()),
168-
schema: schema.clone(),
169-
}))
170-
}
164+
LogicalPlan::Subquery(Subquery {
165+
schema, subqueries, ..
166+
}) => Ok(LogicalPlan::Subquery(Subquery {
167+
subqueries: inputs[1..inputs.len()]
168+
.iter()
169+
.zip(subqueries.iter())
170+
.map(|(input, (_, t))| (input.clone(), *t))
171+
.collect(),
172+
input: Arc::new(inputs[0].clone()),
173+
schema: schema.clone(),
174+
})),
171175
LogicalPlan::TableUDFs(TableUDFs { .. }) => {
172176
Ok(LogicalPlan::TableUDFs(TableUDFs {
173177
expr: expr.to_vec(),

datafusion/core/src/physical_plan/planner.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use crate::logical_plan::plan::{
3131
};
3232
use crate::logical_plan::{
3333
unalias, unnormalize_cols, CrossJoin, DFSchema, Distinct, Expr, Like, LogicalPlan,
34-
Operator, Partitioning as LogicalPartitioning, PlanType, Repartition,
34+
Operator, Partitioning as LogicalPartitioning, PlanType, Repartition, SubqueryType,
3535
ToStringifiedPlan, Union, UserDefinedLogicalNode,
3636
};
3737
use crate::logical_plan::{Limit, Values};
@@ -923,11 +923,12 @@ impl DefaultPhysicalPlanner {
923923
new_session_state.execution_props = new_session_state.execution_props.with_outer_query_cursor(cursor.clone());
924924
new_session_state.config.target_partitions = 1;
925925
let subqueries = futures::stream::iter(subqueries)
926-
.then(|lp| self.create_initial_plan(lp, &new_session_state))
926+
.then(|(lp, _)| self.create_initial_plan(lp, &new_session_state))
927927
.try_collect::<Vec<_>>()
928928
.await?.into_iter()
929-
.map(|p| -> Arc<dyn ExecutionPlan> {
930-
Arc::new(CoalescePartitionsExec::new(p))
929+
.zip(subqueries.iter())
930+
.map(|(p, (_, t))| -> (Arc<dyn ExecutionPlan>, SubqueryType) {
931+
(Arc::new(CoalescePartitionsExec::new(p)), *t)
931932
})
932933
.collect::<Vec<_>>();
933934
let input = self.create_initial_plan(input, &new_session_state).await?;

datafusion/core/src/physical_plan/subquery.rs

Lines changed: 82 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use std::sync::Arc;
2828
use std::task::{Context, Poll};
2929

3030
use crate::error::{DataFusionError, Result};
31+
use crate::logical_plan::{Subquery, SubqueryType};
3132
use crate::physical_plan::{DisplayFormatType, ExecutionPlan, Partitioning};
3233
use arrow::array::new_null_array;
3334
use arrow::datatypes::{Schema, SchemaRef};
@@ -46,7 +47,7 @@ use futures::stream::StreamExt;
4647
#[derive(Debug)]
4748
pub struct SubqueryExec {
4849
/// Sub queries
49-
subqueries: Vec<Arc<dyn ExecutionPlan>>,
50+
subqueries: Vec<(Arc<dyn ExecutionPlan>, SubqueryType)>,
5051
/// Merged schema
5152
schema: SchemaRef,
5253
/// The input plan
@@ -58,15 +59,22 @@ pub struct SubqueryExec {
5859
impl SubqueryExec {
5960
/// Create a projection on an input
6061
pub fn try_new(
61-
subqueries: Vec<Arc<dyn ExecutionPlan>>,
62+
subqueries: Vec<(Arc<dyn ExecutionPlan>, SubqueryType)>,
6263
input: Arc<dyn ExecutionPlan>,
6364
cursor: Arc<OuterQueryCursor>,
6465
) -> Result<Self> {
6566
let input_schema = input.schema();
6667

6768
let mut total_fields = input_schema.fields().clone();
68-
for q in subqueries.iter() {
69-
total_fields.append(&mut q.schema().fields().clone());
69+
for (q, t) in subqueries.iter() {
70+
total_fields.append(
71+
&mut q
72+
.schema()
73+
.fields()
74+
.iter()
75+
.map(|f| Subquery::transform_field(f, *t))
76+
.collect(),
77+
);
7078
}
7179

7280
let merged_schema = Schema::new_with_metadata(total_fields, HashMap::new());
@@ -100,7 +108,7 @@ impl ExecutionPlan for SubqueryExec {
100108

101109
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
102110
let mut res = vec![self.input.clone()];
103-
res.extend(self.subqueries.iter().cloned());
111+
res.extend(self.subqueries.iter().map(|(i, _)| i).cloned());
104112
res
105113
}
106114

@@ -134,7 +142,13 @@ impl ExecutionPlan for SubqueryExec {
134142
}
135143

136144
Ok(Arc::new(SubqueryExec::try_new(
137-
children.iter().skip(1).cloned().collect(),
145+
children
146+
.iter()
147+
.skip(1)
148+
.cloned()
149+
.zip(self.subqueries.iter())
150+
.map(|(p, (_, t))| (p, *t))
151+
.collect(),
138152
children[0].clone(),
139153
self.cursor.clone(),
140154
)?))
@@ -151,71 +165,78 @@ impl ExecutionPlan for SubqueryExec {
151165
let context = context.clone();
152166
let size_hint = stream.size_hint();
153167
let schema = self.schema.clone();
154-
let res_stream =
155-
stream.then(move |batch| {
156-
let cursor = cursor.clone();
157-
let context = context.clone();
158-
let subqueries = subqueries.clone();
159-
let schema = schema.clone();
160-
async move {
161-
let batch = batch?;
162-
let b = Arc::new(batch.clone());
163-
cursor.set_batch(b)?;
164-
let mut subquery_arrays = vec![Vec::new(); subqueries.len()];
165-
for i in 0..batch.num_rows() {
166-
cursor.set_position(i)?;
167-
for (subquery_i, subquery) in subqueries.iter().enumerate() {
168-
let null_array = || {
169-
let schema = subquery.schema();
170-
let fields = schema.fields();
171-
if fields.len() != 1 {
172-
return Err(ArrowError::ComputeError(format!(
173-
"Sub query should have only one column but got {}",
174-
fields.len()
175-
)));
176-
}
177-
178-
let data_type = fields.get(0).unwrap().data_type();
179-
Ok(new_null_array(data_type, 1))
180-
};
168+
let res_stream = stream.then(move |batch| {
169+
let cursor = cursor.clone();
170+
let context = context.clone();
171+
let subqueries = subqueries.clone();
172+
let schema = schema.clone();
173+
async move {
174+
let batch = batch?;
175+
let b = Arc::new(batch.clone());
176+
cursor.set_batch(b)?;
177+
let mut subquery_arrays = vec![Vec::new(); subqueries.len()];
178+
for i in 0..batch.num_rows() {
179+
cursor.set_position(i)?;
180+
for (subquery_i, (subquery, subquery_type)) in
181+
subqueries.iter().enumerate()
182+
{
183+
let schema = subquery.schema();
184+
let fields = schema.fields();
185+
if fields.len() != 1 {
186+
return Err(ArrowError::ComputeError(format!(
187+
"Sub query should have only one column but got {}",
188+
fields.len()
189+
)));
190+
}
191+
let data_type = fields.get(0).unwrap().data_type();
192+
let null_array = || new_null_array(data_type, 1);
181193

182-
if subquery.output_partitioning().partition_count() != 1 {
183-
return Err(ArrowError::ComputeError(format!(
184-
"Sub query should have only one partition but got {}",
185-
subquery.output_partitioning().partition_count()
186-
)));
187-
}
188-
let mut stream = subquery.execute(0, context.clone()).await?;
189-
let res = stream.next().await;
190-
if let Some(subquery_batch) = res {
191-
let subquery_batch = subquery_batch?;
192-
match subquery_batch.column(0).len() {
193-
0 => subquery_arrays[subquery_i].push(null_array()?),
194+
if subquery.output_partitioning().partition_count() != 1 {
195+
return Err(ArrowError::ComputeError(format!(
196+
"Sub query should have only one partition but got {}",
197+
subquery.output_partitioning().partition_count()
198+
)));
199+
}
200+
let mut stream = subquery.execute(0, context.clone()).await?;
201+
let res = stream.next().await;
202+
if let Some(subquery_batch) = res {
203+
let subquery_batch = subquery_batch?;
204+
match subquery_type {
205+
SubqueryType::Scalar => match subquery_batch
206+
.column(0)
207+
.len()
208+
{
209+
0 => subquery_arrays[subquery_i].push(null_array()),
194210
1 => subquery_arrays[subquery_i]
195211
.push(subquery_batch.column(0).clone()),
196212
_ => return Err(ArrowError::ComputeError(
197213
"Sub query should return no more than one row"
198214
.to_string(),
199215
)),
200-
};
201-
} else {
202-
subquery_arrays[subquery_i].push(null_array()?);
203-
}
216+
},
217+
};
218+
} else {
219+
match subquery_type {
220+
SubqueryType::Scalar => {
221+
subquery_arrays[subquery_i].push(null_array())
222+
}
223+
};
204224
}
205225
}
206-
let mut new_columns = batch.columns().to_vec();
207-
for subquery_array in subquery_arrays {
208-
new_columns.push(concat(
209-
subquery_array
210-
.iter()
211-
.map(|a| a.as_ref())
212-
.collect::<Vec<_>>()
213-
.as_slice(),
214-
)?);
215-
}
216-
RecordBatch::try_new(schema.clone(), new_columns)
217226
}
218-
});
227+
let mut new_columns = batch.columns().to_vec();
228+
for subquery_array in subquery_arrays {
229+
new_columns.push(concat(
230+
subquery_array
231+
.iter()
232+
.map(|a| a.as_ref())
233+
.collect::<Vec<_>>()
234+
.as_slice(),
235+
)?);
236+
}
237+
RecordBatch::try_new(schema.clone(), new_columns)
238+
}
239+
});
219240
Ok(Box::pin(SubQueryStream {
220241
schema: self.schema.clone(),
221242
stream: Box::pin(res_stream),

0 commit comments

Comments
 (0)