Files
media-downloader/modules/pg_adapter.py
Todd 0d7b2b1aab Initial commit
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-29 22:42:55 -04:00

1026 lines
35 KiB
Python

#!/usr/bin/env python3
"""
PostgreSQL Adapter — Drop-in sqlite3 replacement module.
When DATABASE_BACKEND=postgresql, this module is monkey-patched into sys.modules['sqlite3'].
Every existing `import sqlite3; sqlite3.connect(...)` transparently routes through psycopg2
with automatic SQL dialect translation. Zero existing queries need to change.
Architecture:
connect(db_path) → PgConnection(psycopg2_conn)
→ PgCursor.execute(sql, params)
→ _translate_sql(sql) [cached]
→ psycopg2 cursor
"""
import os
import re
import threading
from collections import OrderedDict
import psycopg2
import psycopg2.errors
import psycopg2.extensions
import psycopg2.pool
# ---------------------------------------------------------------------------
# Module-level constants (sqlite3 API compatibility)
# ---------------------------------------------------------------------------
PARSE_DECLTYPES = 1
PARSE_COLNAMES = 2
SQLITE_OK = 0
SQLITE_ERROR = 1
sqlite_version = "3.45.0"
sqlite_version_info = (3, 45, 0)
version = "2.6.0"
version_info = (2, 6, 0)
paramstyle = "qmark"
apilevel = "2.0"
threadsafety = 1
# Re-export exception types so `sqlite3.OperationalError` etc. work
Error = psycopg2.Error
OperationalError = psycopg2.OperationalError
IntegrityError = psycopg2.IntegrityError
InterfaceError = psycopg2.InterfaceError
DatabaseError = psycopg2.DatabaseError
DataError = psycopg2.DataError
InternalError = psycopg2.InternalError
NotSupportedError = psycopg2.NotSupportedError
ProgrammingError = psycopg2.ProgrammingError
Warning = psycopg2.Warning
# ---------------------------------------------------------------------------
# Connection pool (singleton)
# ---------------------------------------------------------------------------
_DATABASE_URL = os.getenv(
"DATABASE_URL",
"postgresql://media_downloader:PNsihOXvvuPwWiIvGlsc9Fh2YmMmB@localhost/media_downloader",
)
_pool = None
_pool_lock = threading.Lock()
# Adapter/converter registries (sqlite3 compat — mostly no-ops for PG)
_adapters = {}
_converters = {}
def register_adapter(type_, adapter):
"""sqlite3.register_adapter compatibility — stored but PG handles types natively."""
_adapters[type_] = adapter
def register_converter(typename, converter):
"""sqlite3.register_converter compatibility — stored but PG handles types natively."""
_converters[typename] = converter
def _get_pool():
global _pool
if _pool is None:
with _pool_lock:
if _pool is None:
_pool = psycopg2.pool.ThreadedConnectionPool(
minconn=2,
maxconn=30,
dsn=_DATABASE_URL,
)
return _pool
def connect(database=":memory:", timeout=5.0, detect_types=0,
isolation_level="", check_same_thread=True,
factory=None, cached_statements=128, uri=False):
"""sqlite3.connect() replacement — returns a PgConnection."""
return PgConnection(database, timeout)
# ---------------------------------------------------------------------------
# Unique-constraint cache (for INSERT OR REPLACE translation)
# ---------------------------------------------------------------------------
_pk_cache = {} # table_name -> [column_names]
_pk_cache_lock = threading.Lock()
_pk_cache_loaded = False
def _load_pk_cache(pg_conn):
"""Load primary-key / unique-constraint info from information_schema."""
global _pk_cache_loaded
if _pk_cache_loaded:
return
with _pk_cache_lock:
if _pk_cache_loaded:
return
cur = pg_conn.cursor()
cur.execute("""
SELECT tc.table_name, kcu.column_name, tc.constraint_type
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.table_schema = 'public'
AND tc.constraint_type IN ('PRIMARY KEY', 'UNIQUE')
ORDER BY tc.table_name, tc.constraint_type, kcu.ordinal_position
""")
pk_map = {} # table -> [pk_cols]
uniq_map = {} # table -> [uniq_cols] (first unique constraint)
for table, col, ctype in cur.fetchall():
if ctype == "PRIMARY KEY":
pk_map.setdefault(table, []).append(col)
else:
uniq_map.setdefault(table, []).append(col)
cur.close()
# Prefer PK; fallback to first UNIQUE constraint
for t in set(list(pk_map.keys()) + list(uniq_map.keys())):
_pk_cache[t] = pk_map.get(t, uniq_map.get(t, []))
_pk_cache_loaded = True
# ---------------------------------------------------------------------------
# SQL Translation Engine (cached)
# ---------------------------------------------------------------------------
_sql_cache = OrderedDict()
_SQL_CACHE_MAX = 4096
_sql_cache_lock = threading.Lock()
# Pre-compiled regexes
_RE_PRAGMA = re.compile(
r"^\s*PRAGMA\s+(journal_mode|busy_timeout|synchronous|foreign_keys|wal_checkpoint|cache_size|temp_store)\b",
re.IGNORECASE,
)
_RE_PRAGMA_TABLE_INFO = re.compile(
r"^\s*PRAGMA\s+table_info\s*\(\s*['\"]?(\w+)['\"]?\s*\)",
re.IGNORECASE,
)
_RE_INSERT_OR_IGNORE = re.compile(
r"\bINSERT\s+OR\s+IGNORE\s+INTO\b",
re.IGNORECASE,
)
_RE_INSERT_OR_REPLACE = re.compile(
r"\bINSERT\s+OR\s+REPLACE\s+INTO\s+(\w+)\b",
re.IGNORECASE,
)
_RE_DATETIME_NOW = re.compile(
r"datetime\s*\(\s*'now'\s*(?:,\s*'localtime'\s*)?\)",
re.IGNORECASE,
)
_RE_DATETIME_NOW_INTERVAL = re.compile(
r"datetime\s*\(\s*'now'\s*(?:,\s*'localtime'\s*)?,\s*'(-?\d+)\s+(day|days|hour|hours|minute|minutes|second|seconds|month|months)'\s*\)",
re.IGNORECASE,
)
_RE_DATETIME_NOW_PARAM_INTERVAL = re.compile(
r"datetime\s*\(\s*'now'\s*,\s*\?\s*\|\|\s*'\s*(day|days|hour|hours|minute|minutes|second|seconds)'\s*\)",
re.IGNORECASE,
)
_RE_DATE_NOW = re.compile(
r"date\s*\(\s*'now'\s*\)",
re.IGNORECASE,
)
_RE_DATE_NOW_INTERVAL = re.compile(
r"date\s*\(\s*'now'\s*,\s*'(-?\d+)\s+(day|days)'\s*\)",
re.IGNORECASE,
)
_RE_SCALAR_MIN = re.compile(
r"\bMIN\s*\(\s*([^,()]+)\s*,\s*([^()]+)\s*\)",
re.IGNORECASE,
)
_RE_SCALAR_MAX = re.compile(
r"\bMAX\s*\(\s*([^,()]+)\s*,\s*([^()]+)\s*\)",
re.IGNORECASE,
)
_RE_JSON_EXTRACT = re.compile(
r"json_extract\s*\(\s*(\w+(?:\.\w+)?)\s*,\s*'\$\.(\w+)'\s*\)",
re.IGNORECASE,
)
# strftime('%format', expr) → TO_CHAR(expr, 'pg_format')
# Uses a function to handle nested parentheses instead of a simple regex
_RE_STRFTIME_PREFIX = re.compile(
r"strftime\s*\(\s*'([^']+)'\s*,\s*",
re.IGNORECASE,
)
# datetime(column_expr) → cast as timestamp (not datetime('now') which is handled separately)
_RE_DATETIME_COLUMN = re.compile(
r"datetime\s*\(\s*(\w+(?:\.\w+)?|\?)\s*\)",
re.IGNORECASE,
)
_BOOLEAN_COLUMNS = {'has_match', 'notified', 'downloaded', 'has_images'}
_RE_BOOL_EQ_1 = re.compile(
r"\b(\w+(?:\.\w+)?)\s*=\s*1\b",
)
_RE_BOOL_EQ_0 = re.compile(
r"\b(\w+(?:\.\w+)?)\s*=\s*0\b",
)
_RE_GROUP_CONCAT_DISTINCT = re.compile(
r"GROUP_CONCAT\s*\(\s*DISTINCT\s+([\w.]+)\s*\)",
re.IGNORECASE,
)
_RE_GROUP_CONCAT_SEP = re.compile(
r"GROUP_CONCAT\s*\(\s*([\w.]+)\s*,\s*'([^']*)'\s*\)",
re.IGNORECASE,
)
_RE_GROUP_CONCAT = re.compile(
r"GROUP_CONCAT\s*\(\s*([\w.]+)\s*\)",
re.IGNORECASE,
)
_RE_AUTOINCREMENT = re.compile(r"\bAUTOINCREMENT\b", re.IGNORECASE)
_RE_BLOB = re.compile(r"\bBLOB\b", re.IGNORECASE)
_RE_BEGIN_IMMEDIATE = re.compile(r"^\s*BEGIN\s+IMMEDIATE\b", re.IGNORECASE)
_RE_SQLITE_MASTER = re.compile(
r"\bsqlite_master\b",
re.IGNORECASE,
)
_RE_COLLATE_NOCASE = re.compile(r"\s*COLLATE\s+NOCASE\b", re.IGNORECASE)
_RE_IFNULL = re.compile(r"\bIFNULL\s*\(", re.IGNORECASE)
_RE_STRFTIME = re.compile(
r"strftime\s*\(\s*'([^']+)'\s*,\s*(\w+(?:\.\w+)?)\s*\)",
re.IGNORECASE,
)
_RE_CREATE_TRIGGER = re.compile(
r"^\s*CREATE\s+TRIGGER\b.*?\bEND\s*;?\s*$",
re.IGNORECASE | re.DOTALL,
)
_RE_BOOLEAN_DEFAULT_0 = re.compile(
r"\bBOOLEAN\s+DEFAULT\s+0\b",
re.IGNORECASE,
)
_RE_BOOLEAN_DEFAULT_1 = re.compile(
r"\bBOOLEAN\s+DEFAULT\s+1\b",
re.IGNORECASE,
)
# strftime format map: SQLite → PostgreSQL TO_CHAR
_STRFTIME_MAP = {
"%Y": "YYYY",
"%m": "MM",
"%d": "DD",
"%H": "HH24",
"%M": "MI",
"%S": "SS",
"%W": "IW",
"%j": "DDD",
}
def _convert_strftime_format(sqlite_fmt):
"""Convert a SQLite strftime format string to PostgreSQL TO_CHAR format."""
pg_fmt = sqlite_fmt
for s_code, p_code in _STRFTIME_MAP.items():
pg_fmt = pg_fmt.replace(s_code, p_code)
return pg_fmt
def _translate_sql(sql):
"""
Translate a SQLite SQL string to PostgreSQL dialect.
Returns (translated_sql, is_noop, extra_params) where is_noop=True means
skip execution. extra_params is a tuple of additional parameters that must
be prepended to any user-supplied parameters (used for parameterised
PRAGMA table_info translation).
"""
original = sql
# Check cache first
with _sql_cache_lock:
cached = _sql_cache.get(original)
if cached is not None:
# Move to end (LRU)
_sql_cache.move_to_end(original)
return cached
is_noop = False
extra_params = None
translated = sql
# --- PRAGMA handling ---
# Check for PRAGMA table_info first (special case that returns data)
m = _RE_PRAGMA_TABLE_INFO.match(translated)
if m:
# table_name is safe: regex \w+ guarantees only alphanumeric/underscore
table_name = m.group(1)
translated = (
"SELECT ordinal_position - 1 as cid, column_name as name, "
"data_type as type, "
"CASE WHEN is_nullable = 'NO' THEN 1 ELSE 0 END as notnull, "
"column_default as dflt_value, "
"0 as pk "
"FROM information_schema.columns "
f"WHERE table_schema = 'public' AND table_name = '{table_name}' "
"ORDER BY ordinal_position"
)
elif _RE_PRAGMA.match(translated):
is_noop = True
translated = "SELECT 1 WHERE false"
# --- CREATE TRIGGER (skip entirely) ---
if not is_noop and _RE_CREATE_TRIGGER.match(translated):
is_noop = True
translated = "SELECT 1 WHERE false"
if not is_noop:
# --- INSERT OR IGNORE ---
translated = _RE_INSERT_OR_IGNORE.sub("INSERT INTO", translated)
if "INSERT INTO" in translated and "ON CONFLICT" not in translated.upper() and original != translated:
# Only add ON CONFLICT DO NOTHING if we actually replaced INSERT OR IGNORE
# Find the end of VALUES clause
translated = _add_on_conflict_do_nothing(translated)
# --- INSERT OR REPLACE ---
m_ior = _RE_INSERT_OR_REPLACE.search(original)
if m_ior and "ON CONFLICT" not in translated.upper():
translated = _translate_insert_or_replace(original)
# --- datetime/date functions ---
translated = _RE_DATETIME_NOW_INTERVAL.sub(
lambda m: f"NOW() + INTERVAL '{m.group(1)} {m.group(2)}'", translated
)
translated = _RE_DATETIME_NOW_PARAM_INTERVAL.sub(
lambda m: "NOW() + (INTERVAL '1 " + m.group(1).rstrip("s") + "' * ?)", translated
)
translated = _RE_DATETIME_NOW.sub("NOW()", translated)
translated = _RE_DATE_NOW_INTERVAL.sub(
lambda m: f"CURRENT_DATE + INTERVAL '{m.group(1)} {m.group(2)}'", translated
)
translated = _RE_DATE_NOW.sub("CURRENT_DATE", translated)
# --- scalar MIN/MAX → LEAST/GREATEST (SQLite uses MIN/MAX for 2-arg comparison) ---
translated = _RE_SCALAR_MIN.sub(r"LEAST(\1, \2)", translated)
translated = _RE_SCALAR_MAX.sub(r"GREATEST(\1, \2)", translated)
# --- datetime(column) → column::timestamp (after datetime('now') handled above) ---
translated = _RE_DATETIME_COLUMN.sub(
lambda m: f"{m.group(1)}::timestamp", translated
)
# --- strftime('%fmt', expr) → TO_CHAR(expr::timestamp, 'pg_fmt') ---
def _translate_strftime(sql):
result = []
pos = 0
while pos < len(sql):
m = _RE_STRFTIME_PREFIX.search(sql, pos)
if not m:
result.append(sql[pos:])
break
result.append(sql[pos:m.start()])
fmt = m.group(1)
# Find matching closing paren accounting for nesting
expr_start = m.end()
depth = 1
i = expr_start
while i < len(sql) and depth > 0:
if sql[i] == '(':
depth += 1
elif sql[i] == ')':
depth -= 1
i += 1
expr = sql[expr_start:i - 1].strip()
pg_fmt = fmt
for sqlite_code, pg_code in _STRFTIME_MAP.items():
pg_fmt = pg_fmt.replace(sqlite_code, pg_code)
result.append(f"TO_CHAR(({expr})::timestamp, '{pg_fmt}')")
pos = i
return ''.join(result)
translated = _translate_strftime(translated)
# --- json_extract ---
translated = _RE_JSON_EXTRACT.sub(
lambda m: f"{m.group(1)}::jsonb->>'{m.group(2)}'", translated
)
# --- boolean = 1/0 → TRUE/FALSE for known boolean columns ---
def _bool_eq_1(m):
col = m.group(1).split('.')[-1] # strip table alias
if col in _BOOLEAN_COLUMNS:
return f"{m.group(1)} = TRUE"
return m.group(0)
def _bool_eq_0(m):
col = m.group(1).split('.')[-1]
if col in _BOOLEAN_COLUMNS:
return f"{m.group(1)} = FALSE"
return m.group(0)
translated = _RE_BOOL_EQ_1.sub(_bool_eq_1, translated)
translated = _RE_BOOL_EQ_0.sub(_bool_eq_0, translated)
# --- GROUP_CONCAT → STRING_AGG ---
translated = _RE_GROUP_CONCAT_DISTINCT.sub(
lambda m: f"STRING_AGG(DISTINCT {m.group(1)}::text, ',')", translated
)
translated = _RE_GROUP_CONCAT_SEP.sub(
lambda m: f"STRING_AGG({m.group(1)}::text, '{m.group(2)}')", translated
)
translated = _RE_GROUP_CONCAT.sub(
lambda m: f"STRING_AGG({m.group(1)}::text, ',')", translated
)
# --- IFNULL → COALESCE ---
translated = _RE_IFNULL.sub("COALESCE(", translated)
# --- strftime ---
translated = _RE_STRFTIME.sub(
lambda m: f"TO_CHAR({m.group(2)}, '{_convert_strftime_format(m.group(1))}')",
translated,
)
# --- Type replacements (mainly for CREATE TABLE) ---
translated = _RE_AUTOINCREMENT.sub("", translated)
translated = _RE_BLOB.sub("BYTEA", translated)
translated = _RE_BOOLEAN_DEFAULT_0.sub("BOOLEAN DEFAULT false", translated)
translated = _RE_BOOLEAN_DEFAULT_1.sub("BOOLEAN DEFAULT true", translated)
# DATETIME → TIMESTAMP (PG doesn't have DATETIME type)
translated = re.sub(r"\bDATETIME\b", "TIMESTAMP", translated, flags=re.IGNORECASE)
# INTEGER PRIMARY KEY AUTOINCREMENT → SERIAL PRIMARY KEY (already removed AUTOINCREMENT above)
translated = re.sub(
r"\bINTEGER\s+PRIMARY\s+KEY\b",
"SERIAL PRIMARY KEY",
translated,
flags=re.IGNORECASE,
)
# --- BEGIN / BEGIN IMMEDIATE → no-op (psycopg2 manages transactions) ---
if _RE_BEGIN_IMMEDIATE.match(translated) or re.match(r"^\s*BEGIN\s*;?\s*$", translated, re.IGNORECASE):
is_noop = True
translated = "SELECT 1 WHERE false"
# --- sqlite_master → information_schema.tables ---
if _RE_SQLITE_MASTER.search(translated):
translated = _translate_sqlite_master(translated)
# --- COLLATE NOCASE → remove ---
translated = _RE_COLLATE_NOCASE.sub("", translated)
# --- Parameter placeholder: ? → %s ---
# Must be done AFTER all other translations; skip string literals
translated = _replace_question_marks(translated)
result = (translated, is_noop, extra_params)
# Store in cache
with _sql_cache_lock:
_sql_cache[original] = result
if len(_sql_cache) > _SQL_CACHE_MAX:
_sql_cache.popitem(last=False)
return result
def _add_on_conflict_do_nothing(sql):
"""Append ON CONFLICT DO NOTHING to an INSERT statement."""
# Handle trailing semicolons
stripped = sql.rstrip().rstrip(";")
return stripped + " ON CONFLICT DO NOTHING"
def _translate_insert_or_replace(sql):
"""Translate INSERT OR REPLACE INTO table (...) VALUES (...) to PG upsert."""
m = _RE_INSERT_OR_REPLACE.search(sql)
if not m:
return sql
table_name = m.group(1).lower()
# Get the conflict columns
conflict_cols = _pk_cache.get(table_name, [])
if not conflict_cols:
# Fallback: just do INSERT ... ON CONFLICT DO NOTHING
translated = _RE_INSERT_OR_REPLACE.sub(f"INSERT INTO {table_name}", sql)
return _add_on_conflict_do_nothing(translated)
# Extract column names from the INSERT statement
translated = _RE_INSERT_OR_REPLACE.sub(f"INSERT INTO {table_name}", sql)
# Parse column list from INSERT INTO table (col1, col2, ...) VALUES (...)
col_match = re.search(
r"INSERT\s+INTO\s+\w+\s*\(([^)]+)\)",
translated,
re.IGNORECASE,
)
if not col_match:
return _add_on_conflict_do_nothing(translated)
all_cols = [c.strip().strip('"').strip("'") for c in col_match.group(1).split(",")]
update_cols = [c for c in all_cols if c.lower() not in [k.lower() for k in conflict_cols]]
if not update_cols:
return _add_on_conflict_do_nothing(translated)
conflict_str = ", ".join(conflict_cols)
update_str = ", ".join(f"{c} = EXCLUDED.{c}" for c in update_cols)
stripped = translated.rstrip().rstrip(";")
return f"{stripped} ON CONFLICT ({conflict_str}) DO UPDATE SET {update_str}"
def _translate_sqlite_master(sql):
"""Translate queries against sqlite_master to information_schema.tables."""
# Replace table name first
translated = re.sub(
r"\bsqlite_master\b",
"information_schema.tables",
sql,
flags=re.IGNORECASE,
)
# Replace type='table' with schema/type filter
translated = re.sub(
r"\btype\s*=\s*'table'",
"table_schema = 'public' AND table_type = 'BASE TABLE'",
translated,
flags=re.IGNORECASE,
)
# Remove sqlite_-prefixed filters
translated = re.sub(
r"\bAND\s+name\s+NOT\s+LIKE\s+'sqlite_%'",
"",
translated,
flags=re.IGNORECASE,
)
# Map 'name' column to 'table_name' everywhere (SELECT list, WHERE, etc.)
# But preserve 'table_name' if already present
translated = re.sub(r"\bname\b", "table_name", translated, flags=re.IGNORECASE)
# Remove 'sql' column references (not in information_schema)
translated = re.sub(r",\s*\bsql\b", "", translated, flags=re.IGNORECASE)
translated = re.sub(r"\bsql\b\s*,\s*", "", translated, flags=re.IGNORECASE)
# If 'sql' is the only selected column, replace with a placeholder
translated = re.sub(r"SELECT\s+sql\s+FROM", "SELECT '' as sql FROM", translated, flags=re.IGNORECASE)
# Fix double "table_table_name" if it occurs
translated = translated.replace("table_table_name", "table_name")
return translated
def _replace_question_marks(sql):
"""Replace ? placeholders with %s, but not inside string literals.
Also escape literal % characters for psycopg2."""
result = []
in_single_quote = False
in_double_quote = False
i = 0
while i < len(sql):
ch = sql[i]
if ch == "'" and not in_double_quote:
# Check for escaped quote ''
if in_single_quote and i + 1 < len(sql) and sql[i + 1] == "'":
result.append("''")
i += 2
continue
in_single_quote = not in_single_quote
result.append(ch)
elif ch == '"' and not in_single_quote:
in_double_quote = not in_double_quote
result.append(ch)
elif ch == "?" and not in_single_quote and not in_double_quote:
result.append("%s")
elif ch == "%":
# Escape ALL literal % for psycopg2 (including inside quotes,
# because psycopg2's parameter substitution scans the entire string)
result.append("%%")
else:
result.append(ch)
i += 1
return "".join(result)
# ---------------------------------------------------------------------------
# Row class — sqlite3.Row compatible
# ---------------------------------------------------------------------------
class Row:
"""sqlite3.Row-compatible object supporting both index and name access."""
__slots__ = ("_values", "_keys", "_key_map")
def __init__(self, cursor_description, values):
self._keys = tuple(d[0] for d in cursor_description) if cursor_description else ()
self._values = tuple(values)
self._key_map = {k.lower(): i for i, k in enumerate(self._keys)}
def __getitem__(self, key):
if isinstance(key, (int, slice)):
return self._values[key]
idx = self._key_map.get(key.lower())
if idx is None:
raise IndexError(f"No column named '{key}'")
return self._values[idx]
def __len__(self):
return len(self._values)
def __iter__(self):
return iter(self._values)
def __repr__(self):
return f"<Row {dict(zip(self._keys, self._values))}>"
def __eq__(self, other):
if isinstance(other, Row):
return self._values == other._values
if isinstance(other, tuple):
return self._values == other
return NotImplemented
def __hash__(self):
return hash(self._values)
def keys(self):
return list(self._keys)
# ---------------------------------------------------------------------------
# PgCursor — wraps psycopg2 cursor with SQL translation
# ---------------------------------------------------------------------------
class PgCursor:
"""sqlite3-compatible cursor backed by psycopg2."""
def __init__(self, pg_conn, row_factory=None):
self._pg_conn = pg_conn
self._cursor = pg_conn.cursor()
self._row_factory = row_factory
self._lastrowid = None
self._arraysize = 1
@property
def description(self):
return self._cursor.description
@property
def rowcount(self):
return self._cursor.rowcount
@property
def lastrowid(self):
return self._lastrowid
@property
def arraysize(self):
return self._arraysize
@arraysize.setter
def arraysize(self, value):
self._arraysize = value
def execute(self, sql, parameters=None):
"""Execute SQL with automatic SQLite→PG translation."""
# Ensure PK cache is loaded
_load_pk_cache(self._pg_conn)
translated, is_noop, extra_params = _translate_sql(sql)
if is_noop:
# Execute a no-op query to keep cursor in valid state
self._cursor.execute("SELECT 1 WHERE false")
return self
# Convert parameters
if parameters is not None:
if isinstance(parameters, dict):
params = parameters
else:
params = tuple(parameters)
else:
params = None
# Prepend extra parameters from SQL translation (e.g. PRAGMA table_info)
if extra_params is not None:
if params is None:
params = extra_params
elif isinstance(params, tuple):
params = extra_params + params
# Check if this is an INSERT that needs RETURNING for lastrowid
needs_returning = False
sql_upper = translated.strip().upper()
if sql_upper.startswith("INSERT") and "RETURNING" not in sql_upper:
# Only add RETURNING id if the target table has an 'id' PK column
_load_pk_cache(self._pg_conn)
table_match = re.search(r"INSERT\s+INTO\s+(\w+)", translated, re.IGNORECASE)
if table_match:
tbl = table_match.group(1).lower()
pk_cols = _pk_cache.get(tbl, [])
if pk_cols == ['id']:
needs_returning = True
translated = translated.rstrip().rstrip(";") + " RETURNING id"
# Track if this is an INSERT OR IGNORE/REPLACE for UniqueViolation handling
_is_insert_or_ignore = False
if sql_upper.startswith("INSERT"):
_is_insert_or_ignore = "OR IGNORE" in sql_upper or "OR REPLACE" in sql_upper
try:
self._cursor.execute(translated, params)
except (psycopg2.errors.DuplicateColumn, psycopg2.errors.DuplicateTable,
psycopg2.errors.DuplicateObject):
# Schema conflicts — match SQLite's silent behavior
try:
self._pg_conn.rollback()
except Exception:
pass
return self
except psycopg2.errors.UniqueViolation:
# Only swallow for INSERT OR IGNORE/REPLACE (SQLite's silent behavior)
try:
self._pg_conn.rollback()
except Exception:
pass
if _is_insert_or_ignore:
return self
# For plain INSERT, re-raise as IntegrityError so callers can detect duplicates
import sqlite3 as _sqlite3
raise _sqlite3.IntegrityError("UNIQUE constraint failed")
except psycopg2.ProgrammingError as e:
raise
if needs_returning:
row = self._cursor.fetchone()
if row:
self._lastrowid = row[0]
return self
def executemany(self, sql, seq_of_parameters):
"""Execute SQL for each set of parameters."""
_load_pk_cache(self._pg_conn)
translated, is_noop, extra_params = _translate_sql(sql)
if is_noop:
return self
# Remove any RETURNING clause for executemany (not useful)
sql_upper = translated.strip().upper()
if "RETURNING" not in sql_upper and sql_upper.startswith("INSERT"):
# Don't add RETURNING for executemany
pass
for params in seq_of_parameters:
if isinstance(params, dict):
self._cursor.execute(translated, params)
else:
p = tuple(params)
if extra_params is not None:
p = extra_params + p
self._cursor.execute(translated, p)
return self
def executescript(self, sql_script):
"""Execute multiple SQL statements separated by semicolons."""
statements = sql_script.split(";")
for stmt in statements:
stmt = stmt.strip()
if stmt:
self.execute(stmt)
return self
def fetchone(self):
row = self._cursor.fetchone()
if row is None:
return None
if self._row_factory is Row:
return Row(self._cursor.description, row)
return row
def fetchall(self):
rows = self._cursor.fetchall()
if self._row_factory is Row:
desc = self._cursor.description
return [Row(desc, r) for r in rows]
return rows
def fetchmany(self, size=None):
if size is None:
size = self._arraysize
rows = self._cursor.fetchmany(size)
if self._row_factory is Row:
desc = self._cursor.description
return [Row(desc, r) for r in rows]
return rows
def close(self):
try:
self._cursor.close()
except Exception:
pass
def __iter__(self):
return self
def __next__(self):
row = self.fetchone()
if row is None:
raise StopIteration
return row
# ---------------------------------------------------------------------------
# PgConnection — wraps psycopg2 connection with sqlite3-compatible API
# ---------------------------------------------------------------------------
class PgConnection:
"""sqlite3-compatible connection backed by psycopg2."""
def __init__(self, database=":memory:", timeout=5.0):
self._database = database
self._row_factory = None
self._isolation_level = ""
self._in_transaction = False
self._pg_conn = self._get_valid_conn()
self._pg_conn.autocommit = True # Match sqlite3 isolation_level=None behavior
@staticmethod
def _check_conn(conn):
"""Test if a connection is alive. Returns True if healthy."""
try:
# Use pg status check — doesn't start a transaction
if conn.closed:
return False
conn.isolation_level # Triggers a server round-trip if needed
cur = conn.cursor()
cur.execute("SELECT 1")
cur.fetchone()
cur.close()
# Roll back any implicit transaction the health check may have opened
if conn.status != psycopg2.extensions.STATUS_READY:
conn.rollback()
return True
except Exception:
return False
def _get_valid_conn(self):
"""Get a connection from the pool, validating it's alive."""
pool = _get_pool()
conn = pool.getconn()
if self._check_conn(conn):
return conn
# Connection is stale — discard it and get a fresh one
try:
pool.putconn(conn, close=True)
except Exception:
try:
conn.close()
except Exception:
pass
conn = pool.getconn()
if not self._check_conn(conn):
raise OperationalError("Unable to obtain a valid database connection from pool")
return conn
def _ensure_conn(self):
"""Ensure the connection is still alive, reconnect if not."""
if self._pg_conn is None or self._pg_conn.closed or not self._check_conn(self._pg_conn):
# Connection gone — return it and get a new one
if self._pg_conn is not None:
try:
pool = _get_pool()
pool.putconn(self._pg_conn, close=True)
except Exception:
try:
self._pg_conn.close()
except Exception:
pass
self._pg_conn = self._get_valid_conn()
self._pg_conn.autocommit = True
@property
def row_factory(self):
return self._row_factory
@row_factory.setter
def row_factory(self, value):
self._row_factory = value
@property
def isolation_level(self):
return self._isolation_level
@isolation_level.setter
def isolation_level(self, value):
self._isolation_level = value
if value is None or value == "":
self._pg_conn.autocommit = True
else:
self._pg_conn.autocommit = False
@property
def in_transaction(self):
return self._in_transaction
@property
def total_changes(self):
return 0 # Not tracked
def cursor(self):
self._ensure_conn()
return PgCursor(self._pg_conn, row_factory=self._row_factory)
def execute(self, sql, parameters=None):
"""Convenience method — creates a cursor, executes, returns it."""
# Track explicit BEGIN for proper commit/rollback handling
sql_stripped = sql.strip().upper()
if sql_stripped.startswith("BEGIN"):
self._in_transaction = True
cur = self.cursor()
cur.execute(sql, parameters)
return cur
def executemany(self, sql, seq_of_parameters):
cur = self.cursor()
cur.executemany(sql, seq_of_parameters)
return cur
def executescript(self, sql_script):
cur = self.cursor()
cur.executescript(sql_script)
return cur
def commit(self):
if self._pg_conn is not None:
if self._in_transaction:
# Explicit transaction started by BEGIN — send COMMIT as SQL
# (psycopg2.commit() is a no-op when autocommit=True)
cur = self._pg_conn.cursor()
try:
cur.execute("COMMIT")
except Exception:
pass
finally:
cur.close()
else:
self._pg_conn.commit()
self._in_transaction = False
def rollback(self):
if self._pg_conn is not None:
if self._in_transaction:
cur = self._pg_conn.cursor()
try:
cur.execute("ROLLBACK")
except Exception:
pass
finally:
cur.close()
else:
self._pg_conn.rollback()
self._in_transaction = False
def close(self):
if self._pg_conn is not None:
try:
pool = _get_pool()
pool.putconn(self._pg_conn)
except Exception:
try:
self._pg_conn.close()
except Exception:
pass
self._pg_conn = None
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
self.rollback()
else:
self.commit()
return False
def __del__(self):
try:
self.close()
except Exception:
pass # Suppress errors during GC teardown
# --- Properties for compatibility ---
def set_trace_callback(self, callback):
pass # No-op
def create_function(self, name, num_params, func, deterministic=False):
pass # No-op — PG has its own functions
def create_aggregate(self, name, num_params, aggregate_class):
pass # No-op
def create_collation(self, name, callable_):
pass # No-op
def interrupt(self):
if self._pg_conn:
self._pg_conn.cancel()
def set_authorizer(self, authorizer_callback):
pass # No-op
def set_progress_handler(self, handler, n):
pass # No-op
def enable_load_extension(self, enabled):
pass # No-op
def load_extension(self, path):
pass # No-op
def iterdump(self):
raise NotSupportedError("iterdump() not supported in PostgreSQL adapter")
def backup(self, target, *, pages=-1, progress=None, name="main", sleep=0.250):
raise NotSupportedError("backup() not supported in PostgreSQL adapter")
# ---------------------------------------------------------------------------
# sqlite3 API type aliases
# ---------------------------------------------------------------------------
Connection = PgConnection
Cursor = PgCursor