sqlglot.optimizer.unnest_subqueries
1from sqlglot import exp 2from sqlglot.helper import name_sequence 3from sqlglot.optimizer.scope import ScopeType, traverse_scope 4 5 6def unnest_subqueries(expression): 7 """ 8 Rewrite sqlglot AST to convert some predicates with subqueries into joins. 9 10 Convert scalar subqueries into cross joins. 11 Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. 12 13 Example: 14 >>> import sqlglot 15 >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") 16 >>> unnest_subqueries(expression).sql() 17 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' 18 19 Args: 20 expression (sqlglot.Expression): expression to unnest 21 Returns: 22 sqlglot.Expression: unnested expression 23 """ 24 next_alias_name = name_sequence("_u_") 25 26 for scope in traverse_scope(expression): 27 select = scope.expression 28 parent = select.parent_select 29 if not parent: 30 continue 31 if scope.external_columns: 32 decorrelate(select, parent, scope.external_columns, next_alias_name) 33 elif scope.scope_type == ScopeType.SUBQUERY: 34 unnest(select, parent, next_alias_name) 35 36 return expression 37 38 39def unnest(select, parent_select, next_alias_name): 40 if len(select.selects) > 1: 41 return 42 43 predicate = select.find_ancestor(exp.Condition) 44 if ( 45 not predicate 46 or parent_select is not predicate.parent_select 47 or not parent_select.args.get("from") 48 ): 49 return 50 51 if isinstance(select, exp.Union): 52 select = exp.select(*select.selects).from_(select.subquery(next_alias_name())) 53 54 alias = next_alias_name() 55 clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) 56 57 # This subquery returns a scalar and can just be converted to a cross join 58 if not isinstance(predicate, (exp.In, exp.Any)): 59 column = exp.column(select.selects[0].alias_or_name, alias) 60 61 clause_parent_select = clause.parent_select if clause else None 62 63 if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or ( 64 (not clause or clause_parent_select is not parent_select) 65 and ( 66 parent_select.args.get("group") 67 or any(projection.find(exp.AggFunc) for projection in parent_select.selects) 68 ) 69 ): 70 column = exp.Max(this=column) 71 elif not isinstance(select.parent, exp.Subquery): 72 return 73 74 _replace(select.parent, column) 75 parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False) 76 return 77 78 if select.find(exp.Limit, exp.Offset): 79 return 80 81 if isinstance(predicate, exp.Any): 82 predicate = predicate.find_ancestor(exp.EQ) 83 84 if not predicate or parent_select is not predicate.parent_select: 85 return 86 87 column = _other_operand(predicate) 88 value = select.selects[0] 89 90 join_key = exp.column(value.alias, alias) 91 join_key_not_null = join_key.is_(exp.null()).not_() 92 93 if isinstance(clause, exp.Join): 94 _replace(predicate, exp.true()) 95 parent_select.where(join_key_not_null, copy=False) 96 else: 97 _replace(predicate, join_key_not_null) 98 99 group = select.args.get("group") 100 101 if group: 102 if {value.this} != set(group.expressions): 103 select = ( 104 exp.select(exp.column(value.alias, "_q")) 105 .from_(select.subquery("_q", copy=False), copy=False) 106 .group_by(exp.column(value.alias, "_q"), copy=False) 107 ) 108 else: 109 select = select.group_by(value.this, copy=False) 110 111 parent_select.join( 112 select, 113 on=column.eq(join_key), 114 join_type="LEFT", 115 join_alias=alias, 116 copy=False, 117 ) 118 119 120def decorrelate(select, parent_select, external_columns, next_alias_name): 121 where = select.args.get("where") 122 123 if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): 124 return 125 126 table_alias = next_alias_name() 127 keys = [] 128 129 # for all external columns in the where statement, find the relevant predicate 130 # keys to convert it into a join 131 for column in external_columns: 132 if column.find_ancestor(exp.Where) is not where: 133 return 134 135 predicate = column.find_ancestor(exp.Predicate) 136 137 if not predicate or predicate.find_ancestor(exp.Where) is not where: 138 return 139 140 if isinstance(predicate, exp.Binary): 141 key = ( 142 predicate.right 143 if any(node is column for node in predicate.left.walk()) 144 else predicate.left 145 ) 146 else: 147 return 148 149 keys.append((key, column, predicate)) 150 151 if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): 152 return 153 154 is_subquery_projection = any( 155 node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery) 156 ) 157 158 value = select.selects[0] 159 key_aliases = {} 160 group_by = [] 161 162 for key, _, predicate in keys: 163 # if we filter on the value of the subquery, it needs to be unique 164 if key == value.this: 165 key_aliases[key] = value.alias 166 group_by.append(key) 167 else: 168 if key not in key_aliases: 169 key_aliases[key] = next_alias_name() 170 # all predicates that are equalities must also be in the unique 171 # so that we don't do a many to many join 172 if isinstance(predicate, exp.EQ) and key not in group_by: 173 group_by.append(key) 174 175 parent_predicate = select.find_ancestor(exp.Predicate) 176 177 # if the value of the subquery is not an agg or a key, we need to collect it into an array 178 # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. 179 agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg 180 if not value.find(exp.AggFunc) and value.this not in group_by: 181 select.select( 182 exp.alias_(agg_func(this=value.this), value.alias, quoted=False), 183 append=False, 184 copy=False, 185 ) 186 187 # exists queries should not have any selects as it only checks if there are any rows 188 # all selects will be added by the optimizer and only used for join keys 189 if isinstance(parent_predicate, exp.Exists): 190 select.args["expressions"] = [] 191 192 for key, alias in key_aliases.items(): 193 if key in group_by: 194 # add all keys to the projections of the subquery 195 # so that we can use it as a join key 196 if isinstance(parent_predicate, exp.Exists) or key != value.this: 197 select.select(f"{key} AS {alias}", copy=False) 198 else: 199 select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) 200 201 alias = exp.column(value.alias, table_alias) 202 other = _other_operand(parent_predicate) 203 204 if isinstance(parent_predicate, exp.Exists): 205 alias = exp.column(list(key_aliases.values())[0], table_alias) 206 parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") 207 elif isinstance(parent_predicate, exp.All): 208 parent_predicate = _replace( 209 parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" 210 ) 211 elif isinstance(parent_predicate, exp.Any): 212 if value.this in group_by: 213 parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") 214 else: 215 parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})") 216 elif isinstance(parent_predicate, exp.In): 217 if value.this in group_by: 218 parent_predicate = _replace(parent_predicate, f"{other} = {alias}") 219 else: 220 parent_predicate = _replace( 221 parent_predicate, 222 f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", 223 ) 224 else: 225 if is_subquery_projection: 226 alias = exp.alias_(alias, select.parent.alias) 227 228 # COUNT always returns 0 on empty datasets, so we need take that into consideration here 229 # by transforming all counts into 0 and using that as the coalesced value 230 if value.find(exp.Count): 231 232 def remove_aggs(node): 233 if isinstance(node, exp.Count): 234 return exp.Literal.number(0) 235 elif isinstance(node, exp.AggFunc): 236 return exp.null() 237 return node 238 239 alias = exp.Coalesce( 240 this=alias, 241 expressions=[value.this.transform(remove_aggs)], 242 ) 243 244 select.parent.replace(alias) 245 246 for key, column, predicate in keys: 247 predicate.replace(exp.true()) 248 nested = exp.column(key_aliases[key], table_alias) 249 250 if is_subquery_projection: 251 key.replace(nested) 252 continue 253 254 if key in group_by: 255 key.replace(nested) 256 elif isinstance(predicate, exp.EQ): 257 parent_predicate = _replace( 258 parent_predicate, 259 f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", 260 ) 261 else: 262 key.replace(exp.to_identifier("_x")) 263 parent_predicate = _replace( 264 parent_predicate, 265 f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))", 266 ) 267 268 parent_select.join( 269 select.group_by(*group_by, copy=False), 270 on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], 271 join_type="LEFT", 272 join_alias=table_alias, 273 copy=False, 274 ) 275 276 277def _replace(expression, condition): 278 return expression.replace(exp.condition(condition)) 279 280 281def _other_operand(expression): 282 if isinstance(expression, exp.In): 283 return expression.this 284 285 if isinstance(expression, (exp.Any, exp.All)): 286 return _other_operand(expression.parent) 287 288 if isinstance(expression, exp.Binary): 289 return ( 290 expression.right 291 if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All)) 292 else expression.left 293 ) 294 295 return None
def
unnest_subqueries(expression):
7def unnest_subqueries(expression): 8 """ 9 Rewrite sqlglot AST to convert some predicates with subqueries into joins. 10 11 Convert scalar subqueries into cross joins. 12 Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. 13 14 Example: 15 >>> import sqlglot 16 >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") 17 >>> unnest_subqueries(expression).sql() 18 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' 19 20 Args: 21 expression (sqlglot.Expression): expression to unnest 22 Returns: 23 sqlglot.Expression: unnested expression 24 """ 25 next_alias_name = name_sequence("_u_") 26 27 for scope in traverse_scope(expression): 28 select = scope.expression 29 parent = select.parent_select 30 if not parent: 31 continue 32 if scope.external_columns: 33 decorrelate(select, parent, scope.external_columns, next_alias_name) 34 elif scope.scope_type == ScopeType.SUBQUERY: 35 unnest(select, parent, next_alias_name) 36 37 return expression
Rewrite sqlglot AST to convert some predicates with subqueries into joins.
Convert scalar subqueries into cross joins. Convert correlated or vectorized subqueries into a group by so it is not a many to many left join.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") >>> unnest_subqueries(expression).sql() 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
Arguments:
- expression (sqlglot.Expression): expression to unnest
Returns:
sqlglot.Expression: unnested expression
def
unnest(select, parent_select, next_alias_name):
40def unnest(select, parent_select, next_alias_name): 41 if len(select.selects) > 1: 42 return 43 44 predicate = select.find_ancestor(exp.Condition) 45 if ( 46 not predicate 47 or parent_select is not predicate.parent_select 48 or not parent_select.args.get("from") 49 ): 50 return 51 52 if isinstance(select, exp.Union): 53 select = exp.select(*select.selects).from_(select.subquery(next_alias_name())) 54 55 alias = next_alias_name() 56 clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) 57 58 # This subquery returns a scalar and can just be converted to a cross join 59 if not isinstance(predicate, (exp.In, exp.Any)): 60 column = exp.column(select.selects[0].alias_or_name, alias) 61 62 clause_parent_select = clause.parent_select if clause else None 63 64 if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or ( 65 (not clause or clause_parent_select is not parent_select) 66 and ( 67 parent_select.args.get("group") 68 or any(projection.find(exp.AggFunc) for projection in parent_select.selects) 69 ) 70 ): 71 column = exp.Max(this=column) 72 elif not isinstance(select.parent, exp.Subquery): 73 return 74 75 _replace(select.parent, column) 76 parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False) 77 return 78 79 if select.find(exp.Limit, exp.Offset): 80 return 81 82 if isinstance(predicate, exp.Any): 83 predicate = predicate.find_ancestor(exp.EQ) 84 85 if not predicate or parent_select is not predicate.parent_select: 86 return 87 88 column = _other_operand(predicate) 89 value = select.selects[0] 90 91 join_key = exp.column(value.alias, alias) 92 join_key_not_null = join_key.is_(exp.null()).not_() 93 94 if isinstance(clause, exp.Join): 95 _replace(predicate, exp.true()) 96 parent_select.where(join_key_not_null, copy=False) 97 else: 98 _replace(predicate, join_key_not_null) 99 100 group = select.args.get("group") 101 102 if group: 103 if {value.this} != set(group.expressions): 104 select = ( 105 exp.select(exp.column(value.alias, "_q")) 106 .from_(select.subquery("_q", copy=False), copy=False) 107 .group_by(exp.column(value.alias, "_q"), copy=False) 108 ) 109 else: 110 select = select.group_by(value.this, copy=False) 111 112 parent_select.join( 113 select, 114 on=column.eq(join_key), 115 join_type="LEFT", 116 join_alias=alias, 117 copy=False, 118 )
def
decorrelate(select, parent_select, external_columns, next_alias_name):
121def decorrelate(select, parent_select, external_columns, next_alias_name): 122 where = select.args.get("where") 123 124 if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): 125 return 126 127 table_alias = next_alias_name() 128 keys = [] 129 130 # for all external columns in the where statement, find the relevant predicate 131 # keys to convert it into a join 132 for column in external_columns: 133 if column.find_ancestor(exp.Where) is not where: 134 return 135 136 predicate = column.find_ancestor(exp.Predicate) 137 138 if not predicate or predicate.find_ancestor(exp.Where) is not where: 139 return 140 141 if isinstance(predicate, exp.Binary): 142 key = ( 143 predicate.right 144 if any(node is column for node in predicate.left.walk()) 145 else predicate.left 146 ) 147 else: 148 return 149 150 keys.append((key, column, predicate)) 151 152 if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): 153 return 154 155 is_subquery_projection = any( 156 node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery) 157 ) 158 159 value = select.selects[0] 160 key_aliases = {} 161 group_by = [] 162 163 for key, _, predicate in keys: 164 # if we filter on the value of the subquery, it needs to be unique 165 if key == value.this: 166 key_aliases[key] = value.alias 167 group_by.append(key) 168 else: 169 if key not in key_aliases: 170 key_aliases[key] = next_alias_name() 171 # all predicates that are equalities must also be in the unique 172 # so that we don't do a many to many join 173 if isinstance(predicate, exp.EQ) and key not in group_by: 174 group_by.append(key) 175 176 parent_predicate = select.find_ancestor(exp.Predicate) 177 178 # if the value of the subquery is not an agg or a key, we need to collect it into an array 179 # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. 180 agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg 181 if not value.find(exp.AggFunc) and value.this not in group_by: 182 select.select( 183 exp.alias_(agg_func(this=value.this), value.alias, quoted=False), 184 append=False, 185 copy=False, 186 ) 187 188 # exists queries should not have any selects as it only checks if there are any rows 189 # all selects will be added by the optimizer and only used for join keys 190 if isinstance(parent_predicate, exp.Exists): 191 select.args["expressions"] = [] 192 193 for key, alias in key_aliases.items(): 194 if key in group_by: 195 # add all keys to the projections of the subquery 196 # so that we can use it as a join key 197 if isinstance(parent_predicate, exp.Exists) or key != value.this: 198 select.select(f"{key} AS {alias}", copy=False) 199 else: 200 select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) 201 202 alias = exp.column(value.alias, table_alias) 203 other = _other_operand(parent_predicate) 204 205 if isinstance(parent_predicate, exp.Exists): 206 alias = exp.column(list(key_aliases.values())[0], table_alias) 207 parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") 208 elif isinstance(parent_predicate, exp.All): 209 parent_predicate = _replace( 210 parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" 211 ) 212 elif isinstance(parent_predicate, exp.Any): 213 if value.this in group_by: 214 parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") 215 else: 216 parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})") 217 elif isinstance(parent_predicate, exp.In): 218 if value.this in group_by: 219 parent_predicate = _replace(parent_predicate, f"{other} = {alias}") 220 else: 221 parent_predicate = _replace( 222 parent_predicate, 223 f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", 224 ) 225 else: 226 if is_subquery_projection: 227 alias = exp.alias_(alias, select.parent.alias) 228 229 # COUNT always returns 0 on empty datasets, so we need take that into consideration here 230 # by transforming all counts into 0 and using that as the coalesced value 231 if value.find(exp.Count): 232 233 def remove_aggs(node): 234 if isinstance(node, exp.Count): 235 return exp.Literal.number(0) 236 elif isinstance(node, exp.AggFunc): 237 return exp.null() 238 return node 239 240 alias = exp.Coalesce( 241 this=alias, 242 expressions=[value.this.transform(remove_aggs)], 243 ) 244 245 select.parent.replace(alias) 246 247 for key, column, predicate in keys: 248 predicate.replace(exp.true()) 249 nested = exp.column(key_aliases[key], table_alias) 250 251 if is_subquery_projection: 252 key.replace(nested) 253 continue 254 255 if key in group_by: 256 key.replace(nested) 257 elif isinstance(predicate, exp.EQ): 258 parent_predicate = _replace( 259 parent_predicate, 260 f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", 261 ) 262 else: 263 key.replace(exp.to_identifier("_x")) 264 parent_predicate = _replace( 265 parent_predicate, 266 f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))", 267 ) 268 269 parent_select.join( 270 select.group_by(*group_by, copy=False), 271 on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], 272 join_type="LEFT", 273 join_alias=table_alias, 274 copy=False, 275 )