Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
d89cbc3
add initial rule generator v2 scaffolding
colinthebomb1 Apr 14, 2026
519df4c
add literals and tables in v2
colinthebomb1 Apr 14, 2026
a4f5f12
add variablize literal and table in v2
colinthebomb1 Apr 14, 2026
d3f66f3
remove regex and keep x y placeholders
colinthebomb1 Apr 14, 2026
5ed9a50
canonicalize x y placeholders
colinthebomb1 Apr 14, 2026
6c1352c
add variable list discovery in v2
colinthebomb1 Apr 14, 2026
a2e6978
add merge variable list in v2
colinthebomb1 Apr 14, 2026
a9067fb
add branches support in v2
colinthebomb1 Apr 14, 2026
408a3ee
add fingerprint support in v2
colinthebomb1 Apr 14, 2026
7afd25f
add unify variable names in v2
colinthebomb1 Apr 14, 2026
caff39d
add number of variables in v2
colinthebomb1 Apr 14, 2026
6d6d21a
add initial generate general rule in v2
colinthebomb1 Apr 14, 2026
4abae95
compound query support
colinthebomb1 Apr 21, 2026
99f3934
pass all existing tests
colinthebomb1 Apr 23, 2026
74800d4
fix tests
colinthebomb1 Apr 30, 2026
bd69c1d
remove any special rules from generalizations
colinthebomb1 Apr 30, 2026
e4996c5
migrate rule generator to v2 with full AST-based generalization
colinthebomb1 Apr 30, 2026
2aac818
remove dead code from rule_generator_v2
colinthebomb1 Apr 30, 2026
4956f66
add docstrings
colinthebomb1 Apr 30, 2026
fb7e59e
improve tests
colinthebomb1 May 5, 2026
aa65545
add v2 rule helper
colinthebomb1 May 5, 2026
159d0b5
cleanup
colinthebomb1 May 5, 2026
533f038
rule v2 fix
colinthebomb1 May 5, 2026
33e3428
Fix v2 rules
colinthebomb1 May 5, 2026
03b14cb
fix spreadsheet id 18
colinthebomb1 May 7, 2026
d11796d
address comments
colinthebomb1 May 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions core/ast/node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import List, Set, Optional, Union
from typing import List, Set, Optional, Tuple, Union
from abc import ABC

from .enums import NodeType, JoinType, SortOrder
Expand Down Expand Up @@ -101,18 +101,20 @@ def __hash__(self):

class LiteralNode(Node):
"""Literal value node"""
def __init__(self, _value: str|int|float|bool|datetime|None, **kwargs):
def __init__(self, _value: str|int|float|bool|datetime|None, _alias: Optional[str] = None, **kwargs):
super().__init__(NodeType.LITERAL, **kwargs)
self.value = _value
self.alias = _alias

def __eq__(self, other):
if not isinstance(other, LiteralNode):
return False
return (super().__eq__(other) and
self.value == other.value)
self.value == other.value and
self.alias == other.alias)

def __hash__(self):
return hash((super().__hash__(), self.value))
return hash((super().__hash__(), self.value, self.alias))

class DataTypeNode(Node):
"""SQL data type node used in CAST expressions (e.g. TEXT, DATE, INTEGER)"""
Expand Down Expand Up @@ -249,24 +251,31 @@ def __hash__(self):

class JoinNode(Node):
"""JOIN clause node"""
def __init__(self, _left_table: Union['TableNode', 'JoinNode', 'SubqueryNode'], _right_table: Union['TableNode', 'SubqueryNode'], _join_type: JoinType = JoinType.INNER, _on_condition: Optional['Node'] = None, **kwargs):
def __init__(self, _left_table: Union['TableNode', 'JoinNode', 'SubqueryNode'], _right_table: Union['TableNode', 'SubqueryNode'], _join_type: JoinType = JoinType.INNER, _on_condition: Optional['Node'] = None, _using: Optional[List['Node']] = None, **kwargs):
children = [_left_table, _right_table]
if _on_condition:
children.append(_on_condition)
if _using:
children.extend(_using)
super().__init__(NodeType.JOIN, children=children, **kwargs)
self.left_table = _left_table
self.right_table = _right_table
self.join_type = _join_type
self.on_condition = _on_condition

self.using = list(_using) if _using else None

def __eq__(self, other):
if not isinstance(other, JoinNode):
return False
return (super().__eq__(other) and
self.join_type == other.join_type)
self.join_type == other.join_type and
self.using == other.using)

def __hash__(self):
return hash((super().__hash__(), self.join_type))
using_key: Tuple = ()
if self.using:
using_key = tuple(self.using)
return hash((super().__hash__(), self.join_type, using_key))

# ============================================================================
# Query Structure Nodes
Expand Down Expand Up @@ -463,4 +472,4 @@ def __eq__(self, other):
return super().__eq__(other) and self.whens == other.whens and self.else_val == other.else_val

