""" Task Checkpoint Module for Crash Recovery Tracks progress of long-running scheduler tasks so that if the scheduler crashes mid-task, it can resume from where it left off instead of re-processing everything from scratch. Uses the scheduler_state database (PostgreSQL via pgadapter). """ import json import sqlite3 import threading import time from contextlib import closing from datetime import datetime from pathlib import Path from typing import Callable, List, Optional, Set from modules.universal_logger import get_logger logger = get_logger('TaskCheckpoint') # Path to the scheduler state database _SCHEDULER_DB_PATH = Path(__file__).parent.parent / 'database' / 'scheduler_state.db' # How many items to buffer before flushing to DB _FLUSH_INTERVAL = 5 # Stale checkpoint threshold (hours) — abandon checkpoints older than this STALE_THRESHOLD_HOURS = 48 class TaskCheckpoint: """Track progress of a scheduler task for crash recovery. Usage:: checkpoint = TaskCheckpoint('instagram_unified:all') checkpoint.start(total_items=len(accounts)) for account in accounts: if checkpoint.is_completed(account['username']): continue checkpoint.set_current(account['username']) process(account) checkpoint.mark_completed(account['username']) checkpoint.finish() """ def __init__(self, task_id: str, task_type: str = 'scraping'): self.task_id = task_id self.task_type = task_type self._started = False self._recovering = False self._completed_items: Set[str] = set() self._pending_flush: List[str] = [] # items not yet flushed to DB self._current_item: Optional[str] = None self._total_items: int = 0 self._lock = threading.Lock() # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def start(self, total_items: int = 0): """Create or resume a checkpoint record. If a prior checkpoint exists for this task_id (left behind by a crash), we load the completed items from it and set recovery mode. """ self._total_items = total_items self._started = True existing = self._load_existing() if existing is not None: # Resuming from a crash self._completed_items = existing self._recovering = True logger.info( f"Resuming checkpoint for {self.task_id}: " f"{len(self._completed_items)}/{total_items} items already completed", module='Checkpoint', ) else: # Fresh run self._completed_items = set() self._recovering = False self._create_record(total_items) def is_recovering(self) -> bool: """True if we are resuming from a prior crash.""" return self._recovering def is_completed(self, item_id: str) -> bool: """Check whether *item_id* was already processed in a previous run.""" return str(item_id) in self._completed_items def get_remaining(self, items: list, key_fn: Callable) -> list: """Return only items not yet completed. Args: items: Full list of items. key_fn: Function that extracts the item key from each element. """ return [item for item in items if str(key_fn(item)) not in self._completed_items] def set_current(self, item_id: str): """Record which item is currently being processed (for crash diagnostics).""" self._current_item = str(item_id) self._update_current_item() def mark_completed(self, item_id: str): """Mark an item as done. Batches DB writes every _FLUSH_INTERVAL items.""" item_id = str(item_id) with self._lock: self._completed_items.add(item_id) self._pending_flush.append(item_id) should_flush = len(self._pending_flush) >= _FLUSH_INTERVAL if should_flush: self._flush() def finish(self): """Task completed successfully — delete the checkpoint record.""" if not self._started: return self._flush() # flush any remaining items self._delete_record() self._started = False def finish_if_started(self): """No-op if start() was never called; otherwise calls finish().""" if self._started: self.finish() # ------------------------------------------------------------------ # Class methods for discovery # ------------------------------------------------------------------ @classmethod def get_interrupted(cls) -> list: """Find checkpoint records left behind by crashed tasks. Returns a list of dicts with keys: task_id, task_type, started_at, completed_count, total_items, current_item """ try: with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn: cursor = conn.cursor() cursor.execute( "SELECT task_id, task_type, started_at, completed_items, " "total_items, current_item FROM scheduler_task_checkpoints " "WHERE status = 'running'" ) rows = cursor.fetchall() results = [] for row in rows: task_id, task_type, started_at, completed_json, total_items, current_item = row completed = cls._parse_completed_json(completed_json) results.append({ 'task_id': task_id, 'task_type': task_type, 'started_at': started_at, 'completed_count': len(completed), 'total_items': total_items or 0, 'current_item': current_item, }) return results except Exception as e: if 'no such table' not in str(e).lower(): logger.warning(f"Error reading interrupted checkpoints: {e}", module='Checkpoint') return [] @classmethod def abandon(cls, task_id: str): """Mark a checkpoint as abandoned (e.g. task no longer registered).""" try: with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn: conn.execute( "UPDATE scheduler_task_checkpoints SET status = 'abandoned', " "updated_at = ? WHERE task_id = ?", (datetime.now().isoformat(), task_id), ) conn.commit() except Exception as e: logger.warning(f"Error abandoning checkpoint {task_id}: {e}", module='Checkpoint') # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _load_existing(self) -> Optional[Set[str]]: """Load completed items from an existing checkpoint, or return None.""" try: with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn: cursor = conn.cursor() cursor.execute( "SELECT completed_items FROM scheduler_task_checkpoints " "WHERE task_id = ? AND status = 'running'", (self.task_id,), ) row = cursor.fetchone() if row is None: return None return self._parse_completed_json(row[0]) except Exception as e: if 'no such table' not in str(e).lower(): logger.warning(f"Error loading checkpoint for {self.task_id}: {e}", module='Checkpoint') return None def _create_record(self, total_items: int): """Insert a fresh checkpoint row (or replace an existing abandoned one).""" try: with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn: conn.execute( "INSERT OR REPLACE INTO scheduler_task_checkpoints " "(task_id, task_type, started_at, completed_items, current_item, " "total_items, status, updated_at) " "VALUES (?, ?, ?, '[]', NULL, ?, 'running', ?)", ( self.task_id, self.task_type, datetime.now().isoformat(), total_items, datetime.now().isoformat(), ), ) conn.commit() except Exception as e: logger.warning(f"Error creating checkpoint for {self.task_id}: {e}", module='Checkpoint') def _flush(self): """Write pending completed items to the database.""" with self._lock: if not self._pending_flush: return items_snapshot = list(self._completed_items) self._pending_flush.clear() try: completed_json = json.dumps(items_snapshot) with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn: conn.execute( "UPDATE scheduler_task_checkpoints " "SET completed_items = ?, total_items = ?, updated_at = ? " "WHERE task_id = ?", ( completed_json, self._total_items, datetime.now().isoformat(), self.task_id, ), ) conn.commit() except Exception as e: logger.warning(f"Error flushing checkpoint for {self.task_id}: {e}", module='Checkpoint') def _update_current_item(self): """Update the current_item column for crash diagnostics.""" try: with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn: conn.execute( "UPDATE scheduler_task_checkpoints " "SET current_item = ?, updated_at = ? WHERE task_id = ?", (self._current_item, datetime.now().isoformat(), self.task_id), ) conn.commit() except Exception as e: # Non-critical — just diagnostics pass def _delete_record(self): """Remove the checkpoint row on successful completion.""" try: with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn: conn.execute( "DELETE FROM scheduler_task_checkpoints WHERE task_id = ?", (self.task_id,), ) conn.commit() except Exception as e: logger.warning(f"Error deleting checkpoint for {self.task_id}: {e}", module='Checkpoint') @staticmethod def _parse_completed_json(raw: str) -> Set[str]: """Parse JSON array of completed item IDs, tolerating corruption.""" if not raw: return set() try: items = json.loads(raw) if isinstance(items, list): return set(str(i) for i in items) except (json.JSONDecodeError, TypeError): logger.warning("Corrupted checkpoint data — starting fresh (scrapers deduplicate)", module='Checkpoint') return set()