295
modules/task_checkpoint.py
Normal file
295
modules/task_checkpoint.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""
|
||||
Task Checkpoint Module for Crash Recovery
|
||||
|
||||
Tracks progress of long-running scheduler tasks so that if the scheduler
|
||||
crashes mid-task, it can resume from where it left off instead of
|
||||
re-processing everything from scratch.
|
||||
|
||||
Uses the scheduler_state database (PostgreSQL via pgadapter).
|
||||
"""
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Set
|
||||
|
||||
from modules.universal_logger import get_logger
|
||||
|
||||
logger = get_logger('TaskCheckpoint')
|
||||
|
||||
# Path to the scheduler state database
|
||||
_SCHEDULER_DB_PATH = Path(__file__).parent.parent / 'database' / 'scheduler_state.db'
|
||||
|
||||
# How many items to buffer before flushing to DB
|
||||
_FLUSH_INTERVAL = 5
|
||||
|
||||
# Stale checkpoint threshold (hours) — abandon checkpoints older than this
|
||||
STALE_THRESHOLD_HOURS = 48
|
||||
|
||||
|
||||
class TaskCheckpoint:
|
||||
"""Track progress of a scheduler task for crash recovery.
|
||||
|
||||
Usage::
|
||||
|
||||
checkpoint = TaskCheckpoint('instagram_unified:all')
|
||||
checkpoint.start(total_items=len(accounts))
|
||||
for account in accounts:
|
||||
if checkpoint.is_completed(account['username']):
|
||||
continue
|
||||
checkpoint.set_current(account['username'])
|
||||
process(account)
|
||||
checkpoint.mark_completed(account['username'])
|
||||
checkpoint.finish()
|
||||
"""
|
||||
|
||||
def __init__(self, task_id: str, task_type: str = 'scraping'):
|
||||
self.task_id = task_id
|
||||
self.task_type = task_type
|
||||
self._started = False
|
||||
self._recovering = False
|
||||
self._completed_items: Set[str] = set()
|
||||
self._pending_flush: List[str] = [] # items not yet flushed to DB
|
||||
self._current_item: Optional[str] = None
|
||||
self._total_items: int = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def start(self, total_items: int = 0):
|
||||
"""Create or resume a checkpoint record.
|
||||
|
||||
If a prior checkpoint exists for this task_id (left behind by a crash),
|
||||
we load the completed items from it and set recovery mode.
|
||||
"""
|
||||
self._total_items = total_items
|
||||
self._started = True
|
||||
|
||||
existing = self._load_existing()
|
||||
if existing is not None:
|
||||
# Resuming from a crash
|
||||
self._completed_items = existing
|
||||
self._recovering = True
|
||||
logger.info(
|
||||
f"Resuming checkpoint for {self.task_id}: "
|
||||
f"{len(self._completed_items)}/{total_items} items already completed",
|
||||
module='Checkpoint',
|
||||
)
|
||||
else:
|
||||
# Fresh run
|
||||
self._completed_items = set()
|
||||
self._recovering = False
|
||||
self._create_record(total_items)
|
||||
|
||||
def is_recovering(self) -> bool:
|
||||
"""True if we are resuming from a prior crash."""
|
||||
return self._recovering
|
||||
|
||||
def is_completed(self, item_id: str) -> bool:
|
||||
"""Check whether *item_id* was already processed in a previous run."""
|
||||
return str(item_id) in self._completed_items
|
||||
|
||||
def get_remaining(self, items: list, key_fn: Callable) -> list:
|
||||
"""Return only items not yet completed.
|
||||
|
||||
Args:
|
||||
items: Full list of items.
|
||||
key_fn: Function that extracts the item key from each element.
|
||||
"""
|
||||
return [item for item in items if str(key_fn(item)) not in self._completed_items]
|
||||
|
||||
def set_current(self, item_id: str):
|
||||
"""Record which item is currently being processed (for crash diagnostics)."""
|
||||
self._current_item = str(item_id)
|
||||
self._update_current_item()
|
||||
|
||||
def mark_completed(self, item_id: str):
|
||||
"""Mark an item as done. Batches DB writes every _FLUSH_INTERVAL items."""
|
||||
item_id = str(item_id)
|
||||
with self._lock:
|
||||
self._completed_items.add(item_id)
|
||||
self._pending_flush.append(item_id)
|
||||
should_flush = len(self._pending_flush) >= _FLUSH_INTERVAL
|
||||
if should_flush:
|
||||
self._flush()
|
||||
|
||||
def finish(self):
|
||||
"""Task completed successfully — delete the checkpoint record."""
|
||||
if not self._started:
|
||||
return
|
||||
self._flush() # flush any remaining items
|
||||
self._delete_record()
|
||||
self._started = False
|
||||
|
||||
def finish_if_started(self):
|
||||
"""No-op if start() was never called; otherwise calls finish()."""
|
||||
if self._started:
|
||||
self.finish()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Class methods for discovery
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def get_interrupted(cls) -> list:
|
||||
"""Find checkpoint records left behind by crashed tasks.
|
||||
|
||||
Returns a list of dicts with keys:
|
||||
task_id, task_type, started_at, completed_count, total_items, current_item
|
||||
"""
|
||||
try:
|
||||
with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT task_id, task_type, started_at, completed_items, "
|
||||
"total_items, current_item FROM scheduler_task_checkpoints "
|
||||
"WHERE status = 'running'"
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
task_id, task_type, started_at, completed_json, total_items, current_item = row
|
||||
completed = cls._parse_completed_json(completed_json)
|
||||
results.append({
|
||||
'task_id': task_id,
|
||||
'task_type': task_type,
|
||||
'started_at': started_at,
|
||||
'completed_count': len(completed),
|
||||
'total_items': total_items or 0,
|
||||
'current_item': current_item,
|
||||
})
|
||||
return results
|
||||
except Exception as e:
|
||||
if 'no such table' not in str(e).lower():
|
||||
logger.warning(f"Error reading interrupted checkpoints: {e}", module='Checkpoint')
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def abandon(cls, task_id: str):
|
||||
"""Mark a checkpoint as abandoned (e.g. task no longer registered)."""
|
||||
try:
|
||||
with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn:
|
||||
conn.execute(
|
||||
"UPDATE scheduler_task_checkpoints SET status = 'abandoned', "
|
||||
"updated_at = ? WHERE task_id = ?",
|
||||
(datetime.now().isoformat(), task_id),
|
||||
)
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error abandoning checkpoint {task_id}: {e}", module='Checkpoint')
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _load_existing(self) -> Optional[Set[str]]:
|
||||
"""Load completed items from an existing checkpoint, or return None."""
|
||||
try:
|
||||
with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT completed_items FROM scheduler_task_checkpoints "
|
||||
"WHERE task_id = ? AND status = 'running'",
|
||||
(self.task_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return self._parse_completed_json(row[0])
|
||||
except Exception as e:
|
||||
if 'no such table' not in str(e).lower():
|
||||
logger.warning(f"Error loading checkpoint for {self.task_id}: {e}", module='Checkpoint')
|
||||
return None
|
||||
|
||||
def _create_record(self, total_items: int):
|
||||
"""Insert a fresh checkpoint row (or replace an existing abandoned one)."""
|
||||
try:
|
||||
with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn:
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO scheduler_task_checkpoints "
|
||||
"(task_id, task_type, started_at, completed_items, current_item, "
|
||||
"total_items, status, updated_at) "
|
||||
"VALUES (?, ?, ?, '[]', NULL, ?, 'running', ?)",
|
||||
(
|
||||
self.task_id,
|
||||
self.task_type,
|
||||
datetime.now().isoformat(),
|
||||
total_items,
|
||||
datetime.now().isoformat(),
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error creating checkpoint for {self.task_id}: {e}", module='Checkpoint')
|
||||
|
||||
def _flush(self):
|
||||
"""Write pending completed items to the database."""
|
||||
with self._lock:
|
||||
if not self._pending_flush:
|
||||
return
|
||||
items_snapshot = list(self._completed_items)
|
||||
self._pending_flush.clear()
|
||||
|
||||
try:
|
||||
completed_json = json.dumps(items_snapshot)
|
||||
with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn:
|
||||
conn.execute(
|
||||
"UPDATE scheduler_task_checkpoints "
|
||||
"SET completed_items = ?, total_items = ?, updated_at = ? "
|
||||
"WHERE task_id = ?",
|
||||
(
|
||||
completed_json,
|
||||
self._total_items,
|
||||
datetime.now().isoformat(),
|
||||
self.task_id,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error flushing checkpoint for {self.task_id}: {e}", module='Checkpoint')
|
||||
|
||||
def _update_current_item(self):
|
||||
"""Update the current_item column for crash diagnostics."""
|
||||
try:
|
||||
with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn:
|
||||
conn.execute(
|
||||
"UPDATE scheduler_task_checkpoints "
|
||||
"SET current_item = ?, updated_at = ? WHERE task_id = ?",
|
||||
(self._current_item, datetime.now().isoformat(), self.task_id),
|
||||
)
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
# Non-critical — just diagnostics
|
||||
pass
|
||||
|
||||
def _delete_record(self):
|
||||
"""Remove the checkpoint row on successful completion."""
|
||||
try:
|
||||
with closing(sqlite3.connect(str(_SCHEDULER_DB_PATH), timeout=10)) as conn:
|
||||
conn.execute(
|
||||
"DELETE FROM scheduler_task_checkpoints WHERE task_id = ?",
|
||||
(self.task_id,),
|
||||
)
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error deleting checkpoint for {self.task_id}: {e}", module='Checkpoint')
|
||||
|
||||
@staticmethod
|
||||
def _parse_completed_json(raw: str) -> Set[str]:
|
||||
"""Parse JSON array of completed item IDs, tolerating corruption."""
|
||||
if not raw:
|
||||
return set()
|
||||
try:
|
||||
items = json.loads(raw)
|
||||
if isinstance(items, list):
|
||||
return set(str(i) for i in items)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning("Corrupted checkpoint data — starting fresh (scrapers deduplicate)", module='Checkpoint')
|
||||
return set()
|
||||
Reference in New Issue
Block a user