Initial commit

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Todd
2026-03-29 22:42:55 -04:00
commit 0d7b2b1aab
389 changed files with 280296 additions and 0 deletions

762
web/backend/api.py Normal file
View File

@@ -0,0 +1,762 @@
#!/usr/bin/env python3
"""
Media Downloader Web API
FastAPI backend that integrates with the existing media-downloader system
"""
# Suppress pkg_resources deprecation warning from face_recognition_models
import warnings
warnings.filterwarnings("ignore", message=".*pkg_resources is deprecated.*")
import sys
import os
# Bootstrap database backend (must be before any database imports)
sys.path.insert(0, str(__import__('pathlib').Path(__file__).parent.parent.parent))
import modules.db_bootstrap # noqa: E402,F401
import json
import asyncio
from pathlib import Path
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import JSONResponse
from starlette.middleware.sessions import SessionMiddleware
from starlette_csrf import CSRFMiddleware
import sqlite3
import secrets
import re
# Redis cache manager
from cache_manager import cache_manager, invalidate_download_cache
# Rate limiting
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
# Add parent directory to path to import media-downloader modules
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from modules.unified_database import UnifiedDatabase
from modules.scheduler import DownloadScheduler
from modules.settings_manager import SettingsManager
from modules.universal_logger import get_logger
from modules.discovery_system import get_discovery_system
from modules.semantic_search import get_semantic_search
from web.backend.auth_manager import AuthManager
from web.backend.core.dependencies import AppState
# Initialize universal logger for API
logger = get_logger('API')
# ============================================================================
# CONFIGURATION (use centralized settings from core/config.py)
# ============================================================================
from web.backend.core.config import settings
# Legacy aliases for backward compatibility (prefer using settings.* directly)
PROJECT_ROOT = settings.PROJECT_ROOT
DB_PATH = settings.DB_PATH
CONFIG_PATH = settings.CONFIG_PATH
LOG_PATH = settings.LOG_PATH
TEMP_DIR = settings.TEMP_DIR
DB_POOL_SIZE = settings.DB_POOL_SIZE
DB_CONNECTION_TIMEOUT = settings.DB_CONNECTION_TIMEOUT
PROCESS_TIMEOUT_SHORT = settings.PROCESS_TIMEOUT_SHORT
PROCESS_TIMEOUT_MEDIUM = settings.PROCESS_TIMEOUT_MEDIUM
PROCESS_TIMEOUT_LONG = settings.PROCESS_TIMEOUT_LONG
WEBSOCKET_TIMEOUT = settings.WEBSOCKET_TIMEOUT
ACTIVITY_LOG_INTERVAL = settings.ACTIVITY_LOG_INTERVAL
EMBEDDING_QUEUE_CHECK_INTERVAL = settings.EMBEDDING_QUEUE_CHECK_INTERVAL
EMBEDDING_BATCH_SIZE = settings.EMBEDDING_BATCH_SIZE
EMBEDDING_BATCH_LIMIT = settings.EMBEDDING_BATCH_LIMIT
THUMBNAIL_DB_TIMEOUT = settings.THUMBNAIL_DB_TIMEOUT
# ============================================================================
# PYDANTIC MODELS (imported from models/api_models.py to avoid duplication)
# ============================================================================
from web.backend.models.api_models import (
DownloadResponse,
StatsResponse,
PlatformStatus,
TriggerRequest,
PushoverConfig,
SchedulerConfig,
ConfigUpdate,
HealthStatus,
LoginRequest,
)
# ============================================================================
# AUTHENTICATION DEPENDENCIES (imported from core/dependencies.py)
# ============================================================================
from web.backend.core.dependencies import (
get_current_user,
require_admin,
get_app_state,
)
# ============================================================================
# GLOBAL STATE
# ============================================================================
# Use the shared singleton from core/dependencies (used by all routers)
app_state = get_app_state()
# ============================================================================
# LIFESPAN MANAGEMENT
# ============================================================================
def init_thumbnail_db():
"""Initialize thumbnail cache database"""
thumb_db_path = Path(__file__).parent.parent.parent / 'database' / 'thumbnails.db'
thumb_db_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(thumb_db_path))
conn.execute("""
CREATE TABLE IF NOT EXISTS thumbnails (
file_hash TEXT PRIMARY KEY,
file_path TEXT NOT NULL,
thumbnail_data BLOB NOT NULL,
created_at TEXT NOT NULL,
file_mtime REAL NOT NULL
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_file_path ON thumbnails(file_path)")
conn.commit()
conn.close()
return thumb_db_path
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize and cleanup app resources"""
# Startup
logger.info("🚀 Starting Media Downloader API...", module="Core")
# Initialize database with larger pool for API workers
app_state.db = UnifiedDatabase(str(DB_PATH), use_pool=True, pool_size=DB_POOL_SIZE)
logger.info(f"✓ Connected to database: {DB_PATH} (pool_size={DB_POOL_SIZE})", module="Core")
# Reset any stuck downloads from previous API run
try:
with app_state.db.get_connection(for_write=True) as conn:
cursor = conn.cursor()
cursor.execute('''
UPDATE video_download_queue
SET status = 'pending', progress = 0, started_at = NULL
WHERE status = 'downloading'
''')
reset_count = cursor.rowcount
conn.commit()
if reset_count > 0:
logger.info(f"✓ Reset {reset_count} stuck download(s) to pending", module="Core")
except Exception as e:
logger.warning(f"Could not reset stuck downloads: {e}", module="Core")
# Clear any stale background tasks from previous API run
try:
with app_state.db.get_connection(for_write=True) as conn:
cursor = conn.cursor()
cursor.execute('''
UPDATE background_task_status
SET active = 0
WHERE active = 1
''')
cleared_count = cursor.rowcount
conn.commit()
if cleared_count > 0:
logger.info(f"✓ Cleared {cleared_count} stale background task(s)", module="Core")
except Exception as e:
# Table might not exist yet, that's fine
pass
# Initialize face recognition semaphore (limit to 1 concurrent process)
app_state.face_recognition_semaphore = asyncio.Semaphore(1)
logger.info("✓ Face recognition semaphore initialized (max 1 concurrent)", module="Core")
# Initialize authentication manager
app_state.auth = AuthManager()
logger.info("✓ Authentication manager initialized", module="Core")
# Initialize and include 2FA router
twofa_router = create_2fa_router(app_state.auth, get_current_user)
app.include_router(twofa_router, prefix="/api/auth/2fa", tags=["2FA"])
logger.info("✓ 2FA routes initialized", module="Core")
# Include modular API routers
from web.backend.routers import (
auth_router,
health_router,
downloads_router,
media_router,
recycle_router,
scheduler_router,
video_router,
config_router,
review_router,
face_router,
platforms_router,
discovery_router,
scrapers_router,
semantic_router,
manual_import_router,
stats_router,
celebrity_router,
video_queue_router,
maintenance_router,
files_router,
appearances_router,
easynews_router,
dashboard_router,
paid_content_router,
private_gallery_router,
instagram_unified_router,
cloud_backup_router,
press_router
)
# Include all modular routers
# Note: websocket_manager is set later after ConnectionManager is created
app.include_router(auth_router)
app.include_router(health_router)
app.include_router(downloads_router)
app.include_router(media_router)
app.include_router(recycle_router)
app.include_router(scheduler_router)
app.include_router(video_router)
app.include_router(config_router)
app.include_router(review_router)
app.include_router(face_router)
app.include_router(platforms_router)
app.include_router(discovery_router)
app.include_router(scrapers_router)
app.include_router(semantic_router)
app.include_router(manual_import_router)
app.include_router(stats_router)
app.include_router(celebrity_router)
app.include_router(video_queue_router)
app.include_router(maintenance_router)
app.include_router(files_router)
app.include_router(appearances_router)
app.include_router(easynews_router)
app.include_router(dashboard_router)
app.include_router(paid_content_router)
app.include_router(private_gallery_router)
app.include_router(instagram_unified_router)
app.include_router(cloud_backup_router)
app.include_router(press_router)
logger.info("✓ All 28 modular routers registered", module="Core")
# Initialize settings manager
app_state.settings = SettingsManager(str(DB_PATH))
logger.info("✓ Settings manager initialized", module="Core")
# Check if settings need to be migrated from JSON
existing_settings = app_state.settings.get_all()
if not existing_settings or len(existing_settings) == 0:
logger.info("⚠ No settings in database, migrating from JSON...", module="Core")
if CONFIG_PATH.exists():
app_state.settings.migrate_from_json(str(CONFIG_PATH))
logger.info("✓ Settings migrated from JSON to database", module="Core")
else:
logger.info("⚠ No JSON config file found", module="Core")
# Load configuration from database
app_state.config = app_state.settings.get_all()
if app_state.config:
logger.info(f"✓ Loaded configuration from database", module="Core")
else:
logger.info("⚠ No configuration in database - please configure via web interface", module="Core")
app_state.config = {}
# Initialize scheduler (for status checking only - actual scheduler runs as separate service)
app_state.scheduler = DownloadScheduler(config_path=None, unified_db=app_state.db, settings_manager=app_state.settings)
logger.info("✓ Scheduler connected (status monitoring only)", module="Core")
# Initialize thumbnail database
init_thumbnail_db()
logger.info("✓ Thumbnail cache initialized", module="Core")
# Initialize recycle bin settings with defaults if not exists
if not app_state.settings.get('recycle_bin'):
recycle_bin_defaults = {
'enabled': True,
'path': '/opt/immich/recycle',
'retention_days': 30,
'max_size_gb': 50,
'auto_cleanup': True
}
app_state.settings.set('recycle_bin', recycle_bin_defaults, category='recycle_bin', description='Recycle bin configuration', updated_by='system')
logger.info("✓ Recycle bin settings initialized with defaults", module="Core")
logger.info("✓ Media Downloader API ready!", module="Core")
# Start periodic WAL checkpoint task
async def periodic_wal_checkpoint():
"""Run WAL checkpoint every 5 minutes to prevent WAL file growth"""
while True:
await asyncio.sleep(300) # 5 minutes
try:
if app_state.db:
app_state.db.checkpoint()
logger.debug("WAL checkpoint completed", module="Core")
except Exception as e:
logger.warning(f"WAL checkpoint failed: {e}", module="Core")
checkpoint_task = asyncio.create_task(periodic_wal_checkpoint())
logger.info("✓ WAL checkpoint task started (every 5 minutes)", module="Core")
# Start discovery queue processor
async def process_discovery_queue():
"""Background task to process discovery scan queue (embeddings)"""
while True:
try:
await asyncio.sleep(30) # Check queue every 30 seconds
if not app_state.db:
continue
# Process embeddings first (limit batch size for memory efficiency)
pending_embeddings = app_state.db.get_pending_discovery_scans(limit=EMBEDDING_BATCH_SIZE, scan_type='embedding')
for item in pending_embeddings:
try:
queue_id = item['id']
file_id = item['file_id']
file_path = item['file_path']
# Mark as started
app_state.db.mark_discovery_scan_started(queue_id)
# Check if file still exists
from pathlib import Path
if not Path(file_path).exists():
app_state.db.mark_discovery_scan_completed(queue_id)
continue
# Check if embedding already exists
with app_state.db.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('SELECT 1 FROM content_embeddings WHERE file_id = ?', (file_id,))
if cursor.fetchone():
# Already has embedding
app_state.db.mark_discovery_scan_completed(queue_id)
continue
# Get content type
cursor.execute('SELECT content_type FROM file_inventory WHERE id = ?', (file_id,))
row = cursor.fetchone()
content_type = row['content_type'] if row else 'image'
# Generate embedding (CPU intensive - run in executor)
def generate():
semantic = get_semantic_search(app_state.db)
return semantic.generate_embedding_for_file(file_id, file_path, content_type)
loop = asyncio.get_running_loop()
success = await loop.run_in_executor(None, generate)
if success:
app_state.db.mark_discovery_scan_completed(queue_id)
logger.debug(f"Generated embedding for file_id {file_id}", module="Discovery")
else:
app_state.db.mark_discovery_scan_failed(queue_id, "Embedding generation returned False")
except Exception as e:
logger.warning(f"Failed to process embedding queue item {item.get('id')}: {e}", module="Discovery")
app_state.db.mark_discovery_scan_failed(item.get('id', 0), str(e))
# Small delay between items to avoid CPU spikes
await asyncio.sleep(0.5)
except asyncio.CancelledError:
break
except Exception as e:
logger.warning(f"Discovery queue processor error: {e}", module="Discovery")
await asyncio.sleep(60) # Wait longer on errors
discovery_task = asyncio.create_task(process_discovery_queue())
logger.info("✓ Discovery queue processor started (embeddings, checks every 30s)", module="Core")
# Start periodic quality recheck for Fansly videos (hourly)
async def periodic_quality_recheck():
"""Recheck flagged Fansly attachments for higher quality every hour"""
await asyncio.sleep(300) # Wait 5 min after startup before first check
while True:
try:
from web.backend.routers.paid_content import _auto_quality_recheck_background
await _auto_quality_recheck_background()
except Exception as e:
logger.warning(f"Periodic quality recheck failed: {e}", module="Core")
await asyncio.sleep(3600) # 1 hour
quality_recheck_task = asyncio.create_task(periodic_quality_recheck())
logger.info("✓ Quality recheck task started (hourly)", module="Core")
# Auto-start download queue if enabled
try:
queue_settings = app_state.settings.get('video_queue')
if queue_settings and queue_settings.get('auto_start_on_restart'):
# Check if there are pending items
with app_state.db.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM video_download_queue WHERE status = 'pending'")
pending_count = cursor.fetchone()[0]
if pending_count > 0:
# Import the queue processor and start it
from web.backend.routers.video_queue import queue_processor, run_queue_processor
if not queue_processor.is_running:
queue_processor.start()
asyncio.create_task(run_queue_processor(app_state))
logger.info(f"✓ Auto-started download queue with {pending_count} pending items", module="Core")
else:
logger.info("✓ Auto-start enabled but no pending items in queue", module="Core")
except Exception as e:
logger.warning(f"Could not auto-start download queue: {e}", module="Core")
yield
# Shutdown
logger.info("🛑 Shutting down Media Downloader API...", module="Core")
checkpoint_task.cancel()
discovery_task.cancel()
quality_recheck_task.cancel()
try:
await checkpoint_task
except asyncio.CancelledError:
pass
try:
await discovery_task
except asyncio.CancelledError:
pass
try:
await quality_recheck_task
except asyncio.CancelledError:
pass
# Close shared HTTP client
try:
from web.backend.core.http_client import http_client
await http_client.close()
except Exception:
pass
if app_state.scheduler:
app_state.scheduler.stop()
if app_state.db:
app_state.db.close()
logger.info("✓ Cleanup complete", module="Core")
# ============================================================================
# FASTAPI APP
# ============================================================================
app = FastAPI(
title="Media Downloader API",
description="Web API for managing media downloads from Instagram, TikTok, Snapchat, and Forums",
version="13.13.1",
lifespan=lifespan
)
# CORS middleware - allow frontend origins
app.add_middleware(
CORSMiddleware,
allow_origins=settings.ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
allow_headers=[
"Authorization",
"Content-Type",
"Accept",
"Origin",
"X-Requested-With",
"X-CSRFToken",
"Cache-Control",
"Range",
],
max_age=600, # Cache preflight requests for 10 minutes
)
# GZip compression for responses > 1KB (70% smaller API payloads)
app.add_middleware(GZipMiddleware, minimum_size=1000)
# Session middleware for 2FA setup flow
# Persist generated key to file to avoid invalidating sessions on restart
def _load_session_secret():
secret_file = Path(__file__).parent.parent.parent / '.session_secret'
if os.getenv("SESSION_SECRET_KEY"):
return os.getenv("SESSION_SECRET_KEY")
if secret_file.exists():
with open(secret_file, 'r') as f:
secret = f.read().strip()
if len(secret) >= 32:
return secret
new_secret = secrets.token_urlsafe(32)
try:
fd = os.open(str(secret_file), os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
with os.fdopen(fd, 'w') as f:
f.write(new_secret)
except Exception:
pass # Fall back to in-memory only
return new_secret
SESSION_SECRET_KEY = _load_session_secret()
app.add_middleware(SessionMiddleware, secret_key=SESSION_SECRET_KEY)
# CSRF protection middleware - enabled for state-changing requests
# Note: All endpoints are also protected by JWT authentication for double security
from starlette.types import ASGIApp, Receive, Scope, Send
class CSRFMiddlewareHTTPOnly:
"""Wrapper to apply CSRF middleware only to HTTP requests, not WebSocket"""
def __init__(self, app: ASGIApp, csrf_middleware_class, **kwargs):
self.app = app
self.csrf_middleware = csrf_middleware_class(app, **kwargs)
async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] == "websocket":
# Skip CSRF for WebSocket connections (uses separate auth)
await self.app(scope, receive, send)
else:
# Apply CSRF for HTTP requests
await self.csrf_middleware(scope, receive, send)
CSRF_SECRET_KEY = settings.CSRF_SECRET_KEY or SESSION_SECRET_KEY
app.add_middleware(
CSRFMiddlewareHTTPOnly,
csrf_middleware_class=CSRFMiddleware,
secret=CSRF_SECRET_KEY,
cookie_name="csrftoken",
header_name="X-CSRFToken",
cookie_secure=settings.SECURE_COOKIES,
cookie_httponly=False, # JS needs to read for SPA
cookie_samesite="lax", # CSRF protection
exempt_urls=[
# API endpoints are protected against CSRF via multiple layers:
# 1. Primary: JWT in Authorization header (cannot be forged cross-origin)
# 2. Secondary: Cookie auth uses samesite='lax' which blocks cross-origin POSTs
# 3. Media endpoints (cookie-only) are GET requests, safe from CSRF
# This makes additional CSRF token validation redundant for /api/* routes
re.compile(r"^/api/.*"),
re.compile(r"^/ws$"), # WebSocket endpoint (has separate auth via cookie)
]
)
# Rate limiting with Redis persistence
# Falls back to in-memory if Redis unavailable
redis_uri = f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB}"
try:
limiter = Limiter(key_func=get_remote_address, storage_uri=redis_uri)
logger.info("Rate limiter using Redis storage", module="RateLimit")
except Exception as e:
logger.warning(f"Redis rate limiting unavailable, using in-memory: {e}", module="RateLimit")
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Import and include 2FA routes
from web.backend.twofa_routes import create_2fa_router
# ============================================================================
# WEBSOCKET CONNECTION MANAGER
# ============================================================================
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
self._loop = None
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
# Capture event loop for thread-safe broadcasts
if self._loop is None:
self._loop = asyncio.get_running_loop()
def disconnect(self, websocket: WebSocket):
try:
self.active_connections.remove(websocket)
except ValueError:
pass # Already removed
async def broadcast(self, message: dict):
logger.info(f"ConnectionManager: Broadcasting to {len(self.active_connections)} clients", module="WebSocket")
dead_connections = []
for connection in self.active_connections:
try:
await connection.send_json(message)
except Exception as e:
logger.debug(f"ConnectionManager: Removing dead connection: {e}", module="WebSocket")
dead_connections.append(connection)
# Remove dead connections
for connection in dead_connections:
try:
self.active_connections.remove(connection)
except ValueError:
pass # Already removed
def broadcast_sync(self, message: dict):
"""Thread-safe broadcast for use in background tasks (sync threads)"""
msg_type = message.get('type', 'unknown')
if not self.active_connections:
logger.info(f"broadcast_sync: No active WebSocket connections for {msg_type}", module="WebSocket")
return
logger.info(f"broadcast_sync: Sending {msg_type} to {len(self.active_connections)} clients", module="WebSocket")
try:
# Try to get running loop (we're in async context)
loop = asyncio.get_running_loop()
asyncio.create_task(self.broadcast(message))
logger.info(f"broadcast_sync: Created task for {msg_type}", module="WebSocket")
except RuntimeError:
# No running loop - we're in a thread, use thread-safe call
if self._loop and self._loop.is_running():
asyncio.run_coroutine_threadsafe(self.broadcast(message), self._loop)
logger.info(f"broadcast_sync: Used threadsafe call for {msg_type}", module="WebSocket")
else:
logger.warning(f"broadcast_sync: No event loop available for {msg_type}", module="WebSocket")
manager = ConnectionManager()
# Store websocket manager in app_state for modular routers to access
app_state.websocket_manager = manager
# Initialize scraper event emitter for real-time monitoring
from modules.scraper_event_emitter import ScraperEventEmitter
app_state.scraper_event_emitter = ScraperEventEmitter(websocket_manager=manager, app_state=app_state)
logger.info("Scraper event emitter initialized for real-time monitoring", module="WebSocket")
# ============================================================================
# WEBSOCKET ENDPOINT
# ============================================================================
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket endpoint for real-time updates.
Authentication priority:
1. Cookie (most secure, set by browser automatically)
2. Sec-WebSocket-Protocol header with 'auth.<token>' (secure, not logged)
3. Query parameter (fallback, may be logged in access logs)
"""
auth_token = None
# Priority 1: Try cookie (most common for browsers)
auth_token = websocket.cookies.get('auth_token')
# Priority 2: Check Sec-WebSocket-Protocol header for "auth.<token>" subprotocol
# This is more secure than query params as it's not logged in access logs
if not auth_token:
protocols = websocket.headers.get('sec-websocket-protocol', '')
for protocol in protocols.split(','):
protocol = protocol.strip()
if protocol.startswith('auth.'):
auth_token = protocol[5:] # Remove 'auth.' prefix
break
# Priority 3: Query parameter fallback (least preferred, logged in access logs)
if not auth_token:
auth_token = websocket.query_params.get('token')
if not auth_token:
await websocket.close(code=4001, reason="Authentication required")
logger.warning("WebSocket: Connection rejected - no auth cookie or token", module="WebSocket")
return
# Verify the token
payload = app_state.auth.verify_session(auth_token)
if not payload:
await websocket.close(code=4001, reason="Invalid or expired token")
logger.warning("WebSocket: Connection rejected - invalid token", module="WebSocket")
return
username = payload.get('sub', 'unknown')
# Accept with the auth subprotocol if that was used
protocols = websocket.headers.get('sec-websocket-protocol', '')
if any(p.strip().startswith('auth.') for p in protocols.split(',')):
await websocket.accept(subprotocol='auth')
manager.active_connections.append(websocket)
if manager._loop is None:
manager._loop = asyncio.get_running_loop()
else:
await manager.connect(websocket)
try:
# Send initial connection message
await websocket.send_json({
"type": "connected",
"timestamp": datetime.now().isoformat(),
"user": username
})
logger.info(f"WebSocket: Client connected (user: {username})", module="WebSocket")
while True:
# Keep connection alive with ping/pong
# Use raw receive() instead of receive_text() for better proxy compatibility
try:
message = await asyncio.wait_for(websocket.receive(), timeout=30.0)
msg_type = message.get("type", "unknown")
if msg_type == "websocket.receive":
text = message.get("text")
if text:
# Echo back client messages
await websocket.send_json({
"type": "echo",
"message": text
})
elif msg_type == "websocket.disconnect":
# Client closed connection
code = message.get("code", "unknown")
logger.info(f"WebSocket: Client {username} sent disconnect (code={code})", module="WebSocket")
break
except asyncio.TimeoutError:
# Send ping to keep connection alive
try:
await websocket.send_json({
"type": "ping",
"timestamp": datetime.now().isoformat()
})
except Exception as e:
# Connection lost
logger.info(f"WebSocket: Ping failed for {username}: {type(e).__name__}: {e}", module="WebSocket")
break
except WebSocketDisconnect as e:
logger.info(f"WebSocket: {username} disconnected (code={e.code})", module="WebSocket")
except Exception as e:
logger.error(f"WebSocket error for {username}: {type(e).__name__}: {e}", module="WebSocket")
finally:
manager.disconnect(websocket)
logger.info(f"WebSocket: {username} removed from active connections ({len(manager.active_connections)} remaining)", module="WebSocket")
# ============================================================================
# MAIN ENTRY POINT
# ============================================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"api:app",
host="0.0.0.0",
port=8000,
workers=1, # Single worker needed for WebSocket broadcasts to work correctly
timeout_keep_alive=30,
log_level="info",
limit_concurrency=1000, # Allow up to 1000 concurrent connections
backlog=2048 # Queue size for pending connections
)

565
web/backend/auth_manager.py Normal file
View File

@@ -0,0 +1,565 @@
#!/usr/bin/env python3
"""
Authentication Manager for Media Downloader
Handles user authentication and sessions
"""
import os
import sqlite3
import secrets
import hashlib
from datetime import datetime, timedelta, timezone
from typing import Optional, Dict, List, Tuple
from passlib.context import CryptContext
from jose import jwt, JWTError
from pathlib import Path
# Password hashing
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# JWT Configuration
def _load_jwt_secret():
"""Load JWT secret from file, environment, or generate new one"""
# Try to load from file first
secret_file = Path(__file__).parent.parent.parent / '.jwt_secret'
if secret_file.exists():
with open(secret_file, 'r') as f:
secret = f.read().strip()
# Validate secret has minimum length (at least 32 chars)
if len(secret) >= 32:
return secret
# If too short, regenerate
# Fallback to environment variable
if "JWT_SECRET_KEY" in os.environ:
secret = os.environ["JWT_SECRET_KEY"]
if len(secret) >= 32:
return secret
# Generate a cryptographically strong secret (384 bits / 48 bytes)
new_secret = secrets.token_urlsafe(48)
try:
with open(secret_file, 'w') as f:
f.write(new_secret)
os.chmod(secret_file, 0o600)
except Exception:
# Log warning but continue - in-memory secret will invalidate on restart
import logging
logging.getLogger(__name__).warning(
"Could not save JWT secret to file. Tokens will be invalidated on restart."
)
return new_secret
SECRET_KEY = _load_jwt_secret()
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 # 24 hours (default)
ACCESS_TOKEN_REMEMBER_MINUTES = 60 * 24 * 30 # 30 days (remember me)
class AuthManager:
def __init__(self, db_path: str = None):
if db_path is None:
db_path = str(Path(__file__).parent.parent.parent / 'database' / 'auth.db')
self.db_path = db_path
# Rate limiting
self.login_attempts = {}
self.max_attempts = 5
self.lockout_duration = timedelta(minutes=15)
self._init_database()
self._create_default_user()
def _init_database(self):
"""Initialize the authentication database schema"""
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
cursor = conn.cursor()
# Users table
cursor.execute("""
CREATE TABLE IF NOT EXISTS users (
username TEXT PRIMARY KEY,
password_hash TEXT NOT NULL,
role TEXT NOT NULL DEFAULT 'viewer',
email TEXT,
is_active INTEGER NOT NULL DEFAULT 1,
totp_secret TEXT,
totp_enabled INTEGER NOT NULL DEFAULT 0,
duo_enabled INTEGER NOT NULL DEFAULT 0,
duo_username TEXT,
created_at TEXT NOT NULL,
last_login TEXT,
preferences TEXT
)
""")
# Add duo_enabled column to existing table if it doesn't exist
try:
cursor.execute("ALTER TABLE users ADD COLUMN duo_enabled INTEGER NOT NULL DEFAULT 0")
except sqlite3.OperationalError:
pass # Column already exists
try:
cursor.execute("ALTER TABLE users ADD COLUMN duo_username TEXT")
except sqlite3.OperationalError:
pass # Column already exists
# Backup codes table
cursor.execute("""
CREATE TABLE IF NOT EXISTS backup_codes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL,
code_hash TEXT NOT NULL,
used INTEGER NOT NULL DEFAULT 0,
used_at TEXT,
created_at TEXT NOT NULL,
FOREIGN KEY (username) REFERENCES users(username) ON DELETE CASCADE
)
""")
# Sessions table
cursor.execute("""
CREATE TABLE IF NOT EXISTS sessions (
session_token TEXT PRIMARY KEY,
username TEXT NOT NULL,
created_at TEXT NOT NULL,
expires_at TEXT NOT NULL,
ip_address TEXT,
user_agent TEXT,
FOREIGN KEY (username) REFERENCES users(username) ON DELETE CASCADE
)
""")
# Login attempts table (for rate limiting)
cursor.execute("""
CREATE TABLE IF NOT EXISTS login_attempts (
username TEXT PRIMARY KEY,
attempts INTEGER NOT NULL DEFAULT 0,
last_attempt TEXT NOT NULL
)
""")
# Audit log
cursor.execute("""
CREATE TABLE IF NOT EXISTS auth_audit (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL,
action TEXT NOT NULL,
success INTEGER NOT NULL,
ip_address TEXT,
details TEXT,
timestamp TEXT NOT NULL
)
""")
conn.commit()
def _create_default_user(self):
"""Create default admin user if no users exist"""
import logging
_logger = logging.getLogger(__name__)
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM users")
if cursor.fetchone()[0] == 0:
# Get password from environment or generate secure random one
default_password = os.environ.get("ADMIN_PASSWORD")
password_generated = False
if not default_password:
# Generate a secure random password (16 chars, URL-safe)
default_password = secrets.token_urlsafe(16)
password_generated = True
password_hash = pwd_context.hash(default_password)
cursor.execute("""
INSERT INTO users (username, password_hash, role, email, created_at, preferences)
VALUES (?, ?, ?, ?, ?, ?)
""", (
'admin',
password_hash,
'admin',
'admin@localhost',
datetime.now().isoformat(),
'{"theme": "dark"}'
))
conn.commit()
# Security: Never log the password - write to secure file instead
if password_generated:
password_file = Path(__file__).parent.parent.parent / '.admin_password'
try:
with open(password_file, 'w') as f:
f.write(f"Admin password (delete after first login):\n{default_password}\n")
os.chmod(password_file, 0o600)
_logger.info("Created default admin user. Password saved to .admin_password")
except Exception:
# If we can't write file, show masked hint only
_logger.info("Created default admin user. Set ADMIN_PASSWORD env var to control password.")
else:
_logger.info("Created default admin user with password from ADMIN_PASSWORD")
def verify_password(self, plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash"""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(self, password: str) -> str:
"""Hash a password"""
return pwd_context.hash(password)
def create_access_token(self, data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a JWT access token"""
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def verify_token(self, token: str) -> Optional[Dict]:
"""Verify and decode a JWT token"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
return payload
except JWTError:
return None
def check_rate_limit(self, username: str) -> Tuple[bool, Optional[str]]:
"""Check if user is rate limited"""
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT attempts, last_attempt FROM login_attempts
WHERE username = ?
""", (username,))
row = cursor.fetchone()
if not row:
return True, None
attempts, last_attempt_str = row
last_attempt = datetime.fromisoformat(last_attempt_str)
if attempts >= self.max_attempts:
time_since = datetime.now() - last_attempt
if time_since < self.lockout_duration:
remaining = self.lockout_duration - time_since
return False, f"Account locked. Try again in {int(remaining.total_seconds() / 60)} minutes"
else:
# Reset after lockout period
self._reset_login_attempts(username)
return True, None
return True, None
def _reset_login_attempts(self, username: str):
"""Reset login attempts for a user"""
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM login_attempts WHERE username = ?", (username,))
conn.commit()
def _record_login_attempt(self, username: str, success: bool):
"""Record a login attempt"""
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
cursor = conn.cursor()
if success:
# Reset attempts on successful login
cursor.execute("DELETE FROM login_attempts WHERE username = ?", (username,))
else:
# Increment attempts
now = datetime.now().isoformat()
cursor.execute("""
INSERT INTO login_attempts (username, attempts, last_attempt)
VALUES (?, 1, ?)
ON CONFLICT(username) DO UPDATE SET
attempts = login_attempts.attempts + 1,
last_attempt = EXCLUDED.last_attempt
""", (username, now))
conn.commit()
def authenticate(self, username: str, password: str, ip_address: str = None, remember_me: bool = False) -> Dict:
"""Authenticate a user with username and password"""
# Check rate limiting
allowed, error_msg = self.check_rate_limit(username)
if not allowed:
return {'success': False, 'error': error_msg, 'requires_2fa': False}
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT password_hash, role, is_active
FROM users WHERE username = ?
""", (username,))
row = cursor.fetchone()
if not row:
self._record_login_attempt(username, False)
self._log_audit(username, 'login_failed', False, ip_address, 'Invalid username')
return {'success': False, 'error': 'Invalid credentials'}
password_hash, role, is_active = row
if not is_active:
self._log_audit(username, 'login_failed', False, ip_address, 'Account inactive')
return {'success': False, 'error': 'Account is inactive'}
if not self.verify_password(password, password_hash):
self._record_login_attempt(username, False)
self._log_audit(username, 'login_failed', False, ip_address, 'Invalid password')
return {'success': False, 'error': 'Invalid credentials'}
# Password is correct, create session (2FA removed)
return self._create_session(username, role, ip_address, remember_me)
def _create_session(self, username: str, role: str, ip_address: str = None, remember_me: bool = False) -> Dict:
"""Create a new session and return token - matches backup-central format"""
import json
import uuid
# Create JWT token
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
cursor = conn.cursor()
# Get user email and preferences
cursor.execute("SELECT email, preferences FROM users WHERE username = ?", (username,))
row = cursor.fetchone()
email = row[0] if row and row[0] else None
# Parse preferences JSON
preferences = {}
if row and row[1]:
try:
preferences = json.loads(row[1])
except Exception:
pass
# Use longer expiration if remember_me is enabled
expire_minutes = ACCESS_TOKEN_REMEMBER_MINUTES if remember_me else ACCESS_TOKEN_EXPIRE_MINUTES
token_data = {"sub": username, "role": role, "email": email}
token = self.create_access_token(token_data, expires_delta=timedelta(minutes=expire_minutes))
# Generate session ID
sessionId = str(uuid.uuid4())
now = datetime.now().isoformat()
expires_at = (datetime.now() + timedelta(minutes=expire_minutes)).isoformat()
cursor.execute("""
INSERT INTO sessions (session_token, username, created_at, expires_at, ip_address)
VALUES (?, ?, ?, ?, ?)
""", (token, username, now, expires_at, ip_address))
# Update last login
cursor.execute("""
UPDATE users SET last_login = ? WHERE username = ?
""", (now, username))
conn.commit()
self._log_audit(username, 'login_success', True, ip_address)
# Return structure matching backup-central (with preferences)
return {
'success': True,
'token': token,
'sessionId': sessionId,
'user': {
'username': username,
'role': role,
'email': email,
'preferences': preferences
}
}
def _log_audit(self, username: str, action: str, success: bool, ip_address: str = None, details: str = None):
"""Log an audit event"""
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO auth_audit (username, action, success, ip_address, details, timestamp)
VALUES (?, ?, ?, ?, ?, ?)
""", (username, action, int(success), ip_address, details, datetime.now().isoformat()))
conn.commit()
def get_user(self, username: str) -> Optional[Dict]:
"""Get user information"""
import json
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT username, role, email, is_active, totp_enabled, duo_enabled, duo_username, created_at, last_login, preferences
FROM users WHERE username = ?
""", (username,))
row = cursor.fetchone()
if not row:
return None
# Parse preferences JSON
preferences = {}
if row[9]:
try:
preferences = json.loads(row[9])
except Exception:
pass
return {
'username': row[0],
'role': row[1],
'email': row[2],
'is_active': bool(row[3]),
'totp_enabled': bool(row[4]),
'duo_enabled': bool(row[5]),
'duo_username': row[6],
'created_at': row[7],
'last_login': row[8],
'preferences': preferences
}
def verify_session(self, token: str) -> Optional[Dict]:
"""Verify a session token"""
payload = self.verify_token(token)
if not payload:
return None
username = payload.get("sub")
if not username:
return None
# Check if session exists and is valid
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT expires_at FROM sessions
WHERE session_token = ? AND username = ?
""", (token, username))
row = cursor.fetchone()
if not row:
return None
expires_at = datetime.fromisoformat(row[0])
if expires_at < datetime.now():
# Session expired
return None
return payload
def logout(self, token: str):
"""Logout and invalidate session"""
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM sessions WHERE session_token = ?", (token,))
conn.commit()
# ============================================================================
# DUO 2FA METHODS (Duo Universal Prompt)
# ============================================================================
# Duo authentication is now handled via DuoManager using Universal Prompt
# The flow is: login → redirect to Duo → callback → complete login
# All Duo operations are delegated to self.duo_manager
def change_password(self, username: str, current_password: str, new_password: str, ip_address: str = None) -> Dict:
"""Change user password"""
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
cursor = conn.cursor()
# Get current password hash
cursor.execute("SELECT password_hash FROM users WHERE username = ?", (username,))
row = cursor.fetchone()
if not row:
return {'success': False, 'error': 'User not found'}
current_hash = row[0]
# Verify current password
if not self.verify_password(current_password, current_hash):
self._log_audit(username, 'password_change_failed', False, ip_address, 'Invalid current password')
return {'success': False, 'error': 'Current password is incorrect'}
# Validate new password (minimum 8 characters)
if len(new_password) < 8:
return {'success': False, 'error': 'New password must be at least 8 characters'}
# Hash new password
new_hash = self.get_password_hash(new_password)
# Update password
cursor.execute("""
UPDATE users SET password_hash = ?
WHERE username = ?
""", (new_hash, username))
# Security: Invalidate ALL existing sessions for this user
# This ensures compromised sessions are revoked when password changes
cursor.execute("""
DELETE FROM sessions WHERE username = ?
""", (username,))
sessions_invalidated = cursor.rowcount
conn.commit()
self._log_audit(username, 'password_changed', True, ip_address,
f'Invalidated {sessions_invalidated} sessions')
return {'success': True, 'sessions_invalidated': sessions_invalidated}
def update_preferences(self, username: str, preferences: dict) -> Dict:
"""Update user preferences (theme, etc.)"""
import json
with sqlite3.connect(self.db_path, timeout=30.0) as conn:
cursor = conn.cursor()
# Get current preferences
cursor.execute("SELECT preferences FROM users WHERE username = ?", (username,))
row = cursor.fetchone()
if not row:
return {'success': False, 'error': 'User not found'}
# Merge with existing preferences
current_prefs = {}
if row[0]:
try:
current_prefs = json.loads(row[0])
except Exception:
pass
current_prefs.update(preferences)
# Update preferences
cursor.execute("""
UPDATE users SET preferences = ?
WHERE username = ?
""", (json.dumps(current_prefs), username))
conn.commit()
return {'success': True, 'preferences': current_prefs}

View File

@@ -0,0 +1,224 @@
"""
Redis-based cache manager for Media Downloader API
Provides caching for expensive queries with configurable TTL
"""
import json
import sys
from pathlib import Path
from typing import Any, Optional, Callable
from functools import wraps
import redis
from redis.exceptions import RedisError
# Add parent path to allow imports from modules
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from modules.universal_logger import get_logger
from web.backend.core.config import settings
logger = get_logger('CacheManager')
class CacheManager:
"""Redis cache manager with automatic connection handling"""
def __init__(self, host: str = '127.0.0.1', port: int = 6379, db: int = 0, ttl: int = 300):
"""
Initialize cache manager
Args:
host: Redis host (default: 127.0.0.1)
port: Redis port (default: 6379)
db: Redis database number (default: 0)
ttl: Default TTL in seconds (default: 300 = 5 minutes)
"""
self.host = host
self.port = port
self.db = db
self.default_ttl = ttl
self._redis = None
self._connect()
def _connect(self):
"""Connect to Redis with error handling"""
try:
self._redis = redis.Redis(
host=self.host,
port=self.port,
db=self.db,
decode_responses=True,
socket_connect_timeout=2,
socket_timeout=2
)
# Test connection
self._redis.ping()
logger.info(f"Connected to Redis at {self.host}:{self.port}", module="Redis")
except RedisError as e:
logger.warning(f"Redis connection failed: {e}. Caching disabled.", module="Redis")
self._redis = None
@property
def is_available(self) -> bool:
"""Check if Redis is available"""
if self._redis is None:
return False
try:
self._redis.ping()
return True
except RedisError:
return False
def get(self, key: str) -> Optional[Any]:
"""
Get value from cache
Args:
key: Cache key
Returns:
Cached value (deserialized from JSON) or None if not found/error
"""
if not self.is_available:
return None
try:
value = self._redis.get(key)
if value is None:
return None
return json.loads(value)
except (RedisError, json.JSONDecodeError) as e:
logger.warning(f"Cache get error for key '{key}': {e}", module="Cache")
return None
def set(self, key: str, value: Any, ttl: Optional[int] = None):
"""
Set value in cache
Args:
key: Cache key
value: Value to cache (will be JSON serialized)
ttl: TTL in seconds (default: use default_ttl)
"""
if not self.is_available:
return
ttl = ttl if ttl is not None else self.default_ttl
try:
serialized = json.dumps(value)
self._redis.setex(key, ttl, serialized)
except (RedisError, TypeError) as e:
logger.warning(f"Cache set error for key '{key}': {e}", module="Cache")
def delete(self, key: str):
"""
Delete key from cache
Args:
key: Cache key to delete
"""
if not self.is_available:
return
try:
self._redis.delete(key)
except RedisError as e:
logger.warning(f"Cache delete error for key '{key}': {e}", module="Cache")
def clear(self, pattern: str = "*"):
"""
Clear cache keys matching pattern
Args:
pattern: Redis key pattern (default: "*" clears all)
"""
if not self.is_available:
return
try:
keys = self._redis.keys(pattern)
if keys:
self._redis.delete(*keys)
logger.info(f"Cleared {len(keys)} cache keys matching '{pattern}'", module="Cache")
except RedisError as e:
logger.warning(f"Cache clear error for pattern '{pattern}': {e}", module="Cache")
def cached(self, key_prefix: str, ttl: Optional[int] = None):
"""
Decorator for caching function results
Args:
key_prefix: Prefix for cache key (full key includes function args)
ttl: TTL in seconds (default: use default_ttl)
Example:
@cache_manager.cached('stats', ttl=300)
def get_download_stats(platform: str, days: int):
# Expensive query
return stats
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def async_wrapper(*args, **kwargs):
# Build cache key from function args
cache_key = f"{key_prefix}:{func.__name__}:{hash((args, tuple(sorted(kwargs.items()))))}"
# Try to get from cache
cached_result = self.get(cache_key)
if cached_result is not None:
logger.debug(f"Cache HIT: {cache_key}", module="Cache")
return cached_result
# Cache miss - execute function
logger.debug(f"Cache MISS: {cache_key}", module="Cache")
result = await func(*args, **kwargs)
# Store in cache
self.set(cache_key, result, ttl)
return result
@wraps(func)
def sync_wrapper(*args, **kwargs):
# Build cache key from function args
cache_key = f"{key_prefix}:{func.__name__}:{hash((args, tuple(sorted(kwargs.items()))))}"
# Try to get from cache
cached_result = self.get(cache_key)
if cached_result is not None:
logger.debug(f"Cache HIT: {cache_key}", module="Cache")
return cached_result
# Cache miss - execute function
logger.debug(f"Cache MISS: {cache_key}", module="Cache")
result = func(*args, **kwargs)
# Store in cache
self.set(cache_key, result, ttl)
return result
# Return appropriate wrapper based on function type
import asyncio
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator
# Global cache manager instance (use centralized config)
cache_manager = CacheManager(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
ttl=settings.REDIS_TTL
)
def invalidate_download_cache():
"""Invalidate all download-related caches"""
cache_manager.clear("downloads:*")
cache_manager.clear("stats:*")
cache_manager.clear("filters:*")
logger.info("Invalidated download-related caches", module="Cache")

View File

@@ -0,0 +1 @@
# Core module exports

121
web/backend/core/config.py Normal file
View File

@@ -0,0 +1,121 @@
"""
Unified Configuration Manager
Provides a single source of truth for all configuration values with a clear hierarchy:
1. Environment variables (highest priority)
2. .env file
3. Database settings
4. Hardcoded defaults (lowest priority)
"""
import os
from pathlib import Path
from typing import Any, Dict, Optional
from functools import lru_cache
from dotenv import load_dotenv
# Load environment variables from .env file
PROJECT_ROOT = Path(__file__).parent.parent.parent.parent
ENV_PATH = PROJECT_ROOT / '.env'
if ENV_PATH.exists():
load_dotenv(ENV_PATH)
class Settings:
"""Centralized configuration settings"""
# Project paths
PROJECT_ROOT: Path = PROJECT_ROOT
DB_PATH: Path = PROJECT_ROOT / 'database' / 'media_downloader.db'
CONFIG_PATH: Path = PROJECT_ROOT / 'config' / 'settings.json'
LOG_PATH: Path = PROJECT_ROOT / 'logs'
TEMP_DIR: Path = PROJECT_ROOT / 'temp'
COOKIES_DIR: Path = PROJECT_ROOT / 'cookies'
DATA_DIR: Path = PROJECT_ROOT / 'data'
VENV_BIN: Path = PROJECT_ROOT / 'venv' / 'bin'
MEDIA_BASE_PATH: Path = Path(os.getenv("MEDIA_BASE_PATH", "/opt/immich/md"))
REVIEW_PATH: Path = Path(os.getenv("REVIEW_PATH", "/opt/immich/review"))
RECYCLE_PATH: Path = Path(os.getenv("RECYCLE_PATH", "/opt/immich/recycle"))
# External tool paths (use venv by default)
YT_DLP_PATH: Path = VENV_BIN / 'yt-dlp'
GALLERY_DL_PATH: Path = VENV_BIN / 'gallery-dl'
PYTHON_PATH: Path = VENV_BIN / 'python3'
# Database configuration
DATABASE_BACKEND: str = os.getenv("DATABASE_BACKEND", "sqlite")
DATABASE_URL: str = os.getenv("DATABASE_URL", "")
DB_POOL_SIZE: int = int(os.getenv("DB_POOL_SIZE", "20"))
DB_CONNECTION_TIMEOUT: float = float(os.getenv("DB_CONNECTION_TIMEOUT", "30.0"))
# Process timeouts (seconds)
PROCESS_TIMEOUT_SHORT: int = int(os.getenv("PROCESS_TIMEOUT_SHORT", "2"))
PROCESS_TIMEOUT_MEDIUM: int = int(os.getenv("PROCESS_TIMEOUT_MEDIUM", "10"))
PROCESS_TIMEOUT_LONG: int = int(os.getenv("PROCESS_TIMEOUT_LONG", "120"))
# WebSocket configuration
WEBSOCKET_TIMEOUT: float = float(os.getenv("WEBSOCKET_TIMEOUT", "30.0"))
# Background task intervals (seconds)
ACTIVITY_LOG_INTERVAL: int = int(os.getenv("ACTIVITY_LOG_INTERVAL", "300"))
EMBEDDING_QUEUE_CHECK_INTERVAL: int = int(os.getenv("EMBEDDING_QUEUE_CHECK_INTERVAL", "30"))
EMBEDDING_BATCH_SIZE: int = int(os.getenv("EMBEDDING_BATCH_SIZE", "10"))
EMBEDDING_BATCH_LIMIT: int = int(os.getenv("EMBEDDING_BATCH_LIMIT", "50000"))
# Thumbnail generation
THUMBNAIL_DB_TIMEOUT: float = float(os.getenv("THUMBNAIL_DB_TIMEOUT", "30.0"))
THUMBNAIL_SIZE: tuple = (300, 300)
# Video proxy configuration
PROXY_FILE_CACHE_DURATION: int = int(os.getenv("PROXY_FILE_CACHE_DURATION", "300"))
# Redis cache configuration
REDIS_HOST: str = os.getenv("REDIS_HOST", "127.0.0.1")
REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379"))
REDIS_DB: int = int(os.getenv("REDIS_DB", "0"))
REDIS_TTL: int = int(os.getenv("REDIS_TTL", "300"))
# API configuration
API_VERSION: str = "13.13.1"
API_TITLE: str = "Media Downloader API"
API_DESCRIPTION: str = "Web API for managing media downloads"
# CORS configuration
ALLOWED_ORIGINS: list = os.getenv(
"ALLOWED_ORIGINS",
"http://localhost:5173,http://localhost:3000,http://127.0.0.1:5173,http://127.0.0.1:3000"
).split(",")
# Security
SECURE_COOKIES: bool = os.getenv("SECURE_COOKIES", "false").lower() == "true"
SESSION_SECRET_KEY: str = os.getenv("SESSION_SECRET_KEY", "")
CSRF_SECRET_KEY: str = os.getenv("CSRF_SECRET_KEY", "")
# Rate limiting
RATE_LIMIT_DEFAULT: str = os.getenv("RATE_LIMIT_DEFAULT", "100/minute")
RATE_LIMIT_STRICT: str = os.getenv("RATE_LIMIT_STRICT", "10/minute")
@classmethod
def get(cls, key: str, default: Any = None) -> Any:
"""Get a configuration value by key"""
return getattr(cls, key, default)
@classmethod
def get_path(cls, key: str) -> Path:
"""Get a path configuration value, ensuring it's a Path object"""
value = getattr(cls, key, None)
if value is None:
raise ValueError(f"Configuration key {key} not found")
if isinstance(value, Path):
return value
return Path(value)
@lru_cache()
def get_settings() -> Settings:
"""Get cached settings instance"""
return Settings()
# Convenience exports
settings = get_settings()

View File

@@ -0,0 +1,206 @@
"""
Shared Dependencies
Provides dependency injection for FastAPI routes.
All authentication, database access, and common dependencies are defined here.
"""
import sys
from pathlib import Path
from typing import Dict, Optional, List
from datetime import datetime
from fastapi import Depends, HTTPException, Query, Request, WebSocket
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
from modules.universal_logger import get_logger
logger = get_logger('API')
# Security
security = HTTPBearer(auto_error=False)
# ============================================================================
# GLOBAL STATE (imported from main app)
# ============================================================================
class AppState:
"""Global application state - singleton pattern"""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self.db = None
self.config: Dict = {}
self.scheduler = None
self.websocket_manager = None # ConnectionManager instance for broadcasts
self.scraper_event_emitter = None # ScraperEventEmitter for real-time scraping monitor
self.active_scraper_sessions: Dict = {} # Track active scraping sessions for monitor page
self.running_platform_downloads: Dict = {} # Track running platform downloads {platform: process}
self.download_tasks: Dict = {}
self.auth = None
self.settings = None
self.face_recognition_semaphore = None
self.indexing_running: bool = False
self.indexing_start_time = None
self.review_rescan_running: bool = False
self.review_rescan_progress: Dict = {"current": 0, "total": 0, "current_file": ""}
self._initialized = True
def get_app_state() -> AppState:
"""Get the global application state"""
return AppState()
# ============================================================================
# AUTHENTICATION DEPENDENCIES
# ============================================================================
async def get_current_user(
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
) -> Dict:
"""
Dependency to get current authenticated user from JWT token.
Supports both Authorization header and cookie-based authentication.
"""
app_state = get_app_state()
auth_token = None
# Try Authorization header first
if credentials:
auth_token = credentials.credentials
# Try cookie second
elif 'auth_token' in request.cookies:
auth_token = request.cookies.get('auth_token')
if not auth_token:
raise HTTPException(status_code=401, detail="Not authenticated")
if not app_state.auth:
raise HTTPException(status_code=500, detail="Authentication not initialized")
payload = app_state.auth.verify_session(auth_token)
if not payload:
raise HTTPException(status_code=401, detail="Invalid or expired token")
return payload
async def get_current_user_optional(
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
) -> Optional[Dict]:
"""Optional authentication dependency - returns None if not authenticated"""
try:
return await get_current_user(request, credentials)
except HTTPException:
return None
async def get_current_user_media(
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
token: Optional[str] = Query(None)
) -> Dict:
"""
Authentication for media endpoints.
Supports header, cookie, and query parameter tokens (for img/video tags).
Priority: 1) Authorization header, 2) Cookie, 3) Query parameter
"""
app_state = get_app_state()
auth_token = None
# Try to get token from Authorization header first (preferred)
if credentials:
auth_token = credentials.credentials
# Try cookie second (for <img> and <video> tags)
elif 'auth_token' in request.cookies:
auth_token = request.cookies.get('auth_token')
# Fall back to query parameter (for backward compatibility)
elif token:
auth_token = token
if not auth_token:
raise HTTPException(status_code=401, detail="Not authenticated")
if not app_state.auth:
raise HTTPException(status_code=500, detail="Authentication not initialized")
payload = app_state.auth.verify_session(auth_token)
if not payload:
raise HTTPException(status_code=401, detail="Invalid or expired token")
return payload
async def require_admin(
request: Request,
current_user: Dict = Depends(get_current_user)
) -> Dict:
"""Dependency to require admin role for sensitive operations"""
app_state = get_app_state()
username = current_user.get('sub')
if not username:
raise HTTPException(status_code=401, detail="Invalid user session")
# Get user info to check role
user_info = app_state.auth.get_user(username)
if not user_info:
raise HTTPException(status_code=401, detail="User not found")
if user_info.get('role') != 'admin':
logger.warning(f"User {username} attempted admin operation without admin role", module="Auth")
raise HTTPException(status_code=403, detail="Admin role required")
return current_user
# ============================================================================
# DATABASE DEPENDENCIES
# ============================================================================
def get_database():
"""Get the database connection"""
app_state = get_app_state()
if not app_state.db:
raise HTTPException(status_code=500, detail="Database not initialized")
return app_state.db
def get_settings_manager():
"""Get the settings manager"""
app_state = get_app_state()
if not app_state.settings:
raise HTTPException(status_code=500, detail="Settings manager not initialized")
return app_state.settings
# ============================================================================
# UTILITY DEPENDENCIES
# ============================================================================
def get_scheduler():
"""Get the scheduler instance"""
app_state = get_app_state()
return app_state.scheduler
def get_face_semaphore():
"""Get the face recognition semaphore for limiting concurrent operations"""
app_state = get_app_state()
return app_state.face_recognition_semaphore

View File

@@ -0,0 +1,346 @@
"""
Custom Exception Classes
Provides specific exception types for better error handling and reporting.
Replaces broad 'except Exception' handlers with specific, meaningful exceptions.
"""
from typing import Optional, Dict, Any
from fastapi import HTTPException
# ============================================================================
# BASE EXCEPTIONS
# ============================================================================
class MediaDownloaderError(Exception):
"""Base exception for all Media Downloader errors"""
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
self.message = message
self.details = details or {}
super().__init__(self.message)
def to_dict(self) -> Dict[str, Any]:
return {
"error": self.__class__.__name__,
"message": self.message,
"details": self.details
}
# ============================================================================
# DATABASE EXCEPTIONS
# ============================================================================
class DatabaseError(MediaDownloaderError):
"""Database operation failed"""
pass
class DatabaseConnectionError(DatabaseError):
"""Failed to connect to database"""
pass
class DatabaseQueryError(DatabaseError):
"""Database query failed"""
pass
class RecordNotFoundError(DatabaseError):
"""Requested record not found in database"""
pass
# Alias for generic "not found" scenarios
NotFoundError = RecordNotFoundError
class DuplicateRecordError(DatabaseError):
"""Attempted to insert duplicate record"""
pass
# ============================================================================
# DOWNLOAD EXCEPTIONS
# ============================================================================
class DownloadError(MediaDownloaderError):
"""Download operation failed"""
pass
class NetworkError(DownloadError):
"""Network request failed"""
pass
class RateLimitError(DownloadError):
"""Rate limit exceeded"""
pass
class AuthenticationError(DownloadError):
"""Authentication failed for external service"""
pass
class PlatformUnavailableError(DownloadError):
"""Platform or service is unavailable"""
pass
class ContentNotFoundError(DownloadError):
"""Requested content not found on platform"""
pass
class InvalidURLError(DownloadError):
"""Invalid or malformed URL"""
pass
# ============================================================================
# FILE OPERATION EXCEPTIONS
# ============================================================================
class FileOperationError(MediaDownloaderError):
"""File operation failed"""
pass
class MediaFileNotFoundError(FileOperationError):
"""File not found on filesystem"""
pass
class FileAccessError(FileOperationError):
"""Cannot access file (permissions, etc.)"""
pass
class FileHashError(FileOperationError):
"""File hash computation failed"""
pass
class FileMoveError(FileOperationError):
"""Failed to move file"""
pass
class ThumbnailError(FileOperationError):
"""Thumbnail generation failed"""
pass
# ============================================================================
# VALIDATION EXCEPTIONS
# ============================================================================
class ValidationError(MediaDownloaderError):
"""Input validation failed"""
pass
class InvalidParameterError(ValidationError):
"""Invalid parameter value"""
pass
class MissingParameterError(ValidationError):
"""Required parameter missing"""
pass
class PathTraversalError(ValidationError):
"""Path traversal attempt detected"""
pass
# ============================================================================
# AUTHENTICATION/AUTHORIZATION EXCEPTIONS
# ============================================================================
class AuthError(MediaDownloaderError):
"""Authentication or authorization error"""
pass
class TokenExpiredError(AuthError):
"""Authentication token has expired"""
pass
class InvalidTokenError(AuthError):
"""Authentication token is invalid"""
pass
class InsufficientPermissionsError(AuthError):
"""User lacks required permissions"""
pass
# ============================================================================
# SERVICE EXCEPTIONS
# ============================================================================
class ServiceError(MediaDownloaderError):
"""External service error"""
pass
class FlareSolverrError(ServiceError):
"""FlareSolverr service error"""
pass
class RedisError(ServiceError):
"""Redis cache error"""
pass
class SchedulerError(ServiceError):
"""Scheduler service error"""
pass
# ============================================================================
# FACE RECOGNITION EXCEPTIONS
# ============================================================================
class FaceRecognitionError(MediaDownloaderError):
"""Face recognition operation failed"""
pass
class NoFaceDetectedError(FaceRecognitionError):
"""No face detected in image"""
pass
class FaceEncodingError(FaceRecognitionError):
"""Failed to encode face"""
pass
# ============================================================================
# HTTP EXCEPTION HELPERS
# ============================================================================
def to_http_exception(error: MediaDownloaderError) -> HTTPException:
"""Convert a MediaDownloaderError to an HTTPException"""
# Map exception types to HTTP status codes (most-specific subclasses first)
status_map = {
# 400 Bad Request (subclasses before parent)
InvalidParameterError: 400,
MissingParameterError: 400,
InvalidURLError: 400,
ValidationError: 400,
# 401 Unauthorized (subclasses before parent)
TokenExpiredError: 401,
InvalidTokenError: 401,
AuthenticationError: 401,
AuthError: 401,
# 403 Forbidden
InsufficientPermissionsError: 403,
PathTraversalError: 403,
FileAccessError: 403,
# 404 Not Found
RecordNotFoundError: 404,
MediaFileNotFoundError: 404,
ContentNotFoundError: 404,
# 409 Conflict
DuplicateRecordError: 409,
# 429 Too Many Requests
RateLimitError: 429,
# 500 Internal Server Error (subclasses before parent)
DatabaseConnectionError: 500,
DatabaseQueryError: 500,
DatabaseError: 500,
# 502 Bad Gateway (subclasses before parent)
FlareSolverrError: 502,
ServiceError: 502,
NetworkError: 502,
# 503 Service Unavailable
PlatformUnavailableError: 503,
SchedulerError: 503,
}
# Find the most specific matching status code
status_code = 500 # Default to internal server error
for exc_type, code in status_map.items():
if isinstance(error, exc_type):
status_code = code
break
return HTTPException(
status_code=status_code,
detail=error.to_dict()
)
# ============================================================================
# EXCEPTION HANDLER DECORATOR
# ============================================================================
from functools import wraps
from typing import Callable, TypeVar
import asyncio
T = TypeVar('T')
def handle_exceptions(func: Callable[..., T]) -> Callable[..., T]:
"""
Decorator to convert MediaDownloaderError exceptions to HTTPException.
Use this on route handlers for consistent error responses.
"""
if asyncio.iscoroutinefunction(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except MediaDownloaderError as e:
raise to_http_exception(e)
except HTTPException:
raise # Let HTTPException pass through
except Exception as e:
# Log unexpected exceptions
from modules.universal_logger import get_logger
logger = get_logger('API')
logger.error(f"Unexpected error in {func.__name__}: {e}", module="Core")
raise HTTPException(
status_code=500,
detail={"error": "InternalError", "message": str(e)}
)
return async_wrapper
else:
@wraps(func)
def sync_wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except MediaDownloaderError as e:
raise to_http_exception(e)
except HTTPException:
raise
except Exception as e:
from modules.universal_logger import get_logger
logger = get_logger('API')
logger.error(f"Unexpected error in {func.__name__}: {e}", module="Core")
raise HTTPException(
status_code=500,
detail={"error": "InternalError", "message": str(e)}
)
return sync_wrapper

View File

@@ -0,0 +1,377 @@
"""
Async HTTP Client
Provides async HTTP client to replace synchronous requests library in FastAPI endpoints.
This prevents blocking the event loop and improves concurrency.
Usage:
from web.backend.core.http_client import http_client, get_http_client
# In async endpoint
async def my_endpoint():
async with get_http_client() as client:
response = await client.get("https://api.example.com/data")
return response.json()
# Or use the singleton
response = await http_client.get("https://api.example.com/data")
"""
import httpx
from typing import Optional, Dict, Any, Union
from contextlib import asynccontextmanager
import asyncio
from .config import settings
from .exceptions import NetworkError, ServiceError
# Default timeouts
DEFAULT_TIMEOUT = httpx.Timeout(
connect=10.0,
read=30.0,
write=10.0,
pool=10.0
)
# Longer timeout for slow operations
LONG_TIMEOUT = httpx.Timeout(
connect=10.0,
read=120.0,
write=30.0,
pool=10.0
)
class AsyncHTTPClient:
"""
Async HTTP client wrapper with retry logic and error handling.
Features:
- Connection pooling
- Automatic retries
- Timeout handling
- Error conversion to custom exceptions
"""
def __init__(
self,
base_url: Optional[str] = None,
timeout: Optional[httpx.Timeout] = None,
headers: Optional[Dict[str, str]] = None,
max_retries: int = 3
):
self.base_url = base_url
self.timeout = timeout or DEFAULT_TIMEOUT
self.headers = headers or {}
self.max_retries = max_retries
self._client: Optional[httpx.AsyncClient] = None
async def _get_client(self) -> httpx.AsyncClient:
"""Get or create the async client"""
if self._client is None or self._client.is_closed:
client_kwargs = {
"timeout": self.timeout,
"headers": self.headers,
"follow_redirects": True,
"limits": httpx.Limits(
max_keepalive_connections=20,
max_connections=100,
keepalive_expiry=30.0
)
}
# Only add base_url if it's not None
if self.base_url is not None:
client_kwargs["base_url"] = self.base_url
self._client = httpx.AsyncClient(**client_kwargs)
return self._client
async def close(self):
"""Close the client connection"""
if self._client and not self._client.is_closed:
await self._client.aclose()
self._client = None
async def _request(
self,
method: str,
url: str,
**kwargs
) -> httpx.Response:
"""
Make an HTTP request with retry logic.
Args:
method: HTTP method (GET, POST, etc.)
url: URL to request
**kwargs: Additional arguments for httpx
Returns:
httpx.Response
Raises:
NetworkError: On network failures
ServiceError: On HTTP errors
"""
client = await self._get_client()
last_error = None
for attempt in range(self.max_retries):
try:
response = await client.request(method, url, **kwargs)
# Raise for 4xx/5xx status codes
if response.status_code >= 400:
if response.status_code >= 500:
raise ServiceError(
f"Server error: {response.status_code}",
{"url": url, "status": response.status_code}
)
# Don't retry client errors
return response
return response
except httpx.ConnectError as e:
last_error = NetworkError(
f"Connection failed: {e}",
{"url": url, "attempt": attempt + 1}
)
except httpx.TimeoutException as e:
last_error = NetworkError(
f"Request timed out: {e}",
{"url": url, "attempt": attempt + 1}
)
except httpx.HTTPStatusError as e:
last_error = ServiceError(
f"HTTP error: {e}",
{"url": url, "status": e.response.status_code}
)
# Don't retry client errors
if e.response.status_code < 500:
raise last_error
except Exception as e:
last_error = NetworkError(
f"Request failed: {e}",
{"url": url, "attempt": attempt + 1}
)
# Wait before retry (exponential backoff)
if attempt < self.max_retries - 1:
await asyncio.sleep(2 ** attempt)
raise last_error
# Convenience methods
async def get(
self,
url: str,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[float] = None,
**kwargs
) -> httpx.Response:
"""Make GET request"""
request_timeout = httpx.Timeout(timeout) if timeout else None
return await self._request(
"GET",
url,
params=params,
headers=headers,
timeout=request_timeout,
**kwargs
)
async def post(
self,
url: str,
data: Optional[Dict[str, Any]] = None,
json: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[float] = None,
**kwargs
) -> httpx.Response:
"""Make POST request"""
request_timeout = httpx.Timeout(timeout) if timeout else None
return await self._request(
"POST",
url,
data=data,
json=json,
headers=headers,
timeout=request_timeout,
**kwargs
)
async def put(
self,
url: str,
data: Optional[Dict[str, Any]] = None,
json: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
**kwargs
) -> httpx.Response:
"""Make PUT request"""
return await self._request(
"PUT",
url,
data=data,
json=json,
headers=headers,
**kwargs
)
async def delete(
self,
url: str,
headers: Optional[Dict[str, str]] = None,
**kwargs
) -> httpx.Response:
"""Make DELETE request"""
return await self._request(
"DELETE",
url,
headers=headers,
**kwargs
)
async def head(
self,
url: str,
headers: Optional[Dict[str, str]] = None,
**kwargs
) -> httpx.Response:
"""Make HEAD request"""
return await self._request(
"HEAD",
url,
headers=headers,
**kwargs
)
# Singleton client for general use
http_client = AsyncHTTPClient()
@asynccontextmanager
async def get_http_client(
base_url: Optional[str] = None,
timeout: Optional[httpx.Timeout] = None,
headers: Optional[Dict[str, str]] = None
):
"""
Context manager for creating a temporary HTTP client.
Usage:
async with get_http_client(base_url="https://api.example.com") as client:
response = await client.get("/endpoint")
"""
client = AsyncHTTPClient(
base_url=base_url,
timeout=timeout,
headers=headers
)
try:
yield client
finally:
await client.close()
# Helper functions for common patterns
async def fetch_json(
url: str,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
"""
Fetch JSON from URL.
Args:
url: URL to fetch
params: Query parameters
headers: Request headers
Returns:
Parsed JSON response
"""
response = await http_client.get(url, params=params, headers=headers)
return response.json()
async def post_json(
url: str,
data: Dict[str, Any],
headers: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
"""
POST JSON to URL and return JSON response.
Args:
url: URL to post to
data: JSON data to send
headers: Request headers
Returns:
Parsed JSON response
"""
response = await http_client.post(url, json=data, headers=headers)
return response.json()
async def check_url_accessible(url: str, timeout: float = 5.0) -> bool:
"""
Check if a URL is accessible.
Args:
url: URL to check
timeout: Request timeout
Returns:
True if URL is accessible
"""
try:
response = await http_client.head(url, timeout=timeout)
return response.status_code < 400
except Exception:
return False
async def download_file(
url: str,
save_path: str,
headers: Optional[Dict[str, str]] = None,
chunk_size: int = 8192
) -> int:
"""
Download file from URL.
Args:
url: URL to download
save_path: Path to save file
headers: Request headers
chunk_size: Size of chunks for streaming
Returns:
Number of bytes downloaded
"""
from pathlib import Path
client = await http_client._get_client()
total_bytes = 0
async with client.stream("GET", url, headers=headers) as response:
response.raise_for_status()
# Ensure directory exists
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
with open(save_path, 'wb') as f:
async for chunk in response.aiter_bytes(chunk_size):
f.write(chunk)
total_bytes += len(chunk)
return total_bytes

View File

@@ -0,0 +1,311 @@
"""
Standardized Response Format
Provides consistent response structures for all API endpoints.
All responses follow a standardized format for error handling and data delivery.
"""
from typing import Any, Dict, List, Optional, TypeVar, Generic
from pydantic import BaseModel, Field
from datetime import datetime, timezone
T = TypeVar('T')
# ============================================================================
# STANDARD RESPONSE MODELS
# ============================================================================
class ErrorDetail(BaseModel):
"""Standard error detail structure"""
error: str = Field(..., description="Error type/code")
message: str = Field(..., description="Human-readable error message")
details: Optional[Dict[str, Any]] = Field(None, description="Additional error context")
timestamp: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"))
class SuccessResponse(BaseModel):
"""Standard success response"""
success: bool = True
message: Optional[str] = None
data: Optional[Any] = None
class PaginatedResponse(BaseModel):
"""Standard paginated response"""
items: List[Any]
total: int
page: int
page_size: int
has_more: bool
class ListResponse(BaseModel):
"""Standard list response with metadata"""
items: List[Any]
count: int
# ============================================================================
# RESPONSE BUILDERS
# ============================================================================
def success(data: Any = None, message: str = None) -> Dict[str, Any]:
"""Build a success response"""
response = {"success": True}
if message:
response["message"] = message
if data is not None:
response["data"] = data
return response
def error(
error_type: str,
message: str,
details: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""Build an error response"""
return {
"success": False,
"error": error_type,
"message": message,
"details": details or {},
"timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
}
def paginated(
items: List[Any],
total: int,
page: int,
page_size: int
) -> Dict[str, Any]:
"""Build a paginated response"""
return {
"items": items,
"total": total,
"page": page,
"page_size": page_size,
"has_more": (page * page_size) < total
}
def list_response(items: List[Any], key: str = "items") -> Dict[str, Any]:
"""Build a simple list response with custom key"""
return {
key: items,
"count": len(items)
}
def message_response(message: str) -> Dict[str, str]:
"""Build a simple message response"""
return {"message": message}
def count_response(message: str, count: int) -> Dict[str, Any]:
"""Build a response with count of affected items"""
return {"message": message, "count": count}
def id_response(resource_id: int, message: str = "Resource created successfully") -> Dict[str, Any]:
"""Build a response with created resource ID"""
return {"id": resource_id, "message": message}
def offset_paginated(
items: List[Any],
total: int,
limit: int,
offset: int,
key: str = "items",
include_timestamp: bool = False,
**extra_fields
) -> Dict[str, Any]:
"""
Build an offset-based paginated response (common pattern in existing routers).
This is the standard paginated response format used across all endpoints
that return lists of items with pagination support.
Args:
items: Items for current page
total: Total count of all items
limit: Page size (number of items per page)
offset: Current offset (starting position)
key: Key name for items list (default "items")
include_timestamp: Whether to include timestamp field
**extra_fields: Additional fields to include in response
Returns:
Dict with paginated response structure
Example:
return offset_paginated(
items=media_items,
total=total_count,
limit=50,
offset=0,
key="media",
stats={"images": 100, "videos": 50}
)
"""
response = {
key: items,
"total": total,
"limit": limit,
"offset": offset
}
if include_timestamp:
response["timestamp"] = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
# Add any extra fields
response.update(extra_fields)
return response
def batch_operation_response(
succeeded: bool,
processed: List[Any] = None,
errors: List[Dict[str, Any]] = None,
message: str = None
) -> Dict[str, Any]:
"""
Build a response for batch operations (batch delete, batch move, etc.).
This provides a standardized format for operations that affect multiple items.
Args:
succeeded: Overall success status
processed: List of successfully processed items
errors: List of error details for failed items
message: Optional message
Returns:
Dict with batch operation response structure
Example:
return batch_operation_response(
succeeded=True,
processed=["/path/to/file1.jpg", "/path/to/file2.jpg"],
errors=[{"file": "/path/to/file3.jpg", "error": "File not found"}]
)
"""
processed = processed or []
errors = errors or []
response = {
"success": succeeded,
"processed_count": len(processed),
"error_count": len(errors),
"timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
}
if processed:
response["processed"] = processed
if errors:
response["errors"] = errors
if message:
response["message"] = message
return response
# ============================================================================
# DATE/TIME UTILITIES
# ============================================================================
def to_iso8601(dt: datetime) -> Optional[str]:
"""
Convert datetime to ISO 8601 format with UTC timezone.
This is the standard format for all API responses.
Example output: "2025-12-04T10:30:00Z"
"""
if dt is None:
return None
# Ensure timezone awareness
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
else:
dt = dt.astimezone(timezone.utc)
return dt.strftime("%Y-%m-%dT%H:%M:%SZ")
def from_iso8601(date_string: str) -> Optional[datetime]:
"""
Parse ISO 8601 date string to datetime.
Handles various formats including with and without timezone.
"""
if not date_string:
return None
# Common formats to try
formats = [
"%Y-%m-%dT%H:%M:%SZ", # ISO with Z
"%Y-%m-%dT%H:%M:%S.%fZ", # ISO with microseconds and Z
"%Y-%m-%dT%H:%M:%S", # ISO without Z
"%Y-%m-%dT%H:%M:%S.%f", # ISO with microseconds
"%Y-%m-%d %H:%M:%S", # Space separator
"%Y-%m-%d", # Date only
]
for fmt in formats:
try:
dt = datetime.strptime(date_string, fmt)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt
except ValueError:
continue
# Try parsing with fromisoformat (Python 3.7+)
try:
dt = datetime.fromisoformat(date_string.replace('Z', '+00:00'))
return dt
except ValueError:
pass
return None
def format_timestamp(
timestamp: Optional[str],
output_format: str = "iso8601"
) -> Optional[str]:
"""
Format a timestamp string to the specified format.
Args:
timestamp: Input timestamp string
output_format: "iso8601", "unix", or strftime format string
Returns:
Formatted timestamp string
"""
if not timestamp:
return None
dt = from_iso8601(timestamp)
if not dt:
return timestamp # Return original if parsing failed
if output_format == "iso8601":
return to_iso8601(dt)
elif output_format == "unix":
return str(int(dt.timestamp()))
else:
return dt.strftime(output_format)
def now_iso8601() -> str:
"""Get current UTC time in ISO 8601 format"""
return to_iso8601(datetime.now(timezone.utc))

582
web/backend/core/utils.py Normal file
View File

@@ -0,0 +1,582 @@
"""
Shared Utility Functions
Common helper functions used across multiple routers.
"""
import io
import sqlite3
import hashlib
import subprocess
from collections import OrderedDict
from contextlib import closing
from pathlib import Path
from threading import Lock
from typing import Dict, List, Optional, Tuple, Union
from fastapi import HTTPException
from PIL import Image
from .config import settings
# ============================================================================
# THUMBNAIL LRU CACHE
# ============================================================================
class ThumbnailLRUCache:
"""Thread-safe LRU cache for thumbnail binary data.
Avoids SQLite lookups for frequently accessed thumbnails.
Used by media.py and recycle.py routers.
"""
def __init__(self, max_size: int = 500, max_memory_mb: int = 100):
self._cache: OrderedDict[str, bytes] = OrderedDict()
self._lock = Lock()
self._max_size = max_size
self._max_memory = max_memory_mb * 1024 * 1024 # Convert to bytes
self._current_memory = 0
def get(self, key: str) -> Optional[bytes]:
with self._lock:
if key in self._cache:
# Move to end (most recently used)
self._cache.move_to_end(key)
return self._cache[key]
return None
def put(self, key: str, data: bytes) -> None:
with self._lock:
data_size = len(data)
# Don't cache if single item is too large (>1MB)
if data_size > 1024 * 1024:
return
# Remove old entry if exists
if key in self._cache:
self._current_memory -= len(self._cache[key])
del self._cache[key]
# Evict oldest entries if needed
while (len(self._cache) >= self._max_size or
self._current_memory + data_size > self._max_memory) and self._cache:
oldest_key, oldest_data = self._cache.popitem(last=False)
self._current_memory -= len(oldest_data)
# Add new entry
self._cache[key] = data
self._current_memory += data_size
def clear(self) -> None:
with self._lock:
self._cache.clear()
self._current_memory = 0
# ============================================================================
# SQL FILTER CONSTANTS
# ============================================================================
# Valid media file filters (excluding phrase checks, must have valid extension)
# Used by downloads, health, and analytics endpoints
MEDIA_FILTERS = """
(filename NOT LIKE '%_phrase_checked_%' OR filename IS NULL)
AND (file_path IS NOT NULL AND file_path != '' OR platform = 'forums')
AND (LENGTH(filename) > 20 OR filename LIKE '%_%_%')
AND (
filename LIKE '%.jpg' OR filename LIKE '%.jpeg' OR
filename LIKE '%.png' OR filename LIKE '%.gif' OR
filename LIKE '%.heic' OR filename LIKE '%.heif' OR
filename LIKE '%.mp4' OR filename LIKE '%.mov' OR
filename LIKE '%.webm' OR filename LIKE '%.m4a' OR
filename LIKE '%.mp3' OR filename LIKE '%.avi' OR
filename LIKE '%.mkv' OR filename LIKE '%.flv'
)
"""
# ============================================================================
# PATH VALIDATION
# ============================================================================
# Allowed base paths for file operations
ALLOWED_PATHS = [
settings.MEDIA_BASE_PATH,
settings.REVIEW_PATH,
settings.RECYCLE_PATH,
Path('/opt/media-downloader/temp/manual_import'),
Path('/opt/immich/paid'),
Path('/opt/immich/el'),
Path('/opt/immich/elv'),
]
def validate_file_path(
file_path: str,
allowed_bases: Optional[List[Path]] = None,
require_exists: bool = False
) -> Path:
"""
Validate file path is within allowed directories.
Prevents path traversal attacks.
Args:
file_path: Path to validate
allowed_bases: List of allowed base paths (defaults to ALLOWED_PATHS)
require_exists: If True, also verify the file exists
Returns:
Resolved Path object
Raises:
HTTPException: If path is invalid or outside allowed directories
"""
if allowed_bases is None:
allowed_bases = ALLOWED_PATHS
requested_path = Path(file_path)
try:
resolved_path = requested_path.resolve()
is_allowed = False
for allowed_base in allowed_bases:
try:
resolved_path.relative_to(allowed_base.resolve())
is_allowed = True
break
except ValueError:
continue
if not is_allowed:
raise HTTPException(status_code=403, detail="Access denied")
if require_exists and not resolved_path.exists():
raise HTTPException(status_code=404, detail="File not found")
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=400, detail="Invalid file path")
return resolved_path
# ============================================================================
# QUERY FILTER BUILDER
# ============================================================================
def build_media_filter_query(
platform: Optional[str] = None,
source: Optional[str] = None,
media_type: Optional[str] = None,
location: Optional[str] = None,
date_from: Optional[str] = None,
date_to: Optional[str] = None,
table_alias: str = "fi"
) -> Tuple[str, List]:
"""
Build SQL filter clause for common media queries.
This centralizes the filter building logic that was duplicated across
media.py, downloads.py, and review.py routers.
Args:
platform: Filter by platform (e.g., 'instagram', 'tiktok')
source: Filter by source (e.g., 'stories', 'posts')
media_type: Filter by media type ('image', 'video', 'all')
location: Filter by location ('final', 'review', 'recycle')
date_from: Start date filter (ISO format)
date_to: End date filter (ISO format)
table_alias: SQL table alias (default 'fi' for file_inventory)
Returns:
Tuple of (SQL WHERE clause conditions, list of parameters)
Example:
conditions, params = build_media_filter_query(platform="instagram", media_type="video")
query = f"SELECT * FROM file_inventory fi WHERE {conditions}"
cursor.execute(query, params)
"""
if not table_alias.isidentifier():
table_alias = "fi"
conditions = []
params = []
if platform:
conditions.append(f"{table_alias}.platform = ?")
params.append(platform)
if source:
conditions.append(f"{table_alias}.source = ?")
params.append(source)
if media_type and media_type != 'all':
conditions.append(f"{table_alias}.media_type = ?")
params.append(media_type)
if location:
conditions.append(f"{table_alias}.location = ?")
params.append(location)
if date_from:
# Use COALESCE to handle both post_date from downloads and created_date from file_inventory
conditions.append(f"""
DATE(COALESCE(
(SELECT MAX(d.post_date) FROM downloads d WHERE d.file_path = {table_alias}.file_path),
{table_alias}.created_date
)) >= ?
""")
params.append(date_from)
if date_to:
conditions.append(f"""
DATE(COALESCE(
(SELECT MAX(d.post_date) FROM downloads d WHERE d.file_path = {table_alias}.file_path),
{table_alias}.created_date
)) <= ?
""")
params.append(date_to)
return " AND ".join(conditions) if conditions else "1=1", params
def build_platform_list_filter(
platforms: Optional[List[str]] = None,
table_alias: str = "fi"
) -> Tuple[str, List[str]]:
"""
Build SQL IN clause for filtering by multiple platforms.
Args:
platforms: List of platform names to filter by
table_alias: SQL table alias
Returns:
Tuple of (SQL condition string, list of parameters)
"""
if not platforms:
return "1=1", []
placeholders = ",".join(["?"] * len(platforms))
return f"{table_alias}.platform IN ({placeholders})", platforms
# ============================================================================
# THUMBNAIL GENERATION
# ============================================================================
def generate_image_thumbnail(file_path: Path, max_size: Tuple[int, int] = (300, 300)) -> Optional[bytes]:
"""
Generate thumbnail for image file.
Args:
file_path: Path to image file
max_size: Maximum thumbnail dimensions
Returns:
JPEG bytes or None if generation fails
"""
try:
img = Image.open(file_path)
img.thumbnail(max_size, Image.Resampling.LANCZOS)
# Convert to RGB if necessary
if img.mode in ('RGBA', 'LA', 'P'):
background = Image.new('RGB', img.size, (255, 255, 255))
if img.mode == 'P':
img = img.convert('RGBA')
background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None)
img = background
buffer = io.BytesIO()
img.save(buffer, format='JPEG', quality=85)
return buffer.getvalue()
except Exception:
return None
def generate_video_thumbnail(file_path: Path, max_size: Tuple[int, int] = (300, 300)) -> Optional[bytes]:
"""
Generate thumbnail for video file using ffmpeg.
Args:
file_path: Path to video file
max_size: Maximum thumbnail dimensions
Returns:
JPEG bytes or None if generation fails
"""
# Try seeking to 1s first, then fall back to first frame
for seek_time in ['00:00:01.000', '00:00:00.000']:
try:
result = subprocess.run([
'ffmpeg',
'-ss', seek_time,
'-i', str(file_path),
'-vframes', '1',
'-f', 'image2pipe',
'-vcodec', 'mjpeg',
'-'
], capture_output=True, timeout=30)
if result.returncode != 0 or not result.stdout:
continue
# Resize the frame
img = Image.open(io.BytesIO(result.stdout))
img.thumbnail(max_size, Image.Resampling.LANCZOS)
buffer = io.BytesIO()
img.save(buffer, format='JPEG', quality=85)
return buffer.getvalue()
except Exception:
continue
return None
def get_or_create_thumbnail(
file_path: Union[str, Path],
media_type: str,
content_hash: Optional[str] = None,
max_size: Tuple[int, int] = (300, 300)
) -> Optional[bytes]:
"""
Get thumbnail from cache or generate and cache it.
Uses the thumbnails.db schema: file_hash (PK), file_path, thumbnail_data, created_at, file_mtime.
Lookup strategy:
1. Try content_hash against file_hash column (survives file moves)
2. Fall back to file_path lookup (legacy thumbnails)
3. Generate and cache if not found
Args:
file_path: Path to media file
media_type: 'image' or 'video'
content_hash: Optional pre-computed hash (computed from path if not provided)
max_size: Maximum thumbnail dimensions
Returns:
JPEG bytes or None if generation fails
"""
file_path = Path(file_path)
thumb_db_path = settings.PROJECT_ROOT / 'database' / 'thumbnails.db'
# Compute hash if not provided
file_hash = content_hash if content_hash else hashlib.sha256(str(file_path).encode()).hexdigest()
# Try to get from cache (skip mtime check — downloaded media files don't change)
try:
with sqlite3.connect(str(thumb_db_path), timeout=30.0) as conn:
cursor = conn.cursor()
# 1. Try file_hash lookup (primary key)
cursor.execute(
"SELECT thumbnail_data FROM thumbnails WHERE file_hash = ?",
(file_hash,)
)
result = cursor.fetchone()
if result and result[0]:
return result[0]
# 2. Fall back to file_path lookup (legacy thumbnails)
cursor.execute(
"SELECT thumbnail_data FROM thumbnails WHERE file_path = ?",
(str(file_path),)
)
result = cursor.fetchone()
if result and result[0]:
return result[0]
except Exception:
pass
# Only proceed with generation if the file actually exists
if not file_path.exists():
return None
# Get mtime only when we need to generate and cache a new thumbnail
try:
file_mtime = file_path.stat().st_mtime
except OSError:
file_mtime = 0
# Generate thumbnail
thumbnail_data = None
if media_type == 'video':
thumbnail_data = generate_video_thumbnail(file_path, max_size)
else:
thumbnail_data = generate_image_thumbnail(file_path, max_size)
# Cache the thumbnail
if thumbnail_data:
try:
from .responses import now_iso8601
with sqlite3.connect(str(thumb_db_path), timeout=30.0) as conn:
conn.execute("""
INSERT OR REPLACE INTO thumbnails
(file_hash, file_path, thumbnail_data, created_at, file_mtime)
VALUES (?, ?, ?, ?, ?)
""", (file_hash, str(file_path), thumbnail_data, now_iso8601(), file_mtime))
conn.commit()
except Exception:
pass
return thumbnail_data
def get_media_dimensions(file_path: str, width: int = None, height: int = None) -> Tuple[Optional[int], Optional[int]]:
"""
Get media dimensions, falling back to metadata cache if not provided.
Args:
file_path: Path to media file
width: Width from file_inventory (may be None)
height: Height from file_inventory (may be None)
Returns:
Tuple of (width, height), or (None, None) if not available
"""
if width is not None and height is not None:
return (width, height)
try:
metadata_db_path = settings.PROJECT_ROOT / 'database' / 'media_metadata.db'
file_hash = hashlib.sha256(file_path.encode()).hexdigest()
with closing(sqlite3.connect(str(metadata_db_path))) as conn:
cursor = conn.execute(
"SELECT width, height FROM media_metadata WHERE file_hash = ?",
(file_hash,)
)
result = cursor.fetchone()
if result:
return (result[0], result[1])
except Exception:
pass
return (width, height)
def get_media_dimensions_batch(file_paths: List[str]) -> Dict[str, Tuple[int, int]]:
"""
Get media dimensions for multiple files in a single query (batch lookup).
Avoids N+1 query problem by fetching all dimensions at once.
Args:
file_paths: List of file paths to look up
Returns:
Dict mapping file_path -> (width, height)
"""
if not file_paths:
return {}
result = {}
try:
metadata_db_path = settings.PROJECT_ROOT / 'database' / 'media_metadata.db'
# Build hash -> path mapping
hash_to_path = {}
for fp in file_paths:
file_hash = hashlib.sha256(fp.encode()).hexdigest()
hash_to_path[file_hash] = fp
# Query all at once
with closing(sqlite3.connect(str(metadata_db_path))) as conn:
placeholders = ','.join('?' * len(hash_to_path))
cursor = conn.execute(
f"SELECT file_hash, width, height FROM media_metadata WHERE file_hash IN ({placeholders})",
list(hash_to_path.keys())
)
for row in cursor.fetchall():
file_hash, width, height = row
if file_hash in hash_to_path:
result[hash_to_path[file_hash]] = (width, height)
except Exception:
pass
return result
# ============================================================================
# DATABASE UTILITIES
# ============================================================================
def update_file_path_in_all_tables(db, old_path: str, new_path: str):
"""
Update file path in all relevant database tables.
Used when moving files between locations (review -> final, etc.)
to keep database references consistent.
Args:
db: UnifiedDatabase instance
old_path: The old file path to replace
new_path: The new file path to use
"""
from modules.universal_logger import get_logger
logger = get_logger('API')
try:
with db.get_connection(for_write=True) as conn:
cursor = conn.cursor()
cursor.execute('UPDATE downloads SET file_path = ? WHERE file_path = ?',
(new_path, old_path))
cursor.execute('UPDATE instagram_perceptual_hashes SET file_path = ? WHERE file_path = ?',
(new_path, old_path))
cursor.execute('UPDATE face_recognition_scans SET file_path = ? WHERE file_path = ?',
(new_path, old_path))
try:
cursor.execute('UPDATE semantic_embeddings SET file_path = ? WHERE file_path = ?',
(new_path, old_path))
except sqlite3.OperationalError:
pass # Table may not exist
conn.commit()
except Exception as e:
logger.warning(f"Failed to update file paths in tables: {e}", module="Database")
# ============================================================================
# FACE RECOGNITION UTILITIES
# ============================================================================
# Cached FaceRecognitionModule singleton to avoid loading InsightFace models on every request
_face_module_cache: Dict[int, 'FaceRecognitionModule'] = {}
def get_face_module(db, module_name: str = "FaceAPI"):
"""
Get or create a cached FaceRecognitionModule instance for the given database.
Uses singleton pattern to avoid reloading heavy InsightFace models on each request.
Args:
db: UnifiedDatabase instance
module_name: Name to use in log messages
Returns:
FaceRecognitionModule instance
"""
from modules.face_recognition_module import FaceRecognitionModule
from modules.universal_logger import get_logger
logger = get_logger('API')
db_id = id(db)
if db_id not in _face_module_cache:
logger.info("Creating cached FaceRecognitionModule instance", module=module_name)
_face_module_cache[db_id] = FaceRecognitionModule(unified_db=db)
return _face_module_cache[db_id]

438
web/backend/duo_manager.py Normal file
View File

@@ -0,0 +1,438 @@
#!/usr/bin/env python3
"""
Duo Manager for Media Downloader
Handles Duo Security 2FA Operations
Based on backup-central's implementation
Uses Duo Universal Prompt
"""
import sys
import sqlite3
import secrets
from datetime import datetime, timedelta
from typing import Optional, Dict
from pathlib import Path
import os
# Add parent path to allow imports from modules
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from modules.universal_logger import get_logger
logger = get_logger('DuoManager')
class DuoManager:
def __init__(self, db_path: str = None):
if db_path is None:
db_path = str(Path(__file__).parent.parent.parent / 'database' / 'auth.db')
self.db_path = db_path
# Duo Configuration (from environment variables)
self.client_id = os.getenv('DUO_CLIENT_ID')
self.client_secret = os.getenv('DUO_CLIENT_SECRET')
self.api_host = os.getenv('DUO_API_HOSTNAME')
self.redirect_url = os.getenv('DUO_REDIRECT_URL', 'https://md.lic.ad/api/auth/2fa/duo/callback')
# Check if Duo is configured
self.is_configured = bool(self.client_id and self.client_secret and self.api_host)
# State store (for OAuth flow)
self.state_store = {} # username → state mapping
# Initialize database
self._init_database()
# Initialize Duo client if configured
self.duo_client = None
if self.is_configured:
self._init_duo_client()
def _init_database(self):
"""Initialize Duo-related database tables"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# Ensure Duo columns exist in users table
try:
cursor.execute("ALTER TABLE users ADD COLUMN duo_enabled INTEGER NOT NULL DEFAULT 0")
except sqlite3.OperationalError:
pass # Column already exists
try:
cursor.execute("ALTER TABLE users ADD COLUMN duo_username TEXT")
except sqlite3.OperationalError:
pass
try:
cursor.execute("ALTER TABLE users ADD COLUMN duo_enrolled_at TEXT")
except sqlite3.OperationalError:
pass
# Duo audit log
cursor.execute("""
CREATE TABLE IF NOT EXISTS duo_audit_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL,
action TEXT NOT NULL,
success INTEGER NOT NULL,
ip_address TEXT,
user_agent TEXT,
details TEXT,
timestamp TEXT NOT NULL
)
""")
conn.commit()
def _init_duo_client(self):
"""Initialize Duo Universal Prompt client"""
try:
from duo_universal import Client
self.duo_client = Client(
client_id=self.client_id,
client_secret=self.client_secret,
host=self.api_host,
redirect_uri=self.redirect_url
)
logger.info("Duo Universal Prompt client initialized successfully", module="Duo")
except Exception as e:
logger.error(f"Error initializing Duo client: {e}", module="Duo")
self.is_configured = False
def is_duo_configured(self) -> bool:
"""Check if Duo is properly configured"""
return self.is_configured
def get_configuration_status(self) -> Dict:
"""Get Duo configuration status"""
return {
'configured': self.is_configured,
'hasClientId': bool(self.client_id),
'hasClientSecret': bool(self.client_secret),
'hasApiHostname': bool(self.api_host),
'hasRedirectUrl': bool(self.redirect_url),
'apiHostname': self.api_host if self.api_host else None
}
def generate_state(self, username: str, remember_me: bool = False) -> str:
"""
Generate a state parameter for Duo OAuth flow
Args:
username: Username to associate with this state
remember_me: Whether to remember the user for 30 days
Returns:
Random state string
"""
state = secrets.token_urlsafe(32)
self.state_store[state] = {
'username': username,
'remember_me': remember_me,
'created_at': datetime.now(),
'expires_at': datetime.now() + timedelta(minutes=10)
}
return state
def verify_state(self, state: str) -> Optional[tuple]:
"""
Verify state parameter and return associated username and remember_me
Args:
state: State parameter from Duo callback
Returns:
Tuple of (username, remember_me) if valid, None otherwise
"""
if state not in self.state_store:
return None
state_data = self.state_store[state]
# Check if expired
if datetime.now() > state_data['expires_at']:
del self.state_store[state]
return None
username = state_data['username']
remember_me = state_data.get('remember_me', False)
del self.state_store[state] # One-time use
return (username, remember_me)
def create_auth_url(self, username: str, remember_me: bool = False) -> Dict:
"""
Create Duo authentication URL for user
Args:
username: Username to authenticate
remember_me: Whether to remember the user for 30 days
Returns:
Dict with authUrl and state
"""
if not self.is_configured or not self.duo_client:
raise Exception("Duo is not configured")
try:
# Generate state
state = self.generate_state(username, remember_me)
# Get Duo username for this user
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT duo_username FROM users WHERE username = ?", (username,))
result = cursor.fetchone()
duo_username = result[0] if result and result[0] else username
# Create Duo Universal Prompt auth URL
auth_url = self.duo_client.create_auth_url(duo_username, state)
self._log_audit(username, 'duo_auth_url_created', True, None, None,
'Duo authentication URL created')
return {
'authUrl': auth_url,
'state': state
}
except Exception as e:
self._log_audit(username, 'duo_auth_url_failed', False, None, None,
f'Error: {str(e)}')
raise
def verify_duo_response(self, duo_code: str, username: str) -> bool:
"""
Verify Duo authentication response (Universal Prompt)
Args:
duo_code: Authorization code from Duo
username: Expected username
Returns:
True if authentication successful
"""
if not self.is_configured or not self.duo_client:
return False
try:
# Exchange authorization code for 2FA result
decoded_token = self.duo_client.exchange_authorization_code_for_2fa_result(
duo_code,
username
)
# Check authentication result
if decoded_token and decoded_token.get('auth_result', {}).get('result') == 'allow':
authenticated_username = decoded_token.get('preferred_username')
# Get Duo username for this user
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT duo_username FROM users WHERE username = ?", (username,))
result = cursor.fetchone()
duo_username = result[0] if result and result[0] else username
# Check if username matches
if authenticated_username == duo_username:
self._log_audit(username, 'duo_verify_success', True, None, None,
'Duo authentication successful')
return True
else:
self._log_audit(username, 'duo_verify_failed', False, None, None,
f'Username mismatch: expected {duo_username}, got {authenticated_username}')
return False
else:
self._log_audit(username, 'duo_verify_failed', False, None, None,
'Duo authentication denied')
return False
except Exception as e:
self._log_audit(username, 'duo_verify_failed', False, None, None,
f'Error: {str(e)}')
logger.error(f"Duo verification error: {e}", module="Duo")
return False
def _get_application_key(self) -> str:
"""Get or generate application key for Duo"""
key_file = Path(__file__).parent.parent.parent / '.duo_application_key'
if key_file.exists():
with open(key_file, 'r') as f:
return f.read().strip()
# Generate new key (at least 40 characters)
key = secrets.token_urlsafe(40)
try:
with open(key_file, 'w') as f:
f.write(key)
os.chmod(key_file, 0o600)
except Exception as e:
logger.warning(f"Could not save Duo application key: {e}", module="Duo")
return key
def enroll_user(self, username: str, duo_username: Optional[str] = None) -> bool:
"""
Enroll a user in Duo 2FA
Args:
username: Username
duo_username: Optional Duo username (defaults to username)
Returns:
True if successful
"""
if duo_username is None:
duo_username = username
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
UPDATE users
SET duo_enabled = 1, duo_username = ?, duo_enrolled_at = ?
WHERE username = ?
""", (duo_username, datetime.now().isoformat(), username))
conn.commit()
self._log_audit(username, 'duo_enrolled', True, None, None,
f'Duo enrolled with username: {duo_username}')
return True
except Exception as e:
logger.error(f"Error enrolling user in Duo: {e}", module="Duo")
return False
def unenroll_user(self, username: str) -> bool:
"""
Unenroll a user from Duo 2FA
Args:
username: Username
Returns:
True if successful
"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
UPDATE users
SET duo_enabled = 0, duo_username = NULL, duo_enrolled_at = NULL
WHERE username = ?
""", (username,))
conn.commit()
self._log_audit(username, 'duo_unenrolled', True, None, None,
'Duo unenrolled')
return True
except Exception as e:
logger.error(f"Error unenrolling user from Duo: {e}", module="Duo")
return False
def get_duo_status(self, username: str) -> Dict:
"""
Get Duo status for a user
Args:
username: Username
Returns:
Dict with enabled, duoUsername, enrolledAt
"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT duo_enabled, duo_username, duo_enrolled_at
FROM users
WHERE username = ?
""", (username,))
row = cursor.fetchone()
if not row:
return {'enabled': False, 'duoUsername': None, 'enrolledAt': None}
return {
'enabled': bool(row[0]),
'duoUsername': row[1],
'enrolledAt': row[2]
}
def preauth_user(self, username: str) -> Dict:
"""
Perform Duo preauth to check enrollment status
Args:
username: Username
Returns:
Dict with result and status_msg
"""
if not self.is_configured or not self.duo_client:
return {'result': 'error', 'status_msg': 'Duo not configured'}
try:
# Get Duo username
user_info = self.get_duo_status(username)
duo_username = user_info.get('duoUsername', username)
# Perform preauth
preauth_result = self.duo_client.preauth(username=duo_username)
self._log_audit(username, 'duo_preauth', True, None, None,
f'Preauth result: {preauth_result.get("result")}')
return preauth_result
except Exception as e:
self._log_audit(username, 'duo_preauth_failed', False, None, None,
f'Error: {str(e)}')
return {'result': 'error', 'status_msg': 'Failed to check Duo enrollment status'}
def _log_audit(self, username: str, action: str, success: bool,
ip_address: Optional[str], user_agent: Optional[str],
details: Optional[str] = None):
"""Log Duo audit event"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO duo_audit_log
(username, action, success, ip_address, user_agent, details, timestamp)
VALUES (?, ?, ?, ?, ?, ?, ?)
""", (username, action, int(success), ip_address, user_agent, details,
datetime.now().isoformat()))
conn.commit()
except Exception as e:
logger.error(f"Error logging Duo audit: {e}", module="Duo")
def health_check(self) -> Dict:
"""
Check Duo service health
Returns:
Dict with healthy status and details
"""
if not self.is_configured:
return {'healthy': False, 'error': 'Duo not configured'}
try:
# Simple ping to Duo API
response = self.duo_client.ping()
return {'healthy': True, 'response': response}
except Exception as e:
logger.error(f"Duo health check failed: {e}", module="Duo")
return {'healthy': False, 'error': 'Duo service unavailable'}

View File

@@ -0,0 +1 @@
# Models module exports

View File

@@ -0,0 +1,477 @@
"""
Pydantic Models for API
All request and response models for API endpoints.
Provides validation and documentation for API contracts.
"""
from typing import Dict, List, Optional, Any
from pydantic import BaseModel, Field, field_validator
from datetime import datetime
# ============================================================================
# AUTHENTICATION MODELS
# ============================================================================
class LoginRequest(BaseModel):
"""Login request model"""
username: str = Field(..., min_length=1, max_length=50)
password: str = Field(..., min_length=1)
rememberMe: bool = False
class LoginResponse(BaseModel):
"""Login response model"""
success: bool
token: Optional[str] = None
username: Optional[str] = None
role: Optional[str] = None
expires_at: Optional[str] = None
message: Optional[str] = None
class ChangePasswordRequest(BaseModel):
"""Password change request"""
current_password: str = Field(..., min_length=1)
new_password: str = Field(..., min_length=8)
class UserPreferences(BaseModel):
"""User preferences model"""
theme: Optional[str] = Field(None, pattern=r'^(light|dark|system)$')
notifications_enabled: Optional[bool] = None
default_platform: Optional[str] = None
# ============================================================================
# DOWNLOAD MODELS
# ============================================================================
class DownloadResponse(BaseModel):
"""Single download record"""
id: int
platform: str
source: str
content_type: Optional[str] = None
filename: Optional[str] = None
file_path: Optional[str] = None
file_size: Optional[int] = None
download_date: str
post_date: Optional[str] = None
status: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
class StatsResponse(BaseModel):
"""Statistics response"""
total_downloads: int
by_platform: Dict[str, int]
total_size: int
recent_24h: int
duplicates_prevented: int
review_queue_count: int
recycle_bin_count: int = 0
class TriggerRequest(BaseModel):
"""Manual download trigger request"""
username: Optional[str] = None
content_types: Optional[List[str]] = None
# ============================================================================
# PLATFORM MODELS
# ============================================================================
class PlatformStatus(BaseModel):
"""Platform status model"""
platform: str
enabled: bool
last_run: Optional[str] = None
next_run: Optional[str] = None
status: str
class PlatformConfig(BaseModel):
"""Platform configuration"""
enabled: bool = False
username: Optional[str] = None
interval_hours: int = Field(24, ge=1, le=168)
randomize: bool = True
randomize_minutes: int = Field(30, ge=0, le=180)
# ============================================================================
# CONFIGURATION MODELS
# ============================================================================
class PushoverConfig(BaseModel):
"""Pushover notification configuration"""
enabled: bool
user_key: Optional[str] = Field(None, min_length=30, max_length=30, pattern=r'^[A-Za-z0-9]+$')
api_token: Optional[str] = Field(None, min_length=30, max_length=30, pattern=r'^[A-Za-z0-9]+$')
priority: int = Field(0, ge=-2, le=2)
sound: str = Field("pushover", pattern=r'^[a-z_]+$')
class SchedulerConfig(BaseModel):
"""Scheduler configuration"""
enabled: bool
interval_hours: int = Field(24, ge=1, le=168)
randomize: bool = True
randomize_minutes: int = Field(30, ge=0, le=180)
class RecycleBinConfig(BaseModel):
"""Recycle bin configuration"""
enabled: bool = True
path: str = "/opt/immich/recycle"
retention_days: int = Field(30, ge=1, le=365)
max_size_gb: int = Field(50, ge=1, le=1000)
auto_cleanup: bool = True
class ConfigUpdate(BaseModel):
"""Configuration update request"""
config: Dict[str, Any]
class Config:
extra = "allow"
# ============================================================================
# HEALTH MODELS
# ============================================================================
class ServiceHealth(BaseModel):
"""Individual service health status"""
status: str = Field(..., pattern=r'^(healthy|unhealthy|unknown)$')
message: Optional[str] = None
last_check: Optional[str] = None
details: Optional[Dict[str, Any]] = None
class HealthStatus(BaseModel):
"""Overall health status"""
status: str
services: Dict[str, ServiceHealth]
last_check: str
version: str
# ============================================================================
# MEDIA MODELS
# ============================================================================
class MediaItem(BaseModel):
"""Media item in gallery"""
id: int
file_path: str
filename: str
platform: str
source: str
content_type: str
file_size: Optional[int] = None
width: Optional[int] = None
height: Optional[int] = None
post_date: Optional[str] = None
download_date: Optional[str] = None
face_match: Optional[str] = None
face_confidence: Optional[float] = None
class BatchDeleteRequest(BaseModel):
"""Batch delete request"""
file_paths: List[str] = Field(..., min_length=1)
permanent: bool = False
class BatchMoveRequest(BaseModel):
"""Batch move request"""
file_paths: List[str] = Field(..., min_length=1)
destination: str
# ============================================================================
# REVIEW MODELS
# ============================================================================
class ReviewItem(BaseModel):
"""Review queue item"""
id: int
file_path: str
filename: str
platform: str
source: str
content_type: str
file_size: Optional[int] = None
width: Optional[int] = None
height: Optional[int] = None
detected_faces: Optional[int] = None
best_match: Optional[str] = None
match_confidence: Optional[float] = None
class ReviewKeepRequest(BaseModel):
"""Keep review item request"""
file_path: str
destination: str
new_name: Optional[str] = None
# ============================================================================
# FACE RECOGNITION MODELS
# ============================================================================
class FaceReference(BaseModel):
"""Face reference model"""
id: str
name: str
created_at: str
encoding_count: int = 1
thumbnail: Optional[str] = None
class AddReferenceRequest(BaseModel):
"""Add face reference request"""
file_path: str
name: str
# ============================================================================
# RECYCLE BIN MODELS
# ============================================================================
class RecycleItem(BaseModel):
"""Recycle bin item"""
id: str
original_path: str
original_filename: str
recycle_path: str
file_size: Optional[int] = None
deleted_at: str
deleted_from: str
metadata: Optional[Dict[str, Any]] = None
class RestoreRequest(BaseModel):
"""Restore from recycle bin request"""
recycle_id: str
restore_to: Optional[str] = None # Original path if None
# ============================================================================
# VIDEO DOWNLOAD MODELS
# ============================================================================
class VideoInfoRequest(BaseModel):
"""Video info request"""
url: str = Field(..., pattern=r'^https?://')
class VideoDownloadRequest(BaseModel):
"""Video download request"""
url: str = Field(..., pattern=r'^https?://')
format: Optional[str] = None
quality: Optional[str] = None
class VideoStatus(BaseModel):
"""Video download status"""
video_id: str
platform: str
status: str
progress: Optional[float] = None
title: Optional[str] = None
error: Optional[str] = None
# ============================================================================
# SCHEDULER MODELS
# ============================================================================
class TaskStatus(BaseModel):
"""Scheduler task status"""
task_id: str
platform: str
source: Optional[str] = None
status: str
last_run: Optional[str] = None
next_run: Optional[str] = None
error_count: int = 0
last_error: Optional[str] = None
class SchedulerStatus(BaseModel):
"""Overall scheduler status"""
running: bool
tasks: List[TaskStatus]
current_activity: Optional[Dict[str, Any]] = None
# ============================================================================
# NOTIFICATION MODELS
# ============================================================================
class NotificationItem(BaseModel):
"""Notification item"""
id: int
platform: str
source: str
content_type: str
message: str
title: Optional[str] = None
sent_at: str
download_count: int
status: str
# ============================================================================
# SEMANTIC SEARCH MODELS
# ============================================================================
class SemanticSearchRequest(BaseModel):
"""Semantic search request"""
query: str = Field(..., min_length=1, max_length=500)
limit: int = Field(50, ge=1, le=200)
threshold: float = Field(0.3, ge=0.0, le=1.0)
class SimilarImagesRequest(BaseModel):
"""Find similar images request"""
file_id: int
limit: int = Field(20, ge=1, le=100)
threshold: float = Field(0.5, ge=0.0, le=1.0)
# ============================================================================
# TAG MODELS
# ============================================================================
class TagCreate(BaseModel):
"""Create tag request"""
name: str = Field(..., min_length=1, max_length=50)
color: Optional[str] = Field(None, pattern=r'^#[0-9A-Fa-f]{6}$')
class TagUpdate(BaseModel):
"""Update tag request"""
name: Optional[str] = Field(None, min_length=1, max_length=50)
color: Optional[str] = Field(None, pattern=r'^#[0-9A-Fa-f]{6}$')
class BulkTagRequest(BaseModel):
"""Bulk tag operation request"""
file_ids: List[int] = Field(..., min_length=1)
tag_ids: List[int] = Field(..., min_length=1)
operation: str = Field(..., pattern=r'^(add|remove)$')
# ============================================================================
# COLLECTION MODELS
# ============================================================================
class CollectionCreate(BaseModel):
"""Create collection request"""
name: str = Field(..., min_length=1, max_length=100)
description: Optional[str] = Field(None, max_length=500)
class CollectionUpdate(BaseModel):
"""Update collection request"""
name: Optional[str] = Field(None, min_length=1, max_length=100)
description: Optional[str] = Field(None, max_length=500)
class CollectionBulkAdd(BaseModel):
"""Bulk add files to collection"""
file_ids: List[int] = Field(..., min_length=1)
# ============================================================================
# SMART FOLDER MODELS
# ============================================================================
class SmartFolderCreate(BaseModel):
"""Create smart folder request"""
name: str = Field(..., min_length=1, max_length=100)
query_rules: Dict[str, Any]
class SmartFolderUpdate(BaseModel):
"""Update smart folder request"""
name: Optional[str] = Field(None, min_length=1, max_length=100)
query_rules: Optional[Dict[str, Any]] = None
# ============================================================================
# SCRAPER MODELS
# ============================================================================
class ScraperUpdate(BaseModel):
"""Update scraper settings"""
enabled: Optional[bool] = None
proxy: Optional[str] = None
interval_hours: Optional[int] = Field(None, ge=1, le=168)
settings: Optional[Dict[str, Any]] = None
class CookieUpload(BaseModel):
"""Cookie upload for scraper"""
cookies: str # JSON string of cookies
source: str = Field(..., pattern=r'^(browser|manual|extension)$')
# ============================================================================
# STANDARD RESPONSE MODELS
# ============================================================================
class SuccessResponse(BaseModel):
"""Standard success response"""
success: bool = True
message: str = "Operation completed successfully"
class DataResponse(BaseModel):
"""Standard data response wrapper"""
success: bool = True
data: Any
class MessageResponse(BaseModel):
"""Response with just a message"""
message: str
class CountResponse(BaseModel):
"""Response with count of affected items"""
message: str
count: int
class IdResponse(BaseModel):
"""Response with created resource ID"""
id: int
message: str = "Resource created successfully"
class PaginatedResponse(BaseModel):
"""Paginated list response base"""
items: List[Any]
total: int
limit: int
offset: int
has_more: bool = False
@classmethod
def create(cls, items: List[Any], total: int, limit: int, offset: int):
"""Helper to create paginated response"""
return cls(
items=items,
total=total,
limit=limit,
offset=offset,
has_more=(offset + len(items)) < total
)

View File

@@ -0,0 +1,565 @@
#!/usr/bin/env python3
"""
Passkey Manager for Media Downloader
Handles WebAuthn/Passkey operations
Based on backup-central's implementation
"""
import sys
import sqlite3
import json
from datetime import datetime, timedelta
from typing import Optional, Dict, List
from pathlib import Path
import os
# Add parent path to allow imports from modules
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from modules.universal_logger import get_logger
logger = get_logger('PasskeyManager')
from webauthn import (
generate_registration_options,
verify_registration_response,
generate_authentication_options,
verify_authentication_response,
options_to_json
)
from webauthn.helpers import base64url_to_bytes, bytes_to_base64url
from webauthn.helpers.structs import (
PublicKeyCredentialDescriptor,
UserVerificationRequirement,
AttestationConveyancePreference,
AuthenticatorSelectionCriteria,
ResidentKeyRequirement
)
class PasskeyManager:
def __init__(self, db_path: str = None):
if db_path is None:
db_path = str(Path(__file__).parent.parent.parent / 'database' / 'auth.db')
self.db_path = db_path
# WebAuthn Configuration
self.rp_name = 'Media Downloader'
self.rp_id = self._get_rp_id()
self.origin = self._get_origin()
# Challenge timeout (5 minutes)
self.challenge_timeout = timedelta(minutes=5)
# Challenge storage (in-memory for now, could use Redis/DB for production)
self.challenges = {} # username → challenge mapping
# Initialize database
self._init_database()
def _get_rp_id(self) -> str:
"""Get Relying Party ID (domain)"""
return os.getenv('WEBAUTHN_RP_ID', 'md.lic.ad')
def _get_origin(self) -> str:
"""Get Origin URL"""
return os.getenv('WEBAUTHN_ORIGIN', 'https://md.lic.ad')
def _init_database(self):
"""Initialize passkey-related database tables"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# Passkeys credentials table
cursor.execute("""
CREATE TABLE IF NOT EXISTS passkey_credentials (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL,
credential_id TEXT NOT NULL UNIQUE,
public_key TEXT NOT NULL,
sign_count INTEGER NOT NULL DEFAULT 0,
transports TEXT,
device_name TEXT,
aaguid TEXT,
created_at TEXT NOT NULL,
last_used TEXT,
FOREIGN KEY (username) REFERENCES users(username) ON DELETE CASCADE
)
""")
# Passkey audit log
cursor.execute("""
CREATE TABLE IF NOT EXISTS passkey_audit_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL,
action TEXT NOT NULL,
success INTEGER NOT NULL,
credential_id TEXT,
ip_address TEXT,
user_agent TEXT,
details TEXT,
timestamp TEXT NOT NULL
)
""")
# Add passkey enabled column to users
try:
cursor.execute("ALTER TABLE users ADD COLUMN passkey_enabled INTEGER NOT NULL DEFAULT 0")
except sqlite3.OperationalError:
pass
conn.commit()
def get_user_credentials(self, username: str) -> List[Dict]:
"""
Get all credentials for a user
Args:
username: Username
Returns:
List of credential dicts
"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT credential_id, public_key, sign_count, device_name,
created_at, last_used_at
FROM passkey_credentials
WHERE user_id = ?
""", (username,))
rows = cursor.fetchall()
credentials = []
for row in rows:
credentials.append({
'credential_id': row[0],
'public_key': row[1],
'sign_count': row[2],
'transports': None,
'device_name': row[3],
'aaguid': None,
'created_at': row[4],
'last_used': row[5]
})
return credentials
def generate_registration_options(self, username: str, email: Optional[str] = None) -> Dict:
"""
Generate registration options for a new passkey
Args:
username: Username
email: Optional email for display name
Returns:
Registration options for client
"""
try:
# Get existing credentials
existing_credentials = self.get_user_credentials(username)
# Create exclude credentials list
exclude_credentials = []
for cred in existing_credentials:
exclude_credentials.append(
PublicKeyCredentialDescriptor(
id=base64url_to_bytes(cred['credential_id']),
transports=cred.get('transports')
)
)
# Generate registration options
options = generate_registration_options(
rp_id=self.rp_id,
rp_name=self.rp_name,
user_id=username.encode('utf-8'),
user_name=username,
user_display_name=email or username,
timeout=60000, # 60 seconds
attestation=AttestationConveyancePreference.NONE,
exclude_credentials=exclude_credentials,
authenticator_selection=AuthenticatorSelectionCriteria(
resident_key=ResidentKeyRequirement.PREFERRED,
user_verification=UserVerificationRequirement.PREFERRED
),
supported_pub_key_algs=[-7, -257] # ES256, RS256
)
# Store challenge
challenge_str = bytes_to_base64url(options.challenge)
self.challenges[username] = {
'challenge': challenge_str,
'type': 'registration',
'created_at': datetime.now(),
'expires_at': datetime.now() + self.challenge_timeout
}
self._log_audit(username, 'passkey_registration_options_generated', True,
None, None, None, 'Registration options generated')
# Convert to JSON-serializable format
options_dict = options_to_json(options)
return json.loads(options_dict)
except Exception as e:
self._log_audit(username, 'passkey_registration_options_failed', False,
None, None, None, f'Error: {str(e)}')
raise
def verify_registration(self, username: str, credential_data: Dict,
device_name: Optional[str] = None) -> Dict:
"""
Verify registration response and store credential
Args:
username: Username
credential_data: Registration response from client
device_name: Optional device name
Returns:
Dict with success status
"""
try:
# Get stored challenge
if username not in self.challenges:
raise Exception("No registration in progress")
challenge_data = self.challenges[username]
# Check if expired
if datetime.now() > challenge_data['expires_at']:
del self.challenges[username]
raise Exception("Challenge expired")
if challenge_data['type'] != 'registration':
raise Exception("Invalid challenge type")
expected_challenge = base64url_to_bytes(challenge_data['challenge'])
# Verify registration response
verification = verify_registration_response(
credential=credential_data,
expected_challenge=expected_challenge,
expected_rp_id=self.rp_id,
expected_origin=self.origin
)
# Store credential
credential_id = bytes_to_base64url(verification.credential_id)
public_key = bytes_to_base64url(verification.credential_public_key)
# Extract transports if available
transports = credential_data.get('response', {}).get('transports', [])
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO passkey_credentials
(user_id, credential_id, public_key, sign_count, device_name, created_at)
VALUES (?, ?, ?, ?, ?, ?)
""", (
username,
credential_id,
public_key,
verification.sign_count,
device_name,
datetime.now().isoformat()
))
# Enable passkey for user
cursor.execute("""
UPDATE users SET passkey_enabled = 1
WHERE username = ?
""", (username,))
conn.commit()
# Clean up challenge
del self.challenges[username]
self._log_audit(username, 'passkey_registered', True, credential_id,
None, None, f'Passkey registered: {device_name or "Unknown device"}')
return {
'success': True,
'credentialId': credential_id
}
except Exception as e:
logger.error(f"Passkey registration verification failed for {username}: {str(e)}", module="Passkey")
logger.debug(f"Passkey registration detailed error:", exc_info=True, module="Passkey")
self._log_audit(username, 'passkey_registration_failed', False, None,
None, None, f'Error: {str(e)}')
raise
def generate_authentication_options(self, username: Optional[str] = None) -> Dict:
"""
Generate authentication options for passkey login
Args:
username: Optional username (if None, allows usernameless login)
Returns:
Authentication options for client
"""
try:
allow_credentials = []
# If username provided, get their credentials
if username:
user_credentials = self.get_user_credentials(username)
for cred in user_credentials:
allow_credentials.append(
PublicKeyCredentialDescriptor(
id=base64url_to_bytes(cred['credential_id']),
transports=cred.get('transports')
)
)
# Generate authentication options
options = generate_authentication_options(
rp_id=self.rp_id,
timeout=60000, # 60 seconds
allow_credentials=allow_credentials if allow_credentials else None,
user_verification=UserVerificationRequirement.PREFERRED
)
# Store challenge
challenge_str = bytes_to_base64url(options.challenge)
self.challenges[username or '__anonymous__'] = {
'challenge': challenge_str,
'type': 'authentication',
'created_at': datetime.now(),
'expires_at': datetime.now() + self.challenge_timeout
}
self._log_audit(username or 'anonymous', 'passkey_authentication_options_generated',
True, None, None, None, 'Authentication options generated')
# Convert to JSON-serializable format
options_dict = options_to_json(options)
return json.loads(options_dict)
except Exception as e:
self._log_audit(username or 'anonymous', 'passkey_authentication_options_failed',
False, None, None, None, f'Error: {str(e)}')
raise
def verify_authentication(self, username: str, credential_data: Dict) -> Dict:
"""
Verify authentication response
Args:
username: Username
credential_data: Authentication response from client
Returns:
Dict with success status and verified username
"""
try:
# Get stored challenge
challenge_key = username or '__anonymous__'
if challenge_key not in self.challenges:
raise Exception("No authentication in progress")
challenge_data = self.challenges[challenge_key]
# Check if expired
if datetime.now() > challenge_data['expires_at']:
del self.challenges[challenge_key]
raise Exception("Challenge expired")
if challenge_data['type'] != 'authentication':
raise Exception("Invalid challenge type")
expected_challenge = base64url_to_bytes(challenge_data['challenge'])
# Get credential ID from response
credential_id = credential_data.get('id') or credential_data.get('rawId')
if not credential_id:
raise Exception("No credential ID in response")
# Find credential in database
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT user_id, public_key, sign_count
FROM passkey_credentials
WHERE credential_id = ?
""", (credential_id,))
row = cursor.fetchone()
if not row:
raise Exception("Credential not found")
verified_username, public_key_b64, current_sign_count = row
# Verify authentication response
verification = verify_authentication_response(
credential=credential_data,
expected_challenge=expected_challenge,
expected_rp_id=self.rp_id,
expected_origin=self.origin,
credential_public_key=base64url_to_bytes(public_key_b64),
credential_current_sign_count=current_sign_count
)
# Update sign count and last used
cursor.execute("""
UPDATE passkey_credentials
SET sign_count = ?, last_used_at = ?
WHERE credential_id = ?
""", (verification.new_sign_count, datetime.now().isoformat(), credential_id))
conn.commit()
# Clean up challenge
del self.challenges[challenge_key]
self._log_audit(verified_username, 'passkey_authentication_success', True,
credential_id, None, None, 'Passkey authentication successful')
return {
'success': True,
'username': verified_username
}
except Exception as e:
self._log_audit(username or 'anonymous', 'passkey_authentication_failed', False,
None, None, None, f'Error: {str(e)}')
raise
def remove_credential(self, username: str, credential_id: str) -> bool:
"""
Remove a passkey credential
Args:
username: Username
credential_id: Credential ID to remove
Returns:
True if successful
"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# Debug: List all credentials for this user
cursor.execute("SELECT credential_id FROM passkey_credentials WHERE user_id = ?", (username,))
all_creds = cursor.fetchall()
logger.debug(f"All credentials for {username}: {all_creds}", module="Passkey")
logger.debug(f"Looking for credential_id: '{credential_id}'", module="Passkey")
cursor.execute("""
DELETE FROM passkey_credentials
WHERE user_id = ? AND credential_id = ?
""", (username, credential_id))
deleted = cursor.rowcount > 0
logger.debug(f"Delete operation rowcount: {deleted}", module="Passkey")
# Check if user has any remaining credentials
cursor.execute("""
SELECT COUNT(*) FROM passkey_credentials
WHERE user_id = ?
""", (username,))
remaining = cursor.fetchone()[0]
# If no credentials left, disable passkey
if remaining == 0:
cursor.execute("""
UPDATE users SET passkey_enabled = 0
WHERE username = ?
""", (username,))
conn.commit()
if deleted:
self._log_audit(username, 'passkey_removed', True, credential_id,
None, None, 'Passkey credential removed')
return deleted
except Exception as e:
logger.error(f"Error removing passkey for {username}: {e}", module="Passkey")
return False
def get_passkey_status(self, username: str) -> Dict:
"""
Get passkey status for a user
Args:
username: Username
Returns:
Dict with enabled status and credential count
"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT COUNT(*) FROM passkey_credentials
WHERE user_id = ?
""", (username,))
count = cursor.fetchone()[0]
return {
'enabled': count > 0,
'credentialCount': count
}
def list_credentials(self, username: str) -> List[Dict]:
"""
List all credentials for a user (without sensitive data)
Args:
username: Username
Returns:
List of credential info dicts
"""
credentials = self.get_user_credentials(username)
# Remove sensitive data
safe_credentials = []
for cred in credentials:
safe_credentials.append({
'credentialId': cred['credential_id'],
'deviceName': cred.get('device_name', 'Unknown device'),
'createdAt': cred['created_at'],
'lastUsed': cred.get('last_used'),
'transports': cred.get('transports', [])
})
return safe_credentials
def _log_audit(self, username: str, action: str, success: bool,
credential_id: Optional[str], ip_address: Optional[str],
user_agent: Optional[str], details: Optional[str] = None):
"""Log passkey audit event"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO passkey_audit_log
(user_id, action, success, credential_id, ip_address, user_agent, details, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", (username, action, int(success), credential_id, ip_address, user_agent, details,
datetime.now().isoformat()))
conn.commit()
except Exception as e:
logger.error(f"Error logging passkey audit: {e}", module="Passkey")

View File

@@ -0,0 +1,6 @@
fastapi==0.109.0
uvicorn[standard]>=0.27.0,<0.35.0 # 0.40.0+ has breaking changes
pydantic==2.5.3
python-multipart==0.0.6
websockets==12.0
psutil==7.1.3

View File

@@ -0,0 +1,90 @@
"""
Router module exports
All API routers for the media-downloader backend.
Import routers from here to include them in the main app.
Usage:
from web.backend.routers import (
auth_router,
health_router,
downloads_router,
media_router,
recycle_router,
scheduler_router,
video_router,
config_router,
review_router,
face_router,
platforms_router,
discovery_router,
scrapers_router,
semantic_router,
manual_import_router,
stats_router
)
app.include_router(auth_router)
app.include_router(health_router)
# ... etc
"""
from .auth import router as auth_router
from .health import router as health_router
from .downloads import router as downloads_router
from .media import router as media_router
from .recycle import router as recycle_router
from .scheduler import router as scheduler_router
from .video import router as video_router
from .config import router as config_router
from .review import router as review_router
from .face import router as face_router
from .platforms import router as platforms_router
from .discovery import router as discovery_router
from .scrapers import router as scrapers_router
from .semantic import router as semantic_router
from .manual_import import router as manual_import_router
from .stats import router as stats_router
from .celebrity import router as celebrity_router
from .video_queue import router as video_queue_router
from .maintenance import router as maintenance_router
from .files import router as files_router
from .appearances import router as appearances_router
from .easynews import router as easynews_router
from .dashboard import router as dashboard_router
from .paid_content import router as paid_content_router
from .private_gallery import router as private_gallery_router
from .instagram_unified import router as instagram_unified_router
from .cloud_backup import router as cloud_backup_router
from .press import router as press_router
__all__ = [
'auth_router',
'health_router',
'downloads_router',
'media_router',
'recycle_router',
'scheduler_router',
'video_router',
'config_router',
'review_router',
'face_router',
'platforms_router',
'discovery_router',
'scrapers_router',
'semantic_router',
'manual_import_router',
'stats_router',
'celebrity_router',
'video_queue_router',
'maintenance_router',
'files_router',
'appearances_router',
'easynews_router',
'dashboard_router',
'paid_content_router',
'private_gallery_router',
'instagram_unified_router',
'cloud_backup_router',
'press_router',
]

File diff suppressed because it is too large Load Diff

217
web/backend/routers/auth.py Normal file
View File

@@ -0,0 +1,217 @@
"""
Authentication Router
Handles all authentication-related endpoints:
- Login/Logout
- User info
- Password changes
- User preferences
"""
import sqlite3
from typing import Dict
from fastapi import APIRouter, Depends, HTTPException, Body, Request
from fastapi.responses import JSONResponse
from slowapi import Limiter
from slowapi.util import get_remote_address
from ..core.dependencies import get_current_user, get_app_state
from ..core.config import settings
from ..core.exceptions import AuthError, handle_exceptions
from ..core.responses import to_iso8601, now_iso8601
from ..models.api_models import LoginRequest, ChangePasswordRequest, UserPreferences
from modules.universal_logger import get_logger
logger = get_logger('API')
router = APIRouter(prefix="/api/auth", tags=["Authentication"])
# Rate limiter - will be set from main app
limiter = Limiter(key_func=get_remote_address)
@router.post("/login")
@limiter.limit("5/minute")
@handle_exceptions
async def login(login_data: LoginRequest, request: Request):
"""
Authenticate user with username and password.
Returns JWT token or 2FA challenge if 2FA is enabled.
"""
app_state = get_app_state()
if not app_state.auth:
raise HTTPException(status_code=500, detail="Authentication not initialized")
# Query user from database
with sqlite3.connect(app_state.auth.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT password_hash, role, is_active, totp_enabled, duo_enabled, passkey_enabled
FROM users WHERE username = ?
""", (login_data.username,))
row = cursor.fetchone()
if not row:
raise HTTPException(status_code=401, detail="Invalid credentials")
password_hash, role, is_active, totp_enabled, duo_enabled, passkey_enabled = row
if not is_active:
raise HTTPException(status_code=401, detail="Account is inactive")
if not app_state.auth.verify_password(login_data.password, password_hash):
app_state.auth._record_login_attempt(login_data.username, False)
raise HTTPException(status_code=401, detail="Invalid credentials")
# Check if user has any 2FA methods enabled
available_methods = []
if totp_enabled:
available_methods.append('totp')
if passkey_enabled:
available_methods.append('passkey')
if duo_enabled:
available_methods.append('duo')
# If user has 2FA enabled, return require2FA flag
if available_methods:
return {
'success': True,
'require2FA': True,
'availableMethods': available_methods,
'username': login_data.username
}
# No 2FA - proceed with normal login
result = app_state.auth._create_session(
username=login_data.username,
role=role,
ip_address=request.client.host if request.client else None,
remember_me=login_data.rememberMe
)
# Create response with cookie
response = JSONResponse(content=result)
# Set auth cookie (secure, httponly for security)
max_age = 30 * 24 * 60 * 60 if login_data.rememberMe else None
response.set_cookie(
key="auth_token",
value=result.get('token'),
max_age=max_age,
httponly=True,
secure=settings.SECURE_COOKIES,
samesite="lax",
path="/"
)
logger.info(f"User {login_data.username} logged in successfully", module="Auth")
return response
@router.post("/logout")
@limiter.limit("10/minute")
@handle_exceptions
async def logout(request: Request, current_user: Dict = Depends(get_current_user)):
"""Logout and invalidate session"""
username = current_user.get('sub', 'unknown')
response = JSONResponse(content={
"success": True,
"message": "Logged out successfully",
"timestamp": now_iso8601()
})
# Clear auth cookie
response.set_cookie(
key="auth_token",
value="",
max_age=0,
httponly=True,
secure=settings.SECURE_COOKIES,
samesite="lax",
path="/"
)
logger.info(f"User {username} logged out", module="Auth")
return response
@router.get("/me")
@limiter.limit("30/minute")
@handle_exceptions
async def get_me(request: Request, current_user: Dict = Depends(get_current_user)):
"""Get current user information"""
app_state = get_app_state()
username = current_user.get('sub')
user_info = app_state.auth.get_user(username)
if not user_info:
raise HTTPException(status_code=404, detail="User not found")
return user_info
@router.post("/change-password")
@limiter.limit("5/minute")
@handle_exceptions
async def change_password(
request: Request,
current_password: str = Body(..., embed=True),
new_password: str = Body(..., embed=True),
current_user: Dict = Depends(get_current_user)
):
"""Change user password"""
app_state = get_app_state()
username = current_user.get('sub')
ip_address = request.client.host if request.client else None
# Validate new password
if len(new_password) < 8:
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
result = app_state.auth.change_password(username, current_password, new_password, ip_address)
if not result['success']:
raise HTTPException(status_code=400, detail=result.get('error', 'Password change failed'))
logger.info(f"Password changed for user {username}", module="Auth")
return {
"success": True,
"message": "Password changed successfully",
"timestamp": now_iso8601()
}
@router.post("/preferences")
@limiter.limit("10/minute")
@handle_exceptions
async def update_preferences(
request: Request,
preferences: dict = Body(...),
current_user: Dict = Depends(get_current_user)
):
"""Update user preferences (theme, notifications, etc.)"""
app_state = get_app_state()
username = current_user.get('sub')
# Validate theme if provided
if 'theme' in preferences:
if preferences['theme'] not in ('light', 'dark', 'system'):
raise HTTPException(status_code=400, detail="Invalid theme value")
result = app_state.auth.update_preferences(username, preferences)
if not result['success']:
raise HTTPException(status_code=400, detail=result.get('error', 'Failed to update preferences'))
return {
"success": True,
"message": "Preferences updated",
"preferences": preferences,
"timestamp": now_iso8601()
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,756 @@
"""
Config and Logs Router
Handles configuration and logging operations:
- Get/update application configuration
- Log viewing (single component, merged)
- Notification history and stats
- Changelog retrieval
"""
import json
import re
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, Query, Request
from pydantic import BaseModel
from slowapi import Limiter
from slowapi.util import get_remote_address
from ..core.dependencies import get_current_user, require_admin, get_app_state
from ..core.config import settings
from ..core.exceptions import (
handle_exceptions,
ValidationError,
RecordNotFoundError
)
from ..core.responses import now_iso8601
from modules.universal_logger import get_logger
logger = get_logger('API')
router = APIRouter(prefix="/api", tags=["Configuration"])
limiter = Limiter(key_func=get_remote_address)
LOG_PATH = settings.PROJECT_ROOT / 'logs'
# ============================================================================
# PYDANTIC MODELS
# ============================================================================
class ConfigUpdate(BaseModel):
config: Dict
class MergedLogsRequest(BaseModel):
lines: int = 500
components: List[str]
around_time: Optional[str] = None # ISO timestamp to center logs around
# ============================================================================
# CONFIGURATION ENDPOINTS
# ============================================================================
@router.get("/config")
@limiter.limit("100/minute")
@handle_exceptions
async def get_config(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Get current configuration."""
app_state = get_app_state()
return app_state.settings.get_all()
@router.put("/config")
@limiter.limit("20/minute")
@handle_exceptions
async def update_config(
request: Request,
current_user: Dict = Depends(require_admin),
update: ConfigUpdate = Body(...)
):
"""
Update configuration (admin only).
Saves configuration to database and updates in-memory state.
"""
app_state = get_app_state()
if not isinstance(update.config, dict):
raise ValidationError("Invalid configuration format")
logger.debug(f"Incoming config keys: {list(update.config.keys())}", module="Config")
# Save to database
for key, value in update.config.items():
app_state.settings.set(key, value, category=key, updated_by='api')
# Refresh in-memory config so other endpoints see updated values
app_state.config = app_state.settings.get_all()
# Broadcast update
try:
if hasattr(app_state, 'websocket_manager') and app_state.websocket_manager:
await app_state.websocket_manager.broadcast({
"type": "config_updated",
"timestamp": now_iso8601()
})
except Exception as e:
logger.debug(f"Failed to broadcast config update: {e}", module="Config")
return {"success": True, "message": "Configuration updated"}
# ============================================================================
# LOG ENDPOINTS
# ============================================================================
@router.get("/logs")
@limiter.limit("100/minute")
@handle_exceptions
async def get_logs(
request: Request,
current_user: Dict = Depends(get_current_user),
lines: int = 100,
component: Optional[str] = None
):
"""Get recent log entries from the most recent log files."""
if not LOG_PATH.exists():
return {"logs": [], "available_components": []}
all_log_files = []
# Find date-stamped logs: YYYYMMDD_component.log or YYYYMMDD_HHMMSS_component.log
seen_paths = set()
for log_file in LOG_PATH.glob('*.log'):
if '_' not in log_file.stem:
continue
parts = log_file.stem.split('_')
if not parts[0].isdigit():
continue
try:
stat_info = log_file.stat()
if stat_info.st_size == 0:
continue
mtime = stat_info.st_mtime
# YYYYMMDD_HHMMSS_component.log (3+ parts, first two numeric)
if len(parts) >= 3 and parts[1].isdigit():
comp_name = '_'.join(parts[2:])
# YYYYMMDD_component.log (2+ parts, first numeric)
elif len(parts) >= 2:
comp_name = '_'.join(parts[1:])
else:
continue
seen_paths.add(log_file)
all_log_files.append({
'path': log_file,
'mtime': mtime,
'component': comp_name
})
except OSError:
pass
# Also check for old-style logs (no date prefix)
for log_file in LOG_PATH.glob('*.log'):
if log_file in seen_paths:
continue
if '_' in log_file.stem and log_file.stem.split('_')[0].isdigit():
continue
try:
stat_info = log_file.stat()
if stat_info.st_size == 0:
continue
mtime = stat_info.st_mtime
all_log_files.append({
'path': log_file,
'mtime': mtime,
'component': log_file.stem
})
except OSError:
pass
if not all_log_files:
return {"logs": [], "available_components": []}
components = sorted(set(f['component'] for f in all_log_files))
if component:
log_files = [f for f in all_log_files if f['component'] == component]
else:
log_files = all_log_files
if not log_files:
return {"logs": [], "available_components": components}
most_recent = max(log_files, key=lambda x: x['mtime'])
try:
with open(most_recent['path'], 'r', encoding='utf-8', errors='ignore') as f:
all_lines = f.readlines()
recent_lines = all_lines[-lines:]
return {
"logs": [line.strip() for line in recent_lines],
"available_components": components,
"current_component": most_recent['component'],
"log_file": str(most_recent['path'].name)
}
except Exception as e:
logger.error(f"Error reading log file: {e}", module="Logs")
return {"logs": [], "available_components": components, "error": str(e)}
@router.post("/logs/merged")
@limiter.limit("100/minute")
@handle_exceptions
async def get_merged_logs(
request: Request,
body: MergedLogsRequest,
current_user: Dict = Depends(get_current_user)
):
"""Get merged log entries from multiple components, sorted by timestamp."""
lines = body.lines
components = body.components
if not LOG_PATH.exists():
return {"logs": [], "available_components": [], "selected_components": []}
all_log_files = []
# Find date-stamped logs
for log_file in LOG_PATH.glob('*_*.log'):
try:
stat_info = log_file.stat()
if stat_info.st_size == 0:
continue
mtime = stat_info.st_mtime
parts = log_file.stem.split('_')
# Check OLD format FIRST (YYYYMMDD_HHMMSS_component.log)
if len(parts) >= 3 and parts[0].isdigit() and len(parts[0]) == 8 and parts[1].isdigit() and len(parts[1]) == 6:
comp_name = '_'.join(parts[2:])
all_log_files.append({
'path': log_file,
'mtime': mtime,
'component': comp_name
})
# Then check NEW format (YYYYMMDD_component.log)
elif len(parts) >= 2 and parts[0].isdigit() and len(parts[0]) == 8:
comp_name = '_'.join(parts[1:])
all_log_files.append({
'path': log_file,
'mtime': mtime,
'component': comp_name
})
except OSError:
pass
# Also check for old-style logs
for log_file in LOG_PATH.glob('*.log'):
if '_' in log_file.stem and log_file.stem.split('_')[0].isdigit():
continue
try:
stat_info = log_file.stat()
if stat_info.st_size == 0:
continue
mtime = stat_info.st_mtime
all_log_files.append({
'path': log_file,
'mtime': mtime,
'component': log_file.stem
})
except OSError:
pass
if not all_log_files:
return {"logs": [], "available_components": [], "selected_components": []}
available_components = sorted(set(f['component'] for f in all_log_files))
if not components or len(components) == 0:
return {
"logs": [],
"available_components": available_components,
"selected_components": []
}
selected_log_files = [f for f in all_log_files if f['component'] in components]
if not selected_log_files:
return {
"logs": [],
"available_components": available_components,
"selected_components": components
}
all_logs_with_timestamps = []
for comp in components:
comp_files = [f for f in selected_log_files if f['component'] == comp]
if not comp_files:
continue
most_recent = max(comp_files, key=lambda x: x['mtime'])
try:
with open(most_recent['path'], 'r', encoding='utf-8', errors='ignore') as f:
all_lines = f.readlines()
recent_lines = all_lines[-lines:]
for line in recent_lines:
line = line.strip()
if not line:
continue
# Match timestamp with optional microseconds
timestamp_match = re.match(r'^(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2})(?:\.(\d+))?', line)
if timestamp_match:
timestamp_str = timestamp_match.group(1)
microseconds = timestamp_match.group(2)
try:
timestamp = datetime.strptime(timestamp_str, '%Y-%m-%d %H:%M:%S')
# Add microseconds if present
if microseconds:
# Pad or truncate to 6 digits for microseconds
microseconds = microseconds[:6].ljust(6, '0')
timestamp = timestamp.replace(microsecond=int(microseconds))
all_logs_with_timestamps.append({
'timestamp': timestamp,
'log': line
})
except ValueError:
all_logs_with_timestamps.append({
'timestamp': None,
'log': line
})
else:
all_logs_with_timestamps.append({
'timestamp': None,
'log': line
})
except Exception as e:
logger.error(f"Error reading log file {most_recent['path']}: {e}", module="Logs")
continue
# Sort by timestamp
sorted_logs = sorted(
all_logs_with_timestamps,
key=lambda x: x['timestamp'] if x['timestamp'] is not None else datetime.min
)
# If around_time is specified, center the logs around that timestamp
if body.around_time:
try:
# Parse the target timestamp
target_time = datetime.fromisoformat(body.around_time.replace('Z', '+00:00').replace('+00:00', ''))
# Find logs within 10 minutes of the target time
time_window = timedelta(minutes=10)
filtered_logs = [
entry for entry in sorted_logs
if entry['timestamp'] is not None and
abs((entry['timestamp'] - target_time).total_seconds()) <= time_window.total_seconds()
]
# If we found logs near the target time, use those
# Otherwise fall back to all logs and try to find the closest ones
if filtered_logs:
merged_logs = [entry['log'] for entry in filtered_logs]
else:
# Find the closest logs to the target time
logs_with_diff = [
(entry, abs((entry['timestamp'] - target_time).total_seconds()) if entry['timestamp'] else float('inf'))
for entry in sorted_logs
]
logs_with_diff.sort(key=lambda x: x[1])
# Take the closest logs, centered around the target
closest_logs = logs_with_diff[:lines]
closest_logs.sort(key=lambda x: x[0]['timestamp'] if x[0]['timestamp'] else datetime.min)
merged_logs = [entry[0]['log'] for entry in closest_logs]
except (ValueError, TypeError):
# If parsing fails, fall back to normal behavior
merged_logs = [entry['log'] for entry in sorted_logs]
if len(merged_logs) > lines:
merged_logs = merged_logs[-lines:]
else:
merged_logs = [entry['log'] for entry in sorted_logs]
if len(merged_logs) > lines:
merged_logs = merged_logs[-lines:]
return {
"logs": merged_logs,
"available_components": available_components,
"selected_components": components,
"total_logs": len(merged_logs)
}
# ============================================================================
# NOTIFICATION ENDPOINTS
# ============================================================================
@router.get("/notifications")
@limiter.limit("500/minute")
@handle_exceptions
async def get_notifications(
request: Request,
current_user: Dict = Depends(get_current_user),
limit: int = 50,
offset: int = 0,
platform: Optional[str] = None,
source: Optional[str] = None
):
"""Get notification history with pagination and filters."""
app_state = get_app_state()
with app_state.db.get_connection() as conn:
cursor = conn.cursor()
query = """
SELECT id, platform, source, content_type, message, title,
priority, download_count, sent_at, status, metadata
FROM notifications
WHERE 1=1
"""
params = []
if platform:
query += " AND platform = ?"
params.append(platform)
if source:
# Handle standardized source names
if source == 'YouTube Monitor':
query += " AND source = ?"
params.append('youtube_monitor')
else:
query += " AND source = ?"
params.append(source)
# Get total count
count_query = query.replace(
"SELECT id, platform, source, content_type, message, title, priority, download_count, sent_at, status, metadata",
"SELECT COUNT(*)"
)
cursor.execute(count_query, params)
result = cursor.fetchone()
total = result[0] if result else 0
# Add ordering and pagination
query += " ORDER BY sent_at DESC LIMIT ? OFFSET ?"
params.extend([limit, offset])
cursor.execute(query, params)
rows = cursor.fetchall()
notifications = []
for row in rows:
notifications.append({
'id': row[0],
'platform': row[1],
'source': row[2],
'content_type': row[3],
'message': row[4],
'title': row[5],
'priority': row[6],
'download_count': row[7],
'sent_at': row[8],
'status': row[9],
'metadata': json.loads(row[10]) if row[10] else None
})
return {
'notifications': notifications,
'total': total,
'limit': limit,
'offset': offset
}
@router.get("/notifications/stats")
@limiter.limit("500/minute")
@handle_exceptions
async def get_notification_stats(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Get notification statistics."""
app_state = get_app_state()
with app_state.db.get_connection() as conn:
cursor = conn.cursor()
# Total sent
cursor.execute("SELECT COUNT(*) FROM notifications WHERE status = 'sent'")
result = cursor.fetchone()
total_sent = result[0] if result else 0
# Total failed
cursor.execute("SELECT COUNT(*) FROM notifications WHERE status = 'failed'")
result = cursor.fetchone()
total_failed = result[0] if result else 0
# By platform (consolidate and filter)
cursor.execute("""
SELECT platform, COUNT(*) as count
FROM notifications
GROUP BY platform
ORDER BY count DESC
""")
raw_platforms = {row[0]: row[1] for row in cursor.fetchall()}
# Consolidate similar platforms and exclude system
by_platform = {}
for platform, count in raw_platforms.items():
# Skip system notifications
if platform == 'system':
continue
# Consolidate forum -> forums
if platform == 'forum':
by_platform['forums'] = by_platform.get('forums', 0) + count
# Consolidate fastdl -> instagram (fastdl is an Instagram download method)
elif platform == 'fastdl':
by_platform['instagram'] = by_platform.get('instagram', 0) + count
# Standardize youtube_monitor/youtube_monitors -> youtube
elif platform in ('youtube_monitor', 'youtube_monitors'):
by_platform['youtube'] = by_platform.get('youtube', 0) + count
else:
by_platform[platform] = by_platform.get(platform, 0) + count
# Recent 24h
cursor.execute("""
SELECT COUNT(*) FROM notifications
WHERE sent_at >= datetime('now', '-1 day')
""")
result = cursor.fetchone()
recent_24h = result[0] if result else 0
# Unique sources for filter dropdown
cursor.execute("""
SELECT DISTINCT source FROM notifications
WHERE source IS NOT NULL AND source != ''
ORDER BY source
""")
raw_sources = [row[0] for row in cursor.fetchall()]
# Standardize source names and track special sources
sources = []
has_youtube_monitor = False
has_log_errors = False
for source in raw_sources:
# Standardize youtube_monitor -> YouTube Monitor
if source == 'youtube_monitor':
has_youtube_monitor = True
elif source == 'Log Errors':
has_log_errors = True
else:
sources.append(source)
# Put special sources at the top
priority_sources = []
if has_youtube_monitor:
priority_sources.append('YouTube Monitor')
if has_log_errors:
priority_sources.append('Log Errors')
sources = priority_sources + sources
return {
'total_sent': total_sent,
'total_failed': total_failed,
'by_platform': by_platform,
'recent_24h': recent_24h,
'sources': sources
}
@router.delete("/notifications/{notification_id}")
@limiter.limit("100/minute")
@handle_exceptions
async def delete_notification(
request: Request,
notification_id: int,
current_user: Dict = Depends(get_current_user)
):
"""Delete a single notification from history."""
app_state = get_app_state()
with app_state.db.get_connection(for_write=True) as conn:
cursor = conn.cursor()
# Check if notification exists
cursor.execute("SELECT id FROM notifications WHERE id = ?", (notification_id,))
if not cursor.fetchone():
raise RecordNotFoundError(
"Notification not found",
{"notification_id": notification_id}
)
# Delete the notification
cursor.execute("DELETE FROM notifications WHERE id = ?", (notification_id,))
conn.commit()
return {
'success': True,
'message': 'Notification deleted',
'notification_id': notification_id
}
# ============================================================================
# CHANGELOG ENDPOINT
# ============================================================================
@router.get("/changelog")
@limiter.limit("100/minute")
@handle_exceptions
async def get_changelog(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Get changelog data from JSON file."""
changelog_path = settings.PROJECT_ROOT / "data" / "changelog.json"
if not changelog_path.exists():
return {"versions": []}
with open(changelog_path, 'r') as f:
changelog_data = json.load(f)
return {"versions": changelog_data}
# ============================================================================
# APPEARANCE CONFIG ENDPOINTS
# ============================================================================
class AppearanceConfigUpdate(BaseModel):
tmdb_api_key: Optional[str] = None
tmdb_enabled: bool = True
tmdb_check_interval_hours: int = 12
notify_new_appearances: bool = True
notify_days_before: int = 1
podcast_enabled: bool = False
radio_enabled: bool = False
podchaser_client_id: Optional[str] = None
podchaser_client_secret: Optional[str] = None
podchaser_api_key: Optional[str] = None
podchaser_enabled: bool = False
imdb_enabled: bool = True
@router.get("/config/appearance")
@limiter.limit("100/minute")
@handle_exceptions
async def get_appearance_config(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Get appearance tracking configuration."""
db = get_app_state().db
try:
with db.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
SELECT tmdb_api_key, tmdb_enabled, tmdb_check_interval_hours, tmdb_last_check,
notify_new_appearances, notify_days_before, podcast_enabled, radio_enabled,
podchaser_client_id, podchaser_client_secret, podchaser_api_key,
podchaser_enabled, podchaser_last_check, imdb_enabled
FROM appearance_config
WHERE id = 1
''')
row = cursor.fetchone()
if not row:
# Initialize config if not exists
cursor.execute('INSERT OR IGNORE INTO appearance_config (id) VALUES (1)')
conn.commit()
return {
"tmdb_api_key": None,
"tmdb_enabled": True,
"tmdb_check_interval_hours": 12,
"tmdb_last_check": None,
"notify_new_appearances": True,
"notify_days_before": 1,
"podcast_enabled": False,
"radio_enabled": False,
"podchaser_client_id": None,
"podchaser_client_secret": None,
"podchaser_api_key": None,
"podchaser_enabled": False,
"podchaser_last_check": None
}
return {
"tmdb_api_key": row[0],
"tmdb_enabled": bool(row[1]),
"tmdb_check_interval_hours": row[2],
"tmdb_last_check": row[3],
"notify_new_appearances": bool(row[4]),
"notify_days_before": row[5],
"podcast_enabled": bool(row[6]),
"radio_enabled": bool(row[7]),
"podchaser_client_id": row[8] if len(row) > 8 else None,
"podchaser_client_secret": row[9] if len(row) > 9 else None,
"podchaser_api_key": row[10] if len(row) > 10 else None,
"podchaser_enabled": bool(row[11]) if len(row) > 11 else False,
"podchaser_last_check": row[12] if len(row) > 12 else None,
"imdb_enabled": bool(row[13]) if len(row) > 13 else True
}
except Exception as e:
logger.error(f"Error getting appearance config: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/config/appearance")
@limiter.limit("100/minute")
@handle_exceptions
async def update_appearance_config(
request: Request,
config: AppearanceConfigUpdate,
current_user: Dict = Depends(get_current_user)
):
"""Update appearance tracking configuration."""
db = get_app_state().db
try:
with db.get_connection(for_write=True) as conn:
cursor = conn.cursor()
# Update config
cursor.execute('''
UPDATE appearance_config
SET tmdb_api_key = ?,
tmdb_enabled = ?,
tmdb_check_interval_hours = ?,
notify_new_appearances = ?,
notify_days_before = ?,
podcast_enabled = ?,
radio_enabled = ?,
podchaser_client_id = ?,
podchaser_client_secret = ?,
podchaser_api_key = ?,
podchaser_enabled = ?,
imdb_enabled = ?,
updated_at = CURRENT_TIMESTAMP
WHERE id = 1
''', (config.tmdb_api_key, config.tmdb_enabled, config.tmdb_check_interval_hours,
config.notify_new_appearances, config.notify_days_before,
config.podcast_enabled, config.radio_enabled,
config.podchaser_client_id, config.podchaser_client_secret,
config.podchaser_api_key, config.podchaser_enabled, config.imdb_enabled))
conn.commit()
return {
"success": True,
"message": "Appearance configuration updated successfully"
}
except Exception as e:
logger.error(f"Error updating appearance config: {e}")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -0,0 +1,304 @@
"""
Dashboard API Router
Provides endpoints for dashboard-specific data like recent items across different locations.
"""
from fastapi import APIRouter, Depends, Request
from typing import Dict, Any, Optional
from slowapi import Limiter
from slowapi.util import get_remote_address
from ..core.dependencies import get_current_user, get_app_state
from ..core.exceptions import handle_exceptions
from modules.universal_logger import get_logger
router = APIRouter(prefix="/api/dashboard", tags=["dashboard"])
logger = get_logger('API')
limiter = Limiter(key_func=get_remote_address)
@router.get("/recent-items")
@limiter.limit("60/minute")
@handle_exceptions
async def get_recent_items(
request: Request,
limit: int = 20,
since_id: Optional[int] = None,
current_user=Depends(get_current_user)
) -> Dict[str, Any]:
"""
Get NEW items from Media, Review, and Internet Discovery for dashboard cards.
Uses file_inventory.id for ordering since it monotonically increases with
insertion order. download_date from the downloads table is included for
display but not used for ordering (batch downloads can interleave timestamps).
Args:
limit: Max items per category
since_id: Optional file_inventory ID - only return items with id > this value
Returns up to `limit` items from each location, sorted by most recently added first.
"""
app_state = get_app_state()
with app_state.db.get_connection() as conn:
cursor = conn.cursor()
# Media items (location='final')
# ORDER BY fi.id DESC — id is monotonically increasing and reflects insertion order.
# download_date is included for display but NOT used for ordering.
if since_id:
cursor.execute("""
SELECT fi.id, fi.file_path, fi.filename, fi.source, fi.platform, fi.content_type,
fi.file_size, COALESCE(d.download_date, fi.created_date) as added_at,
fi.width, fi.height
FROM file_inventory fi
LEFT JOIN downloads d ON d.filename = fi.filename
WHERE fi.location = 'final'
AND fi.id > ?
AND (fi.moved_from_review IS NULL OR fi.moved_from_review = 0)
AND (fi.from_discovery IS NULL OR fi.from_discovery = 0)
ORDER BY fi.id DESC
LIMIT ?
""", (since_id, limit))
else:
cursor.execute("""
SELECT fi.id, fi.file_path, fi.filename, fi.source, fi.platform, fi.content_type,
fi.file_size, COALESCE(d.download_date, fi.created_date) as added_at,
fi.width, fi.height
FROM file_inventory fi
LEFT JOIN downloads d ON d.filename = fi.filename
WHERE fi.location = 'final'
AND (fi.moved_from_review IS NULL OR fi.moved_from_review = 0)
AND (fi.from_discovery IS NULL OR fi.from_discovery = 0)
ORDER BY fi.id DESC
LIMIT ?
""", (limit,))
media_items = []
for row in cursor.fetchall():
media_items.append({
'id': row[0],
'file_path': row[1],
'filename': row[2],
'source': row[3],
'platform': row[4],
'media_type': row[5],
'file_size': row[6],
'added_at': row[7],
'width': row[8],
'height': row[9]
})
# Get total count for new media items
if since_id:
cursor.execute("""
SELECT COUNT(*)
FROM file_inventory
WHERE location = 'final'
AND id > ?
AND (moved_from_review IS NULL OR moved_from_review = 0)
AND (from_discovery IS NULL OR from_discovery = 0)
""", (since_id,))
else:
cursor.execute("""
SELECT COUNT(*) FROM file_inventory
WHERE location = 'final'
AND (moved_from_review IS NULL OR moved_from_review = 0)
AND (from_discovery IS NULL OR from_discovery = 0)
""")
media_count = cursor.fetchone()[0]
# Review items (location='review')
if since_id:
cursor.execute("""
SELECT f.id, f.file_path, f.filename, f.source, f.platform, f.content_type,
f.file_size, COALESCE(d.download_date, f.created_date) as added_at,
f.width, f.height,
CASE WHEN fr.id IS NOT NULL THEN 1 ELSE 0 END as face_scanned,
fr.has_match as face_matched, fr.confidence as face_confidence, fr.matched_person
FROM file_inventory f
LEFT JOIN downloads d ON d.filename = f.filename
LEFT JOIN face_recognition_scans fr ON f.file_path = fr.file_path
WHERE f.location = 'review'
AND f.id > ?
AND (f.moved_from_media IS NULL OR f.moved_from_media = 0)
ORDER BY f.id DESC
LIMIT ?
""", (since_id, limit))
else:
cursor.execute("""
SELECT f.id, f.file_path, f.filename, f.source, f.platform, f.content_type,
f.file_size, COALESCE(d.download_date, f.created_date) as added_at,
f.width, f.height,
CASE WHEN fr.id IS NOT NULL THEN 1 ELSE 0 END as face_scanned,
fr.has_match as face_matched, fr.confidence as face_confidence, fr.matched_person
FROM file_inventory f
LEFT JOIN downloads d ON d.filename = f.filename
LEFT JOIN face_recognition_scans fr ON f.file_path = fr.file_path
WHERE f.location = 'review'
AND (f.moved_from_media IS NULL OR f.moved_from_media = 0)
ORDER BY f.id DESC
LIMIT ?
""", (limit,))
review_items = []
for row in cursor.fetchall():
face_recognition = None
if row[10]: # face_scanned
face_recognition = {
'scanned': True,
'matched': bool(row[11]) if row[11] is not None else False,
'confidence': row[12],
'matched_person': row[13]
}
review_items.append({
'id': row[0],
'file_path': row[1],
'filename': row[2],
'source': row[3],
'platform': row[4],
'media_type': row[5],
'file_size': row[6],
'added_at': row[7],
'width': row[8],
'height': row[9],
'face_recognition': face_recognition
})
# Get total count for new review items
if since_id:
cursor.execute("""
SELECT COUNT(*)
FROM file_inventory
WHERE location = 'review'
AND id > ?
AND (moved_from_media IS NULL OR moved_from_media = 0)
""", (since_id,))
else:
cursor.execute("""
SELECT COUNT(*) FROM file_inventory
WHERE location = 'review'
AND (moved_from_media IS NULL OR moved_from_media = 0)
""")
review_count = cursor.fetchone()[0]
# Internet Discovery items (celebrity_discovered_videos with status='new')
internet_discovery_items = []
internet_discovery_count = 0
try:
cursor.execute("""
SELECT
v.id,
v.video_id,
v.title,
v.thumbnail,
v.channel_name,
v.platform,
v.duration,
v.max_resolution,
v.status,
v.discovered_at,
v.url,
v.view_count,
v.upload_date,
c.name as celebrity_name
FROM celebrity_discovered_videos v
LEFT JOIN celebrity_profiles c ON v.celebrity_id = c.id
WHERE v.status = 'new'
ORDER BY v.id DESC
LIMIT ?
""", (limit,))
for row in cursor.fetchall():
internet_discovery_items.append({
'id': row[0],
'video_id': row[1],
'title': row[2],
'thumbnail': row[3],
'channel_name': row[4],
'platform': row[5],
'duration': row[6],
'max_resolution': row[7],
'status': row[8],
'discovered_at': row[9],
'url': row[10],
'view_count': row[11],
'upload_date': row[12],
'celebrity_name': row[13]
})
# Get total count for internet discovery
cursor.execute("SELECT COUNT(*) FROM celebrity_discovered_videos WHERE status = 'new'")
internet_discovery_count = cursor.fetchone()[0]
except Exception as e:
# Table might not exist if celebrity feature not used
logger.warning(f"Could not fetch internet discovery items: {e}", module="Dashboard")
return {
'media': {
'count': media_count,
'items': media_items
},
'review': {
'count': review_count,
'items': review_items
},
'internet_discovery': {
'count': internet_discovery_count,
'items': internet_discovery_items
}
}
@router.get("/dismissed-cards")
@limiter.limit("60/minute")
@handle_exceptions
async def get_dismissed_cards(
request: Request,
user=Depends(get_current_user)
) -> Dict[str, Any]:
"""Get the user's dismissed card IDs."""
app_state = get_app_state()
user_id = user.get('username', 'default')
with app_state.db.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT preference_value FROM user_preferences
WHERE user_id = ? AND preference_key = 'dashboard_dismissed_cards'
""", (user_id,))
row = cursor.fetchone()
if row and row[0]:
import json
return json.loads(row[0])
return {'media': None, 'review': None, 'internet_discovery': None}
@router.post("/dismissed-cards")
@limiter.limit("30/minute")
@handle_exceptions
async def set_dismissed_cards(
request: Request,
data: Dict[str, Any],
user=Depends(get_current_user)
) -> Dict[str, str]:
"""Save the user's dismissed card IDs."""
import json
app_state = get_app_state()
user_id = user.get('username', 'default')
with app_state.db.get_connection(for_write=True) as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO user_preferences (user_id, preference_key, preference_value, updated_at)
VALUES (?, 'dashboard_dismissed_cards', ?, CURRENT_TIMESTAMP)
ON CONFLICT(user_id, preference_key) DO UPDATE SET
preference_value = excluded.preference_value,
updated_at = CURRENT_TIMESTAMP
""", (user_id, json.dumps(data)))
return {'status': 'ok'}

View File

@@ -0,0 +1,942 @@
"""
Discovery Router
Handles discovery, organization and browsing features:
- Tags management (CRUD, file tagging, bulk operations)
- Smart folders (filter-based virtual folders)
- Collections (manual file groupings)
- Timeline and activity views
- Discovery queue management
"""
import json
from datetime import datetime
from typing import Dict, List, Optional
from fastapi import APIRouter, Body, Depends, Query, Request
from pydantic import BaseModel
from slowapi import Limiter
from slowapi.util import get_remote_address
from ..core.dependencies import get_current_user, get_app_state
from ..core.exceptions import handle_exceptions, NotFoundError, ValidationError
from ..core.responses import message_response, id_response, count_response, offset_paginated
from modules.discovery_system import get_discovery_system
from modules.universal_logger import get_logger
logger = get_logger('API')
router = APIRouter(prefix="/api", tags=["Discovery"])
limiter = Limiter(key_func=get_remote_address)
# ============================================================================
# PYDANTIC MODELS
# ============================================================================
class TagCreate(BaseModel):
name: str
parent_id: Optional[int] = None
color: str = '#6366f1'
icon: Optional[str] = None
description: Optional[str] = None
class TagUpdate(BaseModel):
name: Optional[str] = None
parent_id: Optional[int] = None
color: Optional[str] = None
icon: Optional[str] = None
description: Optional[str] = None
class BulkTagRequest(BaseModel):
file_ids: List[int]
tag_ids: List[int]
class SmartFolderCreate(BaseModel):
name: str
filters: dict = {}
icon: str = 'folder'
color: str = '#6366f1'
description: Optional[str] = None
sort_by: str = 'post_date'
sort_order: str = 'desc'
class SmartFolderUpdate(BaseModel):
name: Optional[str] = None
filters: Optional[dict] = None
icon: Optional[str] = None
color: Optional[str] = None
description: Optional[str] = None
sort_by: Optional[str] = None
sort_order: Optional[str] = None
class CollectionCreate(BaseModel):
name: str
description: Optional[str] = None
color: str = '#6366f1'
class CollectionUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
color: Optional[str] = None
cover_file_id: Optional[int] = None
class BulkCollectionAdd(BaseModel):
file_ids: List[int]
class DiscoveryQueueAdd(BaseModel):
file_ids: List[int]
priority: int = 0
# ============================================================================
# TAGS ENDPOINTS
# ============================================================================
@router.get("/tags")
@limiter.limit("60/minute")
@handle_exceptions
async def get_tags(
request: Request,
current_user: Dict = Depends(get_current_user),
parent_id: Optional[int] = Query(None, description="Parent tag ID (null for root, -1 for all)"),
include_counts: bool = Query(True, description="Include file counts")
):
"""Get all tags, optionally filtered by parent."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
tags = discovery.get_tags(parent_id=parent_id, include_counts=include_counts)
return {"tags": tags}
@router.get("/tags/{tag_id}")
@limiter.limit("60/minute")
@handle_exceptions
async def get_tag(
request: Request,
tag_id: int,
current_user: Dict = Depends(get_current_user)
):
"""Get a single tag by ID."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
tag = discovery.get_tag(tag_id)
if not tag:
raise NotFoundError("Tag not found")
return tag
@router.post("/tags")
@limiter.limit("30/minute")
@handle_exceptions
async def create_tag(
request: Request,
tag_data: TagCreate,
current_user: Dict = Depends(get_current_user)
):
"""Create a new tag."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
tag_id = discovery.create_tag(
name=tag_data.name,
parent_id=tag_data.parent_id,
color=tag_data.color,
icon=tag_data.icon,
description=tag_data.description
)
if tag_id is None:
raise ValidationError("Failed to create tag")
return id_response(tag_id, "Tag created successfully")
@router.put("/tags/{tag_id}")
@limiter.limit("30/minute")
@handle_exceptions
async def update_tag(
request: Request,
tag_id: int,
tag_data: TagUpdate,
current_user: Dict = Depends(get_current_user)
):
"""Update a tag."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
success = discovery.update_tag(
tag_id=tag_id,
name=tag_data.name,
color=tag_data.color,
icon=tag_data.icon,
description=tag_data.description,
parent_id=tag_data.parent_id
)
if not success:
raise NotFoundError("Tag not found or update failed")
return message_response("Tag updated successfully")
@router.delete("/tags/{tag_id}")
@limiter.limit("30/minute")
@handle_exceptions
async def delete_tag(
request: Request,
tag_id: int,
current_user: Dict = Depends(get_current_user)
):
"""Delete a tag."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
success = discovery.delete_tag(tag_id)
if not success:
raise NotFoundError("Tag not found")
return message_response("Tag deleted successfully")
@router.get("/files/{file_id}/tags")
@limiter.limit("60/minute")
@handle_exceptions
async def get_file_tags(
request: Request,
file_id: int,
current_user: Dict = Depends(get_current_user)
):
"""Get all tags for a file."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
tags = discovery.get_file_tags(file_id)
return {"tags": tags}
@router.post("/files/{file_id}/tags/{tag_id}")
@limiter.limit("60/minute")
@handle_exceptions
async def tag_file(
request: Request,
file_id: int,
tag_id: int,
current_user: Dict = Depends(get_current_user)
):
"""Add a tag to a file."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
success = discovery.tag_file(file_id, tag_id, created_by=current_user.get('sub'))
if not success:
raise ValidationError("Failed to tag file")
return message_response("Tag added to file")
@router.delete("/files/{file_id}/tags/{tag_id}")
@limiter.limit("60/minute")
@handle_exceptions
async def untag_file(
request: Request,
file_id: int,
tag_id: int,
current_user: Dict = Depends(get_current_user)
):
"""Remove a tag from a file."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
success = discovery.untag_file(file_id, tag_id)
if not success:
raise NotFoundError("Tag not found on file")
return message_response("Tag removed from file")
@router.get("/tags/{tag_id}/files")
@limiter.limit("60/minute")
@handle_exceptions
async def get_files_by_tag(
request: Request,
tag_id: int,
current_user: Dict = Depends(get_current_user),
limit: int = Query(100, ge=1, le=1000),
offset: int = Query(0, ge=0)
):
"""Get files with a specific tag."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
files, total = discovery.get_files_by_tag(tag_id, limit=limit, offset=offset)
return offset_paginated(files, total, limit, offset, key="files")
@router.post("/tags/bulk")
@limiter.limit("30/minute")
@handle_exceptions
async def bulk_tag_files(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Tag multiple files with multiple tags."""
data = await request.json()
file_ids = data.get('file_ids', [])
tag_ids = data.get('tag_ids', [])
if not file_ids or not tag_ids:
raise ValidationError("file_ids and tag_ids required")
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
count = discovery.bulk_tag_files(file_ids, tag_ids, created_by=current_user.get('sub'))
return count_response(f"Tagged {count} file-tag pairs", count)
# ============================================================================
# SMART FOLDERS ENDPOINTS
# ============================================================================
@router.get("/smart-folders")
@limiter.limit("60/minute")
@handle_exceptions
async def get_smart_folders(
request: Request,
current_user: Dict = Depends(get_current_user),
include_system: bool = Query(True, description="Include system smart folders")
):
"""Get all smart folders."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
folders = discovery.get_smart_folders(include_system=include_system)
return {"smart_folders": folders}
@router.get("/smart-folders/stats")
@limiter.limit("30/minute")
@handle_exceptions
async def get_smart_folders_stats(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Get file counts and preview thumbnails for all smart folders."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
folders = discovery.get_smart_folders(include_system=True)
stats = {}
with app_state.db.get_connection() as conn:
cursor = conn.cursor()
for folder in folders:
filters = folder.get('filters', {})
folder_id = folder['id']
query = '''
SELECT COUNT(*) as count
FROM file_inventory fi
WHERE fi.location = 'final'
'''
params = []
if filters.get('platform'):
query += ' AND fi.platform = ?'
params.append(filters['platform'])
if filters.get('media_type'):
query += ' AND fi.content_type = ?'
params.append(filters['media_type'])
if filters.get('source'):
query += ' AND fi.source = ?'
params.append(filters['source'])
if filters.get('size_min'):
query += ' AND fi.file_size >= ?'
params.append(filters['size_min'])
cursor.execute(query, params)
count = cursor.fetchone()[0]
preview_query = '''
SELECT fi.file_path, fi.content_type
FROM file_inventory fi
WHERE fi.location = 'final'
'''
preview_params = []
if filters.get('platform'):
preview_query += ' AND fi.platform = ?'
preview_params.append(filters['platform'])
if filters.get('media_type'):
preview_query += ' AND fi.content_type = ?'
preview_params.append(filters['media_type'])
if filters.get('source'):
preview_query += ' AND fi.source = ?'
preview_params.append(filters['source'])
if filters.get('size_min'):
preview_query += ' AND fi.file_size >= ?'
preview_params.append(filters['size_min'])
preview_query += ' ORDER BY fi.created_date DESC LIMIT 4'
cursor.execute(preview_query, preview_params)
previews = []
for row in cursor.fetchall():
previews.append({
'file_path': row['file_path'],
'content_type': row['content_type']
})
stats[folder_id] = {
'count': count,
'previews': previews
}
return {"stats": stats}
@router.get("/smart-folders/{folder_id}")
@limiter.limit("60/minute")
@handle_exceptions
async def get_smart_folder(
request: Request,
folder_id: int,
current_user: Dict = Depends(get_current_user)
):
"""Get a single smart folder by ID."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
folder = discovery.get_smart_folder(folder_id=folder_id)
if not folder:
raise NotFoundError("Smart folder not found")
return folder
@router.post("/smart-folders")
@limiter.limit("30/minute")
@handle_exceptions
async def create_smart_folder(
request: Request,
folder_data: SmartFolderCreate,
current_user: Dict = Depends(get_current_user)
):
"""Create a new smart folder."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
folder_id = discovery.create_smart_folder(
name=folder_data.name,
filters=folder_data.filters,
icon=folder_data.icon,
color=folder_data.color,
description=folder_data.description,
sort_by=folder_data.sort_by,
sort_order=folder_data.sort_order
)
if folder_id is None:
raise ValidationError("Failed to create smart folder")
return {"id": folder_id, "message": "Smart folder created successfully"}
@router.put("/smart-folders/{folder_id}")
@limiter.limit("30/minute")
@handle_exceptions
async def update_smart_folder(
request: Request,
folder_id: int,
folder_data: SmartFolderUpdate,
current_user: Dict = Depends(get_current_user)
):
"""Update a smart folder (cannot update system folders)."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
success = discovery.update_smart_folder(
folder_id=folder_id,
name=folder_data.name,
filters=folder_data.filters,
icon=folder_data.icon,
color=folder_data.color,
description=folder_data.description,
sort_by=folder_data.sort_by,
sort_order=folder_data.sort_order
)
if not success:
raise ValidationError("Failed to update smart folder (may be a system folder)")
return {"message": "Smart folder updated successfully"}
@router.delete("/smart-folders/{folder_id}")
@limiter.limit("30/minute")
@handle_exceptions
async def delete_smart_folder(
request: Request,
folder_id: int,
current_user: Dict = Depends(get_current_user)
):
"""Delete a smart folder (cannot delete system folders)."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
success = discovery.delete_smart_folder(folder_id)
if not success:
raise ValidationError("Failed to delete smart folder (may be a system folder)")
return {"message": "Smart folder deleted successfully"}
# ============================================================================
# COLLECTIONS ENDPOINTS
# ============================================================================
@router.get("/collections")
@limiter.limit("60/minute")
@handle_exceptions
async def get_collections(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Get all collections."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
collections = discovery.get_collections()
return {"collections": collections}
@router.get("/collections/{collection_id}")
@limiter.limit("60/minute")
@handle_exceptions
async def get_collection(
request: Request,
collection_id: int,
current_user: Dict = Depends(get_current_user)
):
"""Get a single collection by ID."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
collection = discovery.get_collection(collection_id=collection_id)
if not collection:
raise NotFoundError("Collection not found")
return collection
@router.post("/collections")
@limiter.limit("30/minute")
@handle_exceptions
async def create_collection(
request: Request,
collection_data: CollectionCreate,
current_user: Dict = Depends(get_current_user)
):
"""Create a new collection."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
collection_id = discovery.create_collection(
name=collection_data.name,
description=collection_data.description,
color=collection_data.color
)
if collection_id is None:
raise ValidationError("Failed to create collection")
return {"id": collection_id, "message": "Collection created successfully"}
@router.put("/collections/{collection_id}")
@limiter.limit("30/minute")
@handle_exceptions
async def update_collection(
request: Request,
collection_id: int,
collection_data: CollectionUpdate,
current_user: Dict = Depends(get_current_user)
):
"""Update a collection."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
success = discovery.update_collection(
collection_id=collection_id,
name=collection_data.name,
description=collection_data.description,
color=collection_data.color,
cover_file_id=collection_data.cover_file_id
)
if not success:
raise NotFoundError("Collection not found")
return {"message": "Collection updated successfully"}
@router.delete("/collections/{collection_id}")
@limiter.limit("30/minute")
@handle_exceptions
async def delete_collection(
request: Request,
collection_id: int,
current_user: Dict = Depends(get_current_user)
):
"""Delete a collection."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
success = discovery.delete_collection(collection_id)
if not success:
raise NotFoundError("Collection not found")
return {"message": "Collection deleted successfully"}
@router.get("/collections/{collection_id}/files")
@limiter.limit("60/minute")
@handle_exceptions
async def get_collection_files(
request: Request,
collection_id: int,
current_user: Dict = Depends(get_current_user),
limit: int = Query(100, ge=1, le=1000),
offset: int = Query(0, ge=0)
):
"""Get files in a collection."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
files, total = discovery.get_collection_files(collection_id, limit=limit, offset=offset)
return {"files": files, "total": total, "limit": limit, "offset": offset}
@router.post("/collections/{collection_id}/files/{file_id}")
@limiter.limit("60/minute")
@handle_exceptions
async def add_to_collection(
request: Request,
collection_id: int,
file_id: int,
current_user: Dict = Depends(get_current_user)
):
"""Add a file to a collection."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
success = discovery.add_to_collection(collection_id, file_id, added_by=current_user.get('sub'))
if not success:
raise ValidationError("Failed to add file to collection")
return {"message": "File added to collection"}
@router.delete("/collections/{collection_id}/files/{file_id}")
@limiter.limit("60/minute")
@handle_exceptions
async def remove_from_collection(
request: Request,
collection_id: int,
file_id: int,
current_user: Dict = Depends(get_current_user)
):
"""Remove a file from a collection."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
success = discovery.remove_from_collection(collection_id, file_id)
if not success:
raise NotFoundError("File not found in collection")
return {"message": "File removed from collection"}
@router.post("/collections/{collection_id}/files/bulk")
@limiter.limit("30/minute")
@handle_exceptions
async def bulk_add_to_collection(
request: Request,
collection_id: int,
current_user: Dict = Depends(get_current_user)
):
"""Add multiple files to a collection."""
data = await request.json()
file_ids = data.get('file_ids', [])
if not file_ids:
raise ValidationError("file_ids required")
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
count = discovery.bulk_add_to_collection(collection_id, file_ids, added_by=current_user.get('sub'))
return {"message": f"Added {count} files to collection", "count": count}
# ============================================================================
# TIMELINE ENDPOINTS
# ============================================================================
@router.get("/timeline")
@limiter.limit("60/minute")
@handle_exceptions
async def get_timeline(
request: Request,
current_user: Dict = Depends(get_current_user),
granularity: str = Query('day', pattern='^(day|week|month|year)$'),
date_from: Optional[str] = Query(None, pattern=r'^\d{4}-\d{2}-\d{2}$'),
date_to: Optional[str] = Query(None, pattern=r'^\d{4}-\d{2}-\d{2}$'),
platform: Optional[str] = Query(None),
source: Optional[str] = Query(None)
):
"""Get timeline aggregation data."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
data = discovery.get_timeline_data(
granularity=granularity,
date_from=date_from,
date_to=date_to,
platform=platform,
source=source
)
return {"timeline": data, "granularity": granularity}
@router.get("/timeline/heatmap")
@limiter.limit("60/minute")
@handle_exceptions
async def get_timeline_heatmap(
request: Request,
current_user: Dict = Depends(get_current_user),
year: Optional[int] = Query(None, ge=2000, le=2100),
platform: Optional[str] = Query(None)
):
"""Get activity heatmap data (file counts per day for a year)."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
heatmap = discovery.get_activity_heatmap(year=year, platform=platform)
return {"heatmap": heatmap, "year": year or datetime.now().year}
@router.get("/timeline/on-this-day")
@limiter.limit("60/minute")
@handle_exceptions
async def get_on_this_day(
request: Request,
current_user: Dict = Depends(get_current_user),
month: Optional[int] = Query(None, ge=1, le=12),
day: Optional[int] = Query(None, ge=1, le=31),
limit: int = Query(50, ge=1, le=200)
):
"""Get content from the same day in previous years ('On This Day' feature)."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
files = discovery.get_on_this_day(month=month, day=day, limit=limit)
return {"files": files, "count": len(files)}
# ============================================================================
# RECENT ACTIVITY ENDPOINT
# ============================================================================
@router.get("/discovery/recent-activity")
@limiter.limit("60/minute")
@handle_exceptions
async def get_recent_activity(
request: Request,
current_user: Dict = Depends(get_current_user),
limit: int = Query(10, ge=1, le=50)
):
"""Get recent activity across downloads, deletions, and restores."""
app_state = get_app_state()
activity = {
'recent_downloads': [],
'recent_deleted': [],
'recent_restored': [],
'recent_moved_to_review': [],
'summary': {
'downloads_24h': 0,
'downloads_7d': 0,
'deleted_24h': 0,
'deleted_7d': 0
}
}
with app_state.db.get_connection() as conn:
cursor = conn.cursor()
# Recent downloads
cursor.execute('''
SELECT
fi.id, fi.file_path, fi.filename, fi.platform, fi.source,
fi.content_type, fi.file_size, fi.created_date,
d.download_date, d.post_date
FROM file_inventory fi
LEFT JOIN downloads d ON d.file_path = fi.file_path
WHERE fi.location = 'final'
ORDER BY fi.created_date DESC
LIMIT ?
''', (limit,))
for row in cursor.fetchall():
activity['recent_downloads'].append({
'id': row['id'],
'file_path': row['file_path'],
'filename': row['filename'],
'platform': row['platform'],
'source': row['source'],
'content_type': row['content_type'],
'file_size': row['file_size'],
'timestamp': row['download_date'] or row['created_date'],
'action': 'download'
})
# Recent deleted
cursor.execute('''
SELECT
id, original_path, original_filename, recycle_path,
file_size, deleted_at, deleted_from, metadata
FROM recycle_bin
ORDER BY deleted_at DESC
LIMIT ?
''', (limit,))
for row in cursor.fetchall():
metadata = {}
if row['metadata']:
try:
metadata = json.loads(row['metadata'])
except (json.JSONDecodeError, TypeError):
pass
activity['recent_deleted'].append({
'id': row['id'],
'file_path': row['recycle_path'],
'original_path': row['original_path'],
'filename': row['original_filename'],
'platform': metadata.get('platform', 'unknown'),
'source': metadata.get('source', ''),
'content_type': metadata.get('content_type', 'image'),
'file_size': row['file_size'] or 0,
'timestamp': row['deleted_at'],
'deleted_from': row['deleted_from'],
'action': 'delete'
})
# Recent moved to review
cursor.execute('''
SELECT
id, file_path, filename, platform, source,
content_type, file_size, created_date
FROM file_inventory
WHERE location = 'review'
ORDER BY created_date DESC
LIMIT ?
''', (limit,))
for row in cursor.fetchall():
activity['recent_moved_to_review'].append({
'id': row['id'],
'file_path': row['file_path'],
'filename': row['filename'],
'platform': row['platform'],
'source': row['source'],
'content_type': row['content_type'],
'file_size': row['file_size'],
'timestamp': row['created_date'],
'action': 'review'
})
# Summary stats
cursor.execute('''
SELECT COUNT(*) FROM file_inventory
WHERE location = 'final'
AND created_date >= datetime('now', '-1 day')
''')
activity['summary']['downloads_24h'] = cursor.fetchone()[0]
cursor.execute('''
SELECT COUNT(*) FROM file_inventory
WHERE location = 'final'
AND created_date >= datetime('now', '-7 days')
''')
activity['summary']['downloads_7d'] = cursor.fetchone()[0]
cursor.execute('''
SELECT COUNT(*) FROM recycle_bin
WHERE deleted_at >= datetime('now', '-1 day')
''')
activity['summary']['deleted_24h'] = cursor.fetchone()[0]
cursor.execute('''
SELECT COUNT(*) FROM recycle_bin
WHERE deleted_at >= datetime('now', '-7 days')
''')
activity['summary']['deleted_7d'] = cursor.fetchone()[0]
return activity
# ============================================================================
# DISCOVERY QUEUE ENDPOINTS
# ============================================================================
@router.get("/discovery/queue/stats")
@limiter.limit("60/minute")
@handle_exceptions
async def get_queue_stats(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Get discovery queue statistics."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
stats = discovery.get_queue_stats()
return stats
@router.get("/discovery/queue/pending")
@limiter.limit("60/minute")
@handle_exceptions
async def get_pending_queue(
request: Request,
current_user: Dict = Depends(get_current_user),
limit: int = Query(100, ge=1, le=1000)
):
"""Get pending items in the discovery queue."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
items = discovery.get_pending_queue(limit=limit)
return {"items": items, "count": len(items)}
@router.post("/discovery/queue/add")
@limiter.limit("30/minute")
@handle_exceptions
async def add_to_queue(
request: Request,
current_user: Dict = Depends(get_current_user),
file_id: int = Body(...),
priority: int = Body(0)
):
"""Add a file to the discovery queue."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
success = discovery.add_to_queue(file_id, priority=priority)
if not success:
raise ValidationError("Failed to add file to queue")
return {"message": "File added to queue"}
@router.post("/discovery/queue/bulk-add")
@limiter.limit("10/minute")
@handle_exceptions
async def bulk_add_to_queue(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Add multiple files to the discovery queue."""
data = await request.json()
file_ids = data.get('file_ids', [])
priority = data.get('priority', 0)
if not file_ids:
raise ValidationError("file_ids required")
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
count = discovery.bulk_add_to_queue(file_ids, priority=priority)
return {"message": f"Added {count} files to queue", "count": count}
@router.delete("/discovery/queue/clear")
@limiter.limit("10/minute")
@handle_exceptions
async def clear_queue(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Clear the discovery queue."""
app_state = get_app_state()
discovery = get_discovery_system(app_state.db)
count = discovery.clear_queue()
return {"message": f"Cleared {count} items from queue", "count": count}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,478 @@
"""
Easynews Router
Handles Easynews integration:
- Configuration management (credentials, proxy settings)
- Search term management
- Manual check triggers
- Results browsing and downloads
"""
import asyncio
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Dict, List, Optional
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request
from pydantic import BaseModel
from slowapi import Limiter
from slowapi.util import get_remote_address
from ..core.dependencies import get_current_user, require_admin, get_app_state
from ..core.exceptions import handle_exceptions
from modules.universal_logger import get_logger
logger = get_logger('API')
router = APIRouter(prefix="/api/easynews", tags=["Easynews"])
limiter = Limiter(key_func=get_remote_address)
# Thread pool for blocking operations
_executor = ThreadPoolExecutor(max_workers=2)
# ============================================================================
# PYDANTIC MODELS
# ============================================================================
class EasynewsConfigUpdate(BaseModel):
username: Optional[str] = None
password: Optional[str] = None
enabled: Optional[bool] = None
check_interval_hours: Optional[int] = None
auto_download: Optional[bool] = None
min_quality: Optional[str] = None
proxy_enabled: Optional[bool] = None
proxy_type: Optional[str] = None
proxy_host: Optional[str] = None
proxy_port: Optional[int] = None
proxy_username: Optional[str] = None
proxy_password: Optional[str] = None
notifications_enabled: Optional[bool] = None
# ============================================================================
# HELPER FUNCTIONS
# ============================================================================
def _get_monitor():
"""Get the Easynews monitor instance."""
from modules.easynews_monitor import EasynewsMonitor
app_state = get_app_state()
db_path = str(app_state.db.db_path) # Convert Path to string
return EasynewsMonitor(db_path)
# ============================================================================
# CONFIGURATION ENDPOINTS
# ============================================================================
@router.get("/config")
@limiter.limit("30/minute")
@handle_exceptions
async def get_config(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Get Easynews configuration (passwords masked)."""
monitor = _get_monitor()
config = monitor.get_config()
# Mask password for security
if config.get('password'):
config['password'] = '********'
if config.get('proxy_password'):
config['proxy_password'] = '********'
return {
"success": True,
"config": config
}
@router.put("/config")
@limiter.limit("10/minute")
@handle_exceptions
async def update_config(
request: Request,
config: EasynewsConfigUpdate,
current_user: Dict = Depends(require_admin)
):
"""Update Easynews configuration."""
monitor = _get_monitor()
# Build update kwargs
kwargs = {}
if config.username is not None:
kwargs['username'] = config.username
if config.password is not None and config.password != '********':
kwargs['password'] = config.password
if config.enabled is not None:
kwargs['enabled'] = config.enabled
if config.check_interval_hours is not None:
kwargs['check_interval_hours'] = config.check_interval_hours
if config.auto_download is not None:
kwargs['auto_download'] = config.auto_download
if config.min_quality is not None:
kwargs['min_quality'] = config.min_quality
if config.proxy_enabled is not None:
kwargs['proxy_enabled'] = config.proxy_enabled
if config.proxy_type is not None:
kwargs['proxy_type'] = config.proxy_type
if config.proxy_host is not None:
kwargs['proxy_host'] = config.proxy_host
if config.proxy_port is not None:
kwargs['proxy_port'] = config.proxy_port
if config.proxy_username is not None:
kwargs['proxy_username'] = config.proxy_username
if config.proxy_password is not None and config.proxy_password != '********':
kwargs['proxy_password'] = config.proxy_password
if config.notifications_enabled is not None:
kwargs['notifications_enabled'] = config.notifications_enabled
if not kwargs:
return {"success": False, "message": "No updates provided"}
success = monitor.update_config(**kwargs)
return {
"success": success,
"message": "Configuration updated" if success else "Update failed"
}
@router.post("/test")
@limiter.limit("5/minute")
@handle_exceptions
async def test_connection(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Test Easynews connection with current credentials."""
monitor = _get_monitor()
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(_executor, monitor.test_connection)
return result
# ============================================================================
# CELEBRITY ENDPOINTS (uses tracked celebrities from Appearances)
# ============================================================================
@router.get("/celebrities")
@limiter.limit("30/minute")
@handle_exceptions
async def get_celebrities(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Get all tracked celebrities that will be searched on Easynews."""
monitor = _get_monitor()
celebrities = monitor.get_celebrities()
return {
"success": True,
"celebrities": celebrities,
"count": len(celebrities)
}
# ============================================================================
# RESULTS ENDPOINTS
# ============================================================================
@router.get("/results")
@limiter.limit("30/minute")
@handle_exceptions
async def get_results(
request: Request,
status: Optional[str] = None,
celebrity_id: Optional[int] = None,
limit: int = 100,
offset: int = 0,
current_user: Dict = Depends(get_current_user)
):
"""Get discovered results with optional filters."""
monitor = _get_monitor()
results = monitor.get_results(
status=status,
celebrity_id=celebrity_id,
limit=limit,
offset=offset,
)
# Get total count
total = monitor.get_result_count(status=status)
return {
"success": True,
"results": results,
"count": len(results),
"total": total
}
@router.post("/results/{result_id}/status")
@limiter.limit("30/minute")
@handle_exceptions
async def update_result_status(
request: Request,
result_id: int,
status: str,
current_user: Dict = Depends(get_current_user)
):
"""Update a result's status (e.g., mark as ignored)."""
valid_statuses = ['new', 'downloaded', 'ignored', 'failed']
if status not in valid_statuses:
raise HTTPException(status_code=400, detail=f"Invalid status. Must be one of: {valid_statuses}")
monitor = _get_monitor()
success = monitor.update_result_status(result_id, status)
return {
"success": success,
"message": f"Status updated to {status}" if success else "Update failed"
}
@router.post("/results/{result_id}/download")
@limiter.limit("10/minute")
@handle_exceptions
async def download_result(
request: Request,
result_id: int,
background_tasks: BackgroundTasks,
current_user: Dict = Depends(get_current_user)
):
"""Start downloading a result."""
monitor = _get_monitor()
def do_download():
return monitor.download_result(result_id)
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(_executor, do_download)
return result
# ============================================================================
# CHECK ENDPOINTS
# ============================================================================
@router.get("/status")
@limiter.limit("60/minute")
@handle_exceptions
async def get_status(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Get current check status."""
monitor = _get_monitor()
status = monitor.get_status()
config = monitor.get_config()
celebrity_count = monitor.get_celebrity_count()
return {
"success": True,
"status": status,
"last_check": config.get('last_check'),
"enabled": config.get('enabled', False),
"has_credentials": config.get('has_credentials', False),
"celebrity_count": celebrity_count,
}
@router.post("/check")
@limiter.limit("5/minute")
@handle_exceptions
async def trigger_check(
request: Request,
background_tasks: BackgroundTasks,
current_user: Dict = Depends(get_current_user)
):
"""Trigger a manual check for all tracked celebrities."""
monitor = _get_monitor()
status = monitor.get_status()
if status.get('is_running'):
return {
"success": False,
"message": "Check already in progress"
}
def do_check():
return monitor.check_all_celebrities()
loop = asyncio.get_event_loop()
background_tasks.add_task(loop.run_in_executor, _executor, do_check)
return {
"success": True,
"message": "Check started"
}
@router.post("/check/{search_id}")
@limiter.limit("5/minute")
@handle_exceptions
async def trigger_single_check(
request: Request,
search_id: int,
background_tasks: BackgroundTasks,
current_user: Dict = Depends(get_current_user)
):
"""Trigger a manual check for a specific search term."""
monitor = _get_monitor()
status = monitor.get_status()
if status.get('is_running'):
return {
"success": False,
"message": "Check already in progress"
}
# Verify search exists
search = monitor.get_search(search_id)
if not search:
raise HTTPException(status_code=404, detail="Search not found")
def do_check():
return monitor.check_single_search(search_id)
loop = asyncio.get_event_loop()
background_tasks.add_task(loop.run_in_executor, _executor, do_check)
return {
"success": True,
"message": f"Check started for: {search['search_term']}"
}
# ============================================================================
# SEARCH MANAGEMENT ENDPOINTS
# ============================================================================
class EasynewsSearchCreate(BaseModel):
search_term: str
media_type: Optional[str] = 'any'
tmdb_id: Optional[int] = None
tmdb_title: Optional[str] = None
poster_url: Optional[str] = None
class EasynewsSearchUpdate(BaseModel):
search_term: Optional[str] = None
media_type: Optional[str] = None
enabled: Optional[bool] = None
tmdb_id: Optional[int] = None
tmdb_title: Optional[str] = None
poster_url: Optional[str] = None
@router.get("/searches")
@limiter.limit("30/minute")
@handle_exceptions
async def get_searches(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Get all saved search terms."""
monitor = _get_monitor()
searches = monitor.get_all_searches()
return {
"success": True,
"searches": searches,
"count": len(searches)
}
@router.post("/searches")
@limiter.limit("10/minute")
@handle_exceptions
async def add_search(
request: Request,
search: EasynewsSearchCreate,
current_user: Dict = Depends(get_current_user)
):
"""Add a new search term."""
monitor = _get_monitor()
search_id = monitor.add_search(
search_term=search.search_term,
media_type=search.media_type,
tmdb_id=search.tmdb_id,
tmdb_title=search.tmdb_title,
poster_url=search.poster_url
)
if search_id:
return {
"success": True,
"id": search_id,
"message": f"Search term '{search.search_term}' added"
}
else:
return {
"success": False,
"message": "Failed to add search term"
}
@router.put("/searches/{search_id}")
@limiter.limit("10/minute")
@handle_exceptions
async def update_search(
request: Request,
search_id: int,
updates: EasynewsSearchUpdate,
current_user: Dict = Depends(get_current_user)
):
"""Update an existing search term."""
monitor = _get_monitor()
# Build update kwargs
kwargs = {}
if updates.search_term is not None:
kwargs['search_term'] = updates.search_term
if updates.media_type is not None:
kwargs['media_type'] = updates.media_type
if updates.enabled is not None:
kwargs['enabled'] = updates.enabled
if updates.tmdb_id is not None:
kwargs['tmdb_id'] = updates.tmdb_id
if updates.tmdb_title is not None:
kwargs['tmdb_title'] = updates.tmdb_title
if updates.poster_url is not None:
kwargs['poster_url'] = updates.poster_url
if not kwargs:
return {"success": False, "message": "No updates provided"}
success = monitor.update_search(search_id, **kwargs)
return {
"success": success,
"message": "Search updated" if success else "Update failed"
}
@router.delete("/searches/{search_id}")
@limiter.limit("10/minute")
@handle_exceptions
async def delete_search(
request: Request,
search_id: int,
current_user: Dict = Depends(get_current_user)
):
"""Delete a search term."""
monitor = _get_monitor()
success = monitor.delete_search(search_id)
return {
"success": success,
"message": "Search deleted" if success else "Delete failed"
}

1248
web/backend/routers/face.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,128 @@
"""
File serving and thumbnail generation API
Provides endpoints for:
- On-demand thumbnail generation for images and videos
- File serving with proper caching headers
"""
from typing import Dict
from fastapi import APIRouter, Depends, Query, Request
from fastapi.responses import Response
from PIL import Image
from pathlib import Path
import subprocess
import io
from slowapi import Limiter
from slowapi.util import get_remote_address
from ..core.dependencies import get_current_user
from ..core.exceptions import handle_exceptions, NotFoundError, ValidationError
from ..core.utils import validate_file_path
from modules.universal_logger import get_logger
router = APIRouter(prefix="/files", tags=["files"])
logger = get_logger('FilesRouter')
limiter = Limiter(key_func=get_remote_address)
@router.get("/thumbnail")
@limiter.limit("300/minute")
@handle_exceptions
async def get_thumbnail(
request: Request,
path: str = Query(..., description="File path"),
current_user: Dict = Depends(get_current_user)
):
"""
Generate and return thumbnail for image or video
Args:
path: Absolute path to file
Returns:
JPEG thumbnail (200x200px max, maintains aspect ratio)
"""
# Validate file is within allowed directories (prevents path traversal)
file_path = validate_file_path(path, require_exists=True)
file_ext = file_path.suffix.lower()
# Generate thumbnail based on type
if file_ext in ['.jpg', '.jpeg', '.png', '.webp', '.gif', '.heic']:
# Image thumbnail with PIL
img = Image.open(file_path)
# Convert HEIC if needed
if file_ext == '.heic':
img = img.convert('RGB')
# Create thumbnail (maintains aspect ratio)
img.thumbnail((200, 200), Image.Resampling.LANCZOS)
# Convert to JPEG
buffer = io.BytesIO()
if img.mode in ('RGBA', 'LA', 'P'):
# Convert transparency to white background
background = Image.new('RGB', img.size, (255, 255, 255))
if img.mode == 'P':
img = img.convert('RGBA')
background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None)
img = background
img.convert('RGB').save(buffer, format='JPEG', quality=85, optimize=True)
buffer.seek(0)
return Response(
content=buffer.read(),
media_type="image/jpeg",
headers={"Cache-Control": "public, max-age=3600"}
)
elif file_ext in ['.mp4', '.webm', '.mov', '.avi', '.mkv']:
# Video thumbnail with ffmpeg
result = subprocess.run(
[
'ffmpeg',
'-ss', '00:00:01', # Seek to 1 second
'-i', str(file_path),
'-vframes', '1', # Extract 1 frame
'-vf', 'scale=200:-1', # Scale to 200px width, maintain aspect
'-f', 'image2pipe', # Output to pipe
'-vcodec', 'mjpeg', # JPEG codec
'pipe:1'
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
timeout=10,
check=False
)
if result.returncode != 0:
# Try without seeking (for very short videos)
result = subprocess.run(
[
'ffmpeg',
'-i', str(file_path),
'-vframes', '1',
'-vf', 'scale=200:-1',
'-f', 'image2pipe',
'-vcodec', 'mjpeg',
'pipe:1'
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
timeout=10,
check=True
)
return Response(
content=result.stdout,
media_type="image/jpeg",
headers={"Cache-Control": "public, max-age=3600"}
)
else:
raise ValidationError("Unsupported file type")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,436 @@
"""
Instagram Unified Configuration Router
Provides a single configuration interface for all Instagram scrapers.
Manages one central account list with per-account content type toggles,
scraper assignments, and auto-generates legacy per-scraper configs on save.
"""
import copy
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel
from slowapi import Limiter
from slowapi.util import get_remote_address
from ..core.dependencies import get_current_user, get_app_state
from ..core.exceptions import handle_exceptions, ValidationError
from modules.universal_logger import get_logger
logger = get_logger('API')
router = APIRouter(prefix="/api/instagram-unified", tags=["Instagram Unified"])
limiter = Limiter(key_func=get_remote_address)
# Scraper capability matrix
SCRAPER_CAPABILITIES = {
'fastdl': {'posts': True, 'stories': True, 'reels': True, 'tagged': False},
'imginn_api': {'posts': True, 'stories': True, 'reels': False, 'tagged': True},
'imginn': {'posts': True, 'stories': True, 'reels': False, 'tagged': True},
'toolzu': {'posts': True, 'stories': True, 'reels': False, 'tagged': False},
'instagram_client': {'posts': True, 'stories': True, 'reels': True, 'tagged': True},
'instagram': {'posts': True, 'stories': True, 'reels': False, 'tagged': True},
}
SCRAPER_LABELS = {
'fastdl': 'FastDL',
'imginn_api': 'ImgInn API',
'imginn': 'ImgInn',
'toolzu': 'Toolzu',
'instagram_client': 'Instagram Client',
'instagram': 'InstaLoader',
}
CONTENT_TYPES = ['posts', 'stories', 'reels', 'tagged']
LEGACY_SCRAPER_KEYS = ['fastdl', 'imginn_api', 'imginn', 'toolzu', 'instagram_client', 'instagram']
# Default destination paths
DEFAULT_PATHS = {
'posts': '/opt/immich/md/social media/instagram/posts',
'stories': '/opt/immich/md/social media/instagram/stories',
'reels': '/opt/immich/md/social media/instagram/reels',
'tagged': '/opt/immich/md/social media/instagram/tagged',
}
class UnifiedConfigUpdate(BaseModel):
config: Dict[str, Any]
# ============================================================================
# MIGRATION LOGIC
# ============================================================================
def _migrate_from_legacy(app_state) -> Dict[str, Any]:
"""
Build a unified config from existing per-scraper configs.
Called on first load when no instagram_unified key exists.
"""
settings = app_state.settings
# Load all legacy configs
legacy = {}
for key in LEGACY_SCRAPER_KEYS:
legacy[key] = settings.get(key) or {}
# Determine scraper assignments from currently enabled configs
# Only assign a scraper if it actually supports the content type per capability matrix
scraper_assignment = {}
for ct in CONTENT_TYPES:
assigned = None
for scraper_key in LEGACY_SCRAPER_KEYS:
cfg = legacy[scraper_key]
if (cfg.get('enabled') and cfg.get(ct, {}).get('enabled')
and SCRAPER_CAPABILITIES.get(scraper_key, {}).get(ct)):
assigned = scraper_key
break
# Fallback defaults if nothing is enabled — pick first capable scraper
if not assigned:
if ct in ('posts', 'tagged'):
assigned = 'imginn_api'
elif ct in ('stories', 'reels'):
assigned = 'fastdl'
scraper_assignment[ct] = assigned
# Collect all unique usernames from all scrapers
all_usernames = set()
scraper_usernames = {} # track which usernames belong to which scraper
for scraper_key in LEGACY_SCRAPER_KEYS:
cfg = legacy[scraper_key]
usernames = []
if scraper_key == 'instagram':
# InstaLoader uses accounts list format
for acc in cfg.get('accounts', []):
u = acc.get('username')
if u:
usernames.append(u)
else:
usernames = cfg.get('usernames', [])
# Also include phrase_search usernames
ps_usernames = cfg.get('phrase_search', {}).get('usernames', [])
combined = set(usernames) | set(ps_usernames)
scraper_usernames[scraper_key] = combined
all_usernames |= combined
# Build per-account content type flags
# An account gets a content type enabled if:
# - The scraper assigned to that content type has this username in its list
accounts = []
for username in sorted(all_usernames):
account = {'username': username}
for ct in CONTENT_TYPES:
assigned_scraper = scraper_assignment[ct]
# Enable if this user is in the assigned scraper's list
account[ct] = username in scraper_usernames.get(assigned_scraper, set())
accounts.append(account)
# Import content type settings from the first enabled scraper that has them
content_types = {}
for ct in CONTENT_TYPES:
ct_config = {'enabled': False, 'days_back': 7, 'destination_path': DEFAULT_PATHS[ct]}
assigned = scraper_assignment[ct]
cfg = legacy.get(assigned, {})
ct_sub = cfg.get(ct, {})
if ct_sub.get('enabled'):
ct_config['enabled'] = True
if ct_sub.get('days_back'):
ct_config['days_back'] = ct_sub['days_back']
if ct_sub.get('destination_path'):
ct_config['destination_path'] = ct_sub['destination_path']
content_types[ct] = ct_config
# Import phrase search from first scraper that has it enabled
phrase_search = {
'enabled': False,
'download_all': True,
'phrases': [],
'case_sensitive': False,
'match_all': False,
}
for scraper_key in LEGACY_SCRAPER_KEYS:
ps = legacy[scraper_key].get('phrase_search', {})
if ps.get('enabled') or ps.get('phrases'):
phrase_search['enabled'] = ps.get('enabled', False)
phrase_search['download_all'] = ps.get('download_all', True)
phrase_search['phrases'] = ps.get('phrases', [])
phrase_search['case_sensitive'] = ps.get('case_sensitive', False)
phrase_search['match_all'] = ps.get('match_all', False)
break
# Import scraper-specific settings (auth, cookies, etc.)
scraper_settings = {
'fastdl': {},
'imginn_api': {},
'imginn': {'cookie_file': legacy['imginn'].get('cookie_file', '')},
'toolzu': {
'email': legacy['toolzu'].get('email', ''),
'password': legacy['toolzu'].get('password', ''),
'cookie_file': legacy['toolzu'].get('cookie_file', ''),
},
'instagram_client': {},
'instagram': {
'username': legacy['instagram'].get('username', ''),
'password': legacy['instagram'].get('password', ''),
'totp_secret': legacy['instagram'].get('totp_secret', ''),
'session_file': legacy['instagram'].get('session_file', ''),
},
}
# Get global settings from the primary enabled scraper
check_interval = 8
run_at_start = False
user_delay = 20
for scraper_key in ['fastdl', 'imginn_api']:
cfg = legacy[scraper_key]
if cfg.get('enabled'):
check_interval = cfg.get('check_interval_hours', 8)
run_at_start = cfg.get('run_at_start', False)
user_delay = cfg.get('user_delay_seconds', 20)
break
return {
'enabled': True,
'check_interval_hours': check_interval,
'run_at_start': run_at_start,
'user_delay_seconds': user_delay,
'scraper_assignment': scraper_assignment,
'content_types': content_types,
'accounts': accounts,
'phrase_search': phrase_search,
'scraper_settings': scraper_settings,
}
# ============================================================================
# LEGACY CONFIG GENERATION
# ============================================================================
def _generate_legacy_configs(unified: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
"""
From the unified config, generate 6 legacy per-scraper configs
that the existing scraper modules can consume.
"""
scraper_assignment = unified.get('scraper_assignment', {})
content_types = unified.get('content_types', {})
accounts = unified.get('accounts', [])
phrase_search = unified.get('phrase_search', {})
scraper_settings = unified.get('scraper_settings', {})
# For each scraper, determine which content types it's assigned to
scraper_content_types = {key: [] for key in LEGACY_SCRAPER_KEYS}
for ct, scraper_key in scraper_assignment.items():
if scraper_key in scraper_content_types:
scraper_content_types[scraper_key].append(ct)
# For each scraper, collect usernames that have any of its assigned content types enabled
scraper_usernames = {key: set() for key in LEGACY_SCRAPER_KEYS}
for account in accounts:
username = account.get('username', '')
if not username:
continue
for ct, scraper_key in scraper_assignment.items():
if account.get(ct, False):
scraper_usernames[scraper_key].add(username)
# Build legacy configs
result = {}
for scraper_key in LEGACY_SCRAPER_KEYS:
assigned_cts = scraper_content_types[scraper_key]
usernames = sorted(scraper_usernames[scraper_key])
is_enabled = unified.get('enabled', False) and len(assigned_cts) > 0 and len(usernames) > 0
extra_settings = scraper_settings.get(scraper_key, {})
cfg = {
'enabled': is_enabled,
'check_interval_hours': unified.get('check_interval_hours', 8),
'run_at_start': unified.get('run_at_start', False),
}
# Add user_delay_seconds for scrapers that support it
if scraper_key in ('imginn_api', 'instagram_client'):
cfg['user_delay_seconds'] = unified.get('user_delay_seconds', 20)
# Add scraper-specific settings
if scraper_key == 'fastdl':
cfg['high_res'] = True # Always use high resolution
elif scraper_key == 'imginn':
cfg['cookie_file'] = extra_settings.get('cookie_file', '')
elif scraper_key == 'toolzu':
cfg['email'] = extra_settings.get('email', '')
cfg['password'] = extra_settings.get('password', '')
cfg['cookie_file'] = extra_settings.get('cookie_file', '')
# InstaLoader uses accounts format + auth fields at top level
if scraper_key == 'instagram':
ig_settings = extra_settings
cfg['method'] = 'instaloader'
cfg['username'] = ig_settings.get('username', '')
cfg['password'] = ig_settings.get('password', '')
cfg['totp_secret'] = ig_settings.get('totp_secret', '')
cfg['session_file'] = ig_settings.get('session_file', '')
cfg['accounts'] = [
{'username': u, 'check_interval_hours': unified.get('check_interval_hours', 8), 'run_at_start': False}
for u in usernames
]
else:
cfg['usernames'] = usernames
# Content type sub-configs
for ct in CONTENT_TYPES:
ct_global = content_types.get(ct, {})
if ct in assigned_cts:
cfg[ct] = {
'enabled': ct_global.get('enabled', False),
'days_back': ct_global.get('days_back', 7),
'destination_path': ct_global.get('destination_path', DEFAULT_PATHS.get(ct, '')),
}
# Add temp_dir based on scraper key
cfg[ct]['temp_dir'] = f'temp/{scraper_key}/{ct}'
else:
cfg[ct] = {'enabled': False}
# Phrase search goes on the scraper assigned to posts
posts_scraper = scraper_assignment.get('posts')
if scraper_key == posts_scraper and phrase_search.get('enabled'):
# Collect all usernames that have posts enabled for the phrase search usernames list
ps_usernames = sorted(scraper_usernames.get(scraper_key, set()))
cfg['phrase_search'] = {
'enabled': phrase_search.get('enabled', False),
'download_all': phrase_search.get('download_all', True),
'usernames': ps_usernames,
'phrases': phrase_search.get('phrases', []),
'case_sensitive': phrase_search.get('case_sensitive', False),
'match_all': phrase_search.get('match_all', False),
}
else:
cfg['phrase_search'] = {
'enabled': False,
'usernames': [],
'phrases': [],
'case_sensitive': False,
'match_all': False,
}
result[scraper_key] = cfg
return result
# ============================================================================
# ENDPOINTS
# ============================================================================
@router.get("/config")
@handle_exceptions
async def get_config(request: Request, user=Depends(get_current_user)):
"""
Load unified config. Auto-migrates from legacy configs on first load.
Returns config (or generated migration preview), hidden_modules, and scraper capabilities.
"""
app_state = get_app_state()
existing = app_state.settings.get('instagram_unified')
hidden_modules = app_state.settings.get('hidden_modules') or []
if existing:
return {
'config': existing,
'migrated': False,
'hidden_modules': hidden_modules,
'scraper_capabilities': SCRAPER_CAPABILITIES,
'scraper_labels': SCRAPER_LABELS,
}
# No unified config yet — generate migration preview (not auto-saved)
migrated_config = _migrate_from_legacy(app_state)
return {
'config': migrated_config,
'migrated': True,
'hidden_modules': hidden_modules,
'scraper_capabilities': SCRAPER_CAPABILITIES,
'scraper_labels': SCRAPER_LABELS,
}
@router.put("/config")
@handle_exceptions
async def update_config(request: Request, body: UnifiedConfigUpdate, user=Depends(get_current_user)):
"""
Save unified config + generate 6 legacy configs.
"""
app_state = get_app_state()
config = body.config
# Validate scraper assignments against capability matrix
scraper_assignment = config.get('scraper_assignment', {})
for ct, scraper_key in scraper_assignment.items():
if ct not in CONTENT_TYPES:
raise ValidationError(f"Unknown content type: {ct}")
if scraper_key not in SCRAPER_CAPABILITIES:
raise ValidationError(f"Unknown scraper: {scraper_key}")
if not SCRAPER_CAPABILITIES[scraper_key].get(ct):
raise ValidationError(
f"Scraper {SCRAPER_LABELS.get(scraper_key, scraper_key)} does not support {ct}"
)
# Save unified config
app_state.settings.set(
key='instagram_unified',
value=config,
category='scrapers',
description='Unified Instagram configuration',
updated_by='api'
)
# Generate and save 6 legacy configs
legacy_configs = _generate_legacy_configs(config)
for scraper_key, legacy_cfg in legacy_configs.items():
app_state.settings.set(
key=scraper_key,
value=legacy_cfg,
category='scrapers',
description=f'{SCRAPER_LABELS.get(scraper_key, scraper_key)} configuration (auto-generated)',
updated_by='api'
)
# Refresh in-memory config
if hasattr(app_state, 'config') and app_state.config is not None:
app_state.config['instagram_unified'] = config
for scraper_key, legacy_cfg in legacy_configs.items():
app_state.config[scraper_key] = legacy_cfg
# Broadcast config update via WebSocket
try:
if hasattr(app_state, 'websocket_manager') and app_state.websocket_manager:
import asyncio
asyncio.create_task(app_state.websocket_manager.broadcast({
'type': 'config_updated',
'data': {'source': 'instagram_unified'}
}))
except Exception as e:
logger.warning(f"Failed to broadcast config update: {e}", module="InstagramUnified")
return {
'success': True,
'message': 'Instagram configuration saved',
'legacy_configs_updated': list(legacy_configs.keys()),
}
@router.get("/capabilities")
@handle_exceptions
async def get_capabilities(request: Request, user=Depends(get_current_user)):
"""Return scraper capability matrix and hidden modules."""
app_state = get_app_state()
hidden_modules = app_state.settings.get('hidden_modules') or []
return {
'scraper_capabilities': SCRAPER_CAPABILITIES,
'scraper_labels': SCRAPER_LABELS,
'hidden_modules': hidden_modules,
'content_types': CONTENT_TYPES,
}

View File

@@ -0,0 +1,259 @@
"""
Maintenance Router
Handles database maintenance and cleanup operations:
- Scan and remove missing file references
- Database integrity checks
- Orphaned record cleanup
"""
import os
from pathlib import Path
from datetime import datetime
from typing import Dict, List
from fastapi import APIRouter, Depends, Request, BackgroundTasks
from slowapi import Limiter
from slowapi.util import get_remote_address
from ..core.dependencies import get_current_user, get_app_state
from ..core.config import settings
from ..core.responses import now_iso8601
from ..core.exceptions import handle_exceptions
from modules.universal_logger import get_logger
logger = get_logger('Maintenance')
router = APIRouter(prefix="/api/maintenance", tags=["Maintenance"])
limiter = Limiter(key_func=get_remote_address)
# Whitelist of allowed table/column combinations for cleanup operations
# This prevents SQL injection by only allowing known-safe identifiers
ALLOWED_CLEANUP_TABLES = {
"file_inventory": "file_path",
"downloads": "file_path",
"youtube_downloads": "file_path",
"video_downloads": "file_path",
"face_recognition_scans": "file_path",
"face_recognition_references": "reference_image_path",
"discovery_scan_queue": "file_path",
"recycle_bin": "recycle_path",
}
# Pre-built SQL queries for each allowed table (avoids any string interpolation)
# Uses 'id' instead of 'rowid' (PostgreSQL does not have rowid)
# Uses information_schema for table existence checks (PostgreSQL)
_CLEANUP_QUERIES = {
table: {
"check_exists": "SELECT table_name FROM information_schema.tables WHERE table_schema='public' AND table_name=?",
"select": f"SELECT id, {col} FROM {table} WHERE {col} IS NOT NULL AND {col} != ''",
"delete": f"DELETE FROM {table} WHERE id IN ",
}
for table, col in ALLOWED_CLEANUP_TABLES.items()
}
# Store last scan results
last_scan_result = None
@router.post("/cleanup/missing-files")
@limiter.limit("5/hour")
@handle_exceptions
async def cleanup_missing_files(
request: Request,
background_tasks: BackgroundTasks,
dry_run: bool = True,
current_user: Dict = Depends(get_current_user)
):
"""
Scan all database tables for file references and remove entries for missing files.
Args:
dry_run: If True, only report what would be deleted (default: True)
Returns:
Summary of files found and removed
"""
app_state = get_app_state()
user_id = current_user.get('sub', 'unknown')
logger.info(f"Database cleanup started by {user_id} (dry_run={dry_run})", module="Maintenance")
# Run cleanup in background
background_tasks.add_task(
_cleanup_missing_files_task,
app_state,
dry_run,
user_id
)
return {
"status": "started",
"dry_run": dry_run,
"message": "Cleanup scan started in background. Check /api/maintenance/cleanup/status for progress.",
"timestamp": now_iso8601()
}
@router.get("/cleanup/status")
@limiter.limit("60/minute")
@handle_exceptions
async def get_cleanup_status(request: Request, current_user: Dict = Depends(get_current_user)):
"""Get the status and results of the last cleanup scan"""
global last_scan_result
if last_scan_result is None:
return {
"status": "no_scan",
"message": "No cleanup scan has been run yet"
}
return last_scan_result
async def _cleanup_missing_files_task(app_state, dry_run: bool, user_id: str):
"""Background task to scan and cleanup missing files"""
global last_scan_result
start_time = datetime.now()
# Initialize result tracking
result = {
"status": "running",
"started_at": start_time.isoformat(),
"dry_run": dry_run,
"user": user_id,
"tables_scanned": {},
"total_checked": 0,
"total_missing": 0,
"total_removed": 0
}
try:
with app_state.db.get_connection(for_write=True) as conn:
cursor = conn.cursor()
# Define tables and their file path columns
# NOTE: instagram_perceptual_hashes is excluded because the hash data
# is valuable for duplicate detection even if the original file is gone
tables_to_scan = [
("file_inventory", "file_path"),
("downloads", "file_path"),
("youtube_downloads", "file_path"),
("video_downloads", "file_path"),
("face_recognition_scans", "file_path"),
("face_recognition_references", "reference_image_path"),
("discovery_scan_queue", "file_path"),
("recycle_bin", "recycle_path"),
]
for table_name, column_name in tables_to_scan:
logger.info(f"Scanning {table_name}.{column_name}...", module="Maintenance")
table_result = _scan_table(cursor, table_name, column_name, dry_run)
result["tables_scanned"][table_name] = table_result
result["total_checked"] += table_result["checked"]
result["total_missing"] += table_result["missing"]
result["total_removed"] += table_result["removed"]
# Commit if not dry run
if not dry_run:
conn.commit()
logger.info(f"Cleanup completed: removed {result['total_removed']} records", module="Maintenance")
else:
logger.info(f"Dry run completed: {result['total_missing']} records would be removed", module="Maintenance")
# Update result
end_time = datetime.now()
duration = (end_time - start_time).total_seconds()
result.update({
"status": "completed",
"completed_at": end_time.isoformat(),
"duration_seconds": round(duration, 2)
})
except Exception as e:
logger.error(f"Cleanup failed: {e}", module="Maintenance", exc_info=True)
result.update({
"status": "failed",
"error": str(e),
"completed_at": datetime.now().isoformat()
})
last_scan_result = result
def _scan_table(cursor, table_name: str, column_name: str, dry_run: bool) -> Dict:
"""Scan a table for missing files and optionally remove them.
Uses pre-built queries from _CLEANUP_QUERIES to prevent SQL injection.
Only tables in ALLOWED_CLEANUP_TABLES whitelist are permitted.
"""
result = {
"checked": 0,
"missing": 0,
"removed": 0,
"missing_files": []
}
# Validate table/column against whitelist to prevent SQL injection
if table_name not in ALLOWED_CLEANUP_TABLES:
logger.error(f"Table {table_name} not in allowed whitelist", module="Maintenance")
result["error"] = f"Table {table_name} not allowed"
return result
if ALLOWED_CLEANUP_TABLES[table_name] != column_name:
logger.error(f"Column {column_name} not allowed for table {table_name}", module="Maintenance")
result["error"] = f"Column {column_name} not allowed for table {table_name}"
return result
# Get pre-built queries for this table (built at module load time, not from user input)
queries = _CLEANUP_QUERIES[table_name]
try:
# Check if table exists using parameterized query
cursor.execute(queries["check_exists"], (table_name,))
if not cursor.fetchone():
logger.warning(f"Table {table_name} does not exist", module="Maintenance")
return result
# Get all file paths from table using pre-built query
cursor.execute(queries["select"])
rows = cursor.fetchall()
result["checked"] = len(rows)
missing_rowids = []
for rowid, file_path in rows:
if file_path and not os.path.exists(file_path):
result["missing"] += 1
missing_rowids.append(rowid)
# Only keep first 100 examples
if len(result["missing_files"]) < 100:
result["missing_files"].append(file_path)
# Remove missing entries if not dry run
if not dry_run and missing_rowids:
# Delete in batches of 100 using pre-built query base
delete_base = queries["delete"]
for i in range(0, len(missing_rowids), 100):
batch = missing_rowids[i:i+100]
placeholders = ','.join('?' * len(batch))
# The delete_base is pre-built from whitelist, placeholders are just ?
cursor.execute(f"{delete_base}({placeholders})", batch)
result["removed"] += len(batch)
logger.info(
f" {table_name}: checked={result['checked']}, missing={result['missing']}, "
f"{'would_remove' if dry_run else 'removed'}={result['missing']}",
module="Maintenance"
)
except Exception as e:
logger.error(f"Error scanning {table_name}: {e}", module="Maintenance", exc_info=True)
result["error"] = str(e)
return result

View File

@@ -0,0 +1,669 @@
"""
Manual Import Router
Handles manual file import operations:
- Service configuration
- File upload to temp directory
- Filename parsing
- Processing and moving to final destination (async background processing)
"""
import asyncio
import shutil
import uuid
from datetime import datetime
from pathlib import Path
from threading import Lock
from typing import Dict, List, Optional
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, Request, UploadFile
from pydantic import BaseModel
from slowapi import Limiter
from slowapi.util import get_remote_address
from ..core.dependencies import get_current_user, get_app_state
from ..core.exceptions import handle_exceptions, NotFoundError, ValidationError
from modules.filename_parser import FilenameParser, get_preset_patterns, parse_with_fallbacks, INSTAGRAM_PATTERNS
from modules.universal_logger import get_logger
logger = get_logger('API')
router = APIRouter(prefix="/api/manual-import", tags=["Manual Import"])
limiter = Limiter(key_func=get_remote_address)
# ============================================================================
# JOB TRACKING FOR BACKGROUND PROCESSING
# ============================================================================
# In-memory job tracking (jobs are transient - cleared on restart)
_import_jobs: Dict[str, Dict] = {}
_jobs_lock = Lock()
def get_job_status(job_id: str) -> Optional[Dict]:
"""Get the current status of an import job."""
with _jobs_lock:
return _import_jobs.get(job_id)
def update_job_status(job_id: str, updates: Dict):
"""Update an import job's status."""
with _jobs_lock:
if job_id in _import_jobs:
_import_jobs[job_id].update(updates)
def create_job(job_id: str, total_files: int, service_name: str):
"""Create a new import job."""
with _jobs_lock:
_import_jobs[job_id] = {
'id': job_id,
'status': 'processing',
'service_name': service_name,
'total_files': total_files,
'processed_files': 0,
'success_count': 0,
'failed_count': 0,
'results': [],
'current_file': None,
'started_at': datetime.now().isoformat(),
'completed_at': None
}
def cleanup_old_jobs():
"""Remove jobs older than 1 hour."""
with _jobs_lock:
now = datetime.now()
to_remove = []
for job_id, job in _import_jobs.items():
if job.get('completed_at'):
try:
completed = datetime.fromisoformat(job['completed_at'])
if (now - completed).total_seconds() > 3600: # 1 hour
to_remove.append(job_id)
except (ValueError, TypeError):
pass
for job_id in to_remove:
del _import_jobs[job_id]
# ============================================================================
# PYDANTIC MODELS
# ============================================================================
class ParseFilenameRequest(BaseModel):
filename: str
pattern: str
class FileInfo(BaseModel):
filename: str
temp_path: str
manual_datetime: Optional[str] = None
manual_username: Optional[str] = None
class ProcessFilesRequest(BaseModel):
service_name: str
files: List[dict]
# ============================================================================
# HELPER FUNCTIONS
# ============================================================================
def extract_youtube_metadata(video_id: str) -> Optional[Dict]:
"""Extract metadata from YouTube video using yt-dlp."""
import subprocess
import json
try:
result = subprocess.run(
[
'/opt/media-downloader/venv/bin/yt-dlp',
'--dump-json',
'--no-download',
'--no-warnings',
f'https://www.youtube.com/watch?v={video_id}'
],
capture_output=True,
text=True,
timeout=30
)
if result.returncode != 0:
return None
metadata = json.loads(result.stdout)
upload_date = None
if 'upload_date' in metadata and metadata['upload_date']:
try:
upload_date = datetime.strptime(metadata['upload_date'], '%Y%m%d')
except ValueError:
pass
return {
'title': metadata.get('title', ''),
'uploader': metadata.get('uploader', metadata.get('channel', '')),
'channel': metadata.get('channel', metadata.get('uploader', '')),
'upload_date': upload_date,
'duration': metadata.get('duration'),
'view_count': metadata.get('view_count'),
'description': metadata.get('description', '')[:200] if metadata.get('description') else None
}
except Exception as e:
logger.warning(f"Failed to extract YouTube metadata for {video_id}: {e}", module="ManualImport")
return None
def extract_video_id_from_filename(filename: str) -> Optional[str]:
"""Try to extract YouTube video ID from filename."""
import re
name = Path(filename).stem
# Pattern 1: ID in brackets [ID]
bracket_match = re.search(r'\[([A-Za-z0-9_-]{11})\]', name)
if bracket_match:
return bracket_match.group(1)
# Pattern 2: ID at end after underscore
underscore_match = re.search(r'_([A-Za-z0-9_-]{11})$', name)
if underscore_match:
return underscore_match.group(1)
# Pattern 3: Just the ID (filename is exactly 11 chars)
if re.match(r'^[A-Za-z0-9_-]{11}$', name):
return name
# Pattern 4: ID somewhere in the filename
id_match = re.search(r'(?:^|[_\-\s])([A-Za-z0-9_-]{11})(?:[_\-\s.]|$)', name)
if id_match:
return id_match.group(1)
return None
# ============================================================================
# ENDPOINTS
# ============================================================================
@router.get("/services")
@limiter.limit("60/minute")
@handle_exceptions
async def get_manual_import_services(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Get configured manual import services."""
app_state = get_app_state()
config = app_state.settings.get('manual_import')
if not config:
return {
"enabled": False,
"temp_dir": "/opt/media-downloader/temp/manual_import",
"services": [],
"preset_patterns": get_preset_patterns()
}
config['preset_patterns'] = get_preset_patterns()
return config
@router.post("/parse")
@limiter.limit("100/minute")
@handle_exceptions
async def parse_filename(
request: Request,
body: ParseFilenameRequest,
current_user: Dict = Depends(get_current_user)
):
"""Parse a filename using a pattern and return extracted metadata."""
try:
parser = FilenameParser(body.pattern)
result = parser.parse(body.filename)
if result['datetime']:
result['datetime'] = result['datetime'].isoformat()
return result
except Exception as e:
logger.error(f"Error parsing filename: {e}", module="ManualImport")
return {
"valid": False,
"error": str(e),
"username": None,
"datetime": None,
"media_id": None
}
# File upload constants
MAX_FILE_SIZE = 5 * 1024 * 1024 * 1024 # 5GB max file size
MAX_FILENAME_LENGTH = 255
ALLOWED_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.gif', '.mp4', '.mov', '.avi', '.mkv', '.webm', '.webp', '.heic', '.heif'}
@router.post("/upload")
@limiter.limit("30/minute")
@handle_exceptions
async def upload_files_for_import(
request: Request,
files: List[UploadFile] = File(...),
service_name: str = Form(...),
current_user: Dict = Depends(get_current_user)
):
"""Upload files to temp directory for manual import."""
app_state = get_app_state()
config = app_state.settings.get('manual_import')
if not config or not config.get('enabled'):
raise ValidationError("Manual import is not enabled")
services = config.get('services', [])
service = next((s for s in services if s['name'] == service_name and s.get('enabled', True)), None)
if not service:
raise NotFoundError(f"Service '{service_name}' not found or disabled")
session_id = str(uuid.uuid4())[:8]
temp_base = Path(config.get('temp_dir', '/opt/media-downloader/temp/manual_import'))
temp_dir = temp_base / session_id
temp_dir.mkdir(parents=True, exist_ok=True)
pattern = service.get('filename_pattern', '{username}_{YYYYMMDD}_{HHMMSS}_{id}')
platform = service.get('platform', 'unknown')
# Use fallback patterns for Instagram (handles both underscore and dash formats)
use_fallbacks = platform == 'instagram'
parser = FilenameParser(pattern) if not use_fallbacks else None
uploaded_files = []
for file in files:
# Sanitize filename - use only the basename to prevent path traversal
safe_filename = Path(file.filename).name
# Validate filename length
if len(safe_filename) > MAX_FILENAME_LENGTH:
raise ValidationError(f"Filename too long: {safe_filename[:50]}... (max {MAX_FILENAME_LENGTH} chars)")
# Validate file extension
file_ext = Path(safe_filename).suffix.lower()
if file_ext not in ALLOWED_EXTENSIONS:
raise ValidationError(f"File type not allowed: {file_ext}. Allowed: {', '.join(sorted(ALLOWED_EXTENSIONS))}")
file_path = temp_dir / safe_filename
content = await file.read()
# Validate file size
if len(content) > MAX_FILE_SIZE:
raise ValidationError(f"File too large: {safe_filename} ({len(content) / (1024*1024*1024):.2f}GB, max {MAX_FILE_SIZE / (1024*1024*1024)}GB)")
with open(file_path, 'wb') as f:
f.write(content)
# Parse filename - use fallback patterns for Instagram
if use_fallbacks:
parse_result = parse_with_fallbacks(file.filename, INSTAGRAM_PATTERNS)
else:
parse_result = parser.parse(file.filename)
parsed_datetime = None
if parse_result['datetime']:
parsed_datetime = parse_result['datetime'].isoformat()
uploaded_files.append({
"filename": file.filename,
"temp_path": str(file_path),
"size": len(content),
"parsed": {
"valid": parse_result['valid'],
"username": parse_result['username'],
"datetime": parsed_datetime,
"media_id": parse_result['media_id'],
"error": parse_result['error']
}
})
logger.info(f"Uploaded {len(uploaded_files)} files for manual import (service: {service_name})", module="ManualImport")
return {
"session_id": session_id,
"service_name": service_name,
"files": uploaded_files,
"temp_dir": str(temp_dir)
}
def process_files_background(
job_id: str,
service_name: str,
files: List[dict],
service: dict,
app_state
):
"""Background task to process imported files."""
import hashlib
from modules.move_module import MoveManager
destination = Path(service['destination'])
destination.mkdir(parents=True, exist_ok=True)
pattern = service.get('filename_pattern', '{username}_{YYYYMMDD}_{HHMMSS}_{id}')
platform = service.get('platform', 'unknown')
content_type = service.get('content_type', 'videos')
use_ytdlp = service.get('use_ytdlp', False)
parse_filename = service.get('parse_filename', True)
# Use fallback patterns for Instagram
use_fallbacks = platform == 'instagram'
parser = FilenameParser(pattern) if not use_fallbacks else None
# Generate session ID for real-time monitoring
session_id = f"manual_import_{service_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
# Emit scraper_started event
if app_state.scraper_event_emitter:
app_state.scraper_event_emitter.emit_scraper_started(
session_id=session_id,
platform=platform,
account=service_name,
content_type=content_type,
estimated_count=len(files)
)
move_manager = MoveManager(
unified_db=app_state.db,
face_recognition_enabled=False,
notifier=None,
event_emitter=app_state.scraper_event_emitter
)
move_manager.set_session_context(
platform=platform,
account=service_name,
session_id=session_id
)
results = []
success_count = 0
failed_count = 0
for idx, file_info in enumerate(files):
temp_path = Path(file_info['temp_path'])
filename = file_info['filename']
manual_datetime_str = file_info.get('manual_datetime')
manual_username = file_info.get('manual_username')
# Update job status with current file
update_job_status(job_id, {
'current_file': filename,
'processed_files': idx
})
if not temp_path.exists():
result = {"filename": filename, "status": "error", "error": "File not found in temp directory"}
results.append(result)
failed_count += 1
update_job_status(job_id, {'results': results.copy(), 'failed_count': failed_count})
continue
username = 'unknown'
parsed_datetime = None
final_filename = filename
# Use manual values if provided
if not parse_filename or manual_datetime_str or manual_username:
if manual_username:
username = manual_username.strip().lower()
if manual_datetime_str:
try:
parsed_datetime = datetime.strptime(manual_datetime_str, '%Y-%m-%dT%H:%M')
except ValueError:
try:
parsed_datetime = datetime.fromisoformat(manual_datetime_str)
except ValueError:
logger.warning(f"Could not parse manual datetime: {manual_datetime_str}", module="ManualImport")
# Try yt-dlp for YouTube videos
if use_ytdlp and platform == 'youtube':
video_id = extract_video_id_from_filename(filename)
if video_id:
logger.info(f"Extracting YouTube metadata for video ID: {video_id}", module="ManualImport")
yt_metadata = extract_youtube_metadata(video_id)
if yt_metadata:
username = yt_metadata.get('channel') or yt_metadata.get('uploader') or 'unknown'
username = "".join(c for c in username if c.isalnum() or c in ' _-').strip().replace(' ', '_')
parsed_datetime = yt_metadata.get('upload_date')
if yt_metadata.get('title'):
title = yt_metadata['title'][:50]
title = "".join(c for c in title if c.isalnum() or c in ' _-').strip().replace(' ', '_')
ext = Path(filename).suffix
final_filename = f"{username}_{parsed_datetime.strftime('%Y%m%d') if parsed_datetime else 'unknown'}_{title}_{video_id}{ext}"
# Fall back to filename parsing
if parse_filename and username == 'unknown':
if use_fallbacks:
parse_result = parse_with_fallbacks(filename, INSTAGRAM_PATTERNS)
else:
parse_result = parser.parse(filename)
if parse_result['valid']:
username = parse_result['username'] or 'unknown'
parsed_datetime = parse_result['datetime']
elif not use_ytdlp:
result = {"filename": filename, "status": "error", "error": parse_result['error'] or "Failed to parse filename"}
results.append(result)
failed_count += 1
update_job_status(job_id, {'results': results.copy(), 'failed_count': failed_count})
continue
dest_subdir = destination / username
dest_subdir.mkdir(parents=True, exist_ok=True)
dest_path = dest_subdir / final_filename
move_manager.start_batch(
platform=platform,
source=username,
content_type=content_type
)
file_size = temp_path.stat().st_size if temp_path.exists() else 0
try:
success = move_manager.move_file(
source=temp_path,
destination=dest_path,
timestamp=parsed_datetime,
preserve_if_no_timestamp=True,
content_type=content_type
)
move_manager.end_batch()
if success:
url_hash = hashlib.sha256(f"manual_import:{final_filename}".encode()).hexdigest()
with app_state.db.get_connection(for_write=True) as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT OR REPLACE INTO downloads
(url_hash, url, platform, source, content_type, filename, file_path,
file_size, file_hash, post_date, download_date, status, media_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 'completed', ?)
""", (
url_hash,
f"manual_import://{final_filename}",
platform,
username,
content_type,
final_filename,
str(dest_path),
file_size,
None,
parsed_datetime.isoformat() if parsed_datetime else None,
datetime.now().isoformat(),
None
))
conn.commit()
result = {
"filename": filename,
"status": "success",
"destination": str(dest_path),
"username": username,
"datetime": parsed_datetime.isoformat() if parsed_datetime else None
}
results.append(result)
success_count += 1
else:
result = {"filename": filename, "status": "error", "error": "Failed to move file (possibly duplicate)"}
results.append(result)
failed_count += 1
except Exception as e:
move_manager.end_batch()
result = {"filename": filename, "status": "error", "error": str(e)}
results.append(result)
failed_count += 1
# Update job status after each file
update_job_status(job_id, {
'results': results.copy(),
'success_count': success_count,
'failed_count': failed_count,
'processed_files': idx + 1
})
# Clean up temp directory
try:
temp_parent = Path(files[0]['temp_path']).parent if files else None
if temp_parent and temp_parent.exists():
for f in temp_parent.iterdir():
f.unlink()
temp_parent.rmdir()
except Exception:
pass
# Emit scraper_completed event
if app_state.scraper_event_emitter:
app_state.scraper_event_emitter.emit_scraper_completed(
session_id=session_id,
stats={
'total_downloaded': len(files),
'moved': success_count,
'review': 0,
'duplicates': 0,
'failed': failed_count
}
)
# Mark job as complete
update_job_status(job_id, {
'status': 'completed',
'completed_at': datetime.now().isoformat(),
'current_file': None
})
logger.info(f"Manual import complete: {success_count} succeeded, {failed_count} failed", module="ManualImport")
# Cleanup old jobs
cleanup_old_jobs()
@router.post("/process")
@limiter.limit("10/minute")
@handle_exceptions
async def process_imported_files(
request: Request,
background_tasks: BackgroundTasks,
body: ProcessFilesRequest,
current_user: Dict = Depends(get_current_user)
):
"""Process uploaded files in the background - returns immediately with job ID."""
app_state = get_app_state()
config = app_state.settings.get('manual_import')
if not config or not config.get('enabled'):
raise ValidationError("Manual import is not enabled")
services = config.get('services', [])
service = next((s for s in services if s['name'] == body.service_name and s.get('enabled', True)), None)
if not service:
raise NotFoundError(f"Service '{body.service_name}' not found")
# Generate unique job ID
job_id = f"import_{uuid.uuid4().hex[:12]}"
# Create job tracking entry
create_job(job_id, len(body.files), body.service_name)
# Queue background processing
background_tasks.add_task(
process_files_background,
job_id,
body.service_name,
body.files,
service,
app_state
)
logger.info(f"Manual import job {job_id} queued: {len(body.files)} files for {body.service_name}", module="ManualImport")
return {
"job_id": job_id,
"status": "processing",
"total_files": len(body.files),
"message": "Processing started in background"
}
@router.get("/status/{job_id}")
@limiter.limit("120/minute")
@handle_exceptions
async def get_import_job_status(
request: Request,
job_id: str,
current_user: Dict = Depends(get_current_user)
):
"""Get the status of a manual import job."""
job = get_job_status(job_id)
if not job:
raise NotFoundError(f"Job '{job_id}' not found")
return job
@router.delete("/temp")
@limiter.limit("10/minute")
@handle_exceptions
async def clear_temp_directory(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Clear all files from manual import temp directory."""
app_state = get_app_state()
config = app_state.settings.get('manual_import')
temp_dir = Path(config.get('temp_dir', '/opt/media-downloader/temp/manual_import')) if config else Path('/opt/media-downloader/temp/manual_import')
if temp_dir.exists():
shutil.rmtree(temp_dir)
temp_dir.mkdir(parents=True, exist_ok=True)
logger.info("Cleared manual import temp directory", module="ManualImport")
return {"status": "success", "message": "Temp directory cleared"}
@router.get("/preset-patterns")
@limiter.limit("60/minute")
@handle_exceptions
async def get_preset_filename_patterns(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Get available preset filename patterns."""
return {"patterns": get_preset_patterns()}

1404
web/backend/routers/media.py Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

1098
web/backend/routers/press.py Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,602 @@
"""
Recycle Bin Router
Handles all recycle bin operations:
- List deleted files
- Recycle bin statistics
- Restore files
- Permanently delete files
- Empty recycle bin
- Serve files for preview
- Get file metadata
"""
import hashlib
import json
import mimetypes
import sqlite3
from typing import Dict, Optional
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, Body, Query, Request
from fastapi.responses import FileResponse, Response
from slowapi import Limiter
from slowapi.util import get_remote_address
from ..core.dependencies import get_current_user, get_current_user_media, require_admin, get_app_state
from ..core.config import settings
from ..core.exceptions import (
handle_exceptions,
DatabaseError,
RecordNotFoundError,
MediaFileNotFoundError as CustomFileNotFoundError,
FileOperationError
)
from ..core.responses import now_iso8601
from ..core.utils import ThumbnailLRUCache
from modules.universal_logger import get_logger
logger = get_logger('API')
router = APIRouter(prefix="/api/recycle", tags=["Recycle Bin"])
limiter = Limiter(key_func=get_remote_address)
# Global thumbnail memory cache for recycle bin (500 items or 100MB max)
# Using shared ThumbnailLRUCache from core/utils.py
_thumbnail_cache = ThumbnailLRUCache(max_size=500, max_memory_mb=100)
@router.get("/list")
@limiter.limit("100/minute")
@handle_exceptions
async def list_recycle_bin(
request: Request,
current_user: Dict = Depends(get_current_user),
deleted_from: Optional[str] = None,
platform: Optional[str] = None,
source: Optional[str] = None,
search: Optional[str] = None,
media_type: Optional[str] = None,
date_from: Optional[str] = None,
date_to: Optional[str] = None,
size_min: Optional[int] = None,
size_max: Optional[int] = None,
sort_by: str = Query('download_date', pattern='^(deleted_at|file_size|filename|deleted_from|download_date|post_date|confidence)$'),
sort_order: str = Query('desc', pattern='^(asc|desc)$'),
limit: int = Query(100, ge=1, le=1000),
offset: int = Query(0, ge=0)
):
"""
List files in recycle bin.
Args:
deleted_from: Filter by source (downloads, media, review)
platform: Filter by platform (instagram, tiktok, etc.)
source: Filter by source/username
search: Search in filename
media_type: Filter by type (image, video)
date_from: Filter by deletion date (YYYY-MM-DD)
date_to: Filter by deletion date (YYYY-MM-DD)
size_min: Minimum file size in bytes
size_max: Maximum file size in bytes
sort_by: Column to sort by
sort_order: Sort direction (asc, desc)
limit: Maximum items to return
offset: Number of items to skip
"""
app_state = get_app_state()
db = app_state.db
if not db:
raise DatabaseError("Database not initialized")
result = db.list_recycle_bin(
deleted_from=deleted_from,
platform=platform,
source=source,
search=search,
media_type=media_type,
date_from=date_from,
date_to=date_to,
size_min=size_min,
size_max=size_max,
sort_by=sort_by,
sort_order=sort_order,
limit=limit,
offset=offset
)
return {
"success": True,
"items": result['items'],
"total": result['total']
}
@router.get("/filters")
@limiter.limit("100/minute")
@handle_exceptions
async def get_recycle_filters(
request: Request,
current_user: Dict = Depends(get_current_user),
platform: Optional[str] = None
):
"""
Get available filter options for recycle bin.
Args:
platform: If provided, only return sources for this platform
"""
app_state = get_app_state()
db = app_state.db
if not db:
raise DatabaseError("Database not initialized")
filters = db.get_recycle_bin_filters(platform=platform)
return {
"success": True,
"platforms": filters['platforms'],
"sources": filters['sources']
}
@router.get("/stats")
@limiter.limit("100/minute")
@handle_exceptions
async def get_recycle_bin_stats(request: Request, current_user: Dict = Depends(get_current_user)):
"""
Get recycle bin statistics.
Returns total count, total size, and breakdown by deleted_from source.
"""
app_state = get_app_state()
db = app_state.db
if not db:
raise DatabaseError("Database not initialized")
stats = db.get_recycle_bin_stats()
return {
"success": True,
"stats": stats,
"timestamp": now_iso8601()
}
@router.post("/restore")
@limiter.limit("20/minute")
@handle_exceptions
async def restore_from_recycle(
request: Request,
current_user: Dict = Depends(get_current_user),
recycle_id: str = Body(..., embed=True)
):
"""
Restore a file from recycle bin to its original location.
The file will be moved back to its original path and re-registered
in the file_inventory table.
"""
app_state = get_app_state()
db = app_state.db
if not db:
raise DatabaseError("Database not initialized")
success = db.restore_from_recycle_bin(recycle_id)
if success:
# Broadcast update to connected clients
try:
# app_state already retrieved above, use it for websocket broadcast
if hasattr(app_state, 'websocket_manager') and app_state.websocket_manager:
await app_state.websocket_manager.broadcast({
"type": "recycle_restore_completed",
"recycle_id": recycle_id,
"timestamp": now_iso8601()
})
except Exception:
pass # Broadcasting is optional
logger.info(f"Restored file from recycle bin: {recycle_id}", module="Recycle")
return {
"success": True,
"message": "File restored successfully",
"recycle_id": recycle_id
}
else:
raise FileOperationError(
"Failed to restore file",
{"recycle_id": recycle_id}
)
@router.delete("/delete/{recycle_id}")
@limiter.limit("20/minute")
@handle_exceptions
async def permanently_delete_from_recycle(
request: Request,
recycle_id: str,
current_user: Dict = Depends(require_admin)
):
"""
Permanently delete a file from recycle bin.
**Admin only** - This action cannot be undone. The file will be
removed from disk permanently.
"""
app_state = get_app_state()
db = app_state.db
if not db:
raise DatabaseError("Database not initialized")
success = db.permanently_delete_from_recycle_bin(recycle_id)
if success:
# Broadcast update
try:
if hasattr(app_state, 'websocket_manager') and app_state.websocket_manager:
await app_state.websocket_manager.broadcast({
"type": "recycle_delete_completed",
"recycle_id": recycle_id,
"timestamp": now_iso8601()
})
except Exception:
pass
logger.info(f"Permanently deleted file from recycle: {recycle_id}", module="Recycle")
return {
"success": True,
"message": "File permanently deleted",
"recycle_id": recycle_id
}
else:
raise FileOperationError(
"Failed to delete file",
{"recycle_id": recycle_id}
)
@router.post("/empty")
@limiter.limit("5/minute")
@handle_exceptions
async def empty_recycle_bin(
request: Request,
current_user: Dict = Depends(require_admin), # Require admin for destructive operation
older_than_days: Optional[int] = Body(None, embed=True)
):
"""
Empty recycle bin.
Args:
older_than_days: Only delete files older than X days.
If not specified, all files are deleted.
"""
app_state = get_app_state()
db = app_state.db
if not db:
raise DatabaseError("Database not initialized")
deleted_count = db.empty_recycle_bin(older_than_days=older_than_days)
# Broadcast update
try:
if hasattr(app_state, 'websocket_manager') and app_state.websocket_manager:
await app_state.websocket_manager.broadcast({
"type": "recycle_emptied",
"deleted_count": deleted_count,
"timestamp": now_iso8601()
})
except Exception:
pass
logger.info(f"Emptied recycle bin: {deleted_count} files deleted", module="Recycle")
return {
"success": True,
"deleted_count": deleted_count,
"older_than_days": older_than_days
}
@router.get("/file/{recycle_id}")
@limiter.limit("5000/minute")
@handle_exceptions
async def get_recycle_file(
request: Request,
recycle_id: str,
thumbnail: bool = False,
type: Optional[str] = None,
token: Optional[str] = None,
current_user: Dict = Depends(get_current_user_media)
):
"""
Serve a file from recycle bin for preview.
Args:
recycle_id: ID of the recycle bin record
thumbnail: If True, return a thumbnail instead of the full file
type: Media type hint (image/video)
"""
app_state = get_app_state()
db = app_state.db
if not db:
raise DatabaseError("Database not initialized")
# Get recycle bin record
with db.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
'SELECT recycle_path, original_path, original_filename, file_hash FROM recycle_bin WHERE id = ?',
(recycle_id,)
)
row = cursor.fetchone()
if not row:
raise RecordNotFoundError(
"File not found in recycle bin",
{"recycle_id": recycle_id}
)
file_path = Path(row['recycle_path'])
original_path = row['original_path'] # Path where thumbnail was originally cached
if not file_path.exists():
raise CustomFileNotFoundError(
"Physical file not found",
{"path": str(file_path)}
)
# If thumbnail requested, use 3-tier caching
# Use content hash as cache key so thumbnails survive file moves
if thumbnail:
content_hash = row['file_hash']
cache_key = content_hash if content_hash else str(file_path)
# 1. Check in-memory LRU cache first (fastest)
thumbnail_data = _thumbnail_cache.get(cache_key)
if thumbnail_data:
return Response(
content=thumbnail_data,
media_type="image/jpeg",
headers={
"Cache-Control": "public, max-age=86400, immutable",
"Vary": "Accept-Encoding"
}
)
# 2. Get from database cache or generate on-demand
# Pass content hash and original_path for fallback lookup
thumbnail_data = _get_or_create_thumbnail(file_path, type or 'image', content_hash, original_path)
if not thumbnail_data:
raise FileOperationError("Failed to generate thumbnail")
# 3. Add to in-memory cache for faster subsequent requests
_thumbnail_cache.put(cache_key, thumbnail_data)
return Response(
content=thumbnail_data,
media_type="image/jpeg",
headers={
"Cache-Control": "public, max-age=86400, immutable",
"Vary": "Accept-Encoding"
}
)
# Otherwise serve full file
mime_type, _ = mimetypes.guess_type(str(file_path))
if not mime_type:
mime_type = "application/octet-stream"
return FileResponse(
path=str(file_path),
media_type=mime_type,
filename=row['original_filename']
)
@router.get("/metadata/{recycle_id}")
@limiter.limit("5000/minute")
@handle_exceptions
async def get_recycle_metadata(
request: Request,
recycle_id: str,
current_user: Dict = Depends(get_current_user)
):
"""
Get metadata for a recycle bin file.
Returns dimensions, size, platform, source, and other metadata.
This is fetched on-demand for performance.
"""
app_state = get_app_state()
db = app_state.db
if not db:
raise DatabaseError("Database not initialized")
# Get recycle bin record
with db.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
SELECT recycle_path, original_filename, file_size, original_path, metadata
FROM recycle_bin WHERE id = ?
''', (recycle_id,))
row = cursor.fetchone()
if not row:
raise RecordNotFoundError(
"File not found in recycle bin",
{"recycle_id": recycle_id}
)
recycle_path = Path(row['recycle_path'])
if not recycle_path.exists():
raise CustomFileNotFoundError(
"Physical file not found",
{"path": str(recycle_path)}
)
# Parse metadata for platform/source info
platform, source = None, None
try:
metadata = json.loads(row['metadata']) if row['metadata'] else {}
platform = metadata.get('platform')
source = metadata.get('source')
except Exception:
pass
# Get dimensions dynamically
width, height, duration = _extract_dimensions(recycle_path)
return {
"success": True,
"recycle_id": recycle_id,
"filename": row['original_filename'],
"file_size": row['file_size'],
"platform": platform,
"source": source,
"width": width,
"height": height,
"duration": duration
}
# ============================================================================
# HELPER FUNCTIONS
# ============================================================================
def _get_or_create_thumbnail(file_path: Path, media_type: str, content_hash: str = None, original_path: str = None) -> Optional[bytes]:
"""
Get or create a thumbnail for a file.
Uses the same caching system as media.py for consistency.
Uses a 2-step lookup for backwards compatibility:
1. Try content hash (new method - survives file moves)
2. Fall back to original_path lookup (legacy thumbnails cached before move)
Args:
file_path: Path to the file (current location in recycle bin)
media_type: 'image' or 'video'
content_hash: Optional content hash (SHA256 of file content) to use for cache lookup.
original_path: Optional original file path before moving to recycle bin.
"""
from PIL import Image
import io
from datetime import datetime
try:
with sqlite3.connect('thumbnails', timeout=30.0) as conn:
cursor = conn.cursor()
# 1. Try content hash first (new method - survives file moves)
if content_hash:
cursor.execute("SELECT thumbnail_data FROM thumbnails WHERE file_hash = ?", (content_hash,))
result = cursor.fetchone()
if result:
return result[0]
# 2. Fall back to original_path lookup (legacy thumbnails cached before move)
if original_path:
cursor.execute("SELECT thumbnail_data FROM thumbnails WHERE file_path = ?", (original_path,))
result = cursor.fetchone()
if result:
return result[0]
except Exception:
pass
# Generate thumbnail
thumbnail_data = None
try:
if media_type == 'video':
# For videos, try to extract a frame
import subprocess
result = subprocess.run([
'ffmpeg', '-i', str(file_path),
'-ss', '00:00:01', '-vframes', '1',
'-f', 'image2pipe', '-vcodec', 'mjpeg', '-'
], capture_output=True, timeout=10)
if result.returncode == 0:
img = Image.open(io.BytesIO(result.stdout))
else:
return None
else:
img = Image.open(file_path)
# Convert to RGB if necessary
if img.mode in ('RGBA', 'P'):
img = img.convert('RGB')
# Create thumbnail
img.thumbnail((300, 300), Image.Resampling.LANCZOS)
# Save to bytes
output = io.BytesIO()
img.save(output, format='JPEG', quality=85)
thumbnail_data = output.getvalue()
# Cache the generated thumbnail
if thumbnail_data:
try:
file_mtime = file_path.stat().st_mtime if file_path.exists() else None
# Compute file_hash if not provided
thumb_file_hash = content_hash if content_hash else hashlib.sha256(str(file_path).encode()).hexdigest()
with sqlite3.connect('thumbnails') as conn:
conn.execute("""
INSERT OR REPLACE INTO thumbnails
(file_hash, file_path, thumbnail_data, created_at, file_mtime)
VALUES (?, ?, ?, ?, ?)
""", (thumb_file_hash, str(file_path), thumbnail_data, datetime.now().isoformat(), file_mtime))
conn.commit()
except Exception:
pass # Caching is optional, don't fail if it doesn't work
return thumbnail_data
except Exception as e:
logger.warning(f"Failed to generate thumbnail: {e}", module="Recycle")
return None
def _extract_dimensions(file_path: Path) -> tuple:
"""
Extract dimensions from a media file.
Returns: (width, height, duration)
"""
width, height, duration = None, None, None
file_ext = file_path.suffix.lower()
try:
if file_ext in ['.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp', '.heic', '.heif']:
from PIL import Image
with Image.open(file_path) as img:
width, height = img.size
elif file_ext in ['.mp4', '.mov', '.avi', '.mkv', '.webm', '.m4v']:
import subprocess
result = subprocess.run([
'ffprobe', '-v', 'quiet', '-print_format', 'json',
'-show_streams', str(file_path)
], capture_output=True, text=True, timeout=10)
if result.returncode == 0:
data = json.loads(result.stdout)
for stream in data.get('streams', []):
if stream.get('codec_type') == 'video':
width = stream.get('width')
height = stream.get('height')
duration_str = stream.get('duration')
if duration_str:
duration = float(duration_str)
break
except Exception as e:
logger.warning(f"Failed to extract dimensions: {e}", module="Recycle")
return width, height, duration

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,758 @@
"""
Scheduler Router
Handles all scheduler and service management operations:
- Scheduler status and task management
- Current activity monitoring
- Task pause/resume/skip operations
- Service start/stop/restart
- Cache builder service management
- Dependency updates
"""
import json
import os
import re
import signal
import sqlite3
import subprocess
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict
from fastapi import APIRouter, Depends, HTTPException, Request
from slowapi import Limiter
from slowapi.util import get_remote_address
from ..core.dependencies import get_current_user, require_admin, get_app_state
from ..core.config import settings
from ..core.exceptions import (
handle_exceptions,
RecordNotFoundError,
ServiceError
)
from ..core.responses import now_iso8601
from modules.universal_logger import get_logger
logger = get_logger('API')
router = APIRouter(prefix="/api/scheduler", tags=["Scheduler"])
limiter = Limiter(key_func=get_remote_address)
# Service names
SCHEDULER_SERVICE = 'media-downloader.service'
CACHE_BUILDER_SERVICE = 'media-cache-builder.service'
# Valid platform names for subprocess operations (defense in depth)
VALID_PLATFORMS = frozenset(['fastdl', 'imginn', 'imginn_api', 'toolzu', 'snapchat', 'tiktok', 'forums', 'coppermine', 'instagram', 'youtube'])
# Display name mapping for scheduler task_id prefixes
PLATFORM_DISPLAY_NAMES = {
'fastdl': 'FastDL',
'imginn': 'ImgInn',
'imginn_api': 'ImgInn API',
'toolzu': 'Toolzu',
'snapchat': 'Snapchat',
'tiktok': 'TikTok',
'forums': 'Forums',
'forum': 'Forum',
'monitor': 'Forum Monitor',
'instagram': 'Instagram',
'youtube': 'YouTube',
'youtube_channel_monitor': 'YouTube Channels',
'youtube_monitor': 'YouTube Monitor',
'coppermine': 'Coppermine',
'paid_content': 'Paid Content',
'appearances': 'Appearances',
'easynews_monitor': 'Easynews Monitor',
'press_monitor': 'Press Monitor',
}
# ============================================================================
# SCHEDULER STATUS ENDPOINTS
# ============================================================================
@router.get("/status")
@limiter.limit("100/minute")
@handle_exceptions
async def get_scheduler_status(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Get detailed scheduler status including all tasks."""
app_state = get_app_state()
# Get enabled forums from config to filter scheduler tasks
enabled_forums = set()
forums_config = app_state.settings.get('forums')
if forums_config and isinstance(forums_config, dict):
for forum_cfg in forums_config.get('configs', []):
if forum_cfg.get('enabled', False):
enabled_forums.add(forum_cfg.get('name'))
with sqlite3.connect('scheduler_state') as sched_conn:
cursor = sched_conn.cursor()
# Get all tasks
cursor.execute("""
SELECT task_id, last_run, next_run, run_count, status, last_download_count
FROM scheduler_state
ORDER BY next_run ASC
""")
tasks_raw = cursor.fetchall()
# Clean up stale forum/monitor entries
stale_task_ids = []
# Platforms that should always have :username suffix
platforms_requiring_username = {'tiktok', 'instagram', 'imginn', 'imginn_api', 'toolzu', 'snapchat', 'fastdl'}
for row in tasks_raw:
task_id = row[0]
if task_id.startswith('forum:') or task_id.startswith('monitor:'):
forum_name = task_id.split(':', 1)[1]
if forum_name not in enabled_forums:
stale_task_ids.append(task_id)
# Clean up legacy platform entries without :username suffix
elif task_id in platforms_requiring_username:
stale_task_ids.append(task_id)
# Delete stale entries
if stale_task_ids:
for stale_id in stale_task_ids:
cursor.execute("DELETE FROM scheduler_state WHERE task_id = ?", (stale_id,))
sched_conn.commit()
tasks = []
for row in tasks_raw:
task_id = row[0]
# Skip stale and maintenance tasks
if task_id in stale_task_ids:
continue
if task_id.startswith('maintenance:'):
continue
tasks.append({
"task_id": task_id,
"last_run": row[1],
"next_run": row[2],
"run_count": row[3],
"status": row[4],
"last_download_count": row[5]
})
# Count active tasks
active_count = sum(1 for t in tasks if t['status'] == 'active')
# Get next run time
next_run = None
for task in sorted(tasks, key=lambda t: t['next_run'] or ''):
if task['status'] == 'active' and task['next_run']:
next_run = task['next_run']
break
return {
"running": active_count > 0,
"tasks": tasks,
"total_tasks": len(tasks),
"active_tasks": active_count,
"next_run": next_run
}
@router.get("/current-activity")
@limiter.limit("100/minute")
@handle_exceptions
async def get_current_activity(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Get current scheduler activity for real-time status."""
app_state = get_app_state()
# Check if scheduler service is running
result = subprocess.run(
['systemctl', 'is-active', SCHEDULER_SERVICE],
capture_output=True,
text=True
)
scheduler_running = result.stdout.strip() == 'active'
if not scheduler_running:
return {
"active": False,
"scheduler_running": False,
"task_id": None,
"platform": None,
"account": None,
"start_time": None,
"status": None
}
# Get current activity from database
from modules.activity_status import get_activity_manager
activity_manager = get_activity_manager(app_state.db)
activity = activity_manager.get_current_activity()
activity["scheduler_running"] = True
return activity
@router.get("/background-tasks")
@limiter.limit("100/minute")
@handle_exceptions
async def get_background_tasks(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Get all active background tasks (YouTube monitor, etc.) for real-time status."""
app_state = get_app_state()
from modules.activity_status import get_activity_manager
activity_manager = get_activity_manager(app_state.db)
tasks = activity_manager.get_active_background_tasks()
return {"tasks": tasks}
@router.get("/background-tasks/{task_id}")
@limiter.limit("100/minute")
@handle_exceptions
async def get_background_task(
task_id: str,
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Get a specific background task status."""
app_state = get_app_state()
from modules.activity_status import get_activity_manager
activity_manager = get_activity_manager(app_state.db)
task = activity_manager.get_background_task(task_id)
if not task:
return {"active": False, "task_id": task_id}
return task
@router.post("/current-activity/stop")
@limiter.limit("20/minute")
@handle_exceptions
async def stop_current_activity(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Stop the currently running download task."""
app_state = get_app_state()
activity_file = settings.PROJECT_ROOT / 'database' / 'current_activity.json'
if not activity_file.exists():
raise RecordNotFoundError("No active task running")
with open(activity_file, 'r') as f:
activity_data = json.load(f)
if not activity_data.get('active'):
raise RecordNotFoundError("No active task running")
task_id = activity_data.get('task_id')
platform = activity_data.get('platform')
# Security: Validate platform before using in subprocess (defense in depth)
if platform and platform not in VALID_PLATFORMS:
logger.warning(f"Invalid platform in activity file: {platform}", module="Security")
platform = None
# Find and kill the process
if platform:
result = subprocess.run(
['pgrep', '-f', f'media-downloader\\.py.*--platform.*{re.escape(platform)}'],
capture_output=True,
text=True
)
else:
# Fallback: find any media-downloader process
result = subprocess.run(
['pgrep', '-f', 'media-downloader\\.py'],
capture_output=True,
text=True
)
if result.stdout.strip():
pids = [p.strip() for p in result.stdout.strip().split('\n') if p.strip()]
for pid in pids:
try:
os.kill(int(pid), signal.SIGTERM)
logger.info(f"Stopped process {pid} for platform {platform}")
except (ProcessLookupError, ValueError):
pass
# Clear the current activity
inactive_state = {
"active": False,
"task_id": None,
"platform": None,
"account": None,
"start_time": None,
"status": "stopped"
}
with open(activity_file, 'w') as f:
json.dump(inactive_state, f)
# Broadcast stop event
try:
if hasattr(app_state, 'websocket_manager') and app_state.websocket_manager:
await app_state.websocket_manager.broadcast({
"type": "download_stopped",
"task_id": task_id,
"platform": platform,
"timestamp": now_iso8601()
})
except Exception:
pass
return {
"success": True,
"message": f"Stopped {platform} download",
"task_id": task_id
}
# ============================================================================
# TASK MANAGEMENT ENDPOINTS
# ============================================================================
@router.post("/tasks/{task_id}/pause")
@limiter.limit("20/minute")
@handle_exceptions
async def pause_scheduler_task(
request: Request,
task_id: str,
current_user: Dict = Depends(get_current_user)
):
"""Pause a specific scheduler task."""
app_state = get_app_state()
with sqlite3.connect('scheduler_state') as sched_conn:
cursor = sched_conn.cursor()
cursor.execute("""
UPDATE scheduler_state
SET status = 'paused'
WHERE task_id = ?
""", (task_id,))
sched_conn.commit()
row_count = cursor.rowcount
if row_count == 0:
raise RecordNotFoundError("Task not found", {"task_id": task_id})
# Broadcast event
try:
if hasattr(app_state, 'websocket_manager') and app_state.websocket_manager:
await app_state.websocket_manager.broadcast({
"type": "scheduler_task_paused",
"task_id": task_id,
"timestamp": now_iso8601()
})
except Exception:
pass
return {"success": True, "task_id": task_id, "status": "paused"}
@router.post("/tasks/{task_id}/resume")
@limiter.limit("20/minute")
@handle_exceptions
async def resume_scheduler_task(
request: Request,
task_id: str,
current_user: Dict = Depends(get_current_user)
):
"""Resume a paused scheduler task."""
app_state = get_app_state()
with sqlite3.connect('scheduler_state') as sched_conn:
cursor = sched_conn.cursor()
cursor.execute("""
UPDATE scheduler_state
SET status = 'active'
WHERE task_id = ?
""", (task_id,))
sched_conn.commit()
row_count = cursor.rowcount
if row_count == 0:
raise RecordNotFoundError("Task not found", {"task_id": task_id})
# Broadcast event
try:
if hasattr(app_state, 'websocket_manager') and app_state.websocket_manager:
await app_state.websocket_manager.broadcast({
"type": "scheduler_task_resumed",
"task_id": task_id,
"timestamp": now_iso8601()
})
except Exception:
pass
return {"success": True, "task_id": task_id, "status": "active"}
@router.post("/tasks/{task_id}/skip")
@limiter.limit("20/minute")
@handle_exceptions
async def skip_next_run(
request: Request,
task_id: str,
current_user: Dict = Depends(get_current_user)
):
"""Skip the next scheduled run by advancing next_run time."""
app_state = get_app_state()
with sqlite3.connect('scheduler_state') as sched_conn:
cursor = sched_conn.cursor()
# Get current task info
cursor.execute("""
SELECT next_run, interval_hours
FROM scheduler_state
WHERE task_id = ?
""", (task_id,))
result = cursor.fetchone()
if not result:
raise RecordNotFoundError("Task not found", {"task_id": task_id})
current_next_run, interval_hours = result
# Calculate new next_run time
current_time = datetime.fromisoformat(current_next_run)
new_next_run = current_time + timedelta(hours=interval_hours)
# Update the next_run time
cursor.execute("""
UPDATE scheduler_state
SET next_run = ?
WHERE task_id = ?
""", (new_next_run.isoformat(), task_id))
sched_conn.commit()
# Broadcast event
try:
if hasattr(app_state, 'websocket_manager') and app_state.websocket_manager:
await app_state.websocket_manager.broadcast({
"type": "scheduler_run_skipped",
"task_id": task_id,
"new_next_run": new_next_run.isoformat(),
"timestamp": now_iso8601()
})
except Exception:
pass
return {
"success": True,
"task_id": task_id,
"skipped_run": current_next_run,
"new_next_run": new_next_run.isoformat()
}
@router.post("/tasks/{task_id}/reschedule")
@limiter.limit("20/minute")
@handle_exceptions
async def reschedule_task(
request: Request,
task_id: str,
current_user: Dict = Depends(get_current_user)
):
"""Reschedule a task to a new next_run time."""
body = await request.json()
new_next_run = body.get('next_run')
if not new_next_run:
raise HTTPException(status_code=400, detail="next_run is required")
try:
parsed = datetime.fromisoformat(new_next_run)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid datetime format")
with sqlite3.connect('scheduler_state') as sched_conn:
cursor = sched_conn.cursor()
cursor.execute(
"UPDATE scheduler_state SET next_run = ? WHERE task_id = ?",
(parsed.isoformat(), task_id)
)
sched_conn.commit()
if cursor.rowcount == 0:
raise RecordNotFoundError("Task not found", {"task_id": task_id})
# Broadcast event
try:
app_state = get_app_state()
if hasattr(app_state, 'websocket_manager') and app_state.websocket_manager:
await app_state.websocket_manager.broadcast({
"type": "scheduler_task_rescheduled",
"task_id": task_id,
"new_next_run": parsed.isoformat(),
"timestamp": now_iso8601()
})
except Exception:
pass
return {"success": True, "task_id": task_id, "new_next_run": parsed.isoformat()}
# ============================================================================
# CONFIG RELOAD ENDPOINT
# ============================================================================
@router.post("/reload-config")
@limiter.limit("10/minute")
@handle_exceptions
async def reload_scheduler_config(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Reload scheduler config — picks up new/removed accounts and interval changes."""
app_state = get_app_state()
if not hasattr(app_state, 'scheduler') or app_state.scheduler is None:
raise ServiceError("Scheduler is not running", {"service": SCHEDULER_SERVICE})
result = app_state.scheduler.reload_scheduled_tasks()
return {
"success": True,
"added": result['added'],
"removed": result['removed'],
"modified": result['modified'],
"message": (
f"Reload complete: {len(result['added'])} added, "
f"{len(result['removed'])} removed, "
f"{len(result['modified'])} modified"
)
}
# ============================================================================
# SERVICE MANAGEMENT ENDPOINTS
# ============================================================================
@router.get("/service/status")
@limiter.limit("100/minute")
@handle_exceptions
async def get_scheduler_service_status(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Check if scheduler service is running."""
result = subprocess.run(
['systemctl', 'is-active', SCHEDULER_SERVICE],
capture_output=True,
text=True
)
is_running = result.stdout.strip() == 'active'
return {
"running": is_running,
"status": result.stdout.strip()
}
@router.post("/service/start")
@limiter.limit("20/minute")
@handle_exceptions
async def start_scheduler_service(
request: Request,
current_user: Dict = Depends(require_admin) # Require admin for service operations
):
"""Start the scheduler service. Requires admin privileges."""
result = subprocess.run(
['sudo', 'systemctl', 'start', SCHEDULER_SERVICE],
capture_output=True,
text=True
)
if result.returncode != 0:
raise ServiceError(
f"Failed to start service: {result.stderr}",
{"service": SCHEDULER_SERVICE}
)
return {"success": True, "message": "Scheduler service started"}
@router.post("/service/stop")
@limiter.limit("20/minute")
@handle_exceptions
async def stop_scheduler_service(
request: Request,
current_user: Dict = Depends(require_admin) # Require admin for service operations
):
"""Stop the scheduler service. Requires admin privileges."""
result = subprocess.run(
['sudo', 'systemctl', 'stop', SCHEDULER_SERVICE],
capture_output=True,
text=True
)
if result.returncode != 0:
raise ServiceError(
f"Failed to stop service: {result.stderr}",
{"service": SCHEDULER_SERVICE}
)
return {"success": True, "message": "Scheduler service stopped"}
@router.post("/service/restart")
@limiter.limit("20/minute")
@handle_exceptions
async def restart_scheduler_service(
request: Request,
current_user: Dict = Depends(require_admin)
):
"""Restart the scheduler service. Requires admin privileges."""
result = subprocess.run(
['sudo', 'systemctl', 'restart', SCHEDULER_SERVICE],
capture_output=True,
text=True
)
if result.returncode != 0:
raise ServiceError(
f"Failed to restart service: {result.stderr}",
{"service": SCHEDULER_SERVICE}
)
return {"success": True, "message": "Scheduler service restarted"}
# ============================================================================
# DEPENDENCY MANAGEMENT ENDPOINTS
# ============================================================================
@router.get("/dependencies/status")
@limiter.limit("100/minute")
@handle_exceptions
async def get_dependencies_status(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Get dependency update status."""
from modules.dependency_updater import DependencyUpdater
updater = DependencyUpdater(scheduler_mode=False)
status = updater.get_update_status()
return status
@router.post("/dependencies/check")
@limiter.limit("20/minute")
@handle_exceptions
async def check_dependencies(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Force check and update all dependencies."""
app_state = get_app_state()
from modules.dependency_updater import DependencyUpdater
from modules.pushover_notifier import create_notifier_from_config
# Get pushover config
pushover = None
config = app_state.settings.get_all()
if config.get('pushover', {}).get('enabled'):
pushover = create_notifier_from_config(config, unified_db=app_state.db)
updater = DependencyUpdater(
config=config.get('dependency_updater', {}),
pushover_notifier=pushover,
scheduler_mode=True
)
results = updater.force_update_check()
return {
"success": True,
"results": results,
"message": "Dependency check completed"
}
# ============================================================================
# CACHE BUILDER SERVICE ENDPOINTS
# ============================================================================
@router.post("/cache-builder/trigger")
@limiter.limit("10/minute")
@handle_exceptions
async def trigger_cache_builder(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Manually trigger the thumbnail cache builder service."""
result = subprocess.run(
['sudo', 'systemctl', 'start', CACHE_BUILDER_SERVICE],
capture_output=True,
text=True,
timeout=5
)
if result.returncode == 0:
return {"success": True, "message": "Cache builder started successfully"}
else:
raise ServiceError(
f"Failed to start cache builder: {result.stderr}",
{"service": CACHE_BUILDER_SERVICE}
)
@router.get("/cache-builder/status")
@limiter.limit("30/minute")
@handle_exceptions
async def get_cache_builder_status(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Get detailed cache builder service status."""
# Get service status
result = subprocess.run(
['systemctl', 'status', CACHE_BUILDER_SERVICE, '--no-pager'],
capture_output=True,
text=True
)
status_output = result.stdout
# Parse status
is_running = 'Active: active (running)' in status_output
is_inactive = 'Active: inactive' in status_output
last_run = None
next_run = None
# Try to get timer info
timer_result = subprocess.run(
['systemctl', 'list-timers', '--no-pager', '--all'],
capture_output=True,
text=True
)
if CACHE_BUILDER_SERVICE.replace('.service', '') in timer_result.stdout:
for line in timer_result.stdout.split('\n'):
if CACHE_BUILDER_SERVICE.replace('.service', '') in line:
parts = line.split()
if len(parts) >= 2:
# Extract timing info if available
pass
return {
"running": is_running,
"inactive": is_inactive,
"status_output": status_output[:500], # Truncate for brevity
"last_run": last_run,
"next_run": next_run
}

View File

@@ -0,0 +1,819 @@
"""
Scrapers Router
Handles scraper management and error monitoring:
- Scraper configuration (list, get, update)
- Cookie management (test connection, upload, clear)
- Error tracking (recent, count, dismiss, mark viewed)
"""
import json
import re
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional
import requests
from fastapi import APIRouter, Body, Depends, Query, Request
from pydantic import BaseModel
from slowapi import Limiter
from slowapi.util import get_remote_address
from ..core.dependencies import get_current_user, require_admin, get_app_state
from ..core.exceptions import handle_exceptions, NotFoundError, ValidationError
from modules.universal_logger import get_logger
logger = get_logger('API')
router = APIRouter(prefix="/api", tags=["Scrapers"])
limiter = Limiter(key_func=get_remote_address)
# ============================================================================
# PYDANTIC MODELS
# ============================================================================
class ScraperUpdate(BaseModel):
enabled: Optional[bool] = None
proxy_enabled: Optional[bool] = None
proxy_url: Optional[str] = None
flaresolverr_required: Optional[bool] = None
base_url: Optional[str] = None
class CookieUpload(BaseModel):
cookies: List[dict]
merge: bool = True
user_agent: Optional[str] = None
class DismissErrors(BaseModel):
error_ids: Optional[List[int]] = None
dismiss_all: bool = False
class MarkErrorsViewed(BaseModel):
error_ids: Optional[List[int]] = None
mark_all: bool = False
# ============================================================================
# SCRAPER ENDPOINTS
# ============================================================================
@router.get("/scrapers")
@limiter.limit("60/minute")
@handle_exceptions
async def get_scrapers(
request: Request,
current_user: Dict = Depends(get_current_user),
type_filter: Optional[str] = Query(None, alias="type", description="Filter by type")
):
"""Get all scrapers with optional type filter."""
app_state = get_app_state()
scrapers = app_state.db.get_all_scrapers(type_filter=type_filter)
# Filter out scrapers whose related modules are all hidden
hidden_modules = app_state.config.get('hidden_modules', [])
if hidden_modules:
# Map scraper IDs to the modules that use them.
# A scraper is only hidden if ALL related modules are hidden.
scraper_to_modules = {
'instagram': ['instagram', 'instagram_client'],
'snapchat': ['snapchat', 'snapchat_client'],
'fastdl': ['fastdl'],
'imginn': ['imginn'],
'toolzu': ['toolzu'],
'tiktok': ['tiktok'],
'coppermine': ['coppermine'],
}
# Forum scrapers map to the 'forums' module
filtered = []
for scraper in scrapers:
sid = scraper.get('id', '')
if sid.startswith('forum_'):
related = ['forums']
else:
related = scraper_to_modules.get(sid, [])
# Only hide if ALL related modules are hidden
if related and all(m in hidden_modules for m in related):
continue
filtered.append(scraper)
scrapers = filtered
# Don't send cookies_json to frontend (too large)
for scraper in scrapers:
if 'cookies_json' in scraper:
del scraper['cookies_json']
return {"scrapers": scrapers}
# ============================================================================
# PLATFORM CREDENTIALS (UNIFIED COOKIE MANAGEMENT)
# ============================================================================
# Platform definitions for the unified credentials view
_SCRAPER_PLATFORMS = [
{'id': 'instagram', 'name': 'Instagram', 'type': 'cookies', 'source': 'scraper', 'used_by': ['Scheduler']},
{'id': 'tiktok', 'name': 'TikTok', 'type': 'cookies', 'source': 'scraper', 'used_by': ['Scheduler']},
{'id': 'snapchat', 'name': 'Snapchat', 'type': 'cookies', 'source': 'scraper', 'used_by': ['Scheduler']},
{'id': 'ytdlp', 'name': 'YouTube', 'type': 'cookies', 'source': 'scraper', 'used_by': ['Scheduler']},
{'id': 'pornhub', 'name': 'PornHub', 'type': 'cookies', 'source': 'scraper', 'used_by': ['Scheduler']},
{'id': 'xhamster', 'name': 'xHamster', 'type': 'cookies', 'source': 'scraper', 'used_by': ['Scheduler']},
]
_PAID_CONTENT_PLATFORMS = [
{'id': 'onlyfans_direct', 'name': 'OnlyFans', 'type': 'token', 'source': 'paid_content', 'used_by': ['Paid Content'], 'base_url': 'https://onlyfans.com'},
{'id': 'fansly_direct', 'name': 'Fansly', 'type': 'token', 'source': 'paid_content', 'used_by': ['Paid Content'], 'base_url': 'https://fansly.com'},
{'id': 'coomer', 'name': 'Coomer', 'type': 'session', 'source': 'paid_content', 'used_by': ['Paid Content'], 'base_url': 'https://coomer.su'},
{'id': 'kemono', 'name': 'Kemono', 'type': 'session', 'source': 'paid_content', 'used_by': ['Paid Content'], 'base_url': 'https://kemono.su'},
{'id': 'twitch', 'name': 'Twitch', 'type': 'session', 'source': 'paid_content', 'used_by': ['Paid Content'], 'base_url': 'https://twitch.tv'},
{'id': 'bellazon', 'name': 'Bellazon', 'type': 'session', 'source': 'paid_content', 'used_by': ['Paid Content'], 'base_url': 'https://www.bellazon.com'},
]
@router.get("/scrapers/platform-credentials")
@limiter.limit("30/minute")
@handle_exceptions
async def get_platform_credentials(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Get aggregated credential status for all platforms + monitoring preferences."""
app_state = get_app_state()
db = app_state.db
platforms = []
def _get_monitoring_flag(platform_id: str) -> bool:
"""Read monitoring preference from settings."""
try:
val = app_state.settings.get(f"cookie_monitoring:{platform_id}")
if val is not None:
return str(val).lower() not in ('false', '0', 'no')
except Exception:
pass
return True
# 1. Scraper platforms
for platform_def in _SCRAPER_PLATFORMS:
scraper = db.get_scraper(platform_def['id'])
cookies_count = 0
updated_at = None
if scraper:
raw = scraper.get('cookies_json')
if raw:
try:
data = json.loads(raw)
if isinstance(data, list):
cookies_count = len(data)
elif isinstance(data, dict):
c = data.get('cookies', [])
cookies_count = len(c) if isinstance(c, list) else 0
except (json.JSONDecodeError, TypeError):
pass
updated_at = scraper.get('cookies_updated_at')
monitoring_enabled = _get_monitoring_flag(platform_def['id'])
platforms.append({
'id': platform_def['id'],
'name': platform_def['name'],
'type': platform_def['type'],
'source': platform_def['source'],
'cookies_count': cookies_count,
'has_credentials': cookies_count > 0,
'updated_at': updated_at,
'used_by': platform_def['used_by'],
'monitoring_enabled': monitoring_enabled,
})
# 2. Paid content platforms
try:
from modules.paid_content import PaidContentDBAdapter
paid_db = PaidContentDBAdapter(db)
paid_services = {svc['id']: svc for svc in paid_db.get_services()}
except Exception:
paid_services = {}
for platform_def in _PAID_CONTENT_PLATFORMS:
svc = paid_services.get(platform_def['id'], {})
session_val = svc.get('session_cookie') or ''
has_creds = bool(session_val)
updated_at = svc.get('session_updated_at')
# Count credentials: for JSON objects count keys, for JSON arrays count items, otherwise 1 if set
cookies_count = 0
if has_creds:
try:
parsed = json.loads(session_val)
if isinstance(parsed, dict):
cookies_count = len(parsed)
elif isinstance(parsed, list):
cookies_count = len(parsed)
else:
cookies_count = 1
except (json.JSONDecodeError, TypeError):
cookies_count = 1
platforms.append({
'id': platform_def['id'],
'name': platform_def['name'],
'type': platform_def['type'],
'source': platform_def['source'],
'base_url': platform_def.get('base_url'),
'cookies_count': cookies_count,
'has_credentials': has_creds,
'updated_at': updated_at,
'used_by': platform_def['used_by'],
'monitoring_enabled': _get_monitoring_flag(platform_def['id']),
})
# 3. Reddit (private gallery)
reddit_has_creds = False
reddit_cookies_count = 0
reddit_locked = True
try:
from modules.reddit_community_monitor import RedditCommunityMonitor, REDDIT_MONITOR_KEY_FILE
from modules.private_gallery_crypto import get_private_gallery_crypto, load_key_from_file
db_path = str(Path(__file__).parent.parent.parent.parent / 'database' / 'media_downloader.db')
reddit_monitor = RedditCommunityMonitor(db_path)
crypto = get_private_gallery_crypto()
reddit_locked = not crypto.is_initialized()
# If gallery is locked, try loading crypto from key file (exported on unlock)
active_crypto = crypto if not reddit_locked else load_key_from_file(REDDIT_MONITOR_KEY_FILE)
if active_crypto and active_crypto.is_initialized():
reddit_has_creds = reddit_monitor.has_cookies(active_crypto)
if reddit_has_creds:
try:
conn = reddit_monitor._get_connection()
cursor = conn.cursor()
cursor.execute("SELECT value FROM private_media_config WHERE key = 'reddit_monitor_encrypted_cookies'")
row = cursor.fetchone()
conn.close()
if row and row['value']:
decrypted = active_crypto.decrypt_field(row['value'])
parsed = json.loads(decrypted)
if isinstance(parsed, list):
reddit_cookies_count = len(parsed)
except Exception:
reddit_cookies_count = 1 if reddit_has_creds else 0
except Exception:
pass
platforms.append({
'id': 'reddit',
'name': 'Reddit',
'type': 'cookies',
'source': 'private_gallery',
'base_url': 'https://reddit.com',
'cookies_count': reddit_cookies_count,
'has_credentials': reddit_has_creds,
'gallery_locked': reddit_locked,
'updated_at': None,
'used_by': ['Private Gallery'],
'monitoring_enabled': _get_monitoring_flag('reddit'),
})
return {
'platforms': platforms,
'global_monitoring_enabled': _get_monitoring_flag('global'),
}
@router.put("/scrapers/platform-credentials/{platform_id}/monitoring")
@limiter.limit("30/minute")
@handle_exceptions
async def toggle_platform_monitoring(
request: Request,
platform_id: str,
current_user: Dict = Depends(require_admin)
):
"""Toggle health monitoring for a single platform."""
app_state = get_app_state()
body = await request.json()
enabled = body.get('enabled', True)
app_state.settings.set(
key=f"cookie_monitoring:{platform_id}",
value=str(enabled).lower(),
category="cookie_monitoring",
updated_by=current_user.get('username', 'user')
)
return {
'success': True,
'message': f"Monitoring {'enabled' if enabled else 'disabled'} for {platform_id}",
}
@router.put("/scrapers/platform-credentials/monitoring")
@limiter.limit("30/minute")
@handle_exceptions
async def toggle_global_monitoring(
request: Request,
current_user: Dict = Depends(require_admin)
):
"""Toggle global cookie health monitoring."""
app_state = get_app_state()
body = await request.json()
enabled = body.get('enabled', True)
app_state.settings.set(
key="cookie_monitoring:global",
value=str(enabled).lower(),
category="cookie_monitoring",
updated_by=current_user.get('username', 'user')
)
return {
'success': True,
'message': f"Global cookie monitoring {'enabled' if enabled else 'disabled'}",
}
@router.get("/scrapers/{scraper_id}")
@limiter.limit("60/minute")
@handle_exceptions
async def get_scraper(
request: Request,
scraper_id: str,
current_user: Dict = Depends(get_current_user)
):
"""Get a single scraper configuration."""
app_state = get_app_state()
scraper = app_state.db.get_scraper(scraper_id)
if not scraper:
raise NotFoundError(f"Scraper '{scraper_id}' not found")
if 'cookies_json' in scraper:
del scraper['cookies_json']
cookies = app_state.db.get_scraper_cookies(scraper_id)
scraper['cookies_count'] = len(cookies) if cookies else 0
return scraper
@router.put("/scrapers/{scraper_id}")
@limiter.limit("30/minute")
@handle_exceptions
async def update_scraper(
request: Request,
scraper_id: str,
current_user: Dict = Depends(require_admin)
):
"""Update scraper settings (proxy, enabled, base_url)."""
app_state = get_app_state()
body = await request.json()
scraper = app_state.db.get_scraper(scraper_id)
if not scraper:
raise NotFoundError(f"Scraper '{scraper_id}' not found")
success = app_state.db.update_scraper(scraper_id, body)
if not success:
raise ValidationError("No valid fields to update")
return {"success": True, "message": f"Scraper '{scraper_id}' updated"}
@router.post("/scrapers/{scraper_id}/test")
@limiter.limit("10/minute")
@handle_exceptions
async def test_scraper_connection(
request: Request,
scraper_id: str,
current_user: Dict = Depends(require_admin)
):
"""
Test scraper connection via FlareSolverr (if required).
On success, saves cookies to database.
For CLI tools (yt-dlp, gallery-dl), tests that the tool is installed and working.
"""
import subprocess
from modules.cloudflare_handler import CloudflareHandler
app_state = get_app_state()
scraper = app_state.db.get_scraper(scraper_id)
if not scraper:
raise NotFoundError(f"Scraper '{scraper_id}' not found")
# Handle CLI tools specially - test that they're installed and working
if scraper.get('type') == 'cli_tool':
cli_tests = {
'ytdlp': {
'cmd': ['/opt/media-downloader/venv/bin/yt-dlp', '--version'],
'name': 'yt-dlp'
},
'gallerydl': {
'cmd': ['/opt/media-downloader/venv/bin/gallery-dl', '--version'],
'name': 'gallery-dl'
}
}
test_config = cli_tests.get(scraper_id)
if test_config:
try:
result = subprocess.run(
test_config['cmd'],
capture_output=True,
text=True,
timeout=10
)
if result.returncode == 0:
version = result.stdout.strip().split('\n')[0]
cookies_count = 0
# Check if cookies are configured
if scraper.get('cookies_json'):
try:
import json
data = json.loads(scraper['cookies_json'])
# Support both {"cookies": [...]} and [...] formats
if isinstance(data, dict) and 'cookies' in data:
cookies = data['cookies']
elif isinstance(data, list):
cookies = data
else:
cookies = []
cookies_count = len(cookies) if cookies else 0
except (json.JSONDecodeError, TypeError, KeyError) as e:
logger.debug(f"Failed to parse cookies for {scraper_id}: {e}")
app_state.db.update_scraper_test_status(scraper_id, 'success')
msg = f"{test_config['name']} v{version} installed"
if cookies_count > 0:
msg += f", {cookies_count} cookies configured"
return {
"success": True,
"message": msg
}
else:
error_msg = result.stderr.strip() or "Command failed"
app_state.db.update_scraper_test_status(scraper_id, 'failed', error_msg)
return {
"success": False,
"message": f"{test_config['name']} error: {error_msg}"
}
except subprocess.TimeoutExpired:
app_state.db.update_scraper_test_status(scraper_id, 'failed', "Command timed out")
return {"success": False, "message": "Command timed out"}
except FileNotFoundError:
app_state.db.update_scraper_test_status(scraper_id, 'failed', "Tool not installed")
return {"success": False, "message": f"{test_config['name']} not installed"}
else:
# Unknown CLI tool
app_state.db.update_scraper_test_status(scraper_id, 'success')
return {"success": True, "message": "CLI tool registered"}
base_url = scraper.get('base_url')
if not base_url:
raise ValidationError(f"Scraper '{scraper_id}' has no base_url configured")
proxy_url = None
if scraper.get('proxy_enabled') and scraper.get('proxy_url'):
proxy_url = scraper['proxy_url']
cf_handler = CloudflareHandler(
module_name=scraper_id,
cookie_file=None,
proxy_url=proxy_url if proxy_url else None,
flaresolverr_enabled=scraper.get('flaresolverr_required', False)
)
if scraper.get('flaresolverr_required'):
success = cf_handler.get_cookies_via_flaresolverr(base_url, max_retries=2)
if success:
cookies = cf_handler.get_cookies_list()
user_agent = cf_handler.get_user_agent()
app_state.db.save_scraper_cookies(scraper_id, cookies, user_agent=user_agent)
app_state.db.update_scraper_test_status(scraper_id, 'success')
return {
"success": True,
"message": f"Connection successful, {len(cookies)} cookies saved",
"cookies_count": len(cookies)
}
else:
error_msg = "FlareSolverr returned no cookies"
if proxy_url:
error_msg += " (check proxy connection)"
app_state.db.update_scraper_test_status(scraper_id, 'failed', error_msg)
return {
"success": False,
"message": error_msg
}
else:
try:
proxies = {"http": proxy_url, "https": proxy_url} if proxy_url else None
response = requests.get(
base_url,
timeout=10,
proxies=proxies,
headers={'User-Agent': cf_handler.user_agent}
)
if response.status_code < 400:
app_state.db.update_scraper_test_status(scraper_id, 'success')
return {
"success": True,
"message": f"Connection successful (HTTP {response.status_code})"
}
else:
app_state.db.update_scraper_test_status(
scraper_id, 'failed',
f"HTTP {response.status_code}"
)
return {
"success": False,
"message": f"Connection failed with HTTP {response.status_code}"
}
except requests.exceptions.Timeout:
app_state.db.update_scraper_test_status(scraper_id, 'timeout', 'Request timed out')
return {"success": False, "message": "Connection timed out"}
except Exception as e:
app_state.db.update_scraper_test_status(scraper_id, 'failed', str(e))
return {"success": False, "message": str(e)}
@router.post("/scrapers/{scraper_id}/cookies")
@limiter.limit("20/minute")
@handle_exceptions
async def upload_scraper_cookies(
request: Request,
scraper_id: str,
current_user: Dict = Depends(require_admin)
):
"""Upload cookies for a scraper (from browser extension export)."""
app_state = get_app_state()
scraper = app_state.db.get_scraper(scraper_id)
if not scraper:
raise NotFoundError(f"Scraper '{scraper_id}' not found")
body = await request.json()
# Support both {cookies: [...]} and bare [...] formats
if isinstance(body, list):
cookies = body
merge = True
user_agent = None
else:
cookies = body.get('cookies', [])
merge = body.get('merge', True)
user_agent = body.get('user_agent')
if not cookies or not isinstance(cookies, list):
raise ValidationError("Invalid cookies format. Expected {cookies: [...]}")
for i, cookie in enumerate(cookies):
if not isinstance(cookie, dict):
raise ValidationError(f"Cookie {i} is not an object")
if 'name' not in cookie or 'value' not in cookie:
raise ValidationError(f"Cookie {i} missing 'name' or 'value'")
success = app_state.db.save_scraper_cookies(
scraper_id, cookies,
user_agent=user_agent,
merge=merge
)
if success:
all_cookies = app_state.db.get_scraper_cookies(scraper_id)
count = len(all_cookies) if all_cookies else 0
return {
"success": True,
"message": f"{'Merged' if merge else 'Replaced'} {len(cookies)} cookies (total: {count})",
"cookies_count": count
}
else:
raise ValidationError("Failed to save cookies")
@router.delete("/scrapers/{scraper_id}/cookies")
@limiter.limit("20/minute")
@handle_exceptions
async def clear_scraper_cookies(
request: Request,
scraper_id: str,
current_user: Dict = Depends(require_admin)
):
"""Clear all cookies for a scraper."""
app_state = get_app_state()
scraper = app_state.db.get_scraper(scraper_id)
if not scraper:
raise NotFoundError(f"Scraper '{scraper_id}' not found")
success = app_state.db.clear_scraper_cookies(scraper_id)
return {
"success": success,
"message": f"Cookies cleared for '{scraper_id}'" if success else "Failed to clear cookies"
}
# ============================================================================
# ERROR MONITORING ENDPOINTS
# ============================================================================
@router.get("/errors/recent")
@limiter.limit("30/minute")
@handle_exceptions
async def get_recent_errors(
request: Request,
limit: int = Query(50, ge=1, le=500, description="Maximum number of errors to return"),
since_visit: bool = Query(False, description="Only show errors since last dashboard visit (default: show ALL unviewed)"),
include_dismissed: bool = Query(False, description="Include dismissed errors"),
current_user: Dict = Depends(get_current_user)
):
"""Get recent errors from database.
By default, shows ALL unviewed/undismissed errors regardless of when they occurred.
This ensures errors are not missed just because the user visited the dashboard.
Errors are recorded in real-time by universal_logger.py.
"""
app_state = get_app_state()
# By default, show ALL unviewed errors (since=None)
# Only filter by visit time if explicitly requested
since = None
if since_visit:
since = app_state.db.get_last_dashboard_visit()
if not since:
since = datetime.now() - timedelta(hours=24)
errors = app_state.db.get_recent_errors(since=since, include_dismissed=include_dismissed, limit=limit)
return {
"errors": errors,
"total_count": len(errors),
"since": since.isoformat() if since else None,
"unviewed_count": app_state.db.get_unviewed_error_count(since=None) # Always count ALL unviewed
}
@router.get("/errors/count")
@limiter.limit("60/minute")
@handle_exceptions
async def get_error_count(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Get count of ALL unviewed/undismissed errors.
Errors are recorded in real-time by universal_logger.py.
"""
app_state = get_app_state()
# Count ALL unviewed errors
total_unviewed = app_state.db.get_unviewed_error_count(since=None)
# Count errors since last dashboard visit
last_visit = app_state.db.get_last_dashboard_visit()
since_last_visit = app_state.db.get_unviewed_error_count(since=last_visit) if last_visit else total_unviewed
return {
"unviewed_count": total_unviewed,
"total_recent": total_unviewed,
"since_last_visit": since_last_visit
}
@router.post("/errors/dismiss")
@limiter.limit("20/minute")
@handle_exceptions
async def dismiss_errors(
request: Request,
body: Dict = Body(...),
current_user: Dict = Depends(get_current_user)
):
"""Dismiss errors by ID or all."""
app_state = get_app_state()
error_ids = body.get("error_ids", [])
dismiss_all = body.get("dismiss_all", False)
if dismiss_all:
dismissed = app_state.db.dismiss_errors(dismiss_all=True)
elif error_ids:
dismissed = app_state.db.dismiss_errors(error_ids=error_ids)
else:
return {"success": False, "dismissed": 0, "message": "No errors specified"}
return {
"success": True,
"dismissed": dismissed,
"message": f"Dismissed {dismissed} error(s)"
}
@router.post("/errors/mark-viewed")
@limiter.limit("20/minute")
@handle_exceptions
async def mark_errors_viewed(
request: Request,
body: Dict = Body(...),
current_user: Dict = Depends(get_current_user)
):
"""Mark errors as viewed."""
app_state = get_app_state()
error_ids = body.get("error_ids", [])
mark_all = body.get("mark_all", False)
if mark_all:
marked = app_state.db.mark_errors_viewed(mark_all=True)
elif error_ids:
marked = app_state.db.mark_errors_viewed(error_ids=error_ids)
else:
return {"success": False, "marked": 0}
return {
"success": True,
"marked": marked
}
@router.post("/errors/update-visit")
@limiter.limit("30/minute")
@handle_exceptions
async def update_dashboard_visit(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Update the last dashboard visit timestamp."""
app_state = get_app_state()
success = app_state.db.update_dashboard_visit()
return {"success": success}
@router.get("/logs/context")
@limiter.limit("30/minute")
@handle_exceptions
async def get_log_context(
request: Request,
timestamp: str = Query(..., description="ISO timestamp of the error"),
module: Optional[str] = Query(None, description="Module name to filter"),
minutes_before: int = Query(1, description="Minutes of context before error"),
minutes_after: int = Query(1, description="Minutes of context after error"),
current_user: Dict = Depends(get_current_user)
):
"""Get log lines around a specific timestamp for debugging context."""
target_time = datetime.fromisoformat(timestamp)
start_time = target_time - timedelta(minutes=minutes_before)
end_time = target_time + timedelta(minutes=minutes_after)
log_dir = Path('/opt/media-downloader/logs')
log_pattern = re.compile(
r'^(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}) '
r'\[MediaDownloader\.(\w+)\] '
r'\[(\w+)\] '
r'\[(\w+)\] '
r'(.+)$'
)
date_str = target_time.strftime('%Y%m%d')
matching_lines = []
for log_file in log_dir.glob(f'{date_str}_*.log'):
if module and module.lower() not in log_file.stem.lower():
continue
try:
lines = log_file.read_text(errors='replace').splitlines()
for line in lines:
match = log_pattern.match(line)
if match:
timestamp_str, _, log_module, level, message = match.groups()
try:
line_time = datetime.strptime(timestamp_str, '%Y-%m-%d %H:%M:%S')
if start_time <= line_time <= end_time:
matching_lines.append({
'timestamp': timestamp_str,
'module': log_module,
'level': level,
'message': message,
'is_target': abs((line_time - target_time).total_seconds()) < 2
})
except ValueError:
continue
except Exception:
continue
matching_lines.sort(key=lambda x: x['timestamp'])
return {
"context": matching_lines,
"target_timestamp": timestamp,
"range": {
"start": start_time.isoformat(),
"end": end_time.isoformat()
}
}

View File

@@ -0,0 +1,366 @@
"""
Semantic Search Router
Handles CLIP-based semantic search operations:
- Text-based image/video search
- Similar file search
- Embedding generation and management
- Model settings
"""
import asyncio
import time
from typing import Dict, Optional
from fastapi import APIRouter, BackgroundTasks, Body, Depends, Query, Request
from pydantic import BaseModel
from slowapi import Limiter
from slowapi.util import get_remote_address
from ..core.dependencies import get_current_user, require_admin, get_app_state
from ..core.exceptions import handle_exceptions, ValidationError
from modules.semantic_search import get_semantic_search
from modules.universal_logger import get_logger
logger = get_logger('API')
router = APIRouter(prefix="/api/semantic", tags=["Semantic Search"])
limiter = Limiter(key_func=get_remote_address)
# Batch limit for embedding generation
EMBEDDING_BATCH_LIMIT = 10000
# ============================================================================
# PYDANTIC MODELS
# ============================================================================
class SemanticSearchRequest(BaseModel):
query: str
limit: int = 50
platform: Optional[str] = None
source: Optional[str] = None
threshold: float = 0.2
class GenerateEmbeddingsRequest(BaseModel):
limit: int = 100
platform: Optional[str] = None
class SemanticSettingsUpdate(BaseModel):
model: Optional[str] = None
threshold: Optional[float] = None
# ============================================================================
# SEARCH ENDPOINTS
# ============================================================================
@router.get("/stats")
@limiter.limit("120/minute")
@handle_exceptions
async def get_semantic_stats(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Get statistics about semantic search embeddings."""
app_state = get_app_state()
search = get_semantic_search(app_state.db)
stats = search.get_embedding_stats()
return stats
@router.post("/search")
@limiter.limit("30/minute")
@handle_exceptions
async def semantic_search(
request: Request,
current_user: Dict = Depends(get_current_user),
query: str = Body(..., embed=True),
limit: int = Body(50, ge=1, le=200),
platform: Optional[str] = Body(None),
source: Optional[str] = Body(None),
threshold: float = Body(0.2, ge=0.0, le=1.0)
):
"""Search for images/videos using natural language query."""
app_state = get_app_state()
search = get_semantic_search(app_state.db)
results = search.search_by_text(
query=query,
limit=limit,
platform=platform,
source=source,
threshold=threshold
)
return {"results": results, "count": len(results), "query": query}
@router.post("/similar/{file_id}")
@limiter.limit("30/minute")
@handle_exceptions
async def find_similar_files(
request: Request,
file_id: int,
current_user: Dict = Depends(get_current_user),
limit: int = Query(50, ge=1, le=200),
platform: Optional[str] = Query(None),
source: Optional[str] = Query(None),
threshold: float = Query(0.5, ge=0.0, le=1.0)
):
"""Find files similar to a given file."""
app_state = get_app_state()
search = get_semantic_search(app_state.db)
results = search.search_by_file_id(
file_id=file_id,
limit=limit,
platform=platform,
source=source,
threshold=threshold
)
return {"results": results, "count": len(results), "source_file_id": file_id}
# ============================================================================
# EMBEDDING GENERATION ENDPOINTS
# ============================================================================
@router.post("/generate")
@limiter.limit("10/minute")
@handle_exceptions
async def generate_embeddings(
request: Request,
background_tasks: BackgroundTasks,
current_user: Dict = Depends(get_current_user),
limit: int = Body(100, ge=1, le=1000),
platform: Optional[str] = Body(None)
):
"""Generate CLIP embeddings for files that don't have them yet."""
app_state = get_app_state()
if app_state.indexing_running:
return {
"success": False,
"message": "Indexing already in progress",
"already_running": True,
"status": "already_running"
}
search = get_semantic_search(app_state.db)
app_state.indexing_running = True
app_state.indexing_start_time = time.time()
loop = asyncio.get_event_loop()
manager = getattr(app_state, 'websocket_manager', None)
def progress_callback(processed: int, total: int, current_file: str):
try:
if processed % 10 == 0 or processed == total:
stats = search.get_embedding_stats()
if manager:
asyncio.run_coroutine_threadsafe(
manager.broadcast({
"type": "embedding_progress",
"processed": processed,
"total": total,
"current_file": current_file,
"total_embeddings": stats.get('total_embeddings', 0),
"coverage_percent": stats.get('coverage_percent', 0)
}),
loop
)
except Exception:
pass
def run_generation():
try:
results = search.generate_embeddings_batch(
limit=limit,
platform=platform,
progress_callback=progress_callback
)
logger.info(f"Embedding generation complete: {results}", module="SemanticSearch")
try:
stats = search.get_embedding_stats()
if manager:
asyncio.run_coroutine_threadsafe(
manager.broadcast({
"type": "embedding_complete",
"results": results,
"total_embeddings": stats.get('total_embeddings', 0),
"coverage_percent": stats.get('coverage_percent', 0)
}),
loop
)
except Exception:
pass
except Exception as e:
logger.error(f"Embedding generation failed: {e}", module="SemanticSearch")
finally:
app_state.indexing_running = False
app_state.indexing_start_time = None
background_tasks.add_task(run_generation)
return {
"message": f"Started embedding generation for up to {limit} files",
"status": "processing"
}
@router.get("/status")
@limiter.limit("60/minute")
@handle_exceptions
async def get_semantic_indexing_status(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Check if semantic indexing is currently running."""
app_state = get_app_state()
elapsed = None
if app_state.indexing_running and app_state.indexing_start_time:
elapsed = int(time.time() - app_state.indexing_start_time)
return {
"indexing_running": app_state.indexing_running,
"elapsed_seconds": elapsed
}
# ============================================================================
# SETTINGS ENDPOINTS
# ============================================================================
@router.get("/settings")
@limiter.limit("30/minute")
@handle_exceptions
async def get_semantic_settings(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Get semantic search settings."""
app_state = get_app_state()
settings = app_state.settings.get('semantic_search', {})
return {
"model": settings.get('model', 'clip-ViT-B-32'),
"threshold": settings.get('threshold', 0.2)
}
@router.post("/settings")
@limiter.limit("10/minute")
@handle_exceptions
async def update_semantic_settings(
request: Request,
background_tasks: BackgroundTasks,
current_user: Dict = Depends(require_admin),
model: str = Body(None, embed=True),
threshold: float = Body(None, embed=True)
):
"""Update semantic search settings."""
app_state = get_app_state()
current_settings = app_state.settings.get('semantic_search', {}) or {}
old_model = current_settings.get('model', 'clip-ViT-B-32') if isinstance(current_settings, dict) else 'clip-ViT-B-32'
model_changed = False
new_settings = dict(current_settings) if isinstance(current_settings, dict) else {}
if model:
valid_models = ['clip-ViT-B-32', 'clip-ViT-B-16', 'clip-ViT-L-14']
if model not in valid_models:
raise ValidationError(f"Invalid model. Must be one of: {valid_models}")
if model != old_model:
model_changed = True
new_settings['model'] = model
if threshold is not None:
if threshold < 0 or threshold > 1:
raise ValidationError("Threshold must be between 0 and 1")
new_settings['threshold'] = threshold
app_state.settings.set('semantic_search', new_settings, category='ai')
if model_changed:
logger.info(f"Model changed from {old_model} to {model}, clearing embeddings", module="SemanticSearch")
with app_state.db.get_connection(for_write=True) as conn:
cursor = conn.cursor()
cursor.execute('DELETE FROM content_embeddings')
deleted = cursor.rowcount
logger.info(f"Cleared {deleted} embeddings for model change", module="SemanticSearch")
def run_reindex_for_model_change():
try:
search = get_semantic_search(app_state.db, force_reload=True)
results = search.generate_embeddings_batch(limit=EMBEDDING_BATCH_LIMIT)
logger.info(f"Model change re-index complete: {results}", module="SemanticSearch")
except Exception as e:
logger.error(f"Model change re-index failed: {e}", module="SemanticSearch")
background_tasks.add_task(run_reindex_for_model_change)
logger.info(f"Semantic search settings updated: {new_settings} (re-indexing started)", module="SemanticSearch")
return {"success": True, "settings": new_settings, "reindexing": True, "message": f"Model changed to {model}, re-indexing started"}
logger.info(f"Semantic search settings updated: {new_settings}", module="SemanticSearch")
return {"success": True, "settings": new_settings, "reindexing": False}
@router.post("/reindex")
@limiter.limit("2/minute")
@handle_exceptions
async def reindex_embeddings(
request: Request,
background_tasks: BackgroundTasks,
current_user: Dict = Depends(require_admin)
):
"""Clear and regenerate all embeddings."""
app_state = get_app_state()
if app_state.indexing_running:
return {"success": False, "message": "Indexing already in progress", "already_running": True}
search = get_semantic_search(app_state.db)
with app_state.db.get_connection(for_write=True) as conn:
cursor = conn.cursor()
cursor.execute('DELETE FROM content_embeddings')
deleted = cursor.rowcount
logger.info(f"Cleared {deleted} embeddings for reindexing", module="SemanticSearch")
app_state.indexing_running = True
app_state.indexing_start_time = time.time()
def run_reindex():
try:
results = search.generate_embeddings_batch(limit=EMBEDDING_BATCH_LIMIT)
logger.info(f"Reindex complete: {results}", module="SemanticSearch")
except Exception as e:
logger.error(f"Reindex failed: {e}", module="SemanticSearch")
finally:
app_state.indexing_running = False
app_state.indexing_start_time = None
background_tasks.add_task(run_reindex)
return {"success": True, "message": "Re-indexing started in background"}
@router.post("/clear")
@limiter.limit("2/minute")
@handle_exceptions
async def clear_embeddings(
request: Request,
current_user: Dict = Depends(require_admin)
):
"""Clear all embeddings."""
app_state = get_app_state()
with app_state.db.get_connection(for_write=True) as conn:
cursor = conn.cursor()
cursor.execute('DELETE FROM content_embeddings')
deleted = cursor.rowcount
logger.info(f"Cleared {deleted} embeddings", module="SemanticSearch")
return {"success": True, "deleted": deleted}

View File

@@ -0,0 +1,535 @@
"""
Stats Router
Handles statistics, monitoring, settings, and integrations:
- Dashboard statistics
- Downloader monitoring
- Settings management
- Immich integration
"""
import json
import sqlite3
import time
from typing import Dict, Optional
import requests
from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel
from slowapi import Limiter
from slowapi.util import get_remote_address
from ..core.dependencies import get_current_user, require_admin, get_app_state
from ..core.exceptions import handle_exceptions, NotFoundError, ValidationError
from modules.universal_logger import get_logger
logger = get_logger('API')
router = APIRouter(prefix="/api", tags=["Stats & Monitoring"])
limiter = Limiter(key_func=get_remote_address)
# ============================================================================
# PYDANTIC MODELS
# ============================================================================
class SettingUpdate(BaseModel):
value: dict | list | str | int | float | bool
category: Optional[str] = None
description: Optional[str] = None
# ============================================================================
# DASHBOARD STATISTICS
# ============================================================================
@router.get("/stats/dashboard")
@limiter.limit("60/minute")
@handle_exceptions
async def get_dashboard_stats(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Get comprehensive dashboard statistics."""
app_state = get_app_state()
with app_state.db.get_connection() as conn:
cursor = conn.cursor()
# Get download counts per platform (combine downloads and video_downloads)
cursor.execute("""
SELECT platform, SUM(cnt) as count FROM (
SELECT platform, COUNT(*) as cnt FROM downloads GROUP BY platform
UNION ALL
SELECT platform, COUNT(*) as cnt FROM video_downloads GROUP BY platform
) GROUP BY platform
""")
platform_data = {}
for row in cursor.fetchall():
platform = row[0]
if platform not in platform_data:
platform_data[platform] = {
'count': 0,
'size_bytes': 0
}
platform_data[platform]['count'] += row[1]
# Calculate storage sizes from file_inventory (final + review)
cursor.execute("""
SELECT platform, COALESCE(SUM(file_size), 0) as total_size
FROM file_inventory
WHERE location IN ('final', 'review')
GROUP BY platform
""")
for row in cursor.fetchall():
platform = row[0]
if platform not in platform_data:
platform_data[platform] = {'count': 0, 'size_bytes': 0}
platform_data[platform]['size_bytes'] += row[1]
# Only show platforms with actual files
storage_by_platform = []
for platform in sorted(platform_data.keys(), key=lambda p: platform_data[p]['size_bytes'], reverse=True):
if platform_data[platform]['size_bytes'] > 0:
storage_by_platform.append({
'platform': platform,
'count': platform_data[platform]['count'],
'size_bytes': platform_data[platform]['size_bytes'],
'size_mb': round(platform_data[platform]['size_bytes'] / 1024 / 1024, 2)
})
# Downloads per day (last 30 days) - combine downloads and video_downloads
cursor.execute("""
SELECT date, SUM(count) as count FROM (
SELECT DATE(download_date) as date, COUNT(*) as count
FROM downloads
WHERE download_date >= DATE('now', '-30 days')
GROUP BY DATE(download_date)
UNION ALL
SELECT DATE(download_date) as date, COUNT(*) as count
FROM video_downloads
WHERE download_date >= DATE('now', '-30 days')
GROUP BY DATE(download_date)
) GROUP BY date ORDER BY date
""")
downloads_per_day = [{'date': row[0], 'count': row[1]} for row in cursor.fetchall()]
# Content type breakdown
cursor.execute("""
SELECT
content_type,
COUNT(*) as count
FROM downloads
WHERE content_type IS NOT NULL
GROUP BY content_type
ORDER BY count DESC
""")
content_types = {row[0]: row[1] for row in cursor.fetchall()}
# Top sources
cursor.execute("""
SELECT
source,
platform,
COUNT(*) as count
FROM downloads
WHERE source IS NOT NULL
GROUP BY source, platform
ORDER BY count DESC
LIMIT 10
""")
top_sources = [{'source': row[0], 'platform': row[1], 'count': row[2]} for row in cursor.fetchall()]
# Total statistics - use file_inventory for accurate file counts
cursor.execute("""
SELECT
(SELECT COUNT(*) FROM file_inventory WHERE location IN ('final', 'review')) as total_downloads,
(SELECT COALESCE(SUM(file_size), 0) FROM file_inventory WHERE location IN ('final', 'review')) as total_size,
(SELECT COUNT(DISTINCT source) FROM downloads) +
(SELECT COUNT(DISTINCT uploader) FROM video_downloads) as unique_sources,
(SELECT COUNT(DISTINCT platform) FROM file_inventory) as platforms_used
""")
totals = cursor.fetchone()
# Get recycle bin and review counts separately
cursor.execute("SELECT COUNT(*) FROM recycle_bin")
recycle_count = cursor.fetchone()[0] or 0
cursor.execute("SELECT COUNT(*) FROM file_inventory WHERE location = 'review'")
review_count = cursor.fetchone()[0] or 0
# Growth rate - combine downloads and video_downloads
cursor.execute("""
SELECT
(SELECT SUM(CASE WHEN download_date >= DATE('now', '-7 days') THEN 1 ELSE 0 END) FROM downloads) +
(SELECT SUM(CASE WHEN download_date >= DATE('now', '-7 days') THEN 1 ELSE 0 END) FROM video_downloads) as this_week,
(SELECT SUM(CASE WHEN download_date >= DATE('now', '-14 days') AND download_date < DATE('now', '-7 days') THEN 1 ELSE 0 END) FROM downloads) +
(SELECT SUM(CASE WHEN download_date >= DATE('now', '-14 days') AND download_date < DATE('now', '-7 days') THEN 1 ELSE 0 END) FROM video_downloads) as last_week
""")
growth_row = cursor.fetchone()
growth_rate = 0
if growth_row and growth_row[1] > 0:
growth_rate = round(((growth_row[0] - growth_row[1]) / growth_row[1]) * 100, 1)
return {
'storage_by_platform': storage_by_platform,
'downloads_per_day': downloads_per_day,
'content_types': content_types,
'top_sources': top_sources,
'totals': {
'total_downloads': totals[0] or 0,
'total_size_bytes': totals[1] or 0,
'total_size_gb': round((totals[1] or 0) / 1024 / 1024 / 1024, 2),
'unique_sources': totals[2] or 0,
'platforms_used': totals[3] or 0,
'recycle_bin_count': recycle_count,
'review_count': review_count
},
'growth_rate': growth_rate
}
# ============================================================================
# FLARESOLVERR HEALTH CHECK
# ============================================================================
@router.get("/health/flaresolverr")
@limiter.limit("60/minute")
@handle_exceptions
async def check_flaresolverr_health(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Check FlareSolverr health status."""
app_state = get_app_state()
flaresolverr_url = "http://localhost:8191/v1"
try:
with app_state.db.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT value FROM settings WHERE key='flaresolverr'")
result = cursor.fetchone()
if result:
flaresolverr_config = json.loads(result[0])
if 'url' in flaresolverr_config:
flaresolverr_url = flaresolverr_config['url']
except (sqlite3.Error, json.JSONDecodeError, KeyError):
pass
start_time = time.time()
try:
response = requests.post(
flaresolverr_url,
json={"cmd": "sessions.list"},
timeout=5
)
response_time = round((time.time() - start_time) * 1000, 2)
if response.status_code == 200:
return {
'status': 'healthy',
'url': flaresolverr_url,
'response_time_ms': response_time,
'last_check': time.time(),
'sessions': response.json().get('sessions', [])
}
else:
return {
'status': 'unhealthy',
'url': flaresolverr_url,
'response_time_ms': response_time,
'last_check': time.time(),
'error': f"HTTP {response.status_code}: {response.text}"
}
except requests.exceptions.ConnectionError:
return {
'status': 'offline',
'url': flaresolverr_url,
'last_check': time.time(),
'error': 'Connection refused - FlareSolverr may not be running'
}
except requests.exceptions.Timeout:
return {
'status': 'timeout',
'url': flaresolverr_url,
'last_check': time.time(),
'error': 'Request timed out after 5 seconds'
}
except Exception as e:
return {
'status': 'error',
'url': flaresolverr_url,
'last_check': time.time(),
'error': str(e)
}
# ============================================================================
# MONITORING ENDPOINTS
# ============================================================================
@router.get("/monitoring/status")
@limiter.limit("100/minute")
@handle_exceptions
async def get_monitoring_status(
request: Request,
hours: int = 24,
current_user: Dict = Depends(get_current_user)
):
"""Get downloader monitoring status."""
from modules.downloader_monitor import get_monitor
app_state = get_app_state()
monitor = get_monitor(app_state.db, app_state.settings)
status = monitor.get_downloader_status(hours=hours)
return {
"success": True,
"downloaders": status,
"window_hours": hours
}
@router.get("/monitoring/history")
@limiter.limit("100/minute")
@handle_exceptions
async def get_monitoring_history(
request: Request,
downloader: str = None,
limit: int = 100,
current_user: Dict = Depends(get_current_user)
):
"""Get download monitoring history."""
app_state = get_app_state()
with app_state.db.get_connection() as conn:
cursor = conn.cursor()
if downloader:
cursor.execute("""
SELECT
id, downloader, username, timestamp, success,
file_count, error_message, alert_sent
FROM download_monitor
WHERE downloader = ?
ORDER BY timestamp DESC
LIMIT ?
""", (downloader, limit))
else:
cursor.execute("""
SELECT
id, downloader, username, timestamp, success,
file_count, error_message, alert_sent
FROM download_monitor
ORDER BY timestamp DESC
LIMIT ?
""", (limit,))
history = []
for row in cursor.fetchall():
history.append({
'id': row['id'],
'downloader': row['downloader'],
'username': row['username'],
'timestamp': row['timestamp'],
'success': bool(row['success']),
'file_count': row['file_count'],
'error_message': row['error_message'],
'alert_sent': bool(row['alert_sent'])
})
return {
"success": True,
"history": history
}
@router.delete("/monitoring/history")
@limiter.limit("10/minute")
@handle_exceptions
async def clear_monitoring_history(
request: Request,
days: int = 30,
current_user: Dict = Depends(require_admin)
):
"""Clear old monitoring logs."""
from modules.downloader_monitor import get_monitor
app_state = get_app_state()
monitor = get_monitor(app_state.db, app_state.settings)
monitor.clear_old_logs(days=days)
return {
"success": True,
"message": f"Cleared logs older than {days} days"
}
# ============================================================================
# SETTINGS ENDPOINTS
# ============================================================================
@router.get("/settings/{key}")
@limiter.limit("60/minute")
@handle_exceptions
async def get_setting(
request: Request,
key: str,
current_user: Dict = Depends(get_current_user)
):
"""Get a specific setting value."""
app_state = get_app_state()
value = app_state.settings.get(key)
if value is None:
raise NotFoundError(f"Setting '{key}' not found")
return value
@router.put("/settings/{key}")
@limiter.limit("30/minute")
@handle_exceptions
async def update_setting(
request: Request,
key: str,
body: Dict,
current_user: Dict = Depends(get_current_user)
):
"""Update a specific setting value."""
app_state = get_app_state()
value = body.get('value')
if value is None:
raise ValidationError("Missing 'value' in request body")
app_state.settings.set(
key=key,
value=value,
category=body.get('category'),
description=body.get('description'),
updated_by=current_user.get('username', 'user')
)
return {
"success": True,
"message": f"Setting '{key}' updated successfully"
}
# ============================================================================
# IMMICH INTEGRATION
# ============================================================================
@router.post("/immich/scan")
@limiter.limit("10/minute")
@handle_exceptions
async def trigger_immich_scan(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Trigger Immich library scan."""
app_state = get_app_state()
immich_config = app_state.settings.get('immich', {})
if not immich_config.get('enabled'):
return {
"success": False,
"message": "Immich integration is not enabled"
}
api_url = immich_config.get('api_url')
api_key = immich_config.get('api_key')
library_id = immich_config.get('library_id')
if not all([api_url, api_key, library_id]):
return {
"success": False,
"message": "Immich configuration incomplete (missing api_url, api_key, or library_id)"
}
try:
response = requests.post(
f"{api_url}/libraries/{library_id}/scan",
headers={'X-API-KEY': api_key},
timeout=10
)
if response.status_code in [200, 201, 204]:
return {
"success": True,
"message": f"Successfully triggered Immich scan for library {library_id}"
}
else:
return {
"success": False,
"message": f"Immich scan request failed with status {response.status_code}: {response.text}"
}
except requests.exceptions.RequestException as e:
return {
"success": False,
"message": f"Failed to connect to Immich: {str(e)}"
}
# ============================================================================
# ERROR MONITORING SETTINGS
# ============================================================================
class ErrorMonitoringSettings(BaseModel):
enabled: bool = True
push_alert_enabled: bool = True
push_alert_delay_hours: int = 24
dashboard_banner_enabled: bool = True
retention_days: int = 7
@router.get("/error-monitoring/settings")
@limiter.limit("60/minute")
@handle_exceptions
async def get_error_monitoring_settings(
request: Request,
current_user: Dict = Depends(get_current_user)
):
"""Get error monitoring settings."""
app_state = get_app_state()
settings = app_state.settings.get('error_monitoring', {
'enabled': True,
'push_alert_enabled': True,
'push_alert_delay_hours': 24,
'dashboard_banner_enabled': True,
'retention_days': 7
})
return settings
@router.put("/error-monitoring/settings")
@limiter.limit("30/minute")
@handle_exceptions
async def update_error_monitoring_settings(
request: Request,
settings: ErrorMonitoringSettings,
current_user: Dict = Depends(get_current_user)
):
"""Update error monitoring settings."""
app_state = get_app_state()
app_state.settings.set(
key='error_monitoring',
value=settings.model_dump(),
category='monitoring',
description='Error monitoring and alert settings',
updated_by=current_user.get('username', 'user')
)
return {
"success": True,
"message": "Error monitoring settings updated",
"settings": settings.model_dump()
}

1617
web/backend/routers/video.py Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1 @@
# Services module exports

524
web/backend/totp_manager.py Normal file
View File

@@ -0,0 +1,524 @@
#!/usr/bin/env python3
"""
TOTP Manager for Media Downloader
Handles Time-based One-Time Password (TOTP) operations for 2FA
Based on backup-central's implementation
"""
import sys
import pyotp
import qrcode
import io
import base64
import sqlite3
import hashlib
import bcrypt
import secrets
from datetime import datetime, timedelta
from typing import Optional, Dict, List, Tuple
from cryptography.fernet import Fernet
from pathlib import Path
# Add parent path to allow imports from modules
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from modules.universal_logger import get_logger
logger = get_logger('TOTPManager')
class TOTPManager:
def __init__(self, db_path: str = None):
if db_path is None:
db_path = str(Path(__file__).parent.parent.parent / 'database' / 'auth.db')
self.db_path = db_path
self.issuer = 'Media Downloader'
self.window = 1 # Allow 1 step (30s) tolerance for time skew
# Encryption key for TOTP secrets
self.encryption_key = self._get_encryption_key()
self.cipher_suite = Fernet(self.encryption_key)
# Rate limiting configuration
self.max_attempts = 5
self.rate_limit_window = timedelta(minutes=15)
self.lockout_duration = timedelta(minutes=30)
# Initialize database tables
self._init_database()
def _get_encryption_key(self) -> bytes:
"""Get or generate encryption key for TOTP secrets"""
key_file = Path(__file__).parent.parent.parent / '.totp_encryption_key'
if key_file.exists():
with open(key_file, 'rb') as f:
return f.read()
# Generate new key
key = Fernet.generate_key()
try:
with open(key_file, 'wb') as f:
f.write(key)
import os
os.chmod(key_file, 0o600)
except Exception as e:
logger.warning(f"Could not save TOTP encryption key: {e}", module="TOTP")
return key
def _init_database(self):
"""Initialize TOTP-related database tables"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# Ensure TOTP columns exist in users table
try:
cursor.execute("ALTER TABLE users ADD COLUMN totp_secret TEXT")
except sqlite3.OperationalError:
pass # Column already exists
try:
cursor.execute("ALTER TABLE users ADD COLUMN totp_enabled INTEGER NOT NULL DEFAULT 0")
except sqlite3.OperationalError:
pass
try:
cursor.execute("ALTER TABLE users ADD COLUMN totp_enrolled_at TEXT")
except sqlite3.OperationalError:
pass
# TOTP rate limiting table
cursor.execute("""
CREATE TABLE IF NOT EXISTS totp_rate_limit (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL,
ip_address TEXT,
attempts INTEGER NOT NULL DEFAULT 1,
window_start TEXT NOT NULL,
locked_until TEXT,
UNIQUE(username, ip_address)
)
""")
# TOTP audit log
cursor.execute("""
CREATE TABLE IF NOT EXISTS totp_audit_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL,
action TEXT NOT NULL,
success INTEGER NOT NULL,
ip_address TEXT,
user_agent TEXT,
details TEXT,
timestamp TEXT NOT NULL
)
""")
conn.commit()
def generate_secret(self, username: str) -> Dict:
"""
Generate a new TOTP secret and QR code for a user
Returns:
dict with secret, qrCodeDataURL, manualEntryKey
"""
try:
# Generate secret
secret = pyotp.random_base32()
# Create provisioning URI
totp = pyotp.TOTP(secret)
provisioning_uri = totp.provisioning_uri(
name=username,
issuer_name=self.issuer
)
# Generate QR code
qr = qrcode.QRCode(version=1, box_size=10, border=4)
qr.add_data(provisioning_uri)
qr.make(fit=True)
img = qr.make_image(fill_color="black", back_color="white")
# Convert to base64 data URL
buffer = io.BytesIO()
img.save(buffer, format='PNG')
img_str = base64.b64encode(buffer.getvalue()).decode()
qr_code_data_url = f"data:image/png;base64,{img_str}"
self._log_audit(username, 'totp_secret_generated', True, None, None,
'TOTP secret generated for user')
return {
'secret': secret,
'qrCodeDataURL': qr_code_data_url,
'manualEntryKey': secret,
'otpauthURL': provisioning_uri
}
except Exception as error:
self._log_audit(username, 'totp_secret_generation_failed', False, None, None,
f'Error: {str(error)}')
raise
def verify_token(self, secret: str, token: str) -> bool:
"""
Verify a TOTP token
Args:
secret: Base32 encoded secret
token: 6-digit TOTP code
Returns:
True if valid, False otherwise
"""
try:
totp = pyotp.TOTP(secret)
return totp.verify(token, valid_window=self.window)
except Exception as e:
logger.error(f"TOTP verification error: {e}", module="TOTP")
return False
def encrypt_secret(self, secret: str) -> str:
"""Encrypt TOTP secret for storage"""
encrypted = self.cipher_suite.encrypt(secret.encode())
return base64.b64encode(encrypted).decode()
def decrypt_secret(self, encrypted_secret: str) -> str:
"""Decrypt TOTP secret from storage"""
encrypted = base64.b64decode(encrypted_secret.encode())
decrypted = self.cipher_suite.decrypt(encrypted)
return decrypted.decode()
def generate_backup_codes(self, count: int = 10) -> List[str]:
"""
Generate backup codes for recovery
Args:
count: Number of codes to generate
Returns:
List of formatted backup codes
"""
codes = []
for _ in range(count):
# Generate 8-character hex code
code = secrets.token_hex(4).upper()
# Format as XXXX-XXXX
formatted = f"{code[:4]}-{code[4:8]}"
codes.append(formatted)
return codes
def hash_backup_code(self, code: str) -> str:
"""
Hash a backup code for storage using bcrypt
Args:
code: Backup code to hash
Returns:
Bcrypt hash of the code
"""
# Use bcrypt with default work factor (12 rounds)
return bcrypt.hashpw(code.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
def verify_backup_code(self, code: str, username: str) -> Tuple[bool, int]:
"""
Verify a backup code and mark it as used
Returns:
Tuple of (valid, remaining_codes)
"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# Get unused backup codes for user
cursor.execute("""
SELECT id, code_hash FROM backup_codes
WHERE username = ? AND used = 0
""", (username,))
rows = cursor.fetchall()
# Check if code matches any unused code
for row_id, stored_hash in rows:
# Support both bcrypt (new) and SHA256 (legacy) hashes
is_match = False
if stored_hash.startswith('$2b$'):
# Bcrypt hash
try:
is_match = bcrypt.checkpw(code.encode('utf-8'), stored_hash.encode('utf-8'))
except Exception as e:
logger.error(f"Error verifying bcrypt backup code: {e}", module="TOTP")
continue
else:
# Legacy SHA256 hash
legacy_hash = hashlib.sha256(code.encode()).hexdigest()
is_match = (stored_hash == legacy_hash)
if is_match:
# Mark as used
cursor.execute("""
UPDATE backup_codes
SET used = 1, used_at = ?
WHERE id = ?
""", (datetime.now().isoformat(), row_id))
conn.commit()
# Count remaining codes
cursor.execute("""
SELECT COUNT(*) FROM backup_codes
WHERE username = ? AND used = 0
""", (username,))
remaining = cursor.fetchone()[0]
return True, remaining
return False, len(rows)
def enable_totp(self, username: str, secret: str, backup_codes: List[str]) -> bool:
"""
Enable TOTP for a user
Args:
username: Username
secret: TOTP secret
backup_codes: List of backup codes
Returns:
True if successful
"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# Encrypt secret
encrypted_secret = self.encrypt_secret(secret)
# Update user
cursor.execute("""
UPDATE users
SET totp_secret = ?, totp_enabled = 1, totp_enrolled_at = ?
WHERE username = ?
""", (encrypted_secret, datetime.now().isoformat(), username))
# Delete old backup codes
cursor.execute("DELETE FROM backup_codes WHERE username = ?", (username,))
# Insert new backup codes
for code in backup_codes:
code_hash = self.hash_backup_code(code)
cursor.execute("""
INSERT INTO backup_codes (username, code_hash, created_at)
VALUES (?, ?, ?)
""", (username, code_hash, datetime.now().isoformat()))
conn.commit()
self._log_audit(username, 'totp_enabled', True, None, None,
f'TOTP enabled with {len(backup_codes)} backup codes')
return True
except Exception as e:
logger.error(f"Error enabling TOTP: {e}", module="TOTP")
return False
def disable_totp(self, username: str) -> bool:
"""
Disable TOTP for a user
Returns:
True if successful
"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
UPDATE users
SET totp_secret = NULL, totp_enabled = 0, totp_enrolled_at = NULL
WHERE username = ?
""", (username,))
# Delete backup codes
cursor.execute("DELETE FROM backup_codes WHERE username = ?", (username,))
conn.commit()
self._log_audit(username, 'totp_disabled', True, None, None, 'TOTP disabled')
return True
except Exception as e:
logger.error(f"Error disabling TOTP: {e}", module="TOTP")
return False
def get_totp_status(self, username: str) -> Dict:
"""
Get TOTP status for a user
Returns:
dict with enabled, enrolledAt
"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT totp_enabled, totp_enrolled_at FROM users
WHERE username = ?
""", (username,))
row = cursor.fetchone()
if not row:
return {'enabled': False, 'enrolledAt': None}
return {
'enabled': bool(row[0]),
'enrolledAt': row[1]
}
def check_rate_limit(self, username: str, ip_address: str) -> Tuple[bool, int, Optional[str]]:
"""
Check and enforce rate limiting for TOTP verification
Returns:
Tuple of (allowed, attempts_remaining, locked_until)
"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
now = datetime.now()
window_start = now - self.rate_limit_window
cursor.execute("""
SELECT attempts, window_start, locked_until
FROM totp_rate_limit
WHERE username = ? AND ip_address = ?
""", (username, ip_address))
row = cursor.fetchone()
if not row:
# First attempt - create record
cursor.execute("""
INSERT INTO totp_rate_limit (username, ip_address, attempts, window_start)
VALUES (?, ?, 1, ?)
""", (username, ip_address, now.isoformat()))
conn.commit()
return True, self.max_attempts - 1, None
attempts, stored_window_start, locked_until = row
stored_window_start = datetime.fromisoformat(stored_window_start)
# Check if locked out
if locked_until:
locked_until_dt = datetime.fromisoformat(locked_until)
if locked_until_dt > now:
return False, 0, locked_until
else:
# Lockout expired - reset
cursor.execute("""
DELETE FROM totp_rate_limit
WHERE username = ? AND ip_address = ?
""", (username, ip_address))
conn.commit()
return True, self.max_attempts, None
# Check if window expired
if stored_window_start < window_start:
# Reset window
cursor.execute("""
UPDATE totp_rate_limit
SET attempts = 1, window_start = ?, locked_until = NULL
WHERE username = ? AND ip_address = ?
""", (now.isoformat(), username, ip_address))
conn.commit()
return True, self.max_attempts - 1, None
# Increment attempts
new_attempts = attempts + 1
if new_attempts >= self.max_attempts:
# Lock out
locked_until = (now + self.lockout_duration).isoformat()
cursor.execute("""
UPDATE totp_rate_limit
SET attempts = ?, locked_until = ?
WHERE username = ? AND ip_address = ?
""", (new_attempts, locked_until, username, ip_address))
conn.commit()
return False, 0, locked_until
else:
cursor.execute("""
UPDATE totp_rate_limit
SET attempts = ?
WHERE username = ? AND ip_address = ?
""", (new_attempts, username, ip_address))
conn.commit()
return True, self.max_attempts - new_attempts, None
def reset_rate_limit(self, username: str, ip_address: str):
"""Reset rate limit after successful authentication"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
DELETE FROM totp_rate_limit
WHERE username = ? AND ip_address = ?
""", (username, ip_address))
conn.commit()
def regenerate_backup_codes(self, username: str) -> List[str]:
"""
Regenerate backup codes for a user
Returns:
List of new backup codes
"""
# Generate new codes
new_codes = self.generate_backup_codes(10)
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# Delete old codes
cursor.execute("DELETE FROM backup_codes WHERE username = ?", (username,))
# Insert new codes
for code in new_codes:
code_hash = self.hash_backup_code(code)
cursor.execute("""
INSERT INTO backup_codes (username, code_hash, created_at)
VALUES (?, ?, ?)
""", (username, code_hash, datetime.now().isoformat()))
conn.commit()
self._log_audit(username, 'backup_codes_regenerated', True, None, None,
f'Generated {len(new_codes)} new backup codes')
return new_codes
def _log_audit(self, username: str, action: str, success: bool,
ip_address: Optional[str], user_agent: Optional[str],
details: Optional[str] = None):
"""Log TOTP audit event"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO totp_audit_log
(username, action, success, ip_address, user_agent, details, timestamp)
VALUES (?, ?, ?, ?, ?, ?, ?)
""", (username, action, int(success), ip_address, user_agent, details,
datetime.now().isoformat()))
conn.commit()
except Exception as e:
logger.error(f"Error logging TOTP audit: {e}", module="TOTP")

792
web/backend/twofa_routes.py Normal file
View File

@@ -0,0 +1,792 @@
#!/usr/bin/env python3
"""
Two-Factor Authentication API Routes for Media Downloader
Handles TOTP, Passkey, and Duo 2FA endpoints
Based on backup-central's implementation
"""
import sys
from pathlib import Path
from fastapi import APIRouter, Depends, Request, Body
from fastapi.responses import JSONResponse
from typing import Optional, Dict, List
from pydantic import BaseModel
from core.config import settings
# Add parent path to allow imports from modules
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from modules.universal_logger import get_logger
logger = get_logger('TwoFactorAuth')
# Import managers
from web.backend.totp_manager import TOTPManager
from web.backend.duo_manager import DuoManager
# Passkey manager - optional (requires compatible webauthn library)
try:
from web.backend.passkey_manager import PasskeyManager
PASSKEY_AVAILABLE = True
logger.info("Passkey support enabled", module="2FA")
except ImportError as e:
logger.warning(f"Passkey support disabled: {e}", module="2FA")
PasskeyManager = None
PASSKEY_AVAILABLE = False
# Pydantic models for request/response
# Standard response models
class StandardResponse(BaseModel):
"""Standard API response with success flag and message"""
success: bool
message: str
class TOTPSetupResponse(BaseModel):
success: bool
secret: Optional[str] = None
qrCodeDataURL: Optional[str] = None
manualEntryKey: Optional[str] = None
message: Optional[str] = None
class TOTPVerifyRequest(BaseModel):
code: str
class TOTPVerifyResponse(BaseModel):
success: bool
backupCodes: Optional[List[str]] = None
message: str
class TOTPLoginVerifyRequest(BaseModel):
username: str
code: str
useBackupCode: Optional[bool] = False
rememberMe: bool = False
class TOTPDisableRequest(BaseModel):
password: str
code: str
class PasskeyRegistrationOptionsResponse(BaseModel):
success: bool
options: Optional[Dict] = None
message: Optional[str] = None
class PasskeyVerifyRegistrationRequest(BaseModel):
credential: Dict
deviceName: Optional[str] = None
class PasskeyAuthenticationOptionsResponse(BaseModel):
success: bool
options: Optional[Dict] = None
message: Optional[str] = None
class PasskeyVerifyAuthenticationRequest(BaseModel):
credential: Dict
username: Optional[str] = None
rememberMe: bool = False
class DuoAuthRequest(BaseModel):
username: str
rememberMe: bool = False
class DuoAuthResponse(BaseModel):
success: bool
authUrl: Optional[str] = None
state: Optional[str] = None
message: Optional[str] = None
class DuoCallbackRequest(BaseModel):
state: str
duo_code: Optional[str] = None
def create_2fa_router(auth_manager, get_current_user_dependency):
"""
Create and configure the 2FA router
Args:
auth_manager: AuthManager instance
get_current_user_dependency: FastAPI dependency for authentication
Returns:
Configured APIRouter
"""
router = APIRouter()
# Initialize managers
totp_manager = TOTPManager()
duo_manager = DuoManager()
passkey_manager = PasskeyManager() if PASSKEY_AVAILABLE else None
# ============================================================================
# TOTP Endpoints
# ============================================================================
@router.post("/totp/setup", response_model=TOTPSetupResponse)
async def totp_setup(request: Request, current_user: Dict = Depends(get_current_user_dependency)):
"""Generate TOTP secret and QR code"""
try:
username = current_user.get('sub')
# Check if already enabled
status = totp_manager.get_totp_status(username)
if status['enabled']:
return TOTPSetupResponse(
success=False,
message='2FA is already enabled for this account'
)
# Generate secret and QR code
result = totp_manager.generate_secret(username)
# Store in session temporarily (until verified)
# Note: In production, use proper session management
request.session['totp_setup'] = {
'secret': result['secret'],
'username': username
}
return TOTPSetupResponse(
success=True,
secret=result['secret'],
qrCodeDataURL=result['qrCodeDataURL'],
manualEntryKey=result['manualEntryKey'],
message='Scan QR code with your authenticator app'
)
except Exception as e:
logger.error(f"TOTP setup error: {e}", module="2FA")
return TOTPSetupResponse(success=False, message='Failed to generate 2FA setup')
@router.post("/totp/verify", response_model=TOTPVerifyResponse)
async def totp_verify(
verify_data: TOTPVerifyRequest,
request: Request,
current_user: Dict = Depends(get_current_user_dependency)
):
"""Verify TOTP code and enable 2FA"""
try:
username = current_user.get('sub')
code = verify_data.code
# Validate code format
if not code or len(code) != 6 or not code.isdigit():
return TOTPVerifyResponse(
success=False,
message='Invalid code format. Must be 6 digits.'
)
# Get temporary secret from session
totp_setup = request.session.get('totp_setup')
if not totp_setup or totp_setup.get('username') != username:
return TOTPVerifyResponse(
success=False,
message='No setup in progress. Please start 2FA setup again.'
)
secret = totp_setup['secret']
# Verify the code
if not totp_manager.verify_token(secret, code):
return TOTPVerifyResponse(
success=False,
message='Invalid verification code. Please try again.'
)
# Generate backup codes
backup_codes = totp_manager.generate_backup_codes(10)
# Enable TOTP
totp_manager.enable_totp(username, secret, backup_codes)
# Clear session
request.session.pop('totp_setup', None)
return TOTPVerifyResponse(
success=True,
backupCodes=backup_codes,
message='2FA enabled successfully! Save your backup codes.'
)
except Exception as e:
logger.error(f"TOTP verify error: {e}", module="2FA")
return TOTPVerifyResponse(success=False, message='Failed to verify code')
@router.post("/totp/disable")
async def totp_disable(
disable_data: TOTPDisableRequest,
request: Request,
current_user: Dict = Depends(get_current_user_dependency)
):
"""Disable TOTP 2FA"""
try:
username = current_user.get('sub')
password = disable_data.password
code = disable_data.code
# Verify password
user_info = auth_manager.get_user(username)
if not user_info:
return {'success': False, 'message': 'User not found'}
# Get user password hash from database
import sqlite3
with sqlite3.connect(auth_manager.db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT password_hash, totp_secret FROM users WHERE username = ?", (username,))
row = cursor.fetchone()
if not row:
return {'success': False, 'message': 'User not found'}
password_hash, encrypted_secret = row
if not auth_manager.verify_password(password, password_hash):
return {'success': False, 'message': 'Invalid password'}
# Verify current TOTP code
if encrypted_secret:
decrypted_secret = totp_manager.decrypt_secret(encrypted_secret)
if not totp_manager.verify_token(decrypted_secret, code):
return {'success': False, 'message': 'Invalid verification code'}
# Disable TOTP
totp_manager.disable_totp(username)
return {'success': True, 'message': '2FA has been disabled'}
except Exception as e:
logger.error(f"TOTP disable error: {e}", module="2FA")
return {'success': False, 'message': 'Failed to disable 2FA'}
@router.post("/totp/regenerate-backup-codes")
async def totp_regenerate_backup_codes(
code: str = Body(..., embed=True),
current_user: Dict = Depends(get_current_user_dependency)
):
"""Regenerate backup codes"""
try:
username = current_user.get('sub')
# Get user
import sqlite3
with sqlite3.connect(auth_manager.db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT totp_enabled, totp_secret FROM users WHERE username = ?", (username,))
row = cursor.fetchone()
if not row or not row[0]:
return {'success': False, 'message': '2FA is not enabled'}
# Verify current code
encrypted_secret = row[1]
if encrypted_secret:
decrypted_secret = totp_manager.decrypt_secret(encrypted_secret)
if not totp_manager.verify_token(decrypted_secret, code):
return {'success': False, 'message': 'Invalid verification code'}
# Generate new backup codes
backup_codes = totp_manager.regenerate_backup_codes(username)
return {
'success': True,
'backupCodes': backup_codes,
'message': 'New backup codes generated'
}
except Exception as e:
logger.error(f"Backup codes regeneration error: {e}", module="2FA")
return {'success': False, 'message': 'Failed to regenerate backup codes'}
@router.get("/totp/status")
async def totp_status(current_user: Dict = Depends(get_current_user_dependency)):
"""Get TOTP status for current user"""
try:
username = current_user.get('sub')
status = totp_manager.get_totp_status(username)
return {'success': True, **status}
except Exception as e:
logger.error(f"TOTP status error: {e}", module="2FA")
return {'success': False, 'message': 'Failed to get 2FA status'}
@router.post("/totp/login/verify")
async def totp_login_verify(verify_data: TOTPLoginVerifyRequest, request: Request):
"""Verify TOTP code during login"""
try:
username = verify_data.username
code = verify_data.code
use_backup_code = verify_data.useBackupCode
# Get client IP
ip_address = request.client.host if request.client else None
# Check rate limit
allowed, attempts_remaining, locked_until = totp_manager.check_rate_limit(username, ip_address)
if not allowed:
return {
'success': False,
'message': f'Too many failed attempts. Locked until {locked_until}',
'lockedUntil': locked_until
}
# Get user
import sqlite3
with sqlite3.connect(auth_manager.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT password_hash, totp_enabled, totp_secret, role, email
FROM users WHERE username = ?
""", (username,))
row = cursor.fetchone()
if not row or not row[1]: # Not TOTP enabled
return {'success': False, 'message': 'Invalid request'}
password_hash, totp_enabled, encrypted_secret, role, email = row
is_valid = False
# Verify backup code or TOTP
if use_backup_code:
is_valid, remaining = totp_manager.verify_backup_code(code, username)
else:
if encrypted_secret:
decrypted_secret = totp_manager.decrypt_secret(encrypted_secret)
is_valid = totp_manager.verify_token(decrypted_secret, code)
if not is_valid:
return {
'success': False,
'message': 'Invalid verification code',
'attemptsRemaining': attempts_remaining - 1
}
# Success - reset rate limit and create session
totp_manager.reset_rate_limit(username, ip_address)
# Create session with remember_me setting
session_result = auth_manager._create_session(username, role, ip_address, verify_data.rememberMe)
# Create response with cookie
response = JSONResponse(content=session_result)
max_age = 30 * 24 * 60 * 60 if verify_data.rememberMe else None
response.set_cookie(
key="auth_token",
value=session_result.get('token'),
max_age=max_age,
httponly=True,
secure=settings.SECURE_COOKIES,
samesite="lax",
path="/"
)
return response
except Exception as e:
logger.error(f"TOTP login verification error: {e}", module="2FA")
return {'success': False, 'message': 'Verification failed'}
# ============================================================================
# Passkey Endpoints
# ============================================================================
@router.post("/passkey/registration-options", response_model=PasskeyRegistrationOptionsResponse)
async def passkey_registration_options(current_user: Dict = Depends(get_current_user_dependency)):
"""Generate passkey registration options"""
if not PASSKEY_AVAILABLE:
return PasskeyRegistrationOptionsResponse(
success=False,
message='Passkey support is not available on this server'
)
try:
username = current_user.get('sub')
email = current_user.get('email')
options = passkey_manager.generate_registration_options(username, email)
return PasskeyRegistrationOptionsResponse(
success=True,
options=options
)
except Exception as e:
logger.error(f"Passkey registration options error: {e}", module="2FA")
return PasskeyRegistrationOptionsResponse(
success=False,
message='Failed to generate registration options. Please try again.'
)
@router.post("/passkey/verify-registration")
async def passkey_verify_registration(
verify_data: PasskeyVerifyRegistrationRequest,
current_user: Dict = Depends(get_current_user_dependency)
):
"""Verify passkey registration"""
if not PASSKEY_AVAILABLE:
return {'success': False, 'message': 'Passkey support is not available on this server'}
try:
username = current_user.get('sub')
result = passkey_manager.verify_registration(
username,
verify_data.credential,
verify_data.deviceName
)
return {'success': True, **result}
except Exception as e:
logger.error(f"Passkey registration verification error: {e}", module="2FA")
return {'success': False, 'message': 'Failed to verify passkey registration. Please try again.'}
@router.post("/passkey/authentication-options", response_model=PasskeyAuthenticationOptionsResponse)
async def passkey_authentication_options(username: Optional[str] = Body(None, embed=True)):
"""Generate passkey authentication options"""
if not PASSKEY_AVAILABLE:
return PasskeyAuthenticationOptionsResponse(success=False, message='Passkey support is not available on this server')
try:
options = passkey_manager.generate_authentication_options(username)
return PasskeyAuthenticationOptionsResponse(
success=True,
options=options
)
except Exception as e:
logger.error(f"Passkey authentication options error: {e}", module="2FA")
return PasskeyAuthenticationOptionsResponse(
success=False,
message='Failed to generate authentication options. Please try again.'
)
@router.post("/passkey/verify-authentication")
async def passkey_verify_authentication(
verify_data: PasskeyVerifyAuthenticationRequest,
request: Request
):
"""Verify passkey authentication"""
if not PASSKEY_AVAILABLE:
return {'success': False, 'message': 'Passkey support is not available on this server'}
try:
# Get client IP
ip_address = request.client.host if request.client else None
# Verify authentication - use username from request if provided
result = passkey_manager.verify_authentication(verify_data.username, verify_data.credential)
if result['success']:
username = result['username']
# Get user info
import sqlite3
with sqlite3.connect(auth_manager.db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT role, email FROM users WHERE username = ?", (username,))
row = cursor.fetchone()
if row:
role, email = row
# Create session with remember_me setting
session_result = auth_manager._create_session(username, role, ip_address, verify_data.rememberMe)
# Create response with cookie
response = JSONResponse(content=session_result)
max_age = 30 * 24 * 60 * 60 if verify_data.rememberMe else None
response.set_cookie(
key="auth_token",
value=session_result.get('token'),
max_age=max_age,
httponly=True,
secure=settings.SECURE_COOKIES,
samesite="lax",
path="/"
)
return response
return result
except Exception as e:
logger.error(f"Passkey authentication verification error: {e}", module="2FA")
return {'success': False, 'message': 'Failed to verify passkey authentication. Please try again.'}
@router.get("/passkey/list")
async def passkey_list(current_user: Dict = Depends(get_current_user_dependency)):
"""List user's passkeys"""
if not PASSKEY_AVAILABLE:
return {'success': False, 'message': 'Passkey support is not available on this server'}
try:
username = current_user.get('sub')
credentials = passkey_manager.list_credentials(username)
return {'success': True, 'credentials': credentials}
except Exception as e:
logger.error(f"Passkey list error: {e}", module="2FA")
return {'success': False, 'message': 'Failed to retrieve passkeys. Please try again.'}
@router.delete("/passkey/{credential_id}")
async def passkey_remove(
credential_id: str,
current_user: Dict = Depends(get_current_user_dependency)
):
"""Remove a passkey"""
if not PASSKEY_AVAILABLE:
return {'success': False, 'message': 'Passkey support is not available on this server'}
try:
username = current_user.get('sub')
logger.debug(f"Removing passkey - Username: {username}, Credential ID: {credential_id}", module="2FA")
logger.debug(f"Credential ID length: {len(credential_id)}, type: {type(credential_id)}", module="2FA")
success = passkey_manager.remove_credential(username, credential_id)
if success:
return {'success': True, 'message': 'Passkey removed'}
else:
return {'success': False, 'message': 'Passkey not found'}
except Exception as e:
logger.error(f"Passkey remove error: {e}", exc_info=True, module="2FA")
return {'success': False, 'message': 'Failed to remove passkey. Please try again.'}
@router.get("/passkey/status")
async def passkey_status(current_user: Dict = Depends(get_current_user_dependency)):
"""Get passkey status"""
if not PASSKEY_AVAILABLE:
return {'success': False, 'enabled': False, 'message': 'Passkey support is not available on this server'}
try:
username = current_user.get('sub')
status = passkey_manager.get_passkey_status(username)
return {'success': True, **status}
except Exception as e:
logger.error(f"Passkey status error: {e}", module="2FA")
return {'success': False, 'message': 'Failed to retrieve passkey status. Please try again.'}
# ============================================================================
# Duo Endpoints
# ============================================================================
@router.post("/duo/auth", response_model=DuoAuthResponse)
async def duo_auth(auth_data: DuoAuthRequest):
"""Initiate Duo authentication"""
try:
if not duo_manager.is_duo_configured():
return DuoAuthResponse(
success=False,
message='Duo is not configured'
)
username = auth_data.username
# Create auth URL (pass rememberMe to preserve for callback)
result = duo_manager.create_auth_url(username, auth_data.rememberMe)
return DuoAuthResponse(
success=True,
authUrl=result['authUrl'],
state=result['state']
)
except Exception as e:
logger.error(f"Duo auth error: {e}", module="2FA")
return DuoAuthResponse(success=False, message='Failed to initiate Duo authentication. Please try again.')
@router.get("/duo/callback")
async def duo_callback(
state: str,
duo_code: Optional[str] = None,
code: Optional[str] = None,
request: Request = None
):
"""Handle Duo callback (OAuth 2.0)"""
try:
# Duo sends 'code' parameter, normalize to duo_code
if not duo_code and code:
duo_code = code
if not duo_code:
return {'success': False, 'message': 'Missing authorization code'}
# Verify state and get username and remember_me
result = duo_manager.verify_state(state)
if not result:
return {'success': False, 'message': 'Invalid or expired state'}
username, remember_me = result
# Verify Duo response
if duo_manager.verify_duo_response(duo_code, username):
# Get client IP
ip_address = request.client.host if request.client else None
# Get user info
import sqlite3
with sqlite3.connect(auth_manager.db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT role, email FROM users WHERE username = ?", (username,))
row = cursor.fetchone()
if row:
role, email = row
# Create session with remember_me setting
session_result = auth_manager._create_session(username, role, ip_address, remember_me)
# Redirect to frontend with success (token in URL and cookie)
from fastapi.responses import RedirectResponse
token = session_result.get('token')
session_id = session_result.get('sessionId')
# Create redirect response with token in URL (frontend reads and stores it)
redirect_url = f"/?duo_auth=success&token={token}&sessionId={session_id}&username={username}"
response = RedirectResponse(url=redirect_url, status_code=302)
# Set auth cookie with proper max_age
max_age = 30 * 24 * 60 * 60 if remember_me else None
response.set_cookie(
key="auth_token",
value=token,
max_age=max_age,
httponly=True,
secure=settings.SECURE_COOKIES,
samesite="lax",
path="/"
)
return response
# Verification failed - redirect with error
from fastapi.responses import RedirectResponse
return RedirectResponse(url="/login?duo_auth=failed&error=Duo+verification+failed", status_code=302)
except Exception as e:
logger.error(f"Duo callback error: {e}", module="2FA")
return {'success': False, 'message': 'Duo authentication failed. Please try again.'}
@router.post("/duo/enroll")
async def duo_enroll(
duo_username: Optional[str] = Body(None, embed=True),
current_user: Dict = Depends(get_current_user_dependency)
):
"""Enroll user in Duo"""
try:
username = current_user.get('sub')
success = duo_manager.enroll_user(username, duo_username)
if success:
return {'success': True, 'message': 'Duo enrollment successful'}
else:
return {'success': False, 'message': 'Duo enrollment failed'}
except Exception as e:
logger.error(f"Duo enroll error: {e}", module="2FA")
return {'success': False, 'message': 'Failed to enroll in Duo. Please try again.'}
@router.post("/duo/unenroll")
async def duo_unenroll(current_user: Dict = Depends(get_current_user_dependency)):
"""Unenroll user from Duo"""
try:
username = current_user.get('sub')
success = duo_manager.unenroll_user(username)
if success:
return {'success': True, 'message': 'Duo unenrollment successful'}
else:
return {'success': False, 'message': 'Duo unenrollment failed'}
except Exception as e:
logger.error(f"Duo unenroll error: {e}", module="2FA")
return {'success': False, 'message': 'Failed to unenroll from Duo. Please try again.'}
@router.get("/duo/status")
async def duo_status(current_user: Dict = Depends(get_current_user_dependency)):
"""Get Duo status"""
try:
username = current_user.get('sub')
status = duo_manager.get_duo_status(username)
config_status = duo_manager.get_configuration_status()
return {
'success': True,
**status,
'duoConfigured': config_status['configured']
}
except Exception as e:
logger.error(f"Duo status error: {e}", module="2FA")
return {'success': False, 'message': 'Failed to retrieve Duo status. Please try again.'}
# ============================================================================
# Combined 2FA Status Endpoint
# ============================================================================
@router.get("/status")
async def twofa_status(current_user: Dict = Depends(get_current_user_dependency)):
"""Get complete 2FA status for user"""
try:
username = current_user.get('sub')
totp_status = totp_manager.get_totp_status(username)
# Get passkey status and add availability flag
if PASSKEY_AVAILABLE:
passkey_status = passkey_manager.get_passkey_status(username)
passkey_status['available'] = True
else:
passkey_status = {'enabled': False, 'available': False, 'credentialCount': 0}
duo_status = duo_manager.get_duo_status(username)
duo_config_status = duo_manager.get_configuration_status()
# Determine available methods
available_methods = []
if totp_status['enabled']:
available_methods.append('totp')
if PASSKEY_AVAILABLE and passkey_status['enabled']:
available_methods.append('passkey')
if duo_status['enabled'] and duo_config_status['configured']:
available_methods.append('duo')
return {
'success': True,
'totp': totp_status,
'passkey': passkey_status,
'duo': {**duo_status, 'duoConfigured': duo_config_status['configured']},
'availableMethods': available_methods,
'anyEnabled': len(available_methods) > 0
}
except Exception as e:
logger.error(f"2FA status error: {e}", module="2FA")
return {'success': False, 'message': str(e)}
return router