Skip to content

Commit

Permalink
use config context in plan compiler
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aalam committed Sep 20, 2024
1 parent 30952bb commit 03f25b5
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 83 deletions.
4 changes: 3 additions & 1 deletion src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,9 +537,11 @@ class ConfigContext:
def __init__(self, session) -> None:
self.session = session
self.configs = {
"cte_optimization_enabled",
"_query_compilation_stage_enabled",
"cte_optimization_enabled",
"eliminate_numeric_sql_value_cast_enabled",
"large_query_breakdown_complexity_bounds",
"large_query_breakdown_enabled",
}

def __getattr__(self, name: str) -> Any:
Expand Down
36 changes: 19 additions & 17 deletions src/snowflake/snowpark/_internal/compiler/plan_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
get_complexity_score,
)
from snowflake.snowpark._internal.analyzer.snowflake_plan import (
ConfigContext,
PlanQueryType,
Query,
SnowflakePlan,
Expand Down Expand Up @@ -47,14 +48,7 @@ class PlanCompiler:
def __init__(self, plan: SnowflakePlan) -> None:
self._plan = plan
current_session = self._plan.session
self.cte_optimization_enabled = current_session.cte_optimization_enabled
self.large_query_breakdown_enabled = (
current_session.large_query_breakdown_enabled
)
self.query_compilation_stage_enabled = (
current_session._query_compilation_stage_enabled
)
self.complexity_bounds = current_session.large_query_breakdown_complexity_bounds
self.config_context = ConfigContext(current_session)

def should_start_query_compilation(self) -> bool:
"""
Expand All @@ -74,11 +68,18 @@ def should_start_query_compilation(self) -> bool:
return (
not isinstance(current_session._conn, MockServerConnection)
and (self._plan.source_plan is not None)
and self.query_compilation_stage_enabled
and (self.cte_optimization_enabled or self.large_query_breakdown_enabled)
and self.config_context._query_compilation_stage_enabled
and (
self.config_context.cte_optimization_enabled
or self.config_context.large_query_breakdown_enabled
)
)

def compile(self) -> Dict[PlanQueryType, List[Query]]:
with self.config_context:
return self._compile()

def _compile(self) -> Dict[PlanQueryType, List[Query]]:
if self.should_start_query_compilation():
# preparation for compilation
# 1. make a copy of the original plan
Expand All @@ -95,7 +96,7 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
# 3. apply each optimizations if needed
# CTE optimization
cte_start_time = time.time()
if self.cte_optimization_enabled:
if self.config_context.cte_optimization_enabled:
repeated_subquery_eliminator = RepeatedSubqueryElimination(
logical_plans, query_generator
)
Expand All @@ -108,12 +109,12 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
]

# Large query breakdown
if self.large_query_breakdown_enabled:
if self.config_context.large_query_breakdown_enabled:
large_query_breakdown = LargeQueryBreakdown(
self._plan.session,
query_generator,
logical_plans,
self.complexity_bounds,
self.config_context.large_query_breakdown_complexity_bounds,
)
logical_plans = large_query_breakdown.apply()

Expand All @@ -133,9 +134,9 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
total_time = time.time() - start_time
session = self._plan.session
summary_value = {
TelemetryField.CTE_OPTIMIZATION_ENABLED.value: self.cte_optimization_enabled,
TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: self.large_query_breakdown_enabled,
CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: self.complexity_bounds,
TelemetryField.CTE_OPTIMIZATION_ENABLED.value: self.config_context.cte_optimization_enabled,
TelemetryField.LARGE_QUERY_BREAKDOWN_ENABLED.value: self.config_context.large_query_breakdown_enabled,
CompilationStageTelemetryField.COMPLEXITY_SCORE_BOUNDS.value: self.config_context.large_query_breakdown_complexity_bounds,
CompilationStageTelemetryField.TIME_TAKEN_FOR_COMPILATION.value: total_time,
CompilationStageTelemetryField.TIME_TAKEN_FOR_DEEP_COPY_PLAN.value: deep_copy_time,
CompilationStageTelemetryField.TIME_TAKEN_FOR_CTE_OPTIMIZATION.value: cte_time,
Expand All @@ -153,7 +154,8 @@ def compile(self) -> Dict[PlanQueryType, List[Query]]:
else:
final_plan = self._plan
final_plan = final_plan.replace_repeated_subquery_with_cte(
self.cte_optimization_enabled, self.query_compilation_stage_enabled
self.config_context.cte_optimization_enabled,
self.config_context._query_compilation_stage_enabled,
)
return {
PlanQueryType.QUERIES: final_plan.queries,
Expand Down
Loading

0 comments on commit 03f25b5

Please sign in to comment.