Files
media-downloader/modules/task_checkpoint.py
Todd 0d7b2b1aab Initial commit
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-29 22:42:55 -04:00

296 lines
11 KiB
Python

"""
Task Checkpoint Module for Crash Recovery
Tracks progress of long-running scheduler tasks so that if the scheduler
crashes mid-task, it can resume from where it left off instead of
re-processing everything from scratch.
Uses the scheduler_state database (PostgreSQL via pgadapter).
"""
import json
import sqlite3
import threading
import time
from contextlib import closing
from datetime import datetime
from pathlib import Path
from typing import Callable, List, Optional, Set
from modules.universal_logger import get_logger
logger = get_logger('TaskCheckpoint')
# Path to the scheduler state database
_SCHEDULER_DB_PATH = Path(__file__).parent.parent / 'database' / 'scheduler_state.db'
# How many items to buffer before flushing to DB
_FLUSH_INTERVAL = 5
# Stale checkpoint threshold (hours) — abandon checkpoints older than this
STALE_THRESHOLD_HOURS = 48
class TaskCheckpoint:
"""Track progress of a scheduler task for crash recovery.
Usage::
checkpoint = TaskCheckpoint('instagram_unified:all')
checkpoint.start(total_items=len(accounts))
for account in accounts:
if checkpoint.is_completed(account['username']):
continue
checkpoint.set_current(account['username'])
process(account)
checkpoint.mark_completed(account['username'])
checkpoint.finish()
"""
def __init__(self, task_id: str, task_type: str = 'scraping'):
self.task_id = task_id
self.task_type = task_type
self._started = False
self._recovering = False
self._completed_items: Set[str] = set()
self._pending_flush: List[str] = [] # items not yet flushed to DB
self._current_item: Optional[str] = None
self._total_items: int = 0
self._lock = threading.Lock()
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def start(self, total_items: int = 0):
"""Create or resume a checkpoint record.
If a prior checkpoint exists for this task_id (left behind by a crash),
we load the completed items from it and set recovery mode.
"""
self._total_items = total_items
self._started = True
existing = self._load_existing()
if existing is not None:
# Resuming from a crash
self._completed_items = existing
self._recovering = True
logger.info(
f"Resuming checkpoint for {self.task_id}: "
f"{len(self._completed_items)}/{total_items} items already completed",
module='Checkpoint',
)
else:
# Fresh run
self._completed_items = set()
self._recovering = False
self._create_record(total_items)
def is_recovering(self) -> bool:
"""True if we are resuming from a prior crash."""
return self._recovering
def is_completed(self, item_id: str) -> bool:
"""Check whether *item_id* was already processed in a previous run."""
return str(item_id) in self._completed_items
def get_remaining(self, items: list, key_fn: Callable) -> list:
"""Return only items not yet completed.
Args:
items: Full list of items.
key_fn: Function that extracts the item key from each element.
"""
return [item for item in items if str(key_fn(item)) not in self._completed_items]
def set_current(self, item_id: str):
"""Record which item is currently being processed (for crash diagnostics)."""
self._current_item = str(item_id)
self._update_current_item()
def mark_completed(self, item_id: str):
"""Mark an item as done. Batches DB writes every _FLUSH_INTERVAL items."""
item_id = str(item_id)
with self._lock:
self._completed_items.add(item_id)
self._pending_flush.append(item_id)
should_flush = len(self._pending_flush) >= _FLUSH_INTERVAL
if should_flush:
self._flush()
def finish(self):
"""Task completed successfully — delete the checkpoint record."""
if not self._started:
return
self._flush() # flush any remaining items
self._delete_record()
self._started = False
def finish_if_started(self):
"""No-op if start() was never called; otherwise calls finish()."""
if self._started:
self.finish()
# ------------------------------------------------------------------
# Class methods for discovery
# ------------------------------------------------------------------
@classmethod
def get_interrupted(cls) -> list:
"""Find checkpoint records left behind by crashed tasks.
Returns a list of dicts with keys:
task_id, task_type, started_at, completed_count, total_items, current_item
"""
try:
with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT task_id, task_type, started_at, completed_items, "
"total_items, current_item FROM scheduler_task_checkpoints "
"WHERE status = 'running'"
)
rows = cursor.fetchall()
results = []
for row in rows:
task_id, task_type, started_at, completed_json, total_items, current_item = row
completed = cls._parse_completed_json(completed_json)
results.append({
'task_id': task_id,
'task_type': task_type,
'started_at': started_at,
'completed_count': len(completed),
'total_items': total_items or 0,
'current_item': current_item,
})
return results
except Exception as e:
if 'no such table' not in str(e).lower():
logger.warning(f"Error reading interrupted checkpoints: {e}", module='Checkpoint')
return []
@classmethod
def abandon(cls, task_id: str):
"""Mark a checkpoint as abandoned (e.g. task no longer registered)."""
try:
with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn:
conn.execute(
"UPDATE scheduler_task_checkpoints SET status = 'abandoned', "
"updated_at = ? WHERE task_id = ?",
(datetime.now().isoformat(), task_id),
)
conn.commit()
except Exception as e:
logger.warning(f"Error abandoning checkpoint {task_id}: {e}", module='Checkpoint')
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _load_existing(self) -> Optional[Set[str]]:
"""Load completed items from an existing checkpoint, or return None."""
try:
with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT completed_items FROM scheduler_task_checkpoints "
"WHERE task_id = ? AND status = 'running'",
(self.task_id,),
)
row = cursor.fetchone()
if row is None:
return None
return self._parse_completed_json(row[0])
except Exception as e:
if 'no such table' not in str(e).lower():
logger.warning(f"Error loading checkpoint for {self.task_id}: {e}", module='Checkpoint')
return None
def _create_record(self, total_items: int):
"""Insert a fresh checkpoint row (or replace an existing abandoned one)."""
try:
with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn:
conn.execute(
"INSERT OR REPLACE INTO scheduler_task_checkpoints "
"(task_id, task_type, started_at, completed_items, current_item, "
"total_items, status, updated_at) "
"VALUES (?, ?, ?, '[]', NULL, ?, 'running', ?)",
(
self.task_id,
self.task_type,
datetime.now().isoformat(),
total_items,
datetime.now().isoformat(),
),
)
conn.commit()
except Exception as e:
logger.warning(f"Error creating checkpoint for {self.task_id}: {e}", module='Checkpoint')
def _flush(self):
"""Write pending completed items to the database."""
with self._lock:
if not self._pending_flush:
return
items_snapshot = list(self._completed_items)
self._pending_flush.clear()
try:
completed_json = json.dumps(items_snapshot)
with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn:
conn.execute(
"UPDATE scheduler_task_checkpoints "
"SET completed_items = ?, total_items = ?, updated_at = ? "
"WHERE task_id = ?",
(
completed_json,
self._total_items,
datetime.now().isoformat(),
self.task_id,
),
)
conn.commit()
except Exception as e:
logger.warning(f"Error flushing checkpoint for {self.task_id}: {e}", module='Checkpoint')
def _update_current_item(self):
"""Update the current_item column for crash diagnostics."""
try:
with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn:
conn.execute(
"UPDATE scheduler_task_checkpoints "
"SET current_item = ?, updated_at = ? WHERE task_id = ?",
(self._current_item, datetime.now().isoformat(), self.task_id),
)
conn.commit()
except Exception as e:
# Non-critical — just diagnostics
pass
def _delete_record(self):
"""Remove the checkpoint row on successful completion."""
try:
with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn:
conn.execute(
"DELETE FROM scheduler_task_checkpoints WHERE task_id = ?",
(self.task_id,),
)
conn.commit()
except Exception as e:
logger.warning(f"Error deleting checkpoint for {self.task_id}: {e}", module='Checkpoint')
@staticmethod
def _parse_completed_json(raw: str) -> Set[str]:
"""Parse JSON array of completed item IDs, tolerating corruption."""
if not raw:
return set()
try:
items = json.loads(raw)
if isinstance(items, list):
return set(str(i) for i in items)
except (json.JSONDecodeError, TypeError):
logger.warning("Corrupted checkpoint data — starting fresh (scrapers deduplicate)", module='Checkpoint')
return set()