296 lines
11 KiB
Python
296 lines
11 KiB
Python
"""
|
|
Task Checkpoint Module for Crash Recovery
|
|
|
|
Tracks progress of long-running scheduler tasks so that if the scheduler
|
|
crashes mid-task, it can resume from where it left off instead of
|
|
re-processing everything from scratch.
|
|
|
|
Uses the scheduler_state database (PostgreSQL via pgadapter).
|
|
"""
|
|
|
|
import json
|
|
import sqlite3
|
|
import threading
|
|
import time
|
|
from contextlib import closing
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Callable, List, Optional, Set
|
|
|
|
from modules.universal_logger import get_logger
|
|
|
|
logger = get_logger('TaskCheckpoint')
|
|
|
|
# Path to the scheduler state database
|
|
_SCHEDULER_DB_PATH = Path(__file__).parent.parent / 'database' / 'scheduler_state.db'
|
|
|
|
# How many items to buffer before flushing to DB
|
|
_FLUSH_INTERVAL = 5
|
|
|
|
# Stale checkpoint threshold (hours) — abandon checkpoints older than this
|
|
STALE_THRESHOLD_HOURS = 48
|
|
|
|
|
|
class TaskCheckpoint:
|
|
"""Track progress of a scheduler task for crash recovery.
|
|
|
|
Usage::
|
|
|
|
checkpoint = TaskCheckpoint('instagram_unified:all')
|
|
checkpoint.start(total_items=len(accounts))
|
|
for account in accounts:
|
|
if checkpoint.is_completed(account['username']):
|
|
continue
|
|
checkpoint.set_current(account['username'])
|
|
process(account)
|
|
checkpoint.mark_completed(account['username'])
|
|
checkpoint.finish()
|
|
"""
|
|
|
|
def __init__(self, task_id: str, task_type: str = 'scraping'):
|
|
self.task_id = task_id
|
|
self.task_type = task_type
|
|
self._started = False
|
|
self._recovering = False
|
|
self._completed_items: Set[str] = set()
|
|
self._pending_flush: List[str] = [] # items not yet flushed to DB
|
|
self._current_item: Optional[str] = None
|
|
self._total_items: int = 0
|
|
self._lock = threading.Lock()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Public API
|
|
# ------------------------------------------------------------------
|
|
|
|
def start(self, total_items: int = 0):
|
|
"""Create or resume a checkpoint record.
|
|
|
|
If a prior checkpoint exists for this task_id (left behind by a crash),
|
|
we load the completed items from it and set recovery mode.
|
|
"""
|
|
self._total_items = total_items
|
|
self._started = True
|
|
|
|
existing = self._load_existing()
|
|
if existing is not None:
|
|
# Resuming from a crash
|
|
self._completed_items = existing
|
|
self._recovering = True
|
|
logger.info(
|
|
f"Resuming checkpoint for {self.task_id}: "
|
|
f"{len(self._completed_items)}/{total_items} items already completed",
|
|
module='Checkpoint',
|
|
)
|
|
else:
|
|
# Fresh run
|
|
self._completed_items = set()
|
|
self._recovering = False
|
|
self._create_record(total_items)
|
|
|
|
def is_recovering(self) -> bool:
|
|
"""True if we are resuming from a prior crash."""
|
|
return self._recovering
|
|
|
|
def is_completed(self, item_id: str) -> bool:
|
|
"""Check whether *item_id* was already processed in a previous run."""
|
|
return str(item_id) in self._completed_items
|
|
|
|
def get_remaining(self, items: list, key_fn: Callable) -> list:
|
|
"""Return only items not yet completed.
|
|
|
|
Args:
|
|
items: Full list of items.
|
|
key_fn: Function that extracts the item key from each element.
|
|
"""
|
|
return [item for item in items if str(key_fn(item)) not in self._completed_items]
|
|
|
|
def set_current(self, item_id: str):
|
|
"""Record which item is currently being processed (for crash diagnostics)."""
|
|
self._current_item = str(item_id)
|
|
self._update_current_item()
|
|
|
|
def mark_completed(self, item_id: str):
|
|
"""Mark an item as done. Batches DB writes every _FLUSH_INTERVAL items."""
|
|
item_id = str(item_id)
|
|
with self._lock:
|
|
self._completed_items.add(item_id)
|
|
self._pending_flush.append(item_id)
|
|
should_flush = len(self._pending_flush) >= _FLUSH_INTERVAL
|
|
if should_flush:
|
|
self._flush()
|
|
|
|
def finish(self):
|
|
"""Task completed successfully — delete the checkpoint record."""
|
|
if not self._started:
|
|
return
|
|
self._flush() # flush any remaining items
|
|
self._delete_record()
|
|
self._started = False
|
|
|
|
def finish_if_started(self):
|
|
"""No-op if start() was never called; otherwise calls finish()."""
|
|
if self._started:
|
|
self.finish()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Class methods for discovery
|
|
# ------------------------------------------------------------------
|
|
|
|
@classmethod
|
|
def get_interrupted(cls) -> list:
|
|
"""Find checkpoint records left behind by crashed tasks.
|
|
|
|
Returns a list of dicts with keys:
|
|
task_id, task_type, started_at, completed_count, total_items, current_item
|
|
"""
|
|
try:
|
|
with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"SELECT task_id, task_type, started_at, completed_items, "
|
|
"total_items, current_item FROM scheduler_task_checkpoints "
|
|
"WHERE status = 'running'"
|
|
)
|
|
rows = cursor.fetchall()
|
|
|
|
results = []
|
|
for row in rows:
|
|
task_id, task_type, started_at, completed_json, total_items, current_item = row
|
|
completed = cls._parse_completed_json(completed_json)
|
|
results.append({
|
|
'task_id': task_id,
|
|
'task_type': task_type,
|
|
'started_at': started_at,
|
|
'completed_count': len(completed),
|
|
'total_items': total_items or 0,
|
|
'current_item': current_item,
|
|
})
|
|
return results
|
|
except Exception as e:
|
|
if 'no such table' not in str(e).lower():
|
|
logger.warning(f"Error reading interrupted checkpoints: {e}", module='Checkpoint')
|
|
return []
|
|
|
|
@classmethod
|
|
def abandon(cls, task_id: str):
|
|
"""Mark a checkpoint as abandoned (e.g. task no longer registered)."""
|
|
try:
|
|
with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn:
|
|
conn.execute(
|
|
"UPDATE scheduler_task_checkpoints SET status = 'abandoned', "
|
|
"updated_at = ? WHERE task_id = ?",
|
|
(datetime.now().isoformat(), task_id),
|
|
)
|
|
conn.commit()
|
|
except Exception as e:
|
|
logger.warning(f"Error abandoning checkpoint {task_id}: {e}", module='Checkpoint')
|
|
|
|
# ------------------------------------------------------------------
|
|
# Internal helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def _load_existing(self) -> Optional[Set[str]]:
|
|
"""Load completed items from an existing checkpoint, or return None."""
|
|
try:
|
|
with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"SELECT completed_items FROM scheduler_task_checkpoints "
|
|
"WHERE task_id = ? AND status = 'running'",
|
|
(self.task_id,),
|
|
)
|
|
row = cursor.fetchone()
|
|
if row is None:
|
|
return None
|
|
return self._parse_completed_json(row[0])
|
|
except Exception as e:
|
|
if 'no such table' not in str(e).lower():
|
|
logger.warning(f"Error loading checkpoint for {self.task_id}: {e}", module='Checkpoint')
|
|
return None
|
|
|
|
def _create_record(self, total_items: int):
|
|
"""Insert a fresh checkpoint row (or replace an existing abandoned one)."""
|
|
try:
|
|
with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn:
|
|
conn.execute(
|
|
"INSERT OR REPLACE INTO scheduler_task_checkpoints "
|
|
"(task_id, task_type, started_at, completed_items, current_item, "
|
|
"total_items, status, updated_at) "
|
|
"VALUES (?, ?, ?, '[]', NULL, ?, 'running', ?)",
|
|
(
|
|
self.task_id,
|
|
self.task_type,
|
|
datetime.now().isoformat(),
|
|
total_items,
|
|
datetime.now().isoformat(),
|
|
),
|
|
)
|
|
conn.commit()
|
|
except Exception as e:
|
|
logger.warning(f"Error creating checkpoint for {self.task_id}: {e}", module='Checkpoint')
|
|
|
|
def _flush(self):
|
|
"""Write pending completed items to the database."""
|
|
with self._lock:
|
|
if not self._pending_flush:
|
|
return
|
|
items_snapshot = list(self._completed_items)
|
|
self._pending_flush.clear()
|
|
|
|
try:
|
|
completed_json = json.dumps(items_snapshot)
|
|
with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn:
|
|
conn.execute(
|
|
"UPDATE scheduler_task_checkpoints "
|
|
"SET completed_items = ?, total_items = ?, updated_at = ? "
|
|
"WHERE task_id = ?",
|
|
(
|
|
completed_json,
|
|
self._total_items,
|
|
datetime.now().isoformat(),
|
|
self.task_id,
|
|
),
|
|
)
|
|
conn.commit()
|
|
except Exception as e:
|
|
logger.warning(f"Error flushing checkpoint for {self.task_id}: {e}", module='Checkpoint')
|
|
|
|
def _update_current_item(self):
|
|
"""Update the current_item column for crash diagnostics."""
|
|
try:
|
|
with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn:
|
|
conn.execute(
|
|
"UPDATE scheduler_task_checkpoints "
|
|
"SET current_item = ?, updated_at = ? WHERE task_id = ?",
|
|
(self._current_item, datetime.now().isoformat(), self.task_id),
|
|
)
|
|
conn.commit()
|
|
except Exception as e:
|
|
# Non-critical — just diagnostics
|
|
pass
|
|
|
|
def _delete_record(self):
|
|
"""Remove the checkpoint row on successful completion."""
|
|
try:
|
|
with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn:
|
|
conn.execute(
|
|
"DELETE FROM scheduler_task_checkpoints WHERE task_id = ?",
|
|
(self.task_id,),
|
|
)
|
|
conn.commit()
|
|
except Exception as e:
|
|
logger.warning(f"Error deleting checkpoint for {self.task_id}: {e}", module='Checkpoint')
|
|
|
|
@staticmethod
|
|
def _parse_completed_json(raw: str) -> Set[str]:
|
|
"""Parse JSON array of completed item IDs, tolerating corruption."""
|
|
if not raw:
|
|
return set()
|
|
try:
|
|
items = json.loads(raw)
|
|
if isinstance(items, list):
|
|
return set(str(i) for i in items)
|
|
except (json.JSONDecodeError, TypeError):
|
|
logger.warning("Corrupted checkpoint data — starting fresh (scrapers deduplicate)", module='Checkpoint')
|
|
return set()
|