Source code for postgast.split

"""SQL statement splitting via libpg_query."""

from __future__ import annotations

from typing import Literal

from postgast.errors import check_error
from postgast.native import lib

_SPLIT_METHODS = {
    "scanner": lib.pg_query_split_with_scanner,
    "parser": lib.pg_query_split_with_parser,
}


[docs] def split(sql: str, *, method: Literal["scanner", "parser"] = "parser") -> list[str]: """Split a multi-statement SQL string into individual statements. Calls the selected libpg_query split function to split the input into individual SQL statements. The ``"parser"`` method (default) uses the full PostgreSQL parser for improved accuracy, while ``"scanner"`` uses a faster scanner-based approach that tolerates invalid SQL. Args: sql: A SQL string potentially containing multiple statements. method: Which libpg_query splitter to use. ``"parser"`` (default) calls ``pg_query_split_with_parser`` for improved accuracy on valid SQL. ``"scanner"`` calls ``pg_query_split_with_scanner``, which tolerates malformed SQL but may miss some edge cases. Returns: A list of individual SQL statement strings. Raises: PgQueryError: If the SQL causes a parse/scanner error. ValueError: If *method* is not ``"scanner"`` or ``"parser"``. Example: >>> split("SELECT 1; SELECT 2") ['SELECT 1', ' SELECT 2'] >>> split("SELECT 'hello;world'") ["SELECT 'hello;world'"] """ split_fn = _SPLIT_METHODS.get(method) if split_fn is None: raise ValueError(f"Unknown split method {method!r}; expected 'scanner' or 'parser'") sql_bytes = sql.encode("utf-8") result = split_fn(sql_bytes) try: check_error(result) stmts: list[str] = [] for i in range(result.n_stmts): stmt = result.stmts[i].contents stmts.append(sql_bytes[stmt.stmt_location : stmt.stmt_location + stmt.stmt_len].decode("utf-8")) return stmts finally: lib.pg_query_free_split_result(result)