diff --git a/rust/cubesqlplanner/cubesqlplanner/src/logical_plan/multistage/common.rs b/rust/cubesqlplanner/cubesqlplanner/src/logical_plan/multistage/common.rs index 7a3b86859b438..5ca1ddd776f6a 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/logical_plan/multistage/common.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/logical_plan/multistage/common.rs @@ -53,9 +53,6 @@ impl PrettyPrint for MultiStageAppliedState { } result.println("time_shifts:", &state); - if let Some(common) = &self.time_shifts().common_time_shift { - result.println(&format!("- common: {}", common.to_sql()), &details_state); - } for (_, time_shift) in self.time_shifts().dimensions_shifts.iter() { result.println( &format!( diff --git a/rust/cubesqlplanner/cubesqlplanner/src/logical_plan/multistage/leaf_measure.rs b/rust/cubesqlplanner/cubesqlplanner/src/logical_plan/multistage/leaf_measure.rs index 394c79125451c..214ab412fb8d8 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/logical_plan/multistage/leaf_measure.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/logical_plan/multistage/leaf_measure.rs @@ -22,12 +22,9 @@ impl PrettyPrint for MultiStageLeafMeasure { if self.render_measure_for_ungrouped { result.println("render_measure_for_ungrouped: true", &state); } - if !self.time_shifts.dimensions_shifts.is_empty() { + if !self.time_shifts.is_empty() { result.println("time_shifts:", &state); let details_state = state.new_level(); - if let Some(common) = &self.time_shifts.common_time_shift { - result.println(&format!("- common: {}", common.to_sql()), &details_state); - } for (_, time_shift) in self.time_shifts.dimensions_shifts.iter() { result.println( &format!( diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/planners/multi_stage/applied_state.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/planners/multi_stage/applied_state.rs index 150cbf4a85265..8ccbec0c29b68 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/planners/multi_stage/applied_state.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/planners/multi_stage/applied_state.rs @@ -1,7 +1,7 @@ use crate::plan::{FilterGroup, FilterItem}; use crate::planner::filter::FilterOperator; use crate::planner::sql_evaluator::{DimensionTimeShift, MeasureTimeShifts, MemberSymbol}; -use crate::planner::{BaseDimension, BaseMember, BaseTimeDimension, SqlInterval}; +use crate::planner::{BaseDimension, BaseMember, BaseTimeDimension}; use itertools::Itertools; use std::cmp::PartialEq; use std::collections::HashMap; @@ -11,12 +11,11 @@ use std::rc::Rc; #[derive(Clone, Default, Debug)] pub struct TimeShiftState { pub dimensions_shifts: HashMap, - pub common_time_shift: Option, } impl TimeShiftState { pub fn is_empty(&self) -> bool { - self.dimensions_shifts.is_empty() && self.common_time_shift.is_none() + self.dimensions_shifts.is_empty() } } @@ -74,28 +73,28 @@ impl MultiStageAppliedState { } pub fn add_time_shifts(&mut self, time_shifts: MeasureTimeShifts) { - match time_shifts { - MeasureTimeShifts::Dimensions(dimensions) => { - for ts in dimensions.into_iter() { - if let Some(exists) = self - .time_shifts - .dimensions_shifts - .get_mut(&ts.dimension.full_name()) - { - exists.interval += ts.interval; - } else { - self.time_shifts - .dimensions_shifts - .insert(ts.dimension.full_name(), ts); - } - } - } - MeasureTimeShifts::Common(interval) => { - if let Some(common) = self.time_shifts.common_time_shift.as_mut() { - *common += interval; - } else { - self.time_shifts.common_time_shift = Some(interval); - } + let resolved_shifts = match time_shifts { + MeasureTimeShifts::Dimensions(dimensions) => dimensions, + MeasureTimeShifts::Common(interval) => self + .all_time_members() + .into_iter() + .map(|m| DimensionTimeShift { + interval: interval.clone(), + dimension: m, + }) + .collect_vec(), + }; + for ts in resolved_shifts.into_iter() { + if let Some(exists) = self + .time_shifts + .dimensions_shifts + .get_mut(&ts.dimension.full_name()) + { + exists.interval += ts.interval; + } else { + self.time_shifts + .dimensions_shifts + .insert(ts.dimension.full_name(), ts); } } } @@ -104,6 +103,40 @@ impl MultiStageAppliedState { &self.time_shifts } + fn all_time_members(&self) -> Vec> { + let mut filter_symbols = self.all_dimensions_symbols(); + for filter_item in self + .time_dimensions_filters + .iter() + .chain(self.dimensions_filters.iter()) + .chain(self.segments.iter()) + { + filter_item.find_all_member_evaluators(&mut filter_symbols); + } + + let time_symbols = filter_symbols + .into_iter() + .filter_map(|m| { + let symbol = if let Ok(time_dim) = m.as_time_dimension() { + time_dim.base_symbol().clone().resolve_reference_chain() + } else { + m.resolve_reference_chain() + }; + if let Ok(dim) = symbol.as_dimension() { + if dim.dimension_type() == "time" { + Some(symbol) + } else { + None + } + } else { + None + } + }) + .unique_by(|s| s.full_name()) + .collect_vec(); + time_symbols + } + pub fn time_dimensions_filters(&self) -> &Vec { &self.time_dimensions_filters } @@ -122,6 +155,14 @@ impl MultiStageAppliedState { .collect() } + pub fn all_dimensions_symbols(&self) -> Vec> { + self.time_dimensions + .iter() + .map(|d| d.member_evaluator().clone()) + .chain(self.dimensions.iter().map(|d| d.member_evaluator().clone())) + .collect() + } + pub fn dimensions_filters(&self) -> &Vec { &self.dimensions_filters } @@ -343,7 +384,6 @@ impl PartialEq for MultiStageAppliedState { && self.time_dimensions_filters == other.time_dimensions_filters && self.dimensions_filters == other.dimensions_filters && self.measures_filters == other.measures_filters - && self.time_shifts.common_time_shift == other.time_shifts.common_time_shift && self.time_shifts.dimensions_shifts == other.time_shifts.dimensions_shifts } } diff --git a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/time_shift.rs b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/time_shift.rs index d658fad328906..2f30043a0df38 100644 --- a/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/time_shift.rs +++ b/rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/time_shift.rs @@ -4,7 +4,6 @@ use crate::planner::query_tools::QueryTools; use crate::planner::sql_evaluator::MemberSymbol; use crate::planner::sql_evaluator::SqlEvaluatorVisitor; use crate::planner::sql_templates::PlanSqlTemplates; -use crate::planner::SqlInterval; use cubenativeutils::CubeError; use std::any::Any; use std::rc::Rc; @@ -43,16 +42,12 @@ impl SqlNode for TimeShiftSqlNode { let res = match node.as_ref() { MemberSymbol::Dimension(ev) => { if !ev.is_reference() && ev.dimension_type() == "time" { - let mut interval = self.shifts.common_time_shift.clone().unwrap_or_default(); if let Some(shift) = self.shifts.dimensions_shifts.get(&ev.full_name()) { - interval += &shift.interval; - } - if interval == SqlInterval::default() { - input - } else { - let shift = interval.to_sql(); + let shift = shift.interval.to_sql(); let res = templates.add_timestamp_interval(input, shift)?; format!("({})", res) + } else { + input } } else { input