Common Usage Patterns¶
This page demonstrates practical patterns for common tasks. For API fundamentals, see the User Guide.
Query Analysis¶
Audit which tables a query touches¶
import postgast
sql = """
SELECT o.id, c.name, p.title
FROM orders o
JOIN customers c ON o.customer_id = c.id
JOIN products p ON o.product_id = p.id
WHERE o.created_at > '2024-01-01'
"""
tree = postgast.parse(sql)
tables = postgast.extract_tables(tree)
columns = postgast.extract_columns(tree)
functions = postgast.extract_functions(tree)
print("Tables:", tables) # ['orders', 'customers', 'products']
print("Columns:", columns) # ['o.id', 'c.name', 'p.title', ...]
print("Functions:", functions) # []
Detect queries that use subqueries¶
from postgast import parse, find_nodes
from postgast.pg_query_pb2 import SelectStmt
def has_subquery(sql: str) -> bool:
"""Return True if the SQL contains a nested SELECT."""
tree = parse(sql)
select_count = sum(1 for _ in find_nodes(tree, SelectStmt))
return select_count > 1
has_subquery("SELECT * FROM users")
# => False
has_subquery("SELECT * FROM (SELECT id FROM users) AS sub")
# => True
Find all function calls in a query¶
import postgast
sql = "SELECT upper(name), count(*) FROM users GROUP BY upper(name)"
tree = postgast.parse(sql)
print(postgast.extract_functions(tree))
# ['upper', 'count', 'upper']
# Unique function names:
print(set(postgast.extract_functions(tree)))
# {'upper', 'count'}
Query Monitoring¶
Group queries with normalization¶
Replace literal values with positional placeholders so that structurally identical queries collapse into a single group:
import postgast
queries = [
"SELECT * FROM users WHERE id = 42",
"SELECT * FROM users WHERE id = 99",
"SELECT * FROM orders WHERE status = 'pending'",
]
for q in queries:
print(postgast.normalize(q))
# SELECT * FROM users WHERE id = $1
# SELECT * FROM users WHERE id = $1
# SELECT * FROM orders WHERE status = $1
Fingerprint queries for deduplication¶
Two queries are structurally equivalent when they have the same fingerprint, regardless of literal values, whitespace, or formatting:
import postgast
fp1 = postgast.fingerprint("SELECT * FROM users WHERE id = 1")
fp2 = postgast.fingerprint("select * from users where id = 999")
assert fp1.hex == fp2.hex # same structure
fp3 = postgast.fingerprint("SELECT * FROM orders WHERE id = 1")
assert fp1.hex != fp3.hex # different table
SQL Formatting¶
Pretty-print SQL¶
import postgast
ugly = "select u.id,u.name,o.total from users u join orders o on u.id=o.user_id where o.total>100 order by o.total desc"
print(postgast.format_sql(ugly))
Output:
SELECT
u.id,
u.name,
o.total
FROM
users u
JOIN orders o ON u.id = o.user_id
WHERE
o.total > 100
ORDER BY
o.total DESC;
Format an already-parsed tree¶
format_sql also accepts a ParseResult, so you can format after
making AST modifications:
import postgast
tree = postgast.parse("CREATE VIEW v AS SELECT 1")
postgast.set_or_replace(tree)
print(postgast.format_sql(tree))
Output:
CREATE OR REPLACE VIEW v AS
SELECT
1;
Batch Processing¶
Process a SQL migration file¶
Use split() to break a multi-statement file into individual
statements, then analyze each one:
import postgast
migration = """
CREATE TABLE users (
id serial PRIMARY KEY,
name text NOT NULL
);
CREATE INDEX idx_users_name ON users (name);
INSERT INTO users (name) VALUES ('alice'), ('bob');
"""
for stmt in postgast.split(migration):
tree = postgast.parse(stmt)
tables = postgast.extract_tables(tree)
print(f"Tables: {tables!r:30s} SQL: {stmt.strip()[:60]}...")
Split with tolerance for invalid SQL¶
The "scanner" method splits on semicolons without parsing, so it works
even when the SQL contains syntax errors:
import postgast
broken = "SELECT 1; INVALID SYNTAX HERE; SELECT 2"
stmts = postgast.split(broken, method="scanner")
print(stmts)
# ['SELECT 1', ' INVALID SYNTAX HERE', ' SELECT 2']
DDL Tooling¶
Generate rollback DROP statements¶
Automatically produce DROP statements from CREATE DDL for migration
rollback scripts:
import postgast
creates = [
"CREATE FUNCTION public.add(a int, b int) RETURNS int LANGUAGE sql AS $$ SELECT a + b $$",
"CREATE VIEW active_users AS SELECT * FROM users WHERE active",
"CREATE TRIGGER audit_trg BEFORE UPDATE ON users FOR EACH ROW EXECUTE FUNCTION audit()",
]
for sql in creates:
print(postgast.to_drop(sql))
# DROP FUNCTION public.add(int, int)
# DROP VIEW active_users
# DROP TRIGGER audit_trg ON users
Make CREATE statements idempotent¶
Add OR REPLACE to CREATE FUNCTION, CREATE TRIGGER, and
CREATE VIEW statements so they can be re-run safely:
import postgast
sql = "CREATE VIEW active_users AS SELECT * FROM users WHERE active"
print(postgast.ensure_or_replace(sql))
# CREATE OR REPLACE VIEW active_users AS SELECT * FROM users WHERE active = true
# Already idempotent input is unchanged:
postgast.ensure_or_replace(postgast.ensure_or_replace(sql))
Tree Walking¶
Walk the AST to inspect structure¶
walk() yields every node in the tree with its parent field
name, useful for debugging or generic transforms:
import postgast
tree = postgast.parse("SELECT a FROM t WHERE x = 1")
for field_name, node in postgast.walk(tree):
if field_name:
print(f" {field_name}: {type(node).__name__}")
Output:
stmts: RawStmt
stmt: SelectStmt
target_list: ResTarget
val: ColumnRef
fields: String
from_clause: RangeVar
where_clause: A_Expr
lexpr: ColumnRef
fields: String
rexpr: A_Const
Collect information with the Visitor pattern¶
Create a Visitor subclass with visit_<TypeName>
methods. Unhandled node types automatically recurse into their children:
import postgast
class QueryAnalyzer(postgast.Visitor):
def __init__(self):
self.tables = []
self.columns = []
self.has_where = False
def visit_RangeVar(self, node):
self.tables.append(node.relname)
def visit_ColumnRef(self, node):
parts = []
for f in node.fields:
inner = getattr(f, f.WhichOneof("node"))
if hasattr(inner, "sval"):
parts.append(inner.sval)
self.columns.append(".".join(parts))
def visit_SelectStmt(self, node):
if node.HasField("where_clause"):
self.has_where = True
self.generic_visit(node) # continue into children
tree = postgast.parse("SELECT u.name FROM users u WHERE u.active = true")
analyzer = QueryAnalyzer()
analyzer.visit(tree)
print(analyzer.tables) # ['users']
print(analyzer.columns) # ['u.name', 'u.active']
print(analyzer.has_where) # True
Control traversal depth¶
Omitting the call to self.generic_visit(node) in a handler stops
recursion into that node’s children. This lets you skip subtrees:
import postgast
class TopLevelTables(postgast.Visitor):
"""Collect tables from the top-level FROM clause only, ignoring subqueries."""
def __init__(self):
self.tables = []
def visit_RangeVar(self, node):
self.tables.append(node.relname)
def visit_SubLink(self, _node):
pass # don't recurse into subqueries
tree = postgast.parse(
"SELECT * FROM orders WHERE customer_id IN (SELECT id FROM vip_customers)"
)
v = TopLevelTables()
v.visit(tree)
print(v.tables) # ['orders'] — vip_customers is skipped
Use typed AST wrappers¶
Wrap a parse tree with wrap() for typed attribute access.
Works with walk_typed() and TypedVisitor:
from postgast import parse, wrap, walk_typed, TypedVisitor
tree = wrap(parse("SELECT a, b FROM t"))
for field_name, node in walk_typed(tree):
if field_name:
print(f" {field_name}: {type(node).__name__}")
Working with the Protobuf AST¶
Access raw protobuf nodes¶
The parse tree is a standard protobuf Message. You can inspect it
using all the usual protobuf APIs:
import postgast
tree = postgast.parse("SELECT id, name FROM users WHERE active = true")
# Navigate to the SelectStmt
raw_stmt = tree.stmts[0]
select = raw_stmt.stmt.select_stmt
# Inspect the target list (SELECT columns)
for target in select.target_list:
col = target.res_target.val.column_ref
name = col.fields[0].string.sval
print(f"Column: {name}")
# Inspect the FROM clause
table = select.from_clause[0].range_var
print(f"Table: {table.relname}") # 'users'
Find specific node types¶
find_nodes() filters the walk to a single protobuf
message type:
from postgast import parse, find_nodes
from postgast.pg_query_pb2 import FuncCall, RangeVar
tree = parse("SELECT lower(name), count(*) FROM users GROUP BY lower(name)")
# All table references
for rv in find_nodes(tree, RangeVar):
print(f"Table: {rv.relname}")
# All function calls
for fc in find_nodes(tree, FuncCall):
func_name = fc.funcname[0].string.sval
print(f"Function: {func_name}")
PL/pgSQL Parsing¶
Parse a PL/pgSQL function body¶
parse_plpgsql() returns a structured representation of a
PL/pgSQL function’s body, including declarations, assignments, and control
flow:
import json
import postgast
sql = """
CREATE FUNCTION greet(name text) RETURNS text LANGUAGE plpgsql AS $$
DECLARE
result text;
BEGIN
result := 'Hello, ' || name;
RETURN result;
END;
$$
"""
parsed = postgast.parse_plpgsql(sql)
print(json.dumps(parsed, indent=2))
Error Handling¶
Catch and inspect parse errors¶
All functions raise PgQueryError on invalid SQL. The
exception carries the error message, cursor position, and source location:
import postgast
try:
postgast.parse("SELECT * FORM users")
except postgast.PgQueryError as e:
print(f"Error: {e.message}")
print(f"Position: {e.cursorpos}")
# Error: syntax error at or near "users"
# Position: 15
Validate SQL before execution¶
Use parse as a fast syntax check without hitting the database:
import postgast
def is_valid_sql(sql: str) -> bool:
try:
postgast.parse(sql)
return True
except postgast.PgQueryError:
return False
is_valid_sql("SELECT * FROM users") # True
is_valid_sql("SLECT * FORM users") # False
is_valid_sql("SELECT 1; DROP TABLE users; --") # True (valid SQL!)
Tokenization¶
Scan SQL into tokens¶
scan() returns the raw token stream, useful for syntax
highlighting, keyword detection, or building custom splitters:
import postgast
result = postgast.scan("SELECT id FROM users WHERE active = true")
for token in result.tokens:
# Extract the token text using byte positions
text = "SELECT id FROM users WHERE active = true"[token.start:token.end]
print(f"{text:12s} token={token.token} keyword={token.keyword_kind}")