762
web/backend/api.py
Normal file
762
web/backend/api.py
Normal file
@@ -0,0 +1,762 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Media Downloader Web API
|
||||
FastAPI backend that integrates with the existing media-downloader system
|
||||
"""
|
||||
|
||||
# Suppress pkg_resources deprecation warning from face_recognition_models
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore", message=".*pkg_resources is deprecated.*")
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Bootstrap database backend (must be before any database imports)
|
||||
sys.path.insert(0, str(__import__('pathlib').Path(__file__).parent.parent.parent))
|
||||
import modules.db_bootstrap # noqa: E402,F401
|
||||
import json
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from starlette_csrf import CSRFMiddleware
|
||||
import sqlite3
|
||||
import secrets
|
||||
import re
|
||||
|
||||
# Redis cache manager
|
||||
from cache_manager import cache_manager, invalidate_download_cache
|
||||
|
||||
# Rate limiting
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.util import get_remote_address
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
|
||||
# Add parent directory to path to import media-downloader modules
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from modules.unified_database import UnifiedDatabase
|
||||
from modules.scheduler import DownloadScheduler
|
||||
from modules.settings_manager import SettingsManager
|
||||
from modules.universal_logger import get_logger
|
||||
from modules.discovery_system import get_discovery_system
|
||||
from modules.semantic_search import get_semantic_search
|
||||
from web.backend.auth_manager import AuthManager
|
||||
from web.backend.core.dependencies import AppState
|
||||
|
||||
# Initialize universal logger for API
|
||||
logger = get_logger('API')
|
||||
|
||||
# ============================================================================
|
||||
# CONFIGURATION (use centralized settings from core/config.py)
|
||||
# ============================================================================
|
||||
from web.backend.core.config import settings
|
||||
|
||||
# Legacy aliases for backward compatibility (prefer using settings.* directly)
|
||||
PROJECT_ROOT = settings.PROJECT_ROOT
|
||||
DB_PATH = settings.DB_PATH
|
||||
CONFIG_PATH = settings.CONFIG_PATH
|
||||
LOG_PATH = settings.LOG_PATH
|
||||
TEMP_DIR = settings.TEMP_DIR
|
||||
DB_POOL_SIZE = settings.DB_POOL_SIZE
|
||||
DB_CONNECTION_TIMEOUT = settings.DB_CONNECTION_TIMEOUT
|
||||
PROCESS_TIMEOUT_SHORT = settings.PROCESS_TIMEOUT_SHORT
|
||||
PROCESS_TIMEOUT_MEDIUM = settings.PROCESS_TIMEOUT_MEDIUM
|
||||
PROCESS_TIMEOUT_LONG = settings.PROCESS_TIMEOUT_LONG
|
||||
WEBSOCKET_TIMEOUT = settings.WEBSOCKET_TIMEOUT
|
||||
ACTIVITY_LOG_INTERVAL = settings.ACTIVITY_LOG_INTERVAL
|
||||
EMBEDDING_QUEUE_CHECK_INTERVAL = settings.EMBEDDING_QUEUE_CHECK_INTERVAL
|
||||
EMBEDDING_BATCH_SIZE = settings.EMBEDDING_BATCH_SIZE
|
||||
EMBEDDING_BATCH_LIMIT = settings.EMBEDDING_BATCH_LIMIT
|
||||
THUMBNAIL_DB_TIMEOUT = settings.THUMBNAIL_DB_TIMEOUT
|
||||
|
||||
# ============================================================================
|
||||
# PYDANTIC MODELS (imported from models/api_models.py to avoid duplication)
|
||||
# ============================================================================
|
||||
|
||||
from web.backend.models.api_models import (
|
||||
DownloadResponse,
|
||||
StatsResponse,
|
||||
PlatformStatus,
|
||||
TriggerRequest,
|
||||
PushoverConfig,
|
||||
SchedulerConfig,
|
||||
ConfigUpdate,
|
||||
HealthStatus,
|
||||
LoginRequest,
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# AUTHENTICATION DEPENDENCIES (imported from core/dependencies.py)
|
||||
# ============================================================================
|
||||
|
||||
from web.backend.core.dependencies import (
|
||||
get_current_user,
|
||||
require_admin,
|
||||
get_app_state,
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# GLOBAL STATE
|
||||
# ============================================================================
|
||||
|
||||
# Use the shared singleton from core/dependencies (used by all routers)
|
||||
app_state = get_app_state()
|
||||
|
||||
# ============================================================================
|
||||
# LIFESPAN MANAGEMENT
|
||||
# ============================================================================
|
||||
|
||||
def init_thumbnail_db():
|
||||
"""Initialize thumbnail cache database"""
|
||||
thumb_db_path = Path(__file__).parent.parent.parent / 'database' / 'thumbnails.db'
|
||||
thumb_db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
conn = sqlite3.connect(str(thumb_db_path))
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS thumbnails (
|
||||
file_hash TEXT PRIMARY KEY,
|
||||
file_path TEXT NOT NULL,
|
||||
thumbnail_data BLOB NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
file_mtime REAL NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_file_path ON thumbnails(file_path)")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return thumb_db_path
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Initialize and cleanup app resources"""
|
||||
# Startup
|
||||
logger.info("🚀 Starting Media Downloader API...", module="Core")
|
||||
|
||||
# Initialize database with larger pool for API workers
|
||||
app_state.db = UnifiedDatabase(str(DB_PATH), use_pool=True, pool_size=DB_POOL_SIZE)
|
||||
logger.info(f"✓ Connected to database: {DB_PATH} (pool_size={DB_POOL_SIZE})", module="Core")
|
||||
|
||||
# Reset any stuck downloads from previous API run
|
||||
try:
|
||||
with app_state.db.get_connection(for_write=True) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
UPDATE video_download_queue
|
||||
SET status = 'pending', progress = 0, started_at = NULL
|
||||
WHERE status = 'downloading'
|
||||
''')
|
||||
reset_count = cursor.rowcount
|
||||
conn.commit()
|
||||
if reset_count > 0:
|
||||
logger.info(f"✓ Reset {reset_count} stuck download(s) to pending", module="Core")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not reset stuck downloads: {e}", module="Core")
|
||||
|
||||
# Clear any stale background tasks from previous API run
|
||||
try:
|
||||
with app_state.db.get_connection(for_write=True) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
UPDATE background_task_status
|
||||
SET active = 0
|
||||
WHERE active = 1
|
||||
''')
|
||||
cleared_count = cursor.rowcount
|
||||
conn.commit()
|
||||
if cleared_count > 0:
|
||||
logger.info(f"✓ Cleared {cleared_count} stale background task(s)", module="Core")
|
||||
except Exception as e:
|
||||
# Table might not exist yet, that's fine
|
||||
pass
|
||||
|
||||
# Initialize face recognition semaphore (limit to 1 concurrent process)
|
||||
app_state.face_recognition_semaphore = asyncio.Semaphore(1)
|
||||
logger.info("✓ Face recognition semaphore initialized (max 1 concurrent)", module="Core")
|
||||
|
||||
# Initialize authentication manager
|
||||
app_state.auth = AuthManager()
|
||||
logger.info("✓ Authentication manager initialized", module="Core")
|
||||
|
||||
# Initialize and include 2FA router
|
||||
twofa_router = create_2fa_router(app_state.auth, get_current_user)
|
||||
app.include_router(twofa_router, prefix="/api/auth/2fa", tags=["2FA"])
|
||||
logger.info("✓ 2FA routes initialized", module="Core")
|
||||
|
||||
# Include modular API routers
|
||||
from web.backend.routers import (
|
||||
auth_router,
|
||||
health_router,
|
||||
downloads_router,
|
||||
media_router,
|
||||
recycle_router,
|
||||
scheduler_router,
|
||||
video_router,
|
||||
config_router,
|
||||
review_router,
|
||||
face_router,
|
||||
platforms_router,
|
||||
discovery_router,
|
||||
scrapers_router,
|
||||
semantic_router,
|
||||
manual_import_router,
|
||||
stats_router,
|
||||
celebrity_router,
|
||||
video_queue_router,
|
||||
maintenance_router,
|
||||
files_router,
|
||||
appearances_router,
|
||||
easynews_router,
|
||||
dashboard_router,
|
||||
paid_content_router,
|
||||
private_gallery_router,
|
||||
instagram_unified_router,
|
||||
cloud_backup_router,
|
||||
press_router
|
||||
)
|
||||
|
||||
# Include all modular routers
|
||||
# Note: websocket_manager is set later after ConnectionManager is created
|
||||
app.include_router(auth_router)
|
||||
app.include_router(health_router)
|
||||
app.include_router(downloads_router)
|
||||
app.include_router(media_router)
|
||||
app.include_router(recycle_router)
|
||||
app.include_router(scheduler_router)
|
||||
app.include_router(video_router)
|
||||
app.include_router(config_router)
|
||||
app.include_router(review_router)
|
||||
app.include_router(face_router)
|
||||
app.include_router(platforms_router)
|
||||
app.include_router(discovery_router)
|
||||
app.include_router(scrapers_router)
|
||||
app.include_router(semantic_router)
|
||||
app.include_router(manual_import_router)
|
||||
app.include_router(stats_router)
|
||||
app.include_router(celebrity_router)
|
||||
app.include_router(video_queue_router)
|
||||
app.include_router(maintenance_router)
|
||||
app.include_router(files_router)
|
||||
app.include_router(appearances_router)
|
||||
app.include_router(easynews_router)
|
||||
app.include_router(dashboard_router)
|
||||
app.include_router(paid_content_router)
|
||||
app.include_router(private_gallery_router)
|
||||
app.include_router(instagram_unified_router)
|
||||
app.include_router(cloud_backup_router)
|
||||
app.include_router(press_router)
|
||||
|
||||
logger.info("✓ All 28 modular routers registered", module="Core")
|
||||
|
||||
# Initialize settings manager
|
||||
app_state.settings = SettingsManager(str(DB_PATH))
|
||||
logger.info("✓ Settings manager initialized", module="Core")
|
||||
|
||||
# Check if settings need to be migrated from JSON
|
||||
existing_settings = app_state.settings.get_all()
|
||||
if not existing_settings or len(existing_settings) == 0:
|
||||
logger.info("⚠ No settings in database, migrating from JSON...", module="Core")
|
||||
if CONFIG_PATH.exists():
|
||||
app_state.settings.migrate_from_json(str(CONFIG_PATH))
|
||||
logger.info("✓ Settings migrated from JSON to database", module="Core")
|
||||
else:
|
||||
logger.info("⚠ No JSON config file found", module="Core")
|
||||
|
||||
# Load configuration from database
|
||||
app_state.config = app_state.settings.get_all()
|
||||
if app_state.config:
|
||||
logger.info(f"✓ Loaded configuration from database", module="Core")
|
||||
else:
|
||||
logger.info("⚠ No configuration in database - please configure via web interface", module="Core")
|
||||
app_state.config = {}
|
||||
|
||||
# Initialize scheduler (for status checking only - actual scheduler runs as separate service)
|
||||
app_state.scheduler = DownloadScheduler(config_path=None, unified_db=app_state.db, settings_manager=app_state.settings)
|
||||
logger.info("✓ Scheduler connected (status monitoring only)", module="Core")
|
||||
|
||||
# Initialize thumbnail database
|
||||
init_thumbnail_db()
|
||||
logger.info("✓ Thumbnail cache initialized", module="Core")
|
||||
|
||||
# Initialize recycle bin settings with defaults if not exists
|
||||
if not app_state.settings.get('recycle_bin'):
|
||||
recycle_bin_defaults = {
|
||||
'enabled': True,
|
||||
'path': '/opt/immich/recycle',
|
||||
'retention_days': 30,
|
||||
'max_size_gb': 50,
|
||||
'auto_cleanup': True
|
||||
}
|
||||
app_state.settings.set('recycle_bin', recycle_bin_defaults, category='recycle_bin', description='Recycle bin configuration', updated_by='system')
|
||||
logger.info("✓ Recycle bin settings initialized with defaults", module="Core")
|
||||
|
||||
logger.info("✓ Media Downloader API ready!", module="Core")
|
||||
|
||||
# Start periodic WAL checkpoint task
|
||||
async def periodic_wal_checkpoint():
|
||||
"""Run WAL checkpoint every 5 minutes to prevent WAL file growth"""
|
||||
while True:
|
||||
await asyncio.sleep(300) # 5 minutes
|
||||
try:
|
||||
if app_state.db:
|
||||
app_state.db.checkpoint()
|
||||
logger.debug("WAL checkpoint completed", module="Core")
|
||||
except Exception as e:
|
||||
logger.warning(f"WAL checkpoint failed: {e}", module="Core")
|
||||
|
||||
checkpoint_task = asyncio.create_task(periodic_wal_checkpoint())
|
||||
logger.info("✓ WAL checkpoint task started (every 5 minutes)", module="Core")
|
||||
|
||||
# Start discovery queue processor
|
||||
async def process_discovery_queue():
|
||||
"""Background task to process discovery scan queue (embeddings)"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(30) # Check queue every 30 seconds
|
||||
|
||||
if not app_state.db:
|
||||
continue
|
||||
|
||||
# Process embeddings first (limit batch size for memory efficiency)
|
||||
pending_embeddings = app_state.db.get_pending_discovery_scans(limit=EMBEDDING_BATCH_SIZE, scan_type='embedding')
|
||||
|
||||
for item in pending_embeddings:
|
||||
try:
|
||||
queue_id = item['id']
|
||||
file_id = item['file_id']
|
||||
file_path = item['file_path']
|
||||
|
||||
# Mark as started
|
||||
app_state.db.mark_discovery_scan_started(queue_id)
|
||||
|
||||
# Check if file still exists
|
||||
from pathlib import Path
|
||||
if not Path(file_path).exists():
|
||||
app_state.db.mark_discovery_scan_completed(queue_id)
|
||||
continue
|
||||
|
||||
# Check if embedding already exists
|
||||
with app_state.db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('SELECT 1 FROM content_embeddings WHERE file_id = ?', (file_id,))
|
||||
if cursor.fetchone():
|
||||
# Already has embedding
|
||||
app_state.db.mark_discovery_scan_completed(queue_id)
|
||||
continue
|
||||
|
||||
# Get content type
|
||||
cursor.execute('SELECT content_type FROM file_inventory WHERE id = ?', (file_id,))
|
||||
row = cursor.fetchone()
|
||||
content_type = row['content_type'] if row else 'image'
|
||||
|
||||
# Generate embedding (CPU intensive - run in executor)
|
||||
def generate():
|
||||
semantic = get_semantic_search(app_state.db)
|
||||
return semantic.generate_embedding_for_file(file_id, file_path, content_type)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
success = await loop.run_in_executor(None, generate)
|
||||
|
||||
if success:
|
||||
app_state.db.mark_discovery_scan_completed(queue_id)
|
||||
logger.debug(f"Generated embedding for file_id {file_id}", module="Discovery")
|
||||
else:
|
||||
app_state.db.mark_discovery_scan_failed(queue_id, "Embedding generation returned False")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process embedding queue item {item.get('id')}: {e}", module="Discovery")
|
||||
app_state.db.mark_discovery_scan_failed(item.get('id', 0), str(e))
|
||||
|
||||
# Small delay between items to avoid CPU spikes
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Discovery queue processor error: {e}", module="Discovery")
|
||||
await asyncio.sleep(60) # Wait longer on errors
|
||||
|
||||
discovery_task = asyncio.create_task(process_discovery_queue())
|
||||
logger.info("✓ Discovery queue processor started (embeddings, checks every 30s)", module="Core")
|
||||
|
||||
# Start periodic quality recheck for Fansly videos (hourly)
|
||||
async def periodic_quality_recheck():
|
||||
"""Recheck flagged Fansly attachments for higher quality every hour"""
|
||||
await asyncio.sleep(300) # Wait 5 min after startup before first check
|
||||
while True:
|
||||
try:
|
||||
from web.backend.routers.paid_content import _auto_quality_recheck_background
|
||||
await _auto_quality_recheck_background()
|
||||
except Exception as e:
|
||||
logger.warning(f"Periodic quality recheck failed: {e}", module="Core")
|
||||
await asyncio.sleep(3600) # 1 hour
|
||||
|
||||
quality_recheck_task = asyncio.create_task(periodic_quality_recheck())
|
||||
logger.info("✓ Quality recheck task started (hourly)", module="Core")
|
||||
|
||||
# Auto-start download queue if enabled
|
||||
try:
|
||||
queue_settings = app_state.settings.get('video_queue')
|
||||
if queue_settings and queue_settings.get('auto_start_on_restart'):
|
||||
# Check if there are pending items
|
||||
with app_state.db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM video_download_queue WHERE status = 'pending'")
|
||||
pending_count = cursor.fetchone()[0]
|
||||
|
||||
if pending_count > 0:
|
||||
# Import the queue processor and start it
|
||||
from web.backend.routers.video_queue import queue_processor, run_queue_processor
|
||||
if not queue_processor.is_running:
|
||||
queue_processor.start()
|
||||
asyncio.create_task(run_queue_processor(app_state))
|
||||
logger.info(f"✓ Auto-started download queue with {pending_count} pending items", module="Core")
|
||||
else:
|
||||
logger.info("✓ Auto-start enabled but no pending items in queue", module="Core")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not auto-start download queue: {e}", module="Core")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("🛑 Shutting down Media Downloader API...", module="Core")
|
||||
checkpoint_task.cancel()
|
||||
discovery_task.cancel()
|
||||
quality_recheck_task.cancel()
|
||||
try:
|
||||
await checkpoint_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
try:
|
||||
await discovery_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
try:
|
||||
await quality_recheck_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Close shared HTTP client
|
||||
try:
|
||||
from web.backend.core.http_client import http_client
|
||||
await http_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
if app_state.scheduler:
|
||||
app_state.scheduler.stop()
|
||||
if app_state.db:
|
||||
app_state.db.close()
|
||||
logger.info("✓ Cleanup complete", module="Core")
|
||||
|
||||
# ============================================================================
|
||||
# FASTAPI APP
|
||||
# ============================================================================
|
||||
|
||||
app = FastAPI(
|
||||
title="Media Downloader API",
|
||||
description="Web API for managing media downloads from Instagram, TikTok, Snapchat, and Forums",
|
||||
version="13.13.1",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# CORS middleware - allow frontend origins
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
|
||||
allow_headers=[
|
||||
"Authorization",
|
||||
"Content-Type",
|
||||
"Accept",
|
||||
"Origin",
|
||||
"X-Requested-With",
|
||||
"X-CSRFToken",
|
||||
"Cache-Control",
|
||||
"Range",
|
||||
],
|
||||
max_age=600, # Cache preflight requests for 10 minutes
|
||||
)
|
||||
|
||||
# GZip compression for responses > 1KB (70% smaller API payloads)
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
|
||||
# Session middleware for 2FA setup flow
|
||||
# Persist generated key to file to avoid invalidating sessions on restart
|
||||
def _load_session_secret():
|
||||
secret_file = Path(__file__).parent.parent.parent / '.session_secret'
|
||||
if os.getenv("SESSION_SECRET_KEY"):
|
||||
return os.getenv("SESSION_SECRET_KEY")
|
||||
if secret_file.exists():
|
||||
with open(secret_file, 'r') as f:
|
||||
secret = f.read().strip()
|
||||
if len(secret) >= 32:
|
||||
return secret
|
||||
new_secret = secrets.token_urlsafe(32)
|
||||
try:
|
||||
fd = os.open(str(secret_file), os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
||||
with os.fdopen(fd, 'w') as f:
|
||||
f.write(new_secret)
|
||||
except Exception:
|
||||
pass # Fall back to in-memory only
|
||||
return new_secret
|
||||
|
||||
SESSION_SECRET_KEY = _load_session_secret()
|
||||
app.add_middleware(SessionMiddleware, secret_key=SESSION_SECRET_KEY)
|
||||
|
||||
# CSRF protection middleware - enabled for state-changing requests
|
||||
# Note: All endpoints are also protected by JWT authentication for double security
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
class CSRFMiddlewareHTTPOnly:
|
||||
"""Wrapper to apply CSRF middleware only to HTTP requests, not WebSocket"""
|
||||
def __init__(self, app: ASGIApp, csrf_middleware_class, **kwargs):
|
||||
self.app = app
|
||||
self.csrf_middleware = csrf_middleware_class(app, **kwargs)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
if scope["type"] == "websocket":
|
||||
# Skip CSRF for WebSocket connections (uses separate auth)
|
||||
await self.app(scope, receive, send)
|
||||
else:
|
||||
# Apply CSRF for HTTP requests
|
||||
await self.csrf_middleware(scope, receive, send)
|
||||
|
||||
CSRF_SECRET_KEY = settings.CSRF_SECRET_KEY or SESSION_SECRET_KEY
|
||||
|
||||
app.add_middleware(
|
||||
CSRFMiddlewareHTTPOnly,
|
||||
csrf_middleware_class=CSRFMiddleware,
|
||||
secret=CSRF_SECRET_KEY,
|
||||
cookie_name="csrftoken",
|
||||
header_name="X-CSRFToken",
|
||||
cookie_secure=settings.SECURE_COOKIES,
|
||||
cookie_httponly=False, # JS needs to read for SPA
|
||||
cookie_samesite="lax", # CSRF protection
|
||||
exempt_urls=[
|
||||
# API endpoints are protected against CSRF via multiple layers:
|
||||
# 1. Primary: JWT in Authorization header (cannot be forged cross-origin)
|
||||
# 2. Secondary: Cookie auth uses samesite='lax' which blocks cross-origin POSTs
|
||||
# 3. Media endpoints (cookie-only) are GET requests, safe from CSRF
|
||||
# This makes additional CSRF token validation redundant for /api/* routes
|
||||
re.compile(r"^/api/.*"),
|
||||
re.compile(r"^/ws$"), # WebSocket endpoint (has separate auth via cookie)
|
||||
]
|
||||
)
|
||||
|
||||
# Rate limiting with Redis persistence
|
||||
# Falls back to in-memory if Redis unavailable
|
||||
redis_uri = f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB}"
|
||||
try:
|
||||
limiter = Limiter(key_func=get_remote_address, storage_uri=redis_uri)
|
||||
logger.info("Rate limiter using Redis storage", module="RateLimit")
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis rate limiting unavailable, using in-memory: {e}", module="RateLimit")
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
|
||||
# Import and include 2FA routes
|
||||
from web.backend.twofa_routes import create_2fa_router
|
||||
|
||||
# ============================================================================
|
||||
# WEBSOCKET CONNECTION MANAGER
|
||||
# ============================================================================
|
||||
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections: List[WebSocket] = []
|
||||
self._loop = None
|
||||
|
||||
async def connect(self, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
self.active_connections.append(websocket)
|
||||
# Capture event loop for thread-safe broadcasts
|
||||
if self._loop is None:
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
def disconnect(self, websocket: WebSocket):
|
||||
try:
|
||||
self.active_connections.remove(websocket)
|
||||
except ValueError:
|
||||
pass # Already removed
|
||||
|
||||
async def broadcast(self, message: dict):
|
||||
logger.info(f"ConnectionManager: Broadcasting to {len(self.active_connections)} clients", module="WebSocket")
|
||||
|
||||
dead_connections = []
|
||||
for connection in self.active_connections:
|
||||
try:
|
||||
await connection.send_json(message)
|
||||
except Exception as e:
|
||||
logger.debug(f"ConnectionManager: Removing dead connection: {e}", module="WebSocket")
|
||||
dead_connections.append(connection)
|
||||
|
||||
# Remove dead connections
|
||||
for connection in dead_connections:
|
||||
try:
|
||||
self.active_connections.remove(connection)
|
||||
except ValueError:
|
||||
pass # Already removed
|
||||
|
||||
def broadcast_sync(self, message: dict):
|
||||
"""Thread-safe broadcast for use in background tasks (sync threads)"""
|
||||
msg_type = message.get('type', 'unknown')
|
||||
|
||||
if not self.active_connections:
|
||||
logger.info(f"broadcast_sync: No active WebSocket connections for {msg_type}", module="WebSocket")
|
||||
return
|
||||
|
||||
logger.info(f"broadcast_sync: Sending {msg_type} to {len(self.active_connections)} clients", module="WebSocket")
|
||||
|
||||
try:
|
||||
# Try to get running loop (we're in async context)
|
||||
loop = asyncio.get_running_loop()
|
||||
asyncio.create_task(self.broadcast(message))
|
||||
logger.info(f"broadcast_sync: Created task for {msg_type}", module="WebSocket")
|
||||
except RuntimeError:
|
||||
# No running loop - we're in a thread, use thread-safe call
|
||||
if self._loop and self._loop.is_running():
|
||||
asyncio.run_coroutine_threadsafe(self.broadcast(message), self._loop)
|
||||
logger.info(f"broadcast_sync: Used threadsafe call for {msg_type}", module="WebSocket")
|
||||
else:
|
||||
logger.warning(f"broadcast_sync: No event loop available for {msg_type}", module="WebSocket")
|
||||
|
||||
manager = ConnectionManager()
|
||||
|
||||
# Store websocket manager in app_state for modular routers to access
|
||||
app_state.websocket_manager = manager
|
||||
|
||||
# Initialize scraper event emitter for real-time monitoring
|
||||
from modules.scraper_event_emitter import ScraperEventEmitter
|
||||
app_state.scraper_event_emitter = ScraperEventEmitter(websocket_manager=manager, app_state=app_state)
|
||||
logger.info("Scraper event emitter initialized for real-time monitoring", module="WebSocket")
|
||||
|
||||
# ============================================================================
|
||||
# WEBSOCKET ENDPOINT
|
||||
# ============================================================================
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
"""WebSocket endpoint for real-time updates.
|
||||
|
||||
Authentication priority:
|
||||
1. Cookie (most secure, set by browser automatically)
|
||||
2. Sec-WebSocket-Protocol header with 'auth.<token>' (secure, not logged)
|
||||
3. Query parameter (fallback, may be logged in access logs)
|
||||
"""
|
||||
auth_token = None
|
||||
|
||||
# Priority 1: Try cookie (most common for browsers)
|
||||
auth_token = websocket.cookies.get('auth_token')
|
||||
|
||||
# Priority 2: Check Sec-WebSocket-Protocol header for "auth.<token>" subprotocol
|
||||
# This is more secure than query params as it's not logged in access logs
|
||||
if not auth_token:
|
||||
protocols = websocket.headers.get('sec-websocket-protocol', '')
|
||||
for protocol in protocols.split(','):
|
||||
protocol = protocol.strip()
|
||||
if protocol.startswith('auth.'):
|
||||
auth_token = protocol[5:] # Remove 'auth.' prefix
|
||||
break
|
||||
|
||||
# Priority 3: Query parameter fallback (least preferred, logged in access logs)
|
||||
if not auth_token:
|
||||
auth_token = websocket.query_params.get('token')
|
||||
|
||||
if not auth_token:
|
||||
await websocket.close(code=4001, reason="Authentication required")
|
||||
logger.warning("WebSocket: Connection rejected - no auth cookie or token", module="WebSocket")
|
||||
return
|
||||
|
||||
# Verify the token
|
||||
payload = app_state.auth.verify_session(auth_token)
|
||||
if not payload:
|
||||
await websocket.close(code=4001, reason="Invalid or expired token")
|
||||
logger.warning("WebSocket: Connection rejected - invalid token", module="WebSocket")
|
||||
return
|
||||
|
||||
username = payload.get('sub', 'unknown')
|
||||
|
||||
# Accept with the auth subprotocol if that was used
|
||||
protocols = websocket.headers.get('sec-websocket-protocol', '')
|
||||
if any(p.strip().startswith('auth.') for p in protocols.split(',')):
|
||||
await websocket.accept(subprotocol='auth')
|
||||
manager.active_connections.append(websocket)
|
||||
if manager._loop is None:
|
||||
manager._loop = asyncio.get_running_loop()
|
||||
else:
|
||||
await manager.connect(websocket)
|
||||
try:
|
||||
# Send initial connection message
|
||||
await websocket.send_json({
|
||||
"type": "connected",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"user": username
|
||||
})
|
||||
logger.info(f"WebSocket: Client connected (user: {username})", module="WebSocket")
|
||||
|
||||
while True:
|
||||
# Keep connection alive with ping/pong
|
||||
# Use raw receive() instead of receive_text() for better proxy compatibility
|
||||
try:
|
||||
message = await asyncio.wait_for(websocket.receive(), timeout=30.0)
|
||||
msg_type = message.get("type", "unknown")
|
||||
|
||||
if msg_type == "websocket.receive":
|
||||
text = message.get("text")
|
||||
if text:
|
||||
# Echo back client messages
|
||||
await websocket.send_json({
|
||||
"type": "echo",
|
||||
"message": text
|
||||
})
|
||||
elif msg_type == "websocket.disconnect":
|
||||
# Client closed connection
|
||||
code = message.get("code", "unknown")
|
||||
logger.info(f"WebSocket: Client {username} sent disconnect (code={code})", module="WebSocket")
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Send ping to keep connection alive
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "ping",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
except Exception as e:
|
||||
# Connection lost
|
||||
logger.info(f"WebSocket: Ping failed for {username}: {type(e).__name__}: {e}", module="WebSocket")
|
||||
break
|
||||
except WebSocketDisconnect as e:
|
||||
logger.info(f"WebSocket: {username} disconnected (code={e.code})", module="WebSocket")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error for {username}: {type(e).__name__}: {e}", module="WebSocket")
|
||||
finally:
|
||||
manager.disconnect(websocket)
|
||||
logger.info(f"WebSocket: {username} removed from active connections ({len(manager.active_connections)} remaining)", module="WebSocket")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# MAIN ENTRY POINT
|
||||
# ============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(
|
||||
"api:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
workers=1, # Single worker needed for WebSocket broadcasts to work correctly
|
||||
timeout_keep_alive=30,
|
||||
log_level="info",
|
||||
limit_concurrency=1000, # Allow up to 1000 concurrent connections
|
||||
backlog=2048 # Queue size for pending connections
|
||||
)
|
||||
|
||||
565
web/backend/auth_manager.py
Normal file
565
web/backend/auth_manager.py
Normal file
@@ -0,0 +1,565 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Authentication Manager for Media Downloader
|
||||
Handles user authentication and sessions
|
||||
"""
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
import secrets
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Dict, List, Tuple
|
||||
from passlib.context import CryptContext
|
||||
from jose import jwt, JWTError
|
||||
from pathlib import Path
|
||||
|
||||
# Password hashing
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
# JWT Configuration
|
||||
def _load_jwt_secret():
|
||||
"""Load JWT secret from file, environment, or generate new one"""
|
||||
# Try to load from file first
|
||||
secret_file = Path(__file__).parent.parent.parent / '.jwt_secret'
|
||||
if secret_file.exists():
|
||||
with open(secret_file, 'r') as f:
|
||||
secret = f.read().strip()
|
||||
# Validate secret has minimum length (at least 32 chars)
|
||||
if len(secret) >= 32:
|
||||
return secret
|
||||
# If too short, regenerate
|
||||
|
||||
# Fallback to environment variable
|
||||
if "JWT_SECRET_KEY" in os.environ:
|
||||
secret = os.environ["JWT_SECRET_KEY"]
|
||||
if len(secret) >= 32:
|
||||
return secret
|
||||
|
||||
# Generate a cryptographically strong secret (384 bits / 48 bytes)
|
||||
new_secret = secrets.token_urlsafe(48)
|
||||
try:
|
||||
with open(secret_file, 'w') as f:
|
||||
f.write(new_secret)
|
||||
os.chmod(secret_file, 0o600)
|
||||
except Exception:
|
||||
# Log warning but continue - in-memory secret will invalidate on restart
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(
|
||||
"Could not save JWT secret to file. Tokens will be invalidated on restart."
|
||||
)
|
||||
|
||||
return new_secret
|
||||
|
||||
SECRET_KEY = _load_jwt_secret()
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 # 24 hours (default)
|
||||
ACCESS_TOKEN_REMEMBER_MINUTES = 60 * 24 * 30 # 30 days (remember me)
|
||||
|
||||
class AuthManager:
|
||||
def __init__(self, db_path: str = None):
|
||||
if db_path is None:
|
||||
db_path = str(Path(__file__).parent.parent.parent / 'database' / 'auth.db')
|
||||
|
||||
self.db_path = db_path
|
||||
|
||||
# Rate limiting
|
||||
self.login_attempts = {}
|
||||
self.max_attempts = 5
|
||||
self.lockout_duration = timedelta(minutes=15)
|
||||
|
||||
self._init_database()
|
||||
self._create_default_user()
|
||||
|
||||
def _init_database(self):
|
||||
"""Initialize the authentication database schema"""
|
||||
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Users table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
username TEXT PRIMARY KEY,
|
||||
password_hash TEXT NOT NULL,
|
||||
role TEXT NOT NULL DEFAULT 'viewer',
|
||||
email TEXT,
|
||||
is_active INTEGER NOT NULL DEFAULT 1,
|
||||
totp_secret TEXT,
|
||||
totp_enabled INTEGER NOT NULL DEFAULT 0,
|
||||
duo_enabled INTEGER NOT NULL DEFAULT 0,
|
||||
duo_username TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
last_login TEXT,
|
||||
preferences TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
# Add duo_enabled column to existing table if it doesn't exist
|
||||
try:
|
||||
cursor.execute("ALTER TABLE users ADD COLUMN duo_enabled INTEGER NOT NULL DEFAULT 0")
|
||||
except sqlite3.OperationalError:
|
||||
pass # Column already exists
|
||||
|
||||
try:
|
||||
cursor.execute("ALTER TABLE users ADD COLUMN duo_username TEXT")
|
||||
except sqlite3.OperationalError:
|
||||
pass # Column already exists
|
||||
|
||||
# Backup codes table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS backup_codes (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT NOT NULL,
|
||||
code_hash TEXT NOT NULL,
|
||||
used INTEGER NOT NULL DEFAULT 0,
|
||||
used_at TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
FOREIGN KEY (username) REFERENCES users(username) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
# Sessions table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
session_token TEXT PRIMARY KEY,
|
||||
username TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
expires_at TEXT NOT NULL,
|
||||
ip_address TEXT,
|
||||
user_agent TEXT,
|
||||
FOREIGN KEY (username) REFERENCES users(username) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
# Login attempts table (for rate limiting)
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS login_attempts (
|
||||
username TEXT PRIMARY KEY,
|
||||
attempts INTEGER NOT NULL DEFAULT 0,
|
||||
last_attempt TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
|
||||
# Audit log
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS auth_audit (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
success INTEGER NOT NULL,
|
||||
ip_address TEXT,
|
||||
details TEXT,
|
||||
timestamp TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
|
||||
def _create_default_user(self):
|
||||
"""Create default admin user if no users exist"""
|
||||
import logging
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("SELECT COUNT(*) FROM users")
|
||||
if cursor.fetchone()[0] == 0:
|
||||
# Get password from environment or generate secure random one
|
||||
default_password = os.environ.get("ADMIN_PASSWORD")
|
||||
password_generated = False
|
||||
|
||||
if not default_password:
|
||||
# Generate a secure random password (16 chars, URL-safe)
|
||||
default_password = secrets.token_urlsafe(16)
|
||||
password_generated = True
|
||||
|
||||
password_hash = pwd_context.hash(default_password)
|
||||
|
||||
cursor.execute("""
|
||||
INSERT INTO users (username, password_hash, role, email, created_at, preferences)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
'admin',
|
||||
password_hash,
|
||||
'admin',
|
||||
'admin@localhost',
|
||||
datetime.now().isoformat(),
|
||||
'{"theme": "dark"}'
|
||||
))
|
||||
|
||||
conn.commit()
|
||||
|
||||
# Security: Never log the password - write to secure file instead
|
||||
if password_generated:
|
||||
password_file = Path(__file__).parent.parent.parent / '.admin_password'
|
||||
try:
|
||||
with open(password_file, 'w') as f:
|
||||
f.write(f"Admin password (delete after first login):\n{default_password}\n")
|
||||
os.chmod(password_file, 0o600)
|
||||
_logger.info("Created default admin user. Password saved to .admin_password")
|
||||
except Exception:
|
||||
# If we can't write file, show masked hint only
|
||||
_logger.info("Created default admin user. Set ADMIN_PASSWORD env var to control password.")
|
||||
else:
|
||||
_logger.info("Created default admin user with password from ADMIN_PASSWORD")
|
||||
|
||||
def verify_password(self, plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
def get_password_hash(self, password: str) -> str:
|
||||
"""Hash a password"""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
def create_access_token(self, data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create a JWT access token"""
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
def verify_token(self, token: str) -> Optional[Dict]:
|
||||
"""Verify and decode a JWT token"""
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
return payload
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
def check_rate_limit(self, username: str) -> Tuple[bool, Optional[str]]:
|
||||
"""Check if user is rate limited"""
|
||||
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT attempts, last_attempt FROM login_attempts
|
||||
WHERE username = ?
|
||||
""", (username,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
return True, None
|
||||
|
||||
attempts, last_attempt_str = row
|
||||
last_attempt = datetime.fromisoformat(last_attempt_str)
|
||||
|
||||
if attempts >= self.max_attempts:
|
||||
time_since = datetime.now() - last_attempt
|
||||
if time_since < self.lockout_duration:
|
||||
remaining = self.lockout_duration - time_since
|
||||
return False, f"Account locked. Try again in {int(remaining.total_seconds() / 60)} minutes"
|
||||
else:
|
||||
# Reset after lockout period
|
||||
self._reset_login_attempts(username)
|
||||
return True, None
|
||||
|
||||
return True, None
|
||||
|
||||
def _reset_login_attempts(self, username: str):
|
||||
"""Reset login attempts for a user"""
|
||||
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM login_attempts WHERE username = ?", (username,))
|
||||
conn.commit()
|
||||
|
||||
def _record_login_attempt(self, username: str, success: bool):
|
||||
"""Record a login attempt"""
|
||||
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
if success:
|
||||
# Reset attempts on successful login
|
||||
cursor.execute("DELETE FROM login_attempts WHERE username = ?", (username,))
|
||||
else:
|
||||
# Increment attempts
|
||||
now = datetime.now().isoformat()
|
||||
cursor.execute("""
|
||||
INSERT INTO login_attempts (username, attempts, last_attempt)
|
||||
VALUES (?, 1, ?)
|
||||
ON CONFLICT(username) DO UPDATE SET
|
||||
attempts = login_attempts.attempts + 1,
|
||||
last_attempt = EXCLUDED.last_attempt
|
||||
""", (username, now))
|
||||
|
||||
conn.commit()
|
||||
|
||||
def authenticate(self, username: str, password: str, ip_address: str = None, remember_me: bool = False) -> Dict:
|
||||
"""Authenticate a user with username and password"""
|
||||
# Check rate limiting
|
||||
allowed, error_msg = self.check_rate_limit(username)
|
||||
if not allowed:
|
||||
return {'success': False, 'error': error_msg, 'requires_2fa': False}
|
||||
|
||||
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT password_hash, role, is_active
|
||||
FROM users WHERE username = ?
|
||||
""", (username,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
self._record_login_attempt(username, False)
|
||||
self._log_audit(username, 'login_failed', False, ip_address, 'Invalid username')
|
||||
return {'success': False, 'error': 'Invalid credentials'}
|
||||
|
||||
password_hash, role, is_active = row
|
||||
|
||||
if not is_active:
|
||||
self._log_audit(username, 'login_failed', False, ip_address, 'Account inactive')
|
||||
return {'success': False, 'error': 'Account is inactive'}
|
||||
|
||||
if not self.verify_password(password, password_hash):
|
||||
self._record_login_attempt(username, False)
|
||||
self._log_audit(username, 'login_failed', False, ip_address, 'Invalid password')
|
||||
return {'success': False, 'error': 'Invalid credentials'}
|
||||
|
||||
# Password is correct, create session (2FA removed)
|
||||
return self._create_session(username, role, ip_address, remember_me)
|
||||
|
||||
def _create_session(self, username: str, role: str, ip_address: str = None, remember_me: bool = False) -> Dict:
|
||||
"""Create a new session and return token - matches backup-central format"""
|
||||
import json
|
||||
import uuid
|
||||
|
||||
# Create JWT token
|
||||
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get user email and preferences
|
||||
cursor.execute("SELECT email, preferences FROM users WHERE username = ?", (username,))
|
||||
row = cursor.fetchone()
|
||||
email = row[0] if row and row[0] else None
|
||||
|
||||
# Parse preferences JSON
|
||||
preferences = {}
|
||||
if row and row[1]:
|
||||
try:
|
||||
preferences = json.loads(row[1])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Use longer expiration if remember_me is enabled
|
||||
expire_minutes = ACCESS_TOKEN_REMEMBER_MINUTES if remember_me else ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
token_data = {"sub": username, "role": role, "email": email}
|
||||
token = self.create_access_token(token_data, expires_delta=timedelta(minutes=expire_minutes))
|
||||
|
||||
# Generate session ID
|
||||
sessionId = str(uuid.uuid4())
|
||||
|
||||
now = datetime.now().isoformat()
|
||||
expires_at = (datetime.now() + timedelta(minutes=expire_minutes)).isoformat()
|
||||
|
||||
cursor.execute("""
|
||||
INSERT INTO sessions (session_token, username, created_at, expires_at, ip_address)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", (token, username, now, expires_at, ip_address))
|
||||
|
||||
# Update last login
|
||||
cursor.execute("""
|
||||
UPDATE users SET last_login = ? WHERE username = ?
|
||||
""", (now, username))
|
||||
|
||||
conn.commit()
|
||||
|
||||
self._log_audit(username, 'login_success', True, ip_address)
|
||||
|
||||
# Return structure matching backup-central (with preferences)
|
||||
return {
|
||||
'success': True,
|
||||
'token': token,
|
||||
'sessionId': sessionId,
|
||||
'user': {
|
||||
'username': username,
|
||||
'role': role,
|
||||
'email': email,
|
||||
'preferences': preferences
|
||||
}
|
||||
}
|
||||
|
||||
def _log_audit(self, username: str, action: str, success: bool, ip_address: str = None, details: str = None):
|
||||
"""Log an audit event"""
|
||||
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
INSERT INTO auth_audit (username, action, success, ip_address, details, timestamp)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""", (username, action, int(success), ip_address, details, datetime.now().isoformat()))
|
||||
|
||||
conn.commit()
|
||||
|
||||
def get_user(self, username: str) -> Optional[Dict]:
|
||||
"""Get user information"""
|
||||
import json
|
||||
|
||||
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT username, role, email, is_active, totp_enabled, duo_enabled, duo_username, created_at, last_login, preferences
|
||||
FROM users WHERE username = ?
|
||||
""", (username,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
# Parse preferences JSON
|
||||
preferences = {}
|
||||
if row[9]:
|
||||
try:
|
||||
preferences = json.loads(row[9])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
'username': row[0],
|
||||
'role': row[1],
|
||||
'email': row[2],
|
||||
'is_active': bool(row[3]),
|
||||
'totp_enabled': bool(row[4]),
|
||||
'duo_enabled': bool(row[5]),
|
||||
'duo_username': row[6],
|
||||
'created_at': row[7],
|
||||
'last_login': row[8],
|
||||
'preferences': preferences
|
||||
}
|
||||
|
||||
def verify_session(self, token: str) -> Optional[Dict]:
|
||||
"""Verify a session token"""
|
||||
payload = self.verify_token(token)
|
||||
if not payload:
|
||||
return None
|
||||
|
||||
username = payload.get("sub")
|
||||
if not username:
|
||||
return None
|
||||
|
||||
# Check if session exists and is valid
|
||||
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT expires_at FROM sessions
|
||||
WHERE session_token = ? AND username = ?
|
||||
""", (token, username))
|
||||
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
expires_at = datetime.fromisoformat(row[0])
|
||||
if expires_at < datetime.now():
|
||||
# Session expired
|
||||
return None
|
||||
|
||||
return payload
|
||||
|
||||
def logout(self, token: str):
|
||||
"""Logout and invalidate session"""
|
||||
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM sessions WHERE session_token = ?", (token,))
|
||||
conn.commit()
|
||||
|
||||
# ============================================================================
|
||||
# DUO 2FA METHODS (Duo Universal Prompt)
|
||||
# ============================================================================
|
||||
# Duo authentication is now handled via DuoManager using Universal Prompt
|
||||
# The flow is: login → redirect to Duo → callback → complete login
|
||||
# All Duo operations are delegated to self.duo_manager
|
||||
|
||||
def change_password(self, username: str, current_password: str, new_password: str, ip_address: str = None) -> Dict:
|
||||
"""Change user password"""
|
||||
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get current password hash
|
||||
cursor.execute("SELECT password_hash FROM users WHERE username = ?", (username,))
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
return {'success': False, 'error': 'User not found'}
|
||||
|
||||
current_hash = row[0]
|
||||
|
||||
# Verify current password
|
||||
if not self.verify_password(current_password, current_hash):
|
||||
self._log_audit(username, 'password_change_failed', False, ip_address, 'Invalid current password')
|
||||
return {'success': False, 'error': 'Current password is incorrect'}
|
||||
|
||||
# Validate new password (minimum 8 characters)
|
||||
if len(new_password) < 8:
|
||||
return {'success': False, 'error': 'New password must be at least 8 characters'}
|
||||
|
||||
# Hash new password
|
||||
new_hash = self.get_password_hash(new_password)
|
||||
|
||||
# Update password
|
||||
cursor.execute("""
|
||||
UPDATE users SET password_hash = ?
|
||||
WHERE username = ?
|
||||
""", (new_hash, username))
|
||||
|
||||
# Security: Invalidate ALL existing sessions for this user
|
||||
# This ensures compromised sessions are revoked when password changes
|
||||
cursor.execute("""
|
||||
DELETE FROM sessions WHERE username = ?
|
||||
""", (username,))
|
||||
|
||||
sessions_invalidated = cursor.rowcount
|
||||
|
||||
conn.commit()
|
||||
|
||||
self._log_audit(username, 'password_changed', True, ip_address,
|
||||
f'Invalidated {sessions_invalidated} sessions')
|
||||
|
||||
return {'success': True, 'sessions_invalidated': sessions_invalidated}
|
||||
|
||||
def update_preferences(self, username: str, preferences: dict) -> Dict:
|
||||
"""Update user preferences (theme, etc.)"""
|
||||
import json
|
||||
|
||||
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get current preferences
|
||||
cursor.execute("SELECT preferences FROM users WHERE username = ?", (username,))
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
return {'success': False, 'error': 'User not found'}
|
||||
|
||||
# Merge with existing preferences
|
||||
current_prefs = {}
|
||||
if row[0]:
|
||||
try:
|
||||
current_prefs = json.loads(row[0])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
current_prefs.update(preferences)
|
||||
|
||||
# Update preferences
|
||||
cursor.execute("""
|
||||
UPDATE users SET preferences = ?
|
||||
WHERE username = ?
|
||||
""", (json.dumps(current_prefs), username))
|
||||
|
||||
conn.commit()
|
||||
|
||||
return {'success': True, 'preferences': current_prefs}
|
||||
224
web/backend/cache_manager.py
Normal file
224
web/backend/cache_manager.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
Redis-based cache manager for Media Downloader API
|
||||
Provides caching for expensive queries with configurable TTL
|
||||
"""
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Callable
|
||||
from functools import wraps
|
||||
import redis
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
# Add parent path to allow imports from modules
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from modules.universal_logger import get_logger
|
||||
from web.backend.core.config import settings
|
||||
|
||||
logger = get_logger('CacheManager')
|
||||
|
||||
|
||||
class CacheManager:
|
||||
"""Redis cache manager with automatic connection handling"""
|
||||
|
||||
def __init__(self, host: str = '127.0.0.1', port: int = 6379, db: int = 0, ttl: int = 300):
|
||||
"""
|
||||
Initialize cache manager
|
||||
|
||||
Args:
|
||||
host: Redis host (default: 127.0.0.1)
|
||||
port: Redis port (default: 6379)
|
||||
db: Redis database number (default: 0)
|
||||
ttl: Default TTL in seconds (default: 300 = 5 minutes)
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.db = db
|
||||
self.default_ttl = ttl
|
||||
self._redis = None
|
||||
self._connect()
|
||||
|
||||
def _connect(self):
|
||||
"""Connect to Redis with error handling"""
|
||||
try:
|
||||
self._redis = redis.Redis(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
db=self.db,
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=2,
|
||||
socket_timeout=2
|
||||
)
|
||||
# Test connection
|
||||
self._redis.ping()
|
||||
logger.info(f"Connected to Redis at {self.host}:{self.port}", module="Redis")
|
||||
except RedisError as e:
|
||||
logger.warning(f"Redis connection failed: {e}. Caching disabled.", module="Redis")
|
||||
self._redis = None
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
"""Check if Redis is available"""
|
||||
if self._redis is None:
|
||||
return False
|
||||
try:
|
||||
self._redis.ping()
|
||||
return True
|
||||
except RedisError:
|
||||
return False
|
||||
|
||||
def get(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
Get value from cache
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
Cached value (deserialized from JSON) or None if not found/error
|
||||
"""
|
||||
if not self.is_available:
|
||||
return None
|
||||
|
||||
try:
|
||||
value = self._redis.get(key)
|
||||
if value is None:
|
||||
return None
|
||||
return json.loads(value)
|
||||
except (RedisError, json.JSONDecodeError) as e:
|
||||
logger.warning(f"Cache get error for key '{key}': {e}", module="Cache")
|
||||
return None
|
||||
|
||||
def set(self, key: str, value: Any, ttl: Optional[int] = None):
|
||||
"""
|
||||
Set value in cache
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache (will be JSON serialized)
|
||||
ttl: TTL in seconds (default: use default_ttl)
|
||||
"""
|
||||
if not self.is_available:
|
||||
return
|
||||
|
||||
ttl = ttl if ttl is not None else self.default_ttl
|
||||
|
||||
try:
|
||||
serialized = json.dumps(value)
|
||||
self._redis.setex(key, ttl, serialized)
|
||||
except (RedisError, TypeError) as e:
|
||||
logger.warning(f"Cache set error for key '{key}': {e}", module="Cache")
|
||||
|
||||
def delete(self, key: str):
|
||||
"""
|
||||
Delete key from cache
|
||||
|
||||
Args:
|
||||
key: Cache key to delete
|
||||
"""
|
||||
if not self.is_available:
|
||||
return
|
||||
|
||||
try:
|
||||
self._redis.delete(key)
|
||||
except RedisError as e:
|
||||
logger.warning(f"Cache delete error for key '{key}': {e}", module="Cache")
|
||||
|
||||
def clear(self, pattern: str = "*"):
|
||||
"""
|
||||
Clear cache keys matching pattern
|
||||
|
||||
Args:
|
||||
pattern: Redis key pattern (default: "*" clears all)
|
||||
"""
|
||||
if not self.is_available:
|
||||
return
|
||||
|
||||
try:
|
||||
keys = self._redis.keys(pattern)
|
||||
if keys:
|
||||
self._redis.delete(*keys)
|
||||
logger.info(f"Cleared {len(keys)} cache keys matching '{pattern}'", module="Cache")
|
||||
except RedisError as e:
|
||||
logger.warning(f"Cache clear error for pattern '{pattern}': {e}", module="Cache")
|
||||
|
||||
def cached(self, key_prefix: str, ttl: Optional[int] = None):
|
||||
"""
|
||||
Decorator for caching function results
|
||||
|
||||
Args:
|
||||
key_prefix: Prefix for cache key (full key includes function args)
|
||||
ttl: TTL in seconds (default: use default_ttl)
|
||||
|
||||
Example:
|
||||
@cache_manager.cached('stats', ttl=300)
|
||||
def get_download_stats(platform: str, days: int):
|
||||
# Expensive query
|
||||
return stats
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
# Build cache key from function args
|
||||
cache_key = f"{key_prefix}:{func.__name__}:{hash((args, tuple(sorted(kwargs.items()))))}"
|
||||
|
||||
# Try to get from cache
|
||||
cached_result = self.get(cache_key)
|
||||
if cached_result is not None:
|
||||
logger.debug(f"Cache HIT: {cache_key}", module="Cache")
|
||||
return cached_result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache MISS: {cache_key}", module="Cache")
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
# Store in cache
|
||||
self.set(cache_key, result, ttl)
|
||||
|
||||
return result
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
# Build cache key from function args
|
||||
cache_key = f"{key_prefix}:{func.__name__}:{hash((args, tuple(sorted(kwargs.items()))))}"
|
||||
|
||||
# Try to get from cache
|
||||
cached_result = self.get(cache_key)
|
||||
if cached_result is not None:
|
||||
logger.debug(f"Cache HIT: {cache_key}", module="Cache")
|
||||
return cached_result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache MISS: {cache_key}", module="Cache")
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
# Store in cache
|
||||
self.set(cache_key, result, ttl)
|
||||
|
||||
return result
|
||||
|
||||
# Return appropriate wrapper based on function type
|
||||
import asyncio
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# Global cache manager instance (use centralized config)
|
||||
cache_manager = CacheManager(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
ttl=settings.REDIS_TTL
|
||||
)
|
||||
|
||||
|
||||
def invalidate_download_cache():
|
||||
"""Invalidate all download-related caches"""
|
||||
cache_manager.clear("downloads:*")
|
||||
cache_manager.clear("stats:*")
|
||||
cache_manager.clear("filters:*")
|
||||
logger.info("Invalidated download-related caches", module="Cache")
|
||||
1
web/backend/core/__init__.py
Normal file
1
web/backend/core/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Core module exports
|
||||
121
web/backend/core/config.py
Normal file
121
web/backend/core/config.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
Unified Configuration Manager
|
||||
|
||||
Provides a single source of truth for all configuration values with a clear hierarchy:
|
||||
1. Environment variables (highest priority)
|
||||
2. .env file
|
||||
3. Database settings
|
||||
4. Hardcoded defaults (lowest priority)
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from functools import lru_cache
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
ENV_PATH = PROJECT_ROOT / '.env'
|
||||
if ENV_PATH.exists():
|
||||
load_dotenv(ENV_PATH)
|
||||
|
||||
|
||||
class Settings:
|
||||
"""Centralized configuration settings"""
|
||||
|
||||
# Project paths
|
||||
PROJECT_ROOT: Path = PROJECT_ROOT
|
||||
DB_PATH: Path = PROJECT_ROOT / 'database' / 'media_downloader.db'
|
||||
CONFIG_PATH: Path = PROJECT_ROOT / 'config' / 'settings.json'
|
||||
LOG_PATH: Path = PROJECT_ROOT / 'logs'
|
||||
TEMP_DIR: Path = PROJECT_ROOT / 'temp'
|
||||
COOKIES_DIR: Path = PROJECT_ROOT / 'cookies'
|
||||
DATA_DIR: Path = PROJECT_ROOT / 'data'
|
||||
VENV_BIN: Path = PROJECT_ROOT / 'venv' / 'bin'
|
||||
MEDIA_BASE_PATH: Path = Path(os.getenv("MEDIA_BASE_PATH", "/opt/immich/md"))
|
||||
REVIEW_PATH: Path = Path(os.getenv("REVIEW_PATH", "/opt/immich/review"))
|
||||
RECYCLE_PATH: Path = Path(os.getenv("RECYCLE_PATH", "/opt/immich/recycle"))
|
||||
|
||||
# External tool paths (use venv by default)
|
||||
YT_DLP_PATH: Path = VENV_BIN / 'yt-dlp'
|
||||
GALLERY_DL_PATH: Path = VENV_BIN / 'gallery-dl'
|
||||
PYTHON_PATH: Path = VENV_BIN / 'python3'
|
||||
|
||||
# Database configuration
|
||||
DATABASE_BACKEND: str = os.getenv("DATABASE_BACKEND", "sqlite")
|
||||
DATABASE_URL: str = os.getenv("DATABASE_URL", "")
|
||||
DB_POOL_SIZE: int = int(os.getenv("DB_POOL_SIZE", "20"))
|
||||
DB_CONNECTION_TIMEOUT: float = float(os.getenv("DB_CONNECTION_TIMEOUT", "30.0"))
|
||||
|
||||
# Process timeouts (seconds)
|
||||
PROCESS_TIMEOUT_SHORT: int = int(os.getenv("PROCESS_TIMEOUT_SHORT", "2"))
|
||||
PROCESS_TIMEOUT_MEDIUM: int = int(os.getenv("PROCESS_TIMEOUT_MEDIUM", "10"))
|
||||
PROCESS_TIMEOUT_LONG: int = int(os.getenv("PROCESS_TIMEOUT_LONG", "120"))
|
||||
|
||||
# WebSocket configuration
|
||||
WEBSOCKET_TIMEOUT: float = float(os.getenv("WEBSOCKET_TIMEOUT", "30.0"))
|
||||
|
||||
# Background task intervals (seconds)
|
||||
ACTIVITY_LOG_INTERVAL: int = int(os.getenv("ACTIVITY_LOG_INTERVAL", "300"))
|
||||
EMBEDDING_QUEUE_CHECK_INTERVAL: int = int(os.getenv("EMBEDDING_QUEUE_CHECK_INTERVAL", "30"))
|
||||
EMBEDDING_BATCH_SIZE: int = int(os.getenv("EMBEDDING_BATCH_SIZE", "10"))
|
||||
EMBEDDING_BATCH_LIMIT: int = int(os.getenv("EMBEDDING_BATCH_LIMIT", "50000"))
|
||||
|
||||
# Thumbnail generation
|
||||
THUMBNAIL_DB_TIMEOUT: float = float(os.getenv("THUMBNAIL_DB_TIMEOUT", "30.0"))
|
||||
THUMBNAIL_SIZE: tuple = (300, 300)
|
||||
|
||||
# Video proxy configuration
|
||||
PROXY_FILE_CACHE_DURATION: int = int(os.getenv("PROXY_FILE_CACHE_DURATION", "300"))
|
||||
|
||||
# Redis cache configuration
|
||||
REDIS_HOST: str = os.getenv("REDIS_HOST", "127.0.0.1")
|
||||
REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379"))
|
||||
REDIS_DB: int = int(os.getenv("REDIS_DB", "0"))
|
||||
REDIS_TTL: int = int(os.getenv("REDIS_TTL", "300"))
|
||||
|
||||
# API configuration
|
||||
API_VERSION: str = "13.13.1"
|
||||
API_TITLE: str = "Media Downloader API"
|
||||
API_DESCRIPTION: str = "Web API for managing media downloads"
|
||||
|
||||
# CORS configuration
|
||||
ALLOWED_ORIGINS: list = os.getenv(
|
||||
"ALLOWED_ORIGINS",
|
||||
"http://localhost:5173,http://localhost:3000,http://127.0.0.1:5173,http://127.0.0.1:3000"
|
||||
).split(",")
|
||||
|
||||
# Security
|
||||
SECURE_COOKIES: bool = os.getenv("SECURE_COOKIES", "false").lower() == "true"
|
||||
SESSION_SECRET_KEY: str = os.getenv("SESSION_SECRET_KEY", "")
|
||||
CSRF_SECRET_KEY: str = os.getenv("CSRF_SECRET_KEY", "")
|
||||
|
||||
# Rate limiting
|
||||
RATE_LIMIT_DEFAULT: str = os.getenv("RATE_LIMIT_DEFAULT", "100/minute")
|
||||
RATE_LIMIT_STRICT: str = os.getenv("RATE_LIMIT_STRICT", "10/minute")
|
||||
|
||||
@classmethod
|
||||
def get(cls, key: str, default: Any = None) -> Any:
|
||||
"""Get a configuration value by key"""
|
||||
return getattr(cls, key, default)
|
||||
|
||||
@classmethod
|
||||
def get_path(cls, key: str) -> Path:
|
||||
"""Get a path configuration value, ensuring it's a Path object"""
|
||||
value = getattr(cls, key, None)
|
||||
if value is None:
|
||||
raise ValueError(f"Configuration key {key} not found")
|
||||
if isinstance(value, Path):
|
||||
return value
|
||||
return Path(value)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
"""Get cached settings instance"""
|
||||
return Settings()
|
||||
|
||||
|
||||
# Convenience exports
|
||||
settings = get_settings()
|
||||
206
web/backend/core/dependencies.py
Normal file
206
web/backend/core/dependencies.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
Shared Dependencies
|
||||
|
||||
Provides dependency injection for FastAPI routes.
|
||||
All authentication, database access, and common dependencies are defined here.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import Depends, HTTPException, Query, Request, WebSocket
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
|
||||
|
||||
from modules.universal_logger import get_logger
|
||||
|
||||
logger = get_logger('API')
|
||||
|
||||
# Security
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# GLOBAL STATE (imported from main app)
|
||||
# ============================================================================
|
||||
|
||||
class AppState:
|
||||
"""Global application state - singleton pattern"""
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
self.db = None
|
||||
self.config: Dict = {}
|
||||
self.scheduler = None
|
||||
self.websocket_manager = None # ConnectionManager instance for broadcasts
|
||||
self.scraper_event_emitter = None # ScraperEventEmitter for real-time scraping monitor
|
||||
self.active_scraper_sessions: Dict = {} # Track active scraping sessions for monitor page
|
||||
self.running_platform_downloads: Dict = {} # Track running platform downloads {platform: process}
|
||||
self.download_tasks: Dict = {}
|
||||
self.auth = None
|
||||
self.settings = None
|
||||
self.face_recognition_semaphore = None
|
||||
self.indexing_running: bool = False
|
||||
self.indexing_start_time = None
|
||||
self.review_rescan_running: bool = False
|
||||
self.review_rescan_progress: Dict = {"current": 0, "total": 0, "current_file": ""}
|
||||
self._initialized = True
|
||||
|
||||
|
||||
def get_app_state() -> AppState:
|
||||
"""Get the global application state"""
|
||||
return AppState()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# AUTHENTICATION DEPENDENCIES
|
||||
# ============================================================================
|
||||
|
||||
async def get_current_user(
|
||||
request: Request,
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
||||
) -> Dict:
|
||||
"""
|
||||
Dependency to get current authenticated user from JWT token.
|
||||
Supports both Authorization header and cookie-based authentication.
|
||||
"""
|
||||
app_state = get_app_state()
|
||||
auth_token = None
|
||||
|
||||
# Try Authorization header first
|
||||
if credentials:
|
||||
auth_token = credentials.credentials
|
||||
# Try cookie second
|
||||
elif 'auth_token' in request.cookies:
|
||||
auth_token = request.cookies.get('auth_token')
|
||||
|
||||
if not auth_token:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
if not app_state.auth:
|
||||
raise HTTPException(status_code=500, detail="Authentication not initialized")
|
||||
|
||||
payload = app_state.auth.verify_session(auth_token)
|
||||
|
||||
if not payload:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
async def get_current_user_optional(
|
||||
request: Request,
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
||||
) -> Optional[Dict]:
|
||||
"""Optional authentication dependency - returns None if not authenticated"""
|
||||
try:
|
||||
return await get_current_user(request, credentials)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
async def get_current_user_media(
|
||||
request: Request,
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
|
||||
token: Optional[str] = Query(None)
|
||||
) -> Dict:
|
||||
"""
|
||||
Authentication for media endpoints.
|
||||
Supports header, cookie, and query parameter tokens (for img/video tags).
|
||||
Priority: 1) Authorization header, 2) Cookie, 3) Query parameter
|
||||
"""
|
||||
app_state = get_app_state()
|
||||
auth_token = None
|
||||
|
||||
# Try to get token from Authorization header first (preferred)
|
||||
if credentials:
|
||||
auth_token = credentials.credentials
|
||||
# Try cookie second (for <img> and <video> tags)
|
||||
elif 'auth_token' in request.cookies:
|
||||
auth_token = request.cookies.get('auth_token')
|
||||
# Fall back to query parameter (for backward compatibility)
|
||||
elif token:
|
||||
auth_token = token
|
||||
|
||||
if not auth_token:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
if not app_state.auth:
|
||||
raise HTTPException(status_code=500, detail="Authentication not initialized")
|
||||
|
||||
payload = app_state.auth.verify_session(auth_token)
|
||||
if not payload:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
async def require_admin(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
) -> Dict:
|
||||
"""Dependency to require admin role for sensitive operations"""
|
||||
app_state = get_app_state()
|
||||
username = current_user.get('sub')
|
||||
|
||||
if not username:
|
||||
raise HTTPException(status_code=401, detail="Invalid user session")
|
||||
|
||||
# Get user info to check role
|
||||
user_info = app_state.auth.get_user(username)
|
||||
if not user_info:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
if user_info.get('role') != 'admin':
|
||||
logger.warning(f"User {username} attempted admin operation without admin role", module="Auth")
|
||||
raise HTTPException(status_code=403, detail="Admin role required")
|
||||
|
||||
return current_user
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DATABASE DEPENDENCIES
|
||||
# ============================================================================
|
||||
|
||||
def get_database():
|
||||
"""Get the database connection"""
|
||||
app_state = get_app_state()
|
||||
if not app_state.db:
|
||||
raise HTTPException(status_code=500, detail="Database not initialized")
|
||||
return app_state.db
|
||||
|
||||
|
||||
def get_settings_manager():
|
||||
"""Get the settings manager"""
|
||||
app_state = get_app_state()
|
||||
if not app_state.settings:
|
||||
raise HTTPException(status_code=500, detail="Settings manager not initialized")
|
||||
return app_state.settings
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# UTILITY DEPENDENCIES
|
||||
# ============================================================================
|
||||
|
||||
def get_scheduler():
|
||||
"""Get the scheduler instance"""
|
||||
app_state = get_app_state()
|
||||
return app_state.scheduler
|
||||
|
||||
|
||||
def get_face_semaphore():
|
||||
"""Get the face recognition semaphore for limiting concurrent operations"""
|
||||
app_state = get_app_state()
|
||||
return app_state.face_recognition_semaphore
|
||||
346
web/backend/core/exceptions.py
Normal file
346
web/backend/core/exceptions.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""
|
||||
Custom Exception Classes
|
||||
|
||||
Provides specific exception types for better error handling and reporting.
|
||||
Replaces broad 'except Exception' handlers with specific, meaningful exceptions.
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# BASE EXCEPTIONS
|
||||
# ============================================================================
|
||||
|
||||
class MediaDownloaderError(Exception):
|
||||
"""Base exception for all Media Downloader errors"""
|
||||
|
||||
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
||||
self.message = message
|
||||
self.details = details or {}
|
||||
super().__init__(self.message)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"error": self.__class__.__name__,
|
||||
"message": self.message,
|
||||
"details": self.details
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DATABASE EXCEPTIONS
|
||||
# ============================================================================
|
||||
|
||||
class DatabaseError(MediaDownloaderError):
|
||||
"""Database operation failed"""
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseConnectionError(DatabaseError):
|
||||
"""Failed to connect to database"""
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseQueryError(DatabaseError):
|
||||
"""Database query failed"""
|
||||
pass
|
||||
|
||||
|
||||
class RecordNotFoundError(DatabaseError):
|
||||
"""Requested record not found in database"""
|
||||
pass
|
||||
|
||||
|
||||
# Alias for generic "not found" scenarios
|
||||
NotFoundError = RecordNotFoundError
|
||||
|
||||
|
||||
class DuplicateRecordError(DatabaseError):
|
||||
"""Attempted to insert duplicate record"""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DOWNLOAD EXCEPTIONS
|
||||
# ============================================================================
|
||||
|
||||
class DownloadError(MediaDownloaderError):
|
||||
"""Download operation failed"""
|
||||
pass
|
||||
|
||||
|
||||
class NetworkError(DownloadError):
|
||||
"""Network request failed"""
|
||||
pass
|
||||
|
||||
|
||||
class RateLimitError(DownloadError):
|
||||
"""Rate limit exceeded"""
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationError(DownloadError):
|
||||
"""Authentication failed for external service"""
|
||||
pass
|
||||
|
||||
|
||||
class PlatformUnavailableError(DownloadError):
|
||||
"""Platform or service is unavailable"""
|
||||
pass
|
||||
|
||||
|
||||
class ContentNotFoundError(DownloadError):
|
||||
"""Requested content not found on platform"""
|
||||
pass
|
||||
|
||||
|
||||
class InvalidURLError(DownloadError):
|
||||
"""Invalid or malformed URL"""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# FILE OPERATION EXCEPTIONS
|
||||
# ============================================================================
|
||||
|
||||
class FileOperationError(MediaDownloaderError):
|
||||
"""File operation failed"""
|
||||
pass
|
||||
|
||||
|
||||
class MediaFileNotFoundError(FileOperationError):
|
||||
"""File not found on filesystem"""
|
||||
pass
|
||||
|
||||
|
||||
class FileAccessError(FileOperationError):
|
||||
"""Cannot access file (permissions, etc.)"""
|
||||
pass
|
||||
|
||||
|
||||
class FileHashError(FileOperationError):
|
||||
"""File hash computation failed"""
|
||||
pass
|
||||
|
||||
|
||||
class FileMoveError(FileOperationError):
|
||||
"""Failed to move file"""
|
||||
pass
|
||||
|
||||
|
||||
class ThumbnailError(FileOperationError):
|
||||
"""Thumbnail generation failed"""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# VALIDATION EXCEPTIONS
|
||||
# ============================================================================
|
||||
|
||||
class ValidationError(MediaDownloaderError):
|
||||
"""Input validation failed"""
|
||||
pass
|
||||
|
||||
|
||||
class InvalidParameterError(ValidationError):
|
||||
"""Invalid parameter value"""
|
||||
pass
|
||||
|
||||
|
||||
class MissingParameterError(ValidationError):
|
||||
"""Required parameter missing"""
|
||||
pass
|
||||
|
||||
|
||||
class PathTraversalError(ValidationError):
|
||||
"""Path traversal attempt detected"""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# AUTHENTICATION/AUTHORIZATION EXCEPTIONS
|
||||
# ============================================================================
|
||||
|
||||
class AuthError(MediaDownloaderError):
|
||||
"""Authentication or authorization error"""
|
||||
pass
|
||||
|
||||
|
||||
class TokenExpiredError(AuthError):
|
||||
"""Authentication token has expired"""
|
||||
pass
|
||||
|
||||
|
||||
class InvalidTokenError(AuthError):
|
||||
"""Authentication token is invalid"""
|
||||
pass
|
||||
|
||||
|
||||
class InsufficientPermissionsError(AuthError):
|
||||
"""User lacks required permissions"""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SERVICE EXCEPTIONS
|
||||
# ============================================================================
|
||||
|
||||
class ServiceError(MediaDownloaderError):
|
||||
"""External service error"""
|
||||
pass
|
||||
|
||||
|
||||
class FlareSolverrError(ServiceError):
|
||||
"""FlareSolverr service error"""
|
||||
pass
|
||||
|
||||
|
||||
class RedisError(ServiceError):
|
||||
"""Redis cache error"""
|
||||
pass
|
||||
|
||||
|
||||
class SchedulerError(ServiceError):
|
||||
"""Scheduler service error"""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# FACE RECOGNITION EXCEPTIONS
|
||||
# ============================================================================
|
||||
|
||||
class FaceRecognitionError(MediaDownloaderError):
|
||||
"""Face recognition operation failed"""
|
||||
pass
|
||||
|
||||
|
||||
class NoFaceDetectedError(FaceRecognitionError):
|
||||
"""No face detected in image"""
|
||||
pass
|
||||
|
||||
|
||||
class FaceEncodingError(FaceRecognitionError):
|
||||
"""Failed to encode face"""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# HTTP EXCEPTION HELPERS
|
||||
# ============================================================================
|
||||
|
||||
def to_http_exception(error: MediaDownloaderError) -> HTTPException:
|
||||
"""Convert a MediaDownloaderError to an HTTPException"""
|
||||
|
||||
# Map exception types to HTTP status codes (most-specific subclasses first)
|
||||
status_map = {
|
||||
# 400 Bad Request (subclasses before parent)
|
||||
InvalidParameterError: 400,
|
||||
MissingParameterError: 400,
|
||||
InvalidURLError: 400,
|
||||
ValidationError: 400,
|
||||
|
||||
# 401 Unauthorized (subclasses before parent)
|
||||
TokenExpiredError: 401,
|
||||
InvalidTokenError: 401,
|
||||
AuthenticationError: 401,
|
||||
AuthError: 401,
|
||||
|
||||
# 403 Forbidden
|
||||
InsufficientPermissionsError: 403,
|
||||
PathTraversalError: 403,
|
||||
FileAccessError: 403,
|
||||
|
||||
# 404 Not Found
|
||||
RecordNotFoundError: 404,
|
||||
MediaFileNotFoundError: 404,
|
||||
ContentNotFoundError: 404,
|
||||
|
||||
# 409 Conflict
|
||||
DuplicateRecordError: 409,
|
||||
|
||||
# 429 Too Many Requests
|
||||
RateLimitError: 429,
|
||||
|
||||
# 500 Internal Server Error (subclasses before parent)
|
||||
DatabaseConnectionError: 500,
|
||||
DatabaseQueryError: 500,
|
||||
DatabaseError: 500,
|
||||
|
||||
# 502 Bad Gateway (subclasses before parent)
|
||||
FlareSolverrError: 502,
|
||||
ServiceError: 502,
|
||||
NetworkError: 502,
|
||||
|
||||
# 503 Service Unavailable
|
||||
PlatformUnavailableError: 503,
|
||||
SchedulerError: 503,
|
||||
}
|
||||
|
||||
# Find the most specific matching status code
|
||||
status_code = 500 # Default to internal server error
|
||||
for exc_type, code in status_map.items():
|
||||
if isinstance(error, exc_type):
|
||||
status_code = code
|
||||
break
|
||||
|
||||
return HTTPException(
|
||||
status_code=status_code,
|
||||
detail=error.to_dict()
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# EXCEPTION HANDLER DECORATOR
|
||||
# ============================================================================
|
||||
|
||||
from functools import wraps
|
||||
from typing import Callable, TypeVar
|
||||
import asyncio
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def handle_exceptions(func: Callable[..., T]) -> Callable[..., T]:
|
||||
"""
|
||||
Decorator to convert MediaDownloaderError exceptions to HTTPException.
|
||||
Use this on route handlers for consistent error responses.
|
||||
"""
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except MediaDownloaderError as e:
|
||||
raise to_http_exception(e)
|
||||
except HTTPException:
|
||||
raise # Let HTTPException pass through
|
||||
except Exception as e:
|
||||
# Log unexpected exceptions
|
||||
from modules.universal_logger import get_logger
|
||||
logger = get_logger('API')
|
||||
logger.error(f"Unexpected error in {func.__name__}: {e}", module="Core")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "InternalError", "message": str(e)}
|
||||
)
|
||||
return async_wrapper
|
||||
else:
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except MediaDownloaderError as e:
|
||||
raise to_http_exception(e)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
from modules.universal_logger import get_logger
|
||||
logger = get_logger('API')
|
||||
logger.error(f"Unexpected error in {func.__name__}: {e}", module="Core")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "InternalError", "message": str(e)}
|
||||
)
|
||||
return sync_wrapper
|
||||
377
web/backend/core/http_client.py
Normal file
377
web/backend/core/http_client.py
Normal file
@@ -0,0 +1,377 @@
|
||||
"""
|
||||
Async HTTP Client
|
||||
|
||||
Provides async HTTP client to replace synchronous requests library in FastAPI endpoints.
|
||||
This prevents blocking the event loop and improves concurrency.
|
||||
|
||||
Usage:
|
||||
from web.backend.core.http_client import http_client, get_http_client
|
||||
|
||||
# In async endpoint
|
||||
async def my_endpoint():
|
||||
async with get_http_client() as client:
|
||||
response = await client.get("https://api.example.com/data")
|
||||
return response.json()
|
||||
|
||||
# Or use the singleton
|
||||
response = await http_client.get("https://api.example.com/data")
|
||||
"""
|
||||
|
||||
import httpx
|
||||
from typing import Optional, Dict, Any, Union
|
||||
from contextlib import asynccontextmanager
|
||||
import asyncio
|
||||
|
||||
from .config import settings
|
||||
from .exceptions import NetworkError, ServiceError
|
||||
|
||||
|
||||
# Default timeouts
|
||||
DEFAULT_TIMEOUT = httpx.Timeout(
|
||||
connect=10.0,
|
||||
read=30.0,
|
||||
write=10.0,
|
||||
pool=10.0
|
||||
)
|
||||
|
||||
# Longer timeout for slow operations
|
||||
LONG_TIMEOUT = httpx.Timeout(
|
||||
connect=10.0,
|
||||
read=120.0,
|
||||
write=30.0,
|
||||
pool=10.0
|
||||
)
|
||||
|
||||
|
||||
class AsyncHTTPClient:
|
||||
"""
|
||||
Async HTTP client wrapper with retry logic and error handling.
|
||||
|
||||
Features:
|
||||
- Connection pooling
|
||||
- Automatic retries
|
||||
- Timeout handling
|
||||
- Error conversion to custom exceptions
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: Optional[str] = None,
|
||||
timeout: Optional[httpx.Timeout] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
max_retries: int = 3
|
||||
):
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout or DEFAULT_TIMEOUT
|
||||
self.headers = headers or {}
|
||||
self.max_retries = max_retries
|
||||
self._client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
"""Get or create the async client"""
|
||||
if self._client is None or self._client.is_closed:
|
||||
client_kwargs = {
|
||||
"timeout": self.timeout,
|
||||
"headers": self.headers,
|
||||
"follow_redirects": True,
|
||||
"limits": httpx.Limits(
|
||||
max_keepalive_connections=20,
|
||||
max_connections=100,
|
||||
keepalive_expiry=30.0
|
||||
)
|
||||
}
|
||||
# Only add base_url if it's not None
|
||||
if self.base_url is not None:
|
||||
client_kwargs["base_url"] = self.base_url
|
||||
|
||||
self._client = httpx.AsyncClient(**client_kwargs)
|
||||
return self._client
|
||||
|
||||
async def close(self):
|
||||
"""Close the client connection"""
|
||||
if self._client and not self._client.is_closed:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
**kwargs
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Make an HTTP request with retry logic.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
url: URL to request
|
||||
**kwargs: Additional arguments for httpx
|
||||
|
||||
Returns:
|
||||
httpx.Response
|
||||
|
||||
Raises:
|
||||
NetworkError: On network failures
|
||||
ServiceError: On HTTP errors
|
||||
"""
|
||||
client = await self._get_client()
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
response = await client.request(method, url, **kwargs)
|
||||
|
||||
# Raise for 4xx/5xx status codes
|
||||
if response.status_code >= 400:
|
||||
if response.status_code >= 500:
|
||||
raise ServiceError(
|
||||
f"Server error: {response.status_code}",
|
||||
{"url": url, "status": response.status_code}
|
||||
)
|
||||
# Don't retry client errors
|
||||
return response
|
||||
|
||||
return response
|
||||
|
||||
except httpx.ConnectError as e:
|
||||
last_error = NetworkError(
|
||||
f"Connection failed: {e}",
|
||||
{"url": url, "attempt": attempt + 1}
|
||||
)
|
||||
except httpx.TimeoutException as e:
|
||||
last_error = NetworkError(
|
||||
f"Request timed out: {e}",
|
||||
{"url": url, "attempt": attempt + 1}
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
last_error = ServiceError(
|
||||
f"HTTP error: {e}",
|
||||
{"url": url, "status": e.response.status_code}
|
||||
)
|
||||
# Don't retry client errors
|
||||
if e.response.status_code < 500:
|
||||
raise last_error
|
||||
except Exception as e:
|
||||
last_error = NetworkError(
|
||||
f"Request failed: {e}",
|
||||
{"url": url, "attempt": attempt + 1}
|
||||
)
|
||||
|
||||
# Wait before retry (exponential backoff)
|
||||
if attempt < self.max_retries - 1:
|
||||
await asyncio.sleep(2 ** attempt)
|
||||
|
||||
raise last_error
|
||||
|
||||
# Convenience methods
|
||||
|
||||
async def get(
|
||||
self,
|
||||
url: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> httpx.Response:
|
||||
"""Make GET request"""
|
||||
request_timeout = httpx.Timeout(timeout) if timeout else None
|
||||
return await self._request(
|
||||
"GET",
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
timeout=request_timeout,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def post(
|
||||
self,
|
||||
url: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
json: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> httpx.Response:
|
||||
"""Make POST request"""
|
||||
request_timeout = httpx.Timeout(timeout) if timeout else None
|
||||
return await self._request(
|
||||
"POST",
|
||||
url,
|
||||
data=data,
|
||||
json=json,
|
||||
headers=headers,
|
||||
timeout=request_timeout,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def put(
|
||||
self,
|
||||
url: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
json: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
**kwargs
|
||||
) -> httpx.Response:
|
||||
"""Make PUT request"""
|
||||
return await self._request(
|
||||
"PUT",
|
||||
url,
|
||||
data=data,
|
||||
json=json,
|
||||
headers=headers,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
url: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
**kwargs
|
||||
) -> httpx.Response:
|
||||
"""Make DELETE request"""
|
||||
return await self._request(
|
||||
"DELETE",
|
||||
url,
|
||||
headers=headers,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def head(
|
||||
self,
|
||||
url: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
**kwargs
|
||||
) -> httpx.Response:
|
||||
"""Make HEAD request"""
|
||||
return await self._request(
|
||||
"HEAD",
|
||||
url,
|
||||
headers=headers,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
# Singleton client for general use
|
||||
http_client = AsyncHTTPClient()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_http_client(
|
||||
base_url: Optional[str] = None,
|
||||
timeout: Optional[httpx.Timeout] = None,
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
):
|
||||
"""
|
||||
Context manager for creating a temporary HTTP client.
|
||||
|
||||
Usage:
|
||||
async with get_http_client(base_url="https://api.example.com") as client:
|
||||
response = await client.get("/endpoint")
|
||||
"""
|
||||
client = AsyncHTTPClient(
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
headers=headers
|
||||
)
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
# Helper functions for common patterns
|
||||
|
||||
async def fetch_json(
|
||||
url: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch JSON from URL.
|
||||
|
||||
Args:
|
||||
url: URL to fetch
|
||||
params: Query parameters
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
Parsed JSON response
|
||||
"""
|
||||
response = await http_client.get(url, params=params, headers=headers)
|
||||
return response.json()
|
||||
|
||||
|
||||
async def post_json(
|
||||
url: str,
|
||||
data: Dict[str, Any],
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
POST JSON to URL and return JSON response.
|
||||
|
||||
Args:
|
||||
url: URL to post to
|
||||
data: JSON data to send
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
Parsed JSON response
|
||||
"""
|
||||
response = await http_client.post(url, json=data, headers=headers)
|
||||
return response.json()
|
||||
|
||||
|
||||
async def check_url_accessible(url: str, timeout: float = 5.0) -> bool:
|
||||
"""
|
||||
Check if a URL is accessible.
|
||||
|
||||
Args:
|
||||
url: URL to check
|
||||
timeout: Request timeout
|
||||
|
||||
Returns:
|
||||
True if URL is accessible
|
||||
"""
|
||||
try:
|
||||
response = await http_client.head(url, timeout=timeout)
|
||||
return response.status_code < 400
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def download_file(
|
||||
url: str,
|
||||
save_path: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
chunk_size: int = 8192
|
||||
) -> int:
|
||||
"""
|
||||
Download file from URL.
|
||||
|
||||
Args:
|
||||
url: URL to download
|
||||
save_path: Path to save file
|
||||
headers: Request headers
|
||||
chunk_size: Size of chunks for streaming
|
||||
|
||||
Returns:
|
||||
Number of bytes downloaded
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
client = await http_client._get_client()
|
||||
total_bytes = 0
|
||||
|
||||
async with client.stream("GET", url, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
# Ensure directory exists
|
||||
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(save_path, 'wb') as f:
|
||||
async for chunk in response.aiter_bytes(chunk_size):
|
||||
f.write(chunk)
|
||||
total_bytes += len(chunk)
|
||||
|
||||
return total_bytes
|
||||
311
web/backend/core/responses.py
Normal file
311
web/backend/core/responses.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""
|
||||
Standardized Response Format
|
||||
|
||||
Provides consistent response structures for all API endpoints.
|
||||
All responses follow a standardized format for error handling and data delivery.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, TypeVar, Generic
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# STANDARD RESPONSE MODELS
|
||||
# ============================================================================
|
||||
|
||||
class ErrorDetail(BaseModel):
|
||||
"""Standard error detail structure"""
|
||||
error: str = Field(..., description="Error type/code")
|
||||
message: str = Field(..., description="Human-readable error message")
|
||||
details: Optional[Dict[str, Any]] = Field(None, description="Additional error context")
|
||||
timestamp: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"))
|
||||
|
||||
|
||||
class SuccessResponse(BaseModel):
|
||||
"""Standard success response"""
|
||||
success: bool = True
|
||||
message: Optional[str] = None
|
||||
data: Optional[Any] = None
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel):
|
||||
"""Standard paginated response"""
|
||||
items: List[Any]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class ListResponse(BaseModel):
|
||||
"""Standard list response with metadata"""
|
||||
items: List[Any]
|
||||
count: int
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# RESPONSE BUILDERS
|
||||
# ============================================================================
|
||||
|
||||
def success(data: Any = None, message: str = None) -> Dict[str, Any]:
|
||||
"""Build a success response"""
|
||||
response = {"success": True}
|
||||
if message:
|
||||
response["message"] = message
|
||||
if data is not None:
|
||||
response["data"] = data
|
||||
return response
|
||||
|
||||
|
||||
def error(
|
||||
error_type: str,
|
||||
message: str,
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Build an error response"""
|
||||
return {
|
||||
"success": False,
|
||||
"error": error_type,
|
||||
"message": message,
|
||||
"details": details or {},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
}
|
||||
|
||||
|
||||
def paginated(
|
||||
items: List[Any],
|
||||
total: int,
|
||||
page: int,
|
||||
page_size: int
|
||||
) -> Dict[str, Any]:
|
||||
"""Build a paginated response"""
|
||||
return {
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"has_more": (page * page_size) < total
|
||||
}
|
||||
|
||||
|
||||
def list_response(items: List[Any], key: str = "items") -> Dict[str, Any]:
|
||||
"""Build a simple list response with custom key"""
|
||||
return {
|
||||
key: items,
|
||||
"count": len(items)
|
||||
}
|
||||
|
||||
|
||||
def message_response(message: str) -> Dict[str, str]:
|
||||
"""Build a simple message response"""
|
||||
return {"message": message}
|
||||
|
||||
|
||||
def count_response(message: str, count: int) -> Dict[str, Any]:
|
||||
"""Build a response with count of affected items"""
|
||||
return {"message": message, "count": count}
|
||||
|
||||
|
||||
def id_response(resource_id: int, message: str = "Resource created successfully") -> Dict[str, Any]:
|
||||
"""Build a response with created resource ID"""
|
||||
return {"id": resource_id, "message": message}
|
||||
|
||||
|
||||
def offset_paginated(
|
||||
items: List[Any],
|
||||
total: int,
|
||||
limit: int,
|
||||
offset: int,
|
||||
key: str = "items",
|
||||
include_timestamp: bool = False,
|
||||
**extra_fields
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build an offset-based paginated response (common pattern in existing routers).
|
||||
|
||||
This is the standard paginated response format used across all endpoints
|
||||
that return lists of items with pagination support.
|
||||
|
||||
Args:
|
||||
items: Items for current page
|
||||
total: Total count of all items
|
||||
limit: Page size (number of items per page)
|
||||
offset: Current offset (starting position)
|
||||
key: Key name for items list (default "items")
|
||||
include_timestamp: Whether to include timestamp field
|
||||
**extra_fields: Additional fields to include in response
|
||||
|
||||
Returns:
|
||||
Dict with paginated response structure
|
||||
|
||||
Example:
|
||||
return offset_paginated(
|
||||
items=media_items,
|
||||
total=total_count,
|
||||
limit=50,
|
||||
offset=0,
|
||||
key="media",
|
||||
stats={"images": 100, "videos": 50}
|
||||
)
|
||||
"""
|
||||
response = {
|
||||
key: items,
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset
|
||||
}
|
||||
|
||||
if include_timestamp:
|
||||
response["timestamp"] = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
# Add any extra fields
|
||||
response.update(extra_fields)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def batch_operation_response(
|
||||
succeeded: bool,
|
||||
processed: List[Any] = None,
|
||||
errors: List[Dict[str, Any]] = None,
|
||||
message: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build a response for batch operations (batch delete, batch move, etc.).
|
||||
|
||||
This provides a standardized format for operations that affect multiple items.
|
||||
|
||||
Args:
|
||||
succeeded: Overall success status
|
||||
processed: List of successfully processed items
|
||||
errors: List of error details for failed items
|
||||
message: Optional message
|
||||
|
||||
Returns:
|
||||
Dict with batch operation response structure
|
||||
|
||||
Example:
|
||||
return batch_operation_response(
|
||||
succeeded=True,
|
||||
processed=["/path/to/file1.jpg", "/path/to/file2.jpg"],
|
||||
errors=[{"file": "/path/to/file3.jpg", "error": "File not found"}]
|
||||
)
|
||||
"""
|
||||
processed = processed or []
|
||||
errors = errors or []
|
||||
|
||||
response = {
|
||||
"success": succeeded,
|
||||
"processed_count": len(processed),
|
||||
"error_count": len(errors),
|
||||
"timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
}
|
||||
|
||||
if processed:
|
||||
response["processed"] = processed
|
||||
|
||||
if errors:
|
||||
response["errors"] = errors
|
||||
|
||||
if message:
|
||||
response["message"] = message
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DATE/TIME UTILITIES
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def to_iso8601(dt: datetime) -> Optional[str]:
|
||||
"""
|
||||
Convert datetime to ISO 8601 format with UTC timezone.
|
||||
This is the standard format for all API responses.
|
||||
|
||||
Example output: "2025-12-04T10:30:00Z"
|
||||
"""
|
||||
if dt is None:
|
||||
return None
|
||||
|
||||
# Ensure timezone awareness
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
dt = dt.astimezone(timezone.utc)
|
||||
|
||||
return dt.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
|
||||
def from_iso8601(date_string: str) -> Optional[datetime]:
|
||||
"""
|
||||
Parse ISO 8601 date string to datetime.
|
||||
Handles various formats including with and without timezone.
|
||||
"""
|
||||
if not date_string:
|
||||
return None
|
||||
|
||||
# Common formats to try
|
||||
formats = [
|
||||
"%Y-%m-%dT%H:%M:%SZ", # ISO with Z
|
||||
"%Y-%m-%dT%H:%M:%S.%fZ", # ISO with microseconds and Z
|
||||
"%Y-%m-%dT%H:%M:%S", # ISO without Z
|
||||
"%Y-%m-%dT%H:%M:%S.%f", # ISO with microseconds
|
||||
"%Y-%m-%d %H:%M:%S", # Space separator
|
||||
"%Y-%m-%d", # Date only
|
||||
]
|
||||
|
||||
for fmt in formats:
|
||||
try:
|
||||
dt = datetime.strptime(date_string, fmt)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Try parsing with fromisoformat (Python 3.7+)
|
||||
try:
|
||||
dt = datetime.fromisoformat(date_string.replace('Z', '+00:00'))
|
||||
return dt
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def format_timestamp(
|
||||
timestamp: Optional[str],
|
||||
output_format: str = "iso8601"
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Format a timestamp string to the specified format.
|
||||
|
||||
Args:
|
||||
timestamp: Input timestamp string
|
||||
output_format: "iso8601", "unix", or strftime format string
|
||||
|
||||
Returns:
|
||||
Formatted timestamp string
|
||||
"""
|
||||
if not timestamp:
|
||||
return None
|
||||
|
||||
dt = from_iso8601(timestamp)
|
||||
if not dt:
|
||||
return timestamp # Return original if parsing failed
|
||||
|
||||
if output_format == "iso8601":
|
||||
return to_iso8601(dt)
|
||||
elif output_format == "unix":
|
||||
return str(int(dt.timestamp()))
|
||||
else:
|
||||
return dt.strftime(output_format)
|
||||
|
||||
|
||||
def now_iso8601() -> str:
|
||||
"""Get current UTC time in ISO 8601 format"""
|
||||
return to_iso8601(datetime.now(timezone.utc))
|
||||
582
web/backend/core/utils.py
Normal file
582
web/backend/core/utils.py
Normal file
@@ -0,0 +1,582 @@
|
||||
"""
|
||||
Shared Utility Functions
|
||||
|
||||
Common helper functions used across multiple routers.
|
||||
"""
|
||||
|
||||
import io
|
||||
import sqlite3
|
||||
import hashlib
|
||||
import subprocess
|
||||
from collections import OrderedDict
|
||||
from contextlib import closing
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
from PIL import Image
|
||||
|
||||
from .config import settings
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# THUMBNAIL LRU CACHE
|
||||
# ============================================================================
|
||||
|
||||
class ThumbnailLRUCache:
|
||||
"""Thread-safe LRU cache for thumbnail binary data.
|
||||
|
||||
Avoids SQLite lookups for frequently accessed thumbnails.
|
||||
Used by media.py and recycle.py routers.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int = 500, max_memory_mb: int = 100):
|
||||
self._cache: OrderedDict[str, bytes] = OrderedDict()
|
||||
self._lock = Lock()
|
||||
self._max_size = max_size
|
||||
self._max_memory = max_memory_mb * 1024 * 1024 # Convert to bytes
|
||||
self._current_memory = 0
|
||||
|
||||
def get(self, key: str) -> Optional[bytes]:
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
# Move to end (most recently used)
|
||||
self._cache.move_to_end(key)
|
||||
return self._cache[key]
|
||||
return None
|
||||
|
||||
def put(self, key: str, data: bytes) -> None:
|
||||
with self._lock:
|
||||
data_size = len(data)
|
||||
|
||||
# Don't cache if single item is too large (>1MB)
|
||||
if data_size > 1024 * 1024:
|
||||
return
|
||||
|
||||
# Remove old entry if exists
|
||||
if key in self._cache:
|
||||
self._current_memory -= len(self._cache[key])
|
||||
del self._cache[key]
|
||||
|
||||
# Evict oldest entries if needed
|
||||
while (len(self._cache) >= self._max_size or
|
||||
self._current_memory + data_size > self._max_memory) and self._cache:
|
||||
oldest_key, oldest_data = self._cache.popitem(last=False)
|
||||
self._current_memory -= len(oldest_data)
|
||||
|
||||
# Add new entry
|
||||
self._cache[key] = data
|
||||
self._current_memory += data_size
|
||||
|
||||
def clear(self) -> None:
|
||||
with self._lock:
|
||||
self._cache.clear()
|
||||
self._current_memory = 0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SQL FILTER CONSTANTS
|
||||
# ============================================================================
|
||||
|
||||
# Valid media file filters (excluding phrase checks, must have valid extension)
|
||||
# Used by downloads, health, and analytics endpoints
|
||||
MEDIA_FILTERS = """
|
||||
(filename NOT LIKE '%_phrase_checked_%' OR filename IS NULL)
|
||||
AND (file_path IS NOT NULL AND file_path != '' OR platform = 'forums')
|
||||
AND (LENGTH(filename) > 20 OR filename LIKE '%_%_%')
|
||||
AND (
|
||||
filename LIKE '%.jpg' OR filename LIKE '%.jpeg' OR
|
||||
filename LIKE '%.png' OR filename LIKE '%.gif' OR
|
||||
filename LIKE '%.heic' OR filename LIKE '%.heif' OR
|
||||
filename LIKE '%.mp4' OR filename LIKE '%.mov' OR
|
||||
filename LIKE '%.webm' OR filename LIKE '%.m4a' OR
|
||||
filename LIKE '%.mp3' OR filename LIKE '%.avi' OR
|
||||
filename LIKE '%.mkv' OR filename LIKE '%.flv'
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PATH VALIDATION
|
||||
# ============================================================================
|
||||
|
||||
# Allowed base paths for file operations
|
||||
ALLOWED_PATHS = [
|
||||
settings.MEDIA_BASE_PATH,
|
||||
settings.REVIEW_PATH,
|
||||
settings.RECYCLE_PATH,
|
||||
Path('/opt/media-downloader/temp/manual_import'),
|
||||
Path('/opt/immich/paid'),
|
||||
Path('/opt/immich/el'),
|
||||
Path('/opt/immich/elv'),
|
||||
]
|
||||
|
||||
|
||||
def validate_file_path(
|
||||
file_path: str,
|
||||
allowed_bases: Optional[List[Path]] = None,
|
||||
require_exists: bool = False
|
||||
) -> Path:
|
||||
"""
|
||||
Validate file path is within allowed directories.
|
||||
Prevents path traversal attacks.
|
||||
|
||||
Args:
|
||||
file_path: Path to validate
|
||||
allowed_bases: List of allowed base paths (defaults to ALLOWED_PATHS)
|
||||
require_exists: If True, also verify the file exists
|
||||
|
||||
Returns:
|
||||
Resolved Path object
|
||||
|
||||
Raises:
|
||||
HTTPException: If path is invalid or outside allowed directories
|
||||
"""
|
||||
if allowed_bases is None:
|
||||
allowed_bases = ALLOWED_PATHS
|
||||
|
||||
requested_path = Path(file_path)
|
||||
|
||||
try:
|
||||
resolved_path = requested_path.resolve()
|
||||
is_allowed = False
|
||||
|
||||
for allowed_base in allowed_bases:
|
||||
try:
|
||||
resolved_path.relative_to(allowed_base.resolve())
|
||||
is_allowed = True
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not is_allowed:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
if require_exists and not resolved_path.exists():
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="Invalid file path")
|
||||
|
||||
return resolved_path
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# QUERY FILTER BUILDER
|
||||
# ============================================================================
|
||||
|
||||
def build_media_filter_query(
|
||||
platform: Optional[str] = None,
|
||||
source: Optional[str] = None,
|
||||
media_type: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
date_from: Optional[str] = None,
|
||||
date_to: Optional[str] = None,
|
||||
table_alias: str = "fi"
|
||||
) -> Tuple[str, List]:
|
||||
"""
|
||||
Build SQL filter clause for common media queries.
|
||||
|
||||
This centralizes the filter building logic that was duplicated across
|
||||
media.py, downloads.py, and review.py routers.
|
||||
|
||||
Args:
|
||||
platform: Filter by platform (e.g., 'instagram', 'tiktok')
|
||||
source: Filter by source (e.g., 'stories', 'posts')
|
||||
media_type: Filter by media type ('image', 'video', 'all')
|
||||
location: Filter by location ('final', 'review', 'recycle')
|
||||
date_from: Start date filter (ISO format)
|
||||
date_to: End date filter (ISO format)
|
||||
table_alias: SQL table alias (default 'fi' for file_inventory)
|
||||
|
||||
Returns:
|
||||
Tuple of (SQL WHERE clause conditions, list of parameters)
|
||||
|
||||
Example:
|
||||
conditions, params = build_media_filter_query(platform="instagram", media_type="video")
|
||||
query = f"SELECT * FROM file_inventory fi WHERE {conditions}"
|
||||
cursor.execute(query, params)
|
||||
"""
|
||||
if not table_alias.isidentifier():
|
||||
table_alias = "fi"
|
||||
|
||||
conditions = []
|
||||
params = []
|
||||
|
||||
if platform:
|
||||
conditions.append(f"{table_alias}.platform = ?")
|
||||
params.append(platform)
|
||||
|
||||
if source:
|
||||
conditions.append(f"{table_alias}.source = ?")
|
||||
params.append(source)
|
||||
|
||||
if media_type and media_type != 'all':
|
||||
conditions.append(f"{table_alias}.media_type = ?")
|
||||
params.append(media_type)
|
||||
|
||||
if location:
|
||||
conditions.append(f"{table_alias}.location = ?")
|
||||
params.append(location)
|
||||
|
||||
if date_from:
|
||||
# Use COALESCE to handle both post_date from downloads and created_date from file_inventory
|
||||
conditions.append(f"""
|
||||
DATE(COALESCE(
|
||||
(SELECT MAX(d.post_date) FROM downloads d WHERE d.file_path = {table_alias}.file_path),
|
||||
{table_alias}.created_date
|
||||
)) >= ?
|
||||
""")
|
||||
params.append(date_from)
|
||||
|
||||
if date_to:
|
||||
conditions.append(f"""
|
||||
DATE(COALESCE(
|
||||
(SELECT MAX(d.post_date) FROM downloads d WHERE d.file_path = {table_alias}.file_path),
|
||||
{table_alias}.created_date
|
||||
)) <= ?
|
||||
""")
|
||||
params.append(date_to)
|
||||
|
||||
return " AND ".join(conditions) if conditions else "1=1", params
|
||||
|
||||
|
||||
def build_platform_list_filter(
|
||||
platforms: Optional[List[str]] = None,
|
||||
table_alias: str = "fi"
|
||||
) -> Tuple[str, List[str]]:
|
||||
"""
|
||||
Build SQL IN clause for filtering by multiple platforms.
|
||||
|
||||
Args:
|
||||
platforms: List of platform names to filter by
|
||||
table_alias: SQL table alias
|
||||
|
||||
Returns:
|
||||
Tuple of (SQL condition string, list of parameters)
|
||||
"""
|
||||
if not platforms:
|
||||
return "1=1", []
|
||||
|
||||
placeholders = ",".join(["?"] * len(platforms))
|
||||
return f"{table_alias}.platform IN ({placeholders})", platforms
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# THUMBNAIL GENERATION
|
||||
# ============================================================================
|
||||
|
||||
def generate_image_thumbnail(file_path: Path, max_size: Tuple[int, int] = (300, 300)) -> Optional[bytes]:
|
||||
"""
|
||||
Generate thumbnail for image file.
|
||||
|
||||
Args:
|
||||
file_path: Path to image file
|
||||
max_size: Maximum thumbnail dimensions
|
||||
|
||||
Returns:
|
||||
JPEG bytes or None if generation fails
|
||||
"""
|
||||
try:
|
||||
img = Image.open(file_path)
|
||||
img.thumbnail(max_size, Image.Resampling.LANCZOS)
|
||||
|
||||
# Convert to RGB if necessary
|
||||
if img.mode in ('RGBA', 'LA', 'P'):
|
||||
background = Image.new('RGB', img.size, (255, 255, 255))
|
||||
if img.mode == 'P':
|
||||
img = img.convert('RGBA')
|
||||
background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None)
|
||||
img = background
|
||||
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='JPEG', quality=85)
|
||||
return buffer.getvalue()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def generate_video_thumbnail(file_path: Path, max_size: Tuple[int, int] = (300, 300)) -> Optional[bytes]:
|
||||
"""
|
||||
Generate thumbnail for video file using ffmpeg.
|
||||
|
||||
Args:
|
||||
file_path: Path to video file
|
||||
max_size: Maximum thumbnail dimensions
|
||||
|
||||
Returns:
|
||||
JPEG bytes or None if generation fails
|
||||
"""
|
||||
# Try seeking to 1s first, then fall back to first frame
|
||||
for seek_time in ['00:00:01.000', '00:00:00.000']:
|
||||
try:
|
||||
result = subprocess.run([
|
||||
'ffmpeg',
|
||||
'-ss', seek_time,
|
||||
'-i', str(file_path),
|
||||
'-vframes', '1',
|
||||
'-f', 'image2pipe',
|
||||
'-vcodec', 'mjpeg',
|
||||
'-'
|
||||
], capture_output=True, timeout=30)
|
||||
|
||||
if result.returncode != 0 or not result.stdout:
|
||||
continue
|
||||
|
||||
# Resize the frame
|
||||
img = Image.open(io.BytesIO(result.stdout))
|
||||
img.thumbnail(max_size, Image.Resampling.LANCZOS)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='JPEG', quality=85)
|
||||
return buffer.getvalue()
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_or_create_thumbnail(
|
||||
file_path: Union[str, Path],
|
||||
media_type: str,
|
||||
content_hash: Optional[str] = None,
|
||||
max_size: Tuple[int, int] = (300, 300)
|
||||
) -> Optional[bytes]:
|
||||
"""
|
||||
Get thumbnail from cache or generate and cache it.
|
||||
|
||||
Uses the thumbnails.db schema: file_hash (PK), file_path, thumbnail_data, created_at, file_mtime.
|
||||
|
||||
Lookup strategy:
|
||||
1. Try content_hash against file_hash column (survives file moves)
|
||||
2. Fall back to file_path lookup (legacy thumbnails)
|
||||
3. Generate and cache if not found
|
||||
|
||||
Args:
|
||||
file_path: Path to media file
|
||||
media_type: 'image' or 'video'
|
||||
content_hash: Optional pre-computed hash (computed from path if not provided)
|
||||
max_size: Maximum thumbnail dimensions
|
||||
|
||||
Returns:
|
||||
JPEG bytes or None if generation fails
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
thumb_db_path = settings.PROJECT_ROOT / 'database' / 'thumbnails.db'
|
||||
|
||||
# Compute hash if not provided
|
||||
file_hash = content_hash if content_hash else hashlib.sha256(str(file_path).encode()).hexdigest()
|
||||
|
||||
# Try to get from cache (skip mtime check — downloaded media files don't change)
|
||||
try:
|
||||
with sqlite3.connect(str(thumb_db_path), timeout=30.0) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 1. Try file_hash lookup (primary key)
|
||||
cursor.execute(
|
||||
"SELECT thumbnail_data FROM thumbnails WHERE file_hash = ?",
|
||||
(file_hash,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result and result[0]:
|
||||
return result[0]
|
||||
|
||||
# 2. Fall back to file_path lookup (legacy thumbnails)
|
||||
cursor.execute(
|
||||
"SELECT thumbnail_data FROM thumbnails WHERE file_path = ?",
|
||||
(str(file_path),)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result and result[0]:
|
||||
return result[0]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Only proceed with generation if the file actually exists
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
# Get mtime only when we need to generate and cache a new thumbnail
|
||||
try:
|
||||
file_mtime = file_path.stat().st_mtime
|
||||
except OSError:
|
||||
file_mtime = 0
|
||||
|
||||
# Generate thumbnail
|
||||
thumbnail_data = None
|
||||
if media_type == 'video':
|
||||
thumbnail_data = generate_video_thumbnail(file_path, max_size)
|
||||
else:
|
||||
thumbnail_data = generate_image_thumbnail(file_path, max_size)
|
||||
|
||||
# Cache the thumbnail
|
||||
if thumbnail_data:
|
||||
try:
|
||||
from .responses import now_iso8601
|
||||
with sqlite3.connect(str(thumb_db_path), timeout=30.0) as conn:
|
||||
conn.execute("""
|
||||
INSERT OR REPLACE INTO thumbnails
|
||||
(file_hash, file_path, thumbnail_data, created_at, file_mtime)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", (file_hash, str(file_path), thumbnail_data, now_iso8601(), file_mtime))
|
||||
conn.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return thumbnail_data
|
||||
|
||||
|
||||
def get_media_dimensions(file_path: str, width: int = None, height: int = None) -> Tuple[Optional[int], Optional[int]]:
|
||||
"""
|
||||
Get media dimensions, falling back to metadata cache if not provided.
|
||||
|
||||
Args:
|
||||
file_path: Path to media file
|
||||
width: Width from file_inventory (may be None)
|
||||
height: Height from file_inventory (may be None)
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height), or (None, None) if not available
|
||||
"""
|
||||
if width is not None and height is not None:
|
||||
return (width, height)
|
||||
|
||||
try:
|
||||
metadata_db_path = settings.PROJECT_ROOT / 'database' / 'media_metadata.db'
|
||||
|
||||
file_hash = hashlib.sha256(file_path.encode()).hexdigest()
|
||||
with closing(sqlite3.connect(str(metadata_db_path))) as conn:
|
||||
cursor = conn.execute(
|
||||
"SELECT width, height FROM media_metadata WHERE file_hash = ?",
|
||||
(file_hash,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
|
||||
if result:
|
||||
return (result[0], result[1])
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return (width, height)
|
||||
|
||||
|
||||
def get_media_dimensions_batch(file_paths: List[str]) -> Dict[str, Tuple[int, int]]:
|
||||
"""
|
||||
Get media dimensions for multiple files in a single query (batch lookup).
|
||||
Avoids N+1 query problem by fetching all dimensions at once.
|
||||
|
||||
Args:
|
||||
file_paths: List of file paths to look up
|
||||
|
||||
Returns:
|
||||
Dict mapping file_path -> (width, height)
|
||||
"""
|
||||
if not file_paths:
|
||||
return {}
|
||||
|
||||
result = {}
|
||||
|
||||
try:
|
||||
metadata_db_path = settings.PROJECT_ROOT / 'database' / 'media_metadata.db'
|
||||
|
||||
# Build hash -> path mapping
|
||||
hash_to_path = {}
|
||||
for fp in file_paths:
|
||||
file_hash = hashlib.sha256(fp.encode()).hexdigest()
|
||||
hash_to_path[file_hash] = fp
|
||||
|
||||
# Query all at once
|
||||
with closing(sqlite3.connect(str(metadata_db_path))) as conn:
|
||||
placeholders = ','.join('?' * len(hash_to_path))
|
||||
cursor = conn.execute(
|
||||
f"SELECT file_hash, width, height FROM media_metadata WHERE file_hash IN ({placeholders})",
|
||||
list(hash_to_path.keys())
|
||||
)
|
||||
|
||||
for row in cursor.fetchall():
|
||||
file_hash, width, height = row
|
||||
if file_hash in hash_to_path:
|
||||
result[hash_to_path[file_hash]] = (width, height)
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DATABASE UTILITIES
|
||||
# ============================================================================
|
||||
|
||||
def update_file_path_in_all_tables(db, old_path: str, new_path: str):
|
||||
"""
|
||||
Update file path in all relevant database tables.
|
||||
|
||||
Used when moving files between locations (review -> final, etc.)
|
||||
to keep database references consistent.
|
||||
|
||||
Args:
|
||||
db: UnifiedDatabase instance
|
||||
old_path: The old file path to replace
|
||||
new_path: The new file path to use
|
||||
"""
|
||||
from modules.universal_logger import get_logger
|
||||
logger = get_logger('API')
|
||||
|
||||
try:
|
||||
with db.get_connection(for_write=True) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('UPDATE downloads SET file_path = ? WHERE file_path = ?',
|
||||
(new_path, old_path))
|
||||
cursor.execute('UPDATE instagram_perceptual_hashes SET file_path = ? WHERE file_path = ?',
|
||||
(new_path, old_path))
|
||||
cursor.execute('UPDATE face_recognition_scans SET file_path = ? WHERE file_path = ?',
|
||||
(new_path, old_path))
|
||||
|
||||
try:
|
||||
cursor.execute('UPDATE semantic_embeddings SET file_path = ? WHERE file_path = ?',
|
||||
(new_path, old_path))
|
||||
except sqlite3.OperationalError:
|
||||
pass # Table may not exist
|
||||
|
||||
conn.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update file paths in tables: {e}", module="Database")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# FACE RECOGNITION UTILITIES
|
||||
# ============================================================================
|
||||
|
||||
# Cached FaceRecognitionModule singleton to avoid loading InsightFace models on every request
|
||||
_face_module_cache: Dict[int, 'FaceRecognitionModule'] = {}
|
||||
|
||||
|
||||
def get_face_module(db, module_name: str = "FaceAPI"):
|
||||
"""
|
||||
Get or create a cached FaceRecognitionModule instance for the given database.
|
||||
|
||||
Uses singleton pattern to avoid reloading heavy InsightFace models on each request.
|
||||
|
||||
Args:
|
||||
db: UnifiedDatabase instance
|
||||
module_name: Name to use in log messages
|
||||
|
||||
Returns:
|
||||
FaceRecognitionModule instance
|
||||
"""
|
||||
from modules.face_recognition_module import FaceRecognitionModule
|
||||
from modules.universal_logger import get_logger
|
||||
logger = get_logger('API')
|
||||
|
||||
db_id = id(db)
|
||||
if db_id not in _face_module_cache:
|
||||
logger.info("Creating cached FaceRecognitionModule instance", module=module_name)
|
||||
_face_module_cache[db_id] = FaceRecognitionModule(unified_db=db)
|
||||
return _face_module_cache[db_id]
|
||||
438
web/backend/duo_manager.py
Normal file
438
web/backend/duo_manager.py
Normal file
@@ -0,0 +1,438 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Duo Manager for Media Downloader
|
||||
Handles Duo Security 2FA Operations
|
||||
|
||||
Based on backup-central's implementation
|
||||
Uses Duo Universal Prompt
|
||||
"""
|
||||
|
||||
import sys
|
||||
import sqlite3
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
# Add parent path to allow imports from modules
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from modules.universal_logger import get_logger
|
||||
|
||||
logger = get_logger('DuoManager')
|
||||
|
||||
|
||||
class DuoManager:
|
||||
def __init__(self, db_path: str = None):
|
||||
if db_path is None:
|
||||
db_path = str(Path(__file__).parent.parent.parent / 'database' / 'auth.db')
|
||||
|
||||
self.db_path = db_path
|
||||
|
||||
# Duo Configuration (from environment variables)
|
||||
self.client_id = os.getenv('DUO_CLIENT_ID')
|
||||
self.client_secret = os.getenv('DUO_CLIENT_SECRET')
|
||||
self.api_host = os.getenv('DUO_API_HOSTNAME')
|
||||
self.redirect_url = os.getenv('DUO_REDIRECT_URL', 'https://md.lic.ad/api/auth/2fa/duo/callback')
|
||||
|
||||
# Check if Duo is configured
|
||||
self.is_configured = bool(self.client_id and self.client_secret and self.api_host)
|
||||
|
||||
# State store (for OAuth flow)
|
||||
self.state_store = {} # username → state mapping
|
||||
|
||||
# Initialize database
|
||||
self._init_database()
|
||||
|
||||
# Initialize Duo client if configured
|
||||
self.duo_client = None
|
||||
if self.is_configured:
|
||||
self._init_duo_client()
|
||||
|
||||
def _init_database(self):
|
||||
"""Initialize Duo-related database tables"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Ensure Duo columns exist in users table
|
||||
try:
|
||||
cursor.execute("ALTER TABLE users ADD COLUMN duo_enabled INTEGER NOT NULL DEFAULT 0")
|
||||
except sqlite3.OperationalError:
|
||||
pass # Column already exists
|
||||
|
||||
try:
|
||||
cursor.execute("ALTER TABLE users ADD COLUMN duo_username TEXT")
|
||||
except sqlite3.OperationalError:
|
||||
pass
|
||||
|
||||
try:
|
||||
cursor.execute("ALTER TABLE users ADD COLUMN duo_enrolled_at TEXT")
|
||||
except sqlite3.OperationalError:
|
||||
pass
|
||||
|
||||
# Duo audit log
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS duo_audit_log (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
success INTEGER NOT NULL,
|
||||
ip_address TEXT,
|
||||
user_agent TEXT,
|
||||
details TEXT,
|
||||
timestamp TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
|
||||
def _init_duo_client(self):
|
||||
"""Initialize Duo Universal Prompt client"""
|
||||
try:
|
||||
from duo_universal import Client
|
||||
self.duo_client = Client(
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret,
|
||||
host=self.api_host,
|
||||
redirect_uri=self.redirect_url
|
||||
)
|
||||
logger.info("Duo Universal Prompt client initialized successfully", module="Duo")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing Duo client: {e}", module="Duo")
|
||||
self.is_configured = False
|
||||
|
||||
def is_duo_configured(self) -> bool:
|
||||
"""Check if Duo is properly configured"""
|
||||
return self.is_configured
|
||||
|
||||
def get_configuration_status(self) -> Dict:
|
||||
"""Get Duo configuration status"""
|
||||
return {
|
||||
'configured': self.is_configured,
|
||||
'hasClientId': bool(self.client_id),
|
||||
'hasClientSecret': bool(self.client_secret),
|
||||
'hasApiHostname': bool(self.api_host),
|
||||
'hasRedirectUrl': bool(self.redirect_url),
|
||||
'apiHostname': self.api_host if self.api_host else None
|
||||
}
|
||||
|
||||
def generate_state(self, username: str, remember_me: bool = False) -> str:
|
||||
"""
|
||||
Generate a state parameter for Duo OAuth flow
|
||||
|
||||
Args:
|
||||
username: Username to associate with this state
|
||||
remember_me: Whether to remember the user for 30 days
|
||||
|
||||
Returns:
|
||||
Random state string
|
||||
"""
|
||||
state = secrets.token_urlsafe(32)
|
||||
self.state_store[state] = {
|
||||
'username': username,
|
||||
'remember_me': remember_me,
|
||||
'created_at': datetime.now(),
|
||||
'expires_at': datetime.now() + timedelta(minutes=10)
|
||||
}
|
||||
return state
|
||||
|
||||
def verify_state(self, state: str) -> Optional[tuple]:
|
||||
"""
|
||||
Verify state parameter and return associated username and remember_me
|
||||
|
||||
Args:
|
||||
state: State parameter from Duo callback
|
||||
|
||||
Returns:
|
||||
Tuple of (username, remember_me) if valid, None otherwise
|
||||
"""
|
||||
if state not in self.state_store:
|
||||
return None
|
||||
|
||||
state_data = self.state_store[state]
|
||||
|
||||
# Check if expired
|
||||
if datetime.now() > state_data['expires_at']:
|
||||
del self.state_store[state]
|
||||
return None
|
||||
|
||||
username = state_data['username']
|
||||
remember_me = state_data.get('remember_me', False)
|
||||
del self.state_store[state] # One-time use
|
||||
|
||||
return (username, remember_me)
|
||||
|
||||
def create_auth_url(self, username: str, remember_me: bool = False) -> Dict:
|
||||
"""
|
||||
Create Duo authentication URL for user
|
||||
|
||||
Args:
|
||||
username: Username to authenticate
|
||||
remember_me: Whether to remember the user for 30 days
|
||||
|
||||
Returns:
|
||||
Dict with authUrl and state
|
||||
"""
|
||||
if not self.is_configured or not self.duo_client:
|
||||
raise Exception("Duo is not configured")
|
||||
|
||||
try:
|
||||
# Generate state
|
||||
state = self.generate_state(username, remember_me)
|
||||
|
||||
# Get Duo username for this user
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT duo_username FROM users WHERE username = ?", (username,))
|
||||
result = cursor.fetchone()
|
||||
|
||||
duo_username = result[0] if result and result[0] else username
|
||||
|
||||
# Create Duo Universal Prompt auth URL
|
||||
auth_url = self.duo_client.create_auth_url(duo_username, state)
|
||||
|
||||
self._log_audit(username, 'duo_auth_url_created', True, None, None,
|
||||
'Duo authentication URL created')
|
||||
|
||||
return {
|
||||
'authUrl': auth_url,
|
||||
'state': state
|
||||
}
|
||||
except Exception as e:
|
||||
self._log_audit(username, 'duo_auth_url_failed', False, None, None,
|
||||
f'Error: {str(e)}')
|
||||
raise
|
||||
|
||||
def verify_duo_response(self, duo_code: str, username: str) -> bool:
|
||||
"""
|
||||
Verify Duo authentication response (Universal Prompt)
|
||||
|
||||
Args:
|
||||
duo_code: Authorization code from Duo
|
||||
username: Expected username
|
||||
|
||||
Returns:
|
||||
True if authentication successful
|
||||
"""
|
||||
if not self.is_configured or not self.duo_client:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Exchange authorization code for 2FA result
|
||||
decoded_token = self.duo_client.exchange_authorization_code_for_2fa_result(
|
||||
duo_code,
|
||||
username
|
||||
)
|
||||
|
||||
# Check authentication result
|
||||
if decoded_token and decoded_token.get('auth_result', {}).get('result') == 'allow':
|
||||
authenticated_username = decoded_token.get('preferred_username')
|
||||
|
||||
# Get Duo username for this user
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT duo_username FROM users WHERE username = ?", (username,))
|
||||
result = cursor.fetchone()
|
||||
|
||||
duo_username = result[0] if result and result[0] else username
|
||||
|
||||
# Check if username matches
|
||||
if authenticated_username == duo_username:
|
||||
self._log_audit(username, 'duo_verify_success', True, None, None,
|
||||
'Duo authentication successful')
|
||||
return True
|
||||
else:
|
||||
self._log_audit(username, 'duo_verify_failed', False, None, None,
|
||||
f'Username mismatch: expected {duo_username}, got {authenticated_username}')
|
||||
return False
|
||||
else:
|
||||
self._log_audit(username, 'duo_verify_failed', False, None, None,
|
||||
'Duo authentication denied')
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self._log_audit(username, 'duo_verify_failed', False, None, None,
|
||||
f'Error: {str(e)}')
|
||||
logger.error(f"Duo verification error: {e}", module="Duo")
|
||||
return False
|
||||
|
||||
def _get_application_key(self) -> str:
|
||||
"""Get or generate application key for Duo"""
|
||||
key_file = Path(__file__).parent.parent.parent / '.duo_application_key'
|
||||
|
||||
if key_file.exists():
|
||||
with open(key_file, 'r') as f:
|
||||
return f.read().strip()
|
||||
|
||||
# Generate new key (at least 40 characters)
|
||||
key = secrets.token_urlsafe(40)
|
||||
try:
|
||||
with open(key_file, 'w') as f:
|
||||
f.write(key)
|
||||
os.chmod(key_file, 0o600)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not save Duo application key: {e}", module="Duo")
|
||||
|
||||
return key
|
||||
|
||||
def enroll_user(self, username: str, duo_username: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Enroll a user in Duo 2FA
|
||||
|
||||
Args:
|
||||
username: Username
|
||||
duo_username: Optional Duo username (defaults to username)
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
if duo_username is None:
|
||||
duo_username = username
|
||||
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
UPDATE users
|
||||
SET duo_enabled = 1, duo_username = ?, duo_enrolled_at = ?
|
||||
WHERE username = ?
|
||||
""", (duo_username, datetime.now().isoformat(), username))
|
||||
|
||||
conn.commit()
|
||||
|
||||
self._log_audit(username, 'duo_enrolled', True, None, None,
|
||||
f'Duo enrolled with username: {duo_username}')
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error enrolling user in Duo: {e}", module="Duo")
|
||||
return False
|
||||
|
||||
def unenroll_user(self, username: str) -> bool:
|
||||
"""
|
||||
Unenroll a user from Duo 2FA
|
||||
|
||||
Args:
|
||||
username: Username
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
UPDATE users
|
||||
SET duo_enabled = 0, duo_username = NULL, duo_enrolled_at = NULL
|
||||
WHERE username = ?
|
||||
""", (username,))
|
||||
|
||||
conn.commit()
|
||||
|
||||
self._log_audit(username, 'duo_unenrolled', True, None, None,
|
||||
'Duo unenrolled')
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error unenrolling user from Duo: {e}", module="Duo")
|
||||
return False
|
||||
|
||||
def get_duo_status(self, username: str) -> Dict:
|
||||
"""
|
||||
Get Duo status for a user
|
||||
|
||||
Args:
|
||||
username: Username
|
||||
|
||||
Returns:
|
||||
Dict with enabled, duoUsername, enrolledAt
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT duo_enabled, duo_username, duo_enrolled_at
|
||||
FROM users
|
||||
WHERE username = ?
|
||||
""", (username,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
return {'enabled': False, 'duoUsername': None, 'enrolledAt': None}
|
||||
|
||||
return {
|
||||
'enabled': bool(row[0]),
|
||||
'duoUsername': row[1],
|
||||
'enrolledAt': row[2]
|
||||
}
|
||||
|
||||
def preauth_user(self, username: str) -> Dict:
|
||||
"""
|
||||
Perform Duo preauth to check enrollment status
|
||||
|
||||
Args:
|
||||
username: Username
|
||||
|
||||
Returns:
|
||||
Dict with result and status_msg
|
||||
"""
|
||||
if not self.is_configured or not self.duo_client:
|
||||
return {'result': 'error', 'status_msg': 'Duo not configured'}
|
||||
|
||||
try:
|
||||
# Get Duo username
|
||||
user_info = self.get_duo_status(username)
|
||||
duo_username = user_info.get('duoUsername', username)
|
||||
|
||||
# Perform preauth
|
||||
preauth_result = self.duo_client.preauth(username=duo_username)
|
||||
|
||||
self._log_audit(username, 'duo_preauth', True, None, None,
|
||||
f'Preauth result: {preauth_result.get("result")}')
|
||||
|
||||
return preauth_result
|
||||
|
||||
except Exception as e:
|
||||
self._log_audit(username, 'duo_preauth_failed', False, None, None,
|
||||
f'Error: {str(e)}')
|
||||
return {'result': 'error', 'status_msg': 'Failed to check Duo enrollment status'}
|
||||
|
||||
def _log_audit(self, username: str, action: str, success: bool,
|
||||
ip_address: Optional[str], user_agent: Optional[str],
|
||||
details: Optional[str] = None):
|
||||
"""Log Duo audit event"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
INSERT INTO duo_audit_log
|
||||
(username, action, success, ip_address, user_agent, details, timestamp)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""", (username, action, int(success), ip_address, user_agent, details,
|
||||
datetime.now().isoformat()))
|
||||
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging Duo audit: {e}", module="Duo")
|
||||
|
||||
def health_check(self) -> Dict:
|
||||
"""
|
||||
Check Duo service health
|
||||
|
||||
Returns:
|
||||
Dict with healthy status and details
|
||||
"""
|
||||
if not self.is_configured:
|
||||
return {'healthy': False, 'error': 'Duo not configured'}
|
||||
|
||||
try:
|
||||
# Simple ping to Duo API
|
||||
response = self.duo_client.ping()
|
||||
return {'healthy': True, 'response': response}
|
||||
except Exception as e:
|
||||
logger.error(f"Duo health check failed: {e}", module="Duo")
|
||||
return {'healthy': False, 'error': 'Duo service unavailable'}
|
||||
1
web/backend/models/__init__.py
Normal file
1
web/backend/models/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Models module exports
|
||||
477
web/backend/models/api_models.py
Normal file
477
web/backend/models/api_models.py
Normal file
@@ -0,0 +1,477 @@
|
||||
"""
|
||||
Pydantic Models for API
|
||||
|
||||
All request and response models for API endpoints.
|
||||
Provides validation and documentation for API contracts.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# AUTHENTICATION MODELS
|
||||
# ============================================================================
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
"""Login request model"""
|
||||
username: str = Field(..., min_length=1, max_length=50)
|
||||
password: str = Field(..., min_length=1)
|
||||
rememberMe: bool = False
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
"""Login response model"""
|
||||
success: bool
|
||||
token: Optional[str] = None
|
||||
username: Optional[str] = None
|
||||
role: Optional[str] = None
|
||||
expires_at: Optional[str] = None
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class ChangePasswordRequest(BaseModel):
|
||||
"""Password change request"""
|
||||
current_password: str = Field(..., min_length=1)
|
||||
new_password: str = Field(..., min_length=8)
|
||||
|
||||
|
||||
class UserPreferences(BaseModel):
|
||||
"""User preferences model"""
|
||||
theme: Optional[str] = Field(None, pattern=r'^(light|dark|system)$')
|
||||
notifications_enabled: Optional[bool] = None
|
||||
default_platform: Optional[str] = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DOWNLOAD MODELS
|
||||
# ============================================================================
|
||||
|
||||
class DownloadResponse(BaseModel):
|
||||
"""Single download record"""
|
||||
id: int
|
||||
platform: str
|
||||
source: str
|
||||
content_type: Optional[str] = None
|
||||
filename: Optional[str] = None
|
||||
file_path: Optional[str] = None
|
||||
file_size: Optional[int] = None
|
||||
download_date: str
|
||||
post_date: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
|
||||
|
||||
class StatsResponse(BaseModel):
|
||||
"""Statistics response"""
|
||||
total_downloads: int
|
||||
by_platform: Dict[str, int]
|
||||
total_size: int
|
||||
recent_24h: int
|
||||
duplicates_prevented: int
|
||||
review_queue_count: int
|
||||
recycle_bin_count: int = 0
|
||||
|
||||
|
||||
class TriggerRequest(BaseModel):
|
||||
"""Manual download trigger request"""
|
||||
username: Optional[str] = None
|
||||
content_types: Optional[List[str]] = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PLATFORM MODELS
|
||||
# ============================================================================
|
||||
|
||||
class PlatformStatus(BaseModel):
|
||||
"""Platform status model"""
|
||||
platform: str
|
||||
enabled: bool
|
||||
last_run: Optional[str] = None
|
||||
next_run: Optional[str] = None
|
||||
status: str
|
||||
|
||||
|
||||
class PlatformConfig(BaseModel):
|
||||
"""Platform configuration"""
|
||||
enabled: bool = False
|
||||
username: Optional[str] = None
|
||||
interval_hours: int = Field(24, ge=1, le=168)
|
||||
randomize: bool = True
|
||||
randomize_minutes: int = Field(30, ge=0, le=180)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CONFIGURATION MODELS
|
||||
# ============================================================================
|
||||
|
||||
class PushoverConfig(BaseModel):
|
||||
"""Pushover notification configuration"""
|
||||
enabled: bool
|
||||
user_key: Optional[str] = Field(None, min_length=30, max_length=30, pattern=r'^[A-Za-z0-9]+$')
|
||||
api_token: Optional[str] = Field(None, min_length=30, max_length=30, pattern=r'^[A-Za-z0-9]+$')
|
||||
priority: int = Field(0, ge=-2, le=2)
|
||||
sound: str = Field("pushover", pattern=r'^[a-z_]+$')
|
||||
|
||||
|
||||
class SchedulerConfig(BaseModel):
|
||||
"""Scheduler configuration"""
|
||||
enabled: bool
|
||||
interval_hours: int = Field(24, ge=1, le=168)
|
||||
randomize: bool = True
|
||||
randomize_minutes: int = Field(30, ge=0, le=180)
|
||||
|
||||
|
||||
class RecycleBinConfig(BaseModel):
|
||||
"""Recycle bin configuration"""
|
||||
enabled: bool = True
|
||||
path: str = "/opt/immich/recycle"
|
||||
retention_days: int = Field(30, ge=1, le=365)
|
||||
max_size_gb: int = Field(50, ge=1, le=1000)
|
||||
auto_cleanup: bool = True
|
||||
|
||||
|
||||
class ConfigUpdate(BaseModel):
|
||||
"""Configuration update request"""
|
||||
config: Dict[str, Any]
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# HEALTH MODELS
|
||||
# ============================================================================
|
||||
|
||||
class ServiceHealth(BaseModel):
|
||||
"""Individual service health status"""
|
||||
status: str = Field(..., pattern=r'^(healthy|unhealthy|unknown)$')
|
||||
message: Optional[str] = None
|
||||
last_check: Optional[str] = None
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class HealthStatus(BaseModel):
|
||||
"""Overall health status"""
|
||||
status: str
|
||||
services: Dict[str, ServiceHealth]
|
||||
last_check: str
|
||||
version: str
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# MEDIA MODELS
|
||||
# ============================================================================
|
||||
|
||||
class MediaItem(BaseModel):
|
||||
"""Media item in gallery"""
|
||||
id: int
|
||||
file_path: str
|
||||
filename: str
|
||||
platform: str
|
||||
source: str
|
||||
content_type: str
|
||||
file_size: Optional[int] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
post_date: Optional[str] = None
|
||||
download_date: Optional[str] = None
|
||||
face_match: Optional[str] = None
|
||||
face_confidence: Optional[float] = None
|
||||
|
||||
|
||||
class BatchDeleteRequest(BaseModel):
|
||||
"""Batch delete request"""
|
||||
file_paths: List[str] = Field(..., min_length=1)
|
||||
permanent: bool = False
|
||||
|
||||
|
||||
class BatchMoveRequest(BaseModel):
|
||||
"""Batch move request"""
|
||||
file_paths: List[str] = Field(..., min_length=1)
|
||||
destination: str
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# REVIEW MODELS
|
||||
# ============================================================================
|
||||
|
||||
class ReviewItem(BaseModel):
|
||||
"""Review queue item"""
|
||||
id: int
|
||||
file_path: str
|
||||
filename: str
|
||||
platform: str
|
||||
source: str
|
||||
content_type: str
|
||||
file_size: Optional[int] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
detected_faces: Optional[int] = None
|
||||
best_match: Optional[str] = None
|
||||
match_confidence: Optional[float] = None
|
||||
|
||||
|
||||
class ReviewKeepRequest(BaseModel):
|
||||
"""Keep review item request"""
|
||||
file_path: str
|
||||
destination: str
|
||||
new_name: Optional[str] = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# FACE RECOGNITION MODELS
|
||||
# ============================================================================
|
||||
|
||||
class FaceReference(BaseModel):
|
||||
"""Face reference model"""
|
||||
id: str
|
||||
name: str
|
||||
created_at: str
|
||||
encoding_count: int = 1
|
||||
thumbnail: Optional[str] = None
|
||||
|
||||
|
||||
class AddReferenceRequest(BaseModel):
|
||||
"""Add face reference request"""
|
||||
file_path: str
|
||||
name: str
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# RECYCLE BIN MODELS
|
||||
# ============================================================================
|
||||
|
||||
class RecycleItem(BaseModel):
|
||||
"""Recycle bin item"""
|
||||
id: str
|
||||
original_path: str
|
||||
original_filename: str
|
||||
recycle_path: str
|
||||
file_size: Optional[int] = None
|
||||
deleted_at: str
|
||||
deleted_from: str
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class RestoreRequest(BaseModel):
|
||||
"""Restore from recycle bin request"""
|
||||
recycle_id: str
|
||||
restore_to: Optional[str] = None # Original path if None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# VIDEO DOWNLOAD MODELS
|
||||
# ============================================================================
|
||||
|
||||
class VideoInfoRequest(BaseModel):
|
||||
"""Video info request"""
|
||||
url: str = Field(..., pattern=r'^https?://')
|
||||
|
||||
|
||||
class VideoDownloadRequest(BaseModel):
|
||||
"""Video download request"""
|
||||
url: str = Field(..., pattern=r'^https?://')
|
||||
format: Optional[str] = None
|
||||
quality: Optional[str] = None
|
||||
|
||||
|
||||
class VideoStatus(BaseModel):
|
||||
"""Video download status"""
|
||||
video_id: str
|
||||
platform: str
|
||||
status: str
|
||||
progress: Optional[float] = None
|
||||
title: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SCHEDULER MODELS
|
||||
# ============================================================================
|
||||
|
||||
class TaskStatus(BaseModel):
|
||||
"""Scheduler task status"""
|
||||
task_id: str
|
||||
platform: str
|
||||
source: Optional[str] = None
|
||||
status: str
|
||||
last_run: Optional[str] = None
|
||||
next_run: Optional[str] = None
|
||||
error_count: int = 0
|
||||
last_error: Optional[str] = None
|
||||
|
||||
|
||||
class SchedulerStatus(BaseModel):
|
||||
"""Overall scheduler status"""
|
||||
running: bool
|
||||
tasks: List[TaskStatus]
|
||||
current_activity: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# NOTIFICATION MODELS
|
||||
# ============================================================================
|
||||
|
||||
class NotificationItem(BaseModel):
|
||||
"""Notification item"""
|
||||
id: int
|
||||
platform: str
|
||||
source: str
|
||||
content_type: str
|
||||
message: str
|
||||
title: Optional[str] = None
|
||||
sent_at: str
|
||||
download_count: int
|
||||
status: str
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SEMANTIC SEARCH MODELS
|
||||
# ============================================================================
|
||||
|
||||
class SemanticSearchRequest(BaseModel):
|
||||
"""Semantic search request"""
|
||||
query: str = Field(..., min_length=1, max_length=500)
|
||||
limit: int = Field(50, ge=1, le=200)
|
||||
threshold: float = Field(0.3, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class SimilarImagesRequest(BaseModel):
|
||||
"""Find similar images request"""
|
||||
file_id: int
|
||||
limit: int = Field(20, ge=1, le=100)
|
||||
threshold: float = Field(0.5, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TAG MODELS
|
||||
# ============================================================================
|
||||
|
||||
class TagCreate(BaseModel):
|
||||
"""Create tag request"""
|
||||
name: str = Field(..., min_length=1, max_length=50)
|
||||
color: Optional[str] = Field(None, pattern=r'^#[0-9A-Fa-f]{6}$')
|
||||
|
||||
|
||||
class TagUpdate(BaseModel):
|
||||
"""Update tag request"""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=50)
|
||||
color: Optional[str] = Field(None, pattern=r'^#[0-9A-Fa-f]{6}$')
|
||||
|
||||
|
||||
class BulkTagRequest(BaseModel):
|
||||
"""Bulk tag operation request"""
|
||||
file_ids: List[int] = Field(..., min_length=1)
|
||||
tag_ids: List[int] = Field(..., min_length=1)
|
||||
operation: str = Field(..., pattern=r'^(add|remove)$')
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# COLLECTION MODELS
|
||||
# ============================================================================
|
||||
|
||||
class CollectionCreate(BaseModel):
|
||||
"""Create collection request"""
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
description: Optional[str] = Field(None, max_length=500)
|
||||
|
||||
|
||||
class CollectionUpdate(BaseModel):
|
||||
"""Update collection request"""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||
description: Optional[str] = Field(None, max_length=500)
|
||||
|
||||
|
||||
class CollectionBulkAdd(BaseModel):
|
||||
"""Bulk add files to collection"""
|
||||
file_ids: List[int] = Field(..., min_length=1)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SMART FOLDER MODELS
|
||||
# ============================================================================
|
||||
|
||||
class SmartFolderCreate(BaseModel):
|
||||
"""Create smart folder request"""
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
query_rules: Dict[str, Any]
|
||||
|
||||
|
||||
class SmartFolderUpdate(BaseModel):
|
||||
"""Update smart folder request"""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||
query_rules: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SCRAPER MODELS
|
||||
# ============================================================================
|
||||
|
||||
class ScraperUpdate(BaseModel):
|
||||
"""Update scraper settings"""
|
||||
enabled: Optional[bool] = None
|
||||
proxy: Optional[str] = None
|
||||
interval_hours: Optional[int] = Field(None, ge=1, le=168)
|
||||
settings: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class CookieUpload(BaseModel):
|
||||
"""Cookie upload for scraper"""
|
||||
cookies: str # JSON string of cookies
|
||||
source: str = Field(..., pattern=r'^(browser|manual|extension)$')
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# STANDARD RESPONSE MODELS
|
||||
# ============================================================================
|
||||
|
||||
class SuccessResponse(BaseModel):
|
||||
"""Standard success response"""
|
||||
success: bool = True
|
||||
message: str = "Operation completed successfully"
|
||||
|
||||
|
||||
class DataResponse(BaseModel):
|
||||
"""Standard data response wrapper"""
|
||||
success: bool = True
|
||||
data: Any
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
"""Response with just a message"""
|
||||
message: str
|
||||
|
||||
|
||||
class CountResponse(BaseModel):
|
||||
"""Response with count of affected items"""
|
||||
message: str
|
||||
count: int
|
||||
|
||||
|
||||
class IdResponse(BaseModel):
|
||||
"""Response with created resource ID"""
|
||||
id: int
|
||||
message: str = "Resource created successfully"
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel):
|
||||
"""Paginated list response base"""
|
||||
items: List[Any]
|
||||
total: int
|
||||
limit: int
|
||||
offset: int
|
||||
has_more: bool = False
|
||||
|
||||
@classmethod
|
||||
def create(cls, items: List[Any], total: int, limit: int, offset: int):
|
||||
"""Helper to create paginated response"""
|
||||
return cls(
|
||||
items=items,
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
has_more=(offset + len(items)) < total
|
||||
)
|
||||
565
web/backend/passkey_manager.py
Normal file
565
web/backend/passkey_manager.py
Normal file
@@ -0,0 +1,565 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Passkey Manager for Media Downloader
|
||||
Handles WebAuthn/Passkey operations
|
||||
|
||||
Based on backup-central's implementation
|
||||
"""
|
||||
|
||||
import sys
|
||||
import sqlite3
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, List
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
# Add parent path to allow imports from modules
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from modules.universal_logger import get_logger
|
||||
|
||||
logger = get_logger('PasskeyManager')
|
||||
|
||||
from webauthn import (
|
||||
generate_registration_options,
|
||||
verify_registration_response,
|
||||
generate_authentication_options,
|
||||
verify_authentication_response,
|
||||
options_to_json
|
||||
)
|
||||
from webauthn.helpers import base64url_to_bytes, bytes_to_base64url
|
||||
from webauthn.helpers.structs import (
|
||||
PublicKeyCredentialDescriptor,
|
||||
UserVerificationRequirement,
|
||||
AttestationConveyancePreference,
|
||||
AuthenticatorSelectionCriteria,
|
||||
ResidentKeyRequirement
|
||||
)
|
||||
|
||||
|
||||
class PasskeyManager:
|
||||
def __init__(self, db_path: str = None):
|
||||
if db_path is None:
|
||||
db_path = str(Path(__file__).parent.parent.parent / 'database' / 'auth.db')
|
||||
|
||||
self.db_path = db_path
|
||||
|
||||
# WebAuthn Configuration
|
||||
self.rp_name = 'Media Downloader'
|
||||
self.rp_id = self._get_rp_id()
|
||||
self.origin = self._get_origin()
|
||||
|
||||
# Challenge timeout (5 minutes)
|
||||
self.challenge_timeout = timedelta(minutes=5)
|
||||
|
||||
# Challenge storage (in-memory for now, could use Redis/DB for production)
|
||||
self.challenges = {} # username → challenge mapping
|
||||
|
||||
# Initialize database
|
||||
self._init_database()
|
||||
|
||||
def _get_rp_id(self) -> str:
|
||||
"""Get Relying Party ID (domain)"""
|
||||
return os.getenv('WEBAUTHN_RP_ID', 'md.lic.ad')
|
||||
|
||||
def _get_origin(self) -> str:
|
||||
"""Get Origin URL"""
|
||||
return os.getenv('WEBAUTHN_ORIGIN', 'https://md.lic.ad')
|
||||
|
||||
def _init_database(self):
|
||||
"""Initialize passkey-related database tables"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Passkeys credentials table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS passkey_credentials (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT NOT NULL,
|
||||
credential_id TEXT NOT NULL UNIQUE,
|
||||
public_key TEXT NOT NULL,
|
||||
sign_count INTEGER NOT NULL DEFAULT 0,
|
||||
transports TEXT,
|
||||
device_name TEXT,
|
||||
aaguid TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
last_used TEXT,
|
||||
FOREIGN KEY (username) REFERENCES users(username) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
# Passkey audit log
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS passkey_audit_log (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
success INTEGER NOT NULL,
|
||||
credential_id TEXT,
|
||||
ip_address TEXT,
|
||||
user_agent TEXT,
|
||||
details TEXT,
|
||||
timestamp TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
|
||||
# Add passkey enabled column to users
|
||||
try:
|
||||
cursor.execute("ALTER TABLE users ADD COLUMN passkey_enabled INTEGER NOT NULL DEFAULT 0")
|
||||
except sqlite3.OperationalError:
|
||||
pass
|
||||
|
||||
conn.commit()
|
||||
|
||||
def get_user_credentials(self, username: str) -> List[Dict]:
|
||||
"""
|
||||
Get all credentials for a user
|
||||
|
||||
Args:
|
||||
username: Username
|
||||
|
||||
Returns:
|
||||
List of credential dicts
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT credential_id, public_key, sign_count, device_name,
|
||||
created_at, last_used_at
|
||||
FROM passkey_credentials
|
||||
WHERE user_id = ?
|
||||
""", (username,))
|
||||
|
||||
rows = cursor.fetchall()
|
||||
|
||||
credentials = []
|
||||
for row in rows:
|
||||
credentials.append({
|
||||
'credential_id': row[0],
|
||||
'public_key': row[1],
|
||||
'sign_count': row[2],
|
||||
'transports': None,
|
||||
'device_name': row[3],
|
||||
'aaguid': None,
|
||||
'created_at': row[4],
|
||||
'last_used': row[5]
|
||||
})
|
||||
|
||||
return credentials
|
||||
|
||||
def generate_registration_options(self, username: str, email: Optional[str] = None) -> Dict:
|
||||
"""
|
||||
Generate registration options for a new passkey
|
||||
|
||||
Args:
|
||||
username: Username
|
||||
email: Optional email for display name
|
||||
|
||||
Returns:
|
||||
Registration options for client
|
||||
"""
|
||||
try:
|
||||
# Get existing credentials
|
||||
existing_credentials = self.get_user_credentials(username)
|
||||
|
||||
# Create exclude credentials list
|
||||
exclude_credentials = []
|
||||
for cred in existing_credentials:
|
||||
exclude_credentials.append(
|
||||
PublicKeyCredentialDescriptor(
|
||||
id=base64url_to_bytes(cred['credential_id']),
|
||||
transports=cred.get('transports')
|
||||
)
|
||||
)
|
||||
|
||||
# Generate registration options
|
||||
options = generate_registration_options(
|
||||
rp_id=self.rp_id,
|
||||
rp_name=self.rp_name,
|
||||
user_id=username.encode('utf-8'),
|
||||
user_name=username,
|
||||
user_display_name=email or username,
|
||||
timeout=60000, # 60 seconds
|
||||
attestation=AttestationConveyancePreference.NONE,
|
||||
exclude_credentials=exclude_credentials,
|
||||
authenticator_selection=AuthenticatorSelectionCriteria(
|
||||
resident_key=ResidentKeyRequirement.PREFERRED,
|
||||
user_verification=UserVerificationRequirement.PREFERRED
|
||||
),
|
||||
supported_pub_key_algs=[-7, -257] # ES256, RS256
|
||||
)
|
||||
|
||||
# Store challenge
|
||||
challenge_str = bytes_to_base64url(options.challenge)
|
||||
self.challenges[username] = {
|
||||
'challenge': challenge_str,
|
||||
'type': 'registration',
|
||||
'created_at': datetime.now(),
|
||||
'expires_at': datetime.now() + self.challenge_timeout
|
||||
}
|
||||
|
||||
self._log_audit(username, 'passkey_registration_options_generated', True,
|
||||
None, None, None, 'Registration options generated')
|
||||
|
||||
# Convert to JSON-serializable format
|
||||
options_dict = options_to_json(options)
|
||||
|
||||
return json.loads(options_dict)
|
||||
|
||||
except Exception as e:
|
||||
self._log_audit(username, 'passkey_registration_options_failed', False,
|
||||
None, None, None, f'Error: {str(e)}')
|
||||
raise
|
||||
|
||||
def verify_registration(self, username: str, credential_data: Dict,
|
||||
device_name: Optional[str] = None) -> Dict:
|
||||
"""
|
||||
Verify registration response and store credential
|
||||
|
||||
Args:
|
||||
username: Username
|
||||
credential_data: Registration response from client
|
||||
device_name: Optional device name
|
||||
|
||||
Returns:
|
||||
Dict with success status
|
||||
"""
|
||||
try:
|
||||
# Get stored challenge
|
||||
if username not in self.challenges:
|
||||
raise Exception("No registration in progress")
|
||||
|
||||
challenge_data = self.challenges[username]
|
||||
|
||||
# Check if expired
|
||||
if datetime.now() > challenge_data['expires_at']:
|
||||
del self.challenges[username]
|
||||
raise Exception("Challenge expired")
|
||||
|
||||
if challenge_data['type'] != 'registration':
|
||||
raise Exception("Invalid challenge type")
|
||||
|
||||
expected_challenge = base64url_to_bytes(challenge_data['challenge'])
|
||||
|
||||
# Verify registration response
|
||||
verification = verify_registration_response(
|
||||
credential=credential_data,
|
||||
expected_challenge=expected_challenge,
|
||||
expected_rp_id=self.rp_id,
|
||||
expected_origin=self.origin
|
||||
)
|
||||
|
||||
# Store credential
|
||||
credential_id = bytes_to_base64url(verification.credential_id)
|
||||
public_key = bytes_to_base64url(verification.credential_public_key)
|
||||
|
||||
# Extract transports if available
|
||||
transports = credential_data.get('response', {}).get('transports', [])
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
INSERT INTO passkey_credentials
|
||||
(user_id, credential_id, public_key, sign_count, device_name, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
username,
|
||||
credential_id,
|
||||
public_key,
|
||||
verification.sign_count,
|
||||
device_name,
|
||||
datetime.now().isoformat()
|
||||
))
|
||||
|
||||
# Enable passkey for user
|
||||
cursor.execute("""
|
||||
UPDATE users SET passkey_enabled = 1
|
||||
WHERE username = ?
|
||||
""", (username,))
|
||||
|
||||
conn.commit()
|
||||
|
||||
# Clean up challenge
|
||||
del self.challenges[username]
|
||||
|
||||
self._log_audit(username, 'passkey_registered', True, credential_id,
|
||||
None, None, f'Passkey registered: {device_name or "Unknown device"}')
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'credentialId': credential_id
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Passkey registration verification failed for {username}: {str(e)}", module="Passkey")
|
||||
logger.debug(f"Passkey registration detailed error:", exc_info=True, module="Passkey")
|
||||
self._log_audit(username, 'passkey_registration_failed', False, None,
|
||||
None, None, f'Error: {str(e)}')
|
||||
raise
|
||||
|
||||
def generate_authentication_options(self, username: Optional[str] = None) -> Dict:
|
||||
"""
|
||||
Generate authentication options for passkey login
|
||||
|
||||
Args:
|
||||
username: Optional username (if None, allows usernameless login)
|
||||
|
||||
Returns:
|
||||
Authentication options for client
|
||||
"""
|
||||
try:
|
||||
allow_credentials = []
|
||||
|
||||
# If username provided, get their credentials
|
||||
if username:
|
||||
user_credentials = self.get_user_credentials(username)
|
||||
for cred in user_credentials:
|
||||
allow_credentials.append(
|
||||
PublicKeyCredentialDescriptor(
|
||||
id=base64url_to_bytes(cred['credential_id']),
|
||||
transports=cred.get('transports')
|
||||
)
|
||||
)
|
||||
|
||||
# Generate authentication options
|
||||
options = generate_authentication_options(
|
||||
rp_id=self.rp_id,
|
||||
timeout=60000, # 60 seconds
|
||||
allow_credentials=allow_credentials if allow_credentials else None,
|
||||
user_verification=UserVerificationRequirement.PREFERRED
|
||||
)
|
||||
|
||||
# Store challenge
|
||||
challenge_str = bytes_to_base64url(options.challenge)
|
||||
self.challenges[username or '__anonymous__'] = {
|
||||
'challenge': challenge_str,
|
||||
'type': 'authentication',
|
||||
'created_at': datetime.now(),
|
||||
'expires_at': datetime.now() + self.challenge_timeout
|
||||
}
|
||||
|
||||
self._log_audit(username or 'anonymous', 'passkey_authentication_options_generated',
|
||||
True, None, None, None, 'Authentication options generated')
|
||||
|
||||
# Convert to JSON-serializable format
|
||||
options_dict = options_to_json(options)
|
||||
|
||||
return json.loads(options_dict)
|
||||
|
||||
except Exception as e:
|
||||
self._log_audit(username or 'anonymous', 'passkey_authentication_options_failed',
|
||||
False, None, None, None, f'Error: {str(e)}')
|
||||
raise
|
||||
|
||||
def verify_authentication(self, username: str, credential_data: Dict) -> Dict:
|
||||
"""
|
||||
Verify authentication response
|
||||
|
||||
Args:
|
||||
username: Username
|
||||
credential_data: Authentication response from client
|
||||
|
||||
Returns:
|
||||
Dict with success status and verified username
|
||||
"""
|
||||
try:
|
||||
# Get stored challenge
|
||||
challenge_key = username or '__anonymous__'
|
||||
if challenge_key not in self.challenges:
|
||||
raise Exception("No authentication in progress")
|
||||
|
||||
challenge_data = self.challenges[challenge_key]
|
||||
|
||||
# Check if expired
|
||||
if datetime.now() > challenge_data['expires_at']:
|
||||
del self.challenges[challenge_key]
|
||||
raise Exception("Challenge expired")
|
||||
|
||||
if challenge_data['type'] != 'authentication':
|
||||
raise Exception("Invalid challenge type")
|
||||
|
||||
expected_challenge = base64url_to_bytes(challenge_data['challenge'])
|
||||
|
||||
# Get credential ID from response
|
||||
credential_id = credential_data.get('id') or credential_data.get('rawId')
|
||||
if not credential_id:
|
||||
raise Exception("No credential ID in response")
|
||||
|
||||
# Find credential in database
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT user_id, public_key, sign_count
|
||||
FROM passkey_credentials
|
||||
WHERE credential_id = ?
|
||||
""", (credential_id,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
raise Exception("Credential not found")
|
||||
|
||||
verified_username, public_key_b64, current_sign_count = row
|
||||
|
||||
# Verify authentication response
|
||||
verification = verify_authentication_response(
|
||||
credential=credential_data,
|
||||
expected_challenge=expected_challenge,
|
||||
expected_rp_id=self.rp_id,
|
||||
expected_origin=self.origin,
|
||||
credential_public_key=base64url_to_bytes(public_key_b64),
|
||||
credential_current_sign_count=current_sign_count
|
||||
)
|
||||
|
||||
# Update sign count and last used
|
||||
cursor.execute("""
|
||||
UPDATE passkey_credentials
|
||||
SET sign_count = ?, last_used_at = ?
|
||||
WHERE credential_id = ?
|
||||
""", (verification.new_sign_count, datetime.now().isoformat(), credential_id))
|
||||
|
||||
conn.commit()
|
||||
|
||||
# Clean up challenge
|
||||
del self.challenges[challenge_key]
|
||||
|
||||
self._log_audit(verified_username, 'passkey_authentication_success', True,
|
||||
credential_id, None, None, 'Passkey authentication successful')
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'username': verified_username
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self._log_audit(username or 'anonymous', 'passkey_authentication_failed', False,
|
||||
None, None, None, f'Error: {str(e)}')
|
||||
raise
|
||||
|
||||
def remove_credential(self, username: str, credential_id: str) -> bool:
|
||||
"""
|
||||
Remove a passkey credential
|
||||
|
||||
Args:
|
||||
username: Username
|
||||
credential_id: Credential ID to remove
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Debug: List all credentials for this user
|
||||
cursor.execute("SELECT credential_id FROM passkey_credentials WHERE user_id = ?", (username,))
|
||||
all_creds = cursor.fetchall()
|
||||
logger.debug(f"All credentials for {username}: {all_creds}", module="Passkey")
|
||||
logger.debug(f"Looking for credential_id: '{credential_id}'", module="Passkey")
|
||||
|
||||
cursor.execute("""
|
||||
DELETE FROM passkey_credentials
|
||||
WHERE user_id = ? AND credential_id = ?
|
||||
""", (username, credential_id))
|
||||
|
||||
deleted = cursor.rowcount > 0
|
||||
logger.debug(f"Delete operation rowcount: {deleted}", module="Passkey")
|
||||
|
||||
# Check if user has any remaining credentials
|
||||
cursor.execute("""
|
||||
SELECT COUNT(*) FROM passkey_credentials
|
||||
WHERE user_id = ?
|
||||
""", (username,))
|
||||
|
||||
remaining = cursor.fetchone()[0]
|
||||
|
||||
# If no credentials left, disable passkey
|
||||
if remaining == 0:
|
||||
cursor.execute("""
|
||||
UPDATE users SET passkey_enabled = 0
|
||||
WHERE username = ?
|
||||
""", (username,))
|
||||
|
||||
conn.commit()
|
||||
|
||||
if deleted:
|
||||
self._log_audit(username, 'passkey_removed', True, credential_id,
|
||||
None, None, 'Passkey credential removed')
|
||||
|
||||
return deleted
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing passkey for {username}: {e}", module="Passkey")
|
||||
return False
|
||||
|
||||
def get_passkey_status(self, username: str) -> Dict:
|
||||
"""
|
||||
Get passkey status for a user
|
||||
|
||||
Args:
|
||||
username: Username
|
||||
|
||||
Returns:
|
||||
Dict with enabled status and credential count
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT COUNT(*) FROM passkey_credentials
|
||||
WHERE user_id = ?
|
||||
""", (username,))
|
||||
|
||||
count = cursor.fetchone()[0]
|
||||
|
||||
return {
|
||||
'enabled': count > 0,
|
||||
'credentialCount': count
|
||||
}
|
||||
|
||||
def list_credentials(self, username: str) -> List[Dict]:
|
||||
"""
|
||||
List all credentials for a user (without sensitive data)
|
||||
|
||||
Args:
|
||||
username: Username
|
||||
|
||||
Returns:
|
||||
List of credential info dicts
|
||||
"""
|
||||
credentials = self.get_user_credentials(username)
|
||||
|
||||
# Remove sensitive data
|
||||
safe_credentials = []
|
||||
for cred in credentials:
|
||||
safe_credentials.append({
|
||||
'credentialId': cred['credential_id'],
|
||||
'deviceName': cred.get('device_name', 'Unknown device'),
|
||||
'createdAt': cred['created_at'],
|
||||
'lastUsed': cred.get('last_used'),
|
||||
'transports': cred.get('transports', [])
|
||||
})
|
||||
|
||||
return safe_credentials
|
||||
|
||||
def _log_audit(self, username: str, action: str, success: bool,
|
||||
credential_id: Optional[str], ip_address: Optional[str],
|
||||
user_agent: Optional[str], details: Optional[str] = None):
|
||||
"""Log passkey audit event"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
INSERT INTO passkey_audit_log
|
||||
(user_id, action, success, credential_id, ip_address, user_agent, details, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (username, action, int(success), credential_id, ip_address, user_agent, details,
|
||||
datetime.now().isoformat()))
|
||||
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging passkey audit: {e}", module="Passkey")
|
||||
6
web/backend/requirements.txt
Normal file
6
web/backend/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
fastapi==0.109.0
|
||||
uvicorn[standard]>=0.27.0,<0.35.0 # 0.40.0+ has breaking changes
|
||||
pydantic==2.5.3
|
||||
python-multipart==0.0.6
|
||||
websockets==12.0
|
||||
psutil==7.1.3
|
||||
90
web/backend/routers/__init__.py
Normal file
90
web/backend/routers/__init__.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
Router module exports
|
||||
|
||||
All API routers for the media-downloader backend.
|
||||
Import routers from here to include them in the main app.
|
||||
|
||||
Usage:
|
||||
from web.backend.routers import (
|
||||
auth_router,
|
||||
health_router,
|
||||
downloads_router,
|
||||
media_router,
|
||||
recycle_router,
|
||||
scheduler_router,
|
||||
video_router,
|
||||
config_router,
|
||||
review_router,
|
||||
face_router,
|
||||
platforms_router,
|
||||
discovery_router,
|
||||
scrapers_router,
|
||||
semantic_router,
|
||||
manual_import_router,
|
||||
stats_router
|
||||
)
|
||||
|
||||
app.include_router(auth_router)
|
||||
app.include_router(health_router)
|
||||
# ... etc
|
||||
"""
|
||||
|
||||
from .auth import router as auth_router
|
||||
from .health import router as health_router
|
||||
from .downloads import router as downloads_router
|
||||
from .media import router as media_router
|
||||
from .recycle import router as recycle_router
|
||||
from .scheduler import router as scheduler_router
|
||||
from .video import router as video_router
|
||||
from .config import router as config_router
|
||||
from .review import router as review_router
|
||||
from .face import router as face_router
|
||||
from .platforms import router as platforms_router
|
||||
from .discovery import router as discovery_router
|
||||
from .scrapers import router as scrapers_router
|
||||
from .semantic import router as semantic_router
|
||||
from .manual_import import router as manual_import_router
|
||||
from .stats import router as stats_router
|
||||
from .celebrity import router as celebrity_router
|
||||
from .video_queue import router as video_queue_router
|
||||
from .maintenance import router as maintenance_router
|
||||
from .files import router as files_router
|
||||
from .appearances import router as appearances_router
|
||||
from .easynews import router as easynews_router
|
||||
from .dashboard import router as dashboard_router
|
||||
from .paid_content import router as paid_content_router
|
||||
from .private_gallery import router as private_gallery_router
|
||||
from .instagram_unified import router as instagram_unified_router
|
||||
from .cloud_backup import router as cloud_backup_router
|
||||
from .press import router as press_router
|
||||
|
||||
__all__ = [
|
||||
'auth_router',
|
||||
'health_router',
|
||||
'downloads_router',
|
||||
'media_router',
|
||||
'recycle_router',
|
||||
'scheduler_router',
|
||||
'video_router',
|
||||
'config_router',
|
||||
'review_router',
|
||||
'face_router',
|
||||
'platforms_router',
|
||||
'discovery_router',
|
||||
'scrapers_router',
|
||||
'semantic_router',
|
||||
'manual_import_router',
|
||||
'stats_router',
|
||||
'celebrity_router',
|
||||
'video_queue_router',
|
||||
'maintenance_router',
|
||||
'files_router',
|
||||
'appearances_router',
|
||||
'easynews_router',
|
||||
'dashboard_router',
|
||||
'paid_content_router',
|
||||
'private_gallery_router',
|
||||
'instagram_unified_router',
|
||||
'cloud_backup_router',
|
||||
'press_router',
|
||||
]
|
||||
3422
web/backend/routers/appearances.py
Normal file
3422
web/backend/routers/appearances.py
Normal file
File diff suppressed because it is too large
Load Diff
217
web/backend/routers/auth.py
Normal file
217
web/backend/routers/auth.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Authentication Router
|
||||
|
||||
Handles all authentication-related endpoints:
|
||||
- Login/Logout
|
||||
- User info
|
||||
- Password changes
|
||||
- User preferences
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from typing import Dict
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Body, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from ..core.dependencies import get_current_user, get_app_state
|
||||
from ..core.config import settings
|
||||
from ..core.exceptions import AuthError, handle_exceptions
|
||||
from ..core.responses import to_iso8601, now_iso8601
|
||||
from ..models.api_models import LoginRequest, ChangePasswordRequest, UserPreferences
|
||||
from modules.universal_logger import get_logger
|
||||
|
||||
logger = get_logger('API')
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["Authentication"])
|
||||
|
||||
# Rate limiter - will be set from main app
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
@limiter.limit("5/minute")
|
||||
@handle_exceptions
|
||||
async def login(login_data: LoginRequest, request: Request):
|
||||
"""
|
||||
Authenticate user with username and password.
|
||||
|
||||
Returns JWT token or 2FA challenge if 2FA is enabled.
|
||||
"""
|
||||
app_state = get_app_state()
|
||||
|
||||
if not app_state.auth:
|
||||
raise HTTPException(status_code=500, detail="Authentication not initialized")
|
||||
|
||||
# Query user from database
|
||||
with sqlite3.connect(app_state.auth.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT password_hash, role, is_active, totp_enabled, duo_enabled, passkey_enabled
|
||||
FROM users WHERE username = ?
|
||||
""", (login_data.username,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
|
||||
password_hash, role, is_active, totp_enabled, duo_enabled, passkey_enabled = row
|
||||
|
||||
if not is_active:
|
||||
raise HTTPException(status_code=401, detail="Account is inactive")
|
||||
|
||||
if not app_state.auth.verify_password(login_data.password, password_hash):
|
||||
app_state.auth._record_login_attempt(login_data.username, False)
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
|
||||
# Check if user has any 2FA methods enabled
|
||||
available_methods = []
|
||||
if totp_enabled:
|
||||
available_methods.append('totp')
|
||||
if passkey_enabled:
|
||||
available_methods.append('passkey')
|
||||
if duo_enabled:
|
||||
available_methods.append('duo')
|
||||
|
||||
# If user has 2FA enabled, return require2FA flag
|
||||
if available_methods:
|
||||
return {
|
||||
'success': True,
|
||||
'require2FA': True,
|
||||
'availableMethods': available_methods,
|
||||
'username': login_data.username
|
||||
}
|
||||
|
||||
# No 2FA - proceed with normal login
|
||||
result = app_state.auth._create_session(
|
||||
username=login_data.username,
|
||||
role=role,
|
||||
ip_address=request.client.host if request.client else None,
|
||||
remember_me=login_data.rememberMe
|
||||
)
|
||||
|
||||
# Create response with cookie
|
||||
response = JSONResponse(content=result)
|
||||
|
||||
# Set auth cookie (secure, httponly for security)
|
||||
max_age = 30 * 24 * 60 * 60 if login_data.rememberMe else None
|
||||
response.set_cookie(
|
||||
key="auth_token",
|
||||
value=result.get('token'),
|
||||
max_age=max_age,
|
||||
httponly=True,
|
||||
secure=settings.SECURE_COOKIES,
|
||||
samesite="lax",
|
||||
path="/"
|
||||
)
|
||||
|
||||
logger.info(f"User {login_data.username} logged in successfully", module="Auth")
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
@limiter.limit("10/minute")
|
||||
@handle_exceptions
|
||||
async def logout(request: Request, current_user: Dict = Depends(get_current_user)):
|
||||
"""Logout and invalidate session"""
|
||||
username = current_user.get('sub', 'unknown')
|
||||
|
||||
response = JSONResponse(content={
|
||||
"success": True,
|
||||
"message": "Logged out successfully",
|
||||
"timestamp": now_iso8601()
|
||||
})
|
||||
|
||||
# Clear auth cookie
|
||||
response.set_cookie(
|
||||
key="auth_token",
|
||||
value="",
|
||||
max_age=0,
|
||||
httponly=True,
|
||||
secure=settings.SECURE_COOKIES,
|
||||
samesite="lax",
|
||||
path="/"
|
||||
)
|
||||
|
||||
logger.info(f"User {username} logged out", module="Auth")
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def get_me(request: Request, current_user: Dict = Depends(get_current_user)):
|
||||
"""Get current user information"""
|
||||
app_state = get_app_state()
|
||||
username = current_user.get('sub')
|
||||
|
||||
user_info = app_state.auth.get_user(username)
|
||||
if not user_info:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
return user_info
|
||||
|
||||
|
||||
@router.post("/change-password")
|
||||
@limiter.limit("5/minute")
|
||||
@handle_exceptions
|
||||
async def change_password(
|
||||
request: Request,
|
||||
current_password: str = Body(..., embed=True),
|
||||
new_password: str = Body(..., embed=True),
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Change user password"""
|
||||
app_state = get_app_state()
|
||||
username = current_user.get('sub')
|
||||
ip_address = request.client.host if request.client else None
|
||||
|
||||
# Validate new password
|
||||
if len(new_password) < 8:
|
||||
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
|
||||
|
||||
result = app_state.auth.change_password(username, current_password, new_password, ip_address)
|
||||
|
||||
if not result['success']:
|
||||
raise HTTPException(status_code=400, detail=result.get('error', 'Password change failed'))
|
||||
|
||||
logger.info(f"Password changed for user {username}", module="Auth")
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Password changed successfully",
|
||||
"timestamp": now_iso8601()
|
||||
}
|
||||
|
||||
|
||||
@router.post("/preferences")
|
||||
@limiter.limit("10/minute")
|
||||
@handle_exceptions
|
||||
async def update_preferences(
|
||||
request: Request,
|
||||
preferences: dict = Body(...),
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Update user preferences (theme, notifications, etc.)"""
|
||||
app_state = get_app_state()
|
||||
username = current_user.get('sub')
|
||||
|
||||
# Validate theme if provided
|
||||
if 'theme' in preferences:
|
||||
if preferences['theme'] not in ('light', 'dark', 'system'):
|
||||
raise HTTPException(status_code=400, detail="Invalid theme value")
|
||||
|
||||
result = app_state.auth.update_preferences(username, preferences)
|
||||
|
||||
if not result['success']:
|
||||
raise HTTPException(status_code=400, detail=result.get('error', 'Failed to update preferences'))
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Preferences updated",
|
||||
"preferences": preferences,
|
||||
"timestamp": now_iso8601()
|
||||
}
|
||||
2728
web/backend/routers/celebrity.py
Normal file
2728
web/backend/routers/celebrity.py
Normal file
File diff suppressed because it is too large
Load Diff
1525
web/backend/routers/cloud_backup.py
Normal file
1525
web/backend/routers/cloud_backup.py
Normal file
File diff suppressed because it is too large
Load Diff
756
web/backend/routers/config.py
Normal file
756
web/backend/routers/config.py
Normal file
@@ -0,0 +1,756 @@
|
||||
"""
|
||||
Config and Logs Router
|
||||
|
||||
Handles configuration and logging operations:
|
||||
- Get/update application configuration
|
||||
- Log viewing (single component, merged)
|
||||
- Notification history and stats
|
||||
- Changelog retrieval
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query, Request
|
||||
from pydantic import BaseModel
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from ..core.dependencies import get_current_user, require_admin, get_app_state
|
||||
from ..core.config import settings
|
||||
from ..core.exceptions import (
|
||||
handle_exceptions,
|
||||
ValidationError,
|
||||
RecordNotFoundError
|
||||
)
|
||||
from ..core.responses import now_iso8601
|
||||
from modules.universal_logger import get_logger
|
||||
|
||||
logger = get_logger('API')
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["Configuration"])
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
LOG_PATH = settings.PROJECT_ROOT / 'logs'
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PYDANTIC MODELS
|
||||
# ============================================================================
|
||||
|
||||
class ConfigUpdate(BaseModel):
|
||||
config: Dict
|
||||
|
||||
|
||||
class MergedLogsRequest(BaseModel):
|
||||
lines: int = 500
|
||||
components: List[str]
|
||||
around_time: Optional[str] = None # ISO timestamp to center logs around
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CONFIGURATION ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/config")
|
||||
@limiter.limit("100/minute")
|
||||
@handle_exceptions
|
||||
async def get_config(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get current configuration."""
|
||||
app_state = get_app_state()
|
||||
return app_state.settings.get_all()
|
||||
|
||||
|
||||
@router.put("/config")
|
||||
@limiter.limit("20/minute")
|
||||
@handle_exceptions
|
||||
async def update_config(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(require_admin),
|
||||
update: ConfigUpdate = Body(...)
|
||||
):
|
||||
"""
|
||||
Update configuration (admin only).
|
||||
|
||||
Saves configuration to database and updates in-memory state.
|
||||
"""
|
||||
app_state = get_app_state()
|
||||
|
||||
if not isinstance(update.config, dict):
|
||||
raise ValidationError("Invalid configuration format")
|
||||
|
||||
logger.debug(f"Incoming config keys: {list(update.config.keys())}", module="Config")
|
||||
|
||||
# Save to database
|
||||
for key, value in update.config.items():
|
||||
app_state.settings.set(key, value, category=key, updated_by='api')
|
||||
|
||||
# Refresh in-memory config so other endpoints see updated values
|
||||
app_state.config = app_state.settings.get_all()
|
||||
|
||||
# Broadcast update
|
||||
try:
|
||||
if hasattr(app_state, 'websocket_manager') and app_state.websocket_manager:
|
||||
await app_state.websocket_manager.broadcast({
|
||||
"type": "config_updated",
|
||||
"timestamp": now_iso8601()
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to broadcast config update: {e}", module="Config")
|
||||
|
||||
return {"success": True, "message": "Configuration updated"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# LOG ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/logs")
|
||||
@limiter.limit("100/minute")
|
||||
@handle_exceptions
|
||||
async def get_logs(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user),
|
||||
lines: int = 100,
|
||||
component: Optional[str] = None
|
||||
):
|
||||
"""Get recent log entries from the most recent log files."""
|
||||
if not LOG_PATH.exists():
|
||||
return {"logs": [], "available_components": []}
|
||||
|
||||
all_log_files = []
|
||||
|
||||
# Find date-stamped logs: YYYYMMDD_component.log or YYYYMMDD_HHMMSS_component.log
|
||||
seen_paths = set()
|
||||
for log_file in LOG_PATH.glob('*.log'):
|
||||
if '_' not in log_file.stem:
|
||||
continue
|
||||
parts = log_file.stem.split('_')
|
||||
if not parts[0].isdigit():
|
||||
continue
|
||||
try:
|
||||
stat_info = log_file.stat()
|
||||
if stat_info.st_size == 0:
|
||||
continue
|
||||
mtime = stat_info.st_mtime
|
||||
# YYYYMMDD_HHMMSS_component.log (3+ parts, first two numeric)
|
||||
if len(parts) >= 3 and parts[1].isdigit():
|
||||
comp_name = '_'.join(parts[2:])
|
||||
# YYYYMMDD_component.log (2+ parts, first numeric)
|
||||
elif len(parts) >= 2:
|
||||
comp_name = '_'.join(parts[1:])
|
||||
else:
|
||||
continue
|
||||
seen_paths.add(log_file)
|
||||
all_log_files.append({
|
||||
'path': log_file,
|
||||
'mtime': mtime,
|
||||
'component': comp_name
|
||||
})
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Also check for old-style logs (no date prefix)
|
||||
for log_file in LOG_PATH.glob('*.log'):
|
||||
if log_file in seen_paths:
|
||||
continue
|
||||
if '_' in log_file.stem and log_file.stem.split('_')[0].isdigit():
|
||||
continue
|
||||
try:
|
||||
stat_info = log_file.stat()
|
||||
if stat_info.st_size == 0:
|
||||
continue
|
||||
mtime = stat_info.st_mtime
|
||||
all_log_files.append({
|
||||
'path': log_file,
|
||||
'mtime': mtime,
|
||||
'component': log_file.stem
|
||||
})
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if not all_log_files:
|
||||
return {"logs": [], "available_components": []}
|
||||
|
||||
components = sorted(set(f['component'] for f in all_log_files))
|
||||
|
||||
if component:
|
||||
log_files = [f for f in all_log_files if f['component'] == component]
|
||||
else:
|
||||
log_files = all_log_files
|
||||
|
||||
if not log_files:
|
||||
return {"logs": [], "available_components": components}
|
||||
|
||||
most_recent = max(log_files, key=lambda x: x['mtime'])
|
||||
|
||||
try:
|
||||
with open(most_recent['path'], 'r', encoding='utf-8', errors='ignore') as f:
|
||||
all_lines = f.readlines()
|
||||
recent_lines = all_lines[-lines:]
|
||||
|
||||
return {
|
||||
"logs": [line.strip() for line in recent_lines],
|
||||
"available_components": components,
|
||||
"current_component": most_recent['component'],
|
||||
"log_file": str(most_recent['path'].name)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading log file: {e}", module="Logs")
|
||||
return {"logs": [], "available_components": components, "error": str(e)}
|
||||
|
||||
|
||||
@router.post("/logs/merged")
|
||||
@limiter.limit("100/minute")
|
||||
@handle_exceptions
|
||||
async def get_merged_logs(
|
||||
request: Request,
|
||||
body: MergedLogsRequest,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get merged log entries from multiple components, sorted by timestamp."""
|
||||
lines = body.lines
|
||||
components = body.components
|
||||
|
||||
if not LOG_PATH.exists():
|
||||
return {"logs": [], "available_components": [], "selected_components": []}
|
||||
|
||||
all_log_files = []
|
||||
|
||||
# Find date-stamped logs
|
||||
for log_file in LOG_PATH.glob('*_*.log'):
|
||||
try:
|
||||
stat_info = log_file.stat()
|
||||
if stat_info.st_size == 0:
|
||||
continue
|
||||
mtime = stat_info.st_mtime
|
||||
parts = log_file.stem.split('_')
|
||||
|
||||
# Check OLD format FIRST (YYYYMMDD_HHMMSS_component.log)
|
||||
if len(parts) >= 3 and parts[0].isdigit() and len(parts[0]) == 8 and parts[1].isdigit() and len(parts[1]) == 6:
|
||||
comp_name = '_'.join(parts[2:])
|
||||
all_log_files.append({
|
||||
'path': log_file,
|
||||
'mtime': mtime,
|
||||
'component': comp_name
|
||||
})
|
||||
# Then check NEW format (YYYYMMDD_component.log)
|
||||
elif len(parts) >= 2 and parts[0].isdigit() and len(parts[0]) == 8:
|
||||
comp_name = '_'.join(parts[1:])
|
||||
all_log_files.append({
|
||||
'path': log_file,
|
||||
'mtime': mtime,
|
||||
'component': comp_name
|
||||
})
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Also check for old-style logs
|
||||
for log_file in LOG_PATH.glob('*.log'):
|
||||
if '_' in log_file.stem and log_file.stem.split('_')[0].isdigit():
|
||||
continue
|
||||
try:
|
||||
stat_info = log_file.stat()
|
||||
if stat_info.st_size == 0:
|
||||
continue
|
||||
mtime = stat_info.st_mtime
|
||||
all_log_files.append({
|
||||
'path': log_file,
|
||||
'mtime': mtime,
|
||||
'component': log_file.stem
|
||||
})
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if not all_log_files:
|
||||
return {"logs": [], "available_components": [], "selected_components": []}
|
||||
|
||||
available_components = sorted(set(f['component'] for f in all_log_files))
|
||||
|
||||
if not components or len(components) == 0:
|
||||
return {
|
||||
"logs": [],
|
||||
"available_components": available_components,
|
||||
"selected_components": []
|
||||
}
|
||||
|
||||
selected_log_files = [f for f in all_log_files if f['component'] in components]
|
||||
|
||||
if not selected_log_files:
|
||||
return {
|
||||
"logs": [],
|
||||
"available_components": available_components,
|
||||
"selected_components": components
|
||||
}
|
||||
|
||||
all_logs_with_timestamps = []
|
||||
|
||||
for comp in components:
|
||||
comp_files = [f for f in selected_log_files if f['component'] == comp]
|
||||
if not comp_files:
|
||||
continue
|
||||
|
||||
most_recent = max(comp_files, key=lambda x: x['mtime'])
|
||||
|
||||
try:
|
||||
with open(most_recent['path'], 'r', encoding='utf-8', errors='ignore') as f:
|
||||
all_lines = f.readlines()
|
||||
recent_lines = all_lines[-lines:]
|
||||
|
||||
for line in recent_lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Match timestamp with optional microseconds
|
||||
timestamp_match = re.match(r'^(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2})(?:\.(\d+))?', line)
|
||||
|
||||
if timestamp_match:
|
||||
timestamp_str = timestamp_match.group(1)
|
||||
microseconds = timestamp_match.group(2)
|
||||
try:
|
||||
timestamp = datetime.strptime(timestamp_str, '%Y-%m-%d %H:%M:%S')
|
||||
# Add microseconds if present
|
||||
if microseconds:
|
||||
# Pad or truncate to 6 digits for microseconds
|
||||
microseconds = microseconds[:6].ljust(6, '0')
|
||||
timestamp = timestamp.replace(microsecond=int(microseconds))
|
||||
all_logs_with_timestamps.append({
|
||||
'timestamp': timestamp,
|
||||
'log': line
|
||||
})
|
||||
except ValueError:
|
||||
all_logs_with_timestamps.append({
|
||||
'timestamp': None,
|
||||
'log': line
|
||||
})
|
||||
else:
|
||||
all_logs_with_timestamps.append({
|
||||
'timestamp': None,
|
||||
'log': line
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading log file {most_recent['path']}: {e}", module="Logs")
|
||||
continue
|
||||
|
||||
# Sort by timestamp
|
||||
sorted_logs = sorted(
|
||||
all_logs_with_timestamps,
|
||||
key=lambda x: x['timestamp'] if x['timestamp'] is not None else datetime.min
|
||||
)
|
||||
|
||||
# If around_time is specified, center the logs around that timestamp
|
||||
if body.around_time:
|
||||
try:
|
||||
# Parse the target timestamp
|
||||
target_time = datetime.fromisoformat(body.around_time.replace('Z', '+00:00').replace('+00:00', ''))
|
||||
|
||||
# Find logs within 10 minutes of the target time
|
||||
time_window = timedelta(minutes=10)
|
||||
filtered_logs = [
|
||||
entry for entry in sorted_logs
|
||||
if entry['timestamp'] is not None and
|
||||
abs((entry['timestamp'] - target_time).total_seconds()) <= time_window.total_seconds()
|
||||
]
|
||||
|
||||
# If we found logs near the target time, use those
|
||||
# Otherwise fall back to all logs and try to find the closest ones
|
||||
if filtered_logs:
|
||||
merged_logs = [entry['log'] for entry in filtered_logs]
|
||||
else:
|
||||
# Find the closest logs to the target time
|
||||
logs_with_diff = [
|
||||
(entry, abs((entry['timestamp'] - target_time).total_seconds()) if entry['timestamp'] else float('inf'))
|
||||
for entry in sorted_logs
|
||||
]
|
||||
logs_with_diff.sort(key=lambda x: x[1])
|
||||
# Take the closest logs, centered around the target
|
||||
closest_logs = logs_with_diff[:lines]
|
||||
closest_logs.sort(key=lambda x: x[0]['timestamp'] if x[0]['timestamp'] else datetime.min)
|
||||
merged_logs = [entry[0]['log'] for entry in closest_logs]
|
||||
except (ValueError, TypeError):
|
||||
# If parsing fails, fall back to normal behavior
|
||||
merged_logs = [entry['log'] for entry in sorted_logs]
|
||||
if len(merged_logs) > lines:
|
||||
merged_logs = merged_logs[-lines:]
|
||||
else:
|
||||
merged_logs = [entry['log'] for entry in sorted_logs]
|
||||
if len(merged_logs) > lines:
|
||||
merged_logs = merged_logs[-lines:]
|
||||
|
||||
return {
|
||||
"logs": merged_logs,
|
||||
"available_components": available_components,
|
||||
"selected_components": components,
|
||||
"total_logs": len(merged_logs)
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# NOTIFICATION ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/notifications")
|
||||
@limiter.limit("500/minute")
|
||||
@handle_exceptions
|
||||
async def get_notifications(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user),
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
platform: Optional[str] = None,
|
||||
source: Optional[str] = None
|
||||
):
|
||||
"""Get notification history with pagination and filters."""
|
||||
app_state = get_app_state()
|
||||
|
||||
with app_state.db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
query = """
|
||||
SELECT id, platform, source, content_type, message, title,
|
||||
priority, download_count, sent_at, status, metadata
|
||||
FROM notifications
|
||||
WHERE 1=1
|
||||
"""
|
||||
params = []
|
||||
|
||||
if platform:
|
||||
query += " AND platform = ?"
|
||||
params.append(platform)
|
||||
|
||||
if source:
|
||||
# Handle standardized source names
|
||||
if source == 'YouTube Monitor':
|
||||
query += " AND source = ?"
|
||||
params.append('youtube_monitor')
|
||||
else:
|
||||
query += " AND source = ?"
|
||||
params.append(source)
|
||||
|
||||
# Get total count
|
||||
count_query = query.replace(
|
||||
"SELECT id, platform, source, content_type, message, title, priority, download_count, sent_at, status, metadata",
|
||||
"SELECT COUNT(*)"
|
||||
)
|
||||
cursor.execute(count_query, params)
|
||||
result = cursor.fetchone()
|
||||
total = result[0] if result else 0
|
||||
|
||||
# Add ordering and pagination
|
||||
query += " ORDER BY sent_at DESC LIMIT ? OFFSET ?"
|
||||
params.extend([limit, offset])
|
||||
|
||||
cursor.execute(query, params)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
notifications = []
|
||||
for row in rows:
|
||||
notifications.append({
|
||||
'id': row[0],
|
||||
'platform': row[1],
|
||||
'source': row[2],
|
||||
'content_type': row[3],
|
||||
'message': row[4],
|
||||
'title': row[5],
|
||||
'priority': row[6],
|
||||
'download_count': row[7],
|
||||
'sent_at': row[8],
|
||||
'status': row[9],
|
||||
'metadata': json.loads(row[10]) if row[10] else None
|
||||
})
|
||||
|
||||
return {
|
||||
'notifications': notifications,
|
||||
'total': total,
|
||||
'limit': limit,
|
||||
'offset': offset
|
||||
}
|
||||
|
||||
|
||||
@router.get("/notifications/stats")
|
||||
@limiter.limit("500/minute")
|
||||
@handle_exceptions
|
||||
async def get_notification_stats(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get notification statistics."""
|
||||
app_state = get_app_state()
|
||||
|
||||
with app_state.db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Total sent
|
||||
cursor.execute("SELECT COUNT(*) FROM notifications WHERE status = 'sent'")
|
||||
result = cursor.fetchone()
|
||||
total_sent = result[0] if result else 0
|
||||
|
||||
# Total failed
|
||||
cursor.execute("SELECT COUNT(*) FROM notifications WHERE status = 'failed'")
|
||||
result = cursor.fetchone()
|
||||
total_failed = result[0] if result else 0
|
||||
|
||||
# By platform (consolidate and filter)
|
||||
cursor.execute("""
|
||||
SELECT platform, COUNT(*) as count
|
||||
FROM notifications
|
||||
GROUP BY platform
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
raw_platforms = {row[0]: row[1] for row in cursor.fetchall()}
|
||||
|
||||
# Consolidate similar platforms and exclude system
|
||||
by_platform = {}
|
||||
for platform, count in raw_platforms.items():
|
||||
# Skip system notifications
|
||||
if platform == 'system':
|
||||
continue
|
||||
# Consolidate forum -> forums
|
||||
if platform == 'forum':
|
||||
by_platform['forums'] = by_platform.get('forums', 0) + count
|
||||
# Consolidate fastdl -> instagram (fastdl is an Instagram download method)
|
||||
elif platform == 'fastdl':
|
||||
by_platform['instagram'] = by_platform.get('instagram', 0) + count
|
||||
# Standardize youtube_monitor/youtube_monitors -> youtube
|
||||
elif platform in ('youtube_monitor', 'youtube_monitors'):
|
||||
by_platform['youtube'] = by_platform.get('youtube', 0) + count
|
||||
else:
|
||||
by_platform[platform] = by_platform.get(platform, 0) + count
|
||||
|
||||
# Recent 24h
|
||||
cursor.execute("""
|
||||
SELECT COUNT(*) FROM notifications
|
||||
WHERE sent_at >= datetime('now', '-1 day')
|
||||
""")
|
||||
result = cursor.fetchone()
|
||||
recent_24h = result[0] if result else 0
|
||||
|
||||
# Unique sources for filter dropdown
|
||||
cursor.execute("""
|
||||
SELECT DISTINCT source FROM notifications
|
||||
WHERE source IS NOT NULL AND source != ''
|
||||
ORDER BY source
|
||||
""")
|
||||
raw_sources = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
# Standardize source names and track special sources
|
||||
sources = []
|
||||
has_youtube_monitor = False
|
||||
has_log_errors = False
|
||||
for source in raw_sources:
|
||||
# Standardize youtube_monitor -> YouTube Monitor
|
||||
if source == 'youtube_monitor':
|
||||
has_youtube_monitor = True
|
||||
elif source == 'Log Errors':
|
||||
has_log_errors = True
|
||||
else:
|
||||
sources.append(source)
|
||||
|
||||
# Put special sources at the top
|
||||
priority_sources = []
|
||||
if has_youtube_monitor:
|
||||
priority_sources.append('YouTube Monitor')
|
||||
if has_log_errors:
|
||||
priority_sources.append('Log Errors')
|
||||
sources = priority_sources + sources
|
||||
|
||||
return {
|
||||
'total_sent': total_sent,
|
||||
'total_failed': total_failed,
|
||||
'by_platform': by_platform,
|
||||
'recent_24h': recent_24h,
|
||||
'sources': sources
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/notifications/{notification_id}")
|
||||
@limiter.limit("100/minute")
|
||||
@handle_exceptions
|
||||
async def delete_notification(
|
||||
request: Request,
|
||||
notification_id: int,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Delete a single notification from history."""
|
||||
app_state = get_app_state()
|
||||
|
||||
with app_state.db.get_connection(for_write=True) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check if notification exists
|
||||
cursor.execute("SELECT id FROM notifications WHERE id = ?", (notification_id,))
|
||||
if not cursor.fetchone():
|
||||
raise RecordNotFoundError(
|
||||
"Notification not found",
|
||||
{"notification_id": notification_id}
|
||||
)
|
||||
|
||||
# Delete the notification
|
||||
cursor.execute("DELETE FROM notifications WHERE id = ?", (notification_id,))
|
||||
conn.commit()
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': 'Notification deleted',
|
||||
'notification_id': notification_id
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CHANGELOG ENDPOINT
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/changelog")
|
||||
@limiter.limit("100/minute")
|
||||
@handle_exceptions
|
||||
async def get_changelog(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get changelog data from JSON file."""
|
||||
changelog_path = settings.PROJECT_ROOT / "data" / "changelog.json"
|
||||
|
||||
if not changelog_path.exists():
|
||||
return {"versions": []}
|
||||
|
||||
with open(changelog_path, 'r') as f:
|
||||
changelog_data = json.load(f)
|
||||
|
||||
return {"versions": changelog_data}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# APPEARANCE CONFIG ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
class AppearanceConfigUpdate(BaseModel):
|
||||
tmdb_api_key: Optional[str] = None
|
||||
tmdb_enabled: bool = True
|
||||
tmdb_check_interval_hours: int = 12
|
||||
notify_new_appearances: bool = True
|
||||
notify_days_before: int = 1
|
||||
podcast_enabled: bool = False
|
||||
radio_enabled: bool = False
|
||||
podchaser_client_id: Optional[str] = None
|
||||
podchaser_client_secret: Optional[str] = None
|
||||
podchaser_api_key: Optional[str] = None
|
||||
podchaser_enabled: bool = False
|
||||
imdb_enabled: bool = True
|
||||
|
||||
|
||||
@router.get("/config/appearance")
|
||||
@limiter.limit("100/minute")
|
||||
@handle_exceptions
|
||||
async def get_appearance_config(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get appearance tracking configuration."""
|
||||
db = get_app_state().db
|
||||
try:
|
||||
with db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
SELECT tmdb_api_key, tmdb_enabled, tmdb_check_interval_hours, tmdb_last_check,
|
||||
notify_new_appearances, notify_days_before, podcast_enabled, radio_enabled,
|
||||
podchaser_client_id, podchaser_client_secret, podchaser_api_key,
|
||||
podchaser_enabled, podchaser_last_check, imdb_enabled
|
||||
FROM appearance_config
|
||||
WHERE id = 1
|
||||
''')
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
# Initialize config if not exists
|
||||
cursor.execute('INSERT OR IGNORE INTO appearance_config (id) VALUES (1)')
|
||||
conn.commit()
|
||||
return {
|
||||
"tmdb_api_key": None,
|
||||
"tmdb_enabled": True,
|
||||
"tmdb_check_interval_hours": 12,
|
||||
"tmdb_last_check": None,
|
||||
"notify_new_appearances": True,
|
||||
"notify_days_before": 1,
|
||||
"podcast_enabled": False,
|
||||
"radio_enabled": False,
|
||||
"podchaser_client_id": None,
|
||||
"podchaser_client_secret": None,
|
||||
"podchaser_api_key": None,
|
||||
"podchaser_enabled": False,
|
||||
"podchaser_last_check": None
|
||||
}
|
||||
|
||||
return {
|
||||
"tmdb_api_key": row[0],
|
||||
"tmdb_enabled": bool(row[1]),
|
||||
"tmdb_check_interval_hours": row[2],
|
||||
"tmdb_last_check": row[3],
|
||||
"notify_new_appearances": bool(row[4]),
|
||||
"notify_days_before": row[5],
|
||||
"podcast_enabled": bool(row[6]),
|
||||
"radio_enabled": bool(row[7]),
|
||||
"podchaser_client_id": row[8] if len(row) > 8 else None,
|
||||
"podchaser_client_secret": row[9] if len(row) > 9 else None,
|
||||
"podchaser_api_key": row[10] if len(row) > 10 else None,
|
||||
"podchaser_enabled": bool(row[11]) if len(row) > 11 else False,
|
||||
"podchaser_last_check": row[12] if len(row) > 12 else None,
|
||||
"imdb_enabled": bool(row[13]) if len(row) > 13 else True
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting appearance config: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/config/appearance")
|
||||
@limiter.limit("100/minute")
|
||||
@handle_exceptions
|
||||
async def update_appearance_config(
|
||||
request: Request,
|
||||
config: AppearanceConfigUpdate,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Update appearance tracking configuration."""
|
||||
db = get_app_state().db
|
||||
try:
|
||||
with db.get_connection(for_write=True) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Update config
|
||||
cursor.execute('''
|
||||
UPDATE appearance_config
|
||||
SET tmdb_api_key = ?,
|
||||
tmdb_enabled = ?,
|
||||
tmdb_check_interval_hours = ?,
|
||||
notify_new_appearances = ?,
|
||||
notify_days_before = ?,
|
||||
podcast_enabled = ?,
|
||||
radio_enabled = ?,
|
||||
podchaser_client_id = ?,
|
||||
podchaser_client_secret = ?,
|
||||
podchaser_api_key = ?,
|
||||
podchaser_enabled = ?,
|
||||
imdb_enabled = ?,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = 1
|
||||
''', (config.tmdb_api_key, config.tmdb_enabled, config.tmdb_check_interval_hours,
|
||||
config.notify_new_appearances, config.notify_days_before,
|
||||
config.podcast_enabled, config.radio_enabled,
|
||||
config.podchaser_client_id, config.podchaser_client_secret,
|
||||
config.podchaser_api_key, config.podchaser_enabled, config.imdb_enabled))
|
||||
|
||||
conn.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Appearance configuration updated successfully"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating appearance config: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
304
web/backend/routers/dashboard.py
Normal file
304
web/backend/routers/dashboard.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Dashboard API Router
|
||||
|
||||
Provides endpoints for dashboard-specific data like recent items across different locations.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from typing import Dict, Any, Optional
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from ..core.dependencies import get_current_user, get_app_state
|
||||
from ..core.exceptions import handle_exceptions
|
||||
from modules.universal_logger import get_logger
|
||||
|
||||
router = APIRouter(prefix="/api/dashboard", tags=["dashboard"])
|
||||
logger = get_logger('API')
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
@router.get("/recent-items")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_recent_items(
|
||||
request: Request,
|
||||
limit: int = 20,
|
||||
since_id: Optional[int] = None,
|
||||
current_user=Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get NEW items from Media, Review, and Internet Discovery for dashboard cards.
|
||||
|
||||
Uses file_inventory.id for ordering since it monotonically increases with
|
||||
insertion order. download_date from the downloads table is included for
|
||||
display but not used for ordering (batch downloads can interleave timestamps).
|
||||
|
||||
Args:
|
||||
limit: Max items per category
|
||||
since_id: Optional file_inventory ID - only return items with id > this value
|
||||
|
||||
Returns up to `limit` items from each location, sorted by most recently added first.
|
||||
"""
|
||||
app_state = get_app_state()
|
||||
with app_state.db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Media items (location='final')
|
||||
# ORDER BY fi.id DESC — id is monotonically increasing and reflects insertion order.
|
||||
# download_date is included for display but NOT used for ordering.
|
||||
if since_id:
|
||||
cursor.execute("""
|
||||
SELECT fi.id, fi.file_path, fi.filename, fi.source, fi.platform, fi.content_type,
|
||||
fi.file_size, COALESCE(d.download_date, fi.created_date) as added_at,
|
||||
fi.width, fi.height
|
||||
FROM file_inventory fi
|
||||
LEFT JOIN downloads d ON d.filename = fi.filename
|
||||
WHERE fi.location = 'final'
|
||||
AND fi.id > ?
|
||||
AND (fi.moved_from_review IS NULL OR fi.moved_from_review = 0)
|
||||
AND (fi.from_discovery IS NULL OR fi.from_discovery = 0)
|
||||
ORDER BY fi.id DESC
|
||||
LIMIT ?
|
||||
""", (since_id, limit))
|
||||
else:
|
||||
cursor.execute("""
|
||||
SELECT fi.id, fi.file_path, fi.filename, fi.source, fi.platform, fi.content_type,
|
||||
fi.file_size, COALESCE(d.download_date, fi.created_date) as added_at,
|
||||
fi.width, fi.height
|
||||
FROM file_inventory fi
|
||||
LEFT JOIN downloads d ON d.filename = fi.filename
|
||||
WHERE fi.location = 'final'
|
||||
AND (fi.moved_from_review IS NULL OR fi.moved_from_review = 0)
|
||||
AND (fi.from_discovery IS NULL OR fi.from_discovery = 0)
|
||||
ORDER BY fi.id DESC
|
||||
LIMIT ?
|
||||
""", (limit,))
|
||||
|
||||
media_items = []
|
||||
for row in cursor.fetchall():
|
||||
media_items.append({
|
||||
'id': row[0],
|
||||
'file_path': row[1],
|
||||
'filename': row[2],
|
||||
'source': row[3],
|
||||
'platform': row[4],
|
||||
'media_type': row[5],
|
||||
'file_size': row[6],
|
||||
'added_at': row[7],
|
||||
'width': row[8],
|
||||
'height': row[9]
|
||||
})
|
||||
|
||||
# Get total count for new media items
|
||||
if since_id:
|
||||
cursor.execute("""
|
||||
SELECT COUNT(*)
|
||||
FROM file_inventory
|
||||
WHERE location = 'final'
|
||||
AND id > ?
|
||||
AND (moved_from_review IS NULL OR moved_from_review = 0)
|
||||
AND (from_discovery IS NULL OR from_discovery = 0)
|
||||
""", (since_id,))
|
||||
else:
|
||||
cursor.execute("""
|
||||
SELECT COUNT(*) FROM file_inventory
|
||||
WHERE location = 'final'
|
||||
AND (moved_from_review IS NULL OR moved_from_review = 0)
|
||||
AND (from_discovery IS NULL OR from_discovery = 0)
|
||||
""")
|
||||
media_count = cursor.fetchone()[0]
|
||||
|
||||
# Review items (location='review')
|
||||
if since_id:
|
||||
cursor.execute("""
|
||||
SELECT f.id, f.file_path, f.filename, f.source, f.platform, f.content_type,
|
||||
f.file_size, COALESCE(d.download_date, f.created_date) as added_at,
|
||||
f.width, f.height,
|
||||
CASE WHEN fr.id IS NOT NULL THEN 1 ELSE 0 END as face_scanned,
|
||||
fr.has_match as face_matched, fr.confidence as face_confidence, fr.matched_person
|
||||
FROM file_inventory f
|
||||
LEFT JOIN downloads d ON d.filename = f.filename
|
||||
LEFT JOIN face_recognition_scans fr ON f.file_path = fr.file_path
|
||||
WHERE f.location = 'review'
|
||||
AND f.id > ?
|
||||
AND (f.moved_from_media IS NULL OR f.moved_from_media = 0)
|
||||
ORDER BY f.id DESC
|
||||
LIMIT ?
|
||||
""", (since_id, limit))
|
||||
else:
|
||||
cursor.execute("""
|
||||
SELECT f.id, f.file_path, f.filename, f.source, f.platform, f.content_type,
|
||||
f.file_size, COALESCE(d.download_date, f.created_date) as added_at,
|
||||
f.width, f.height,
|
||||
CASE WHEN fr.id IS NOT NULL THEN 1 ELSE 0 END as face_scanned,
|
||||
fr.has_match as face_matched, fr.confidence as face_confidence, fr.matched_person
|
||||
FROM file_inventory f
|
||||
LEFT JOIN downloads d ON d.filename = f.filename
|
||||
LEFT JOIN face_recognition_scans fr ON f.file_path = fr.file_path
|
||||
WHERE f.location = 'review'
|
||||
AND (f.moved_from_media IS NULL OR f.moved_from_media = 0)
|
||||
ORDER BY f.id DESC
|
||||
LIMIT ?
|
||||
""", (limit,))
|
||||
|
||||
review_items = []
|
||||
for row in cursor.fetchall():
|
||||
face_recognition = None
|
||||
if row[10]: # face_scanned
|
||||
face_recognition = {
|
||||
'scanned': True,
|
||||
'matched': bool(row[11]) if row[11] is not None else False,
|
||||
'confidence': row[12],
|
||||
'matched_person': row[13]
|
||||
}
|
||||
|
||||
review_items.append({
|
||||
'id': row[0],
|
||||
'file_path': row[1],
|
||||
'filename': row[2],
|
||||
'source': row[3],
|
||||
'platform': row[4],
|
||||
'media_type': row[5],
|
||||
'file_size': row[6],
|
||||
'added_at': row[7],
|
||||
'width': row[8],
|
||||
'height': row[9],
|
||||
'face_recognition': face_recognition
|
||||
})
|
||||
|
||||
# Get total count for new review items
|
||||
if since_id:
|
||||
cursor.execute("""
|
||||
SELECT COUNT(*)
|
||||
FROM file_inventory
|
||||
WHERE location = 'review'
|
||||
AND id > ?
|
||||
AND (moved_from_media IS NULL OR moved_from_media = 0)
|
||||
""", (since_id,))
|
||||
else:
|
||||
cursor.execute("""
|
||||
SELECT COUNT(*) FROM file_inventory
|
||||
WHERE location = 'review'
|
||||
AND (moved_from_media IS NULL OR moved_from_media = 0)
|
||||
""")
|
||||
review_count = cursor.fetchone()[0]
|
||||
|
||||
# Internet Discovery items (celebrity_discovered_videos with status='new')
|
||||
internet_discovery_items = []
|
||||
internet_discovery_count = 0
|
||||
|
||||
try:
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
v.id,
|
||||
v.video_id,
|
||||
v.title,
|
||||
v.thumbnail,
|
||||
v.channel_name,
|
||||
v.platform,
|
||||
v.duration,
|
||||
v.max_resolution,
|
||||
v.status,
|
||||
v.discovered_at,
|
||||
v.url,
|
||||
v.view_count,
|
||||
v.upload_date,
|
||||
c.name as celebrity_name
|
||||
FROM celebrity_discovered_videos v
|
||||
LEFT JOIN celebrity_profiles c ON v.celebrity_id = c.id
|
||||
WHERE v.status = 'new'
|
||||
ORDER BY v.id DESC
|
||||
LIMIT ?
|
||||
""", (limit,))
|
||||
|
||||
for row in cursor.fetchall():
|
||||
internet_discovery_items.append({
|
||||
'id': row[0],
|
||||
'video_id': row[1],
|
||||
'title': row[2],
|
||||
'thumbnail': row[3],
|
||||
'channel_name': row[4],
|
||||
'platform': row[5],
|
||||
'duration': row[6],
|
||||
'max_resolution': row[7],
|
||||
'status': row[8],
|
||||
'discovered_at': row[9],
|
||||
'url': row[10],
|
||||
'view_count': row[11],
|
||||
'upload_date': row[12],
|
||||
'celebrity_name': row[13]
|
||||
})
|
||||
|
||||
# Get total count for internet discovery
|
||||
cursor.execute("SELECT COUNT(*) FROM celebrity_discovered_videos WHERE status = 'new'")
|
||||
internet_discovery_count = cursor.fetchone()[0]
|
||||
except Exception as e:
|
||||
# Table might not exist if celebrity feature not used
|
||||
logger.warning(f"Could not fetch internet discovery items: {e}", module="Dashboard")
|
||||
|
||||
return {
|
||||
'media': {
|
||||
'count': media_count,
|
||||
'items': media_items
|
||||
},
|
||||
'review': {
|
||||
'count': review_count,
|
||||
'items': review_items
|
||||
},
|
||||
'internet_discovery': {
|
||||
'count': internet_discovery_count,
|
||||
'items': internet_discovery_items
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.get("/dismissed-cards")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_dismissed_cards(
|
||||
request: Request,
|
||||
user=Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get the user's dismissed card IDs."""
|
||||
app_state = get_app_state()
|
||||
user_id = user.get('username', 'default')
|
||||
|
||||
with app_state.db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT preference_value FROM user_preferences
|
||||
WHERE user_id = ? AND preference_key = 'dashboard_dismissed_cards'
|
||||
""", (user_id,))
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row and row[0]:
|
||||
import json
|
||||
return json.loads(row[0])
|
||||
|
||||
return {'media': None, 'review': None, 'internet_discovery': None}
|
||||
|
||||
|
||||
@router.post("/dismissed-cards")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def set_dismissed_cards(
|
||||
request: Request,
|
||||
data: Dict[str, Any],
|
||||
user=Depends(get_current_user)
|
||||
) -> Dict[str, str]:
|
||||
"""Save the user's dismissed card IDs."""
|
||||
import json
|
||||
app_state = get_app_state()
|
||||
user_id = user.get('username', 'default')
|
||||
|
||||
with app_state.db.get_connection(for_write=True) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
INSERT INTO user_preferences (user_id, preference_key, preference_value, updated_at)
|
||||
VALUES (?, 'dashboard_dismissed_cards', ?, CURRENT_TIMESTAMP)
|
||||
ON CONFLICT(user_id, preference_key) DO UPDATE SET
|
||||
preference_value = excluded.preference_value,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
""", (user_id, json.dumps(data)))
|
||||
|
||||
return {'status': 'ok'}
|
||||
942
web/backend/routers/discovery.py
Normal file
942
web/backend/routers/discovery.py
Normal file
@@ -0,0 +1,942 @@
|
||||
"""
|
||||
Discovery Router
|
||||
|
||||
Handles discovery, organization and browsing features:
|
||||
- Tags management (CRUD, file tagging, bulk operations)
|
||||
- Smart folders (filter-based virtual folders)
|
||||
- Collections (manual file groupings)
|
||||
- Timeline and activity views
|
||||
- Discovery queue management
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Query, Request
|
||||
from pydantic import BaseModel
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from ..core.dependencies import get_current_user, get_app_state
|
||||
from ..core.exceptions import handle_exceptions, NotFoundError, ValidationError
|
||||
from ..core.responses import message_response, id_response, count_response, offset_paginated
|
||||
from modules.discovery_system import get_discovery_system
|
||||
from modules.universal_logger import get_logger
|
||||
|
||||
logger = get_logger('API')
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["Discovery"])
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PYDANTIC MODELS
|
||||
# ============================================================================
|
||||
|
||||
class TagCreate(BaseModel):
|
||||
name: str
|
||||
parent_id: Optional[int] = None
|
||||
color: str = '#6366f1'
|
||||
icon: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class TagUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
parent_id: Optional[int] = None
|
||||
color: Optional[str] = None
|
||||
icon: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class BulkTagRequest(BaseModel):
|
||||
file_ids: List[int]
|
||||
tag_ids: List[int]
|
||||
|
||||
|
||||
class SmartFolderCreate(BaseModel):
|
||||
name: str
|
||||
filters: dict = {}
|
||||
icon: str = 'folder'
|
||||
color: str = '#6366f1'
|
||||
description: Optional[str] = None
|
||||
sort_by: str = 'post_date'
|
||||
sort_order: str = 'desc'
|
||||
|
||||
|
||||
class SmartFolderUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
filters: Optional[dict] = None
|
||||
icon: Optional[str] = None
|
||||
color: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
sort_by: Optional[str] = None
|
||||
sort_order: Optional[str] = None
|
||||
|
||||
|
||||
class CollectionCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
color: str = '#6366f1'
|
||||
|
||||
|
||||
class CollectionUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
color: Optional[str] = None
|
||||
cover_file_id: Optional[int] = None
|
||||
|
||||
|
||||
class BulkCollectionAdd(BaseModel):
|
||||
file_ids: List[int]
|
||||
|
||||
|
||||
class DiscoveryQueueAdd(BaseModel):
|
||||
file_ids: List[int]
|
||||
priority: int = 0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TAGS ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/tags")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_tags(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user),
|
||||
parent_id: Optional[int] = Query(None, description="Parent tag ID (null for root, -1 for all)"),
|
||||
include_counts: bool = Query(True, description="Include file counts")
|
||||
):
|
||||
"""Get all tags, optionally filtered by parent."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
tags = discovery.get_tags(parent_id=parent_id, include_counts=include_counts)
|
||||
return {"tags": tags}
|
||||
|
||||
|
||||
@router.get("/tags/{tag_id}")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_tag(
|
||||
request: Request,
|
||||
tag_id: int,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get a single tag by ID."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
tag = discovery.get_tag(tag_id)
|
||||
if not tag:
|
||||
raise NotFoundError("Tag not found")
|
||||
return tag
|
||||
|
||||
|
||||
@router.post("/tags")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def create_tag(
|
||||
request: Request,
|
||||
tag_data: TagCreate,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Create a new tag."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
tag_id = discovery.create_tag(
|
||||
name=tag_data.name,
|
||||
parent_id=tag_data.parent_id,
|
||||
color=tag_data.color,
|
||||
icon=tag_data.icon,
|
||||
description=tag_data.description
|
||||
)
|
||||
if tag_id is None:
|
||||
raise ValidationError("Failed to create tag")
|
||||
return id_response(tag_id, "Tag created successfully")
|
||||
|
||||
|
||||
@router.put("/tags/{tag_id}")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def update_tag(
|
||||
request: Request,
|
||||
tag_id: int,
|
||||
tag_data: TagUpdate,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Update a tag."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
success = discovery.update_tag(
|
||||
tag_id=tag_id,
|
||||
name=tag_data.name,
|
||||
color=tag_data.color,
|
||||
icon=tag_data.icon,
|
||||
description=tag_data.description,
|
||||
parent_id=tag_data.parent_id
|
||||
)
|
||||
if not success:
|
||||
raise NotFoundError("Tag not found or update failed")
|
||||
return message_response("Tag updated successfully")
|
||||
|
||||
|
||||
@router.delete("/tags/{tag_id}")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def delete_tag(
|
||||
request: Request,
|
||||
tag_id: int,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Delete a tag."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
success = discovery.delete_tag(tag_id)
|
||||
if not success:
|
||||
raise NotFoundError("Tag not found")
|
||||
return message_response("Tag deleted successfully")
|
||||
|
||||
|
||||
@router.get("/files/{file_id}/tags")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_file_tags(
|
||||
request: Request,
|
||||
file_id: int,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get all tags for a file."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
tags = discovery.get_file_tags(file_id)
|
||||
return {"tags": tags}
|
||||
|
||||
|
||||
@router.post("/files/{file_id}/tags/{tag_id}")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def tag_file(
|
||||
request: Request,
|
||||
file_id: int,
|
||||
tag_id: int,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Add a tag to a file."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
success = discovery.tag_file(file_id, tag_id, created_by=current_user.get('sub'))
|
||||
if not success:
|
||||
raise ValidationError("Failed to tag file")
|
||||
return message_response("Tag added to file")
|
||||
|
||||
|
||||
@router.delete("/files/{file_id}/tags/{tag_id}")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def untag_file(
|
||||
request: Request,
|
||||
file_id: int,
|
||||
tag_id: int,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Remove a tag from a file."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
success = discovery.untag_file(file_id, tag_id)
|
||||
if not success:
|
||||
raise NotFoundError("Tag not found on file")
|
||||
return message_response("Tag removed from file")
|
||||
|
||||
|
||||
@router.get("/tags/{tag_id}/files")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_files_by_tag(
|
||||
request: Request,
|
||||
tag_id: int,
|
||||
current_user: Dict = Depends(get_current_user),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0)
|
||||
):
|
||||
"""Get files with a specific tag."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
files, total = discovery.get_files_by_tag(tag_id, limit=limit, offset=offset)
|
||||
return offset_paginated(files, total, limit, offset, key="files")
|
||||
|
||||
|
||||
@router.post("/tags/bulk")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def bulk_tag_files(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Tag multiple files with multiple tags."""
|
||||
data = await request.json()
|
||||
file_ids = data.get('file_ids', [])
|
||||
tag_ids = data.get('tag_ids', [])
|
||||
|
||||
if not file_ids or not tag_ids:
|
||||
raise ValidationError("file_ids and tag_ids required")
|
||||
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
count = discovery.bulk_tag_files(file_ids, tag_ids, created_by=current_user.get('sub'))
|
||||
return count_response(f"Tagged {count} file-tag pairs", count)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SMART FOLDERS ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/smart-folders")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_smart_folders(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user),
|
||||
include_system: bool = Query(True, description="Include system smart folders")
|
||||
):
|
||||
"""Get all smart folders."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
folders = discovery.get_smart_folders(include_system=include_system)
|
||||
return {"smart_folders": folders}
|
||||
|
||||
|
||||
@router.get("/smart-folders/stats")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def get_smart_folders_stats(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get file counts and preview thumbnails for all smart folders."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
folders = discovery.get_smart_folders(include_system=True)
|
||||
|
||||
stats = {}
|
||||
with app_state.db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
for folder in folders:
|
||||
filters = folder.get('filters', {})
|
||||
folder_id = folder['id']
|
||||
|
||||
query = '''
|
||||
SELECT COUNT(*) as count
|
||||
FROM file_inventory fi
|
||||
WHERE fi.location = 'final'
|
||||
'''
|
||||
params = []
|
||||
|
||||
if filters.get('platform'):
|
||||
query += ' AND fi.platform = ?'
|
||||
params.append(filters['platform'])
|
||||
|
||||
if filters.get('media_type'):
|
||||
query += ' AND fi.content_type = ?'
|
||||
params.append(filters['media_type'])
|
||||
|
||||
if filters.get('source'):
|
||||
query += ' AND fi.source = ?'
|
||||
params.append(filters['source'])
|
||||
|
||||
if filters.get('size_min'):
|
||||
query += ' AND fi.file_size >= ?'
|
||||
params.append(filters['size_min'])
|
||||
|
||||
cursor.execute(query, params)
|
||||
count = cursor.fetchone()[0]
|
||||
|
||||
preview_query = '''
|
||||
SELECT fi.file_path, fi.content_type
|
||||
FROM file_inventory fi
|
||||
WHERE fi.location = 'final'
|
||||
'''
|
||||
preview_params = []
|
||||
|
||||
if filters.get('platform'):
|
||||
preview_query += ' AND fi.platform = ?'
|
||||
preview_params.append(filters['platform'])
|
||||
|
||||
if filters.get('media_type'):
|
||||
preview_query += ' AND fi.content_type = ?'
|
||||
preview_params.append(filters['media_type'])
|
||||
|
||||
if filters.get('source'):
|
||||
preview_query += ' AND fi.source = ?'
|
||||
preview_params.append(filters['source'])
|
||||
|
||||
if filters.get('size_min'):
|
||||
preview_query += ' AND fi.file_size >= ?'
|
||||
preview_params.append(filters['size_min'])
|
||||
|
||||
preview_query += ' ORDER BY fi.created_date DESC LIMIT 4'
|
||||
cursor.execute(preview_query, preview_params)
|
||||
|
||||
previews = []
|
||||
for row in cursor.fetchall():
|
||||
previews.append({
|
||||
'file_path': row['file_path'],
|
||||
'content_type': row['content_type']
|
||||
})
|
||||
|
||||
stats[folder_id] = {
|
||||
'count': count,
|
||||
'previews': previews
|
||||
}
|
||||
|
||||
return {"stats": stats}
|
||||
|
||||
|
||||
@router.get("/smart-folders/{folder_id}")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_smart_folder(
|
||||
request: Request,
|
||||
folder_id: int,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get a single smart folder by ID."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
folder = discovery.get_smart_folder(folder_id=folder_id)
|
||||
if not folder:
|
||||
raise NotFoundError("Smart folder not found")
|
||||
return folder
|
||||
|
||||
|
||||
@router.post("/smart-folders")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def create_smart_folder(
|
||||
request: Request,
|
||||
folder_data: SmartFolderCreate,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Create a new smart folder."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
folder_id = discovery.create_smart_folder(
|
||||
name=folder_data.name,
|
||||
filters=folder_data.filters,
|
||||
icon=folder_data.icon,
|
||||
color=folder_data.color,
|
||||
description=folder_data.description,
|
||||
sort_by=folder_data.sort_by,
|
||||
sort_order=folder_data.sort_order
|
||||
)
|
||||
if folder_id is None:
|
||||
raise ValidationError("Failed to create smart folder")
|
||||
return {"id": folder_id, "message": "Smart folder created successfully"}
|
||||
|
||||
|
||||
@router.put("/smart-folders/{folder_id}")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def update_smart_folder(
|
||||
request: Request,
|
||||
folder_id: int,
|
||||
folder_data: SmartFolderUpdate,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Update a smart folder (cannot update system folders)."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
success = discovery.update_smart_folder(
|
||||
folder_id=folder_id,
|
||||
name=folder_data.name,
|
||||
filters=folder_data.filters,
|
||||
icon=folder_data.icon,
|
||||
color=folder_data.color,
|
||||
description=folder_data.description,
|
||||
sort_by=folder_data.sort_by,
|
||||
sort_order=folder_data.sort_order
|
||||
)
|
||||
if not success:
|
||||
raise ValidationError("Failed to update smart folder (may be a system folder)")
|
||||
return {"message": "Smart folder updated successfully"}
|
||||
|
||||
|
||||
@router.delete("/smart-folders/{folder_id}")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def delete_smart_folder(
|
||||
request: Request,
|
||||
folder_id: int,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Delete a smart folder (cannot delete system folders)."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
success = discovery.delete_smart_folder(folder_id)
|
||||
if not success:
|
||||
raise ValidationError("Failed to delete smart folder (may be a system folder)")
|
||||
return {"message": "Smart folder deleted successfully"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# COLLECTIONS ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/collections")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_collections(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get all collections."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
collections = discovery.get_collections()
|
||||
return {"collections": collections}
|
||||
|
||||
|
||||
@router.get("/collections/{collection_id}")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_collection(
|
||||
request: Request,
|
||||
collection_id: int,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get a single collection by ID."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
collection = discovery.get_collection(collection_id=collection_id)
|
||||
if not collection:
|
||||
raise NotFoundError("Collection not found")
|
||||
return collection
|
||||
|
||||
|
||||
@router.post("/collections")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def create_collection(
|
||||
request: Request,
|
||||
collection_data: CollectionCreate,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Create a new collection."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
collection_id = discovery.create_collection(
|
||||
name=collection_data.name,
|
||||
description=collection_data.description,
|
||||
color=collection_data.color
|
||||
)
|
||||
if collection_id is None:
|
||||
raise ValidationError("Failed to create collection")
|
||||
return {"id": collection_id, "message": "Collection created successfully"}
|
||||
|
||||
|
||||
@router.put("/collections/{collection_id}")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def update_collection(
|
||||
request: Request,
|
||||
collection_id: int,
|
||||
collection_data: CollectionUpdate,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Update a collection."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
success = discovery.update_collection(
|
||||
collection_id=collection_id,
|
||||
name=collection_data.name,
|
||||
description=collection_data.description,
|
||||
color=collection_data.color,
|
||||
cover_file_id=collection_data.cover_file_id
|
||||
)
|
||||
if not success:
|
||||
raise NotFoundError("Collection not found")
|
||||
return {"message": "Collection updated successfully"}
|
||||
|
||||
|
||||
@router.delete("/collections/{collection_id}")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def delete_collection(
|
||||
request: Request,
|
||||
collection_id: int,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Delete a collection."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
success = discovery.delete_collection(collection_id)
|
||||
if not success:
|
||||
raise NotFoundError("Collection not found")
|
||||
return {"message": "Collection deleted successfully"}
|
||||
|
||||
|
||||
@router.get("/collections/{collection_id}/files")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_collection_files(
|
||||
request: Request,
|
||||
collection_id: int,
|
||||
current_user: Dict = Depends(get_current_user),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0)
|
||||
):
|
||||
"""Get files in a collection."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
files, total = discovery.get_collection_files(collection_id, limit=limit, offset=offset)
|
||||
return {"files": files, "total": total, "limit": limit, "offset": offset}
|
||||
|
||||
|
||||
@router.post("/collections/{collection_id}/files/{file_id}")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def add_to_collection(
|
||||
request: Request,
|
||||
collection_id: int,
|
||||
file_id: int,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Add a file to a collection."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
success = discovery.add_to_collection(collection_id, file_id, added_by=current_user.get('sub'))
|
||||
if not success:
|
||||
raise ValidationError("Failed to add file to collection")
|
||||
return {"message": "File added to collection"}
|
||||
|
||||
|
||||
@router.delete("/collections/{collection_id}/files/{file_id}")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def remove_from_collection(
|
||||
request: Request,
|
||||
collection_id: int,
|
||||
file_id: int,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Remove a file from a collection."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
success = discovery.remove_from_collection(collection_id, file_id)
|
||||
if not success:
|
||||
raise NotFoundError("File not found in collection")
|
||||
return {"message": "File removed from collection"}
|
||||
|
||||
|
||||
@router.post("/collections/{collection_id}/files/bulk")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def bulk_add_to_collection(
|
||||
request: Request,
|
||||
collection_id: int,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Add multiple files to a collection."""
|
||||
data = await request.json()
|
||||
file_ids = data.get('file_ids', [])
|
||||
|
||||
if not file_ids:
|
||||
raise ValidationError("file_ids required")
|
||||
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
count = discovery.bulk_add_to_collection(collection_id, file_ids, added_by=current_user.get('sub'))
|
||||
return {"message": f"Added {count} files to collection", "count": count}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TIMELINE ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/timeline")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_timeline(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user),
|
||||
granularity: str = Query('day', pattern='^(day|week|month|year)$'),
|
||||
date_from: Optional[str] = Query(None, pattern=r'^\d{4}-\d{2}-\d{2}$'),
|
||||
date_to: Optional[str] = Query(None, pattern=r'^\d{4}-\d{2}-\d{2}$'),
|
||||
platform: Optional[str] = Query(None),
|
||||
source: Optional[str] = Query(None)
|
||||
):
|
||||
"""Get timeline aggregation data."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
data = discovery.get_timeline_data(
|
||||
granularity=granularity,
|
||||
date_from=date_from,
|
||||
date_to=date_to,
|
||||
platform=platform,
|
||||
source=source
|
||||
)
|
||||
return {"timeline": data, "granularity": granularity}
|
||||
|
||||
|
||||
@router.get("/timeline/heatmap")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_timeline_heatmap(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user),
|
||||
year: Optional[int] = Query(None, ge=2000, le=2100),
|
||||
platform: Optional[str] = Query(None)
|
||||
):
|
||||
"""Get activity heatmap data (file counts per day for a year)."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
heatmap = discovery.get_activity_heatmap(year=year, platform=platform)
|
||||
return {"heatmap": heatmap, "year": year or datetime.now().year}
|
||||
|
||||
|
||||
@router.get("/timeline/on-this-day")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_on_this_day(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user),
|
||||
month: Optional[int] = Query(None, ge=1, le=12),
|
||||
day: Optional[int] = Query(None, ge=1, le=31),
|
||||
limit: int = Query(50, ge=1, le=200)
|
||||
):
|
||||
"""Get content from the same day in previous years ('On This Day' feature)."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
files = discovery.get_on_this_day(month=month, day=day, limit=limit)
|
||||
return {"files": files, "count": len(files)}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# RECENT ACTIVITY ENDPOINT
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/discovery/recent-activity")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_recent_activity(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user),
|
||||
limit: int = Query(10, ge=1, le=50)
|
||||
):
|
||||
"""Get recent activity across downloads, deletions, and restores."""
|
||||
app_state = get_app_state()
|
||||
|
||||
activity = {
|
||||
'recent_downloads': [],
|
||||
'recent_deleted': [],
|
||||
'recent_restored': [],
|
||||
'recent_moved_to_review': [],
|
||||
'summary': {
|
||||
'downloads_24h': 0,
|
||||
'downloads_7d': 0,
|
||||
'deleted_24h': 0,
|
||||
'deleted_7d': 0
|
||||
}
|
||||
}
|
||||
|
||||
with app_state.db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Recent downloads
|
||||
cursor.execute('''
|
||||
SELECT
|
||||
fi.id, fi.file_path, fi.filename, fi.platform, fi.source,
|
||||
fi.content_type, fi.file_size, fi.created_date,
|
||||
d.download_date, d.post_date
|
||||
FROM file_inventory fi
|
||||
LEFT JOIN downloads d ON d.file_path = fi.file_path
|
||||
WHERE fi.location = 'final'
|
||||
ORDER BY fi.created_date DESC
|
||||
LIMIT ?
|
||||
''', (limit,))
|
||||
|
||||
for row in cursor.fetchall():
|
||||
activity['recent_downloads'].append({
|
||||
'id': row['id'],
|
||||
'file_path': row['file_path'],
|
||||
'filename': row['filename'],
|
||||
'platform': row['platform'],
|
||||
'source': row['source'],
|
||||
'content_type': row['content_type'],
|
||||
'file_size': row['file_size'],
|
||||
'timestamp': row['download_date'] or row['created_date'],
|
||||
'action': 'download'
|
||||
})
|
||||
|
||||
# Recent deleted
|
||||
cursor.execute('''
|
||||
SELECT
|
||||
id, original_path, original_filename, recycle_path,
|
||||
file_size, deleted_at, deleted_from, metadata
|
||||
FROM recycle_bin
|
||||
ORDER BY deleted_at DESC
|
||||
LIMIT ?
|
||||
''', (limit,))
|
||||
|
||||
for row in cursor.fetchall():
|
||||
metadata = {}
|
||||
if row['metadata']:
|
||||
try:
|
||||
metadata = json.loads(row['metadata'])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
activity['recent_deleted'].append({
|
||||
'id': row['id'],
|
||||
'file_path': row['recycle_path'],
|
||||
'original_path': row['original_path'],
|
||||
'filename': row['original_filename'],
|
||||
'platform': metadata.get('platform', 'unknown'),
|
||||
'source': metadata.get('source', ''),
|
||||
'content_type': metadata.get('content_type', 'image'),
|
||||
'file_size': row['file_size'] or 0,
|
||||
'timestamp': row['deleted_at'],
|
||||
'deleted_from': row['deleted_from'],
|
||||
'action': 'delete'
|
||||
})
|
||||
|
||||
# Recent moved to review
|
||||
cursor.execute('''
|
||||
SELECT
|
||||
id, file_path, filename, platform, source,
|
||||
content_type, file_size, created_date
|
||||
FROM file_inventory
|
||||
WHERE location = 'review'
|
||||
ORDER BY created_date DESC
|
||||
LIMIT ?
|
||||
''', (limit,))
|
||||
|
||||
for row in cursor.fetchall():
|
||||
activity['recent_moved_to_review'].append({
|
||||
'id': row['id'],
|
||||
'file_path': row['file_path'],
|
||||
'filename': row['filename'],
|
||||
'platform': row['platform'],
|
||||
'source': row['source'],
|
||||
'content_type': row['content_type'],
|
||||
'file_size': row['file_size'],
|
||||
'timestamp': row['created_date'],
|
||||
'action': 'review'
|
||||
})
|
||||
|
||||
# Summary stats
|
||||
cursor.execute('''
|
||||
SELECT COUNT(*) FROM file_inventory
|
||||
WHERE location = 'final'
|
||||
AND created_date >= datetime('now', '-1 day')
|
||||
''')
|
||||
activity['summary']['downloads_24h'] = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute('''
|
||||
SELECT COUNT(*) FROM file_inventory
|
||||
WHERE location = 'final'
|
||||
AND created_date >= datetime('now', '-7 days')
|
||||
''')
|
||||
activity['summary']['downloads_7d'] = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute('''
|
||||
SELECT COUNT(*) FROM recycle_bin
|
||||
WHERE deleted_at >= datetime('now', '-1 day')
|
||||
''')
|
||||
activity['summary']['deleted_24h'] = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute('''
|
||||
SELECT COUNT(*) FROM recycle_bin
|
||||
WHERE deleted_at >= datetime('now', '-7 days')
|
||||
''')
|
||||
activity['summary']['deleted_7d'] = cursor.fetchone()[0]
|
||||
|
||||
return activity
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DISCOVERY QUEUE ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/discovery/queue/stats")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_queue_stats(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get discovery queue statistics."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
stats = discovery.get_queue_stats()
|
||||
return stats
|
||||
|
||||
|
||||
@router.get("/discovery/queue/pending")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_pending_queue(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user),
|
||||
limit: int = Query(100, ge=1, le=1000)
|
||||
):
|
||||
"""Get pending items in the discovery queue."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
items = discovery.get_pending_queue(limit=limit)
|
||||
return {"items": items, "count": len(items)}
|
||||
|
||||
|
||||
@router.post("/discovery/queue/add")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def add_to_queue(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user),
|
||||
file_id: int = Body(...),
|
||||
priority: int = Body(0)
|
||||
):
|
||||
"""Add a file to the discovery queue."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
success = discovery.add_to_queue(file_id, priority=priority)
|
||||
if not success:
|
||||
raise ValidationError("Failed to add file to queue")
|
||||
return {"message": "File added to queue"}
|
||||
|
||||
|
||||
@router.post("/discovery/queue/bulk-add")
|
||||
@limiter.limit("10/minute")
|
||||
@handle_exceptions
|
||||
async def bulk_add_to_queue(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Add multiple files to the discovery queue."""
|
||||
data = await request.json()
|
||||
file_ids = data.get('file_ids', [])
|
||||
priority = data.get('priority', 0)
|
||||
|
||||
if not file_ids:
|
||||
raise ValidationError("file_ids required")
|
||||
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
count = discovery.bulk_add_to_queue(file_ids, priority=priority)
|
||||
return {"message": f"Added {count} files to queue", "count": count}
|
||||
|
||||
|
||||
@router.delete("/discovery/queue/clear")
|
||||
@limiter.limit("10/minute")
|
||||
@handle_exceptions
|
||||
async def clear_queue(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Clear the discovery queue."""
|
||||
app_state = get_app_state()
|
||||
discovery = get_discovery_system(app_state.db)
|
||||
count = discovery.clear_queue()
|
||||
return {"message": f"Cleared {count} items from queue", "count": count}
|
||||
1370
web/backend/routers/downloads.py
Normal file
1370
web/backend/routers/downloads.py
Normal file
File diff suppressed because it is too large
Load Diff
478
web/backend/routers/easynews.py
Normal file
478
web/backend/routers/easynews.py
Normal file
@@ -0,0 +1,478 @@
|
||||
"""
|
||||
Easynews Router
|
||||
|
||||
Handles Easynews integration:
|
||||
- Configuration management (credentials, proxy settings)
|
||||
- Search term management
|
||||
- Manual check triggers
|
||||
- Results browsing and downloads
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from ..core.dependencies import get_current_user, require_admin, get_app_state
|
||||
from ..core.exceptions import handle_exceptions
|
||||
from modules.universal_logger import get_logger
|
||||
|
||||
logger = get_logger('API')
|
||||
|
||||
router = APIRouter(prefix="/api/easynews", tags=["Easynews"])
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
# Thread pool for blocking operations
|
||||
_executor = ThreadPoolExecutor(max_workers=2)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PYDANTIC MODELS
|
||||
# ============================================================================
|
||||
|
||||
class EasynewsConfigUpdate(BaseModel):
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
enabled: Optional[bool] = None
|
||||
check_interval_hours: Optional[int] = None
|
||||
auto_download: Optional[bool] = None
|
||||
min_quality: Optional[str] = None
|
||||
proxy_enabled: Optional[bool] = None
|
||||
proxy_type: Optional[str] = None
|
||||
proxy_host: Optional[str] = None
|
||||
proxy_port: Optional[int] = None
|
||||
proxy_username: Optional[str] = None
|
||||
proxy_password: Optional[str] = None
|
||||
notifications_enabled: Optional[bool] = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# HELPER FUNCTIONS
|
||||
# ============================================================================
|
||||
|
||||
def _get_monitor():
|
||||
"""Get the Easynews monitor instance."""
|
||||
from modules.easynews_monitor import EasynewsMonitor
|
||||
app_state = get_app_state()
|
||||
db_path = str(app_state.db.db_path) # Convert Path to string
|
||||
return EasynewsMonitor(db_path)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CONFIGURATION ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/config")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def get_config(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get Easynews configuration (passwords masked)."""
|
||||
monitor = _get_monitor()
|
||||
config = monitor.get_config()
|
||||
|
||||
# Mask password for security
|
||||
if config.get('password'):
|
||||
config['password'] = '********'
|
||||
if config.get('proxy_password'):
|
||||
config['proxy_password'] = '********'
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"config": config
|
||||
}
|
||||
|
||||
|
||||
@router.put("/config")
|
||||
@limiter.limit("10/minute")
|
||||
@handle_exceptions
|
||||
async def update_config(
|
||||
request: Request,
|
||||
config: EasynewsConfigUpdate,
|
||||
current_user: Dict = Depends(require_admin)
|
||||
):
|
||||
"""Update Easynews configuration."""
|
||||
monitor = _get_monitor()
|
||||
|
||||
# Build update kwargs
|
||||
kwargs = {}
|
||||
if config.username is not None:
|
||||
kwargs['username'] = config.username
|
||||
if config.password is not None and config.password != '********':
|
||||
kwargs['password'] = config.password
|
||||
if config.enabled is not None:
|
||||
kwargs['enabled'] = config.enabled
|
||||
if config.check_interval_hours is not None:
|
||||
kwargs['check_interval_hours'] = config.check_interval_hours
|
||||
if config.auto_download is not None:
|
||||
kwargs['auto_download'] = config.auto_download
|
||||
if config.min_quality is not None:
|
||||
kwargs['min_quality'] = config.min_quality
|
||||
if config.proxy_enabled is not None:
|
||||
kwargs['proxy_enabled'] = config.proxy_enabled
|
||||
if config.proxy_type is not None:
|
||||
kwargs['proxy_type'] = config.proxy_type
|
||||
if config.proxy_host is not None:
|
||||
kwargs['proxy_host'] = config.proxy_host
|
||||
if config.proxy_port is not None:
|
||||
kwargs['proxy_port'] = config.proxy_port
|
||||
if config.proxy_username is not None:
|
||||
kwargs['proxy_username'] = config.proxy_username
|
||||
if config.proxy_password is not None and config.proxy_password != '********':
|
||||
kwargs['proxy_password'] = config.proxy_password
|
||||
if config.notifications_enabled is not None:
|
||||
kwargs['notifications_enabled'] = config.notifications_enabled
|
||||
|
||||
if not kwargs:
|
||||
return {"success": False, "message": "No updates provided"}
|
||||
|
||||
success = monitor.update_config(**kwargs)
|
||||
|
||||
return {
|
||||
"success": success,
|
||||
"message": "Configuration updated" if success else "Update failed"
|
||||
}
|
||||
|
||||
|
||||
@router.post("/test")
|
||||
@limiter.limit("5/minute")
|
||||
@handle_exceptions
|
||||
async def test_connection(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Test Easynews connection with current credentials."""
|
||||
monitor = _get_monitor()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(_executor, monitor.test_connection)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CELEBRITY ENDPOINTS (uses tracked celebrities from Appearances)
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/celebrities")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def get_celebrities(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get all tracked celebrities that will be searched on Easynews."""
|
||||
monitor = _get_monitor()
|
||||
celebrities = monitor.get_celebrities()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"celebrities": celebrities,
|
||||
"count": len(celebrities)
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# RESULTS ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/results")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def get_results(
|
||||
request: Request,
|
||||
status: Optional[str] = None,
|
||||
celebrity_id: Optional[int] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get discovered results with optional filters."""
|
||||
monitor = _get_monitor()
|
||||
results = monitor.get_results(
|
||||
status=status,
|
||||
celebrity_id=celebrity_id,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
# Get total count
|
||||
total = monitor.get_result_count(status=status)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"results": results,
|
||||
"count": len(results),
|
||||
"total": total
|
||||
}
|
||||
|
||||
|
||||
@router.post("/results/{result_id}/status")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def update_result_status(
|
||||
request: Request,
|
||||
result_id: int,
|
||||
status: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Update a result's status (e.g., mark as ignored)."""
|
||||
valid_statuses = ['new', 'downloaded', 'ignored', 'failed']
|
||||
if status not in valid_statuses:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid status. Must be one of: {valid_statuses}")
|
||||
|
||||
monitor = _get_monitor()
|
||||
success = monitor.update_result_status(result_id, status)
|
||||
|
||||
return {
|
||||
"success": success,
|
||||
"message": f"Status updated to {status}" if success else "Update failed"
|
||||
}
|
||||
|
||||
|
||||
@router.post("/results/{result_id}/download")
|
||||
@limiter.limit("10/minute")
|
||||
@handle_exceptions
|
||||
async def download_result(
|
||||
request: Request,
|
||||
result_id: int,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Start downloading a result."""
|
||||
monitor = _get_monitor()
|
||||
|
||||
def do_download():
|
||||
return monitor.download_result(result_id)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(_executor, do_download)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CHECK ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/status")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_status(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get current check status."""
|
||||
monitor = _get_monitor()
|
||||
status = monitor.get_status()
|
||||
config = monitor.get_config()
|
||||
celebrity_count = monitor.get_celebrity_count()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"status": status,
|
||||
"last_check": config.get('last_check'),
|
||||
"enabled": config.get('enabled', False),
|
||||
"has_credentials": config.get('has_credentials', False),
|
||||
"celebrity_count": celebrity_count,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/check")
|
||||
@limiter.limit("5/minute")
|
||||
@handle_exceptions
|
||||
async def trigger_check(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Trigger a manual check for all tracked celebrities."""
|
||||
monitor = _get_monitor()
|
||||
status = monitor.get_status()
|
||||
|
||||
if status.get('is_running'):
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Check already in progress"
|
||||
}
|
||||
|
||||
def do_check():
|
||||
return monitor.check_all_celebrities()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
background_tasks.add_task(loop.run_in_executor, _executor, do_check)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Check started"
|
||||
}
|
||||
|
||||
|
||||
@router.post("/check/{search_id}")
|
||||
@limiter.limit("5/minute")
|
||||
@handle_exceptions
|
||||
async def trigger_single_check(
|
||||
request: Request,
|
||||
search_id: int,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Trigger a manual check for a specific search term."""
|
||||
monitor = _get_monitor()
|
||||
status = monitor.get_status()
|
||||
|
||||
if status.get('is_running'):
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Check already in progress"
|
||||
}
|
||||
|
||||
# Verify search exists
|
||||
search = monitor.get_search(search_id)
|
||||
if not search:
|
||||
raise HTTPException(status_code=404, detail="Search not found")
|
||||
|
||||
def do_check():
|
||||
return monitor.check_single_search(search_id)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
background_tasks.add_task(loop.run_in_executor, _executor, do_check)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Check started for: {search['search_term']}"
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SEARCH MANAGEMENT ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
class EasynewsSearchCreate(BaseModel):
|
||||
search_term: str
|
||||
media_type: Optional[str] = 'any'
|
||||
tmdb_id: Optional[int] = None
|
||||
tmdb_title: Optional[str] = None
|
||||
poster_url: Optional[str] = None
|
||||
|
||||
|
||||
class EasynewsSearchUpdate(BaseModel):
|
||||
search_term: Optional[str] = None
|
||||
media_type: Optional[str] = None
|
||||
enabled: Optional[bool] = None
|
||||
tmdb_id: Optional[int] = None
|
||||
tmdb_title: Optional[str] = None
|
||||
poster_url: Optional[str] = None
|
||||
|
||||
|
||||
@router.get("/searches")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def get_searches(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get all saved search terms."""
|
||||
monitor = _get_monitor()
|
||||
searches = monitor.get_all_searches()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"searches": searches,
|
||||
"count": len(searches)
|
||||
}
|
||||
|
||||
|
||||
@router.post("/searches")
|
||||
@limiter.limit("10/minute")
|
||||
@handle_exceptions
|
||||
async def add_search(
|
||||
request: Request,
|
||||
search: EasynewsSearchCreate,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Add a new search term."""
|
||||
monitor = _get_monitor()
|
||||
|
||||
search_id = monitor.add_search(
|
||||
search_term=search.search_term,
|
||||
media_type=search.media_type,
|
||||
tmdb_id=search.tmdb_id,
|
||||
tmdb_title=search.tmdb_title,
|
||||
poster_url=search.poster_url
|
||||
)
|
||||
|
||||
if search_id:
|
||||
return {
|
||||
"success": True,
|
||||
"id": search_id,
|
||||
"message": f"Search term '{search.search_term}' added"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Failed to add search term"
|
||||
}
|
||||
|
||||
|
||||
@router.put("/searches/{search_id}")
|
||||
@limiter.limit("10/minute")
|
||||
@handle_exceptions
|
||||
async def update_search(
|
||||
request: Request,
|
||||
search_id: int,
|
||||
updates: EasynewsSearchUpdate,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Update an existing search term."""
|
||||
monitor = _get_monitor()
|
||||
|
||||
# Build update kwargs
|
||||
kwargs = {}
|
||||
if updates.search_term is not None:
|
||||
kwargs['search_term'] = updates.search_term
|
||||
if updates.media_type is not None:
|
||||
kwargs['media_type'] = updates.media_type
|
||||
if updates.enabled is not None:
|
||||
kwargs['enabled'] = updates.enabled
|
||||
if updates.tmdb_id is not None:
|
||||
kwargs['tmdb_id'] = updates.tmdb_id
|
||||
if updates.tmdb_title is not None:
|
||||
kwargs['tmdb_title'] = updates.tmdb_title
|
||||
if updates.poster_url is not None:
|
||||
kwargs['poster_url'] = updates.poster_url
|
||||
|
||||
if not kwargs:
|
||||
return {"success": False, "message": "No updates provided"}
|
||||
|
||||
success = monitor.update_search(search_id, **kwargs)
|
||||
|
||||
return {
|
||||
"success": success,
|
||||
"message": "Search updated" if success else "Update failed"
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/searches/{search_id}")
|
||||
@limiter.limit("10/minute")
|
||||
@handle_exceptions
|
||||
async def delete_search(
|
||||
request: Request,
|
||||
search_id: int,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Delete a search term."""
|
||||
monitor = _get_monitor()
|
||||
success = monitor.delete_search(search_id)
|
||||
|
||||
return {
|
||||
"success": success,
|
||||
"message": "Search deleted" if success else "Delete failed"
|
||||
}
|
||||
1248
web/backend/routers/face.py
Normal file
1248
web/backend/routers/face.py
Normal file
File diff suppressed because it is too large
Load Diff
128
web/backend/routers/files.py
Normal file
128
web/backend/routers/files.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
File serving and thumbnail generation API
|
||||
|
||||
Provides endpoints for:
|
||||
- On-demand thumbnail generation for images and videos
|
||||
- File serving with proper caching headers
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from fastapi.responses import Response
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
import io
|
||||
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from ..core.dependencies import get_current_user
|
||||
from ..core.exceptions import handle_exceptions, NotFoundError, ValidationError
|
||||
from ..core.utils import validate_file_path
|
||||
from modules.universal_logger import get_logger
|
||||
|
||||
router = APIRouter(prefix="/files", tags=["files"])
|
||||
logger = get_logger('FilesRouter')
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
@router.get("/thumbnail")
|
||||
@limiter.limit("300/minute")
|
||||
@handle_exceptions
|
||||
async def get_thumbnail(
|
||||
request: Request,
|
||||
path: str = Query(..., description="File path"),
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Generate and return thumbnail for image or video
|
||||
|
||||
Args:
|
||||
path: Absolute path to file
|
||||
|
||||
Returns:
|
||||
JPEG thumbnail (200x200px max, maintains aspect ratio)
|
||||
"""
|
||||
# Validate file is within allowed directories (prevents path traversal)
|
||||
file_path = validate_file_path(path, require_exists=True)
|
||||
|
||||
file_ext = file_path.suffix.lower()
|
||||
|
||||
# Generate thumbnail based on type
|
||||
if file_ext in ['.jpg', '.jpeg', '.png', '.webp', '.gif', '.heic']:
|
||||
# Image thumbnail with PIL
|
||||
img = Image.open(file_path)
|
||||
|
||||
# Convert HEIC if needed
|
||||
if file_ext == '.heic':
|
||||
img = img.convert('RGB')
|
||||
|
||||
# Create thumbnail (maintains aspect ratio)
|
||||
img.thumbnail((200, 200), Image.Resampling.LANCZOS)
|
||||
|
||||
# Convert to JPEG
|
||||
buffer = io.BytesIO()
|
||||
if img.mode in ('RGBA', 'LA', 'P'):
|
||||
# Convert transparency to white background
|
||||
background = Image.new('RGB', img.size, (255, 255, 255))
|
||||
if img.mode == 'P':
|
||||
img = img.convert('RGBA')
|
||||
background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None)
|
||||
img = background
|
||||
|
||||
img.convert('RGB').save(buffer, format='JPEG', quality=85, optimize=True)
|
||||
buffer.seek(0)
|
||||
|
||||
return Response(
|
||||
content=buffer.read(),
|
||||
media_type="image/jpeg",
|
||||
headers={"Cache-Control": "public, max-age=3600"}
|
||||
)
|
||||
|
||||
elif file_ext in ['.mp4', '.webm', '.mov', '.avi', '.mkv']:
|
||||
# Video thumbnail with ffmpeg
|
||||
result = subprocess.run(
|
||||
[
|
||||
'ffmpeg',
|
||||
'-ss', '00:00:01', # Seek to 1 second
|
||||
'-i', str(file_path),
|
||||
'-vframes', '1', # Extract 1 frame
|
||||
'-vf', 'scale=200:-1', # Scale to 200px width, maintain aspect
|
||||
'-f', 'image2pipe', # Output to pipe
|
||||
'-vcodec', 'mjpeg', # JPEG codec
|
||||
'pipe:1'
|
||||
],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
timeout=10,
|
||||
check=False
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
# Try without seeking (for very short videos)
|
||||
result = subprocess.run(
|
||||
[
|
||||
'ffmpeg',
|
||||
'-i', str(file_path),
|
||||
'-vframes', '1',
|
||||
'-vf', 'scale=200:-1',
|
||||
'-f', 'image2pipe',
|
||||
'-vcodec', 'mjpeg',
|
||||
'pipe:1'
|
||||
],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
timeout=10,
|
||||
check=True
|
||||
)
|
||||
|
||||
return Response(
|
||||
content=result.stdout,
|
||||
media_type="image/jpeg",
|
||||
headers={"Cache-Control": "public, max-age=3600"}
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValidationError("Unsupported file type")
|
||||
1054
web/backend/routers/health.py
Normal file
1054
web/backend/routers/health.py
Normal file
File diff suppressed because it is too large
Load Diff
436
web/backend/routers/instagram_unified.py
Normal file
436
web/backend/routers/instagram_unified.py
Normal file
@@ -0,0 +1,436 @@
|
||||
"""
|
||||
Instagram Unified Configuration Router
|
||||
|
||||
Provides a single configuration interface for all Instagram scrapers.
|
||||
Manages one central account list with per-account content type toggles,
|
||||
scraper assignments, and auto-generates legacy per-scraper configs on save.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from ..core.dependencies import get_current_user, get_app_state
|
||||
from ..core.exceptions import handle_exceptions, ValidationError
|
||||
from modules.universal_logger import get_logger
|
||||
|
||||
logger = get_logger('API')
|
||||
|
||||
router = APIRouter(prefix="/api/instagram-unified", tags=["Instagram Unified"])
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
# Scraper capability matrix
|
||||
SCRAPER_CAPABILITIES = {
|
||||
'fastdl': {'posts': True, 'stories': True, 'reels': True, 'tagged': False},
|
||||
'imginn_api': {'posts': True, 'stories': True, 'reels': False, 'tagged': True},
|
||||
'imginn': {'posts': True, 'stories': True, 'reels': False, 'tagged': True},
|
||||
'toolzu': {'posts': True, 'stories': True, 'reels': False, 'tagged': False},
|
||||
'instagram_client': {'posts': True, 'stories': True, 'reels': True, 'tagged': True},
|
||||
'instagram': {'posts': True, 'stories': True, 'reels': False, 'tagged': True},
|
||||
}
|
||||
|
||||
SCRAPER_LABELS = {
|
||||
'fastdl': 'FastDL',
|
||||
'imginn_api': 'ImgInn API',
|
||||
'imginn': 'ImgInn',
|
||||
'toolzu': 'Toolzu',
|
||||
'instagram_client': 'Instagram Client',
|
||||
'instagram': 'InstaLoader',
|
||||
}
|
||||
|
||||
CONTENT_TYPES = ['posts', 'stories', 'reels', 'tagged']
|
||||
|
||||
LEGACY_SCRAPER_KEYS = ['fastdl', 'imginn_api', 'imginn', 'toolzu', 'instagram_client', 'instagram']
|
||||
|
||||
# Default destination paths
|
||||
DEFAULT_PATHS = {
|
||||
'posts': '/opt/immich/md/social media/instagram/posts',
|
||||
'stories': '/opt/immich/md/social media/instagram/stories',
|
||||
'reels': '/opt/immich/md/social media/instagram/reels',
|
||||
'tagged': '/opt/immich/md/social media/instagram/tagged',
|
||||
}
|
||||
|
||||
|
||||
class UnifiedConfigUpdate(BaseModel):
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# MIGRATION LOGIC
|
||||
# ============================================================================
|
||||
|
||||
def _migrate_from_legacy(app_state) -> Dict[str, Any]:
|
||||
"""
|
||||
Build a unified config from existing per-scraper configs.
|
||||
Called on first load when no instagram_unified key exists.
|
||||
"""
|
||||
settings = app_state.settings
|
||||
|
||||
# Load all legacy configs
|
||||
legacy = {}
|
||||
for key in LEGACY_SCRAPER_KEYS:
|
||||
legacy[key] = settings.get(key) or {}
|
||||
|
||||
# Determine scraper assignments from currently enabled configs
|
||||
# Only assign a scraper if it actually supports the content type per capability matrix
|
||||
scraper_assignment = {}
|
||||
for ct in CONTENT_TYPES:
|
||||
assigned = None
|
||||
for scraper_key in LEGACY_SCRAPER_KEYS:
|
||||
cfg = legacy[scraper_key]
|
||||
if (cfg.get('enabled') and cfg.get(ct, {}).get('enabled')
|
||||
and SCRAPER_CAPABILITIES.get(scraper_key, {}).get(ct)):
|
||||
assigned = scraper_key
|
||||
break
|
||||
# Fallback defaults if nothing is enabled — pick first capable scraper
|
||||
if not assigned:
|
||||
if ct in ('posts', 'tagged'):
|
||||
assigned = 'imginn_api'
|
||||
elif ct in ('stories', 'reels'):
|
||||
assigned = 'fastdl'
|
||||
scraper_assignment[ct] = assigned
|
||||
|
||||
# Collect all unique usernames from all scrapers
|
||||
all_usernames = set()
|
||||
scraper_usernames = {} # track which usernames belong to which scraper
|
||||
|
||||
for scraper_key in LEGACY_SCRAPER_KEYS:
|
||||
cfg = legacy[scraper_key]
|
||||
usernames = []
|
||||
|
||||
if scraper_key == 'instagram':
|
||||
# InstaLoader uses accounts list format
|
||||
for acc in cfg.get('accounts', []):
|
||||
u = acc.get('username')
|
||||
if u:
|
||||
usernames.append(u)
|
||||
else:
|
||||
usernames = cfg.get('usernames', [])
|
||||
|
||||
# Also include phrase_search usernames
|
||||
ps_usernames = cfg.get('phrase_search', {}).get('usernames', [])
|
||||
combined = set(usernames) | set(ps_usernames)
|
||||
scraper_usernames[scraper_key] = combined
|
||||
all_usernames |= combined
|
||||
|
||||
# Build per-account content type flags
|
||||
# An account gets a content type enabled if:
|
||||
# - The scraper assigned to that content type has this username in its list
|
||||
accounts = []
|
||||
for username in sorted(all_usernames):
|
||||
account = {'username': username}
|
||||
for ct in CONTENT_TYPES:
|
||||
assigned_scraper = scraper_assignment[ct]
|
||||
# Enable if this user is in the assigned scraper's list
|
||||
account[ct] = username in scraper_usernames.get(assigned_scraper, set())
|
||||
accounts.append(account)
|
||||
|
||||
# Import content type settings from the first enabled scraper that has them
|
||||
content_types = {}
|
||||
for ct in CONTENT_TYPES:
|
||||
ct_config = {'enabled': False, 'days_back': 7, 'destination_path': DEFAULT_PATHS[ct]}
|
||||
assigned = scraper_assignment[ct]
|
||||
cfg = legacy.get(assigned, {})
|
||||
ct_sub = cfg.get(ct, {})
|
||||
if ct_sub.get('enabled'):
|
||||
ct_config['enabled'] = True
|
||||
if ct_sub.get('days_back'):
|
||||
ct_config['days_back'] = ct_sub['days_back']
|
||||
if ct_sub.get('destination_path'):
|
||||
ct_config['destination_path'] = ct_sub['destination_path']
|
||||
content_types[ct] = ct_config
|
||||
|
||||
# Import phrase search from first scraper that has it enabled
|
||||
phrase_search = {
|
||||
'enabled': False,
|
||||
'download_all': True,
|
||||
'phrases': [],
|
||||
'case_sensitive': False,
|
||||
'match_all': False,
|
||||
}
|
||||
for scraper_key in LEGACY_SCRAPER_KEYS:
|
||||
ps = legacy[scraper_key].get('phrase_search', {})
|
||||
if ps.get('enabled') or ps.get('phrases'):
|
||||
phrase_search['enabled'] = ps.get('enabled', False)
|
||||
phrase_search['download_all'] = ps.get('download_all', True)
|
||||
phrase_search['phrases'] = ps.get('phrases', [])
|
||||
phrase_search['case_sensitive'] = ps.get('case_sensitive', False)
|
||||
phrase_search['match_all'] = ps.get('match_all', False)
|
||||
break
|
||||
|
||||
# Import scraper-specific settings (auth, cookies, etc.)
|
||||
scraper_settings = {
|
||||
'fastdl': {},
|
||||
'imginn_api': {},
|
||||
'imginn': {'cookie_file': legacy['imginn'].get('cookie_file', '')},
|
||||
'toolzu': {
|
||||
'email': legacy['toolzu'].get('email', ''),
|
||||
'password': legacy['toolzu'].get('password', ''),
|
||||
'cookie_file': legacy['toolzu'].get('cookie_file', ''),
|
||||
},
|
||||
'instagram_client': {},
|
||||
'instagram': {
|
||||
'username': legacy['instagram'].get('username', ''),
|
||||
'password': legacy['instagram'].get('password', ''),
|
||||
'totp_secret': legacy['instagram'].get('totp_secret', ''),
|
||||
'session_file': legacy['instagram'].get('session_file', ''),
|
||||
},
|
||||
}
|
||||
|
||||
# Get global settings from the primary enabled scraper
|
||||
check_interval = 8
|
||||
run_at_start = False
|
||||
user_delay = 20
|
||||
for scraper_key in ['fastdl', 'imginn_api']:
|
||||
cfg = legacy[scraper_key]
|
||||
if cfg.get('enabled'):
|
||||
check_interval = cfg.get('check_interval_hours', 8)
|
||||
run_at_start = cfg.get('run_at_start', False)
|
||||
user_delay = cfg.get('user_delay_seconds', 20)
|
||||
break
|
||||
|
||||
return {
|
||||
'enabled': True,
|
||||
'check_interval_hours': check_interval,
|
||||
'run_at_start': run_at_start,
|
||||
'user_delay_seconds': user_delay,
|
||||
'scraper_assignment': scraper_assignment,
|
||||
'content_types': content_types,
|
||||
'accounts': accounts,
|
||||
'phrase_search': phrase_search,
|
||||
'scraper_settings': scraper_settings,
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# LEGACY CONFIG GENERATION
|
||||
# ============================================================================
|
||||
|
||||
def _generate_legacy_configs(unified: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
From the unified config, generate 6 legacy per-scraper configs
|
||||
that the existing scraper modules can consume.
|
||||
"""
|
||||
scraper_assignment = unified.get('scraper_assignment', {})
|
||||
content_types = unified.get('content_types', {})
|
||||
accounts = unified.get('accounts', [])
|
||||
phrase_search = unified.get('phrase_search', {})
|
||||
scraper_settings = unified.get('scraper_settings', {})
|
||||
|
||||
# For each scraper, determine which content types it's assigned to
|
||||
scraper_content_types = {key: [] for key in LEGACY_SCRAPER_KEYS}
|
||||
for ct, scraper_key in scraper_assignment.items():
|
||||
if scraper_key in scraper_content_types:
|
||||
scraper_content_types[scraper_key].append(ct)
|
||||
|
||||
# For each scraper, collect usernames that have any of its assigned content types enabled
|
||||
scraper_usernames = {key: set() for key in LEGACY_SCRAPER_KEYS}
|
||||
for account in accounts:
|
||||
username = account.get('username', '')
|
||||
if not username:
|
||||
continue
|
||||
for ct, scraper_key in scraper_assignment.items():
|
||||
if account.get(ct, False):
|
||||
scraper_usernames[scraper_key].add(username)
|
||||
|
||||
# Build legacy configs
|
||||
result = {}
|
||||
for scraper_key in LEGACY_SCRAPER_KEYS:
|
||||
assigned_cts = scraper_content_types[scraper_key]
|
||||
usernames = sorted(scraper_usernames[scraper_key])
|
||||
is_enabled = unified.get('enabled', False) and len(assigned_cts) > 0 and len(usernames) > 0
|
||||
extra_settings = scraper_settings.get(scraper_key, {})
|
||||
|
||||
cfg = {
|
||||
'enabled': is_enabled,
|
||||
'check_interval_hours': unified.get('check_interval_hours', 8),
|
||||
'run_at_start': unified.get('run_at_start', False),
|
||||
}
|
||||
|
||||
# Add user_delay_seconds for scrapers that support it
|
||||
if scraper_key in ('imginn_api', 'instagram_client'):
|
||||
cfg['user_delay_seconds'] = unified.get('user_delay_seconds', 20)
|
||||
|
||||
# Add scraper-specific settings
|
||||
if scraper_key == 'fastdl':
|
||||
cfg['high_res'] = True # Always use high resolution
|
||||
elif scraper_key == 'imginn':
|
||||
cfg['cookie_file'] = extra_settings.get('cookie_file', '')
|
||||
elif scraper_key == 'toolzu':
|
||||
cfg['email'] = extra_settings.get('email', '')
|
||||
cfg['password'] = extra_settings.get('password', '')
|
||||
cfg['cookie_file'] = extra_settings.get('cookie_file', '')
|
||||
|
||||
# InstaLoader uses accounts format + auth fields at top level
|
||||
if scraper_key == 'instagram':
|
||||
ig_settings = extra_settings
|
||||
cfg['method'] = 'instaloader'
|
||||
cfg['username'] = ig_settings.get('username', '')
|
||||
cfg['password'] = ig_settings.get('password', '')
|
||||
cfg['totp_secret'] = ig_settings.get('totp_secret', '')
|
||||
cfg['session_file'] = ig_settings.get('session_file', '')
|
||||
cfg['accounts'] = [
|
||||
{'username': u, 'check_interval_hours': unified.get('check_interval_hours', 8), 'run_at_start': False}
|
||||
for u in usernames
|
||||
]
|
||||
else:
|
||||
cfg['usernames'] = usernames
|
||||
|
||||
# Content type sub-configs
|
||||
for ct in CONTENT_TYPES:
|
||||
ct_global = content_types.get(ct, {})
|
||||
if ct in assigned_cts:
|
||||
cfg[ct] = {
|
||||
'enabled': ct_global.get('enabled', False),
|
||||
'days_back': ct_global.get('days_back', 7),
|
||||
'destination_path': ct_global.get('destination_path', DEFAULT_PATHS.get(ct, '')),
|
||||
}
|
||||
# Add temp_dir based on scraper key
|
||||
cfg[ct]['temp_dir'] = f'temp/{scraper_key}/{ct}'
|
||||
else:
|
||||
cfg[ct] = {'enabled': False}
|
||||
|
||||
# Phrase search goes on the scraper assigned to posts
|
||||
posts_scraper = scraper_assignment.get('posts')
|
||||
if scraper_key == posts_scraper and phrase_search.get('enabled'):
|
||||
# Collect all usernames that have posts enabled for the phrase search usernames list
|
||||
ps_usernames = sorted(scraper_usernames.get(scraper_key, set()))
|
||||
cfg['phrase_search'] = {
|
||||
'enabled': phrase_search.get('enabled', False),
|
||||
'download_all': phrase_search.get('download_all', True),
|
||||
'usernames': ps_usernames,
|
||||
'phrases': phrase_search.get('phrases', []),
|
||||
'case_sensitive': phrase_search.get('case_sensitive', False),
|
||||
'match_all': phrase_search.get('match_all', False),
|
||||
}
|
||||
else:
|
||||
cfg['phrase_search'] = {
|
||||
'enabled': False,
|
||||
'usernames': [],
|
||||
'phrases': [],
|
||||
'case_sensitive': False,
|
||||
'match_all': False,
|
||||
}
|
||||
|
||||
result[scraper_key] = cfg
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/config")
|
||||
@handle_exceptions
|
||||
async def get_config(request: Request, user=Depends(get_current_user)):
|
||||
"""
|
||||
Load unified config. Auto-migrates from legacy configs on first load.
|
||||
Returns config (or generated migration preview), hidden_modules, and scraper capabilities.
|
||||
"""
|
||||
app_state = get_app_state()
|
||||
existing = app_state.settings.get('instagram_unified')
|
||||
hidden_modules = app_state.settings.get('hidden_modules') or []
|
||||
|
||||
if existing:
|
||||
return {
|
||||
'config': existing,
|
||||
'migrated': False,
|
||||
'hidden_modules': hidden_modules,
|
||||
'scraper_capabilities': SCRAPER_CAPABILITIES,
|
||||
'scraper_labels': SCRAPER_LABELS,
|
||||
}
|
||||
|
||||
# No unified config yet — generate migration preview (not auto-saved)
|
||||
migrated_config = _migrate_from_legacy(app_state)
|
||||
return {
|
||||
'config': migrated_config,
|
||||
'migrated': True,
|
||||
'hidden_modules': hidden_modules,
|
||||
'scraper_capabilities': SCRAPER_CAPABILITIES,
|
||||
'scraper_labels': SCRAPER_LABELS,
|
||||
}
|
||||
|
||||
|
||||
@router.put("/config")
|
||||
@handle_exceptions
|
||||
async def update_config(request: Request, body: UnifiedConfigUpdate, user=Depends(get_current_user)):
|
||||
"""
|
||||
Save unified config + generate 6 legacy configs.
|
||||
"""
|
||||
app_state = get_app_state()
|
||||
config = body.config
|
||||
|
||||
# Validate scraper assignments against capability matrix
|
||||
scraper_assignment = config.get('scraper_assignment', {})
|
||||
for ct, scraper_key in scraper_assignment.items():
|
||||
if ct not in CONTENT_TYPES:
|
||||
raise ValidationError(f"Unknown content type: {ct}")
|
||||
if scraper_key not in SCRAPER_CAPABILITIES:
|
||||
raise ValidationError(f"Unknown scraper: {scraper_key}")
|
||||
if not SCRAPER_CAPABILITIES[scraper_key].get(ct):
|
||||
raise ValidationError(
|
||||
f"Scraper {SCRAPER_LABELS.get(scraper_key, scraper_key)} does not support {ct}"
|
||||
)
|
||||
|
||||
# Save unified config
|
||||
app_state.settings.set(
|
||||
key='instagram_unified',
|
||||
value=config,
|
||||
category='scrapers',
|
||||
description='Unified Instagram configuration',
|
||||
updated_by='api'
|
||||
)
|
||||
|
||||
# Generate and save 6 legacy configs
|
||||
legacy_configs = _generate_legacy_configs(config)
|
||||
for scraper_key, legacy_cfg in legacy_configs.items():
|
||||
app_state.settings.set(
|
||||
key=scraper_key,
|
||||
value=legacy_cfg,
|
||||
category='scrapers',
|
||||
description=f'{SCRAPER_LABELS.get(scraper_key, scraper_key)} configuration (auto-generated)',
|
||||
updated_by='api'
|
||||
)
|
||||
|
||||
# Refresh in-memory config
|
||||
if hasattr(app_state, 'config') and app_state.config is not None:
|
||||
app_state.config['instagram_unified'] = config
|
||||
for scraper_key, legacy_cfg in legacy_configs.items():
|
||||
app_state.config[scraper_key] = legacy_cfg
|
||||
|
||||
# Broadcast config update via WebSocket
|
||||
try:
|
||||
if hasattr(app_state, 'websocket_manager') and app_state.websocket_manager:
|
||||
import asyncio
|
||||
asyncio.create_task(app_state.websocket_manager.broadcast({
|
||||
'type': 'config_updated',
|
||||
'data': {'source': 'instagram_unified'}
|
||||
}))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to broadcast config update: {e}", module="InstagramUnified")
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': 'Instagram configuration saved',
|
||||
'legacy_configs_updated': list(legacy_configs.keys()),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/capabilities")
|
||||
@handle_exceptions
|
||||
async def get_capabilities(request: Request, user=Depends(get_current_user)):
|
||||
"""Return scraper capability matrix and hidden modules."""
|
||||
app_state = get_app_state()
|
||||
hidden_modules = app_state.settings.get('hidden_modules') or []
|
||||
|
||||
return {
|
||||
'scraper_capabilities': SCRAPER_CAPABILITIES,
|
||||
'scraper_labels': SCRAPER_LABELS,
|
||||
'hidden_modules': hidden_modules,
|
||||
'content_types': CONTENT_TYPES,
|
||||
}
|
||||
259
web/backend/routers/maintenance.py
Normal file
259
web/backend/routers/maintenance.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
Maintenance Router
|
||||
|
||||
Handles database maintenance and cleanup operations:
|
||||
- Scan and remove missing file references
|
||||
- Database integrity checks
|
||||
- Orphaned record cleanup
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Dict, List
|
||||
from fastapi import APIRouter, Depends, Request, BackgroundTasks
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from ..core.dependencies import get_current_user, get_app_state
|
||||
from ..core.config import settings
|
||||
from ..core.responses import now_iso8601
|
||||
from ..core.exceptions import handle_exceptions
|
||||
from modules.universal_logger import get_logger
|
||||
|
||||
logger = get_logger('Maintenance')
|
||||
|
||||
router = APIRouter(prefix="/api/maintenance", tags=["Maintenance"])
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
# Whitelist of allowed table/column combinations for cleanup operations
|
||||
# This prevents SQL injection by only allowing known-safe identifiers
|
||||
ALLOWED_CLEANUP_TABLES = {
|
||||
"file_inventory": "file_path",
|
||||
"downloads": "file_path",
|
||||
"youtube_downloads": "file_path",
|
||||
"video_downloads": "file_path",
|
||||
"face_recognition_scans": "file_path",
|
||||
"face_recognition_references": "reference_image_path",
|
||||
"discovery_scan_queue": "file_path",
|
||||
"recycle_bin": "recycle_path",
|
||||
}
|
||||
|
||||
# Pre-built SQL queries for each allowed table (avoids any string interpolation)
|
||||
# Uses 'id' instead of 'rowid' (PostgreSQL does not have rowid)
|
||||
# Uses information_schema for table existence checks (PostgreSQL)
|
||||
_CLEANUP_QUERIES = {
|
||||
table: {
|
||||
"check_exists": "SELECT table_name FROM information_schema.tables WHERE table_schema='public' AND table_name=?",
|
||||
"select": f"SELECT id, {col} FROM {table} WHERE {col} IS NOT NULL AND {col} != ''",
|
||||
"delete": f"DELETE FROM {table} WHERE id IN ",
|
||||
}
|
||||
for table, col in ALLOWED_CLEANUP_TABLES.items()
|
||||
}
|
||||
|
||||
# Store last scan results
|
||||
last_scan_result = None
|
||||
|
||||
|
||||
@router.post("/cleanup/missing-files")
|
||||
@limiter.limit("5/hour")
|
||||
@handle_exceptions
|
||||
async def cleanup_missing_files(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
dry_run: bool = True,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Scan all database tables for file references and remove entries for missing files.
|
||||
|
||||
Args:
|
||||
dry_run: If True, only report what would be deleted (default: True)
|
||||
|
||||
Returns:
|
||||
Summary of files found and removed
|
||||
"""
|
||||
app_state = get_app_state()
|
||||
user_id = current_user.get('sub', 'unknown')
|
||||
|
||||
logger.info(f"Database cleanup started by {user_id} (dry_run={dry_run})", module="Maintenance")
|
||||
|
||||
# Run cleanup in background
|
||||
background_tasks.add_task(
|
||||
_cleanup_missing_files_task,
|
||||
app_state,
|
||||
dry_run,
|
||||
user_id
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "started",
|
||||
"dry_run": dry_run,
|
||||
"message": "Cleanup scan started in background. Check /api/maintenance/cleanup/status for progress.",
|
||||
"timestamp": now_iso8601()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/cleanup/status")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_cleanup_status(request: Request, current_user: Dict = Depends(get_current_user)):
|
||||
"""Get the status and results of the last cleanup scan"""
|
||||
global last_scan_result
|
||||
|
||||
if last_scan_result is None:
|
||||
return {
|
||||
"status": "no_scan",
|
||||
"message": "No cleanup scan has been run yet"
|
||||
}
|
||||
|
||||
return last_scan_result
|
||||
|
||||
|
||||
async def _cleanup_missing_files_task(app_state, dry_run: bool, user_id: str):
|
||||
"""Background task to scan and cleanup missing files"""
|
||||
global last_scan_result
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
# Initialize result tracking
|
||||
result = {
|
||||
"status": "running",
|
||||
"started_at": start_time.isoformat(),
|
||||
"dry_run": dry_run,
|
||||
"user": user_id,
|
||||
"tables_scanned": {},
|
||||
"total_checked": 0,
|
||||
"total_missing": 0,
|
||||
"total_removed": 0
|
||||
}
|
||||
|
||||
try:
|
||||
with app_state.db.get_connection(for_write=True) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Define tables and their file path columns
|
||||
# NOTE: instagram_perceptual_hashes is excluded because the hash data
|
||||
# is valuable for duplicate detection even if the original file is gone
|
||||
tables_to_scan = [
|
||||
("file_inventory", "file_path"),
|
||||
("downloads", "file_path"),
|
||||
("youtube_downloads", "file_path"),
|
||||
("video_downloads", "file_path"),
|
||||
("face_recognition_scans", "file_path"),
|
||||
("face_recognition_references", "reference_image_path"),
|
||||
("discovery_scan_queue", "file_path"),
|
||||
("recycle_bin", "recycle_path"),
|
||||
]
|
||||
|
||||
for table_name, column_name in tables_to_scan:
|
||||
logger.info(f"Scanning {table_name}.{column_name}...", module="Maintenance")
|
||||
|
||||
table_result = _scan_table(cursor, table_name, column_name, dry_run)
|
||||
result["tables_scanned"][table_name] = table_result
|
||||
result["total_checked"] += table_result["checked"]
|
||||
result["total_missing"] += table_result["missing"]
|
||||
result["total_removed"] += table_result["removed"]
|
||||
|
||||
# Commit if not dry run
|
||||
if not dry_run:
|
||||
conn.commit()
|
||||
logger.info(f"Cleanup completed: removed {result['total_removed']} records", module="Maintenance")
|
||||
else:
|
||||
logger.info(f"Dry run completed: {result['total_missing']} records would be removed", module="Maintenance")
|
||||
|
||||
# Update result
|
||||
end_time = datetime.now()
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
|
||||
result.update({
|
||||
"status": "completed",
|
||||
"completed_at": end_time.isoformat(),
|
||||
"duration_seconds": round(duration, 2)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cleanup failed: {e}", module="Maintenance", exc_info=True)
|
||||
result.update({
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"completed_at": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
last_scan_result = result
|
||||
|
||||
|
||||
def _scan_table(cursor, table_name: str, column_name: str, dry_run: bool) -> Dict:
|
||||
"""Scan a table for missing files and optionally remove them.
|
||||
|
||||
Uses pre-built queries from _CLEANUP_QUERIES to prevent SQL injection.
|
||||
Only tables in ALLOWED_CLEANUP_TABLES whitelist are permitted.
|
||||
"""
|
||||
result = {
|
||||
"checked": 0,
|
||||
"missing": 0,
|
||||
"removed": 0,
|
||||
"missing_files": []
|
||||
}
|
||||
|
||||
# Validate table/column against whitelist to prevent SQL injection
|
||||
if table_name not in ALLOWED_CLEANUP_TABLES:
|
||||
logger.error(f"Table {table_name} not in allowed whitelist", module="Maintenance")
|
||||
result["error"] = f"Table {table_name} not allowed"
|
||||
return result
|
||||
|
||||
if ALLOWED_CLEANUP_TABLES[table_name] != column_name:
|
||||
logger.error(f"Column {column_name} not allowed for table {table_name}", module="Maintenance")
|
||||
result["error"] = f"Column {column_name} not allowed for table {table_name}"
|
||||
return result
|
||||
|
||||
# Get pre-built queries for this table (built at module load time, not from user input)
|
||||
queries = _CLEANUP_QUERIES[table_name]
|
||||
|
||||
try:
|
||||
# Check if table exists using parameterized query
|
||||
cursor.execute(queries["check_exists"], (table_name,))
|
||||
if not cursor.fetchone():
|
||||
logger.warning(f"Table {table_name} does not exist", module="Maintenance")
|
||||
return result
|
||||
|
||||
# Get all file paths from table using pre-built query
|
||||
cursor.execute(queries["select"])
|
||||
|
||||
rows = cursor.fetchall()
|
||||
result["checked"] = len(rows)
|
||||
|
||||
missing_rowids = []
|
||||
|
||||
for rowid, file_path in rows:
|
||||
if file_path and not os.path.exists(file_path):
|
||||
result["missing"] += 1
|
||||
missing_rowids.append(rowid)
|
||||
|
||||
# Only keep first 100 examples
|
||||
if len(result["missing_files"]) < 100:
|
||||
result["missing_files"].append(file_path)
|
||||
|
||||
# Remove missing entries if not dry run
|
||||
if not dry_run and missing_rowids:
|
||||
# Delete in batches of 100 using pre-built query base
|
||||
delete_base = queries["delete"]
|
||||
for i in range(0, len(missing_rowids), 100):
|
||||
batch = missing_rowids[i:i+100]
|
||||
placeholders = ','.join('?' * len(batch))
|
||||
# The delete_base is pre-built from whitelist, placeholders are just ?
|
||||
cursor.execute(f"{delete_base}({placeholders})", batch)
|
||||
result["removed"] += len(batch)
|
||||
|
||||
logger.info(
|
||||
f" {table_name}: checked={result['checked']}, missing={result['missing']}, "
|
||||
f"{'would_remove' if dry_run else 'removed'}={result['missing']}",
|
||||
module="Maintenance"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning {table_name}: {e}", module="Maintenance", exc_info=True)
|
||||
result["error"] = str(e)
|
||||
|
||||
return result
|
||||
669
web/backend/routers/manual_import.py
Normal file
669
web/backend/routers/manual_import.py
Normal file
@@ -0,0 +1,669 @@
|
||||
"""
|
||||
Manual Import Router
|
||||
|
||||
Handles manual file import operations:
|
||||
- Service configuration
|
||||
- File upload to temp directory
|
||||
- Filename parsing
|
||||
- Processing and moving to final destination (async background processing)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import shutil
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, Request, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from ..core.dependencies import get_current_user, get_app_state
|
||||
from ..core.exceptions import handle_exceptions, NotFoundError, ValidationError
|
||||
from modules.filename_parser import FilenameParser, get_preset_patterns, parse_with_fallbacks, INSTAGRAM_PATTERNS
|
||||
from modules.universal_logger import get_logger
|
||||
|
||||
logger = get_logger('API')
|
||||
|
||||
router = APIRouter(prefix="/api/manual-import", tags=["Manual Import"])
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# JOB TRACKING FOR BACKGROUND PROCESSING
|
||||
# ============================================================================
|
||||
|
||||
# In-memory job tracking (jobs are transient - cleared on restart)
|
||||
_import_jobs: Dict[str, Dict] = {}
|
||||
_jobs_lock = Lock()
|
||||
|
||||
|
||||
def get_job_status(job_id: str) -> Optional[Dict]:
|
||||
"""Get the current status of an import job."""
|
||||
with _jobs_lock:
|
||||
return _import_jobs.get(job_id)
|
||||
|
||||
|
||||
def update_job_status(job_id: str, updates: Dict):
|
||||
"""Update an import job's status."""
|
||||
with _jobs_lock:
|
||||
if job_id in _import_jobs:
|
||||
_import_jobs[job_id].update(updates)
|
||||
|
||||
|
||||
def create_job(job_id: str, total_files: int, service_name: str):
|
||||
"""Create a new import job."""
|
||||
with _jobs_lock:
|
||||
_import_jobs[job_id] = {
|
||||
'id': job_id,
|
||||
'status': 'processing',
|
||||
'service_name': service_name,
|
||||
'total_files': total_files,
|
||||
'processed_files': 0,
|
||||
'success_count': 0,
|
||||
'failed_count': 0,
|
||||
'results': [],
|
||||
'current_file': None,
|
||||
'started_at': datetime.now().isoformat(),
|
||||
'completed_at': None
|
||||
}
|
||||
|
||||
|
||||
def cleanup_old_jobs():
|
||||
"""Remove jobs older than 1 hour."""
|
||||
with _jobs_lock:
|
||||
now = datetime.now()
|
||||
to_remove = []
|
||||
for job_id, job in _import_jobs.items():
|
||||
if job.get('completed_at'):
|
||||
try:
|
||||
completed = datetime.fromisoformat(job['completed_at'])
|
||||
if (now - completed).total_seconds() > 3600: # 1 hour
|
||||
to_remove.append(job_id)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
for job_id in to_remove:
|
||||
del _import_jobs[job_id]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PYDANTIC MODELS
|
||||
# ============================================================================
|
||||
|
||||
class ParseFilenameRequest(BaseModel):
|
||||
filename: str
|
||||
pattern: str
|
||||
|
||||
|
||||
class FileInfo(BaseModel):
|
||||
filename: str
|
||||
temp_path: str
|
||||
manual_datetime: Optional[str] = None
|
||||
manual_username: Optional[str] = None
|
||||
|
||||
|
||||
class ProcessFilesRequest(BaseModel):
|
||||
service_name: str
|
||||
files: List[dict]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# HELPER FUNCTIONS
|
||||
# ============================================================================
|
||||
|
||||
def extract_youtube_metadata(video_id: str) -> Optional[Dict]:
|
||||
"""Extract metadata from YouTube video using yt-dlp."""
|
||||
import subprocess
|
||||
import json
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
'/opt/media-downloader/venv/bin/yt-dlp',
|
||||
'--dump-json',
|
||||
'--no-download',
|
||||
'--no-warnings',
|
||||
f'https://www.youtube.com/watch?v={video_id}'
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
|
||||
metadata = json.loads(result.stdout)
|
||||
|
||||
upload_date = None
|
||||
if 'upload_date' in metadata and metadata['upload_date']:
|
||||
try:
|
||||
upload_date = datetime.strptime(metadata['upload_date'], '%Y%m%d')
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return {
|
||||
'title': metadata.get('title', ''),
|
||||
'uploader': metadata.get('uploader', metadata.get('channel', '')),
|
||||
'channel': metadata.get('channel', metadata.get('uploader', '')),
|
||||
'upload_date': upload_date,
|
||||
'duration': metadata.get('duration'),
|
||||
'view_count': metadata.get('view_count'),
|
||||
'description': metadata.get('description', '')[:200] if metadata.get('description') else None
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract YouTube metadata for {video_id}: {e}", module="ManualImport")
|
||||
return None
|
||||
|
||||
|
||||
def extract_video_id_from_filename(filename: str) -> Optional[str]:
|
||||
"""Try to extract YouTube video ID from filename."""
|
||||
import re
|
||||
|
||||
name = Path(filename).stem
|
||||
|
||||
# Pattern 1: ID in brackets [ID]
|
||||
bracket_match = re.search(r'\[([A-Za-z0-9_-]{11})\]', name)
|
||||
if bracket_match:
|
||||
return bracket_match.group(1)
|
||||
|
||||
# Pattern 2: ID at end after underscore
|
||||
underscore_match = re.search(r'_([A-Za-z0-9_-]{11})$', name)
|
||||
if underscore_match:
|
||||
return underscore_match.group(1)
|
||||
|
||||
# Pattern 3: Just the ID (filename is exactly 11 chars)
|
||||
if re.match(r'^[A-Za-z0-9_-]{11}$', name):
|
||||
return name
|
||||
|
||||
# Pattern 4: ID somewhere in the filename
|
||||
id_match = re.search(r'(?:^|[_\-\s])([A-Za-z0-9_-]{11})(?:[_\-\s.]|$)', name)
|
||||
if id_match:
|
||||
return id_match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/services")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_manual_import_services(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get configured manual import services."""
|
||||
app_state = get_app_state()
|
||||
config = app_state.settings.get('manual_import')
|
||||
|
||||
if not config:
|
||||
return {
|
||||
"enabled": False,
|
||||
"temp_dir": "/opt/media-downloader/temp/manual_import",
|
||||
"services": [],
|
||||
"preset_patterns": get_preset_patterns()
|
||||
}
|
||||
|
||||
config['preset_patterns'] = get_preset_patterns()
|
||||
return config
|
||||
|
||||
|
||||
@router.post("/parse")
|
||||
@limiter.limit("100/minute")
|
||||
@handle_exceptions
|
||||
async def parse_filename(
|
||||
request: Request,
|
||||
body: ParseFilenameRequest,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Parse a filename using a pattern and return extracted metadata."""
|
||||
try:
|
||||
parser = FilenameParser(body.pattern)
|
||||
result = parser.parse(body.filename)
|
||||
|
||||
if result['datetime']:
|
||||
result['datetime'] = result['datetime'].isoformat()
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing filename: {e}", module="ManualImport")
|
||||
return {
|
||||
"valid": False,
|
||||
"error": str(e),
|
||||
"username": None,
|
||||
"datetime": None,
|
||||
"media_id": None
|
||||
}
|
||||
|
||||
|
||||
# File upload constants
|
||||
MAX_FILE_SIZE = 5 * 1024 * 1024 * 1024 # 5GB max file size
|
||||
MAX_FILENAME_LENGTH = 255
|
||||
ALLOWED_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.gif', '.mp4', '.mov', '.avi', '.mkv', '.webm', '.webp', '.heic', '.heif'}
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def upload_files_for_import(
|
||||
request: Request,
|
||||
files: List[UploadFile] = File(...),
|
||||
service_name: str = Form(...),
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Upload files to temp directory for manual import."""
|
||||
app_state = get_app_state()
|
||||
|
||||
config = app_state.settings.get('manual_import')
|
||||
if not config or not config.get('enabled'):
|
||||
raise ValidationError("Manual import is not enabled")
|
||||
|
||||
services = config.get('services', [])
|
||||
service = next((s for s in services if s['name'] == service_name and s.get('enabled', True)), None)
|
||||
if not service:
|
||||
raise NotFoundError(f"Service '{service_name}' not found or disabled")
|
||||
|
||||
session_id = str(uuid.uuid4())[:8]
|
||||
temp_base = Path(config.get('temp_dir', '/opt/media-downloader/temp/manual_import'))
|
||||
temp_dir = temp_base / session_id
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
pattern = service.get('filename_pattern', '{username}_{YYYYMMDD}_{HHMMSS}_{id}')
|
||||
platform = service.get('platform', 'unknown')
|
||||
|
||||
# Use fallback patterns for Instagram (handles both underscore and dash formats)
|
||||
use_fallbacks = platform == 'instagram'
|
||||
parser = FilenameParser(pattern) if not use_fallbacks else None
|
||||
|
||||
uploaded_files = []
|
||||
for file in files:
|
||||
# Sanitize filename - use only the basename to prevent path traversal
|
||||
safe_filename = Path(file.filename).name
|
||||
|
||||
# Validate filename length
|
||||
if len(safe_filename) > MAX_FILENAME_LENGTH:
|
||||
raise ValidationError(f"Filename too long: {safe_filename[:50]}... (max {MAX_FILENAME_LENGTH} chars)")
|
||||
|
||||
# Validate file extension
|
||||
file_ext = Path(safe_filename).suffix.lower()
|
||||
if file_ext not in ALLOWED_EXTENSIONS:
|
||||
raise ValidationError(f"File type not allowed: {file_ext}. Allowed: {', '.join(sorted(ALLOWED_EXTENSIONS))}")
|
||||
|
||||
file_path = temp_dir / safe_filename
|
||||
content = await file.read()
|
||||
|
||||
# Validate file size
|
||||
if len(content) > MAX_FILE_SIZE:
|
||||
raise ValidationError(f"File too large: {safe_filename} ({len(content) / (1024*1024*1024):.2f}GB, max {MAX_FILE_SIZE / (1024*1024*1024)}GB)")
|
||||
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(content)
|
||||
|
||||
# Parse filename - use fallback patterns for Instagram
|
||||
if use_fallbacks:
|
||||
parse_result = parse_with_fallbacks(file.filename, INSTAGRAM_PATTERNS)
|
||||
else:
|
||||
parse_result = parser.parse(file.filename)
|
||||
|
||||
parsed_datetime = None
|
||||
if parse_result['datetime']:
|
||||
parsed_datetime = parse_result['datetime'].isoformat()
|
||||
|
||||
uploaded_files.append({
|
||||
"filename": file.filename,
|
||||
"temp_path": str(file_path),
|
||||
"size": len(content),
|
||||
"parsed": {
|
||||
"valid": parse_result['valid'],
|
||||
"username": parse_result['username'],
|
||||
"datetime": parsed_datetime,
|
||||
"media_id": parse_result['media_id'],
|
||||
"error": parse_result['error']
|
||||
}
|
||||
})
|
||||
|
||||
logger.info(f"Uploaded {len(uploaded_files)} files for manual import (service: {service_name})", module="ManualImport")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"service_name": service_name,
|
||||
"files": uploaded_files,
|
||||
"temp_dir": str(temp_dir)
|
||||
}
|
||||
|
||||
|
||||
def process_files_background(
|
||||
job_id: str,
|
||||
service_name: str,
|
||||
files: List[dict],
|
||||
service: dict,
|
||||
app_state
|
||||
):
|
||||
"""Background task to process imported files."""
|
||||
import hashlib
|
||||
from modules.move_module import MoveManager
|
||||
|
||||
destination = Path(service['destination'])
|
||||
destination.mkdir(parents=True, exist_ok=True)
|
||||
pattern = service.get('filename_pattern', '{username}_{YYYYMMDD}_{HHMMSS}_{id}')
|
||||
|
||||
platform = service.get('platform', 'unknown')
|
||||
content_type = service.get('content_type', 'videos')
|
||||
use_ytdlp = service.get('use_ytdlp', False)
|
||||
parse_filename = service.get('parse_filename', True)
|
||||
|
||||
# Use fallback patterns for Instagram
|
||||
use_fallbacks = platform == 'instagram'
|
||||
parser = FilenameParser(pattern) if not use_fallbacks else None
|
||||
|
||||
# Generate session ID for real-time monitoring
|
||||
session_id = f"manual_import_{service_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
|
||||
# Emit scraper_started event
|
||||
if app_state.scraper_event_emitter:
|
||||
app_state.scraper_event_emitter.emit_scraper_started(
|
||||
session_id=session_id,
|
||||
platform=platform,
|
||||
account=service_name,
|
||||
content_type=content_type,
|
||||
estimated_count=len(files)
|
||||
)
|
||||
|
||||
move_manager = MoveManager(
|
||||
unified_db=app_state.db,
|
||||
face_recognition_enabled=False,
|
||||
notifier=None,
|
||||
event_emitter=app_state.scraper_event_emitter
|
||||
)
|
||||
|
||||
move_manager.set_session_context(
|
||||
platform=platform,
|
||||
account=service_name,
|
||||
session_id=session_id
|
||||
)
|
||||
|
||||
results = []
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for idx, file_info in enumerate(files):
|
||||
temp_path = Path(file_info['temp_path'])
|
||||
filename = file_info['filename']
|
||||
manual_datetime_str = file_info.get('manual_datetime')
|
||||
manual_username = file_info.get('manual_username')
|
||||
|
||||
# Update job status with current file
|
||||
update_job_status(job_id, {
|
||||
'current_file': filename,
|
||||
'processed_files': idx
|
||||
})
|
||||
|
||||
if not temp_path.exists():
|
||||
result = {"filename": filename, "status": "error", "error": "File not found in temp directory"}
|
||||
results.append(result)
|
||||
failed_count += 1
|
||||
update_job_status(job_id, {'results': results.copy(), 'failed_count': failed_count})
|
||||
continue
|
||||
|
||||
username = 'unknown'
|
||||
parsed_datetime = None
|
||||
final_filename = filename
|
||||
|
||||
# Use manual values if provided
|
||||
if not parse_filename or manual_datetime_str or manual_username:
|
||||
if manual_username:
|
||||
username = manual_username.strip().lower()
|
||||
if manual_datetime_str:
|
||||
try:
|
||||
parsed_datetime = datetime.strptime(manual_datetime_str, '%Y-%m-%dT%H:%M')
|
||||
except ValueError:
|
||||
try:
|
||||
parsed_datetime = datetime.fromisoformat(manual_datetime_str)
|
||||
except ValueError:
|
||||
logger.warning(f"Could not parse manual datetime: {manual_datetime_str}", module="ManualImport")
|
||||
|
||||
# Try yt-dlp for YouTube videos
|
||||
if use_ytdlp and platform == 'youtube':
|
||||
video_id = extract_video_id_from_filename(filename)
|
||||
if video_id:
|
||||
logger.info(f"Extracting YouTube metadata for video ID: {video_id}", module="ManualImport")
|
||||
yt_metadata = extract_youtube_metadata(video_id)
|
||||
if yt_metadata:
|
||||
username = yt_metadata.get('channel') or yt_metadata.get('uploader') or 'unknown'
|
||||
username = "".join(c for c in username if c.isalnum() or c in ' _-').strip().replace(' ', '_')
|
||||
parsed_datetime = yt_metadata.get('upload_date')
|
||||
if yt_metadata.get('title'):
|
||||
title = yt_metadata['title'][:50]
|
||||
title = "".join(c for c in title if c.isalnum() or c in ' _-').strip().replace(' ', '_')
|
||||
ext = Path(filename).suffix
|
||||
final_filename = f"{username}_{parsed_datetime.strftime('%Y%m%d') if parsed_datetime else 'unknown'}_{title}_{video_id}{ext}"
|
||||
|
||||
# Fall back to filename parsing
|
||||
if parse_filename and username == 'unknown':
|
||||
if use_fallbacks:
|
||||
parse_result = parse_with_fallbacks(filename, INSTAGRAM_PATTERNS)
|
||||
else:
|
||||
parse_result = parser.parse(filename)
|
||||
if parse_result['valid']:
|
||||
username = parse_result['username'] or 'unknown'
|
||||
parsed_datetime = parse_result['datetime']
|
||||
elif not use_ytdlp:
|
||||
result = {"filename": filename, "status": "error", "error": parse_result['error'] or "Failed to parse filename"}
|
||||
results.append(result)
|
||||
failed_count += 1
|
||||
update_job_status(job_id, {'results': results.copy(), 'failed_count': failed_count})
|
||||
continue
|
||||
|
||||
dest_subdir = destination / username
|
||||
dest_subdir.mkdir(parents=True, exist_ok=True)
|
||||
dest_path = dest_subdir / final_filename
|
||||
|
||||
move_manager.start_batch(
|
||||
platform=platform,
|
||||
source=username,
|
||||
content_type=content_type
|
||||
)
|
||||
|
||||
file_size = temp_path.stat().st_size if temp_path.exists() else 0
|
||||
|
||||
try:
|
||||
success = move_manager.move_file(
|
||||
source=temp_path,
|
||||
destination=dest_path,
|
||||
timestamp=parsed_datetime,
|
||||
preserve_if_no_timestamp=True,
|
||||
content_type=content_type
|
||||
)
|
||||
|
||||
move_manager.end_batch()
|
||||
|
||||
if success:
|
||||
url_hash = hashlib.sha256(f"manual_import:{final_filename}".encode()).hexdigest()
|
||||
|
||||
with app_state.db.get_connection(for_write=True) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
INSERT OR REPLACE INTO downloads
|
||||
(url_hash, url, platform, source, content_type, filename, file_path,
|
||||
file_size, file_hash, post_date, download_date, status, media_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 'completed', ?)
|
||||
""", (
|
||||
url_hash,
|
||||
f"manual_import://{final_filename}",
|
||||
platform,
|
||||
username,
|
||||
content_type,
|
||||
final_filename,
|
||||
str(dest_path),
|
||||
file_size,
|
||||
None,
|
||||
parsed_datetime.isoformat() if parsed_datetime else None,
|
||||
datetime.now().isoformat(),
|
||||
None
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
result = {
|
||||
"filename": filename,
|
||||
"status": "success",
|
||||
"destination": str(dest_path),
|
||||
"username": username,
|
||||
"datetime": parsed_datetime.isoformat() if parsed_datetime else None
|
||||
}
|
||||
results.append(result)
|
||||
success_count += 1
|
||||
else:
|
||||
result = {"filename": filename, "status": "error", "error": "Failed to move file (possibly duplicate)"}
|
||||
results.append(result)
|
||||
failed_count += 1
|
||||
|
||||
except Exception as e:
|
||||
move_manager.end_batch()
|
||||
result = {"filename": filename, "status": "error", "error": str(e)}
|
||||
results.append(result)
|
||||
failed_count += 1
|
||||
|
||||
# Update job status after each file
|
||||
update_job_status(job_id, {
|
||||
'results': results.copy(),
|
||||
'success_count': success_count,
|
||||
'failed_count': failed_count,
|
||||
'processed_files': idx + 1
|
||||
})
|
||||
|
||||
# Clean up temp directory
|
||||
try:
|
||||
temp_parent = Path(files[0]['temp_path']).parent if files else None
|
||||
if temp_parent and temp_parent.exists():
|
||||
for f in temp_parent.iterdir():
|
||||
f.unlink()
|
||||
temp_parent.rmdir()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Emit scraper_completed event
|
||||
if app_state.scraper_event_emitter:
|
||||
app_state.scraper_event_emitter.emit_scraper_completed(
|
||||
session_id=session_id,
|
||||
stats={
|
||||
'total_downloaded': len(files),
|
||||
'moved': success_count,
|
||||
'review': 0,
|
||||
'duplicates': 0,
|
||||
'failed': failed_count
|
||||
}
|
||||
)
|
||||
|
||||
# Mark job as complete
|
||||
update_job_status(job_id, {
|
||||
'status': 'completed',
|
||||
'completed_at': datetime.now().isoformat(),
|
||||
'current_file': None
|
||||
})
|
||||
|
||||
logger.info(f"Manual import complete: {success_count} succeeded, {failed_count} failed", module="ManualImport")
|
||||
|
||||
# Cleanup old jobs
|
||||
cleanup_old_jobs()
|
||||
|
||||
|
||||
@router.post("/process")
|
||||
@limiter.limit("10/minute")
|
||||
@handle_exceptions
|
||||
async def process_imported_files(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
body: ProcessFilesRequest,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Process uploaded files in the background - returns immediately with job ID."""
|
||||
app_state = get_app_state()
|
||||
|
||||
config = app_state.settings.get('manual_import')
|
||||
if not config or not config.get('enabled'):
|
||||
raise ValidationError("Manual import is not enabled")
|
||||
|
||||
services = config.get('services', [])
|
||||
service = next((s for s in services if s['name'] == body.service_name and s.get('enabled', True)), None)
|
||||
if not service:
|
||||
raise NotFoundError(f"Service '{body.service_name}' not found")
|
||||
|
||||
# Generate unique job ID
|
||||
job_id = f"import_{uuid.uuid4().hex[:12]}"
|
||||
|
||||
# Create job tracking entry
|
||||
create_job(job_id, len(body.files), body.service_name)
|
||||
|
||||
# Queue background processing
|
||||
background_tasks.add_task(
|
||||
process_files_background,
|
||||
job_id,
|
||||
body.service_name,
|
||||
body.files,
|
||||
service,
|
||||
app_state
|
||||
)
|
||||
|
||||
logger.info(f"Manual import job {job_id} queued: {len(body.files)} files for {body.service_name}", module="ManualImport")
|
||||
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"status": "processing",
|
||||
"total_files": len(body.files),
|
||||
"message": "Processing started in background"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/status/{job_id}")
|
||||
@limiter.limit("120/minute")
|
||||
@handle_exceptions
|
||||
async def get_import_job_status(
|
||||
request: Request,
|
||||
job_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get the status of a manual import job."""
|
||||
job = get_job_status(job_id)
|
||||
|
||||
if not job:
|
||||
raise NotFoundError(f"Job '{job_id}' not found")
|
||||
|
||||
return job
|
||||
|
||||
|
||||
@router.delete("/temp")
|
||||
@limiter.limit("10/minute")
|
||||
@handle_exceptions
|
||||
async def clear_temp_directory(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Clear all files from manual import temp directory."""
|
||||
app_state = get_app_state()
|
||||
|
||||
config = app_state.settings.get('manual_import')
|
||||
temp_dir = Path(config.get('temp_dir', '/opt/media-downloader/temp/manual_import')) if config else Path('/opt/media-downloader/temp/manual_import')
|
||||
|
||||
if temp_dir.exists():
|
||||
shutil.rmtree(temp_dir)
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info("Cleared manual import temp directory", module="ManualImport")
|
||||
return {"status": "success", "message": "Temp directory cleared"}
|
||||
|
||||
|
||||
@router.get("/preset-patterns")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_preset_filename_patterns(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get available preset filename patterns."""
|
||||
return {"patterns": get_preset_patterns()}
|
||||
1404
web/backend/routers/media.py
Normal file
1404
web/backend/routers/media.py
Normal file
File diff suppressed because it is too large
Load Diff
6256
web/backend/routers/paid_content.py
Normal file
6256
web/backend/routers/paid_content.py
Normal file
File diff suppressed because it is too large
Load Diff
1475
web/backend/routers/platforms.py
Normal file
1475
web/backend/routers/platforms.py
Normal file
File diff suppressed because it is too large
Load Diff
1098
web/backend/routers/press.py
Normal file
1098
web/backend/routers/press.py
Normal file
File diff suppressed because it is too large
Load Diff
7629
web/backend/routers/private_gallery.py
Normal file
7629
web/backend/routers/private_gallery.py
Normal file
File diff suppressed because it is too large
Load Diff
602
web/backend/routers/recycle.py
Normal file
602
web/backend/routers/recycle.py
Normal file
@@ -0,0 +1,602 @@
|
||||
"""
|
||||
Recycle Bin Router
|
||||
|
||||
Handles all recycle bin operations:
|
||||
- List deleted files
|
||||
- Recycle bin statistics
|
||||
- Restore files
|
||||
- Permanently delete files
|
||||
- Empty recycle bin
|
||||
- Serve files for preview
|
||||
- Get file metadata
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import mimetypes
|
||||
import sqlite3
|
||||
from typing import Dict, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Body, Query, Request
|
||||
from fastapi.responses import FileResponse, Response
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from ..core.dependencies import get_current_user, get_current_user_media, require_admin, get_app_state
|
||||
from ..core.config import settings
|
||||
from ..core.exceptions import (
|
||||
handle_exceptions,
|
||||
DatabaseError,
|
||||
RecordNotFoundError,
|
||||
MediaFileNotFoundError as CustomFileNotFoundError,
|
||||
FileOperationError
|
||||
)
|
||||
from ..core.responses import now_iso8601
|
||||
from ..core.utils import ThumbnailLRUCache
|
||||
from modules.universal_logger import get_logger
|
||||
|
||||
logger = get_logger('API')
|
||||
|
||||
router = APIRouter(prefix="/api/recycle", tags=["Recycle Bin"])
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
# Global thumbnail memory cache for recycle bin (500 items or 100MB max)
|
||||
# Using shared ThumbnailLRUCache from core/utils.py
|
||||
_thumbnail_cache = ThumbnailLRUCache(max_size=500, max_memory_mb=100)
|
||||
|
||||
|
||||
@router.get("/list")
|
||||
@limiter.limit("100/minute")
|
||||
@handle_exceptions
|
||||
async def list_recycle_bin(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user),
|
||||
deleted_from: Optional[str] = None,
|
||||
platform: Optional[str] = None,
|
||||
source: Optional[str] = None,
|
||||
search: Optional[str] = None,
|
||||
media_type: Optional[str] = None,
|
||||
date_from: Optional[str] = None,
|
||||
date_to: Optional[str] = None,
|
||||
size_min: Optional[int] = None,
|
||||
size_max: Optional[int] = None,
|
||||
sort_by: str = Query('download_date', pattern='^(deleted_at|file_size|filename|deleted_from|download_date|post_date|confidence)$'),
|
||||
sort_order: str = Query('desc', pattern='^(asc|desc)$'),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0)
|
||||
):
|
||||
"""
|
||||
List files in recycle bin.
|
||||
|
||||
Args:
|
||||
deleted_from: Filter by source (downloads, media, review)
|
||||
platform: Filter by platform (instagram, tiktok, etc.)
|
||||
source: Filter by source/username
|
||||
search: Search in filename
|
||||
media_type: Filter by type (image, video)
|
||||
date_from: Filter by deletion date (YYYY-MM-DD)
|
||||
date_to: Filter by deletion date (YYYY-MM-DD)
|
||||
size_min: Minimum file size in bytes
|
||||
size_max: Maximum file size in bytes
|
||||
sort_by: Column to sort by
|
||||
sort_order: Sort direction (asc, desc)
|
||||
limit: Maximum items to return
|
||||
offset: Number of items to skip
|
||||
"""
|
||||
app_state = get_app_state()
|
||||
db = app_state.db
|
||||
|
||||
if not db:
|
||||
raise DatabaseError("Database not initialized")
|
||||
|
||||
result = db.list_recycle_bin(
|
||||
deleted_from=deleted_from,
|
||||
platform=platform,
|
||||
source=source,
|
||||
search=search,
|
||||
media_type=media_type,
|
||||
date_from=date_from,
|
||||
date_to=date_to,
|
||||
size_min=size_min,
|
||||
size_max=size_max,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"items": result['items'],
|
||||
"total": result['total']
|
||||
}
|
||||
|
||||
|
||||
@router.get("/filters")
|
||||
@limiter.limit("100/minute")
|
||||
@handle_exceptions
|
||||
async def get_recycle_filters(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user),
|
||||
platform: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Get available filter options for recycle bin.
|
||||
|
||||
Args:
|
||||
platform: If provided, only return sources for this platform
|
||||
"""
|
||||
app_state = get_app_state()
|
||||
db = app_state.db
|
||||
|
||||
if not db:
|
||||
raise DatabaseError("Database not initialized")
|
||||
|
||||
filters = db.get_recycle_bin_filters(platform=platform)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"platforms": filters['platforms'],
|
||||
"sources": filters['sources']
|
||||
}
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
@limiter.limit("100/minute")
|
||||
@handle_exceptions
|
||||
async def get_recycle_bin_stats(request: Request, current_user: Dict = Depends(get_current_user)):
|
||||
"""
|
||||
Get recycle bin statistics.
|
||||
|
||||
Returns total count, total size, and breakdown by deleted_from source.
|
||||
"""
|
||||
app_state = get_app_state()
|
||||
db = app_state.db
|
||||
|
||||
if not db:
|
||||
raise DatabaseError("Database not initialized")
|
||||
|
||||
stats = db.get_recycle_bin_stats()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"stats": stats,
|
||||
"timestamp": now_iso8601()
|
||||
}
|
||||
|
||||
|
||||
@router.post("/restore")
|
||||
@limiter.limit("20/minute")
|
||||
@handle_exceptions
|
||||
async def restore_from_recycle(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user),
|
||||
recycle_id: str = Body(..., embed=True)
|
||||
):
|
||||
"""
|
||||
Restore a file from recycle bin to its original location.
|
||||
|
||||
The file will be moved back to its original path and re-registered
|
||||
in the file_inventory table.
|
||||
"""
|
||||
app_state = get_app_state()
|
||||
db = app_state.db
|
||||
|
||||
if not db:
|
||||
raise DatabaseError("Database not initialized")
|
||||
|
||||
success = db.restore_from_recycle_bin(recycle_id)
|
||||
|
||||
if success:
|
||||
# Broadcast update to connected clients
|
||||
try:
|
||||
# app_state already retrieved above, use it for websocket broadcast
|
||||
if hasattr(app_state, 'websocket_manager') and app_state.websocket_manager:
|
||||
await app_state.websocket_manager.broadcast({
|
||||
"type": "recycle_restore_completed",
|
||||
"recycle_id": recycle_id,
|
||||
"timestamp": now_iso8601()
|
||||
})
|
||||
except Exception:
|
||||
pass # Broadcasting is optional
|
||||
|
||||
logger.info(f"Restored file from recycle bin: {recycle_id}", module="Recycle")
|
||||
return {
|
||||
"success": True,
|
||||
"message": "File restored successfully",
|
||||
"recycle_id": recycle_id
|
||||
}
|
||||
else:
|
||||
raise FileOperationError(
|
||||
"Failed to restore file",
|
||||
{"recycle_id": recycle_id}
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/delete/{recycle_id}")
|
||||
@limiter.limit("20/minute")
|
||||
@handle_exceptions
|
||||
async def permanently_delete_from_recycle(
|
||||
request: Request,
|
||||
recycle_id: str,
|
||||
current_user: Dict = Depends(require_admin)
|
||||
):
|
||||
"""
|
||||
Permanently delete a file from recycle bin.
|
||||
|
||||
**Admin only** - This action cannot be undone. The file will be
|
||||
removed from disk permanently.
|
||||
"""
|
||||
app_state = get_app_state()
|
||||
db = app_state.db
|
||||
|
||||
if not db:
|
||||
raise DatabaseError("Database not initialized")
|
||||
|
||||
success = db.permanently_delete_from_recycle_bin(recycle_id)
|
||||
|
||||
if success:
|
||||
# Broadcast update
|
||||
try:
|
||||
if hasattr(app_state, 'websocket_manager') and app_state.websocket_manager:
|
||||
await app_state.websocket_manager.broadcast({
|
||||
"type": "recycle_delete_completed",
|
||||
"recycle_id": recycle_id,
|
||||
"timestamp": now_iso8601()
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(f"Permanently deleted file from recycle: {recycle_id}", module="Recycle")
|
||||
return {
|
||||
"success": True,
|
||||
"message": "File permanently deleted",
|
||||
"recycle_id": recycle_id
|
||||
}
|
||||
else:
|
||||
raise FileOperationError(
|
||||
"Failed to delete file",
|
||||
{"recycle_id": recycle_id}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/empty")
|
||||
@limiter.limit("5/minute")
|
||||
@handle_exceptions
|
||||
async def empty_recycle_bin(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(require_admin), # Require admin for destructive operation
|
||||
older_than_days: Optional[int] = Body(None, embed=True)
|
||||
):
|
||||
"""
|
||||
Empty recycle bin.
|
||||
|
||||
Args:
|
||||
older_than_days: Only delete files older than X days.
|
||||
If not specified, all files are deleted.
|
||||
"""
|
||||
app_state = get_app_state()
|
||||
db = app_state.db
|
||||
|
||||
if not db:
|
||||
raise DatabaseError("Database not initialized")
|
||||
|
||||
deleted_count = db.empty_recycle_bin(older_than_days=older_than_days)
|
||||
|
||||
# Broadcast update
|
||||
try:
|
||||
if hasattr(app_state, 'websocket_manager') and app_state.websocket_manager:
|
||||
await app_state.websocket_manager.broadcast({
|
||||
"type": "recycle_emptied",
|
||||
"deleted_count": deleted_count,
|
||||
"timestamp": now_iso8601()
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(f"Emptied recycle bin: {deleted_count} files deleted", module="Recycle")
|
||||
return {
|
||||
"success": True,
|
||||
"deleted_count": deleted_count,
|
||||
"older_than_days": older_than_days
|
||||
}
|
||||
|
||||
|
||||
@router.get("/file/{recycle_id}")
|
||||
@limiter.limit("5000/minute")
|
||||
@handle_exceptions
|
||||
async def get_recycle_file(
|
||||
request: Request,
|
||||
recycle_id: str,
|
||||
thumbnail: bool = False,
|
||||
type: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
current_user: Dict = Depends(get_current_user_media)
|
||||
):
|
||||
"""
|
||||
Serve a file from recycle bin for preview.
|
||||
|
||||
Args:
|
||||
recycle_id: ID of the recycle bin record
|
||||
thumbnail: If True, return a thumbnail instead of the full file
|
||||
type: Media type hint (image/video)
|
||||
"""
|
||||
app_state = get_app_state()
|
||||
db = app_state.db
|
||||
|
||||
if not db:
|
||||
raise DatabaseError("Database not initialized")
|
||||
|
||||
# Get recycle bin record
|
||||
with db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
'SELECT recycle_path, original_path, original_filename, file_hash FROM recycle_bin WHERE id = ?',
|
||||
(recycle_id,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
raise RecordNotFoundError(
|
||||
"File not found in recycle bin",
|
||||
{"recycle_id": recycle_id}
|
||||
)
|
||||
|
||||
file_path = Path(row['recycle_path'])
|
||||
original_path = row['original_path'] # Path where thumbnail was originally cached
|
||||
if not file_path.exists():
|
||||
raise CustomFileNotFoundError(
|
||||
"Physical file not found",
|
||||
{"path": str(file_path)}
|
||||
)
|
||||
|
||||
# If thumbnail requested, use 3-tier caching
|
||||
# Use content hash as cache key so thumbnails survive file moves
|
||||
if thumbnail:
|
||||
content_hash = row['file_hash']
|
||||
cache_key = content_hash if content_hash else str(file_path)
|
||||
|
||||
# 1. Check in-memory LRU cache first (fastest)
|
||||
thumbnail_data = _thumbnail_cache.get(cache_key)
|
||||
if thumbnail_data:
|
||||
return Response(
|
||||
content=thumbnail_data,
|
||||
media_type="image/jpeg",
|
||||
headers={
|
||||
"Cache-Control": "public, max-age=86400, immutable",
|
||||
"Vary": "Accept-Encoding"
|
||||
}
|
||||
)
|
||||
|
||||
# 2. Get from database cache or generate on-demand
|
||||
# Pass content hash and original_path for fallback lookup
|
||||
thumbnail_data = _get_or_create_thumbnail(file_path, type or 'image', content_hash, original_path)
|
||||
if not thumbnail_data:
|
||||
raise FileOperationError("Failed to generate thumbnail")
|
||||
|
||||
# 3. Add to in-memory cache for faster subsequent requests
|
||||
_thumbnail_cache.put(cache_key, thumbnail_data)
|
||||
|
||||
return Response(
|
||||
content=thumbnail_data,
|
||||
media_type="image/jpeg",
|
||||
headers={
|
||||
"Cache-Control": "public, max-age=86400, immutable",
|
||||
"Vary": "Accept-Encoding"
|
||||
}
|
||||
)
|
||||
|
||||
# Otherwise serve full file
|
||||
mime_type, _ = mimetypes.guess_type(str(file_path))
|
||||
if not mime_type:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
return FileResponse(
|
||||
path=str(file_path),
|
||||
media_type=mime_type,
|
||||
filename=row['original_filename']
|
||||
)
|
||||
|
||||
|
||||
@router.get("/metadata/{recycle_id}")
|
||||
@limiter.limit("5000/minute")
|
||||
@handle_exceptions
|
||||
async def get_recycle_metadata(
|
||||
request: Request,
|
||||
recycle_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get metadata for a recycle bin file.
|
||||
|
||||
Returns dimensions, size, platform, source, and other metadata.
|
||||
This is fetched on-demand for performance.
|
||||
"""
|
||||
app_state = get_app_state()
|
||||
db = app_state.db
|
||||
|
||||
if not db:
|
||||
raise DatabaseError("Database not initialized")
|
||||
|
||||
# Get recycle bin record
|
||||
with db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
SELECT recycle_path, original_filename, file_size, original_path, metadata
|
||||
FROM recycle_bin WHERE id = ?
|
||||
''', (recycle_id,))
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
raise RecordNotFoundError(
|
||||
"File not found in recycle bin",
|
||||
{"recycle_id": recycle_id}
|
||||
)
|
||||
|
||||
recycle_path = Path(row['recycle_path'])
|
||||
if not recycle_path.exists():
|
||||
raise CustomFileNotFoundError(
|
||||
"Physical file not found",
|
||||
{"path": str(recycle_path)}
|
||||
)
|
||||
|
||||
# Parse metadata for platform/source info
|
||||
platform, source = None, None
|
||||
try:
|
||||
metadata = json.loads(row['metadata']) if row['metadata'] else {}
|
||||
platform = metadata.get('platform')
|
||||
source = metadata.get('source')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Get dimensions dynamically
|
||||
width, height, duration = _extract_dimensions(recycle_path)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"recycle_id": recycle_id,
|
||||
"filename": row['original_filename'],
|
||||
"file_size": row['file_size'],
|
||||
"platform": platform,
|
||||
"source": source,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"duration": duration
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# HELPER FUNCTIONS
|
||||
# ============================================================================
|
||||
|
||||
def _get_or_create_thumbnail(file_path: Path, media_type: str, content_hash: str = None, original_path: str = None) -> Optional[bytes]:
|
||||
"""
|
||||
Get or create a thumbnail for a file.
|
||||
Uses the same caching system as media.py for consistency.
|
||||
|
||||
Uses a 2-step lookup for backwards compatibility:
|
||||
1. Try content hash (new method - survives file moves)
|
||||
2. Fall back to original_path lookup (legacy thumbnails cached before move)
|
||||
|
||||
Args:
|
||||
file_path: Path to the file (current location in recycle bin)
|
||||
media_type: 'image' or 'video'
|
||||
content_hash: Optional content hash (SHA256 of file content) to use for cache lookup.
|
||||
original_path: Optional original file path before moving to recycle bin.
|
||||
"""
|
||||
from PIL import Image
|
||||
import io
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
with sqlite3.connect('thumbnails', timeout=30.0) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 1. Try content hash first (new method - survives file moves)
|
||||
if content_hash:
|
||||
cursor.execute("SELECT thumbnail_data FROM thumbnails WHERE file_hash = ?", (content_hash,))
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
return result[0]
|
||||
|
||||
# 2. Fall back to original_path lookup (legacy thumbnails cached before move)
|
||||
if original_path:
|
||||
cursor.execute("SELECT thumbnail_data FROM thumbnails WHERE file_path = ?", (original_path,))
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
return result[0]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Generate thumbnail
|
||||
thumbnail_data = None
|
||||
try:
|
||||
if media_type == 'video':
|
||||
# For videos, try to extract a frame
|
||||
import subprocess
|
||||
result = subprocess.run([
|
||||
'ffmpeg', '-i', str(file_path),
|
||||
'-ss', '00:00:01', '-vframes', '1',
|
||||
'-f', 'image2pipe', '-vcodec', 'mjpeg', '-'
|
||||
], capture_output=True, timeout=10)
|
||||
|
||||
if result.returncode == 0:
|
||||
img = Image.open(io.BytesIO(result.stdout))
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
img = Image.open(file_path)
|
||||
|
||||
# Convert to RGB if necessary
|
||||
if img.mode in ('RGBA', 'P'):
|
||||
img = img.convert('RGB')
|
||||
|
||||
# Create thumbnail
|
||||
img.thumbnail((300, 300), Image.Resampling.LANCZOS)
|
||||
|
||||
# Save to bytes
|
||||
output = io.BytesIO()
|
||||
img.save(output, format='JPEG', quality=85)
|
||||
thumbnail_data = output.getvalue()
|
||||
|
||||
# Cache the generated thumbnail
|
||||
if thumbnail_data:
|
||||
try:
|
||||
file_mtime = file_path.stat().st_mtime if file_path.exists() else None
|
||||
# Compute file_hash if not provided
|
||||
thumb_file_hash = content_hash if content_hash else hashlib.sha256(str(file_path).encode()).hexdigest()
|
||||
with sqlite3.connect('thumbnails') as conn:
|
||||
conn.execute("""
|
||||
INSERT OR REPLACE INTO thumbnails
|
||||
(file_hash, file_path, thumbnail_data, created_at, file_mtime)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", (thumb_file_hash, str(file_path), thumbnail_data, datetime.now().isoformat(), file_mtime))
|
||||
conn.commit()
|
||||
except Exception:
|
||||
pass # Caching is optional, don't fail if it doesn't work
|
||||
|
||||
return thumbnail_data
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to generate thumbnail: {e}", module="Recycle")
|
||||
return None
|
||||
|
||||
|
||||
def _extract_dimensions(file_path: Path) -> tuple:
|
||||
"""
|
||||
Extract dimensions from a media file.
|
||||
|
||||
Returns: (width, height, duration)
|
||||
"""
|
||||
width, height, duration = None, None, None
|
||||
file_ext = file_path.suffix.lower()
|
||||
|
||||
try:
|
||||
if file_ext in ['.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp', '.heic', '.heif']:
|
||||
from PIL import Image
|
||||
with Image.open(file_path) as img:
|
||||
width, height = img.size
|
||||
|
||||
elif file_ext in ['.mp4', '.mov', '.avi', '.mkv', '.webm', '.m4v']:
|
||||
import subprocess
|
||||
result = subprocess.run([
|
||||
'ffprobe', '-v', 'quiet', '-print_format', 'json',
|
||||
'-show_streams', str(file_path)
|
||||
], capture_output=True, text=True, timeout=10)
|
||||
|
||||
if result.returncode == 0:
|
||||
data = json.loads(result.stdout)
|
||||
for stream in data.get('streams', []):
|
||||
if stream.get('codec_type') == 'video':
|
||||
width = stream.get('width')
|
||||
height = stream.get('height')
|
||||
duration_str = stream.get('duration')
|
||||
if duration_str:
|
||||
duration = float(duration_str)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract dimensions: {e}", module="Recycle")
|
||||
|
||||
return width, height, duration
|
||||
1028
web/backend/routers/review.py
Normal file
1028
web/backend/routers/review.py
Normal file
File diff suppressed because it is too large
Load Diff
758
web/backend/routers/scheduler.py
Normal file
758
web/backend/routers/scheduler.py
Normal file
@@ -0,0 +1,758 @@
|
||||
"""
|
||||
Scheduler Router
|
||||
|
||||
Handles all scheduler and service management operations:
|
||||
- Scheduler status and task management
|
||||
- Current activity monitoring
|
||||
- Task pause/resume/skip operations
|
||||
- Service start/stop/restart
|
||||
- Cache builder service management
|
||||
- Dependency updates
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import sqlite3
|
||||
import subprocess
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from ..core.dependencies import get_current_user, require_admin, get_app_state
|
||||
from ..core.config import settings
|
||||
from ..core.exceptions import (
|
||||
handle_exceptions,
|
||||
RecordNotFoundError,
|
||||
ServiceError
|
||||
)
|
||||
from ..core.responses import now_iso8601
|
||||
from modules.universal_logger import get_logger
|
||||
|
||||
logger = get_logger('API')
|
||||
|
||||
router = APIRouter(prefix="/api/scheduler", tags=["Scheduler"])
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
# Service names
|
||||
SCHEDULER_SERVICE = 'media-downloader.service'
|
||||
CACHE_BUILDER_SERVICE = 'media-cache-builder.service'
|
||||
|
||||
# Valid platform names for subprocess operations (defense in depth)
|
||||
VALID_PLATFORMS = frozenset(['fastdl', 'imginn', 'imginn_api', 'toolzu', 'snapchat', 'tiktok', 'forums', 'coppermine', 'instagram', 'youtube'])
|
||||
|
||||
# Display name mapping for scheduler task_id prefixes
|
||||
PLATFORM_DISPLAY_NAMES = {
|
||||
'fastdl': 'FastDL',
|
||||
'imginn': 'ImgInn',
|
||||
'imginn_api': 'ImgInn API',
|
||||
'toolzu': 'Toolzu',
|
||||
'snapchat': 'Snapchat',
|
||||
'tiktok': 'TikTok',
|
||||
'forums': 'Forums',
|
||||
'forum': 'Forum',
|
||||
'monitor': 'Forum Monitor',
|
||||
'instagram': 'Instagram',
|
||||
'youtube': 'YouTube',
|
||||
'youtube_channel_monitor': 'YouTube Channels',
|
||||
'youtube_monitor': 'YouTube Monitor',
|
||||
'coppermine': 'Coppermine',
|
||||
'paid_content': 'Paid Content',
|
||||
'appearances': 'Appearances',
|
||||
'easynews_monitor': 'Easynews Monitor',
|
||||
'press_monitor': 'Press Monitor',
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SCHEDULER STATUS ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/status")
|
||||
@limiter.limit("100/minute")
|
||||
@handle_exceptions
|
||||
async def get_scheduler_status(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get detailed scheduler status including all tasks."""
|
||||
app_state = get_app_state()
|
||||
|
||||
# Get enabled forums from config to filter scheduler tasks
|
||||
enabled_forums = set()
|
||||
forums_config = app_state.settings.get('forums')
|
||||
if forums_config and isinstance(forums_config, dict):
|
||||
for forum_cfg in forums_config.get('configs', []):
|
||||
if forum_cfg.get('enabled', False):
|
||||
enabled_forums.add(forum_cfg.get('name'))
|
||||
|
||||
with sqlite3.connect('scheduler_state') as sched_conn:
|
||||
cursor = sched_conn.cursor()
|
||||
|
||||
# Get all tasks
|
||||
cursor.execute("""
|
||||
SELECT task_id, last_run, next_run, run_count, status, last_download_count
|
||||
FROM scheduler_state
|
||||
ORDER BY next_run ASC
|
||||
""")
|
||||
tasks_raw = cursor.fetchall()
|
||||
|
||||
# Clean up stale forum/monitor entries
|
||||
stale_task_ids = []
|
||||
# Platforms that should always have :username suffix
|
||||
platforms_requiring_username = {'tiktok', 'instagram', 'imginn', 'imginn_api', 'toolzu', 'snapchat', 'fastdl'}
|
||||
|
||||
for row in tasks_raw:
|
||||
task_id = row[0]
|
||||
if task_id.startswith('forum:') or task_id.startswith('monitor:'):
|
||||
forum_name = task_id.split(':', 1)[1]
|
||||
if forum_name not in enabled_forums:
|
||||
stale_task_ids.append(task_id)
|
||||
# Clean up legacy platform entries without :username suffix
|
||||
elif task_id in platforms_requiring_username:
|
||||
stale_task_ids.append(task_id)
|
||||
|
||||
# Delete stale entries
|
||||
if stale_task_ids:
|
||||
for stale_id in stale_task_ids:
|
||||
cursor.execute("DELETE FROM scheduler_state WHERE task_id = ?", (stale_id,))
|
||||
sched_conn.commit()
|
||||
|
||||
tasks = []
|
||||
for row in tasks_raw:
|
||||
task_id = row[0]
|
||||
# Skip stale and maintenance tasks
|
||||
if task_id in stale_task_ids:
|
||||
continue
|
||||
if task_id.startswith('maintenance:'):
|
||||
continue
|
||||
|
||||
tasks.append({
|
||||
"task_id": task_id,
|
||||
"last_run": row[1],
|
||||
"next_run": row[2],
|
||||
"run_count": row[3],
|
||||
"status": row[4],
|
||||
"last_download_count": row[5]
|
||||
})
|
||||
|
||||
# Count active tasks
|
||||
active_count = sum(1 for t in tasks if t['status'] == 'active')
|
||||
|
||||
# Get next run time
|
||||
next_run = None
|
||||
for task in sorted(tasks, key=lambda t: t['next_run'] or ''):
|
||||
if task['status'] == 'active' and task['next_run']:
|
||||
next_run = task['next_run']
|
||||
break
|
||||
|
||||
return {
|
||||
"running": active_count > 0,
|
||||
"tasks": tasks,
|
||||
"total_tasks": len(tasks),
|
||||
"active_tasks": active_count,
|
||||
"next_run": next_run
|
||||
}
|
||||
|
||||
|
||||
@router.get("/current-activity")
|
||||
@limiter.limit("100/minute")
|
||||
@handle_exceptions
|
||||
async def get_current_activity(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get current scheduler activity for real-time status."""
|
||||
app_state = get_app_state()
|
||||
|
||||
# Check if scheduler service is running
|
||||
result = subprocess.run(
|
||||
['systemctl', 'is-active', SCHEDULER_SERVICE],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
scheduler_running = result.stdout.strip() == 'active'
|
||||
|
||||
if not scheduler_running:
|
||||
return {
|
||||
"active": False,
|
||||
"scheduler_running": False,
|
||||
"task_id": None,
|
||||
"platform": None,
|
||||
"account": None,
|
||||
"start_time": None,
|
||||
"status": None
|
||||
}
|
||||
|
||||
# Get current activity from database
|
||||
from modules.activity_status import get_activity_manager
|
||||
activity_manager = get_activity_manager(app_state.db)
|
||||
activity = activity_manager.get_current_activity()
|
||||
activity["scheduler_running"] = True
|
||||
return activity
|
||||
|
||||
|
||||
@router.get("/background-tasks")
|
||||
@limiter.limit("100/minute")
|
||||
@handle_exceptions
|
||||
async def get_background_tasks(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get all active background tasks (YouTube monitor, etc.) for real-time status."""
|
||||
app_state = get_app_state()
|
||||
|
||||
from modules.activity_status import get_activity_manager
|
||||
activity_manager = get_activity_manager(app_state.db)
|
||||
tasks = activity_manager.get_active_background_tasks()
|
||||
return {"tasks": tasks}
|
||||
|
||||
|
||||
@router.get("/background-tasks/{task_id}")
|
||||
@limiter.limit("100/minute")
|
||||
@handle_exceptions
|
||||
async def get_background_task(
|
||||
task_id: str,
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get a specific background task status."""
|
||||
app_state = get_app_state()
|
||||
|
||||
from modules.activity_status import get_activity_manager
|
||||
activity_manager = get_activity_manager(app_state.db)
|
||||
task = activity_manager.get_background_task(task_id)
|
||||
|
||||
if not task:
|
||||
return {"active": False, "task_id": task_id}
|
||||
return task
|
||||
|
||||
|
||||
@router.post("/current-activity/stop")
|
||||
@limiter.limit("20/minute")
|
||||
@handle_exceptions
|
||||
async def stop_current_activity(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Stop the currently running download task."""
|
||||
app_state = get_app_state()
|
||||
activity_file = settings.PROJECT_ROOT / 'database' / 'current_activity.json'
|
||||
|
||||
if not activity_file.exists():
|
||||
raise RecordNotFoundError("No active task running")
|
||||
|
||||
with open(activity_file, 'r') as f:
|
||||
activity_data = json.load(f)
|
||||
|
||||
if not activity_data.get('active'):
|
||||
raise RecordNotFoundError("No active task running")
|
||||
|
||||
task_id = activity_data.get('task_id')
|
||||
platform = activity_data.get('platform')
|
||||
|
||||
# Security: Validate platform before using in subprocess (defense in depth)
|
||||
if platform and platform not in VALID_PLATFORMS:
|
||||
logger.warning(f"Invalid platform in activity file: {platform}", module="Security")
|
||||
platform = None
|
||||
|
||||
# Find and kill the process
|
||||
if platform:
|
||||
result = subprocess.run(
|
||||
['pgrep', '-f', f'media-downloader\\.py.*--platform.*{re.escape(platform)}'],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
else:
|
||||
# Fallback: find any media-downloader process
|
||||
result = subprocess.run(
|
||||
['pgrep', '-f', 'media-downloader\\.py'],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if result.stdout.strip():
|
||||
pids = [p.strip() for p in result.stdout.strip().split('\n') if p.strip()]
|
||||
for pid in pids:
|
||||
try:
|
||||
os.kill(int(pid), signal.SIGTERM)
|
||||
logger.info(f"Stopped process {pid} for platform {platform}")
|
||||
except (ProcessLookupError, ValueError):
|
||||
pass
|
||||
|
||||
# Clear the current activity
|
||||
inactive_state = {
|
||||
"active": False,
|
||||
"task_id": None,
|
||||
"platform": None,
|
||||
"account": None,
|
||||
"start_time": None,
|
||||
"status": "stopped"
|
||||
}
|
||||
|
||||
with open(activity_file, 'w') as f:
|
||||
json.dump(inactive_state, f)
|
||||
|
||||
# Broadcast stop event
|
||||
try:
|
||||
if hasattr(app_state, 'websocket_manager') and app_state.websocket_manager:
|
||||
await app_state.websocket_manager.broadcast({
|
||||
"type": "download_stopped",
|
||||
"task_id": task_id,
|
||||
"platform": platform,
|
||||
"timestamp": now_iso8601()
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Stopped {platform} download",
|
||||
"task_id": task_id
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TASK MANAGEMENT ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.post("/tasks/{task_id}/pause")
|
||||
@limiter.limit("20/minute")
|
||||
@handle_exceptions
|
||||
async def pause_scheduler_task(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Pause a specific scheduler task."""
|
||||
app_state = get_app_state()
|
||||
|
||||
with sqlite3.connect('scheduler_state') as sched_conn:
|
||||
cursor = sched_conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
UPDATE scheduler_state
|
||||
SET status = 'paused'
|
||||
WHERE task_id = ?
|
||||
""", (task_id,))
|
||||
|
||||
sched_conn.commit()
|
||||
row_count = cursor.rowcount
|
||||
|
||||
if row_count == 0:
|
||||
raise RecordNotFoundError("Task not found", {"task_id": task_id})
|
||||
|
||||
# Broadcast event
|
||||
try:
|
||||
if hasattr(app_state, 'websocket_manager') and app_state.websocket_manager:
|
||||
await app_state.websocket_manager.broadcast({
|
||||
"type": "scheduler_task_paused",
|
||||
"task_id": task_id,
|
||||
"timestamp": now_iso8601()
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {"success": True, "task_id": task_id, "status": "paused"}
|
||||
|
||||
|
||||
@router.post("/tasks/{task_id}/resume")
|
||||
@limiter.limit("20/minute")
|
||||
@handle_exceptions
|
||||
async def resume_scheduler_task(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Resume a paused scheduler task."""
|
||||
app_state = get_app_state()
|
||||
|
||||
with sqlite3.connect('scheduler_state') as sched_conn:
|
||||
cursor = sched_conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
UPDATE scheduler_state
|
||||
SET status = 'active'
|
||||
WHERE task_id = ?
|
||||
""", (task_id,))
|
||||
|
||||
sched_conn.commit()
|
||||
row_count = cursor.rowcount
|
||||
|
||||
if row_count == 0:
|
||||
raise RecordNotFoundError("Task not found", {"task_id": task_id})
|
||||
|
||||
# Broadcast event
|
||||
try:
|
||||
if hasattr(app_state, 'websocket_manager') and app_state.websocket_manager:
|
||||
await app_state.websocket_manager.broadcast({
|
||||
"type": "scheduler_task_resumed",
|
||||
"task_id": task_id,
|
||||
"timestamp": now_iso8601()
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {"success": True, "task_id": task_id, "status": "active"}
|
||||
|
||||
|
||||
@router.post("/tasks/{task_id}/skip")
|
||||
@limiter.limit("20/minute")
|
||||
@handle_exceptions
|
||||
async def skip_next_run(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Skip the next scheduled run by advancing next_run time."""
|
||||
app_state = get_app_state()
|
||||
|
||||
with sqlite3.connect('scheduler_state') as sched_conn:
|
||||
cursor = sched_conn.cursor()
|
||||
|
||||
# Get current task info
|
||||
cursor.execute("""
|
||||
SELECT next_run, interval_hours
|
||||
FROM scheduler_state
|
||||
WHERE task_id = ?
|
||||
""", (task_id,))
|
||||
|
||||
result = cursor.fetchone()
|
||||
if not result:
|
||||
raise RecordNotFoundError("Task not found", {"task_id": task_id})
|
||||
|
||||
current_next_run, interval_hours = result
|
||||
|
||||
# Calculate new next_run time
|
||||
current_time = datetime.fromisoformat(current_next_run)
|
||||
new_next_run = current_time + timedelta(hours=interval_hours)
|
||||
|
||||
# Update the next_run time
|
||||
cursor.execute("""
|
||||
UPDATE scheduler_state
|
||||
SET next_run = ?
|
||||
WHERE task_id = ?
|
||||
""", (new_next_run.isoformat(), task_id))
|
||||
|
||||
sched_conn.commit()
|
||||
|
||||
# Broadcast event
|
||||
try:
|
||||
if hasattr(app_state, 'websocket_manager') and app_state.websocket_manager:
|
||||
await app_state.websocket_manager.broadcast({
|
||||
"type": "scheduler_run_skipped",
|
||||
"task_id": task_id,
|
||||
"new_next_run": new_next_run.isoformat(),
|
||||
"timestamp": now_iso8601()
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"task_id": task_id,
|
||||
"skipped_run": current_next_run,
|
||||
"new_next_run": new_next_run.isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.post("/tasks/{task_id}/reschedule")
|
||||
@limiter.limit("20/minute")
|
||||
@handle_exceptions
|
||||
async def reschedule_task(
|
||||
request: Request,
|
||||
task_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Reschedule a task to a new next_run time."""
|
||||
body = await request.json()
|
||||
new_next_run = body.get('next_run')
|
||||
if not new_next_run:
|
||||
raise HTTPException(status_code=400, detail="next_run is required")
|
||||
|
||||
try:
|
||||
parsed = datetime.fromisoformat(new_next_run)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid datetime format")
|
||||
|
||||
with sqlite3.connect('scheduler_state') as sched_conn:
|
||||
cursor = sched_conn.cursor()
|
||||
cursor.execute(
|
||||
"UPDATE scheduler_state SET next_run = ? WHERE task_id = ?",
|
||||
(parsed.isoformat(), task_id)
|
||||
)
|
||||
sched_conn.commit()
|
||||
if cursor.rowcount == 0:
|
||||
raise RecordNotFoundError("Task not found", {"task_id": task_id})
|
||||
|
||||
# Broadcast event
|
||||
try:
|
||||
app_state = get_app_state()
|
||||
if hasattr(app_state, 'websocket_manager') and app_state.websocket_manager:
|
||||
await app_state.websocket_manager.broadcast({
|
||||
"type": "scheduler_task_rescheduled",
|
||||
"task_id": task_id,
|
||||
"new_next_run": parsed.isoformat(),
|
||||
"timestamp": now_iso8601()
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {"success": True, "task_id": task_id, "new_next_run": parsed.isoformat()}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CONFIG RELOAD ENDPOINT
|
||||
# ============================================================================
|
||||
|
||||
@router.post("/reload-config")
|
||||
@limiter.limit("10/minute")
|
||||
@handle_exceptions
|
||||
async def reload_scheduler_config(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Reload scheduler config — picks up new/removed accounts and interval changes."""
|
||||
app_state = get_app_state()
|
||||
|
||||
if not hasattr(app_state, 'scheduler') or app_state.scheduler is None:
|
||||
raise ServiceError("Scheduler is not running", {"service": SCHEDULER_SERVICE})
|
||||
|
||||
result = app_state.scheduler.reload_scheduled_tasks()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"added": result['added'],
|
||||
"removed": result['removed'],
|
||||
"modified": result['modified'],
|
||||
"message": (
|
||||
f"Reload complete: {len(result['added'])} added, "
|
||||
f"{len(result['removed'])} removed, "
|
||||
f"{len(result['modified'])} modified"
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SERVICE MANAGEMENT ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/service/status")
|
||||
@limiter.limit("100/minute")
|
||||
@handle_exceptions
|
||||
async def get_scheduler_service_status(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Check if scheduler service is running."""
|
||||
result = subprocess.run(
|
||||
['systemctl', 'is-active', SCHEDULER_SERVICE],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
is_running = result.stdout.strip() == 'active'
|
||||
|
||||
return {
|
||||
"running": is_running,
|
||||
"status": result.stdout.strip()
|
||||
}
|
||||
|
||||
|
||||
@router.post("/service/start")
|
||||
@limiter.limit("20/minute")
|
||||
@handle_exceptions
|
||||
async def start_scheduler_service(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(require_admin) # Require admin for service operations
|
||||
):
|
||||
"""Start the scheduler service. Requires admin privileges."""
|
||||
result = subprocess.run(
|
||||
['sudo', 'systemctl', 'start', SCHEDULER_SERVICE],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise ServiceError(
|
||||
f"Failed to start service: {result.stderr}",
|
||||
{"service": SCHEDULER_SERVICE}
|
||||
)
|
||||
|
||||
return {"success": True, "message": "Scheduler service started"}
|
||||
|
||||
|
||||
@router.post("/service/stop")
|
||||
@limiter.limit("20/minute")
|
||||
@handle_exceptions
|
||||
async def stop_scheduler_service(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(require_admin) # Require admin for service operations
|
||||
):
|
||||
"""Stop the scheduler service. Requires admin privileges."""
|
||||
result = subprocess.run(
|
||||
['sudo', 'systemctl', 'stop', SCHEDULER_SERVICE],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise ServiceError(
|
||||
f"Failed to stop service: {result.stderr}",
|
||||
{"service": SCHEDULER_SERVICE}
|
||||
)
|
||||
|
||||
return {"success": True, "message": "Scheduler service stopped"}
|
||||
|
||||
|
||||
@router.post("/service/restart")
|
||||
@limiter.limit("20/minute")
|
||||
@handle_exceptions
|
||||
async def restart_scheduler_service(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(require_admin)
|
||||
):
|
||||
"""Restart the scheduler service. Requires admin privileges."""
|
||||
result = subprocess.run(
|
||||
['sudo', 'systemctl', 'restart', SCHEDULER_SERVICE],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise ServiceError(
|
||||
f"Failed to restart service: {result.stderr}",
|
||||
{"service": SCHEDULER_SERVICE}
|
||||
)
|
||||
|
||||
return {"success": True, "message": "Scheduler service restarted"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DEPENDENCY MANAGEMENT ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/dependencies/status")
|
||||
@limiter.limit("100/minute")
|
||||
@handle_exceptions
|
||||
async def get_dependencies_status(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get dependency update status."""
|
||||
from modules.dependency_updater import DependencyUpdater
|
||||
updater = DependencyUpdater(scheduler_mode=False)
|
||||
status = updater.get_update_status()
|
||||
return status
|
||||
|
||||
|
||||
@router.post("/dependencies/check")
|
||||
@limiter.limit("20/minute")
|
||||
@handle_exceptions
|
||||
async def check_dependencies(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Force check and update all dependencies."""
|
||||
app_state = get_app_state()
|
||||
|
||||
from modules.dependency_updater import DependencyUpdater
|
||||
from modules.pushover_notifier import create_notifier_from_config
|
||||
|
||||
# Get pushover config
|
||||
pushover = None
|
||||
config = app_state.settings.get_all()
|
||||
if config.get('pushover', {}).get('enabled'):
|
||||
pushover = create_notifier_from_config(config, unified_db=app_state.db)
|
||||
|
||||
updater = DependencyUpdater(
|
||||
config=config.get('dependency_updater', {}),
|
||||
pushover_notifier=pushover,
|
||||
scheduler_mode=True
|
||||
)
|
||||
|
||||
results = updater.force_update_check()
|
||||
return {
|
||||
"success": True,
|
||||
"results": results,
|
||||
"message": "Dependency check completed"
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CACHE BUILDER SERVICE ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.post("/cache-builder/trigger")
|
||||
@limiter.limit("10/minute")
|
||||
@handle_exceptions
|
||||
async def trigger_cache_builder(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Manually trigger the thumbnail cache builder service."""
|
||||
result = subprocess.run(
|
||||
['sudo', 'systemctl', 'start', CACHE_BUILDER_SERVICE],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
return {"success": True, "message": "Cache builder started successfully"}
|
||||
else:
|
||||
raise ServiceError(
|
||||
f"Failed to start cache builder: {result.stderr}",
|
||||
{"service": CACHE_BUILDER_SERVICE}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/cache-builder/status")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def get_cache_builder_status(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get detailed cache builder service status."""
|
||||
# Get service status
|
||||
result = subprocess.run(
|
||||
['systemctl', 'status', CACHE_BUILDER_SERVICE, '--no-pager'],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
status_output = result.stdout
|
||||
|
||||
# Parse status
|
||||
is_running = 'Active: active (running)' in status_output
|
||||
is_inactive = 'Active: inactive' in status_output
|
||||
last_run = None
|
||||
next_run = None
|
||||
|
||||
# Try to get timer info
|
||||
timer_result = subprocess.run(
|
||||
['systemctl', 'list-timers', '--no-pager', '--all'],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if CACHE_BUILDER_SERVICE.replace('.service', '') in timer_result.stdout:
|
||||
for line in timer_result.stdout.split('\n'):
|
||||
if CACHE_BUILDER_SERVICE.replace('.service', '') in line:
|
||||
parts = line.split()
|
||||
if len(parts) >= 2:
|
||||
# Extract timing info if available
|
||||
pass
|
||||
|
||||
return {
|
||||
"running": is_running,
|
||||
"inactive": is_inactive,
|
||||
"status_output": status_output[:500], # Truncate for brevity
|
||||
"last_run": last_run,
|
||||
"next_run": next_run
|
||||
}
|
||||
819
web/backend/routers/scrapers.py
Normal file
819
web/backend/routers/scrapers.py
Normal file
@@ -0,0 +1,819 @@
|
||||
"""
|
||||
Scrapers Router
|
||||
|
||||
Handles scraper management and error monitoring:
|
||||
- Scraper configuration (list, get, update)
|
||||
- Cookie management (test connection, upload, clear)
|
||||
- Error tracking (recent, count, dismiss, mark viewed)
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter, Body, Depends, Query, Request
|
||||
from pydantic import BaseModel
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from ..core.dependencies import get_current_user, require_admin, get_app_state
|
||||
from ..core.exceptions import handle_exceptions, NotFoundError, ValidationError
|
||||
from modules.universal_logger import get_logger
|
||||
|
||||
logger = get_logger('API')
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["Scrapers"])
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PYDANTIC MODELS
|
||||
# ============================================================================
|
||||
|
||||
class ScraperUpdate(BaseModel):
|
||||
enabled: Optional[bool] = None
|
||||
proxy_enabled: Optional[bool] = None
|
||||
proxy_url: Optional[str] = None
|
||||
flaresolverr_required: Optional[bool] = None
|
||||
base_url: Optional[str] = None
|
||||
|
||||
|
||||
class CookieUpload(BaseModel):
|
||||
cookies: List[dict]
|
||||
merge: bool = True
|
||||
user_agent: Optional[str] = None
|
||||
|
||||
|
||||
class DismissErrors(BaseModel):
|
||||
error_ids: Optional[List[int]] = None
|
||||
dismiss_all: bool = False
|
||||
|
||||
|
||||
class MarkErrorsViewed(BaseModel):
|
||||
error_ids: Optional[List[int]] = None
|
||||
mark_all: bool = False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SCRAPER ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/scrapers")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_scrapers(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user),
|
||||
type_filter: Optional[str] = Query(None, alias="type", description="Filter by type")
|
||||
):
|
||||
"""Get all scrapers with optional type filter."""
|
||||
app_state = get_app_state()
|
||||
scrapers = app_state.db.get_all_scrapers(type_filter=type_filter)
|
||||
|
||||
# Filter out scrapers whose related modules are all hidden
|
||||
hidden_modules = app_state.config.get('hidden_modules', [])
|
||||
if hidden_modules:
|
||||
# Map scraper IDs to the modules that use them.
|
||||
# A scraper is only hidden if ALL related modules are hidden.
|
||||
scraper_to_modules = {
|
||||
'instagram': ['instagram', 'instagram_client'],
|
||||
'snapchat': ['snapchat', 'snapchat_client'],
|
||||
'fastdl': ['fastdl'],
|
||||
'imginn': ['imginn'],
|
||||
'toolzu': ['toolzu'],
|
||||
'tiktok': ['tiktok'],
|
||||
'coppermine': ['coppermine'],
|
||||
}
|
||||
# Forum scrapers map to the 'forums' module
|
||||
filtered = []
|
||||
for scraper in scrapers:
|
||||
sid = scraper.get('id', '')
|
||||
if sid.startswith('forum_'):
|
||||
related = ['forums']
|
||||
else:
|
||||
related = scraper_to_modules.get(sid, [])
|
||||
# Only hide if ALL related modules are hidden
|
||||
if related and all(m in hidden_modules for m in related):
|
||||
continue
|
||||
filtered.append(scraper)
|
||||
scrapers = filtered
|
||||
|
||||
# Don't send cookies_json to frontend (too large)
|
||||
for scraper in scrapers:
|
||||
if 'cookies_json' in scraper:
|
||||
del scraper['cookies_json']
|
||||
|
||||
return {"scrapers": scrapers}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PLATFORM CREDENTIALS (UNIFIED COOKIE MANAGEMENT)
|
||||
# ============================================================================
|
||||
|
||||
# Platform definitions for the unified credentials view
|
||||
_SCRAPER_PLATFORMS = [
|
||||
{'id': 'instagram', 'name': 'Instagram', 'type': 'cookies', 'source': 'scraper', 'used_by': ['Scheduler']},
|
||||
{'id': 'tiktok', 'name': 'TikTok', 'type': 'cookies', 'source': 'scraper', 'used_by': ['Scheduler']},
|
||||
{'id': 'snapchat', 'name': 'Snapchat', 'type': 'cookies', 'source': 'scraper', 'used_by': ['Scheduler']},
|
||||
{'id': 'ytdlp', 'name': 'YouTube', 'type': 'cookies', 'source': 'scraper', 'used_by': ['Scheduler']},
|
||||
{'id': 'pornhub', 'name': 'PornHub', 'type': 'cookies', 'source': 'scraper', 'used_by': ['Scheduler']},
|
||||
{'id': 'xhamster', 'name': 'xHamster', 'type': 'cookies', 'source': 'scraper', 'used_by': ['Scheduler']},
|
||||
]
|
||||
|
||||
_PAID_CONTENT_PLATFORMS = [
|
||||
{'id': 'onlyfans_direct', 'name': 'OnlyFans', 'type': 'token', 'source': 'paid_content', 'used_by': ['Paid Content'], 'base_url': 'https://onlyfans.com'},
|
||||
{'id': 'fansly_direct', 'name': 'Fansly', 'type': 'token', 'source': 'paid_content', 'used_by': ['Paid Content'], 'base_url': 'https://fansly.com'},
|
||||
{'id': 'coomer', 'name': 'Coomer', 'type': 'session', 'source': 'paid_content', 'used_by': ['Paid Content'], 'base_url': 'https://coomer.su'},
|
||||
{'id': 'kemono', 'name': 'Kemono', 'type': 'session', 'source': 'paid_content', 'used_by': ['Paid Content'], 'base_url': 'https://kemono.su'},
|
||||
{'id': 'twitch', 'name': 'Twitch', 'type': 'session', 'source': 'paid_content', 'used_by': ['Paid Content'], 'base_url': 'https://twitch.tv'},
|
||||
{'id': 'bellazon', 'name': 'Bellazon', 'type': 'session', 'source': 'paid_content', 'used_by': ['Paid Content'], 'base_url': 'https://www.bellazon.com'},
|
||||
]
|
||||
|
||||
|
||||
@router.get("/scrapers/platform-credentials")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def get_platform_credentials(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get aggregated credential status for all platforms + monitoring preferences."""
|
||||
app_state = get_app_state()
|
||||
db = app_state.db
|
||||
platforms = []
|
||||
|
||||
def _get_monitoring_flag(platform_id: str) -> bool:
|
||||
"""Read monitoring preference from settings."""
|
||||
try:
|
||||
val = app_state.settings.get(f"cookie_monitoring:{platform_id}")
|
||||
if val is not None:
|
||||
return str(val).lower() not in ('false', '0', 'no')
|
||||
except Exception:
|
||||
pass
|
||||
return True
|
||||
|
||||
# 1. Scraper platforms
|
||||
for platform_def in _SCRAPER_PLATFORMS:
|
||||
scraper = db.get_scraper(platform_def['id'])
|
||||
cookies_count = 0
|
||||
updated_at = None
|
||||
if scraper:
|
||||
raw = scraper.get('cookies_json')
|
||||
if raw:
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
if isinstance(data, list):
|
||||
cookies_count = len(data)
|
||||
elif isinstance(data, dict):
|
||||
c = data.get('cookies', [])
|
||||
cookies_count = len(c) if isinstance(c, list) else 0
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
updated_at = scraper.get('cookies_updated_at')
|
||||
|
||||
monitoring_enabled = _get_monitoring_flag(platform_def['id'])
|
||||
|
||||
platforms.append({
|
||||
'id': platform_def['id'],
|
||||
'name': platform_def['name'],
|
||||
'type': platform_def['type'],
|
||||
'source': platform_def['source'],
|
||||
'cookies_count': cookies_count,
|
||||
'has_credentials': cookies_count > 0,
|
||||
'updated_at': updated_at,
|
||||
'used_by': platform_def['used_by'],
|
||||
'monitoring_enabled': monitoring_enabled,
|
||||
})
|
||||
|
||||
# 2. Paid content platforms
|
||||
try:
|
||||
from modules.paid_content import PaidContentDBAdapter
|
||||
paid_db = PaidContentDBAdapter(db)
|
||||
paid_services = {svc['id']: svc for svc in paid_db.get_services()}
|
||||
except Exception:
|
||||
paid_services = {}
|
||||
|
||||
for platform_def in _PAID_CONTENT_PLATFORMS:
|
||||
svc = paid_services.get(platform_def['id'], {})
|
||||
session_val = svc.get('session_cookie') or ''
|
||||
has_creds = bool(session_val)
|
||||
updated_at = svc.get('session_updated_at')
|
||||
|
||||
# Count credentials: for JSON objects count keys, for JSON arrays count items, otherwise 1 if set
|
||||
cookies_count = 0
|
||||
if has_creds:
|
||||
try:
|
||||
parsed = json.loads(session_val)
|
||||
if isinstance(parsed, dict):
|
||||
cookies_count = len(parsed)
|
||||
elif isinstance(parsed, list):
|
||||
cookies_count = len(parsed)
|
||||
else:
|
||||
cookies_count = 1
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
cookies_count = 1
|
||||
|
||||
platforms.append({
|
||||
'id': platform_def['id'],
|
||||
'name': platform_def['name'],
|
||||
'type': platform_def['type'],
|
||||
'source': platform_def['source'],
|
||||
'base_url': platform_def.get('base_url'),
|
||||
'cookies_count': cookies_count,
|
||||
'has_credentials': has_creds,
|
||||
'updated_at': updated_at,
|
||||
'used_by': platform_def['used_by'],
|
||||
'monitoring_enabled': _get_monitoring_flag(platform_def['id']),
|
||||
})
|
||||
|
||||
# 3. Reddit (private gallery)
|
||||
reddit_has_creds = False
|
||||
reddit_cookies_count = 0
|
||||
reddit_locked = True
|
||||
try:
|
||||
from modules.reddit_community_monitor import RedditCommunityMonitor, REDDIT_MONITOR_KEY_FILE
|
||||
from modules.private_gallery_crypto import get_private_gallery_crypto, load_key_from_file
|
||||
db_path = str(Path(__file__).parent.parent.parent.parent / 'database' / 'media_downloader.db')
|
||||
reddit_monitor = RedditCommunityMonitor(db_path)
|
||||
crypto = get_private_gallery_crypto()
|
||||
reddit_locked = not crypto.is_initialized()
|
||||
|
||||
# If gallery is locked, try loading crypto from key file (exported on unlock)
|
||||
active_crypto = crypto if not reddit_locked else load_key_from_file(REDDIT_MONITOR_KEY_FILE)
|
||||
|
||||
if active_crypto and active_crypto.is_initialized():
|
||||
reddit_has_creds = reddit_monitor.has_cookies(active_crypto)
|
||||
if reddit_has_creds:
|
||||
try:
|
||||
conn = reddit_monitor._get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT value FROM private_media_config WHERE key = 'reddit_monitor_encrypted_cookies'")
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
if row and row['value']:
|
||||
decrypted = active_crypto.decrypt_field(row['value'])
|
||||
parsed = json.loads(decrypted)
|
||||
if isinstance(parsed, list):
|
||||
reddit_cookies_count = len(parsed)
|
||||
except Exception:
|
||||
reddit_cookies_count = 1 if reddit_has_creds else 0
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
platforms.append({
|
||||
'id': 'reddit',
|
||||
'name': 'Reddit',
|
||||
'type': 'cookies',
|
||||
'source': 'private_gallery',
|
||||
'base_url': 'https://reddit.com',
|
||||
'cookies_count': reddit_cookies_count,
|
||||
'has_credentials': reddit_has_creds,
|
||||
'gallery_locked': reddit_locked,
|
||||
'updated_at': None,
|
||||
'used_by': ['Private Gallery'],
|
||||
'monitoring_enabled': _get_monitoring_flag('reddit'),
|
||||
})
|
||||
|
||||
return {
|
||||
'platforms': platforms,
|
||||
'global_monitoring_enabled': _get_monitoring_flag('global'),
|
||||
}
|
||||
|
||||
|
||||
@router.put("/scrapers/platform-credentials/{platform_id}/monitoring")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def toggle_platform_monitoring(
|
||||
request: Request,
|
||||
platform_id: str,
|
||||
current_user: Dict = Depends(require_admin)
|
||||
):
|
||||
"""Toggle health monitoring for a single platform."""
|
||||
app_state = get_app_state()
|
||||
body = await request.json()
|
||||
enabled = body.get('enabled', True)
|
||||
|
||||
app_state.settings.set(
|
||||
key=f"cookie_monitoring:{platform_id}",
|
||||
value=str(enabled).lower(),
|
||||
category="cookie_monitoring",
|
||||
updated_by=current_user.get('username', 'user')
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': f"Monitoring {'enabled' if enabled else 'disabled'} for {platform_id}",
|
||||
}
|
||||
|
||||
|
||||
@router.put("/scrapers/platform-credentials/monitoring")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def toggle_global_monitoring(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(require_admin)
|
||||
):
|
||||
"""Toggle global cookie health monitoring."""
|
||||
app_state = get_app_state()
|
||||
body = await request.json()
|
||||
enabled = body.get('enabled', True)
|
||||
|
||||
app_state.settings.set(
|
||||
key="cookie_monitoring:global",
|
||||
value=str(enabled).lower(),
|
||||
category="cookie_monitoring",
|
||||
updated_by=current_user.get('username', 'user')
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': f"Global cookie monitoring {'enabled' if enabled else 'disabled'}",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/scrapers/{scraper_id}")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_scraper(
|
||||
request: Request,
|
||||
scraper_id: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get a single scraper configuration."""
|
||||
app_state = get_app_state()
|
||||
scraper = app_state.db.get_scraper(scraper_id)
|
||||
if not scraper:
|
||||
raise NotFoundError(f"Scraper '{scraper_id}' not found")
|
||||
|
||||
if 'cookies_json' in scraper:
|
||||
del scraper['cookies_json']
|
||||
|
||||
cookies = app_state.db.get_scraper_cookies(scraper_id)
|
||||
scraper['cookies_count'] = len(cookies) if cookies else 0
|
||||
|
||||
return scraper
|
||||
|
||||
|
||||
@router.put("/scrapers/{scraper_id}")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def update_scraper(
|
||||
request: Request,
|
||||
scraper_id: str,
|
||||
current_user: Dict = Depends(require_admin)
|
||||
):
|
||||
"""Update scraper settings (proxy, enabled, base_url)."""
|
||||
app_state = get_app_state()
|
||||
body = await request.json()
|
||||
|
||||
scraper = app_state.db.get_scraper(scraper_id)
|
||||
if not scraper:
|
||||
raise NotFoundError(f"Scraper '{scraper_id}' not found")
|
||||
|
||||
success = app_state.db.update_scraper(scraper_id, body)
|
||||
if not success:
|
||||
raise ValidationError("No valid fields to update")
|
||||
|
||||
return {"success": True, "message": f"Scraper '{scraper_id}' updated"}
|
||||
|
||||
|
||||
@router.post("/scrapers/{scraper_id}/test")
|
||||
@limiter.limit("10/minute")
|
||||
@handle_exceptions
|
||||
async def test_scraper_connection(
|
||||
request: Request,
|
||||
scraper_id: str,
|
||||
current_user: Dict = Depends(require_admin)
|
||||
):
|
||||
"""
|
||||
Test scraper connection via FlareSolverr (if required).
|
||||
On success, saves cookies to database.
|
||||
For CLI tools (yt-dlp, gallery-dl), tests that the tool is installed and working.
|
||||
"""
|
||||
import subprocess
|
||||
from modules.cloudflare_handler import CloudflareHandler
|
||||
|
||||
app_state = get_app_state()
|
||||
scraper = app_state.db.get_scraper(scraper_id)
|
||||
if not scraper:
|
||||
raise NotFoundError(f"Scraper '{scraper_id}' not found")
|
||||
|
||||
# Handle CLI tools specially - test that they're installed and working
|
||||
if scraper.get('type') == 'cli_tool':
|
||||
cli_tests = {
|
||||
'ytdlp': {
|
||||
'cmd': ['/opt/media-downloader/venv/bin/yt-dlp', '--version'],
|
||||
'name': 'yt-dlp'
|
||||
},
|
||||
'gallerydl': {
|
||||
'cmd': ['/opt/media-downloader/venv/bin/gallery-dl', '--version'],
|
||||
'name': 'gallery-dl'
|
||||
}
|
||||
}
|
||||
|
||||
test_config = cli_tests.get(scraper_id)
|
||||
if test_config:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
test_config['cmd'],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10
|
||||
)
|
||||
if result.returncode == 0:
|
||||
version = result.stdout.strip().split('\n')[0]
|
||||
cookies_count = 0
|
||||
# Check if cookies are configured
|
||||
if scraper.get('cookies_json'):
|
||||
try:
|
||||
import json
|
||||
data = json.loads(scraper['cookies_json'])
|
||||
# Support both {"cookies": [...]} and [...] formats
|
||||
if isinstance(data, dict) and 'cookies' in data:
|
||||
cookies = data['cookies']
|
||||
elif isinstance(data, list):
|
||||
cookies = data
|
||||
else:
|
||||
cookies = []
|
||||
cookies_count = len(cookies) if cookies else 0
|
||||
except (json.JSONDecodeError, TypeError, KeyError) as e:
|
||||
logger.debug(f"Failed to parse cookies for {scraper_id}: {e}")
|
||||
|
||||
app_state.db.update_scraper_test_status(scraper_id, 'success')
|
||||
msg = f"{test_config['name']} v{version} installed"
|
||||
if cookies_count > 0:
|
||||
msg += f", {cookies_count} cookies configured"
|
||||
return {
|
||||
"success": True,
|
||||
"message": msg
|
||||
}
|
||||
else:
|
||||
error_msg = result.stderr.strip() or "Command failed"
|
||||
app_state.db.update_scraper_test_status(scraper_id, 'failed', error_msg)
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"{test_config['name']} error: {error_msg}"
|
||||
}
|
||||
except subprocess.TimeoutExpired:
|
||||
app_state.db.update_scraper_test_status(scraper_id, 'failed', "Command timed out")
|
||||
return {"success": False, "message": "Command timed out"}
|
||||
except FileNotFoundError:
|
||||
app_state.db.update_scraper_test_status(scraper_id, 'failed', "Tool not installed")
|
||||
return {"success": False, "message": f"{test_config['name']} not installed"}
|
||||
else:
|
||||
# Unknown CLI tool
|
||||
app_state.db.update_scraper_test_status(scraper_id, 'success')
|
||||
return {"success": True, "message": "CLI tool registered"}
|
||||
|
||||
base_url = scraper.get('base_url')
|
||||
if not base_url:
|
||||
raise ValidationError(f"Scraper '{scraper_id}' has no base_url configured")
|
||||
|
||||
proxy_url = None
|
||||
if scraper.get('proxy_enabled') and scraper.get('proxy_url'):
|
||||
proxy_url = scraper['proxy_url']
|
||||
|
||||
cf_handler = CloudflareHandler(
|
||||
module_name=scraper_id,
|
||||
cookie_file=None,
|
||||
proxy_url=proxy_url if proxy_url else None,
|
||||
flaresolverr_enabled=scraper.get('flaresolverr_required', False)
|
||||
)
|
||||
|
||||
if scraper.get('flaresolverr_required'):
|
||||
success = cf_handler.get_cookies_via_flaresolverr(base_url, max_retries=2)
|
||||
|
||||
if success:
|
||||
cookies = cf_handler.get_cookies_list()
|
||||
user_agent = cf_handler.get_user_agent()
|
||||
|
||||
app_state.db.save_scraper_cookies(scraper_id, cookies, user_agent=user_agent)
|
||||
app_state.db.update_scraper_test_status(scraper_id, 'success')
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Connection successful, {len(cookies)} cookies saved",
|
||||
"cookies_count": len(cookies)
|
||||
}
|
||||
else:
|
||||
error_msg = "FlareSolverr returned no cookies"
|
||||
if proxy_url:
|
||||
error_msg += " (check proxy connection)"
|
||||
app_state.db.update_scraper_test_status(scraper_id, 'failed', error_msg)
|
||||
return {
|
||||
"success": False,
|
||||
"message": error_msg
|
||||
}
|
||||
else:
|
||||
try:
|
||||
proxies = {"http": proxy_url, "https": proxy_url} if proxy_url else None
|
||||
response = requests.get(
|
||||
base_url,
|
||||
timeout=10,
|
||||
proxies=proxies,
|
||||
headers={'User-Agent': cf_handler.user_agent}
|
||||
)
|
||||
|
||||
if response.status_code < 400:
|
||||
app_state.db.update_scraper_test_status(scraper_id, 'success')
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Connection successful (HTTP {response.status_code})"
|
||||
}
|
||||
else:
|
||||
app_state.db.update_scraper_test_status(
|
||||
scraper_id, 'failed',
|
||||
f"HTTP {response.status_code}"
|
||||
)
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Connection failed with HTTP {response.status_code}"
|
||||
}
|
||||
except requests.exceptions.Timeout:
|
||||
app_state.db.update_scraper_test_status(scraper_id, 'timeout', 'Request timed out')
|
||||
return {"success": False, "message": "Connection timed out"}
|
||||
except Exception as e:
|
||||
app_state.db.update_scraper_test_status(scraper_id, 'failed', str(e))
|
||||
return {"success": False, "message": str(e)}
|
||||
|
||||
|
||||
@router.post("/scrapers/{scraper_id}/cookies")
|
||||
@limiter.limit("20/minute")
|
||||
@handle_exceptions
|
||||
async def upload_scraper_cookies(
|
||||
request: Request,
|
||||
scraper_id: str,
|
||||
current_user: Dict = Depends(require_admin)
|
||||
):
|
||||
"""Upload cookies for a scraper (from browser extension export)."""
|
||||
app_state = get_app_state()
|
||||
|
||||
scraper = app_state.db.get_scraper(scraper_id)
|
||||
if not scraper:
|
||||
raise NotFoundError(f"Scraper '{scraper_id}' not found")
|
||||
|
||||
body = await request.json()
|
||||
|
||||
# Support both {cookies: [...]} and bare [...] formats
|
||||
if isinstance(body, list):
|
||||
cookies = body
|
||||
merge = True
|
||||
user_agent = None
|
||||
else:
|
||||
cookies = body.get('cookies', [])
|
||||
merge = body.get('merge', True)
|
||||
user_agent = body.get('user_agent')
|
||||
|
||||
if not cookies or not isinstance(cookies, list):
|
||||
raise ValidationError("Invalid cookies format. Expected {cookies: [...]}")
|
||||
|
||||
for i, cookie in enumerate(cookies):
|
||||
if not isinstance(cookie, dict):
|
||||
raise ValidationError(f"Cookie {i} is not an object")
|
||||
if 'name' not in cookie or 'value' not in cookie:
|
||||
raise ValidationError(f"Cookie {i} missing 'name' or 'value'")
|
||||
|
||||
success = app_state.db.save_scraper_cookies(
|
||||
scraper_id, cookies,
|
||||
user_agent=user_agent,
|
||||
merge=merge
|
||||
)
|
||||
|
||||
if success:
|
||||
all_cookies = app_state.db.get_scraper_cookies(scraper_id)
|
||||
count = len(all_cookies) if all_cookies else 0
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"{'Merged' if merge else 'Replaced'} {len(cookies)} cookies (total: {count})",
|
||||
"cookies_count": count
|
||||
}
|
||||
else:
|
||||
raise ValidationError("Failed to save cookies")
|
||||
|
||||
|
||||
@router.delete("/scrapers/{scraper_id}/cookies")
|
||||
@limiter.limit("20/minute")
|
||||
@handle_exceptions
|
||||
async def clear_scraper_cookies(
|
||||
request: Request,
|
||||
scraper_id: str,
|
||||
current_user: Dict = Depends(require_admin)
|
||||
):
|
||||
"""Clear all cookies for a scraper."""
|
||||
app_state = get_app_state()
|
||||
|
||||
scraper = app_state.db.get_scraper(scraper_id)
|
||||
if not scraper:
|
||||
raise NotFoundError(f"Scraper '{scraper_id}' not found")
|
||||
|
||||
success = app_state.db.clear_scraper_cookies(scraper_id)
|
||||
|
||||
return {
|
||||
"success": success,
|
||||
"message": f"Cookies cleared for '{scraper_id}'" if success else "Failed to clear cookies"
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ERROR MONITORING ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/errors/recent")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def get_recent_errors(
|
||||
request: Request,
|
||||
limit: int = Query(50, ge=1, le=500, description="Maximum number of errors to return"),
|
||||
since_visit: bool = Query(False, description="Only show errors since last dashboard visit (default: show ALL unviewed)"),
|
||||
include_dismissed: bool = Query(False, description="Include dismissed errors"),
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get recent errors from database.
|
||||
|
||||
By default, shows ALL unviewed/undismissed errors regardless of when they occurred.
|
||||
This ensures errors are not missed just because the user visited the dashboard.
|
||||
Errors are recorded in real-time by universal_logger.py.
|
||||
"""
|
||||
app_state = get_app_state()
|
||||
|
||||
# By default, show ALL unviewed errors (since=None)
|
||||
# Only filter by visit time if explicitly requested
|
||||
since = None
|
||||
if since_visit:
|
||||
since = app_state.db.get_last_dashboard_visit()
|
||||
if not since:
|
||||
since = datetime.now() - timedelta(hours=24)
|
||||
|
||||
errors = app_state.db.get_recent_errors(since=since, include_dismissed=include_dismissed, limit=limit)
|
||||
|
||||
return {
|
||||
"errors": errors,
|
||||
"total_count": len(errors),
|
||||
"since": since.isoformat() if since else None,
|
||||
"unviewed_count": app_state.db.get_unviewed_error_count(since=None) # Always count ALL unviewed
|
||||
}
|
||||
|
||||
|
||||
@router.get("/errors/count")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_error_count(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get count of ALL unviewed/undismissed errors.
|
||||
|
||||
Errors are recorded in real-time by universal_logger.py.
|
||||
"""
|
||||
app_state = get_app_state()
|
||||
|
||||
# Count ALL unviewed errors
|
||||
total_unviewed = app_state.db.get_unviewed_error_count(since=None)
|
||||
|
||||
# Count errors since last dashboard visit
|
||||
last_visit = app_state.db.get_last_dashboard_visit()
|
||||
since_last_visit = app_state.db.get_unviewed_error_count(since=last_visit) if last_visit else total_unviewed
|
||||
|
||||
return {
|
||||
"unviewed_count": total_unviewed,
|
||||
"total_recent": total_unviewed,
|
||||
"since_last_visit": since_last_visit
|
||||
}
|
||||
|
||||
|
||||
@router.post("/errors/dismiss")
|
||||
@limiter.limit("20/minute")
|
||||
@handle_exceptions
|
||||
async def dismiss_errors(
|
||||
request: Request,
|
||||
body: Dict = Body(...),
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Dismiss errors by ID or all."""
|
||||
app_state = get_app_state()
|
||||
|
||||
error_ids = body.get("error_ids", [])
|
||||
dismiss_all = body.get("dismiss_all", False)
|
||||
|
||||
if dismiss_all:
|
||||
dismissed = app_state.db.dismiss_errors(dismiss_all=True)
|
||||
elif error_ids:
|
||||
dismissed = app_state.db.dismiss_errors(error_ids=error_ids)
|
||||
else:
|
||||
return {"success": False, "dismissed": 0, "message": "No errors specified"}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"dismissed": dismissed,
|
||||
"message": f"Dismissed {dismissed} error(s)"
|
||||
}
|
||||
|
||||
|
||||
@router.post("/errors/mark-viewed")
|
||||
@limiter.limit("20/minute")
|
||||
@handle_exceptions
|
||||
async def mark_errors_viewed(
|
||||
request: Request,
|
||||
body: Dict = Body(...),
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Mark errors as viewed."""
|
||||
app_state = get_app_state()
|
||||
|
||||
error_ids = body.get("error_ids", [])
|
||||
mark_all = body.get("mark_all", False)
|
||||
|
||||
if mark_all:
|
||||
marked = app_state.db.mark_errors_viewed(mark_all=True)
|
||||
elif error_ids:
|
||||
marked = app_state.db.mark_errors_viewed(error_ids=error_ids)
|
||||
else:
|
||||
return {"success": False, "marked": 0}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"marked": marked
|
||||
}
|
||||
|
||||
|
||||
@router.post("/errors/update-visit")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def update_dashboard_visit(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Update the last dashboard visit timestamp."""
|
||||
app_state = get_app_state()
|
||||
success = app_state.db.update_dashboard_visit()
|
||||
return {"success": success}
|
||||
|
||||
|
||||
@router.get("/logs/context")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def get_log_context(
|
||||
request: Request,
|
||||
timestamp: str = Query(..., description="ISO timestamp of the error"),
|
||||
module: Optional[str] = Query(None, description="Module name to filter"),
|
||||
minutes_before: int = Query(1, description="Minutes of context before error"),
|
||||
minutes_after: int = Query(1, description="Minutes of context after error"),
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get log lines around a specific timestamp for debugging context."""
|
||||
target_time = datetime.fromisoformat(timestamp)
|
||||
start_time = target_time - timedelta(minutes=minutes_before)
|
||||
end_time = target_time + timedelta(minutes=minutes_after)
|
||||
|
||||
log_dir = Path('/opt/media-downloader/logs')
|
||||
log_pattern = re.compile(
|
||||
r'^(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}) '
|
||||
r'\[MediaDownloader\.(\w+)\] '
|
||||
r'\[(\w+)\] '
|
||||
r'\[(\w+)\] '
|
||||
r'(.+)$'
|
||||
)
|
||||
|
||||
date_str = target_time.strftime('%Y%m%d')
|
||||
matching_lines = []
|
||||
|
||||
for log_file in log_dir.glob(f'{date_str}_*.log'):
|
||||
if module and module.lower() not in log_file.stem.lower():
|
||||
continue
|
||||
|
||||
try:
|
||||
lines = log_file.read_text(errors='replace').splitlines()
|
||||
|
||||
for line in lines:
|
||||
match = log_pattern.match(line)
|
||||
if match:
|
||||
timestamp_str, _, log_module, level, message = match.groups()
|
||||
try:
|
||||
line_time = datetime.strptime(timestamp_str, '%Y-%m-%d %H:%M:%S')
|
||||
if start_time <= line_time <= end_time:
|
||||
matching_lines.append({
|
||||
'timestamp': timestamp_str,
|
||||
'module': log_module,
|
||||
'level': level,
|
||||
'message': message,
|
||||
'is_target': abs((line_time - target_time).total_seconds()) < 2
|
||||
})
|
||||
except ValueError:
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
matching_lines.sort(key=lambda x: x['timestamp'])
|
||||
|
||||
return {
|
||||
"context": matching_lines,
|
||||
"target_timestamp": timestamp,
|
||||
"range": {
|
||||
"start": start_time.isoformat(),
|
||||
"end": end_time.isoformat()
|
||||
}
|
||||
}
|
||||
366
web/backend/routers/semantic.py
Normal file
366
web/backend/routers/semantic.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""
|
||||
Semantic Search Router
|
||||
|
||||
Handles CLIP-based semantic search operations:
|
||||
- Text-based image/video search
|
||||
- Similar file search
|
||||
- Embedding generation and management
|
||||
- Model settings
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Body, Depends, Query, Request
|
||||
from pydantic import BaseModel
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from ..core.dependencies import get_current_user, require_admin, get_app_state
|
||||
from ..core.exceptions import handle_exceptions, ValidationError
|
||||
from modules.semantic_search import get_semantic_search
|
||||
from modules.universal_logger import get_logger
|
||||
|
||||
logger = get_logger('API')
|
||||
|
||||
router = APIRouter(prefix="/api/semantic", tags=["Semantic Search"])
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
# Batch limit for embedding generation
|
||||
EMBEDDING_BATCH_LIMIT = 10000
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PYDANTIC MODELS
|
||||
# ============================================================================
|
||||
|
||||
class SemanticSearchRequest(BaseModel):
|
||||
query: str
|
||||
limit: int = 50
|
||||
platform: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
threshold: float = 0.2
|
||||
|
||||
|
||||
class GenerateEmbeddingsRequest(BaseModel):
|
||||
limit: int = 100
|
||||
platform: Optional[str] = None
|
||||
|
||||
|
||||
class SemanticSettingsUpdate(BaseModel):
|
||||
model: Optional[str] = None
|
||||
threshold: Optional[float] = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SEARCH ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/stats")
|
||||
@limiter.limit("120/minute")
|
||||
@handle_exceptions
|
||||
async def get_semantic_stats(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get statistics about semantic search embeddings."""
|
||||
app_state = get_app_state()
|
||||
search = get_semantic_search(app_state.db)
|
||||
stats = search.get_embedding_stats()
|
||||
return stats
|
||||
|
||||
|
||||
@router.post("/search")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def semantic_search(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user),
|
||||
query: str = Body(..., embed=True),
|
||||
limit: int = Body(50, ge=1, le=200),
|
||||
platform: Optional[str] = Body(None),
|
||||
source: Optional[str] = Body(None),
|
||||
threshold: float = Body(0.2, ge=0.0, le=1.0)
|
||||
):
|
||||
"""Search for images/videos using natural language query."""
|
||||
app_state = get_app_state()
|
||||
search = get_semantic_search(app_state.db)
|
||||
results = search.search_by_text(
|
||||
query=query,
|
||||
limit=limit,
|
||||
platform=platform,
|
||||
source=source,
|
||||
threshold=threshold
|
||||
)
|
||||
return {"results": results, "count": len(results), "query": query}
|
||||
|
||||
|
||||
@router.post("/similar/{file_id}")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def find_similar_files(
|
||||
request: Request,
|
||||
file_id: int,
|
||||
current_user: Dict = Depends(get_current_user),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
platform: Optional[str] = Query(None),
|
||||
source: Optional[str] = Query(None),
|
||||
threshold: float = Query(0.5, ge=0.0, le=1.0)
|
||||
):
|
||||
"""Find files similar to a given file."""
|
||||
app_state = get_app_state()
|
||||
search = get_semantic_search(app_state.db)
|
||||
results = search.search_by_file_id(
|
||||
file_id=file_id,
|
||||
limit=limit,
|
||||
platform=platform,
|
||||
source=source,
|
||||
threshold=threshold
|
||||
)
|
||||
return {"results": results, "count": len(results), "source_file_id": file_id}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# EMBEDDING GENERATION ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.post("/generate")
|
||||
@limiter.limit("10/minute")
|
||||
@handle_exceptions
|
||||
async def generate_embeddings(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict = Depends(get_current_user),
|
||||
limit: int = Body(100, ge=1, le=1000),
|
||||
platform: Optional[str] = Body(None)
|
||||
):
|
||||
"""Generate CLIP embeddings for files that don't have them yet."""
|
||||
app_state = get_app_state()
|
||||
|
||||
if app_state.indexing_running:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Indexing already in progress",
|
||||
"already_running": True,
|
||||
"status": "already_running"
|
||||
}
|
||||
|
||||
search = get_semantic_search(app_state.db)
|
||||
|
||||
app_state.indexing_running = True
|
||||
app_state.indexing_start_time = time.time()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
manager = getattr(app_state, 'websocket_manager', None)
|
||||
|
||||
def progress_callback(processed: int, total: int, current_file: str):
|
||||
try:
|
||||
if processed % 10 == 0 or processed == total:
|
||||
stats = search.get_embedding_stats()
|
||||
if manager:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
manager.broadcast({
|
||||
"type": "embedding_progress",
|
||||
"processed": processed,
|
||||
"total": total,
|
||||
"current_file": current_file,
|
||||
"total_embeddings": stats.get('total_embeddings', 0),
|
||||
"coverage_percent": stats.get('coverage_percent', 0)
|
||||
}),
|
||||
loop
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def run_generation():
|
||||
try:
|
||||
results = search.generate_embeddings_batch(
|
||||
limit=limit,
|
||||
platform=platform,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
logger.info(f"Embedding generation complete: {results}", module="SemanticSearch")
|
||||
try:
|
||||
stats = search.get_embedding_stats()
|
||||
if manager:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
manager.broadcast({
|
||||
"type": "embedding_complete",
|
||||
"results": results,
|
||||
"total_embeddings": stats.get('total_embeddings', 0),
|
||||
"coverage_percent": stats.get('coverage_percent', 0)
|
||||
}),
|
||||
loop
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Embedding generation failed: {e}", module="SemanticSearch")
|
||||
finally:
|
||||
app_state.indexing_running = False
|
||||
app_state.indexing_start_time = None
|
||||
|
||||
background_tasks.add_task(run_generation)
|
||||
|
||||
return {
|
||||
"message": f"Started embedding generation for up to {limit} files",
|
||||
"status": "processing"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_semantic_indexing_status(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Check if semantic indexing is currently running."""
|
||||
app_state = get_app_state()
|
||||
elapsed = None
|
||||
if app_state.indexing_running and app_state.indexing_start_time:
|
||||
elapsed = int(time.time() - app_state.indexing_start_time)
|
||||
return {
|
||||
"indexing_running": app_state.indexing_running,
|
||||
"elapsed_seconds": elapsed
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SETTINGS ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/settings")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def get_semantic_settings(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get semantic search settings."""
|
||||
app_state = get_app_state()
|
||||
settings = app_state.settings.get('semantic_search', {})
|
||||
return {
|
||||
"model": settings.get('model', 'clip-ViT-B-32'),
|
||||
"threshold": settings.get('threshold', 0.2)
|
||||
}
|
||||
|
||||
|
||||
@router.post("/settings")
|
||||
@limiter.limit("10/minute")
|
||||
@handle_exceptions
|
||||
async def update_semantic_settings(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict = Depends(require_admin),
|
||||
model: str = Body(None, embed=True),
|
||||
threshold: float = Body(None, embed=True)
|
||||
):
|
||||
"""Update semantic search settings."""
|
||||
app_state = get_app_state()
|
||||
|
||||
current_settings = app_state.settings.get('semantic_search', {}) or {}
|
||||
old_model = current_settings.get('model', 'clip-ViT-B-32') if isinstance(current_settings, dict) else 'clip-ViT-B-32'
|
||||
model_changed = False
|
||||
|
||||
new_settings = dict(current_settings) if isinstance(current_settings, dict) else {}
|
||||
|
||||
if model:
|
||||
valid_models = ['clip-ViT-B-32', 'clip-ViT-B-16', 'clip-ViT-L-14']
|
||||
if model not in valid_models:
|
||||
raise ValidationError(f"Invalid model. Must be one of: {valid_models}")
|
||||
if model != old_model:
|
||||
model_changed = True
|
||||
new_settings['model'] = model
|
||||
|
||||
if threshold is not None:
|
||||
if threshold < 0 or threshold > 1:
|
||||
raise ValidationError("Threshold must be between 0 and 1")
|
||||
new_settings['threshold'] = threshold
|
||||
|
||||
app_state.settings.set('semantic_search', new_settings, category='ai')
|
||||
|
||||
if model_changed:
|
||||
logger.info(f"Model changed from {old_model} to {model}, clearing embeddings", module="SemanticSearch")
|
||||
|
||||
with app_state.db.get_connection(for_write=True) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('DELETE FROM content_embeddings')
|
||||
deleted = cursor.rowcount
|
||||
logger.info(f"Cleared {deleted} embeddings for model change", module="SemanticSearch")
|
||||
|
||||
def run_reindex_for_model_change():
|
||||
try:
|
||||
search = get_semantic_search(app_state.db, force_reload=True)
|
||||
results = search.generate_embeddings_batch(limit=EMBEDDING_BATCH_LIMIT)
|
||||
logger.info(f"Model change re-index complete: {results}", module="SemanticSearch")
|
||||
except Exception as e:
|
||||
logger.error(f"Model change re-index failed: {e}", module="SemanticSearch")
|
||||
|
||||
background_tasks.add_task(run_reindex_for_model_change)
|
||||
|
||||
logger.info(f"Semantic search settings updated: {new_settings} (re-indexing started)", module="SemanticSearch")
|
||||
return {"success": True, "settings": new_settings, "reindexing": True, "message": f"Model changed to {model}, re-indexing started"}
|
||||
|
||||
logger.info(f"Semantic search settings updated: {new_settings}", module="SemanticSearch")
|
||||
return {"success": True, "settings": new_settings, "reindexing": False}
|
||||
|
||||
|
||||
@router.post("/reindex")
|
||||
@limiter.limit("2/minute")
|
||||
@handle_exceptions
|
||||
async def reindex_embeddings(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict = Depends(require_admin)
|
||||
):
|
||||
"""Clear and regenerate all embeddings."""
|
||||
app_state = get_app_state()
|
||||
|
||||
if app_state.indexing_running:
|
||||
return {"success": False, "message": "Indexing already in progress", "already_running": True}
|
||||
|
||||
search = get_semantic_search(app_state.db)
|
||||
|
||||
with app_state.db.get_connection(for_write=True) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('DELETE FROM content_embeddings')
|
||||
deleted = cursor.rowcount
|
||||
logger.info(f"Cleared {deleted} embeddings for reindexing", module="SemanticSearch")
|
||||
|
||||
app_state.indexing_running = True
|
||||
app_state.indexing_start_time = time.time()
|
||||
|
||||
def run_reindex():
|
||||
try:
|
||||
results = search.generate_embeddings_batch(limit=EMBEDDING_BATCH_LIMIT)
|
||||
logger.info(f"Reindex complete: {results}", module="SemanticSearch")
|
||||
except Exception as e:
|
||||
logger.error(f"Reindex failed: {e}", module="SemanticSearch")
|
||||
finally:
|
||||
app_state.indexing_running = False
|
||||
app_state.indexing_start_time = None
|
||||
|
||||
background_tasks.add_task(run_reindex)
|
||||
|
||||
return {"success": True, "message": "Re-indexing started in background"}
|
||||
|
||||
|
||||
@router.post("/clear")
|
||||
@limiter.limit("2/minute")
|
||||
@handle_exceptions
|
||||
async def clear_embeddings(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(require_admin)
|
||||
):
|
||||
"""Clear all embeddings."""
|
||||
app_state = get_app_state()
|
||||
|
||||
with app_state.db.get_connection(for_write=True) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('DELETE FROM content_embeddings')
|
||||
deleted = cursor.rowcount
|
||||
|
||||
logger.info(f"Cleared {deleted} embeddings", module="SemanticSearch")
|
||||
return {"success": True, "deleted": deleted}
|
||||
535
web/backend/routers/stats.py
Normal file
535
web/backend/routers/stats.py
Normal file
@@ -0,0 +1,535 @@
|
||||
"""
|
||||
Stats Router
|
||||
|
||||
Handles statistics, monitoring, settings, and integrations:
|
||||
- Dashboard statistics
|
||||
- Downloader monitoring
|
||||
- Settings management
|
||||
- Immich integration
|
||||
"""
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from ..core.dependencies import get_current_user, require_admin, get_app_state
|
||||
from ..core.exceptions import handle_exceptions, NotFoundError, ValidationError
|
||||
from modules.universal_logger import get_logger
|
||||
|
||||
logger = get_logger('API')
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["Stats & Monitoring"])
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PYDANTIC MODELS
|
||||
# ============================================================================
|
||||
|
||||
class SettingUpdate(BaseModel):
|
||||
value: dict | list | str | int | float | bool
|
||||
category: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DASHBOARD STATISTICS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/stats/dashboard")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_dashboard_stats(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get comprehensive dashboard statistics."""
|
||||
app_state = get_app_state()
|
||||
|
||||
with app_state.db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get download counts per platform (combine downloads and video_downloads)
|
||||
cursor.execute("""
|
||||
SELECT platform, SUM(cnt) as count FROM (
|
||||
SELECT platform, COUNT(*) as cnt FROM downloads GROUP BY platform
|
||||
UNION ALL
|
||||
SELECT platform, COUNT(*) as cnt FROM video_downloads GROUP BY platform
|
||||
) GROUP BY platform
|
||||
""")
|
||||
|
||||
platform_data = {}
|
||||
for row in cursor.fetchall():
|
||||
platform = row[0]
|
||||
if platform not in platform_data:
|
||||
platform_data[platform] = {
|
||||
'count': 0,
|
||||
'size_bytes': 0
|
||||
}
|
||||
platform_data[platform]['count'] += row[1]
|
||||
|
||||
# Calculate storage sizes from file_inventory (final + review)
|
||||
cursor.execute("""
|
||||
SELECT platform, COALESCE(SUM(file_size), 0) as total_size
|
||||
FROM file_inventory
|
||||
WHERE location IN ('final', 'review')
|
||||
GROUP BY platform
|
||||
""")
|
||||
|
||||
for row in cursor.fetchall():
|
||||
platform = row[0]
|
||||
if platform not in platform_data:
|
||||
platform_data[platform] = {'count': 0, 'size_bytes': 0}
|
||||
platform_data[platform]['size_bytes'] += row[1]
|
||||
|
||||
# Only show platforms with actual files
|
||||
storage_by_platform = []
|
||||
for platform in sorted(platform_data.keys(), key=lambda p: platform_data[p]['size_bytes'], reverse=True):
|
||||
if platform_data[platform]['size_bytes'] > 0:
|
||||
storage_by_platform.append({
|
||||
'platform': platform,
|
||||
'count': platform_data[platform]['count'],
|
||||
'size_bytes': platform_data[platform]['size_bytes'],
|
||||
'size_mb': round(platform_data[platform]['size_bytes'] / 1024 / 1024, 2)
|
||||
})
|
||||
|
||||
# Downloads per day (last 30 days) - combine downloads and video_downloads
|
||||
cursor.execute("""
|
||||
SELECT date, SUM(count) as count FROM (
|
||||
SELECT DATE(download_date) as date, COUNT(*) as count
|
||||
FROM downloads
|
||||
WHERE download_date >= DATE('now', '-30 days')
|
||||
GROUP BY DATE(download_date)
|
||||
UNION ALL
|
||||
SELECT DATE(download_date) as date, COUNT(*) as count
|
||||
FROM video_downloads
|
||||
WHERE download_date >= DATE('now', '-30 days')
|
||||
GROUP BY DATE(download_date)
|
||||
) GROUP BY date ORDER BY date
|
||||
""")
|
||||
downloads_per_day = [{'date': row[0], 'count': row[1]} for row in cursor.fetchall()]
|
||||
|
||||
# Content type breakdown
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
content_type,
|
||||
COUNT(*) as count
|
||||
FROM downloads
|
||||
WHERE content_type IS NOT NULL
|
||||
GROUP BY content_type
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
content_types = {row[0]: row[1] for row in cursor.fetchall()}
|
||||
|
||||
# Top sources
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
source,
|
||||
platform,
|
||||
COUNT(*) as count
|
||||
FROM downloads
|
||||
WHERE source IS NOT NULL
|
||||
GROUP BY source, platform
|
||||
ORDER BY count DESC
|
||||
LIMIT 10
|
||||
""")
|
||||
top_sources = [{'source': row[0], 'platform': row[1], 'count': row[2]} for row in cursor.fetchall()]
|
||||
|
||||
# Total statistics - use file_inventory for accurate file counts
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
(SELECT COUNT(*) FROM file_inventory WHERE location IN ('final', 'review')) as total_downloads,
|
||||
(SELECT COALESCE(SUM(file_size), 0) FROM file_inventory WHERE location IN ('final', 'review')) as total_size,
|
||||
(SELECT COUNT(DISTINCT source) FROM downloads) +
|
||||
(SELECT COUNT(DISTINCT uploader) FROM video_downloads) as unique_sources,
|
||||
(SELECT COUNT(DISTINCT platform) FROM file_inventory) as platforms_used
|
||||
""")
|
||||
totals = cursor.fetchone()
|
||||
|
||||
# Get recycle bin and review counts separately
|
||||
cursor.execute("SELECT COUNT(*) FROM recycle_bin")
|
||||
recycle_count = cursor.fetchone()[0] or 0
|
||||
|
||||
cursor.execute("SELECT COUNT(*) FROM file_inventory WHERE location = 'review'")
|
||||
review_count = cursor.fetchone()[0] or 0
|
||||
|
||||
# Growth rate - combine downloads and video_downloads
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
(SELECT SUM(CASE WHEN download_date >= DATE('now', '-7 days') THEN 1 ELSE 0 END) FROM downloads) +
|
||||
(SELECT SUM(CASE WHEN download_date >= DATE('now', '-7 days') THEN 1 ELSE 0 END) FROM video_downloads) as this_week,
|
||||
(SELECT SUM(CASE WHEN download_date >= DATE('now', '-14 days') AND download_date < DATE('now', '-7 days') THEN 1 ELSE 0 END) FROM downloads) +
|
||||
(SELECT SUM(CASE WHEN download_date >= DATE('now', '-14 days') AND download_date < DATE('now', '-7 days') THEN 1 ELSE 0 END) FROM video_downloads) as last_week
|
||||
""")
|
||||
growth_row = cursor.fetchone()
|
||||
growth_rate = 0
|
||||
if growth_row and growth_row[1] > 0:
|
||||
growth_rate = round(((growth_row[0] - growth_row[1]) / growth_row[1]) * 100, 1)
|
||||
|
||||
return {
|
||||
'storage_by_platform': storage_by_platform,
|
||||
'downloads_per_day': downloads_per_day,
|
||||
'content_types': content_types,
|
||||
'top_sources': top_sources,
|
||||
'totals': {
|
||||
'total_downloads': totals[0] or 0,
|
||||
'total_size_bytes': totals[1] or 0,
|
||||
'total_size_gb': round((totals[1] or 0) / 1024 / 1024 / 1024, 2),
|
||||
'unique_sources': totals[2] or 0,
|
||||
'platforms_used': totals[3] or 0,
|
||||
'recycle_bin_count': recycle_count,
|
||||
'review_count': review_count
|
||||
},
|
||||
'growth_rate': growth_rate
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# FLARESOLVERR HEALTH CHECK
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/health/flaresolverr")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def check_flaresolverr_health(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Check FlareSolverr health status."""
|
||||
app_state = get_app_state()
|
||||
|
||||
flaresolverr_url = "http://localhost:8191/v1"
|
||||
|
||||
try:
|
||||
with app_state.db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT value FROM settings WHERE key='flaresolverr'")
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
flaresolverr_config = json.loads(result[0])
|
||||
if 'url' in flaresolverr_config:
|
||||
flaresolverr_url = flaresolverr_config['url']
|
||||
except (sqlite3.Error, json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
flaresolverr_url,
|
||||
json={"cmd": "sessions.list"},
|
||||
timeout=5
|
||||
)
|
||||
response_time = round((time.time() - start_time) * 1000, 2)
|
||||
|
||||
if response.status_code == 200:
|
||||
return {
|
||||
'status': 'healthy',
|
||||
'url': flaresolverr_url,
|
||||
'response_time_ms': response_time,
|
||||
'last_check': time.time(),
|
||||
'sessions': response.json().get('sessions', [])
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'status': 'unhealthy',
|
||||
'url': flaresolverr_url,
|
||||
'response_time_ms': response_time,
|
||||
'last_check': time.time(),
|
||||
'error': f"HTTP {response.status_code}: {response.text}"
|
||||
}
|
||||
except requests.exceptions.ConnectionError:
|
||||
return {
|
||||
'status': 'offline',
|
||||
'url': flaresolverr_url,
|
||||
'last_check': time.time(),
|
||||
'error': 'Connection refused - FlareSolverr may not be running'
|
||||
}
|
||||
except requests.exceptions.Timeout:
|
||||
return {
|
||||
'status': 'timeout',
|
||||
'url': flaresolverr_url,
|
||||
'last_check': time.time(),
|
||||
'error': 'Request timed out after 5 seconds'
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
'status': 'error',
|
||||
'url': flaresolverr_url,
|
||||
'last_check': time.time(),
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# MONITORING ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/monitoring/status")
|
||||
@limiter.limit("100/minute")
|
||||
@handle_exceptions
|
||||
async def get_monitoring_status(
|
||||
request: Request,
|
||||
hours: int = 24,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get downloader monitoring status."""
|
||||
from modules.downloader_monitor import get_monitor
|
||||
|
||||
app_state = get_app_state()
|
||||
monitor = get_monitor(app_state.db, app_state.settings)
|
||||
status = monitor.get_downloader_status(hours=hours)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"downloaders": status,
|
||||
"window_hours": hours
|
||||
}
|
||||
|
||||
|
||||
@router.get("/monitoring/history")
|
||||
@limiter.limit("100/minute")
|
||||
@handle_exceptions
|
||||
async def get_monitoring_history(
|
||||
request: Request,
|
||||
downloader: str = None,
|
||||
limit: int = 100,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get download monitoring history."""
|
||||
app_state = get_app_state()
|
||||
|
||||
with app_state.db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
if downloader:
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
id, downloader, username, timestamp, success,
|
||||
file_count, error_message, alert_sent
|
||||
FROM download_monitor
|
||||
WHERE downloader = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""", (downloader, limit))
|
||||
else:
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
id, downloader, username, timestamp, success,
|
||||
file_count, error_message, alert_sent
|
||||
FROM download_monitor
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""", (limit,))
|
||||
|
||||
history = []
|
||||
for row in cursor.fetchall():
|
||||
history.append({
|
||||
'id': row['id'],
|
||||
'downloader': row['downloader'],
|
||||
'username': row['username'],
|
||||
'timestamp': row['timestamp'],
|
||||
'success': bool(row['success']),
|
||||
'file_count': row['file_count'],
|
||||
'error_message': row['error_message'],
|
||||
'alert_sent': bool(row['alert_sent'])
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"history": history
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/monitoring/history")
|
||||
@limiter.limit("10/minute")
|
||||
@handle_exceptions
|
||||
async def clear_monitoring_history(
|
||||
request: Request,
|
||||
days: int = 30,
|
||||
current_user: Dict = Depends(require_admin)
|
||||
):
|
||||
"""Clear old monitoring logs."""
|
||||
from modules.downloader_monitor import get_monitor
|
||||
|
||||
app_state = get_app_state()
|
||||
monitor = get_monitor(app_state.db, app_state.settings)
|
||||
monitor.clear_old_logs(days=days)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Cleared logs older than {days} days"
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SETTINGS ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/settings/{key}")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_setting(
|
||||
request: Request,
|
||||
key: str,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get a specific setting value."""
|
||||
app_state = get_app_state()
|
||||
value = app_state.settings.get(key)
|
||||
if value is None:
|
||||
raise NotFoundError(f"Setting '{key}' not found")
|
||||
return value
|
||||
|
||||
|
||||
@router.put("/settings/{key}")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def update_setting(
|
||||
request: Request,
|
||||
key: str,
|
||||
body: Dict,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Update a specific setting value."""
|
||||
app_state = get_app_state()
|
||||
|
||||
value = body.get('value')
|
||||
if value is None:
|
||||
raise ValidationError("Missing 'value' in request body")
|
||||
|
||||
app_state.settings.set(
|
||||
key=key,
|
||||
value=value,
|
||||
category=body.get('category'),
|
||||
description=body.get('description'),
|
||||
updated_by=current_user.get('username', 'user')
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Setting '{key}' updated successfully"
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# IMMICH INTEGRATION
|
||||
# ============================================================================
|
||||
|
||||
@router.post("/immich/scan")
|
||||
@limiter.limit("10/minute")
|
||||
@handle_exceptions
|
||||
async def trigger_immich_scan(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Trigger Immich library scan."""
|
||||
app_state = get_app_state()
|
||||
|
||||
immich_config = app_state.settings.get('immich', {})
|
||||
|
||||
if not immich_config.get('enabled'):
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Immich integration is not enabled"
|
||||
}
|
||||
|
||||
api_url = immich_config.get('api_url')
|
||||
api_key = immich_config.get('api_key')
|
||||
library_id = immich_config.get('library_id')
|
||||
|
||||
if not all([api_url, api_key, library_id]):
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Immich configuration incomplete (missing api_url, api_key, or library_id)"
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{api_url}/libraries/{library_id}/scan",
|
||||
headers={'X-API-KEY': api_key},
|
||||
timeout=10
|
||||
)
|
||||
|
||||
if response.status_code in [200, 201, 204]:
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Successfully triggered Immich scan for library {library_id}"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Immich scan request failed with status {response.status_code}: {response.text}"
|
||||
}
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Failed to connect to Immich: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ERROR MONITORING SETTINGS
|
||||
# ============================================================================
|
||||
|
||||
class ErrorMonitoringSettings(BaseModel):
|
||||
enabled: bool = True
|
||||
push_alert_enabled: bool = True
|
||||
push_alert_delay_hours: int = 24
|
||||
dashboard_banner_enabled: bool = True
|
||||
retention_days: int = 7
|
||||
|
||||
|
||||
@router.get("/error-monitoring/settings")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_error_monitoring_settings(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get error monitoring settings."""
|
||||
app_state = get_app_state()
|
||||
|
||||
settings = app_state.settings.get('error_monitoring', {
|
||||
'enabled': True,
|
||||
'push_alert_enabled': True,
|
||||
'push_alert_delay_hours': 24,
|
||||
'dashboard_banner_enabled': True,
|
||||
'retention_days': 7
|
||||
})
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
@router.put("/error-monitoring/settings")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def update_error_monitoring_settings(
|
||||
request: Request,
|
||||
settings: ErrorMonitoringSettings,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Update error monitoring settings."""
|
||||
app_state = get_app_state()
|
||||
|
||||
app_state.settings.set(
|
||||
key='error_monitoring',
|
||||
value=settings.model_dump(),
|
||||
category='monitoring',
|
||||
description='Error monitoring and alert settings',
|
||||
updated_by=current_user.get('username', 'user')
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Error monitoring settings updated",
|
||||
"settings": settings.model_dump()
|
||||
}
|
||||
1617
web/backend/routers/video.py
Normal file
1617
web/backend/routers/video.py
Normal file
File diff suppressed because it is too large
Load Diff
1884
web/backend/routers/video_queue.py
Normal file
1884
web/backend/routers/video_queue.py
Normal file
File diff suppressed because it is too large
Load Diff
1
web/backend/services/__init__.py
Normal file
1
web/backend/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Services module exports
|
||||
524
web/backend/totp_manager.py
Normal file
524
web/backend/totp_manager.py
Normal file
@@ -0,0 +1,524 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
TOTP Manager for Media Downloader
|
||||
Handles Time-based One-Time Password (TOTP) operations for 2FA
|
||||
|
||||
Based on backup-central's implementation
|
||||
"""
|
||||
|
||||
import sys
|
||||
import pyotp
|
||||
import qrcode
|
||||
import io
|
||||
import base64
|
||||
import sqlite3
|
||||
import hashlib
|
||||
import bcrypt
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, List, Tuple
|
||||
from cryptography.fernet import Fernet
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent path to allow imports from modules
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from modules.universal_logger import get_logger
|
||||
|
||||
logger = get_logger('TOTPManager')
|
||||
|
||||
|
||||
class TOTPManager:
|
||||
def __init__(self, db_path: str = None):
|
||||
if db_path is None:
|
||||
db_path = str(Path(__file__).parent.parent.parent / 'database' / 'auth.db')
|
||||
|
||||
self.db_path = db_path
|
||||
self.issuer = 'Media Downloader'
|
||||
self.window = 1 # Allow 1 step (30s) tolerance for time skew
|
||||
|
||||
# Encryption key for TOTP secrets
|
||||
self.encryption_key = self._get_encryption_key()
|
||||
self.cipher_suite = Fernet(self.encryption_key)
|
||||
|
||||
# Rate limiting configuration
|
||||
self.max_attempts = 5
|
||||
self.rate_limit_window = timedelta(minutes=15)
|
||||
self.lockout_duration = timedelta(minutes=30)
|
||||
|
||||
# Initialize database tables
|
||||
self._init_database()
|
||||
|
||||
def _get_encryption_key(self) -> bytes:
|
||||
"""Get or generate encryption key for TOTP secrets"""
|
||||
key_file = Path(__file__).parent.parent.parent / '.totp_encryption_key'
|
||||
|
||||
if key_file.exists():
|
||||
with open(key_file, 'rb') as f:
|
||||
return f.read()
|
||||
|
||||
# Generate new key
|
||||
key = Fernet.generate_key()
|
||||
try:
|
||||
with open(key_file, 'wb') as f:
|
||||
f.write(key)
|
||||
import os
|
||||
os.chmod(key_file, 0o600)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not save TOTP encryption key: {e}", module="TOTP")
|
||||
|
||||
return key
|
||||
|
||||
def _init_database(self):
|
||||
"""Initialize TOTP-related database tables"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Ensure TOTP columns exist in users table
|
||||
try:
|
||||
cursor.execute("ALTER TABLE users ADD COLUMN totp_secret TEXT")
|
||||
except sqlite3.OperationalError:
|
||||
pass # Column already exists
|
||||
|
||||
try:
|
||||
cursor.execute("ALTER TABLE users ADD COLUMN totp_enabled INTEGER NOT NULL DEFAULT 0")
|
||||
except sqlite3.OperationalError:
|
||||
pass
|
||||
|
||||
try:
|
||||
cursor.execute("ALTER TABLE users ADD COLUMN totp_enrolled_at TEXT")
|
||||
except sqlite3.OperationalError:
|
||||
pass
|
||||
|
||||
# TOTP rate limiting table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS totp_rate_limit (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT NOT NULL,
|
||||
ip_address TEXT,
|
||||
attempts INTEGER NOT NULL DEFAULT 1,
|
||||
window_start TEXT NOT NULL,
|
||||
locked_until TEXT,
|
||||
UNIQUE(username, ip_address)
|
||||
)
|
||||
""")
|
||||
|
||||
# TOTP audit log
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS totp_audit_log (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
success INTEGER NOT NULL,
|
||||
ip_address TEXT,
|
||||
user_agent TEXT,
|
||||
details TEXT,
|
||||
timestamp TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
|
||||
def generate_secret(self, username: str) -> Dict:
|
||||
"""
|
||||
Generate a new TOTP secret and QR code for a user
|
||||
|
||||
Returns:
|
||||
dict with secret, qrCodeDataURL, manualEntryKey
|
||||
"""
|
||||
try:
|
||||
# Generate secret
|
||||
secret = pyotp.random_base32()
|
||||
|
||||
# Create provisioning URI
|
||||
totp = pyotp.TOTP(secret)
|
||||
provisioning_uri = totp.provisioning_uri(
|
||||
name=username,
|
||||
issuer_name=self.issuer
|
||||
)
|
||||
|
||||
# Generate QR code
|
||||
qr = qrcode.QRCode(version=1, box_size=10, border=4)
|
||||
qr.add_data(provisioning_uri)
|
||||
qr.make(fit=True)
|
||||
|
||||
img = qr.make_image(fill_color="black", back_color="white")
|
||||
|
||||
# Convert to base64 data URL
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='PNG')
|
||||
img_str = base64.b64encode(buffer.getvalue()).decode()
|
||||
qr_code_data_url = f"data:image/png;base64,{img_str}"
|
||||
|
||||
self._log_audit(username, 'totp_secret_generated', True, None, None,
|
||||
'TOTP secret generated for user')
|
||||
|
||||
return {
|
||||
'secret': secret,
|
||||
'qrCodeDataURL': qr_code_data_url,
|
||||
'manualEntryKey': secret,
|
||||
'otpauthURL': provisioning_uri
|
||||
}
|
||||
except Exception as error:
|
||||
self._log_audit(username, 'totp_secret_generation_failed', False, None, None,
|
||||
f'Error: {str(error)}')
|
||||
raise
|
||||
|
||||
def verify_token(self, secret: str, token: str) -> bool:
|
||||
"""
|
||||
Verify a TOTP token
|
||||
|
||||
Args:
|
||||
secret: Base32 encoded secret
|
||||
token: 6-digit TOTP code
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
totp = pyotp.TOTP(secret)
|
||||
return totp.verify(token, valid_window=self.window)
|
||||
except Exception as e:
|
||||
logger.error(f"TOTP verification error: {e}", module="TOTP")
|
||||
return False
|
||||
|
||||
def encrypt_secret(self, secret: str) -> str:
|
||||
"""Encrypt TOTP secret for storage"""
|
||||
encrypted = self.cipher_suite.encrypt(secret.encode())
|
||||
return base64.b64encode(encrypted).decode()
|
||||
|
||||
def decrypt_secret(self, encrypted_secret: str) -> str:
|
||||
"""Decrypt TOTP secret from storage"""
|
||||
encrypted = base64.b64decode(encrypted_secret.encode())
|
||||
decrypted = self.cipher_suite.decrypt(encrypted)
|
||||
return decrypted.decode()
|
||||
|
||||
def generate_backup_codes(self, count: int = 10) -> List[str]:
|
||||
"""
|
||||
Generate backup codes for recovery
|
||||
|
||||
Args:
|
||||
count: Number of codes to generate
|
||||
|
||||
Returns:
|
||||
List of formatted backup codes
|
||||
"""
|
||||
codes = []
|
||||
for _ in range(count):
|
||||
# Generate 8-character hex code
|
||||
code = secrets.token_hex(4).upper()
|
||||
# Format as XXXX-XXXX
|
||||
formatted = f"{code[:4]}-{code[4:8]}"
|
||||
codes.append(formatted)
|
||||
return codes
|
||||
|
||||
def hash_backup_code(self, code: str) -> str:
|
||||
"""
|
||||
Hash a backup code for storage using bcrypt
|
||||
|
||||
Args:
|
||||
code: Backup code to hash
|
||||
|
||||
Returns:
|
||||
Bcrypt hash of the code
|
||||
"""
|
||||
# Use bcrypt with default work factor (12 rounds)
|
||||
return bcrypt.hashpw(code.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
|
||||
|
||||
def verify_backup_code(self, code: str, username: str) -> Tuple[bool, int]:
|
||||
"""
|
||||
Verify a backup code and mark it as used
|
||||
|
||||
Returns:
|
||||
Tuple of (valid, remaining_codes)
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get unused backup codes for user
|
||||
cursor.execute("""
|
||||
SELECT id, code_hash FROM backup_codes
|
||||
WHERE username = ? AND used = 0
|
||||
""", (username,))
|
||||
|
||||
rows = cursor.fetchall()
|
||||
|
||||
# Check if code matches any unused code
|
||||
for row_id, stored_hash in rows:
|
||||
# Support both bcrypt (new) and SHA256 (legacy) hashes
|
||||
is_match = False
|
||||
|
||||
if stored_hash.startswith('$2b$'):
|
||||
# Bcrypt hash
|
||||
try:
|
||||
is_match = bcrypt.checkpw(code.encode('utf-8'), stored_hash.encode('utf-8'))
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying bcrypt backup code: {e}", module="TOTP")
|
||||
continue
|
||||
else:
|
||||
# Legacy SHA256 hash
|
||||
legacy_hash = hashlib.sha256(code.encode()).hexdigest()
|
||||
is_match = (stored_hash == legacy_hash)
|
||||
|
||||
if is_match:
|
||||
# Mark as used
|
||||
cursor.execute("""
|
||||
UPDATE backup_codes
|
||||
SET used = 1, used_at = ?
|
||||
WHERE id = ?
|
||||
""", (datetime.now().isoformat(), row_id))
|
||||
|
||||
conn.commit()
|
||||
|
||||
# Count remaining codes
|
||||
cursor.execute("""
|
||||
SELECT COUNT(*) FROM backup_codes
|
||||
WHERE username = ? AND used = 0
|
||||
""", (username,))
|
||||
|
||||
remaining = cursor.fetchone()[0]
|
||||
return True, remaining
|
||||
|
||||
return False, len(rows)
|
||||
|
||||
def enable_totp(self, username: str, secret: str, backup_codes: List[str]) -> bool:
|
||||
"""
|
||||
Enable TOTP for a user
|
||||
|
||||
Args:
|
||||
username: Username
|
||||
secret: TOTP secret
|
||||
backup_codes: List of backup codes
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Encrypt secret
|
||||
encrypted_secret = self.encrypt_secret(secret)
|
||||
|
||||
# Update user
|
||||
cursor.execute("""
|
||||
UPDATE users
|
||||
SET totp_secret = ?, totp_enabled = 1, totp_enrolled_at = ?
|
||||
WHERE username = ?
|
||||
""", (encrypted_secret, datetime.now().isoformat(), username))
|
||||
|
||||
# Delete old backup codes
|
||||
cursor.execute("DELETE FROM backup_codes WHERE username = ?", (username,))
|
||||
|
||||
# Insert new backup codes
|
||||
for code in backup_codes:
|
||||
code_hash = self.hash_backup_code(code)
|
||||
cursor.execute("""
|
||||
INSERT INTO backup_codes (username, code_hash, created_at)
|
||||
VALUES (?, ?, ?)
|
||||
""", (username, code_hash, datetime.now().isoformat()))
|
||||
|
||||
conn.commit()
|
||||
|
||||
self._log_audit(username, 'totp_enabled', True, None, None,
|
||||
f'TOTP enabled with {len(backup_codes)} backup codes')
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error enabling TOTP: {e}", module="TOTP")
|
||||
return False
|
||||
|
||||
def disable_totp(self, username: str) -> bool:
|
||||
"""
|
||||
Disable TOTP for a user
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
UPDATE users
|
||||
SET totp_secret = NULL, totp_enabled = 0, totp_enrolled_at = NULL
|
||||
WHERE username = ?
|
||||
""", (username,))
|
||||
|
||||
# Delete backup codes
|
||||
cursor.execute("DELETE FROM backup_codes WHERE username = ?", (username,))
|
||||
|
||||
conn.commit()
|
||||
|
||||
self._log_audit(username, 'totp_disabled', True, None, None, 'TOTP disabled')
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error disabling TOTP: {e}", module="TOTP")
|
||||
return False
|
||||
|
||||
def get_totp_status(self, username: str) -> Dict:
|
||||
"""
|
||||
Get TOTP status for a user
|
||||
|
||||
Returns:
|
||||
dict with enabled, enrolledAt
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT totp_enabled, totp_enrolled_at FROM users
|
||||
WHERE username = ?
|
||||
""", (username,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
return {'enabled': False, 'enrolledAt': None}
|
||||
|
||||
return {
|
||||
'enabled': bool(row[0]),
|
||||
'enrolledAt': row[1]
|
||||
}
|
||||
|
||||
def check_rate_limit(self, username: str, ip_address: str) -> Tuple[bool, int, Optional[str]]:
|
||||
"""
|
||||
Check and enforce rate limiting for TOTP verification
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed, attempts_remaining, locked_until)
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
now = datetime.now()
|
||||
window_start = now - self.rate_limit_window
|
||||
|
||||
cursor.execute("""
|
||||
SELECT attempts, window_start, locked_until
|
||||
FROM totp_rate_limit
|
||||
WHERE username = ? AND ip_address = ?
|
||||
""", (username, ip_address))
|
||||
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
# First attempt - create record
|
||||
cursor.execute("""
|
||||
INSERT INTO totp_rate_limit (username, ip_address, attempts, window_start)
|
||||
VALUES (?, ?, 1, ?)
|
||||
""", (username, ip_address, now.isoformat()))
|
||||
conn.commit()
|
||||
return True, self.max_attempts - 1, None
|
||||
|
||||
attempts, stored_window_start, locked_until = row
|
||||
stored_window_start = datetime.fromisoformat(stored_window_start)
|
||||
|
||||
# Check if locked out
|
||||
if locked_until:
|
||||
locked_until_dt = datetime.fromisoformat(locked_until)
|
||||
if locked_until_dt > now:
|
||||
return False, 0, locked_until
|
||||
else:
|
||||
# Lockout expired - reset
|
||||
cursor.execute("""
|
||||
DELETE FROM totp_rate_limit
|
||||
WHERE username = ? AND ip_address = ?
|
||||
""", (username, ip_address))
|
||||
conn.commit()
|
||||
return True, self.max_attempts, None
|
||||
|
||||
# Check if window expired
|
||||
if stored_window_start < window_start:
|
||||
# Reset window
|
||||
cursor.execute("""
|
||||
UPDATE totp_rate_limit
|
||||
SET attempts = 1, window_start = ?, locked_until = NULL
|
||||
WHERE username = ? AND ip_address = ?
|
||||
""", (now.isoformat(), username, ip_address))
|
||||
conn.commit()
|
||||
return True, self.max_attempts - 1, None
|
||||
|
||||
# Increment attempts
|
||||
new_attempts = attempts + 1
|
||||
|
||||
if new_attempts >= self.max_attempts:
|
||||
# Lock out
|
||||
locked_until = (now + self.lockout_duration).isoformat()
|
||||
cursor.execute("""
|
||||
UPDATE totp_rate_limit
|
||||
SET attempts = ?, locked_until = ?
|
||||
WHERE username = ? AND ip_address = ?
|
||||
""", (new_attempts, locked_until, username, ip_address))
|
||||
conn.commit()
|
||||
return False, 0, locked_until
|
||||
else:
|
||||
cursor.execute("""
|
||||
UPDATE totp_rate_limit
|
||||
SET attempts = ?
|
||||
WHERE username = ? AND ip_address = ?
|
||||
""", (new_attempts, username, ip_address))
|
||||
conn.commit()
|
||||
return True, self.max_attempts - new_attempts, None
|
||||
|
||||
def reset_rate_limit(self, username: str, ip_address: str):
|
||||
"""Reset rate limit after successful authentication"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
DELETE FROM totp_rate_limit
|
||||
WHERE username = ? AND ip_address = ?
|
||||
""", (username, ip_address))
|
||||
|
||||
conn.commit()
|
||||
|
||||
def regenerate_backup_codes(self, username: str) -> List[str]:
|
||||
"""
|
||||
Regenerate backup codes for a user
|
||||
|
||||
Returns:
|
||||
List of new backup codes
|
||||
"""
|
||||
# Generate new codes
|
||||
new_codes = self.generate_backup_codes(10)
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Delete old codes
|
||||
cursor.execute("DELETE FROM backup_codes WHERE username = ?", (username,))
|
||||
|
||||
# Insert new codes
|
||||
for code in new_codes:
|
||||
code_hash = self.hash_backup_code(code)
|
||||
cursor.execute("""
|
||||
INSERT INTO backup_codes (username, code_hash, created_at)
|
||||
VALUES (?, ?, ?)
|
||||
""", (username, code_hash, datetime.now().isoformat()))
|
||||
|
||||
conn.commit()
|
||||
|
||||
self._log_audit(username, 'backup_codes_regenerated', True, None, None,
|
||||
f'Generated {len(new_codes)} new backup codes')
|
||||
|
||||
return new_codes
|
||||
|
||||
def _log_audit(self, username: str, action: str, success: bool,
|
||||
ip_address: Optional[str], user_agent: Optional[str],
|
||||
details: Optional[str] = None):
|
||||
"""Log TOTP audit event"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
INSERT INTO totp_audit_log
|
||||
(username, action, success, ip_address, user_agent, details, timestamp)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""", (username, action, int(success), ip_address, user_agent, details,
|
||||
datetime.now().isoformat()))
|
||||
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging TOTP audit: {e}", module="TOTP")
|
||||
792
web/backend/twofa_routes.py
Normal file
792
web/backend/twofa_routes.py
Normal file
@@ -0,0 +1,792 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Two-Factor Authentication API Routes for Media Downloader
|
||||
Handles TOTP, Passkey, and Duo 2FA endpoints
|
||||
|
||||
Based on backup-central's implementation
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter, Depends, Request, Body
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing import Optional, Dict, List
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.config import settings
|
||||
|
||||
# Add parent path to allow imports from modules
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from modules.universal_logger import get_logger
|
||||
|
||||
logger = get_logger('TwoFactorAuth')
|
||||
|
||||
# Import managers
|
||||
from web.backend.totp_manager import TOTPManager
|
||||
from web.backend.duo_manager import DuoManager
|
||||
|
||||
# Passkey manager - optional (requires compatible webauthn library)
|
||||
try:
|
||||
from web.backend.passkey_manager import PasskeyManager
|
||||
PASSKEY_AVAILABLE = True
|
||||
logger.info("Passkey support enabled", module="2FA")
|
||||
except ImportError as e:
|
||||
logger.warning(f"Passkey support disabled: {e}", module="2FA")
|
||||
PasskeyManager = None
|
||||
PASSKEY_AVAILABLE = False
|
||||
|
||||
|
||||
# Pydantic models for request/response
|
||||
|
||||
# Standard response models
|
||||
class StandardResponse(BaseModel):
|
||||
"""Standard API response with success flag and message"""
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class TOTPSetupResponse(BaseModel):
|
||||
success: bool
|
||||
secret: Optional[str] = None
|
||||
qrCodeDataURL: Optional[str] = None
|
||||
manualEntryKey: Optional[str] = None
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class TOTPVerifyRequest(BaseModel):
|
||||
code: str
|
||||
|
||||
|
||||
class TOTPVerifyResponse(BaseModel):
|
||||
success: bool
|
||||
backupCodes: Optional[List[str]] = None
|
||||
message: str
|
||||
|
||||
|
||||
class TOTPLoginVerifyRequest(BaseModel):
|
||||
username: str
|
||||
code: str
|
||||
useBackupCode: Optional[bool] = False
|
||||
rememberMe: bool = False
|
||||
|
||||
|
||||
class TOTPDisableRequest(BaseModel):
|
||||
password: str
|
||||
code: str
|
||||
|
||||
|
||||
class PasskeyRegistrationOptionsResponse(BaseModel):
|
||||
success: bool
|
||||
options: Optional[Dict] = None
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class PasskeyVerifyRegistrationRequest(BaseModel):
|
||||
credential: Dict
|
||||
deviceName: Optional[str] = None
|
||||
|
||||
|
||||
class PasskeyAuthenticationOptionsResponse(BaseModel):
|
||||
success: bool
|
||||
options: Optional[Dict] = None
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class PasskeyVerifyAuthenticationRequest(BaseModel):
|
||||
credential: Dict
|
||||
username: Optional[str] = None
|
||||
rememberMe: bool = False
|
||||
|
||||
|
||||
class DuoAuthRequest(BaseModel):
|
||||
username: str
|
||||
rememberMe: bool = False
|
||||
|
||||
|
||||
class DuoAuthResponse(BaseModel):
|
||||
success: bool
|
||||
authUrl: Optional[str] = None
|
||||
state: Optional[str] = None
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class DuoCallbackRequest(BaseModel):
|
||||
state: str
|
||||
duo_code: Optional[str] = None
|
||||
|
||||
|
||||
def create_2fa_router(auth_manager, get_current_user_dependency):
|
||||
"""
|
||||
Create and configure the 2FA router
|
||||
|
||||
Args:
|
||||
auth_manager: AuthManager instance
|
||||
get_current_user_dependency: FastAPI dependency for authentication
|
||||
|
||||
Returns:
|
||||
Configured APIRouter
|
||||
"""
|
||||
router = APIRouter()
|
||||
|
||||
# Initialize managers
|
||||
totp_manager = TOTPManager()
|
||||
duo_manager = DuoManager()
|
||||
passkey_manager = PasskeyManager() if PASSKEY_AVAILABLE else None
|
||||
|
||||
# ============================================================================
|
||||
# TOTP Endpoints
|
||||
# ============================================================================
|
||||
|
||||
@router.post("/totp/setup", response_model=TOTPSetupResponse)
|
||||
async def totp_setup(request: Request, current_user: Dict = Depends(get_current_user_dependency)):
|
||||
"""Generate TOTP secret and QR code"""
|
||||
try:
|
||||
username = current_user.get('sub')
|
||||
|
||||
# Check if already enabled
|
||||
status = totp_manager.get_totp_status(username)
|
||||
if status['enabled']:
|
||||
return TOTPSetupResponse(
|
||||
success=False,
|
||||
message='2FA is already enabled for this account'
|
||||
)
|
||||
|
||||
# Generate secret and QR code
|
||||
result = totp_manager.generate_secret(username)
|
||||
|
||||
# Store in session temporarily (until verified)
|
||||
# Note: In production, use proper session management
|
||||
request.session['totp_setup'] = {
|
||||
'secret': result['secret'],
|
||||
'username': username
|
||||
}
|
||||
|
||||
return TOTPSetupResponse(
|
||||
success=True,
|
||||
secret=result['secret'],
|
||||
qrCodeDataURL=result['qrCodeDataURL'],
|
||||
manualEntryKey=result['manualEntryKey'],
|
||||
message='Scan QR code with your authenticator app'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"TOTP setup error: {e}", module="2FA")
|
||||
return TOTPSetupResponse(success=False, message='Failed to generate 2FA setup')
|
||||
|
||||
@router.post("/totp/verify", response_model=TOTPVerifyResponse)
|
||||
async def totp_verify(
|
||||
verify_data: TOTPVerifyRequest,
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user_dependency)
|
||||
):
|
||||
"""Verify TOTP code and enable 2FA"""
|
||||
try:
|
||||
username = current_user.get('sub')
|
||||
code = verify_data.code
|
||||
|
||||
# Validate code format
|
||||
if not code or len(code) != 6 or not code.isdigit():
|
||||
return TOTPVerifyResponse(
|
||||
success=False,
|
||||
message='Invalid code format. Must be 6 digits.'
|
||||
)
|
||||
|
||||
# Get temporary secret from session
|
||||
totp_setup = request.session.get('totp_setup')
|
||||
if not totp_setup or totp_setup.get('username') != username:
|
||||
return TOTPVerifyResponse(
|
||||
success=False,
|
||||
message='No setup in progress. Please start 2FA setup again.'
|
||||
)
|
||||
|
||||
secret = totp_setup['secret']
|
||||
|
||||
# Verify the code
|
||||
if not totp_manager.verify_token(secret, code):
|
||||
return TOTPVerifyResponse(
|
||||
success=False,
|
||||
message='Invalid verification code. Please try again.'
|
||||
)
|
||||
|
||||
# Generate backup codes
|
||||
backup_codes = totp_manager.generate_backup_codes(10)
|
||||
|
||||
# Enable TOTP
|
||||
totp_manager.enable_totp(username, secret, backup_codes)
|
||||
|
||||
# Clear session
|
||||
request.session.pop('totp_setup', None)
|
||||
|
||||
return TOTPVerifyResponse(
|
||||
success=True,
|
||||
backupCodes=backup_codes,
|
||||
message='2FA enabled successfully! Save your backup codes.'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"TOTP verify error: {e}", module="2FA")
|
||||
return TOTPVerifyResponse(success=False, message='Failed to verify code')
|
||||
|
||||
@router.post("/totp/disable")
|
||||
async def totp_disable(
|
||||
disable_data: TOTPDisableRequest,
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user_dependency)
|
||||
):
|
||||
"""Disable TOTP 2FA"""
|
||||
try:
|
||||
username = current_user.get('sub')
|
||||
password = disable_data.password
|
||||
code = disable_data.code
|
||||
|
||||
# Verify password
|
||||
user_info = auth_manager.get_user(username)
|
||||
if not user_info:
|
||||
return {'success': False, 'message': 'User not found'}
|
||||
|
||||
# Get user password hash from database
|
||||
import sqlite3
|
||||
with sqlite3.connect(auth_manager.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT password_hash, totp_secret FROM users WHERE username = ?", (username,))
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
return {'success': False, 'message': 'User not found'}
|
||||
|
||||
password_hash, encrypted_secret = row
|
||||
|
||||
if not auth_manager.verify_password(password, password_hash):
|
||||
return {'success': False, 'message': 'Invalid password'}
|
||||
|
||||
# Verify current TOTP code
|
||||
if encrypted_secret:
|
||||
decrypted_secret = totp_manager.decrypt_secret(encrypted_secret)
|
||||
if not totp_manager.verify_token(decrypted_secret, code):
|
||||
return {'success': False, 'message': 'Invalid verification code'}
|
||||
|
||||
# Disable TOTP
|
||||
totp_manager.disable_totp(username)
|
||||
|
||||
return {'success': True, 'message': '2FA has been disabled'}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"TOTP disable error: {e}", module="2FA")
|
||||
return {'success': False, 'message': 'Failed to disable 2FA'}
|
||||
|
||||
@router.post("/totp/regenerate-backup-codes")
|
||||
async def totp_regenerate_backup_codes(
|
||||
code: str = Body(..., embed=True),
|
||||
current_user: Dict = Depends(get_current_user_dependency)
|
||||
):
|
||||
"""Regenerate backup codes"""
|
||||
try:
|
||||
username = current_user.get('sub')
|
||||
|
||||
# Get user
|
||||
import sqlite3
|
||||
with sqlite3.connect(auth_manager.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT totp_enabled, totp_secret FROM users WHERE username = ?", (username,))
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row or not row[0]:
|
||||
return {'success': False, 'message': '2FA is not enabled'}
|
||||
|
||||
# Verify current code
|
||||
encrypted_secret = row[1]
|
||||
if encrypted_secret:
|
||||
decrypted_secret = totp_manager.decrypt_secret(encrypted_secret)
|
||||
if not totp_manager.verify_token(decrypted_secret, code):
|
||||
return {'success': False, 'message': 'Invalid verification code'}
|
||||
|
||||
# Generate new backup codes
|
||||
backup_codes = totp_manager.regenerate_backup_codes(username)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'backupCodes': backup_codes,
|
||||
'message': 'New backup codes generated'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Backup codes regeneration error: {e}", module="2FA")
|
||||
return {'success': False, 'message': 'Failed to regenerate backup codes'}
|
||||
|
||||
@router.get("/totp/status")
|
||||
async def totp_status(current_user: Dict = Depends(get_current_user_dependency)):
|
||||
"""Get TOTP status for current user"""
|
||||
try:
|
||||
username = current_user.get('sub')
|
||||
status = totp_manager.get_totp_status(username)
|
||||
|
||||
return {'success': True, **status}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"TOTP status error: {e}", module="2FA")
|
||||
return {'success': False, 'message': 'Failed to get 2FA status'}
|
||||
|
||||
@router.post("/totp/login/verify")
|
||||
async def totp_login_verify(verify_data: TOTPLoginVerifyRequest, request: Request):
|
||||
"""Verify TOTP code during login"""
|
||||
try:
|
||||
username = verify_data.username
|
||||
code = verify_data.code
|
||||
use_backup_code = verify_data.useBackupCode
|
||||
|
||||
# Get client IP
|
||||
ip_address = request.client.host if request.client else None
|
||||
|
||||
# Check rate limit
|
||||
allowed, attempts_remaining, locked_until = totp_manager.check_rate_limit(username, ip_address)
|
||||
|
||||
if not allowed:
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'Too many failed attempts. Locked until {locked_until}',
|
||||
'lockedUntil': locked_until
|
||||
}
|
||||
|
||||
# Get user
|
||||
import sqlite3
|
||||
with sqlite3.connect(auth_manager.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT password_hash, totp_enabled, totp_secret, role, email
|
||||
FROM users WHERE username = ?
|
||||
""", (username,))
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row or not row[1]: # Not TOTP enabled
|
||||
return {'success': False, 'message': 'Invalid request'}
|
||||
|
||||
password_hash, totp_enabled, encrypted_secret, role, email = row
|
||||
|
||||
is_valid = False
|
||||
|
||||
# Verify backup code or TOTP
|
||||
if use_backup_code:
|
||||
is_valid, remaining = totp_manager.verify_backup_code(code, username)
|
||||
else:
|
||||
if encrypted_secret:
|
||||
decrypted_secret = totp_manager.decrypt_secret(encrypted_secret)
|
||||
is_valid = totp_manager.verify_token(decrypted_secret, code)
|
||||
|
||||
if not is_valid:
|
||||
return {
|
||||
'success': False,
|
||||
'message': 'Invalid verification code',
|
||||
'attemptsRemaining': attempts_remaining - 1
|
||||
}
|
||||
|
||||
# Success - reset rate limit and create session
|
||||
totp_manager.reset_rate_limit(username, ip_address)
|
||||
|
||||
# Create session with remember_me setting
|
||||
session_result = auth_manager._create_session(username, role, ip_address, verify_data.rememberMe)
|
||||
|
||||
# Create response with cookie
|
||||
response = JSONResponse(content=session_result)
|
||||
max_age = 30 * 24 * 60 * 60 if verify_data.rememberMe else None
|
||||
response.set_cookie(
|
||||
key="auth_token",
|
||||
value=session_result.get('token'),
|
||||
max_age=max_age,
|
||||
httponly=True,
|
||||
secure=settings.SECURE_COOKIES,
|
||||
samesite="lax",
|
||||
path="/"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"TOTP login verification error: {e}", module="2FA")
|
||||
return {'success': False, 'message': 'Verification failed'}
|
||||
|
||||
# ============================================================================
|
||||
# Passkey Endpoints
|
||||
# ============================================================================
|
||||
|
||||
@router.post("/passkey/registration-options", response_model=PasskeyRegistrationOptionsResponse)
|
||||
async def passkey_registration_options(current_user: Dict = Depends(get_current_user_dependency)):
|
||||
"""Generate passkey registration options"""
|
||||
if not PASSKEY_AVAILABLE:
|
||||
return PasskeyRegistrationOptionsResponse(
|
||||
success=False,
|
||||
message='Passkey support is not available on this server'
|
||||
)
|
||||
|
||||
try:
|
||||
username = current_user.get('sub')
|
||||
email = current_user.get('email')
|
||||
|
||||
options = passkey_manager.generate_registration_options(username, email)
|
||||
|
||||
return PasskeyRegistrationOptionsResponse(
|
||||
success=True,
|
||||
options=options
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Passkey registration options error: {e}", module="2FA")
|
||||
return PasskeyRegistrationOptionsResponse(
|
||||
success=False,
|
||||
message='Failed to generate registration options. Please try again.'
|
||||
)
|
||||
|
||||
@router.post("/passkey/verify-registration")
|
||||
async def passkey_verify_registration(
|
||||
verify_data: PasskeyVerifyRegistrationRequest,
|
||||
current_user: Dict = Depends(get_current_user_dependency)
|
||||
):
|
||||
"""Verify passkey registration"""
|
||||
if not PASSKEY_AVAILABLE:
|
||||
return {'success': False, 'message': 'Passkey support is not available on this server'}
|
||||
|
||||
try:
|
||||
username = current_user.get('sub')
|
||||
|
||||
result = passkey_manager.verify_registration(
|
||||
username,
|
||||
verify_data.credential,
|
||||
verify_data.deviceName
|
||||
)
|
||||
|
||||
return {'success': True, **result}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Passkey registration verification error: {e}", module="2FA")
|
||||
return {'success': False, 'message': 'Failed to verify passkey registration. Please try again.'}
|
||||
|
||||
@router.post("/passkey/authentication-options", response_model=PasskeyAuthenticationOptionsResponse)
|
||||
async def passkey_authentication_options(username: Optional[str] = Body(None, embed=True)):
|
||||
"""Generate passkey authentication options"""
|
||||
if not PASSKEY_AVAILABLE:
|
||||
return PasskeyAuthenticationOptionsResponse(success=False, message='Passkey support is not available on this server')
|
||||
|
||||
try:
|
||||
options = passkey_manager.generate_authentication_options(username)
|
||||
|
||||
return PasskeyAuthenticationOptionsResponse(
|
||||
success=True,
|
||||
options=options
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Passkey authentication options error: {e}", module="2FA")
|
||||
return PasskeyAuthenticationOptionsResponse(
|
||||
success=False,
|
||||
message='Failed to generate authentication options. Please try again.'
|
||||
)
|
||||
|
||||
@router.post("/passkey/verify-authentication")
|
||||
async def passkey_verify_authentication(
|
||||
verify_data: PasskeyVerifyAuthenticationRequest,
|
||||
request: Request
|
||||
):
|
||||
"""Verify passkey authentication"""
|
||||
if not PASSKEY_AVAILABLE:
|
||||
return {'success': False, 'message': 'Passkey support is not available on this server'}
|
||||
|
||||
try:
|
||||
# Get client IP
|
||||
ip_address = request.client.host if request.client else None
|
||||
|
||||
# Verify authentication - use username from request if provided
|
||||
result = passkey_manager.verify_authentication(verify_data.username, verify_data.credential)
|
||||
|
||||
if result['success']:
|
||||
username = result['username']
|
||||
|
||||
# Get user info
|
||||
import sqlite3
|
||||
with sqlite3.connect(auth_manager.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT role, email FROM users WHERE username = ?", (username,))
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row:
|
||||
role, email = row
|
||||
|
||||
# Create session with remember_me setting
|
||||
session_result = auth_manager._create_session(username, role, ip_address, verify_data.rememberMe)
|
||||
|
||||
# Create response with cookie
|
||||
response = JSONResponse(content=session_result)
|
||||
max_age = 30 * 24 * 60 * 60 if verify_data.rememberMe else None
|
||||
response.set_cookie(
|
||||
key="auth_token",
|
||||
value=session_result.get('token'),
|
||||
max_age=max_age,
|
||||
httponly=True,
|
||||
secure=settings.SECURE_COOKIES,
|
||||
samesite="lax",
|
||||
path="/"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Passkey authentication verification error: {e}", module="2FA")
|
||||
return {'success': False, 'message': 'Failed to verify passkey authentication. Please try again.'}
|
||||
|
||||
@router.get("/passkey/list")
|
||||
async def passkey_list(current_user: Dict = Depends(get_current_user_dependency)):
|
||||
"""List user's passkeys"""
|
||||
if not PASSKEY_AVAILABLE:
|
||||
return {'success': False, 'message': 'Passkey support is not available on this server'}
|
||||
|
||||
try:
|
||||
username = current_user.get('sub')
|
||||
credentials = passkey_manager.list_credentials(username)
|
||||
|
||||
return {'success': True, 'credentials': credentials}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Passkey list error: {e}", module="2FA")
|
||||
return {'success': False, 'message': 'Failed to retrieve passkeys. Please try again.'}
|
||||
|
||||
@router.delete("/passkey/{credential_id}")
|
||||
async def passkey_remove(
|
||||
credential_id: str,
|
||||
current_user: Dict = Depends(get_current_user_dependency)
|
||||
):
|
||||
"""Remove a passkey"""
|
||||
if not PASSKEY_AVAILABLE:
|
||||
return {'success': False, 'message': 'Passkey support is not available on this server'}
|
||||
|
||||
try:
|
||||
username = current_user.get('sub')
|
||||
logger.debug(f"Removing passkey - Username: {username}, Credential ID: {credential_id}", module="2FA")
|
||||
logger.debug(f"Credential ID length: {len(credential_id)}, type: {type(credential_id)}", module="2FA")
|
||||
success = passkey_manager.remove_credential(username, credential_id)
|
||||
|
||||
if success:
|
||||
return {'success': True, 'message': 'Passkey removed'}
|
||||
else:
|
||||
return {'success': False, 'message': 'Passkey not found'}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Passkey remove error: {e}", exc_info=True, module="2FA")
|
||||
return {'success': False, 'message': 'Failed to remove passkey. Please try again.'}
|
||||
|
||||
@router.get("/passkey/status")
|
||||
async def passkey_status(current_user: Dict = Depends(get_current_user_dependency)):
|
||||
"""Get passkey status"""
|
||||
if not PASSKEY_AVAILABLE:
|
||||
return {'success': False, 'enabled': False, 'message': 'Passkey support is not available on this server'}
|
||||
|
||||
try:
|
||||
username = current_user.get('sub')
|
||||
status = passkey_manager.get_passkey_status(username)
|
||||
|
||||
return {'success': True, **status}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Passkey status error: {e}", module="2FA")
|
||||
return {'success': False, 'message': 'Failed to retrieve passkey status. Please try again.'}
|
||||
|
||||
# ============================================================================
|
||||
# Duo Endpoints
|
||||
# ============================================================================
|
||||
|
||||
@router.post("/duo/auth", response_model=DuoAuthResponse)
|
||||
async def duo_auth(auth_data: DuoAuthRequest):
|
||||
"""Initiate Duo authentication"""
|
||||
try:
|
||||
if not duo_manager.is_duo_configured():
|
||||
return DuoAuthResponse(
|
||||
success=False,
|
||||
message='Duo is not configured'
|
||||
)
|
||||
|
||||
username = auth_data.username
|
||||
|
||||
# Create auth URL (pass rememberMe to preserve for callback)
|
||||
result = duo_manager.create_auth_url(username, auth_data.rememberMe)
|
||||
|
||||
return DuoAuthResponse(
|
||||
success=True,
|
||||
authUrl=result['authUrl'],
|
||||
state=result['state']
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Duo auth error: {e}", module="2FA")
|
||||
return DuoAuthResponse(success=False, message='Failed to initiate Duo authentication. Please try again.')
|
||||
|
||||
@router.get("/duo/callback")
|
||||
async def duo_callback(
|
||||
state: str,
|
||||
duo_code: Optional[str] = None,
|
||||
code: Optional[str] = None,
|
||||
request: Request = None
|
||||
):
|
||||
"""Handle Duo callback (OAuth 2.0)"""
|
||||
try:
|
||||
# Duo sends 'code' parameter, normalize to duo_code
|
||||
if not duo_code and code:
|
||||
duo_code = code
|
||||
|
||||
if not duo_code:
|
||||
return {'success': False, 'message': 'Missing authorization code'}
|
||||
|
||||
# Verify state and get username and remember_me
|
||||
result = duo_manager.verify_state(state)
|
||||
if not result:
|
||||
return {'success': False, 'message': 'Invalid or expired state'}
|
||||
|
||||
username, remember_me = result
|
||||
|
||||
# Verify Duo response
|
||||
if duo_manager.verify_duo_response(duo_code, username):
|
||||
# Get client IP
|
||||
ip_address = request.client.host if request.client else None
|
||||
|
||||
# Get user info
|
||||
import sqlite3
|
||||
with sqlite3.connect(auth_manager.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT role, email FROM users WHERE username = ?", (username,))
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row:
|
||||
role, email = row
|
||||
|
||||
# Create session with remember_me setting
|
||||
session_result = auth_manager._create_session(username, role, ip_address, remember_me)
|
||||
|
||||
# Redirect to frontend with success (token in URL and cookie)
|
||||
from fastapi.responses import RedirectResponse
|
||||
token = session_result.get('token')
|
||||
session_id = session_result.get('sessionId')
|
||||
|
||||
# Create redirect response with token in URL (frontend reads and stores it)
|
||||
redirect_url = f"/?duo_auth=success&token={token}&sessionId={session_id}&username={username}"
|
||||
response = RedirectResponse(url=redirect_url, status_code=302)
|
||||
|
||||
# Set auth cookie with proper max_age
|
||||
max_age = 30 * 24 * 60 * 60 if remember_me else None
|
||||
response.set_cookie(
|
||||
key="auth_token",
|
||||
value=token,
|
||||
max_age=max_age,
|
||||
httponly=True,
|
||||
secure=settings.SECURE_COOKIES,
|
||||
samesite="lax",
|
||||
path="/"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
# Verification failed - redirect with error
|
||||
from fastapi.responses import RedirectResponse
|
||||
return RedirectResponse(url="/login?duo_auth=failed&error=Duo+verification+failed", status_code=302)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Duo callback error: {e}", module="2FA")
|
||||
return {'success': False, 'message': 'Duo authentication failed. Please try again.'}
|
||||
|
||||
@router.post("/duo/enroll")
|
||||
async def duo_enroll(
|
||||
duo_username: Optional[str] = Body(None, embed=True),
|
||||
current_user: Dict = Depends(get_current_user_dependency)
|
||||
):
|
||||
"""Enroll user in Duo"""
|
||||
try:
|
||||
username = current_user.get('sub')
|
||||
|
||||
success = duo_manager.enroll_user(username, duo_username)
|
||||
|
||||
if success:
|
||||
return {'success': True, 'message': 'Duo enrollment successful'}
|
||||
else:
|
||||
return {'success': False, 'message': 'Duo enrollment failed'}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Duo enroll error: {e}", module="2FA")
|
||||
return {'success': False, 'message': 'Failed to enroll in Duo. Please try again.'}
|
||||
|
||||
@router.post("/duo/unenroll")
|
||||
async def duo_unenroll(current_user: Dict = Depends(get_current_user_dependency)):
|
||||
"""Unenroll user from Duo"""
|
||||
try:
|
||||
username = current_user.get('sub')
|
||||
|
||||
success = duo_manager.unenroll_user(username)
|
||||
|
||||
if success:
|
||||
return {'success': True, 'message': 'Duo unenrollment successful'}
|
||||
else:
|
||||
return {'success': False, 'message': 'Duo unenrollment failed'}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Duo unenroll error: {e}", module="2FA")
|
||||
return {'success': False, 'message': 'Failed to unenroll from Duo. Please try again.'}
|
||||
|
||||
@router.get("/duo/status")
|
||||
async def duo_status(current_user: Dict = Depends(get_current_user_dependency)):
|
||||
"""Get Duo status"""
|
||||
try:
|
||||
username = current_user.get('sub')
|
||||
status = duo_manager.get_duo_status(username)
|
||||
config_status = duo_manager.get_configuration_status()
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
**status,
|
||||
'duoConfigured': config_status['configured']
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Duo status error: {e}", module="2FA")
|
||||
return {'success': False, 'message': 'Failed to retrieve Duo status. Please try again.'}
|
||||
|
||||
# ============================================================================
|
||||
# Combined 2FA Status Endpoint
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/status")
|
||||
async def twofa_status(current_user: Dict = Depends(get_current_user_dependency)):
|
||||
"""Get complete 2FA status for user"""
|
||||
try:
|
||||
username = current_user.get('sub')
|
||||
|
||||
totp_status = totp_manager.get_totp_status(username)
|
||||
|
||||
# Get passkey status and add availability flag
|
||||
if PASSKEY_AVAILABLE:
|
||||
passkey_status = passkey_manager.get_passkey_status(username)
|
||||
passkey_status['available'] = True
|
||||
else:
|
||||
passkey_status = {'enabled': False, 'available': False, 'credentialCount': 0}
|
||||
|
||||
duo_status = duo_manager.get_duo_status(username)
|
||||
duo_config_status = duo_manager.get_configuration_status()
|
||||
|
||||
# Determine available methods
|
||||
available_methods = []
|
||||
if totp_status['enabled']:
|
||||
available_methods.append('totp')
|
||||
if PASSKEY_AVAILABLE and passkey_status['enabled']:
|
||||
available_methods.append('passkey')
|
||||
if duo_status['enabled'] and duo_config_status['configured']:
|
||||
available_methods.append('duo')
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'totp': totp_status,
|
||||
'passkey': passkey_status,
|
||||
'duo': {**duo_status, 'duoConfigured': duo_config_status['configured']},
|
||||
'availableMethods': available_methods,
|
||||
'anyEnabled': len(available_methods) > 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"2FA status error: {e}", module="2FA")
|
||||
return {'success': False, 'message': str(e)}
|
||||
|
||||
return router
|
||||
Reference in New Issue
Block a user