762
web/backend/api.py
Normal file
762
web/backend/api.py
Normal file
@@ -0,0 +1,762 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Media Downloader Web API
|
||||
FastAPI backend that integrates with the existing media-downloader system
|
||||
"""
|
||||
|
||||
# Suppress pkg_resources deprecation warning from face_recognition_models
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore", message=".*pkg_resources is deprecated.*")
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Bootstrap database backend (must be before any database imports)
|
||||
sys.path.insert(0, str(__import__('pathlib').Path(__file__).parent.parent.parent))
|
||||
import modules.db_bootstrap # noqa: E402,F401
|
||||
import json
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from starlette_csrf import CSRFMiddleware
|
||||
import sqlite3
|
||||
import secrets
|
||||
import re
|
||||
|
||||
# Redis cache manager
|
||||
from cache_manager import cache_manager, invalidate_download_cache
|
||||
|
||||
# Rate limiting
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.util import get_remote_address
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
|
||||
# Add parent directory to path to import media-downloader modules
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from modules.unified_database import UnifiedDatabase
|
||||
from modules.scheduler import DownloadScheduler
|
||||
from modules.settings_manager import SettingsManager
|
||||
from modules.universal_logger import get_logger
|
||||
from modules.discovery_system import get_discovery_system
|
||||
from modules.semantic_search import get_semantic_search
|
||||
from web.backend.auth_manager import AuthManager
|
||||
from web.backend.core.dependencies import AppState
|
||||
|
||||
# Initialize universal logger for API
|
||||
logger = get_logger('API')
|
||||
|
||||
# ============================================================================
|
||||
# CONFIGURATION (use centralized settings from core/config.py)
|
||||
# ============================================================================
|
||||
from web.backend.core.config import settings
|
||||
|
||||
# Legacy aliases for backward compatibility (prefer using settings.* directly)
|
||||
PROJECT_ROOT = settings.PROJECT_ROOT
|
||||
DB_PATH = settings.DB_PATH
|
||||
CONFIG_PATH = settings.CONFIG_PATH
|
||||
LOG_PATH = settings.LOG_PATH
|
||||
TEMP_DIR = settings.TEMP_DIR
|
||||
DB_POOL_SIZE = settings.DB_POOL_SIZE
|
||||
DB_CONNECTION_TIMEOUT = settings.DB_CONNECTION_TIMEOUT
|
||||
PROCESS_TIMEOUT_SHORT = settings.PROCESS_TIMEOUT_SHORT
|
||||
PROCESS_TIMEOUT_MEDIUM = settings.PROCESS_TIMEOUT_MEDIUM
|
||||
PROCESS_TIMEOUT_LONG = settings.PROCESS_TIMEOUT_LONG
|
||||
WEBSOCKET_TIMEOUT = settings.WEBSOCKET_TIMEOUT
|
||||
ACTIVITY_LOG_INTERVAL = settings.ACTIVITY_LOG_INTERVAL
|
||||
EMBEDDING_QUEUE_CHECK_INTERVAL = settings.EMBEDDING_QUEUE_CHECK_INTERVAL
|
||||
EMBEDDING_BATCH_SIZE = settings.EMBEDDING_BATCH_SIZE
|
||||
EMBEDDING_BATCH_LIMIT = settings.EMBEDDING_BATCH_LIMIT
|
||||
THUMBNAIL_DB_TIMEOUT = settings.THUMBNAIL_DB_TIMEOUT
|
||||
|
||||
# ============================================================================
|
||||
# PYDANTIC MODELS (imported from models/api_models.py to avoid duplication)
|
||||
# ============================================================================
|
||||
|
||||
from web.backend.models.api_models import (
|
||||
DownloadResponse,
|
||||
StatsResponse,
|
||||
PlatformStatus,
|
||||
TriggerRequest,
|
||||
PushoverConfig,
|
||||
SchedulerConfig,
|
||||
ConfigUpdate,
|
||||
HealthStatus,
|
||||
LoginRequest,
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# AUTHENTICATION DEPENDENCIES (imported from core/dependencies.py)
|
||||
# ============================================================================
|
||||
|
||||
from web.backend.core.dependencies import (
|
||||
get_current_user,
|
||||
require_admin,
|
||||
get_app_state,
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# GLOBAL STATE
|
||||
# ============================================================================
|
||||
|
||||
# Use the shared singleton from core/dependencies (used by all routers)
|
||||
app_state = get_app_state()
|
||||
|
||||
# ============================================================================
|
||||
# LIFESPAN MANAGEMENT
|
||||
# ============================================================================
|
||||
|
||||
def init_thumbnail_db():
|
||||
"""Initialize thumbnail cache database"""
|
||||
thumb_db_path = Path(__file__).parent.parent.parent / 'database' / 'thumbnails.db'
|
||||
thumb_db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
conn = sqlite3.connect(str(thumb_db_path))
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS thumbnails (
|
||||
file_hash TEXT PRIMARY KEY,
|
||||
file_path TEXT NOT NULL,
|
||||
thumbnail_data BLOB NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
file_mtime REAL NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_file_path ON thumbnails(file_path)")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return thumb_db_path
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Initialize and cleanup app resources"""
|
||||
# Startup
|
||||
logger.info("🚀 Starting Media Downloader API...", module="Core")
|
||||
|
||||
# Initialize database with larger pool for API workers
|
||||
app_state.db = UnifiedDatabase(str(DB_PATH), use_pool=True, pool_size=DB_POOL_SIZE)
|
||||
logger.info(f"✓ Connected to database: {DB_PATH} (pool_size={DB_POOL_SIZE})", module="Core")
|
||||
|
||||
# Reset any stuck downloads from previous API run
|
||||
try:
|
||||
with app_state.db.get_connection(for_write=True) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
UPDATE video_download_queue
|
||||
SET status = 'pending', progress = 0, started_at = NULL
|
||||
WHERE status = 'downloading'
|
||||
''')
|
||||
reset_count = cursor.rowcount
|
||||
conn.commit()
|
||||
if reset_count > 0:
|
||||
logger.info(f"✓ Reset {reset_count} stuck download(s) to pending", module="Core")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not reset stuck downloads: {e}", module="Core")
|
||||
|
||||
# Clear any stale background tasks from previous API run
|
||||
try:
|
||||
with app_state.db.get_connection(for_write=True) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
UPDATE background_task_status
|
||||
SET active = 0
|
||||
WHERE active = 1
|
||||
''')
|
||||
cleared_count = cursor.rowcount
|
||||
conn.commit()
|
||||
if cleared_count > 0:
|
||||
logger.info(f"✓ Cleared {cleared_count} stale background task(s)", module="Core")
|
||||
except Exception as e:
|
||||
# Table might not exist yet, that's fine
|
||||
pass
|
||||
|
||||
# Initialize face recognition semaphore (limit to 1 concurrent process)
|
||||
app_state.face_recognition_semaphore = asyncio.Semaphore(1)
|
||||
logger.info("✓ Face recognition semaphore initialized (max 1 concurrent)", module="Core")
|
||||
|
||||
# Initialize authentication manager
|
||||
app_state.auth = AuthManager()
|
||||
logger.info("✓ Authentication manager initialized", module="Core")
|
||||
|
||||
# Initialize and include 2FA router
|
||||
twofa_router = create_2fa_router(app_state.auth, get_current_user)
|
||||
app.include_router(twofa_router, prefix="/api/auth/2fa", tags=["2FA"])
|
||||
logger.info("✓ 2FA routes initialized", module="Core")
|
||||
|
||||
# Include modular API routers
|
||||
from web.backend.routers import (
|
||||
auth_router,
|
||||
health_router,
|
||||
downloads_router,
|
||||
media_router,
|
||||
recycle_router,
|
||||
scheduler_router,
|
||||
video_router,
|
||||
config_router,
|
||||
review_router,
|
||||
face_router,
|
||||
platforms_router,
|
||||
discovery_router,
|
||||
scrapers_router,
|
||||
semantic_router,
|
||||
manual_import_router,
|
||||
stats_router,
|
||||
celebrity_router,
|
||||
video_queue_router,
|
||||
maintenance_router,
|
||||
files_router,
|
||||
appearances_router,
|
||||
easynews_router,
|
||||
dashboard_router,
|
||||
paid_content_router,
|
||||
private_gallery_router,
|
||||
instagram_unified_router,
|
||||
cloud_backup_router,
|
||||
press_router
|
||||
)
|
||||
|
||||
# Include all modular routers
|
||||
# Note: websocket_manager is set later after ConnectionManager is created
|
||||
app.include_router(auth_router)
|
||||
app.include_router(health_router)
|
||||
app.include_router(downloads_router)
|
||||
app.include_router(media_router)
|
||||
app.include_router(recycle_router)
|
||||
app.include_router(scheduler_router)
|
||||
app.include_router(video_router)
|
||||
app.include_router(config_router)
|
||||
app.include_router(review_router)
|
||||
app.include_router(face_router)
|
||||
app.include_router(platforms_router)
|
||||
app.include_router(discovery_router)
|
||||
app.include_router(scrapers_router)
|
||||
app.include_router(semantic_router)
|
||||
app.include_router(manual_import_router)
|
||||
app.include_router(stats_router)
|
||||
app.include_router(celebrity_router)
|
||||
app.include_router(video_queue_router)
|
||||
app.include_router(maintenance_router)
|
||||
app.include_router(files_router)
|
||||
app.include_router(appearances_router)
|
||||
app.include_router(easynews_router)
|
||||
app.include_router(dashboard_router)
|
||||
app.include_router(paid_content_router)
|
||||
app.include_router(private_gallery_router)
|
||||
app.include_router(instagram_unified_router)
|
||||
app.include_router(cloud_backup_router)
|
||||
app.include_router(press_router)
|
||||
|
||||
logger.info("✓ All 28 modular routers registered", module="Core")
|
||||
|
||||
# Initialize settings manager
|
||||
app_state.settings = SettingsManager(str(DB_PATH))
|
||||
logger.info("✓ Settings manager initialized", module="Core")
|
||||
|
||||
# Check if settings need to be migrated from JSON
|
||||
existing_settings = app_state.settings.get_all()
|
||||
if not existing_settings or len(existing_settings) == 0:
|
||||
logger.info("⚠ No settings in database, migrating from JSON...", module="Core")
|
||||
if CONFIG_PATH.exists():
|
||||
app_state.settings.migrate_from_json(str(CONFIG_PATH))
|
||||
logger.info("✓ Settings migrated from JSON to database", module="Core")
|
||||
else:
|
||||
logger.info("⚠ No JSON config file found", module="Core")
|
||||
|
||||
# Load configuration from database
|
||||
app_state.config = app_state.settings.get_all()
|
||||
if app_state.config:
|
||||
logger.info(f"✓ Loaded configuration from database", module="Core")
|
||||
else:
|
||||
logger.info("⚠ No configuration in database - please configure via web interface", module="Core")
|
||||
app_state.config = {}
|
||||
|
||||
# Initialize scheduler (for status checking only - actual scheduler runs as separate service)
|
||||
app_state.scheduler = DownloadScheduler(config_path=None, unified_db=app_state.db, settings_manager=app_state.settings)
|
||||
logger.info("✓ Scheduler connected (status monitoring only)", module="Core")
|
||||
|
||||
# Initialize thumbnail database
|
||||
init_thumbnail_db()
|
||||
logger.info("✓ Thumbnail cache initialized", module="Core")
|
||||
|
||||
# Initialize recycle bin settings with defaults if not exists
|
||||
if not app_state.settings.get('recycle_bin'):
|
||||
recycle_bin_defaults = {
|
||||
'enabled': True,
|
||||
'path': '/opt/immich/recycle',
|
||||
'retention_days': 30,
|
||||
'max_size_gb': 50,
|
||||
'auto_cleanup': True
|
||||
}
|
||||
app_state.settings.set('recycle_bin', recycle_bin_defaults, category='recycle_bin', description='Recycle bin configuration', updated_by='system')
|
||||
logger.info("✓ Recycle bin settings initialized with defaults", module="Core")
|
||||
|
||||
logger.info("✓ Media Downloader API ready!", module="Core")
|
||||
|
||||
# Start periodic WAL checkpoint task
|
||||
async def periodic_wal_checkpoint():
|
||||
"""Run WAL checkpoint every 5 minutes to prevent WAL file growth"""
|
||||
while True:
|
||||
await asyncio.sleep(300) # 5 minutes
|
||||
try:
|
||||
if app_state.db:
|
||||
app_state.db.checkpoint()
|
||||
logger.debug("WAL checkpoint completed", module="Core")
|
||||
except Exception as e:
|
||||
logger.warning(f"WAL checkpoint failed: {e}", module="Core")
|
||||
|
||||
checkpoint_task = asyncio.create_task(periodic_wal_checkpoint())
|
||||
logger.info("✓ WAL checkpoint task started (every 5 minutes)", module="Core")
|
||||
|
||||
# Start discovery queue processor
|
||||
async def process_discovery_queue():
|
||||
"""Background task to process discovery scan queue (embeddings)"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(30) # Check queue every 30 seconds
|
||||
|
||||
if not app_state.db:
|
||||
continue
|
||||
|
||||
# Process embeddings first (limit batch size for memory efficiency)
|
||||
pending_embeddings = app_state.db.get_pending_discovery_scans(limit=EMBEDDING_BATCH_SIZE, scan_type='embedding')
|
||||
|
||||
for item in pending_embeddings:
|
||||
try:
|
||||
queue_id = item['id']
|
||||
file_id = item['file_id']
|
||||
file_path = item['file_path']
|
||||
|
||||
# Mark as started
|
||||
app_state.db.mark_discovery_scan_started(queue_id)
|
||||
|
||||
# Check if file still exists
|
||||
from pathlib import Path
|
||||
if not Path(file_path).exists():
|
||||
app_state.db.mark_discovery_scan_completed(queue_id)
|
||||
continue
|
||||
|
||||
# Check if embedding already exists
|
||||
with app_state.db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('SELECT 1 FROM content_embeddings WHERE file_id = ?', (file_id,))
|
||||
if cursor.fetchone():
|
||||
# Already has embedding
|
||||
app_state.db.mark_discovery_scan_completed(queue_id)
|
||||
continue
|
||||
|
||||
# Get content type
|
||||
cursor.execute('SELECT content_type FROM file_inventory WHERE id = ?', (file_id,))
|
||||
row = cursor.fetchone()
|
||||
content_type = row['content_type'] if row else 'image'
|
||||
|
||||
# Generate embedding (CPU intensive - run in executor)
|
||||
def generate():
|
||||
semantic = get_semantic_search(app_state.db)
|
||||
return semantic.generate_embedding_for_file(file_id, file_path, content_type)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
success = await loop.run_in_executor(None, generate)
|
||||
|
||||
if success:
|
||||
app_state.db.mark_discovery_scan_completed(queue_id)
|
||||
logger.debug(f"Generated embedding for file_id {file_id}", module="Discovery")
|
||||
else:
|
||||
app_state.db.mark_discovery_scan_failed(queue_id, "Embedding generation returned False")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process embedding queue item {item.get('id')}: {e}", module="Discovery")
|
||||
app_state.db.mark_discovery_scan_failed(item.get('id', 0), str(e))
|
||||
|
||||
# Small delay between items to avoid CPU spikes
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Discovery queue processor error: {e}", module="Discovery")
|
||||
await asyncio.sleep(60) # Wait longer on errors
|
||||
|
||||
discovery_task = asyncio.create_task(process_discovery_queue())
|
||||
logger.info("✓ Discovery queue processor started (embeddings, checks every 30s)", module="Core")
|
||||
|
||||
# Start periodic quality recheck for Fansly videos (hourly)
|
||||
async def periodic_quality_recheck():
|
||||
"""Recheck flagged Fansly attachments for higher quality every hour"""
|
||||
await asyncio.sleep(300) # Wait 5 min after startup before first check
|
||||
while True:
|
||||
try:
|
||||
from web.backend.routers.paid_content import _auto_quality_recheck_background
|
||||
await _auto_quality_recheck_background()
|
||||
except Exception as e:
|
||||
logger.warning(f"Periodic quality recheck failed: {e}", module="Core")
|
||||
await asyncio.sleep(3600) # 1 hour
|
||||
|
||||
quality_recheck_task = asyncio.create_task(periodic_quality_recheck())
|
||||
logger.info("✓ Quality recheck task started (hourly)", module="Core")
|
||||
|
||||
# Auto-start download queue if enabled
|
||||
try:
|
||||
queue_settings = app_state.settings.get('video_queue')
|
||||
if queue_settings and queue_settings.get('auto_start_on_restart'):
|
||||
# Check if there are pending items
|
||||
with app_state.db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM video_download_queue WHERE status = 'pending'")
|
||||
pending_count = cursor.fetchone()[0]
|
||||
|
||||
if pending_count > 0:
|
||||
# Import the queue processor and start it
|
||||
from web.backend.routers.video_queue import queue_processor, run_queue_processor
|
||||
if not queue_processor.is_running:
|
||||
queue_processor.start()
|
||||
asyncio.create_task(run_queue_processor(app_state))
|
||||
logger.info(f"✓ Auto-started download queue with {pending_count} pending items", module="Core")
|
||||
else:
|
||||
logger.info("✓ Auto-start enabled but no pending items in queue", module="Core")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not auto-start download queue: {e}", module="Core")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("🛑 Shutting down Media Downloader API...", module="Core")
|
||||
checkpoint_task.cancel()
|
||||
discovery_task.cancel()
|
||||
quality_recheck_task.cancel()
|
||||
try:
|
||||
await checkpoint_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
try:
|
||||
await discovery_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
try:
|
||||
await quality_recheck_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Close shared HTTP client
|
||||
try:
|
||||
from web.backend.core.http_client import http_client
|
||||
await http_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
if app_state.scheduler:
|
||||
app_state.scheduler.stop()
|
||||
if app_state.db:
|
||||
app_state.db.close()
|
||||
logger.info("✓ Cleanup complete", module="Core")
|
||||
|
||||
# ============================================================================
|
||||
# FASTAPI APP
|
||||
# ============================================================================
|
||||
|
||||
app = FastAPI(
|
||||
title="Media Downloader API",
|
||||
description="Web API for managing media downloads from Instagram, TikTok, Snapchat, and Forums",
|
||||
version="13.13.1",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# CORS middleware - allow frontend origins
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
|
||||
allow_headers=[
|
||||
"Authorization",
|
||||
"Content-Type",
|
||||
"Accept",
|
||||
"Origin",
|
||||
"X-Requested-With",
|
||||
"X-CSRFToken",
|
||||
"Cache-Control",
|
||||
"Range",
|
||||
],
|
||||
max_age=600, # Cache preflight requests for 10 minutes
|
||||
)
|
||||
|
||||
# GZip compression for responses > 1KB (70% smaller API payloads)
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
|
||||
# Session middleware for 2FA setup flow
|
||||
# Persist generated key to file to avoid invalidating sessions on restart
|
||||
def _load_session_secret():
|
||||
secret_file = Path(__file__).parent.parent.parent / '.session_secret'
|
||||
if os.getenv("SESSION_SECRET_KEY"):
|
||||
return os.getenv("SESSION_SECRET_KEY")
|
||||
if secret_file.exists():
|
||||
with open(secret_file, 'r') as f:
|
||||
secret = f.read().strip()
|
||||
if len(secret) >= 32:
|
||||
return secret
|
||||
new_secret = secrets.token_urlsafe(32)
|
||||
try:
|
||||
fd = os.open(str(secret_file), os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
||||
with os.fdopen(fd, 'w') as f:
|
||||
f.write(new_secret)
|
||||
except Exception:
|
||||
pass # Fall back to in-memory only
|
||||
return new_secret
|
||||
|
||||
SESSION_SECRET_KEY = _load_session_secret()
|
||||
app.add_middleware(SessionMiddleware, secret_key=SESSION_SECRET_KEY)
|
||||
|
||||
# CSRF protection middleware - enabled for state-changing requests
|
||||
# Note: All endpoints are also protected by JWT authentication for double security
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
class CSRFMiddlewareHTTPOnly:
|
||||
"""Wrapper to apply CSRF middleware only to HTTP requests, not WebSocket"""
|
||||
def __init__(self, app: ASGIApp, csrf_middleware_class, **kwargs):
|
||||
self.app = app
|
||||
self.csrf_middleware = csrf_middleware_class(app, **kwargs)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
if scope["type"] == "websocket":
|
||||
# Skip CSRF for WebSocket connections (uses separate auth)
|
||||
await self.app(scope, receive, send)
|
||||
else:
|
||||
# Apply CSRF for HTTP requests
|
||||
await self.csrf_middleware(scope, receive, send)
|
||||
|
||||
CSRF_SECRET_KEY = settings.CSRF_SECRET_KEY or SESSION_SECRET_KEY
|
||||
|
||||
app.add_middleware(
|
||||
CSRFMiddlewareHTTPOnly,
|
||||
csrf_middleware_class=CSRFMiddleware,
|
||||
secret=CSRF_SECRET_KEY,
|
||||
cookie_name="csrftoken",
|
||||
header_name="X-CSRFToken",
|
||||
cookie_secure=settings.SECURE_COOKIES,
|
||||
cookie_httponly=False, # JS needs to read for SPA
|
||||
cookie_samesite="lax", # CSRF protection
|
||||
exempt_urls=[
|
||||
# API endpoints are protected against CSRF via multiple layers:
|
||||
# 1. Primary: JWT in Authorization header (cannot be forged cross-origin)
|
||||
# 2. Secondary: Cookie auth uses samesite='lax' which blocks cross-origin POSTs
|
||||
# 3. Media endpoints (cookie-only) are GET requests, safe from CSRF
|
||||
# This makes additional CSRF token validation redundant for /api/* routes
|
||||
re.compile(r"^/api/.*"),
|
||||
re.compile(r"^/ws$"), # WebSocket endpoint (has separate auth via cookie)
|
||||
]
|
||||
)
|
||||
|
||||
# Rate limiting with Redis persistence
|
||||
# Falls back to in-memory if Redis unavailable
|
||||
redis_uri = f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB}"
|
||||
try:
|
||||
limiter = Limiter(key_func=get_remote_address, storage_uri=redis_uri)
|
||||
logger.info("Rate limiter using Redis storage", module="RateLimit")
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis rate limiting unavailable, using in-memory: {e}", module="RateLimit")
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
|
||||
# Import and include 2FA routes
|
||||
from web.backend.twofa_routes import create_2fa_router
|
||||
|
||||
# ============================================================================
|
||||
# WEBSOCKET CONNECTION MANAGER
|
||||
# ============================================================================
|
||||
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections: List[WebSocket] = []
|
||||
self._loop = None
|
||||
|
||||
async def connect(self, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
self.active_connections.append(websocket)
|
||||
# Capture event loop for thread-safe broadcasts
|
||||
if self._loop is None:
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
def disconnect(self, websocket: WebSocket):
|
||||
try:
|
||||
self.active_connections.remove(websocket)
|
||||
except ValueError:
|
||||
pass # Already removed
|
||||
|
||||
async def broadcast(self, message: dict):
|
||||
logger.info(f"ConnectionManager: Broadcasting to {len(self.active_connections)} clients", module="WebSocket")
|
||||
|
||||
dead_connections = []
|
||||
for connection in self.active_connections:
|
||||
try:
|
||||
await connection.send_json(message)
|
||||
except Exception as e:
|
||||
logger.debug(f"ConnectionManager: Removing dead connection: {e}", module="WebSocket")
|
||||
dead_connections.append(connection)
|
||||
|
||||
# Remove dead connections
|
||||
for connection in dead_connections:
|
||||
try:
|
||||
self.active_connections.remove(connection)
|
||||
except ValueError:
|
||||
pass # Already removed
|
||||
|
||||
def broadcast_sync(self, message: dict):
|
||||
"""Thread-safe broadcast for use in background tasks (sync threads)"""
|
||||
msg_type = message.get('type', 'unknown')
|
||||
|
||||
if not self.active_connections:
|
||||
logger.info(f"broadcast_sync: No active WebSocket connections for {msg_type}", module="WebSocket")
|
||||
return
|
||||
|
||||
logger.info(f"broadcast_sync: Sending {msg_type} to {len(self.active_connections)} clients", module="WebSocket")
|
||||
|
||||
try:
|
||||
# Try to get running loop (we're in async context)
|
||||
loop = asyncio.get_running_loop()
|
||||
asyncio.create_task(self.broadcast(message))
|
||||
logger.info(f"broadcast_sync: Created task for {msg_type}", module="WebSocket")
|
||||
except RuntimeError:
|
||||
# No running loop - we're in a thread, use thread-safe call
|
||||
if self._loop and self._loop.is_running():
|
||||
asyncio.run_coroutine_threadsafe(self.broadcast(message), self._loop)
|
||||
logger.info(f"broadcast_sync: Used threadsafe call for {msg_type}", module="WebSocket")
|
||||
else:
|
||||
logger.warning(f"broadcast_sync: No event loop available for {msg_type}", module="WebSocket")
|
||||
|
||||
manager = ConnectionManager()
|
||||
|
||||
# Store websocket manager in app_state for modular routers to access
|
||||
app_state.websocket_manager = manager
|
||||
|
||||
# Initialize scraper event emitter for real-time monitoring
|
||||
from modules.scraper_event_emitter import ScraperEventEmitter
|
||||
app_state.scraper_event_emitter = ScraperEventEmitter(websocket_manager=manager, app_state=app_state)
|
||||
logger.info("Scraper event emitter initialized for real-time monitoring", module="WebSocket")
|
||||
|
||||
# ============================================================================
|
||||
# WEBSOCKET ENDPOINT
|
||||
# ============================================================================
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
"""WebSocket endpoint for real-time updates.
|
||||
|
||||
Authentication priority:
|
||||
1. Cookie (most secure, set by browser automatically)
|
||||
2. Sec-WebSocket-Protocol header with 'auth.<token>' (secure, not logged)
|
||||
3. Query parameter (fallback, may be logged in access logs)
|
||||
"""
|
||||
auth_token = None
|
||||
|
||||
# Priority 1: Try cookie (most common for browsers)
|
||||
auth_token = websocket.cookies.get('auth_token')
|
||||
|
||||
# Priority 2: Check Sec-WebSocket-Protocol header for "auth.<token>" subprotocol
|
||||
# This is more secure than query params as it's not logged in access logs
|
||||
if not auth_token:
|
||||
protocols = websocket.headers.get('sec-websocket-protocol', '')
|
||||
for protocol in protocols.split(','):
|
||||
protocol = protocol.strip()
|
||||
if protocol.startswith('auth.'):
|
||||
auth_token = protocol[5:] # Remove 'auth.' prefix
|
||||
break
|
||||
|
||||
# Priority 3: Query parameter fallback (least preferred, logged in access logs)
|
||||
if not auth_token:
|
||||
auth_token = websocket.query_params.get('token')
|
||||
|
||||
if not auth_token:
|
||||
await websocket.close(code=4001, reason="Authentication required")
|
||||
logger.warning("WebSocket: Connection rejected - no auth cookie or token", module="WebSocket")
|
||||
return
|
||||
|
||||
# Verify the token
|
||||
payload = app_state.auth.verify_session(auth_token)
|
||||
if not payload:
|
||||
await websocket.close(code=4001, reason="Invalid or expired token")
|
||||
logger.warning("WebSocket: Connection rejected - invalid token", module="WebSocket")
|
||||
return
|
||||
|
||||
username = payload.get('sub', 'unknown')
|
||||
|
||||
# Accept with the auth subprotocol if that was used
|
||||
protocols = websocket.headers.get('sec-websocket-protocol', '')
|
||||
if any(p.strip().startswith('auth.') for p in protocols.split(',')):
|
||||
await websocket.accept(subprotocol='auth')
|
||||
manager.active_connections.append(websocket)
|
||||
if manager._loop is None:
|
||||
manager._loop = asyncio.get_running_loop()
|
||||
else:
|
||||
await manager.connect(websocket)
|
||||
try:
|
||||
# Send initial connection message
|
||||
await websocket.send_json({
|
||||
"type": "connected",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"user": username
|
||||
})
|
||||
logger.info(f"WebSocket: Client connected (user: {username})", module="WebSocket")
|
||||
|
||||
while True:
|
||||
# Keep connection alive with ping/pong
|
||||
# Use raw receive() instead of receive_text() for better proxy compatibility
|
||||
try:
|
||||
message = await asyncio.wait_for(websocket.receive(), timeout=30.0)
|
||||
msg_type = message.get("type", "unknown")
|
||||
|
||||
if msg_type == "websocket.receive":
|
||||
text = message.get("text")
|
||||
if text:
|
||||
# Echo back client messages
|
||||
await websocket.send_json({
|
||||
"type": "echo",
|
||||
"message": text
|
||||
})
|
||||
elif msg_type == "websocket.disconnect":
|
||||
# Client closed connection
|
||||
code = message.get("code", "unknown")
|
||||
logger.info(f"WebSocket: Client {username} sent disconnect (code={code})", module="WebSocket")
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Send ping to keep connection alive
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "ping",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
except Exception as e:
|
||||
# Connection lost
|
||||
logger.info(f"WebSocket: Ping failed for {username}: {type(e).__name__}: {e}", module="WebSocket")
|
||||
break
|
||||
except WebSocketDisconnect as e:
|
||||
logger.info(f"WebSocket: {username} disconnected (code={e.code})", module="WebSocket")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error for {username}: {type(e).__name__}: {e}", module="WebSocket")
|
||||
finally:
|
||||
manager.disconnect(websocket)
|
||||
logger.info(f"WebSocket: {username} removed from active connections ({len(manager.active_connections)} remaining)", module="WebSocket")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# MAIN ENTRY POINT
|
||||
# ============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(
|
||||
"api:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
workers=1, # Single worker needed for WebSocket broadcasts to work correctly
|
||||
timeout_keep_alive=30,
|
||||
log_level="info",
|
||||
limit_concurrency=1000, # Allow up to 1000 concurrent connections
|
||||
backlog=2048 # Queue size for pending connections
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user