@@ -28,6 +28,7 @@ use std::sync::Arc;
28
28
use std:: task:: { Context , Poll } ;
29
29
30
30
use crate :: error:: { DataFusionError , Result } ;
31
+ use crate :: logical_plan:: { Subquery , SubqueryType } ;
31
32
use crate :: physical_plan:: { DisplayFormatType , ExecutionPlan , Partitioning } ;
32
33
use arrow:: array:: new_null_array;
33
34
use arrow:: datatypes:: { Schema , SchemaRef } ;
@@ -46,7 +47,7 @@ use futures::stream::StreamExt;
46
47
#[ derive( Debug ) ]
47
48
pub struct SubqueryExec {
48
49
/// Sub queries
49
- subqueries : Vec < Arc < dyn ExecutionPlan > > ,
50
+ subqueries : Vec < ( Arc < dyn ExecutionPlan > , SubqueryType ) > ,
50
51
/// Merged schema
51
52
schema : SchemaRef ,
52
53
/// The input plan
@@ -58,15 +59,22 @@ pub struct SubqueryExec {
58
59
impl SubqueryExec {
59
60
/// Create a projection on an input
60
61
pub fn try_new (
61
- subqueries : Vec < Arc < dyn ExecutionPlan > > ,
62
+ subqueries : Vec < ( Arc < dyn ExecutionPlan > , SubqueryType ) > ,
62
63
input : Arc < dyn ExecutionPlan > ,
63
64
cursor : Arc < OuterQueryCursor > ,
64
65
) -> Result < Self > {
65
66
let input_schema = input. schema ( ) ;
66
67
67
68
let mut total_fields = input_schema. fields ( ) . clone ( ) ;
68
- for q in subqueries. iter ( ) {
69
- total_fields. append ( & mut q. schema ( ) . fields ( ) . clone ( ) ) ;
69
+ for ( q, t) in subqueries. iter ( ) {
70
+ total_fields. append (
71
+ & mut q
72
+ . schema ( )
73
+ . fields ( )
74
+ . iter ( )
75
+ . map ( |f| Subquery :: transform_field ( f, * t) )
76
+ . collect ( ) ,
77
+ ) ;
70
78
}
71
79
72
80
let merged_schema = Schema :: new_with_metadata ( total_fields, HashMap :: new ( ) ) ;
@@ -100,7 +108,7 @@ impl ExecutionPlan for SubqueryExec {
100
108
101
109
fn children ( & self ) -> Vec < Arc < dyn ExecutionPlan > > {
102
110
let mut res = vec ! [ self . input. clone( ) ] ;
103
- res. extend ( self . subqueries . iter ( ) . cloned ( ) ) ;
111
+ res. extend ( self . subqueries . iter ( ) . map ( | ( i , _ ) | i ) . cloned ( ) ) ;
104
112
res
105
113
}
106
114
@@ -134,7 +142,13 @@ impl ExecutionPlan for SubqueryExec {
134
142
}
135
143
136
144
Ok ( Arc :: new ( SubqueryExec :: try_new (
137
- children. iter ( ) . skip ( 1 ) . cloned ( ) . collect ( ) ,
145
+ children
146
+ . iter ( )
147
+ . skip ( 1 )
148
+ . cloned ( )
149
+ . zip ( self . subqueries . iter ( ) )
150
+ . map ( |( p, ( _, t) ) | ( p, * t) )
151
+ . collect ( ) ,
138
152
children[ 0 ] . clone ( ) ,
139
153
self . cursor . clone ( ) ,
140
154
) ?) )
@@ -151,71 +165,78 @@ impl ExecutionPlan for SubqueryExec {
151
165
let context = context. clone ( ) ;
152
166
let size_hint = stream. size_hint ( ) ;
153
167
let schema = self . schema . clone ( ) ;
154
- let res_stream =
155
- stream. then ( move |batch| {
156
- let cursor = cursor. clone ( ) ;
157
- let context = context. clone ( ) ;
158
- let subqueries = subqueries. clone ( ) ;
159
- let schema = schema. clone ( ) ;
160
- async move {
161
- let batch = batch?;
162
- let b = Arc :: new ( batch. clone ( ) ) ;
163
- cursor. set_batch ( b) ?;
164
- let mut subquery_arrays = vec ! [ Vec :: new( ) ; subqueries. len( ) ] ;
165
- for i in 0 ..batch. num_rows ( ) {
166
- cursor. set_position ( i) ?;
167
- for ( subquery_i, subquery) in subqueries. iter ( ) . enumerate ( ) {
168
- let null_array = || {
169
- let schema = subquery. schema ( ) ;
170
- let fields = schema. fields ( ) ;
171
- if fields. len ( ) != 1 {
172
- return Err ( ArrowError :: ComputeError ( format ! (
173
- "Sub query should have only one column but got {}" ,
174
- fields. len( )
175
- ) ) ) ;
176
- }
177
-
178
- let data_type = fields. get ( 0 ) . unwrap ( ) . data_type ( ) ;
179
- Ok ( new_null_array ( data_type, 1 ) )
180
- } ;
168
+ let res_stream = stream. then ( move |batch| {
169
+ let cursor = cursor. clone ( ) ;
170
+ let context = context. clone ( ) ;
171
+ let subqueries = subqueries. clone ( ) ;
172
+ let schema = schema. clone ( ) ;
173
+ async move {
174
+ let batch = batch?;
175
+ let b = Arc :: new ( batch. clone ( ) ) ;
176
+ cursor. set_batch ( b) ?;
177
+ let mut subquery_arrays = vec ! [ Vec :: new( ) ; subqueries. len( ) ] ;
178
+ for i in 0 ..batch. num_rows ( ) {
179
+ cursor. set_position ( i) ?;
180
+ for ( subquery_i, ( subquery, subquery_type) ) in
181
+ subqueries. iter ( ) . enumerate ( )
182
+ {
183
+ let schema = subquery. schema ( ) ;
184
+ let fields = schema. fields ( ) ;
185
+ if fields. len ( ) != 1 {
186
+ return Err ( ArrowError :: ComputeError ( format ! (
187
+ "Sub query should have only one column but got {}" ,
188
+ fields. len( )
189
+ ) ) ) ;
190
+ }
191
+ let data_type = fields. get ( 0 ) . unwrap ( ) . data_type ( ) ;
192
+ let null_array = || new_null_array ( data_type, 1 ) ;
181
193
182
- if subquery. output_partitioning ( ) . partition_count ( ) != 1 {
183
- return Err ( ArrowError :: ComputeError ( format ! (
184
- "Sub query should have only one partition but got {}" ,
185
- subquery. output_partitioning( ) . partition_count( )
186
- ) ) ) ;
187
- }
188
- let mut stream = subquery. execute ( 0 , context. clone ( ) ) . await ?;
189
- let res = stream. next ( ) . await ;
190
- if let Some ( subquery_batch) = res {
191
- let subquery_batch = subquery_batch?;
192
- match subquery_batch. column ( 0 ) . len ( ) {
193
- 0 => subquery_arrays[ subquery_i] . push ( null_array ( ) ?) ,
194
+ if subquery. output_partitioning ( ) . partition_count ( ) != 1 {
195
+ return Err ( ArrowError :: ComputeError ( format ! (
196
+ "Sub query should have only one partition but got {}" ,
197
+ subquery. output_partitioning( ) . partition_count( )
198
+ ) ) ) ;
199
+ }
200
+ let mut stream = subquery. execute ( 0 , context. clone ( ) ) . await ?;
201
+ let res = stream. next ( ) . await ;
202
+ if let Some ( subquery_batch) = res {
203
+ let subquery_batch = subquery_batch?;
204
+ match subquery_type {
205
+ SubqueryType :: Scalar => match subquery_batch
206
+ . column ( 0 )
207
+ . len ( )
208
+ {
209
+ 0 => subquery_arrays[ subquery_i] . push ( null_array ( ) ) ,
194
210
1 => subquery_arrays[ subquery_i]
195
211
. push ( subquery_batch. column ( 0 ) . clone ( ) ) ,
196
212
_ => return Err ( ArrowError :: ComputeError (
197
213
"Sub query should return no more than one row"
198
214
. to_string ( ) ,
199
215
) ) ,
200
- } ;
201
- } else {
202
- subquery_arrays[ subquery_i] . push ( null_array ( ) ?) ;
203
- }
216
+ } ,
217
+ } ;
218
+ } else {
219
+ match subquery_type {
220
+ SubqueryType :: Scalar => {
221
+ subquery_arrays[ subquery_i] . push ( null_array ( ) )
222
+ }
223
+ } ;
204
224
}
205
225
}
206
- let mut new_columns = batch. columns ( ) . to_vec ( ) ;
207
- for subquery_array in subquery_arrays {
208
- new_columns. push ( concat (
209
- subquery_array
210
- . iter ( )
211
- . map ( |a| a. as_ref ( ) )
212
- . collect :: < Vec < _ > > ( )
213
- . as_slice ( ) ,
214
- ) ?) ;
215
- }
216
- RecordBatch :: try_new ( schema. clone ( ) , new_columns)
217
226
}
218
- } ) ;
227
+ let mut new_columns = batch. columns ( ) . to_vec ( ) ;
228
+ for subquery_array in subquery_arrays {
229
+ new_columns. push ( concat (
230
+ subquery_array
231
+ . iter ( )
232
+ . map ( |a| a. as_ref ( ) )
233
+ . collect :: < Vec < _ > > ( )
234
+ . as_slice ( ) ,
235
+ ) ?) ;
236
+ }
237
+ RecordBatch :: try_new ( schema. clone ( ) , new_columns)
238
+ }
239
+ } ) ;
219
240
Ok ( Box :: pin ( SubQueryStream {
220
241
schema : self . schema . clone ( ) ,
221
242
stream : Box :: pin ( res_stream) ,
0 commit comments