API Reference¶
Core Functions¶
- postgast.parse(query)[source]¶
Parse a SQL query into a protobuf AST.
Calls libpg_query’s
pg_query_parse_protobufto parse the query and returns the deserializedParseResultprotobuf message containing the abstract syntax tree.- Parameters:
query (
str) – A SQL query string.- Return type:
ParseResult- Returns:
A
ParseResultprotobuf message withversion(int) andstmts(list ofRawStmt) fields.- Raises:
PgQueryError – If the query contains a syntax error.
Example
>>> tree = parse("SELECT id, name FROM users WHERE active = true") >>> len(tree.stmts) 1 >>> tree.stmts[0].stmt.HasField("select_stmt") True
- postgast.deparse(tree)[source]¶
Convert a protobuf parse tree back into a SQL string.
Calls libpg_query’s
pg_query_deparse_protobufto convert aParseResultAST back into SQL text. This is the inverse ofpostgast.parse().Note
The deparsed SQL is canonicalized by libpg_query and may differ from the original query in whitespace, casing, or parenthesization while remaining semantically equivalent.
- Parameters:
tree (
ParseResult|AstNode) – AParseResultprotobuf message (as returned bypostgast.parse()), or a typedAstNodewrapper.- Return type:
- Returns:
The deparsed SQL string.
- Raises:
PgQueryError – If the parse tree cannot be deparsed.
Example
>>> from postgast import parse, deparse >>> tree = parse("SELECT id FROM users") >>> deparse(tree) 'SELECT id FROM users'
- postgast.normalize(query)[source]¶
Normalize a SQL query by replacing literal constants with placeholders.
Calls libpg_query’s
pg_query_normalizeto replace literal values (strings, numbers, etc.) with parameter placeholders ($1,$2, …). This is useful for grouping structurally equivalent queries.- Parameters:
query (
str) – A SQL query string.- Return type:
- Returns:
The normalized query with constants replaced by positional placeholders.
- Raises:
PgQueryError – If the query cannot be parsed.
Example
>>> normalize("SELECT * FROM users WHERE id = 42 AND name = 'Alice'") 'SELECT * FROM users WHERE id = $1 AND name = $2'
- postgast.fingerprint(query)[source]¶
Compute a structural fingerprint of a SQL query.
Calls libpg_query’s
pg_query_fingerprintto produce a hash that identifies structurally equivalent queries regardless of literal values.- Parameters:
query (
str) – A SQL query string.- Return type:
- Returns:
A
FingerprintResultcontaining the numeric fingerprint and its hex string representation.- Raises:
PgQueryError – If the query cannot be parsed.
Example
>>> result = fingerprint("SELECT * FROM users WHERE id = 1") >>> result.hex '0ca858a0484f5826' >>> result == fingerprint("SELECT * FROM users WHERE id = 2") True
- postgast.split(sql, *, method='parser')[source]¶
Split a multi-statement SQL string into individual statements.
Calls the selected libpg_query split function to split the input into individual SQL statements. The
"parser"method (default) uses the full PostgreSQL parser for improved accuracy, while"scanner"uses a faster scanner-based approach that tolerates invalid SQL.- Parameters:
sql (
str) – A SQL string potentially containing multiple statements.method (
Literal['scanner','parser']) – Which libpg_query splitter to use."parser"(default) callspg_query_split_with_parserfor improved accuracy on valid SQL."scanner"callspg_query_split_with_scanner, which tolerates malformed SQL but may miss some edge cases.
- Return type:
- Returns:
A list of individual SQL statement strings.
- Raises:
PgQueryError – If the SQL causes a parse/scanner error.
ValueError – If method is not
"scanner"or"parser".
Example
>>> split("SELECT 1; SELECT 2") ['SELECT 1', ' SELECT 2'] >>> split("SELECT 'hello;world'") ["SELECT 'hello;world'"]
- postgast.scan(sql)[source]¶
Tokenize a SQL string into a sequence of scan tokens.
Calls libpg_query’s
pg_query_scanto tokenize the input and returns the deserializedScanResultprotobuf message containing a list ofScanTokenobjects with token type, keyword kind, and byte positions.- Parameters:
sql (
str) – A SQL string to tokenize.- Return type:
ScanResult- Returns:
A
ScanResultprotobuf message withversion(int) andtokens(list ofScanToken) fields.- Raises:
PgQueryError – If the input contains a scan error (e.g., unterminated string literal).
Example
>>> result = scan("SELECT 1") >>> len(result.tokens) 2 >>> result.tokens[0].start, result.tokens[0].end (0, 6)
Tree Walking¶
- postgast.walk(node)[source]¶
Depth-first pre-order traversal of a protobuf message tree.
Yields
(field_name, message)tuples for every protobuf message encountered. The field_name is the protobuf field name that led to the message (e.g."where_clause","target_list"), or an empty string for the root.Nodeoneof wrappers are transparently unwrapped so that only concrete message types appear in the output.- Parameters:
node (
Message) – Any protobufMessageinstance (ParseResult,SelectStmt, etc.).- Yields:
(field_name, message)tuples in depth-first pre-order.- Return type:
Example
>>> from postgast import parse, walk >>> tree = parse("SELECT 1") >>> for field_name, node in walk(tree): ... if field_name: ... print(f"{field_name}: {type(node).__name__}") stmts: RawStmt stmt: SelectStmt target_list: ResTarget val: A_Const ival: Integer
- class postgast.Visitor[source]¶
Base class for protobuf parse tree visitors.
Subclass and override
visit_<TypeName>methods (e.g.visit_SelectStmt,visit_ColumnRef) to handle specific node types. Unhandled types fall through togeneric_visit()which recurses into children.Call
visit()on a root message to start traversal:class TableCollector(Visitor): def __init__(self): self.tables = [] def visit_RangeVar(self, node): self.tables.append(node.relname) collector = TableCollector() collector.visit(parse_result)
- visit(node)[source]¶
Dispatch node to
visit_<TypeName>orgeneric_visit().Looks up a method named
visit_<TypeName>(where<TypeName>matches the protobuf descriptor name, e.g.visit_SelectStmt). Falls back togeneric_visit()if no specific handler exists.
AST Helpers¶
- postgast.find_nodes(tree, node_type)[source]¶
Yield all protobuf messages matching node_type from a parse tree.
Walks the tree in depth-first pre-order (same as
walk()) and yields every message that is an instance of node_type.- Parameters:
- Yields:
Matching instances in depth-first pre-order.
- Return type:
Example
>>> from postgast import find_nodes, parse >>> from postgast.pg_query_pb2 import RangeVar >>> tree = parse("SELECT * FROM users JOIN orders ON users.id = orders.uid") >>> [n.relname for n in find_nodes(tree, RangeVar)] ['users', 'orders']
- postgast.extract_tables(tree)[source]¶
Yield table names referenced in a parse tree.
Walks all
RangeVarnodes and yields their names as dot-joined strings ("schema.table"when schema-qualified,"table"otherwise).Results preserve encounter order and include duplicates. Use
set()on the result to get unique table names.- Parameters:
tree (
Message) – Any protobufMessage(ParseResult,SelectStmt, etc.).- Yields:
Table names in encounter order.
- Return type:
Example
>>> from postgast import extract_tables, parse >>> tree = parse("SELECT * FROM public.users JOIN orders ON true") >>> list(extract_tables(tree)) ['public.users', 'orders']
- postgast.extract_columns(tree)[source]¶
Yield column references found in a parse tree.
Walks all
ColumnRefnodes and yields their names as dot-joined strings.SELECT *produces"*";t.*produces"t.*".Results preserve encounter order and include duplicates.
- Parameters:
tree (
Message) – Any protobufMessage(ParseResult,SelectStmt, etc.).- Yields:
Column references in encounter order.
- Return type:
Example
>>> from postgast import extract_columns, parse >>> tree = parse("SELECT u.name, age FROM users u WHERE age > 18") >>> list(extract_columns(tree)) ['u.name', 'age', 'age']
- postgast.extract_functions(tree)[source]¶
Yield function call names found in a parse tree.
Walks all
FuncCallnodes and yields their names as dot-joined strings ("schema.func"when schema-qualified,"func"otherwise).Results preserve encounter order and include duplicates.
- Parameters:
tree (
Message) – Any protobufMessage(ParseResult,SelectStmt, etc.).- Yields:
Function names in encounter order.
- Return type:
Example
>>> from postgast import extract_functions, parse >>> tree = parse("SELECT lower(name), count(*) FROM users") >>> list(extract_functions(tree)) ['lower', 'count']
- postgast.extract_function_identity(tree)[source]¶
Return the identity of the first
CREATE FUNCTIONstatement in a parse tree.Finds the first
CreateFunctionStmtnode whereis_procedureisFalseand returns aFunctionIdentitywith the schema and function name.- Parameters:
tree (
Message) – Any protobufMessage(ParseResult,SelectStmt, etc.).- Return type:
- Returns:
A
FunctionIdentityorNoneif no matching node is found.
Example
>>> from postgast import extract_function_identity, parse >>> sql = "CREATE FUNCTION public.add(a int, b int) RETURNS int LANGUAGE sql AS $$ SELECT a + b $$" >>> identity = extract_function_identity(parse(sql)) >>> identity.schema, identity.name ('public', 'add')
- postgast.extract_trigger_identity(tree)[source]¶
Return the identity of the first
CREATE TRIGGERstatement in a parse tree.Finds the first
CreateTrigStmtnode and returns aTriggerIdentitywith the trigger name, schema, and table name.- Parameters:
tree (
Message) – Any protobufMessage(ParseResult,SelectStmt, etc.).- Return type:
- Returns:
A
TriggerIdentityorNoneif no matching node is found.
Example
>>> from postgast import extract_trigger_identity, parse >>> sql = "CREATE TRIGGER my_trg AFTER INSERT ON orders FOR EACH ROW EXECUTE FUNCTION notify()" >>> identity = extract_trigger_identity(parse(sql)) >>> identity.trigger, identity.table ('my_trg', 'orders')
DDL Helpers¶
- postgast.set_or_replace(tree)[source]¶
Set
replace = Trueon eligible DDL nodes in a parse tree.Walks tree and flips the
replaceflag onCreateFunctionStmt,CreateTrigStmt, andViewStmtnodes where it is currentlyFalse.- Parameters:
tree (
Message) – A protobufMessage(typically aParseResult).- Return type:
- Returns:
Number of nodes that were modified.
Example
>>> from postgast import set_or_replace, parse, deparse >>> tree = parse("CREATE VIEW v AS SELECT 1") >>> set_or_replace(tree) 1 >>> "OR REPLACE" in deparse(tree) True
- postgast.ensure_or_replace(sql)[source]¶
Return sql with all eligible
CREATEstatements rewritten toCREATE OR REPLACE.Parses the input, sets
replace = TrueonCreateFunctionStmt,CreateTrigStmt, andViewStmtnodes, and deparses back to SQL.- Parameters:
sql (
str) – One or more SQL statements.- Return type:
- Returns:
The rewritten SQL text.
- Raises:
PgQueryError – If sql cannot be parsed.
Example
>>> from postgast import ensure_or_replace >>> ensure_or_replace("CREATE VIEW v AS SELECT 1") 'CREATE OR REPLACE VIEW v AS SELECT 1'
- postgast.to_drop(sql)[source]¶
Return the
DROPstatement corresponding to aCREATEstatement.Parses sql, builds a
DropStmtprotobuf from the parsed AST, and deparses it back to SQL. Supports:CREATE FUNCTION/CREATE PROCEDURECREATE TRIGGERCREATE VIEWCREATE TABLECREATE INDEXCREATE SEQUENCECREATE SCHEMACREATE EXTENSIONCREATE TYPE(enum, range, composite)CREATE MATERIALIZED VIEW ... AS
All
OR REPLACEandIF NOT EXISTSvariants are accepted.- Parameters:
sql (
str) – A single CREATE statement.- Return type:
- Returns:
The corresponding DROP statement.
- Raises:
ValueError – If sql contains zero or more than one statement, or if the statement is not a supported CREATE type.
PgQueryError – If sql is not valid SQL.
Example
>>> from postgast import to_drop >>> to_drop("CREATE TABLE public.users (id int)") 'DROP TABLE public.users'
Types¶
- class postgast.FingerprintResult(fingerprint, hex)[source]¶
Result of fingerprinting a SQL query.
- fingerprint¶
The uint64 numeric hash.
- hex¶
The hexadecimal string representation of the fingerprint.
Exceptions¶
- class postgast.PgQueryError(message, *, cursorpos=0, context=None, funcname=None, filename=None, lineno=0)[source]¶
Structured error raised when libpg_query rejects a SQL statement.
Every postgast function that calls into libpg_query (
parse(),deparse(),normalize(),fingerprint(),split(),scan(),parse_plpgsql(), andformat_sql()) may raise this exception. The error carries the same structured fields that the C library provides, so callers can build precise diagnostics (e.g., underlining the offending token) without parsing the message string.cursorposis a 1-based byte offset into the original SQL string pointing to the token where the error was detected. When it is0the position is unknown. Because it counts bytes,e.cursorpos - 1only equals the corresponding Python string index when the SQL is pure ASCII. For SQL containing multibyte UTF-8 characters (e.g., Unicode identifiers or string literals), index into the UTF-8-encodedbytesrepresentation instead:pos = sql.encode("utf-8")[: e.cursorpos - 1].decode("utf-8") char_offset = len(pos)
The
funcname,filename, andlinenofields refer to the internal C source of libpg_query / PostgreSQL’s parser, not to your Python code. They are mainly useful for filing upstream bug reports.- Parameters:
- message¶
Human-readable error description from the PostgreSQL parser.
- cursorpos¶
1-based byte offset in the SQL string where the error was detected (
0when the position is unavailable).
- context¶
Additional context from the parser (e.g., PL/pgSQL function name), or
None.
- funcname¶
Internal C function name where the error originated, or
None.
- filename¶
Internal C source file where the error originated, or
None.
- lineno¶
Line number in the internal C source file (
0when unavailable).
Examples
Catch a syntax error and inspect the cursor position:
>>> from postgast import parse, PgQueryError >>> try: ... parse("SELECT FROM") ... except PgQueryError as e: ... print(e.cursorpos) 12
Use
cursorposto highlight the error location (ASCII-safe shortcut):>>> from postgast import parse, PgQueryError >>> sql = "SELECT * FORM users" >>> try: ... parse(sql) ... except PgQueryError as e: ... idx = max(e.cursorpos - 1, 0) ... print(sql) ... print(" " * idx + "^") ... print(e.message) SELECT * FORM users ^ syntax error at or near "FORM"
For SQL that may contain non-ASCII characters, convert via the encoded bytes to get the correct character offset:
sql = "SELECT 'ü' FORM t" try: parse(sql) except PgQueryError as e: idx = len(sql.encode("utf-8")[: e.cursorpos - 1].decode("utf-8")) print(sql) print(" " * idx + "^")