1026 lines
35 KiB
Python
1026 lines
35 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
PostgreSQL Adapter — Drop-in sqlite3 replacement module.
|
|
|
|
When DATABASE_BACKEND=postgresql, this module is monkey-patched into sys.modules['sqlite3'].
|
|
Every existing `import sqlite3; sqlite3.connect(...)` transparently routes through psycopg2
|
|
with automatic SQL dialect translation. Zero existing queries need to change.
|
|
|
|
Architecture:
|
|
connect(db_path) → PgConnection(psycopg2_conn)
|
|
→ PgCursor.execute(sql, params)
|
|
→ _translate_sql(sql) [cached]
|
|
→ psycopg2 cursor
|
|
"""
|
|
|
|
import os
|
|
import re
|
|
import threading
|
|
from collections import OrderedDict
|
|
|
|
import psycopg2
|
|
import psycopg2.errors
|
|
import psycopg2.extensions
|
|
import psycopg2.pool
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Module-level constants (sqlite3 API compatibility)
|
|
# ---------------------------------------------------------------------------
|
|
PARSE_DECLTYPES = 1
|
|
PARSE_COLNAMES = 2
|
|
SQLITE_OK = 0
|
|
SQLITE_ERROR = 1
|
|
sqlite_version = "3.45.0"
|
|
sqlite_version_info = (3, 45, 0)
|
|
version = "2.6.0"
|
|
version_info = (2, 6, 0)
|
|
paramstyle = "qmark"
|
|
apilevel = "2.0"
|
|
threadsafety = 1
|
|
|
|
# Re-export exception types so `sqlite3.OperationalError` etc. work
|
|
Error = psycopg2.Error
|
|
OperationalError = psycopg2.OperationalError
|
|
IntegrityError = psycopg2.IntegrityError
|
|
InterfaceError = psycopg2.InterfaceError
|
|
DatabaseError = psycopg2.DatabaseError
|
|
DataError = psycopg2.DataError
|
|
InternalError = psycopg2.InternalError
|
|
NotSupportedError = psycopg2.NotSupportedError
|
|
ProgrammingError = psycopg2.ProgrammingError
|
|
Warning = psycopg2.Warning
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Connection pool (singleton)
|
|
# ---------------------------------------------------------------------------
|
|
_DATABASE_URL = os.getenv(
|
|
"DATABASE_URL",
|
|
"postgresql://media_downloader:PNsihOXvvuPwWiIvGlsc9Fh2YmMmB@localhost/media_downloader",
|
|
)
|
|
_pool = None
|
|
_pool_lock = threading.Lock()
|
|
|
|
# Adapter/converter registries (sqlite3 compat — mostly no-ops for PG)
|
|
_adapters = {}
|
|
_converters = {}
|
|
|
|
|
|
def register_adapter(type_, adapter):
|
|
"""sqlite3.register_adapter compatibility — stored but PG handles types natively."""
|
|
_adapters[type_] = adapter
|
|
|
|
|
|
def register_converter(typename, converter):
|
|
"""sqlite3.register_converter compatibility — stored but PG handles types natively."""
|
|
_converters[typename] = converter
|
|
|
|
|
|
def _get_pool():
|
|
global _pool
|
|
if _pool is None:
|
|
with _pool_lock:
|
|
if _pool is None:
|
|
_pool = psycopg2.pool.ThreadedConnectionPool(
|
|
minconn=2,
|
|
maxconn=30,
|
|
dsn=_DATABASE_URL,
|
|
)
|
|
return _pool
|
|
|
|
|
|
def connect(database=":memory:", timeout=5.0, detect_types=0,
|
|
isolation_level="", check_same_thread=True,
|
|
factory=None, cached_statements=128, uri=False):
|
|
"""sqlite3.connect() replacement — returns a PgConnection."""
|
|
return PgConnection(database, timeout)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Unique-constraint cache (for INSERT OR REPLACE translation)
|
|
# ---------------------------------------------------------------------------
|
|
_pk_cache = {} # table_name -> [column_names]
|
|
_pk_cache_lock = threading.Lock()
|
|
_pk_cache_loaded = False
|
|
|
|
|
|
def _load_pk_cache(pg_conn):
|
|
"""Load primary-key / unique-constraint info from information_schema."""
|
|
global _pk_cache_loaded
|
|
if _pk_cache_loaded:
|
|
return
|
|
with _pk_cache_lock:
|
|
if _pk_cache_loaded:
|
|
return
|
|
cur = pg_conn.cursor()
|
|
cur.execute("""
|
|
SELECT tc.table_name, kcu.column_name, tc.constraint_type
|
|
FROM information_schema.table_constraints tc
|
|
JOIN information_schema.key_column_usage kcu
|
|
ON tc.constraint_name = kcu.constraint_name
|
|
AND tc.table_schema = kcu.table_schema
|
|
WHERE tc.table_schema = 'public'
|
|
AND tc.constraint_type IN ('PRIMARY KEY', 'UNIQUE')
|
|
ORDER BY tc.table_name, tc.constraint_type, kcu.ordinal_position
|
|
""")
|
|
pk_map = {} # table -> [pk_cols]
|
|
uniq_map = {} # table -> [uniq_cols] (first unique constraint)
|
|
for table, col, ctype in cur.fetchall():
|
|
if ctype == "PRIMARY KEY":
|
|
pk_map.setdefault(table, []).append(col)
|
|
else:
|
|
uniq_map.setdefault(table, []).append(col)
|
|
cur.close()
|
|
# Prefer PK; fallback to first UNIQUE constraint
|
|
for t in set(list(pk_map.keys()) + list(uniq_map.keys())):
|
|
_pk_cache[t] = pk_map.get(t, uniq_map.get(t, []))
|
|
_pk_cache_loaded = True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SQL Translation Engine (cached)
|
|
# ---------------------------------------------------------------------------
|
|
_sql_cache = OrderedDict()
|
|
_SQL_CACHE_MAX = 4096
|
|
_sql_cache_lock = threading.Lock()
|
|
|
|
# Pre-compiled regexes
|
|
_RE_PRAGMA = re.compile(
|
|
r"^\s*PRAGMA\s+(journal_mode|busy_timeout|synchronous|foreign_keys|wal_checkpoint|cache_size|temp_store)\b",
|
|
re.IGNORECASE,
|
|
)
|
|
_RE_PRAGMA_TABLE_INFO = re.compile(
|
|
r"^\s*PRAGMA\s+table_info\s*\(\s*['\"]?(\w+)['\"]?\s*\)",
|
|
re.IGNORECASE,
|
|
)
|
|
_RE_INSERT_OR_IGNORE = re.compile(
|
|
r"\bINSERT\s+OR\s+IGNORE\s+INTO\b",
|
|
re.IGNORECASE,
|
|
)
|
|
_RE_INSERT_OR_REPLACE = re.compile(
|
|
r"\bINSERT\s+OR\s+REPLACE\s+INTO\s+(\w+)\b",
|
|
re.IGNORECASE,
|
|
)
|
|
_RE_DATETIME_NOW = re.compile(
|
|
r"datetime\s*\(\s*'now'\s*(?:,\s*'localtime'\s*)?\)",
|
|
re.IGNORECASE,
|
|
)
|
|
_RE_DATETIME_NOW_INTERVAL = re.compile(
|
|
r"datetime\s*\(\s*'now'\s*(?:,\s*'localtime'\s*)?,\s*'(-?\d+)\s+(day|days|hour|hours|minute|minutes|second|seconds|month|months)'\s*\)",
|
|
re.IGNORECASE,
|
|
)
|
|
_RE_DATETIME_NOW_PARAM_INTERVAL = re.compile(
|
|
r"datetime\s*\(\s*'now'\s*,\s*\?\s*\|\|\s*'\s*(day|days|hour|hours|minute|minutes|second|seconds)'\s*\)",
|
|
re.IGNORECASE,
|
|
)
|
|
_RE_DATE_NOW = re.compile(
|
|
r"date\s*\(\s*'now'\s*\)",
|
|
re.IGNORECASE,
|
|
)
|
|
_RE_DATE_NOW_INTERVAL = re.compile(
|
|
r"date\s*\(\s*'now'\s*,\s*'(-?\d+)\s+(day|days)'\s*\)",
|
|
re.IGNORECASE,
|
|
)
|
|
_RE_SCALAR_MIN = re.compile(
|
|
r"\bMIN\s*\(\s*([^,()]+)\s*,\s*([^()]+)\s*\)",
|
|
re.IGNORECASE,
|
|
)
|
|
_RE_SCALAR_MAX = re.compile(
|
|
r"\bMAX\s*\(\s*([^,()]+)\s*,\s*([^()]+)\s*\)",
|
|
re.IGNORECASE,
|
|
)
|
|
_RE_JSON_EXTRACT = re.compile(
|
|
r"json_extract\s*\(\s*(\w+(?:\.\w+)?)\s*,\s*'\$\.(\w+)'\s*\)",
|
|
re.IGNORECASE,
|
|
)
|
|
# strftime('%format', expr) → TO_CHAR(expr, 'pg_format')
|
|
# Uses a function to handle nested parentheses instead of a simple regex
|
|
_RE_STRFTIME_PREFIX = re.compile(
|
|
r"strftime\s*\(\s*'([^']+)'\s*,\s*",
|
|
re.IGNORECASE,
|
|
)
|
|
# datetime(column_expr) → cast as timestamp (not datetime('now') which is handled separately)
|
|
_RE_DATETIME_COLUMN = re.compile(
|
|
r"datetime\s*\(\s*(\w+(?:\.\w+)?|\?)\s*\)",
|
|
re.IGNORECASE,
|
|
)
|
|
_BOOLEAN_COLUMNS = {'has_match', 'notified', 'downloaded', 'has_images'}
|
|
_RE_BOOL_EQ_1 = re.compile(
|
|
r"\b(\w+(?:\.\w+)?)\s*=\s*1\b",
|
|
)
|
|
_RE_BOOL_EQ_0 = re.compile(
|
|
r"\b(\w+(?:\.\w+)?)\s*=\s*0\b",
|
|
)
|
|
_RE_GROUP_CONCAT_DISTINCT = re.compile(
|
|
r"GROUP_CONCAT\s*\(\s*DISTINCT\s+([\w.]+)\s*\)",
|
|
re.IGNORECASE,
|
|
)
|
|
_RE_GROUP_CONCAT_SEP = re.compile(
|
|
r"GROUP_CONCAT\s*\(\s*([\w.]+)\s*,\s*'([^']*)'\s*\)",
|
|
re.IGNORECASE,
|
|
)
|
|
_RE_GROUP_CONCAT = re.compile(
|
|
r"GROUP_CONCAT\s*\(\s*([\w.]+)\s*\)",
|
|
re.IGNORECASE,
|
|
)
|
|
_RE_AUTOINCREMENT = re.compile(r"\bAUTOINCREMENT\b", re.IGNORECASE)
|
|
_RE_BLOB = re.compile(r"\bBLOB\b", re.IGNORECASE)
|
|
_RE_BEGIN_IMMEDIATE = re.compile(r"^\s*BEGIN\s+IMMEDIATE\b", re.IGNORECASE)
|
|
_RE_SQLITE_MASTER = re.compile(
|
|
r"\bsqlite_master\b",
|
|
re.IGNORECASE,
|
|
)
|
|
_RE_COLLATE_NOCASE = re.compile(r"\s*COLLATE\s+NOCASE\b", re.IGNORECASE)
|
|
_RE_IFNULL = re.compile(r"\bIFNULL\s*\(", re.IGNORECASE)
|
|
_RE_STRFTIME = re.compile(
|
|
r"strftime\s*\(\s*'([^']+)'\s*,\s*(\w+(?:\.\w+)?)\s*\)",
|
|
re.IGNORECASE,
|
|
)
|
|
_RE_CREATE_TRIGGER = re.compile(
|
|
r"^\s*CREATE\s+TRIGGER\b.*?\bEND\s*;?\s*$",
|
|
re.IGNORECASE | re.DOTALL,
|
|
)
|
|
_RE_BOOLEAN_DEFAULT_0 = re.compile(
|
|
r"\bBOOLEAN\s+DEFAULT\s+0\b",
|
|
re.IGNORECASE,
|
|
)
|
|
_RE_BOOLEAN_DEFAULT_1 = re.compile(
|
|
r"\bBOOLEAN\s+DEFAULT\s+1\b",
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
# strftime format map: SQLite → PostgreSQL TO_CHAR
|
|
_STRFTIME_MAP = {
|
|
"%Y": "YYYY",
|
|
"%m": "MM",
|
|
"%d": "DD",
|
|
"%H": "HH24",
|
|
"%M": "MI",
|
|
"%S": "SS",
|
|
"%W": "IW",
|
|
"%j": "DDD",
|
|
}
|
|
|
|
|
|
def _convert_strftime_format(sqlite_fmt):
|
|
"""Convert a SQLite strftime format string to PostgreSQL TO_CHAR format."""
|
|
pg_fmt = sqlite_fmt
|
|
for s_code, p_code in _STRFTIME_MAP.items():
|
|
pg_fmt = pg_fmt.replace(s_code, p_code)
|
|
return pg_fmt
|
|
|
|
|
|
def _translate_sql(sql):
|
|
"""
|
|
Translate a SQLite SQL string to PostgreSQL dialect.
|
|
Returns (translated_sql, is_noop, extra_params) where is_noop=True means
|
|
skip execution. extra_params is a tuple of additional parameters that must
|
|
be prepended to any user-supplied parameters (used for parameterised
|
|
PRAGMA table_info translation).
|
|
"""
|
|
original = sql
|
|
|
|
# Check cache first
|
|
with _sql_cache_lock:
|
|
cached = _sql_cache.get(original)
|
|
if cached is not None:
|
|
# Move to end (LRU)
|
|
_sql_cache.move_to_end(original)
|
|
return cached
|
|
|
|
is_noop = False
|
|
extra_params = None
|
|
translated = sql
|
|
|
|
# --- PRAGMA handling ---
|
|
# Check for PRAGMA table_info first (special case that returns data)
|
|
m = _RE_PRAGMA_TABLE_INFO.match(translated)
|
|
if m:
|
|
# table_name is safe: regex \w+ guarantees only alphanumeric/underscore
|
|
table_name = m.group(1)
|
|
translated = (
|
|
"SELECT ordinal_position - 1 as cid, column_name as name, "
|
|
"data_type as type, "
|
|
"CASE WHEN is_nullable = 'NO' THEN 1 ELSE 0 END as notnull, "
|
|
"column_default as dflt_value, "
|
|
"0 as pk "
|
|
"FROM information_schema.columns "
|
|
f"WHERE table_schema = 'public' AND table_name = '{table_name}' "
|
|
"ORDER BY ordinal_position"
|
|
)
|
|
elif _RE_PRAGMA.match(translated):
|
|
is_noop = True
|
|
translated = "SELECT 1 WHERE false"
|
|
|
|
# --- CREATE TRIGGER (skip entirely) ---
|
|
if not is_noop and _RE_CREATE_TRIGGER.match(translated):
|
|
is_noop = True
|
|
translated = "SELECT 1 WHERE false"
|
|
|
|
if not is_noop:
|
|
# --- INSERT OR IGNORE ---
|
|
translated = _RE_INSERT_OR_IGNORE.sub("INSERT INTO", translated)
|
|
if "INSERT INTO" in translated and "ON CONFLICT" not in translated.upper() and original != translated:
|
|
# Only add ON CONFLICT DO NOTHING if we actually replaced INSERT OR IGNORE
|
|
# Find the end of VALUES clause
|
|
translated = _add_on_conflict_do_nothing(translated)
|
|
|
|
# --- INSERT OR REPLACE ---
|
|
m_ior = _RE_INSERT_OR_REPLACE.search(original)
|
|
if m_ior and "ON CONFLICT" not in translated.upper():
|
|
translated = _translate_insert_or_replace(original)
|
|
|
|
# --- datetime/date functions ---
|
|
translated = _RE_DATETIME_NOW_INTERVAL.sub(
|
|
lambda m: f"NOW() + INTERVAL '{m.group(1)} {m.group(2)}'", translated
|
|
)
|
|
translated = _RE_DATETIME_NOW_PARAM_INTERVAL.sub(
|
|
lambda m: "NOW() + (INTERVAL '1 " + m.group(1).rstrip("s") + "' * ?)", translated
|
|
)
|
|
translated = _RE_DATETIME_NOW.sub("NOW()", translated)
|
|
translated = _RE_DATE_NOW_INTERVAL.sub(
|
|
lambda m: f"CURRENT_DATE + INTERVAL '{m.group(1)} {m.group(2)}'", translated
|
|
)
|
|
translated = _RE_DATE_NOW.sub("CURRENT_DATE", translated)
|
|
|
|
# --- scalar MIN/MAX → LEAST/GREATEST (SQLite uses MIN/MAX for 2-arg comparison) ---
|
|
translated = _RE_SCALAR_MIN.sub(r"LEAST(\1, \2)", translated)
|
|
translated = _RE_SCALAR_MAX.sub(r"GREATEST(\1, \2)", translated)
|
|
|
|
# --- datetime(column) → column::timestamp (after datetime('now') handled above) ---
|
|
translated = _RE_DATETIME_COLUMN.sub(
|
|
lambda m: f"{m.group(1)}::timestamp", translated
|
|
)
|
|
|
|
# --- strftime('%fmt', expr) → TO_CHAR(expr::timestamp, 'pg_fmt') ---
|
|
def _translate_strftime(sql):
|
|
result = []
|
|
pos = 0
|
|
while pos < len(sql):
|
|
m = _RE_STRFTIME_PREFIX.search(sql, pos)
|
|
if not m:
|
|
result.append(sql[pos:])
|
|
break
|
|
result.append(sql[pos:m.start()])
|
|
fmt = m.group(1)
|
|
# Find matching closing paren accounting for nesting
|
|
expr_start = m.end()
|
|
depth = 1
|
|
i = expr_start
|
|
while i < len(sql) and depth > 0:
|
|
if sql[i] == '(':
|
|
depth += 1
|
|
elif sql[i] == ')':
|
|
depth -= 1
|
|
i += 1
|
|
expr = sql[expr_start:i - 1].strip()
|
|
pg_fmt = fmt
|
|
for sqlite_code, pg_code in _STRFTIME_MAP.items():
|
|
pg_fmt = pg_fmt.replace(sqlite_code, pg_code)
|
|
result.append(f"TO_CHAR(({expr})::timestamp, '{pg_fmt}')")
|
|
pos = i
|
|
return ''.join(result)
|
|
translated = _translate_strftime(translated)
|
|
|
|
# --- json_extract ---
|
|
translated = _RE_JSON_EXTRACT.sub(
|
|
lambda m: f"{m.group(1)}::jsonb->>'{m.group(2)}'", translated
|
|
)
|
|
|
|
# --- boolean = 1/0 → TRUE/FALSE for known boolean columns ---
|
|
def _bool_eq_1(m):
|
|
col = m.group(1).split('.')[-1] # strip table alias
|
|
if col in _BOOLEAN_COLUMNS:
|
|
return f"{m.group(1)} = TRUE"
|
|
return m.group(0)
|
|
def _bool_eq_0(m):
|
|
col = m.group(1).split('.')[-1]
|
|
if col in _BOOLEAN_COLUMNS:
|
|
return f"{m.group(1)} = FALSE"
|
|
return m.group(0)
|
|
translated = _RE_BOOL_EQ_1.sub(_bool_eq_1, translated)
|
|
translated = _RE_BOOL_EQ_0.sub(_bool_eq_0, translated)
|
|
|
|
# --- GROUP_CONCAT → STRING_AGG ---
|
|
translated = _RE_GROUP_CONCAT_DISTINCT.sub(
|
|
lambda m: f"STRING_AGG(DISTINCT {m.group(1)}::text, ',')", translated
|
|
)
|
|
translated = _RE_GROUP_CONCAT_SEP.sub(
|
|
lambda m: f"STRING_AGG({m.group(1)}::text, '{m.group(2)}')", translated
|
|
)
|
|
translated = _RE_GROUP_CONCAT.sub(
|
|
lambda m: f"STRING_AGG({m.group(1)}::text, ',')", translated
|
|
)
|
|
|
|
# --- IFNULL → COALESCE ---
|
|
translated = _RE_IFNULL.sub("COALESCE(", translated)
|
|
|
|
# --- strftime ---
|
|
translated = _RE_STRFTIME.sub(
|
|
lambda m: f"TO_CHAR({m.group(2)}, '{_convert_strftime_format(m.group(1))}')",
|
|
translated,
|
|
)
|
|
|
|
# --- Type replacements (mainly for CREATE TABLE) ---
|
|
translated = _RE_AUTOINCREMENT.sub("", translated)
|
|
translated = _RE_BLOB.sub("BYTEA", translated)
|
|
translated = _RE_BOOLEAN_DEFAULT_0.sub("BOOLEAN DEFAULT false", translated)
|
|
translated = _RE_BOOLEAN_DEFAULT_1.sub("BOOLEAN DEFAULT true", translated)
|
|
# DATETIME → TIMESTAMP (PG doesn't have DATETIME type)
|
|
translated = re.sub(r"\bDATETIME\b", "TIMESTAMP", translated, flags=re.IGNORECASE)
|
|
# INTEGER PRIMARY KEY AUTOINCREMENT → SERIAL PRIMARY KEY (already removed AUTOINCREMENT above)
|
|
translated = re.sub(
|
|
r"\bINTEGER\s+PRIMARY\s+KEY\b",
|
|
"SERIAL PRIMARY KEY",
|
|
translated,
|
|
flags=re.IGNORECASE,
|
|
)
|
|
|
|
# --- BEGIN / BEGIN IMMEDIATE → no-op (psycopg2 manages transactions) ---
|
|
if _RE_BEGIN_IMMEDIATE.match(translated) or re.match(r"^\s*BEGIN\s*;?\s*$", translated, re.IGNORECASE):
|
|
is_noop = True
|
|
translated = "SELECT 1 WHERE false"
|
|
|
|
# --- sqlite_master → information_schema.tables ---
|
|
if _RE_SQLITE_MASTER.search(translated):
|
|
translated = _translate_sqlite_master(translated)
|
|
|
|
# --- COLLATE NOCASE → remove ---
|
|
translated = _RE_COLLATE_NOCASE.sub("", translated)
|
|
|
|
# --- Parameter placeholder: ? → %s ---
|
|
# Must be done AFTER all other translations; skip string literals
|
|
translated = _replace_question_marks(translated)
|
|
|
|
result = (translated, is_noop, extra_params)
|
|
|
|
# Store in cache
|
|
with _sql_cache_lock:
|
|
_sql_cache[original] = result
|
|
if len(_sql_cache) > _SQL_CACHE_MAX:
|
|
_sql_cache.popitem(last=False)
|
|
|
|
return result
|
|
|
|
|
|
def _add_on_conflict_do_nothing(sql):
|
|
"""Append ON CONFLICT DO NOTHING to an INSERT statement."""
|
|
# Handle trailing semicolons
|
|
stripped = sql.rstrip().rstrip(";")
|
|
return stripped + " ON CONFLICT DO NOTHING"
|
|
|
|
|
|
def _translate_insert_or_replace(sql):
|
|
"""Translate INSERT OR REPLACE INTO table (...) VALUES (...) to PG upsert."""
|
|
m = _RE_INSERT_OR_REPLACE.search(sql)
|
|
if not m:
|
|
return sql
|
|
table_name = m.group(1).lower()
|
|
|
|
# Get the conflict columns
|
|
conflict_cols = _pk_cache.get(table_name, [])
|
|
if not conflict_cols:
|
|
# Fallback: just do INSERT ... ON CONFLICT DO NOTHING
|
|
translated = _RE_INSERT_OR_REPLACE.sub(f"INSERT INTO {table_name}", sql)
|
|
return _add_on_conflict_do_nothing(translated)
|
|
|
|
# Extract column names from the INSERT statement
|
|
translated = _RE_INSERT_OR_REPLACE.sub(f"INSERT INTO {table_name}", sql)
|
|
|
|
# Parse column list from INSERT INTO table (col1, col2, ...) VALUES (...)
|
|
col_match = re.search(
|
|
r"INSERT\s+INTO\s+\w+\s*\(([^)]+)\)",
|
|
translated,
|
|
re.IGNORECASE,
|
|
)
|
|
if not col_match:
|
|
return _add_on_conflict_do_nothing(translated)
|
|
|
|
all_cols = [c.strip().strip('"').strip("'") for c in col_match.group(1).split(",")]
|
|
update_cols = [c for c in all_cols if c.lower() not in [k.lower() for k in conflict_cols]]
|
|
|
|
if not update_cols:
|
|
return _add_on_conflict_do_nothing(translated)
|
|
|
|
conflict_str = ", ".join(conflict_cols)
|
|
update_str = ", ".join(f"{c} = EXCLUDED.{c}" for c in update_cols)
|
|
|
|
stripped = translated.rstrip().rstrip(";")
|
|
return f"{stripped} ON CONFLICT ({conflict_str}) DO UPDATE SET {update_str}"
|
|
|
|
|
|
def _translate_sqlite_master(sql):
|
|
"""Translate queries against sqlite_master to information_schema.tables."""
|
|
# Replace table name first
|
|
translated = re.sub(
|
|
r"\bsqlite_master\b",
|
|
"information_schema.tables",
|
|
sql,
|
|
flags=re.IGNORECASE,
|
|
)
|
|
# Replace type='table' with schema/type filter
|
|
translated = re.sub(
|
|
r"\btype\s*=\s*'table'",
|
|
"table_schema = 'public' AND table_type = 'BASE TABLE'",
|
|
translated,
|
|
flags=re.IGNORECASE,
|
|
)
|
|
# Remove sqlite_-prefixed filters
|
|
translated = re.sub(
|
|
r"\bAND\s+name\s+NOT\s+LIKE\s+'sqlite_%'",
|
|
"",
|
|
translated,
|
|
flags=re.IGNORECASE,
|
|
)
|
|
# Map 'name' column to 'table_name' everywhere (SELECT list, WHERE, etc.)
|
|
# But preserve 'table_name' if already present
|
|
translated = re.sub(r"\bname\b", "table_name", translated, flags=re.IGNORECASE)
|
|
# Remove 'sql' column references (not in information_schema)
|
|
translated = re.sub(r",\s*\bsql\b", "", translated, flags=re.IGNORECASE)
|
|
translated = re.sub(r"\bsql\b\s*,\s*", "", translated, flags=re.IGNORECASE)
|
|
# If 'sql' is the only selected column, replace with a placeholder
|
|
translated = re.sub(r"SELECT\s+sql\s+FROM", "SELECT '' as sql FROM", translated, flags=re.IGNORECASE)
|
|
# Fix double "table_table_name" if it occurs
|
|
translated = translated.replace("table_table_name", "table_name")
|
|
return translated
|
|
|
|
|
|
def _replace_question_marks(sql):
|
|
"""Replace ? placeholders with %s, but not inside string literals.
|
|
Also escape literal % characters for psycopg2."""
|
|
result = []
|
|
in_single_quote = False
|
|
in_double_quote = False
|
|
i = 0
|
|
while i < len(sql):
|
|
ch = sql[i]
|
|
if ch == "'" and not in_double_quote:
|
|
# Check for escaped quote ''
|
|
if in_single_quote and i + 1 < len(sql) and sql[i + 1] == "'":
|
|
result.append("''")
|
|
i += 2
|
|
continue
|
|
in_single_quote = not in_single_quote
|
|
result.append(ch)
|
|
elif ch == '"' and not in_single_quote:
|
|
in_double_quote = not in_double_quote
|
|
result.append(ch)
|
|
elif ch == "?" and not in_single_quote and not in_double_quote:
|
|
result.append("%s")
|
|
elif ch == "%":
|
|
# Escape ALL literal % for psycopg2 (including inside quotes,
|
|
# because psycopg2's parameter substitution scans the entire string)
|
|
result.append("%%")
|
|
else:
|
|
result.append(ch)
|
|
i += 1
|
|
return "".join(result)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Row class — sqlite3.Row compatible
|
|
# ---------------------------------------------------------------------------
|
|
class Row:
|
|
"""sqlite3.Row-compatible object supporting both index and name access."""
|
|
|
|
__slots__ = ("_values", "_keys", "_key_map")
|
|
|
|
def __init__(self, cursor_description, values):
|
|
self._keys = tuple(d[0] for d in cursor_description) if cursor_description else ()
|
|
self._values = tuple(values)
|
|
self._key_map = {k.lower(): i for i, k in enumerate(self._keys)}
|
|
|
|
def __getitem__(self, key):
|
|
if isinstance(key, (int, slice)):
|
|
return self._values[key]
|
|
idx = self._key_map.get(key.lower())
|
|
if idx is None:
|
|
raise IndexError(f"No column named '{key}'")
|
|
return self._values[idx]
|
|
|
|
def __len__(self):
|
|
return len(self._values)
|
|
|
|
def __iter__(self):
|
|
return iter(self._values)
|
|
|
|
def __repr__(self):
|
|
return f"<Row {dict(zip(self._keys, self._values))}>"
|
|
|
|
def __eq__(self, other):
|
|
if isinstance(other, Row):
|
|
return self._values == other._values
|
|
if isinstance(other, tuple):
|
|
return self._values == other
|
|
return NotImplemented
|
|
|
|
def __hash__(self):
|
|
return hash(self._values)
|
|
|
|
def keys(self):
|
|
return list(self._keys)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# PgCursor — wraps psycopg2 cursor with SQL translation
|
|
# ---------------------------------------------------------------------------
|
|
class PgCursor:
|
|
"""sqlite3-compatible cursor backed by psycopg2."""
|
|
|
|
def __init__(self, pg_conn, row_factory=None):
|
|
self._pg_conn = pg_conn
|
|
self._cursor = pg_conn.cursor()
|
|
self._row_factory = row_factory
|
|
self._lastrowid = None
|
|
self._arraysize = 1
|
|
|
|
@property
|
|
def description(self):
|
|
return self._cursor.description
|
|
|
|
@property
|
|
def rowcount(self):
|
|
return self._cursor.rowcount
|
|
|
|
@property
|
|
def lastrowid(self):
|
|
return self._lastrowid
|
|
|
|
@property
|
|
def arraysize(self):
|
|
return self._arraysize
|
|
|
|
@arraysize.setter
|
|
def arraysize(self, value):
|
|
self._arraysize = value
|
|
|
|
def execute(self, sql, parameters=None):
|
|
"""Execute SQL with automatic SQLite→PG translation."""
|
|
# Ensure PK cache is loaded
|
|
_load_pk_cache(self._pg_conn)
|
|
|
|
translated, is_noop, extra_params = _translate_sql(sql)
|
|
|
|
if is_noop:
|
|
# Execute a no-op query to keep cursor in valid state
|
|
self._cursor.execute("SELECT 1 WHERE false")
|
|
return self
|
|
|
|
# Convert parameters
|
|
if parameters is not None:
|
|
if isinstance(parameters, dict):
|
|
params = parameters
|
|
else:
|
|
params = tuple(parameters)
|
|
else:
|
|
params = None
|
|
|
|
# Prepend extra parameters from SQL translation (e.g. PRAGMA table_info)
|
|
if extra_params is not None:
|
|
if params is None:
|
|
params = extra_params
|
|
elif isinstance(params, tuple):
|
|
params = extra_params + params
|
|
|
|
# Check if this is an INSERT that needs RETURNING for lastrowid
|
|
needs_returning = False
|
|
sql_upper = translated.strip().upper()
|
|
if sql_upper.startswith("INSERT") and "RETURNING" not in sql_upper:
|
|
# Only add RETURNING id if the target table has an 'id' PK column
|
|
_load_pk_cache(self._pg_conn)
|
|
table_match = re.search(r"INSERT\s+INTO\s+(\w+)", translated, re.IGNORECASE)
|
|
if table_match:
|
|
tbl = table_match.group(1).lower()
|
|
pk_cols = _pk_cache.get(tbl, [])
|
|
if pk_cols == ['id']:
|
|
needs_returning = True
|
|
translated = translated.rstrip().rstrip(";") + " RETURNING id"
|
|
|
|
# Track if this is an INSERT OR IGNORE/REPLACE for UniqueViolation handling
|
|
_is_insert_or_ignore = False
|
|
if sql_upper.startswith("INSERT"):
|
|
_is_insert_or_ignore = "OR IGNORE" in sql_upper or "OR REPLACE" in sql_upper
|
|
|
|
try:
|
|
self._cursor.execute(translated, params)
|
|
except (psycopg2.errors.DuplicateColumn, psycopg2.errors.DuplicateTable,
|
|
psycopg2.errors.DuplicateObject):
|
|
# Schema conflicts — match SQLite's silent behavior
|
|
try:
|
|
self._pg_conn.rollback()
|
|
except Exception:
|
|
pass
|
|
return self
|
|
except psycopg2.errors.UniqueViolation:
|
|
# Only swallow for INSERT OR IGNORE/REPLACE (SQLite's silent behavior)
|
|
try:
|
|
self._pg_conn.rollback()
|
|
except Exception:
|
|
pass
|
|
if _is_insert_or_ignore:
|
|
return self
|
|
# For plain INSERT, re-raise as IntegrityError so callers can detect duplicates
|
|
import sqlite3 as _sqlite3
|
|
raise _sqlite3.IntegrityError("UNIQUE constraint failed")
|
|
except psycopg2.ProgrammingError as e:
|
|
raise
|
|
|
|
if needs_returning:
|
|
row = self._cursor.fetchone()
|
|
if row:
|
|
self._lastrowid = row[0]
|
|
|
|
return self
|
|
|
|
def executemany(self, sql, seq_of_parameters):
|
|
"""Execute SQL for each set of parameters."""
|
|
_load_pk_cache(self._pg_conn)
|
|
translated, is_noop, extra_params = _translate_sql(sql)
|
|
if is_noop:
|
|
return self
|
|
|
|
# Remove any RETURNING clause for executemany (not useful)
|
|
sql_upper = translated.strip().upper()
|
|
if "RETURNING" not in sql_upper and sql_upper.startswith("INSERT"):
|
|
# Don't add RETURNING for executemany
|
|
pass
|
|
|
|
for params in seq_of_parameters:
|
|
if isinstance(params, dict):
|
|
self._cursor.execute(translated, params)
|
|
else:
|
|
p = tuple(params)
|
|
if extra_params is not None:
|
|
p = extra_params + p
|
|
self._cursor.execute(translated, p)
|
|
return self
|
|
|
|
def executescript(self, sql_script):
|
|
"""Execute multiple SQL statements separated by semicolons."""
|
|
statements = sql_script.split(";")
|
|
for stmt in statements:
|
|
stmt = stmt.strip()
|
|
if stmt:
|
|
self.execute(stmt)
|
|
return self
|
|
|
|
def fetchone(self):
|
|
row = self._cursor.fetchone()
|
|
if row is None:
|
|
return None
|
|
if self._row_factory is Row:
|
|
return Row(self._cursor.description, row)
|
|
return row
|
|
|
|
def fetchall(self):
|
|
rows = self._cursor.fetchall()
|
|
if self._row_factory is Row:
|
|
desc = self._cursor.description
|
|
return [Row(desc, r) for r in rows]
|
|
return rows
|
|
|
|
def fetchmany(self, size=None):
|
|
if size is None:
|
|
size = self._arraysize
|
|
rows = self._cursor.fetchmany(size)
|
|
if self._row_factory is Row:
|
|
desc = self._cursor.description
|
|
return [Row(desc, r) for r in rows]
|
|
return rows
|
|
|
|
def close(self):
|
|
try:
|
|
self._cursor.close()
|
|
except Exception:
|
|
pass
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
row = self.fetchone()
|
|
if row is None:
|
|
raise StopIteration
|
|
return row
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# PgConnection — wraps psycopg2 connection with sqlite3-compatible API
|
|
# ---------------------------------------------------------------------------
|
|
class PgConnection:
|
|
"""sqlite3-compatible connection backed by psycopg2."""
|
|
|
|
def __init__(self, database=":memory:", timeout=5.0):
|
|
self._database = database
|
|
self._row_factory = None
|
|
self._isolation_level = ""
|
|
self._in_transaction = False
|
|
|
|
self._pg_conn = self._get_valid_conn()
|
|
self._pg_conn.autocommit = True # Match sqlite3 isolation_level=None behavior
|
|
|
|
@staticmethod
|
|
def _check_conn(conn):
|
|
"""Test if a connection is alive. Returns True if healthy."""
|
|
try:
|
|
# Use pg status check — doesn't start a transaction
|
|
if conn.closed:
|
|
return False
|
|
conn.isolation_level # Triggers a server round-trip if needed
|
|
cur = conn.cursor()
|
|
cur.execute("SELECT 1")
|
|
cur.fetchone()
|
|
cur.close()
|
|
# Roll back any implicit transaction the health check may have opened
|
|
if conn.status != psycopg2.extensions.STATUS_READY:
|
|
conn.rollback()
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
def _get_valid_conn(self):
|
|
"""Get a connection from the pool, validating it's alive."""
|
|
pool = _get_pool()
|
|
conn = pool.getconn()
|
|
if self._check_conn(conn):
|
|
return conn
|
|
# Connection is stale — discard it and get a fresh one
|
|
try:
|
|
pool.putconn(conn, close=True)
|
|
except Exception:
|
|
try:
|
|
conn.close()
|
|
except Exception:
|
|
pass
|
|
conn = pool.getconn()
|
|
if not self._check_conn(conn):
|
|
raise OperationalError("Unable to obtain a valid database connection from pool")
|
|
return conn
|
|
|
|
def _ensure_conn(self):
|
|
"""Ensure the connection is still alive, reconnect if not."""
|
|
if self._pg_conn is None or self._pg_conn.closed or not self._check_conn(self._pg_conn):
|
|
# Connection gone — return it and get a new one
|
|
if self._pg_conn is not None:
|
|
try:
|
|
pool = _get_pool()
|
|
pool.putconn(self._pg_conn, close=True)
|
|
except Exception:
|
|
try:
|
|
self._pg_conn.close()
|
|
except Exception:
|
|
pass
|
|
self._pg_conn = self._get_valid_conn()
|
|
self._pg_conn.autocommit = True
|
|
|
|
@property
|
|
def row_factory(self):
|
|
return self._row_factory
|
|
|
|
@row_factory.setter
|
|
def row_factory(self, value):
|
|
self._row_factory = value
|
|
|
|
@property
|
|
def isolation_level(self):
|
|
return self._isolation_level
|
|
|
|
@isolation_level.setter
|
|
def isolation_level(self, value):
|
|
self._isolation_level = value
|
|
if value is None or value == "":
|
|
self._pg_conn.autocommit = True
|
|
else:
|
|
self._pg_conn.autocommit = False
|
|
|
|
@property
|
|
def in_transaction(self):
|
|
return self._in_transaction
|
|
|
|
@property
|
|
def total_changes(self):
|
|
return 0 # Not tracked
|
|
|
|
def cursor(self):
|
|
self._ensure_conn()
|
|
return PgCursor(self._pg_conn, row_factory=self._row_factory)
|
|
|
|
def execute(self, sql, parameters=None):
|
|
"""Convenience method — creates a cursor, executes, returns it."""
|
|
# Track explicit BEGIN for proper commit/rollback handling
|
|
sql_stripped = sql.strip().upper()
|
|
if sql_stripped.startswith("BEGIN"):
|
|
self._in_transaction = True
|
|
cur = self.cursor()
|
|
cur.execute(sql, parameters)
|
|
return cur
|
|
|
|
def executemany(self, sql, seq_of_parameters):
|
|
cur = self.cursor()
|
|
cur.executemany(sql, seq_of_parameters)
|
|
return cur
|
|
|
|
def executescript(self, sql_script):
|
|
cur = self.cursor()
|
|
cur.executescript(sql_script)
|
|
return cur
|
|
|
|
def commit(self):
|
|
if self._pg_conn is not None:
|
|
if self._in_transaction:
|
|
# Explicit transaction started by BEGIN — send COMMIT as SQL
|
|
# (psycopg2.commit() is a no-op when autocommit=True)
|
|
cur = self._pg_conn.cursor()
|
|
try:
|
|
cur.execute("COMMIT")
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
cur.close()
|
|
else:
|
|
self._pg_conn.commit()
|
|
self._in_transaction = False
|
|
|
|
def rollback(self):
|
|
if self._pg_conn is not None:
|
|
if self._in_transaction:
|
|
cur = self._pg_conn.cursor()
|
|
try:
|
|
cur.execute("ROLLBACK")
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
cur.close()
|
|
else:
|
|
self._pg_conn.rollback()
|
|
self._in_transaction = False
|
|
|
|
def close(self):
|
|
if self._pg_conn is not None:
|
|
try:
|
|
pool = _get_pool()
|
|
pool.putconn(self._pg_conn)
|
|
except Exception:
|
|
try:
|
|
self._pg_conn.close()
|
|
except Exception:
|
|
pass
|
|
self._pg_conn = None
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
if exc_type is not None:
|
|
self.rollback()
|
|
else:
|
|
self.commit()
|
|
return False
|
|
|
|
def __del__(self):
|
|
try:
|
|
self.close()
|
|
except Exception:
|
|
pass # Suppress errors during GC teardown
|
|
|
|
# --- Properties for compatibility ---
|
|
def set_trace_callback(self, callback):
|
|
pass # No-op
|
|
|
|
def create_function(self, name, num_params, func, deterministic=False):
|
|
pass # No-op — PG has its own functions
|
|
|
|
def create_aggregate(self, name, num_params, aggregate_class):
|
|
pass # No-op
|
|
|
|
def create_collation(self, name, callable_):
|
|
pass # No-op
|
|
|
|
def interrupt(self):
|
|
if self._pg_conn:
|
|
self._pg_conn.cancel()
|
|
|
|
def set_authorizer(self, authorizer_callback):
|
|
pass # No-op
|
|
|
|
def set_progress_handler(self, handler, n):
|
|
pass # No-op
|
|
|
|
def enable_load_extension(self, enabled):
|
|
pass # No-op
|
|
|
|
def load_extension(self, path):
|
|
pass # No-op
|
|
|
|
def iterdump(self):
|
|
raise NotSupportedError("iterdump() not supported in PostgreSQL adapter")
|
|
|
|
def backup(self, target, *, pages=-1, progress=None, name="main", sleep=0.250):
|
|
raise NotSupportedError("backup() not supported in PostgreSQL adapter")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# sqlite3 API type aliases
|
|
# ---------------------------------------------------------------------------
|
|
Connection = PgConnection
|
|
Cursor = PgCursor
|