"""Tree walking and visitor pattern for protobuf parse trees."""
from __future__ import annotations
from typing import TYPE_CHECKING
from google.protobuf.descriptor import FieldDescriptor
from google.protobuf.message import Message
if TYPE_CHECKING:
from collections.abc import Generator
from postgast.nodes.base import AstNode
_NODE_ONEOF = "node"
def unwrap_node(node: Message) -> Message:
"""If *node* is a ``Node`` oneof wrapper, return the inner concrete message; otherwise return *node* unchanged.
In libpg_query's protobuf schema every child reference is wrapped in a generic ``Node`` message that contains a
single ``oneof node`` field. This helper peels that wrapper so you can work with the concrete message type
(``SelectStmt``, ``ColumnRef``, etc.) directly.
If *node* is already a concrete message (not a ``Node`` wrapper), it is returned as-is, making this safe to call
unconditionally.
Args:
node: Any protobuf ``Message`` — typically a ``pg_query_pb2.Node``, but concrete messages are accepted too.
Returns:
The inner concrete message if *node* was a ``Node`` wrapper, otherwise *node* itself.
Example:
>>> from postgast import parse
>>> from postgast.walk import unwrap_node
>>> tree = parse("SELECT 1")
>>> raw_stmt = tree.stmts[0]
>>> # raw_stmt.stmt is a Node wrapper — unwrap to get the SelectStmt
>>> select = unwrap_node(raw_stmt.stmt)
>>> type(select).__name__
'SelectStmt'
"""
oneofs = type(node).DESCRIPTOR.oneofs
if len(oneofs) == 1 and oneofs[0].name == _NODE_ONEOF:
which = node.WhichOneof(_NODE_ONEOF)
if which is not None:
return getattr(node, which)
return node
def _iter_children(node: Message) -> Generator[tuple[str, Message], None, None]:
"""Yield ``(field_name, child_message)`` for every message-typed field on *node*, unwrapping ``Node`` wrappers."""
for fd, value in node.ListFields():
if fd.type != FieldDescriptor.TYPE_MESSAGE:
continue
if isinstance(value, Message):
yield fd.name, unwrap_node(value)
else:
for item in value:
yield fd.name, unwrap_node(item)
[docs]
def walk(node: Message) -> Generator[tuple[str, Message], None, None]:
"""Depth-first pre-order traversal of a protobuf message tree.
Yields ``(field_name, message)`` tuples for every protobuf message encountered. The *field_name* is the protobuf
field name that led to the message (e.g. ``"where_clause"``, ``"target_list"``), or an empty string for the root.
``Node`` oneof wrappers are transparently unwrapped so that only concrete message types appear in the output.
Args:
node: Any protobuf ``Message`` instance (``ParseResult``, ``SelectStmt``, etc.).
Yields:
``(field_name, message)`` tuples in depth-first pre-order.
Example:
>>> from postgast import parse, walk
>>> tree = parse("SELECT 1")
>>> for field_name, node in walk(tree):
... if field_name:
... print(f"{field_name}: {type(node).__name__}")
stmts: RawStmt
stmt: SelectStmt
target_list: ResTarget
val: A_Const
ival: Integer
"""
node = unwrap_node(node)
yield "", node
stack: list[tuple[str, Message]] = list(reversed(list(_iter_children(node))))
while stack:
field_name, child = stack.pop()
yield field_name, child
stack.extend(reversed(list(_iter_children(child))))
def walk_typed(node: AstNode) -> Generator[tuple[str, AstNode], None, None]:
"""Depth-first pre-order traversal of a typed AST wrapper tree.
Like :func:`walk` but accepts and yields typed :class:`AstNode` wrappers instead of raw protobuf ``Message``
objects. Delegates to :func:`walk` internally.
Args:
node: A typed ``AstNode`` wrapper (e.g. from :func:`postgast.wrap`).
Yields:
``(field_name, wrapper)`` tuples in depth-first pre-order.
Example:
>>> from postgast import parse, wrap, walk_typed
>>> tree = wrap(parse("SELECT 1"))
>>> for field_name, node in walk_typed(tree):
... if field_name:
... print(f"{field_name}: {type(node).__name__}")
stmts: RawStmt
stmt: SelectStmt
target_list: ResTarget
val: A_Const
ival: Integer
"""
from postgast.nodes.base import _wrap # pyright: ignore[reportPrivateUsage]
for field_name, message in walk(node._pb): # pyright: ignore[reportPrivateUsage]
yield field_name, _wrap(message)
[docs]
class Visitor:
"""Base class for protobuf parse tree visitors.
Subclass and override ``visit_<TypeName>`` methods (e.g. ``visit_SelectStmt``, ``visit_ColumnRef``) to handle
specific node types. Unhandled types fall through to :meth:`generic_visit` which recurses into children.
Call :meth:`visit` on a root message to start traversal::
class TableCollector(Visitor):
def __init__(self):
self.tables = []
def visit_RangeVar(self, node):
self.tables.append(node.relname)
collector = TableCollector()
collector.visit(parse_result)
"""
[docs]
def visit(self, node: Message) -> None:
"""Dispatch *node* to ``visit_<TypeName>`` or :meth:`generic_visit`.
Looks up a method named ``visit_<TypeName>`` (where ``<TypeName>`` matches the protobuf descriptor name, e.g.
``visit_SelectStmt``). Falls back to :meth:`generic_visit` if no specific handler exists.
Args:
node: Any protobuf ``Message`` instance.
"""
node = unwrap_node(node)
type_name = type(node).DESCRIPTOR.name
handler = getattr(self, f"visit_{type_name}", self.generic_visit)
handler(node)
[docs]
def generic_visit(self, node: Message) -> None:
"""Visit all message-typed children of *node*.
Override this method to customize the default traversal behavior. Call ``super().generic_visit(node)`` from a
``visit_*`` handler to continue recursion into a node's children after custom processing.
Args:
node: Any protobuf ``Message`` instance.
"""
for _field_name, child in _iter_children(node):
self.visit(child)
class TypedVisitor:
"""Base class for typed AST wrapper visitors.
Like :class:`Visitor` but dispatches to handlers that receive typed :class:`AstNode` wrappers instead of raw
protobuf ``Message`` objects.
Subclass and override ``visit_<TypeName>`` methods to handle specific node types with full type safety::
class TableCollector(TypedVisitor):
def __init__(self):
self.tables = []
def visit_RangeVar(self, node):
self.tables.append(node.relname)
collector = TableCollector()
collector.visit(wrap(parse_result))
"""
def visit(self, node: AstNode) -> None:
"""Dispatch *node* to ``visit_<TypeName>`` or :meth:`generic_visit`."""
method_name = f"visit_{type(node).__name__}"
visitor = getattr(self, method_name, self.generic_visit)
visitor(node)
def generic_visit(self, node: AstNode) -> None:
"""Visit all child nodes of *node*.
Override this method to customize the default traversal behavior. Call ``super().generic_visit(node)`` from a
``visit_*`` handler to continue recursion into a node's children after custom processing.
"""
from postgast.nodes.base import _wrap # pyright: ignore[reportPrivateUsage]
for _field_name, child in _iter_children(node._pb): # pyright: ignore[reportPrivateUsage]
self.visit(_wrap(child))