def __hash__(self):
return hash((super().__hash__(), tuple(self.whens), self.else_val))
return hash((super().__hash__(), tuple(self.whens), self.else_val))
65 changes: 32 additions & 33 deletions core/query_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,19 +104,10 @@ def format_select(select_node: SelectNode) -> dict:

items = []
for child in children:
if child.type == NodeType.COLUMN:
if child.alias:
items.append({'name': child.alias, 'value': format_expression(child)})
else:
items.append({'value': format_expression(child)})
elif child.type == NodeType.FUNCTION:
func_expr = format_expression(child)
if hasattr(child, 'alias') and child.alias:
items.append({'name': child.alias, 'value': func_expr})
else:
items.append({'value': func_expr})
else:
items.append({'value': format_expression(child)})
item = {'value': format_expression(child)}
if hasattr(child, 'alias') and child.alias:
item['name'] = child.alias
items.append(item)

select_key = 'select_distinct' if select_node.distinct else 'select'
result[select_key] = items
Expand Down Expand Up @@ -172,47 +163,43 @@ def format_from(from_node: FromNode):

def format_join(join_node: JoinNode) -> list:
"""Format a JOIN node"""
children = list(join_node.children)

if len(children) < 2:
raise ValueError("JoinNode must have at least 2 children (left and right tables)")

left_node = children[0]
right_node = children[1]
join_condition = children[2] if len(children) > 2 else None

left_node = join_node.left_table
right_node = join_node.right_table
join_condition = join_node.on_condition
using_columns = join_node.using

result = []

# Format left side (could be a table or nested join)

if left_node.type == NodeType.JOIN:
# Nested join - recursively format
result.extend(format_join(left_node))
else:
# Simple table - this becomes the FROM table
result.append(format_source(left_node))

# Format the join itself
join_dict = {}

# Map join types to mosql format
join_type_map = {
JoinType.JOIN: 'join',
JoinType.INNER: 'inner join',
JoinType.LEFT: 'left join',
JoinType.RIGHT: 'right join',
JoinType.FULL: 'full join',
JoinType.CROSS: 'cross join',
JoinType.NATURAL: 'natural join',
}

join_key = join_type_map.get(join_node.join_type, 'join')
join_dict[join_key] = format_source(right_node)

# Add join condition if it exists

if join_condition:
join_dict['on'] = format_expression(join_condition)

if using_columns:
if len(using_columns) == 1:
join_dict['using'] = format_expression(using_columns[0])
else:
join_dict['using'] = [format_expression(col) for col in using_columns]

result.append(join_dict)

return result


Expand Down Expand Up @@ -401,5 +388,17 @@ def format_expression(node: Node):
unit = node.unit.name.lower()
return {'interval': [value, unit]}

elif node.type == NodeType.VAR:
return node.name

elif node.type == NodeType.VARSET:
return node.name

elif node.type == NodeType.QUERY:
return ast_to_json(node)

elif node.type == NodeType.COMPOUND_QUERY:
return compound_to_mosql_json(node)

else:
raise ValueError(f"Unsupported node type in expression: {node.type}")
raise ValueError(f"Unsupported node type in expression: {node.type}")
20 changes: 18 additions & 2 deletions core/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
# TODO: implement ElementVariableNode, SetVariableNode
from core.ast.enums import JoinType, SortOrder
from typing import List, Optional
import mo_sql_parsing as mosql
import json

Expand Down Expand Up @@ -133,8 +134,21 @@ def _append_source(node: Node, alias):
if 'on' in item:
on_condition = self.parse_expression(item['on'], aliases)

using_columns: Optional[List[Node]] = None
if 'using' in item:
using_value = item['using']
if isinstance(using_value, list):
using_columns = [
ColumnNode(str(c)) if not isinstance(c, dict) else self.parse_expression(c, aliases)
for c in using_value
]
elif isinstance(using_value, dict):
using_columns = [self.parse_expression(using_value, aliases)]
else:
using_columns = [ColumnNode(str(using_value))]

join_type = self.parse_join_type(join_key)
join_node = JoinNode(left_source, right_source, join_type, on_condition)
join_node = JoinNode(left_source, right_source, join_type, on_condition, using_columns)
left_source = join_node

elif 'value' in item:
Expand Down Expand Up @@ -592,7 +606,9 @@ def parse_join_type(join_key: str) -> JoinType:
"""Extract JoinType from mo_sql_parsing join key."""
key_lower = join_key.lower().replace(' ', '_')

if 'inner' in key_lower:
if 'natural' in key_lower:
return JoinType.NATURAL
Comment thread
colinthebomb1 marked this conversation as resolved.
elif 'inner' in key_lower:
return JoinType.INNER
elif 'left' in key_lower:
return JoinType.LEFT
Expand Down
Loading
Loading