1
web/backend/core/__init__.py
Normal file
1
web/backend/core/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Core module exports
|
||||
121
web/backend/core/config.py
Normal file
121
web/backend/core/config.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
Unified Configuration Manager
|
||||
|
||||
Provides a single source of truth for all configuration values with a clear hierarchy:
|
||||
1. Environment variables (highest priority)
|
||||
2. .env file
|
||||
3. Database settings
|
||||
4. Hardcoded defaults (lowest priority)
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from functools import lru_cache
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
ENV_PATH = PROJECT_ROOT / '.env'
|
||||
if ENV_PATH.exists():
|
||||
load_dotenv(ENV_PATH)
|
||||
|
||||
|
||||
class Settings:
|
||||
"""Centralized configuration settings"""
|
||||
|
||||
# Project paths
|
||||
PROJECT_ROOT: Path = PROJECT_ROOT
|
||||
DB_PATH: Path = PROJECT_ROOT / 'database' / 'media_downloader.db'
|
||||
CONFIG_PATH: Path = PROJECT_ROOT / 'config' / 'settings.json'
|
||||
LOG_PATH: Path = PROJECT_ROOT / 'logs'
|
||||
TEMP_DIR: Path = PROJECT_ROOT / 'temp'
|
||||
COOKIES_DIR: Path = PROJECT_ROOT / 'cookies'
|
||||
DATA_DIR: Path = PROJECT_ROOT / 'data'
|
||||
VENV_BIN: Path = PROJECT_ROOT / 'venv' / 'bin'
|
||||
MEDIA_BASE_PATH: Path = Path(os.getenv("MEDIA_BASE_PATH", "/opt/immich/md"))
|
||||
REVIEW_PATH: Path = Path(os.getenv("REVIEW_PATH", "/opt/immich/review"))
|
||||
RECYCLE_PATH: Path = Path(os.getenv("RECYCLE_PATH", "/opt/immich/recycle"))
|
||||
|
||||
# External tool paths (use venv by default)
|
||||
YT_DLP_PATH: Path = VENV_BIN / 'yt-dlp'
|
||||
GALLERY_DL_PATH: Path = VENV_BIN / 'gallery-dl'
|
||||
PYTHON_PATH: Path = VENV_BIN / 'python3'
|
||||
|
||||
# Database configuration
|
||||
DATABASE_BACKEND: str = os.getenv("DATABASE_BACKEND", "sqlite")
|
||||
DATABASE_URL: str = os.getenv("DATABASE_URL", "")
|
||||
DB_POOL_SIZE: int = int(os.getenv("DB_POOL_SIZE", "20"))
|
||||
DB_CONNECTION_TIMEOUT: float = float(os.getenv("DB_CONNECTION_TIMEOUT", "30.0"))
|
||||
|
||||
# Process timeouts (seconds)
|
||||
PROCESS_TIMEOUT_SHORT: int = int(os.getenv("PROCESS_TIMEOUT_SHORT", "2"))
|
||||
PROCESS_TIMEOUT_MEDIUM: int = int(os.getenv("PROCESS_TIMEOUT_MEDIUM", "10"))
|
||||
PROCESS_TIMEOUT_LONG: int = int(os.getenv("PROCESS_TIMEOUT_LONG", "120"))
|
||||
|
||||
# WebSocket configuration
|
||||
WEBSOCKET_TIMEOUT: float = float(os.getenv("WEBSOCKET_TIMEOUT", "30.0"))
|
||||
|
||||
# Background task intervals (seconds)
|
||||
ACTIVITY_LOG_INTERVAL: int = int(os.getenv("ACTIVITY_LOG_INTERVAL", "300"))
|
||||
EMBEDDING_QUEUE_CHECK_INTERVAL: int = int(os.getenv("EMBEDDING_QUEUE_CHECK_INTERVAL", "30"))
|
||||
EMBEDDING_BATCH_SIZE: int = int(os.getenv("EMBEDDING_BATCH_SIZE", "10"))
|
||||
EMBEDDING_BATCH_LIMIT: int = int(os.getenv("EMBEDDING_BATCH_LIMIT", "50000"))
|
||||
|
||||
# Thumbnail generation
|
||||
THUMBNAIL_DB_TIMEOUT: float = float(os.getenv("THUMBNAIL_DB_TIMEOUT", "30.0"))
|
||||
THUMBNAIL_SIZE: tuple = (300, 300)
|
||||
|
||||
# Video proxy configuration
|
||||
PROXY_FILE_CACHE_DURATION: int = int(os.getenv("PROXY_FILE_CACHE_DURATION", "300"))
|
||||
|
||||
# Redis cache configuration
|
||||
REDIS_HOST: str = os.getenv("REDIS_HOST", "127.0.0.1")
|
||||
REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379"))
|
||||
REDIS_DB: int = int(os.getenv("REDIS_DB", "0"))
|
||||
REDIS_TTL: int = int(os.getenv("REDIS_TTL", "300"))
|
||||
|
||||
# API configuration
|
||||
API_VERSION: str = "13.13.1"
|
||||
API_TITLE: str = "Media Downloader API"
|
||||
API_DESCRIPTION: str = "Web API for managing media downloads"
|
||||
|
||||
# CORS configuration
|
||||
ALLOWED_ORIGINS: list = os.getenv(
|
||||
"ALLOWED_ORIGINS",
|
||||
"http://localhost:5173,http://localhost:3000,http://127.0.0.1:5173,http://127.0.0.1:3000"
|
||||
).split(",")
|
||||
|
||||
# Security
|
||||
SECURE_COOKIES: bool = os.getenv("SECURE_COOKIES", "false").lower() == "true"
|
||||
SESSION_SECRET_KEY: str = os.getenv("SESSION_SECRET_KEY", "")
|
||||
CSRF_SECRET_KEY: str = os.getenv("CSRF_SECRET_KEY", "")
|
||||
|
||||
# Rate limiting
|
||||
RATE_LIMIT_DEFAULT: str = os.getenv("RATE_LIMIT_DEFAULT", "100/minute")
|
||||
RATE_LIMIT_STRICT: str = os.getenv("RATE_LIMIT_STRICT", "10/minute")
|
||||
|
||||
@classmethod
|
||||
def get(cls, key: str, default: Any = None) -> Any:
|
||||
"""Get a configuration value by key"""
|
||||
return getattr(cls, key, default)
|
||||
|
||||
@classmethod
|
||||
def get_path(cls, key: str) -> Path:
|
||||
"""Get a path configuration value, ensuring it's a Path object"""
|
||||
value = getattr(cls, key, None)
|
||||
if value is None:
|
||||
raise ValueError(f"Configuration key {key} not found")
|
||||
if isinstance(value, Path):
|
||||
return value
|
||||
return Path(value)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
"""Get cached settings instance"""
|
||||
return Settings()
|
||||
|
||||
|
||||
# Convenience exports
|
||||
settings = get_settings()
|
||||
206
web/backend/core/dependencies.py
Normal file
206
web/backend/core/dependencies.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
Shared Dependencies
|
||||
|
||||
Provides dependency injection for FastAPI routes.
|
||||
All authentication, database access, and common dependencies are defined here.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import Depends, HTTPException, Query, Request, WebSocket
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
|
||||
|
||||
from modules.universal_logger import get_logger
|
||||
|
||||
logger = get_logger('API')
|
||||
|
||||
# Security
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# GLOBAL STATE (imported from main app)
|
||||
# ============================================================================
|
||||
|
||||
class AppState:
|
||||
"""Global application state - singleton pattern"""
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
self.db = None
|
||||
self.config: Dict = {}
|
||||
self.scheduler = None
|
||||
self.websocket_manager = None # ConnectionManager instance for broadcasts
|
||||
self.scraper_event_emitter = None # ScraperEventEmitter for real-time scraping monitor
|
||||
self.active_scraper_sessions: Dict = {} # Track active scraping sessions for monitor page
|
||||
self.running_platform_downloads: Dict = {} # Track running platform downloads {platform: process}
|
||||
self.download_tasks: Dict = {}
|
||||
self.auth = None
|
||||
self.settings = None
|
||||
self.face_recognition_semaphore = None
|
||||
self.indexing_running: bool = False
|
||||
self.indexing_start_time = None
|
||||
self.review_rescan_running: bool = False
|
||||
self.review_rescan_progress: Dict = {"current": 0, "total": 0, "current_file": ""}
|
||||
self._initialized = True
|
||||
|
||||
|
||||
def get_app_state() -> AppState:
|
||||
"""Get the global application state"""
|
||||
return AppState()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# AUTHENTICATION DEPENDENCIES
|
||||
# ============================================================================
|
||||
|
||||
async def get_current_user(
|
||||
request: Request,
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
||||
) -> Dict:
|
||||
"""
|
||||
Dependency to get current authenticated user from JWT token.
|
||||
Supports both Authorization header and cookie-based authentication.
|
||||
"""
|
||||
app_state = get_app_state()
|
||||
auth_token = None
|
||||
|
||||
# Try Authorization header first
|
||||
if credentials:
|
||||
auth_token = credentials.credentials
|
||||
# Try cookie second
|
||||
elif 'auth_token' in request.cookies:
|
||||
auth_token = request.cookies.get('auth_token')
|
||||
|
||||
if not auth_token:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
if not app_state.auth:
|
||||
raise HTTPException(status_code=500, detail="Authentication not initialized")
|
||||
|
||||
payload = app_state.auth.verify_session(auth_token)
|
||||
|
||||
if not payload:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
async def get_current_user_optional(
|
||||
request: Request,
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
||||
) -> Optional[Dict]:
|
||||
"""Optional authentication dependency - returns None if not authenticated"""
|
||||
try:
|
||||
return await get_current_user(request, credentials)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
async def get_current_user_media(
|
||||
request: Request,
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
|
||||
token: Optional[str] = Query(None)
|
||||
) -> Dict:
|
||||
"""
|
||||
Authentication for media endpoints.
|
||||
Supports header, cookie, and query parameter tokens (for img/video tags).
|
||||
Priority: 1) Authorization header, 2) Cookie, 3) Query parameter
|
||||
"""
|
||||
app_state = get_app_state()
|
||||
auth_token = None
|
||||
|
||||
# Try to get token from Authorization header first (preferred)
|
||||
if credentials:
|
||||
auth_token = credentials.credentials
|
||||
# Try cookie second (for <img> and <video> tags)
|
||||
elif 'auth_token' in request.cookies:
|
||||
auth_token = request.cookies.get('auth_token')
|
||||
# Fall back to query parameter (for backward compatibility)
|
||||
elif token:
|
||||
auth_token = token
|
||||
|
||||
if not auth_token:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
if not app_state.auth:
|
||||
raise HTTPException(status_code=500, detail="Authentication not initialized")
|
||||
|
||||
payload = app_state.auth.verify_session(auth_token)
|
||||
if not payload:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
async def require_admin(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
) -> Dict:
|
||||
"""Dependency to require admin role for sensitive operations"""
|
||||
app_state = get_app_state()
|
||||
username = current_user.get('sub')
|
||||
|
||||
if not username:
|
||||
raise HTTPException(status_code=401, detail="Invalid user session")
|
||||
|
||||
# Get user info to check role
|
||||
user_info = app_state.auth.get_user(username)
|
||||
if not user_info:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
if user_info.get('role') != 'admin':
|
||||
logger.warning(f"User {username} attempted admin operation without admin role", module="Auth")
|
||||
raise HTTPException(status_code=403, detail="Admin role required")
|
||||
|
||||
return current_user
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DATABASE DEPENDENCIES
|
||||
# ============================================================================
|
||||
|
||||
def get_database():
|
||||
"""Get the database connection"""
|
||||
app_state = get_app_state()
|
||||
if not app_state.db:
|
||||
raise HTTPException(status_code=500, detail="Database not initialized")
|
||||
return app_state.db
|
||||
|
||||
|
||||
def get_settings_manager():
|
||||
"""Get the settings manager"""
|
||||
app_state = get_app_state()
|
||||
if not app_state.settings:
|
||||
raise HTTPException(status_code=500, detail="Settings manager not initialized")
|
||||
return app_state.settings
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# UTILITY DEPENDENCIES
|
||||
# ============================================================================
|
||||
|
||||
def get_scheduler():
|
||||
"""Get the scheduler instance"""
|
||||
app_state = get_app_state()
|
||||
return app_state.scheduler
|
||||
|
||||
|
||||
def get_face_semaphore():
|
||||
"""Get the face recognition semaphore for limiting concurrent operations"""
|
||||
app_state = get_app_state()
|
||||
return app_state.face_recognition_semaphore
|
||||
346
web/backend/core/exceptions.py
Normal file
346
web/backend/core/exceptions.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""
|
||||
Custom Exception Classes
|
||||
|
||||
Provides specific exception types for better error handling and reporting.
|
||||
Replaces broad 'except Exception' handlers with specific, meaningful exceptions.
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# BASE EXCEPTIONS
|
||||
# ============================================================================
|
||||
|
||||
class MediaDownloaderError(Exception):
|
||||
"""Base exception for all Media Downloader errors"""
|
||||
|
||||
def __init__(self, message: str, details: Optional[Dict[str, Any]] = None):
|
||||
self.message = message
|
||||
self.details = details or {}
|
||||
super().__init__(self.message)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"error": self.__class__.__name__,
|
||||
"message": self.message,
|
||||
"details": self.details
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DATABASE EXCEPTIONS
|
||||
# ============================================================================
|
||||
|
||||
class DatabaseError(MediaDownloaderError):
|
||||
"""Database operation failed"""
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseConnectionError(DatabaseError):
|
||||
"""Failed to connect to database"""
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseQueryError(DatabaseError):
|
||||
"""Database query failed"""
|
||||
pass
|
||||
|
||||
|
||||
class RecordNotFoundError(DatabaseError):
|
||||
"""Requested record not found in database"""
|
||||
pass
|
||||
|
||||
|
||||
# Alias for generic "not found" scenarios
|
||||
NotFoundError = RecordNotFoundError
|
||||
|
||||
|
||||
class DuplicateRecordError(DatabaseError):
|
||||
"""Attempted to insert duplicate record"""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DOWNLOAD EXCEPTIONS
|
||||
# ============================================================================
|
||||
|
||||
class DownloadError(MediaDownloaderError):
|
||||
"""Download operation failed"""
|
||||
pass
|
||||
|
||||
|
||||
class NetworkError(DownloadError):
|
||||
"""Network request failed"""
|
||||
pass
|
||||
|
||||
|
||||
class RateLimitError(DownloadError):
|
||||
"""Rate limit exceeded"""
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationError(DownloadError):
|
||||
"""Authentication failed for external service"""
|
||||
pass
|
||||
|
||||
|
||||
class PlatformUnavailableError(DownloadError):
|
||||
"""Platform or service is unavailable"""
|
||||
pass
|
||||
|
||||
|
||||
class ContentNotFoundError(DownloadError):
|
||||
"""Requested content not found on platform"""
|
||||
pass
|
||||
|
||||
|
||||
class InvalidURLError(DownloadError):
|
||||
"""Invalid or malformed URL"""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# FILE OPERATION EXCEPTIONS
|
||||
# ============================================================================
|
||||
|
||||
class FileOperationError(MediaDownloaderError):
|
||||
"""File operation failed"""
|
||||
pass
|
||||
|
||||
|
||||
class MediaFileNotFoundError(FileOperationError):
|
||||
"""File not found on filesystem"""
|
||||
pass
|
||||
|
||||
|
||||
class FileAccessError(FileOperationError):
|
||||
"""Cannot access file (permissions, etc.)"""
|
||||
pass
|
||||
|
||||
|
||||
class FileHashError(FileOperationError):
|
||||
"""File hash computation failed"""
|
||||
pass
|
||||
|
||||
|
||||
class FileMoveError(FileOperationError):
|
||||
"""Failed to move file"""
|
||||
pass
|
||||
|
||||
|
||||
class ThumbnailError(FileOperationError):
|
||||
"""Thumbnail generation failed"""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# VALIDATION EXCEPTIONS
|
||||
# ============================================================================
|
||||
|
||||
class ValidationError(MediaDownloaderError):
|
||||
"""Input validation failed"""
|
||||
pass
|
||||
|
||||
|
||||
class InvalidParameterError(ValidationError):
|
||||
"""Invalid parameter value"""
|
||||
pass
|
||||
|
||||
|
||||
class MissingParameterError(ValidationError):
|
||||
"""Required parameter missing"""
|
||||
pass
|
||||
|
||||
|
||||
class PathTraversalError(ValidationError):
|
||||
"""Path traversal attempt detected"""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# AUTHENTICATION/AUTHORIZATION EXCEPTIONS
|
||||
# ============================================================================
|
||||
|
||||
class AuthError(MediaDownloaderError):
|
||||
"""Authentication or authorization error"""
|
||||
pass
|
||||
|
||||
|
||||
class TokenExpiredError(AuthError):
|
||||
"""Authentication token has expired"""
|
||||
pass
|
||||
|
||||
|
||||
class InvalidTokenError(AuthError):
|
||||
"""Authentication token is invalid"""
|
||||
pass
|
||||
|
||||
|
||||
class InsufficientPermissionsError(AuthError):
|
||||
"""User lacks required permissions"""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SERVICE EXCEPTIONS
|
||||
# ============================================================================
|
||||
|
||||
class ServiceError(MediaDownloaderError):
|
||||
"""External service error"""
|
||||
pass
|
||||
|
||||
|
||||
class FlareSolverrError(ServiceError):
|
||||
"""FlareSolverr service error"""
|
||||
pass
|
||||
|
||||
|
||||
class RedisError(ServiceError):
|
||||
"""Redis cache error"""
|
||||
pass
|
||||
|
||||
|
||||
class SchedulerError(ServiceError):
|
||||
"""Scheduler service error"""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# FACE RECOGNITION EXCEPTIONS
|
||||
# ============================================================================
|
||||
|
||||
class FaceRecognitionError(MediaDownloaderError):
|
||||
"""Face recognition operation failed"""
|
||||
pass
|
||||
|
||||
|
||||
class NoFaceDetectedError(FaceRecognitionError):
|
||||
"""No face detected in image"""
|
||||
pass
|
||||
|
||||
|
||||
class FaceEncodingError(FaceRecognitionError):
|
||||
"""Failed to encode face"""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# HTTP EXCEPTION HELPERS
|
||||
# ============================================================================
|
||||
|
||||
def to_http_exception(error: MediaDownloaderError) -> HTTPException:
|
||||
"""Convert a MediaDownloaderError to an HTTPException"""
|
||||
|
||||
# Map exception types to HTTP status codes (most-specific subclasses first)
|
||||
status_map = {
|
||||
# 400 Bad Request (subclasses before parent)
|
||||
InvalidParameterError: 400,
|
||||
MissingParameterError: 400,
|
||||
InvalidURLError: 400,
|
||||
ValidationError: 400,
|
||||
|
||||
# 401 Unauthorized (subclasses before parent)
|
||||
TokenExpiredError: 401,
|
||||
InvalidTokenError: 401,
|
||||
AuthenticationError: 401,
|
||||
AuthError: 401,
|
||||
|
||||
# 403 Forbidden
|
||||
InsufficientPermissionsError: 403,
|
||||
PathTraversalError: 403,
|
||||
FileAccessError: 403,
|
||||
|
||||
# 404 Not Found
|
||||
RecordNotFoundError: 404,
|
||||
MediaFileNotFoundError: 404,
|
||||
ContentNotFoundError: 404,
|
||||
|
||||
# 409 Conflict
|
||||
DuplicateRecordError: 409,
|
||||
|
||||
# 429 Too Many Requests
|
||||
RateLimitError: 429,
|
||||
|
||||
# 500 Internal Server Error (subclasses before parent)
|
||||
DatabaseConnectionError: 500,
|
||||
DatabaseQueryError: 500,
|
||||
DatabaseError: 500,
|
||||
|
||||
# 502 Bad Gateway (subclasses before parent)
|
||||
FlareSolverrError: 502,
|
||||
ServiceError: 502,
|
||||
NetworkError: 502,
|
||||
|
||||
# 503 Service Unavailable
|
||||
PlatformUnavailableError: 503,
|
||||
SchedulerError: 503,
|
||||
}
|
||||
|
||||
# Find the most specific matching status code
|
||||
status_code = 500 # Default to internal server error
|
||||
for exc_type, code in status_map.items():
|
||||
if isinstance(error, exc_type):
|
||||
status_code = code
|
||||
break
|
||||
|
||||
return HTTPException(
|
||||
status_code=status_code,
|
||||
detail=error.to_dict()
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# EXCEPTION HANDLER DECORATOR
|
||||
# ============================================================================
|
||||
|
||||
from functools import wraps
|
||||
from typing import Callable, TypeVar
|
||||
import asyncio
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def handle_exceptions(func: Callable[..., T]) -> Callable[..., T]:
|
||||
"""
|
||||
Decorator to convert MediaDownloaderError exceptions to HTTPException.
|
||||
Use this on route handlers for consistent error responses.
|
||||
"""
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except MediaDownloaderError as e:
|
||||
raise to_http_exception(e)
|
||||
except HTTPException:
|
||||
raise # Let HTTPException pass through
|
||||
except Exception as e:
|
||||
# Log unexpected exceptions
|
||||
from modules.universal_logger import get_logger
|
||||
logger = get_logger('API')
|
||||
logger.error(f"Unexpected error in {func.__name__}: {e}", module="Core")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "InternalError", "message": str(e)}
|
||||
)
|
||||
return async_wrapper
|
||||
else:
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except MediaDownloaderError as e:
|
||||
raise to_http_exception(e)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
from modules.universal_logger import get_logger
|
||||
logger = get_logger('API')
|
||||
logger.error(f"Unexpected error in {func.__name__}: {e}", module="Core")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "InternalError", "message": str(e)}
|
||||
)
|
||||
return sync_wrapper
|
||||
377
web/backend/core/http_client.py
Normal file
377
web/backend/core/http_client.py
Normal file
@@ -0,0 +1,377 @@
|
||||
"""
|
||||
Async HTTP Client
|
||||
|
||||
Provides async HTTP client to replace synchronous requests library in FastAPI endpoints.
|
||||
This prevents blocking the event loop and improves concurrency.
|
||||
|
||||
Usage:
|
||||
from web.backend.core.http_client import http_client, get_http_client
|
||||
|
||||
# In async endpoint
|
||||
async def my_endpoint():
|
||||
async with get_http_client() as client:
|
||||
response = await client.get("https://api.example.com/data")
|
||||
return response.json()
|
||||
|
||||
# Or use the singleton
|
||||
response = await http_client.get("https://api.example.com/data")
|
||||
"""
|
||||
|
||||
import httpx
|
||||
from typing import Optional, Dict, Any, Union
|
||||
from contextlib import asynccontextmanager
|
||||
import asyncio
|
||||
|
||||
from .config import settings
|
||||
from .exceptions import NetworkError, ServiceError
|
||||
|
||||
|
||||
# Default timeouts
|
||||
DEFAULT_TIMEOUT = httpx.Timeout(
|
||||
connect=10.0,
|
||||
read=30.0,
|
||||
write=10.0,
|
||||
pool=10.0
|
||||
)
|
||||
|
||||
# Longer timeout for slow operations
|
||||
LONG_TIMEOUT = httpx.Timeout(
|
||||
connect=10.0,
|
||||
read=120.0,
|
||||
write=30.0,
|
||||
pool=10.0
|
||||
)
|
||||
|
||||
|
||||
class AsyncHTTPClient:
|
||||
"""
|
||||
Async HTTP client wrapper with retry logic and error handling.
|
||||
|
||||
Features:
|
||||
- Connection pooling
|
||||
- Automatic retries
|
||||
- Timeout handling
|
||||
- Error conversion to custom exceptions
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: Optional[str] = None,
|
||||
timeout: Optional[httpx.Timeout] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
max_retries: int = 3
|
||||
):
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout or DEFAULT_TIMEOUT
|
||||
self.headers = headers or {}
|
||||
self.max_retries = max_retries
|
||||
self._client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
"""Get or create the async client"""
|
||||
if self._client is None or self._client.is_closed:
|
||||
client_kwargs = {
|
||||
"timeout": self.timeout,
|
||||
"headers": self.headers,
|
||||
"follow_redirects": True,
|
||||
"limits": httpx.Limits(
|
||||
max_keepalive_connections=20,
|
||||
max_connections=100,
|
||||
keepalive_expiry=30.0
|
||||
)
|
||||
}
|
||||
# Only add base_url if it's not None
|
||||
if self.base_url is not None:
|
||||
client_kwargs["base_url"] = self.base_url
|
||||
|
||||
self._client = httpx.AsyncClient(**client_kwargs)
|
||||
return self._client
|
||||
|
||||
async def close(self):
|
||||
"""Close the client connection"""
|
||||
if self._client and not self._client.is_closed:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
**kwargs
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Make an HTTP request with retry logic.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
url: URL to request
|
||||
**kwargs: Additional arguments for httpx
|
||||
|
||||
Returns:
|
||||
httpx.Response
|
||||
|
||||
Raises:
|
||||
NetworkError: On network failures
|
||||
ServiceError: On HTTP errors
|
||||
"""
|
||||
client = await self._get_client()
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
response = await client.request(method, url, **kwargs)
|
||||
|
||||
# Raise for 4xx/5xx status codes
|
||||
if response.status_code >= 400:
|
||||
if response.status_code >= 500:
|
||||
raise ServiceError(
|
||||
f"Server error: {response.status_code}",
|
||||
{"url": url, "status": response.status_code}
|
||||
)
|
||||
# Don't retry client errors
|
||||
return response
|
||||
|
||||
return response
|
||||
|
||||
except httpx.ConnectError as e:
|
||||
last_error = NetworkError(
|
||||
f"Connection failed: {e}",
|
||||
{"url": url, "attempt": attempt + 1}
|
||||
)
|
||||
except httpx.TimeoutException as e:
|
||||
last_error = NetworkError(
|
||||
f"Request timed out: {e}",
|
||||
{"url": url, "attempt": attempt + 1}
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
last_error = ServiceError(
|
||||
f"HTTP error: {e}",
|
||||
{"url": url, "status": e.response.status_code}
|
||||
)
|
||||
# Don't retry client errors
|
||||
if e.response.status_code < 500:
|
||||
raise last_error
|
||||
except Exception as e:
|
||||
last_error = NetworkError(
|
||||
f"Request failed: {e}",
|
||||
{"url": url, "attempt": attempt + 1}
|
||||
)
|
||||
|
||||
# Wait before retry (exponential backoff)
|
||||
if attempt < self.max_retries - 1:
|
||||
await asyncio.sleep(2 ** attempt)
|
||||
|
||||
raise last_error
|
||||
|
||||
# Convenience methods
|
||||
|
||||
async def get(
|
||||
self,
|
||||
url: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> httpx.Response:
|
||||
"""Make GET request"""
|
||||
request_timeout = httpx.Timeout(timeout) if timeout else None
|
||||
return await self._request(
|
||||
"GET",
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
timeout=request_timeout,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def post(
|
||||
self,
|
||||
url: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
json: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> httpx.Response:
|
||||
"""Make POST request"""
|
||||
request_timeout = httpx.Timeout(timeout) if timeout else None
|
||||
return await self._request(
|
||||
"POST",
|
||||
url,
|
||||
data=data,
|
||||
json=json,
|
||||
headers=headers,
|
||||
timeout=request_timeout,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def put(
|
||||
self,
|
||||
url: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
json: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
**kwargs
|
||||
) -> httpx.Response:
|
||||
"""Make PUT request"""
|
||||
return await self._request(
|
||||
"PUT",
|
||||
url,
|
||||
data=data,
|
||||
json=json,
|
||||
headers=headers,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
url: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
**kwargs
|
||||
) -> httpx.Response:
|
||||
"""Make DELETE request"""
|
||||
return await self._request(
|
||||
"DELETE",
|
||||
url,
|
||||
headers=headers,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def head(
|
||||
self,
|
||||
url: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
**kwargs
|
||||
) -> httpx.Response:
|
||||
"""Make HEAD request"""
|
||||
return await self._request(
|
||||
"HEAD",
|
||||
url,
|
||||
headers=headers,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
# Singleton client for general use
|
||||
http_client = AsyncHTTPClient()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_http_client(
|
||||
base_url: Optional[str] = None,
|
||||
timeout: Optional[httpx.Timeout] = None,
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
):
|
||||
"""
|
||||
Context manager for creating a temporary HTTP client.
|
||||
|
||||
Usage:
|
||||
async with get_http_client(base_url="https://api.example.com") as client:
|
||||
response = await client.get("/endpoint")
|
||||
"""
|
||||
client = AsyncHTTPClient(
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
headers=headers
|
||||
)
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
# Helper functions for common patterns
|
||||
|
||||
async def fetch_json(
|
||||
url: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch JSON from URL.
|
||||
|
||||
Args:
|
||||
url: URL to fetch
|
||||
params: Query parameters
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
Parsed JSON response
|
||||
"""
|
||||
response = await http_client.get(url, params=params, headers=headers)
|
||||
return response.json()
|
||||
|
||||
|
||||
async def post_json(
|
||||
url: str,
|
||||
data: Dict[str, Any],
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
POST JSON to URL and return JSON response.
|
||||
|
||||
Args:
|
||||
url: URL to post to
|
||||
data: JSON data to send
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
Parsed JSON response
|
||||
"""
|
||||
response = await http_client.post(url, json=data, headers=headers)
|
||||
return response.json()
|
||||
|
||||
|
||||
async def check_url_accessible(url: str, timeout: float = 5.0) -> bool:
|
||||
"""
|
||||
Check if a URL is accessible.
|
||||
|
||||
Args:
|
||||
url: URL to check
|
||||
timeout: Request timeout
|
||||
|
||||
Returns:
|
||||
True if URL is accessible
|
||||
"""
|
||||
try:
|
||||
response = await http_client.head(url, timeout=timeout)
|
||||
return response.status_code < 400
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def download_file(
|
||||
url: str,
|
||||
save_path: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
chunk_size: int = 8192
|
||||
) -> int:
|
||||
"""
|
||||
Download file from URL.
|
||||
|
||||
Args:
|
||||
url: URL to download
|
||||
save_path: Path to save file
|
||||
headers: Request headers
|
||||
chunk_size: Size of chunks for streaming
|
||||
|
||||
Returns:
|
||||
Number of bytes downloaded
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
client = await http_client._get_client()
|
||||
total_bytes = 0
|
||||
|
||||
async with client.stream("GET", url, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
# Ensure directory exists
|
||||
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(save_path, 'wb') as f:
|
||||
async for chunk in response.aiter_bytes(chunk_size):
|
||||
f.write(chunk)
|
||||
total_bytes += len(chunk)
|
||||
|
||||
return total_bytes
|
||||
311
web/backend/core/responses.py
Normal file
311
web/backend/core/responses.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""
|
||||
Standardized Response Format
|
||||
|
||||
Provides consistent response structures for all API endpoints.
|
||||
All responses follow a standardized format for error handling and data delivery.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, TypeVar, Generic
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# STANDARD RESPONSE MODELS
|
||||
# ============================================================================
|
||||
|
||||
class ErrorDetail(BaseModel):
|
||||
"""Standard error detail structure"""
|
||||
error: str = Field(..., description="Error type/code")
|
||||
message: str = Field(..., description="Human-readable error message")
|
||||
details: Optional[Dict[str, Any]] = Field(None, description="Additional error context")
|
||||
timestamp: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"))
|
||||
|
||||
|
||||
class SuccessResponse(BaseModel):
|
||||
"""Standard success response"""
|
||||
success: bool = True
|
||||
message: Optional[str] = None
|
||||
data: Optional[Any] = None
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel):
|
||||
"""Standard paginated response"""
|
||||
items: List[Any]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class ListResponse(BaseModel):
|
||||
"""Standard list response with metadata"""
|
||||
items: List[Any]
|
||||
count: int
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# RESPONSE BUILDERS
|
||||
# ============================================================================
|
||||
|
||||
def success(data: Any = None, message: str = None) -> Dict[str, Any]:
|
||||
"""Build a success response"""
|
||||
response = {"success": True}
|
||||
if message:
|
||||
response["message"] = message
|
||||
if data is not None:
|
||||
response["data"] = data
|
||||
return response
|
||||
|
||||
|
||||
def error(
|
||||
error_type: str,
|
||||
message: str,
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Build an error response"""
|
||||
return {
|
||||
"success": False,
|
||||
"error": error_type,
|
||||
"message": message,
|
||||
"details": details or {},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
}
|
||||
|
||||
|
||||
def paginated(
|
||||
items: List[Any],
|
||||
total: int,
|
||||
page: int,
|
||||
page_size: int
|
||||
) -> Dict[str, Any]:
|
||||
"""Build a paginated response"""
|
||||
return {
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"has_more": (page * page_size) < total
|
||||
}
|
||||
|
||||
|
||||
def list_response(items: List[Any], key: str = "items") -> Dict[str, Any]:
|
||||
"""Build a simple list response with custom key"""
|
||||
return {
|
||||
key: items,
|
||||
"count": len(items)
|
||||
}
|
||||
|
||||
|
||||
def message_response(message: str) -> Dict[str, str]:
|
||||
"""Build a simple message response"""
|
||||
return {"message": message}
|
||||
|
||||
|
||||
def count_response(message: str, count: int) -> Dict[str, Any]:
|
||||
"""Build a response with count of affected items"""
|
||||
return {"message": message, "count": count}
|
||||
|
||||
|
||||
def id_response(resource_id: int, message: str = "Resource created successfully") -> Dict[str, Any]:
|
||||
"""Build a response with created resource ID"""
|
||||
return {"id": resource_id, "message": message}
|
||||
|
||||
|
||||
def offset_paginated(
|
||||
items: List[Any],
|
||||
total: int,
|
||||
limit: int,
|
||||
offset: int,
|
||||
key: str = "items",
|
||||
include_timestamp: bool = False,
|
||||
**extra_fields
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build an offset-based paginated response (common pattern in existing routers).
|
||||
|
||||
This is the standard paginated response format used across all endpoints
|
||||
that return lists of items with pagination support.
|
||||
|
||||
Args:
|
||||
items: Items for current page
|
||||
total: Total count of all items
|
||||
limit: Page size (number of items per page)
|
||||
offset: Current offset (starting position)
|
||||
key: Key name for items list (default "items")
|
||||
include_timestamp: Whether to include timestamp field
|
||||
**extra_fields: Additional fields to include in response
|
||||
|
||||
Returns:
|
||||
Dict with paginated response structure
|
||||
|
||||
Example:
|
||||
return offset_paginated(
|
||||
items=media_items,
|
||||
total=total_count,
|
||||
limit=50,
|
||||
offset=0,
|
||||
key="media",
|
||||
stats={"images": 100, "videos": 50}
|
||||
)
|
||||
"""
|
||||
response = {
|
||||
key: items,
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset
|
||||
}
|
||||
|
||||
if include_timestamp:
|
||||
response["timestamp"] = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
# Add any extra fields
|
||||
response.update(extra_fields)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def batch_operation_response(
|
||||
succeeded: bool,
|
||||
processed: List[Any] = None,
|
||||
errors: List[Dict[str, Any]] = None,
|
||||
message: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build a response for batch operations (batch delete, batch move, etc.).
|
||||
|
||||
This provides a standardized format for operations that affect multiple items.
|
||||
|
||||
Args:
|
||||
succeeded: Overall success status
|
||||
processed: List of successfully processed items
|
||||
errors: List of error details for failed items
|
||||
message: Optional message
|
||||
|
||||
Returns:
|
||||
Dict with batch operation response structure
|
||||
|
||||
Example:
|
||||
return batch_operation_response(
|
||||
succeeded=True,
|
||||
processed=["/path/to/file1.jpg", "/path/to/file2.jpg"],
|
||||
errors=[{"file": "/path/to/file3.jpg", "error": "File not found"}]
|
||||
)
|
||||
"""
|
||||
processed = processed or []
|
||||
errors = errors or []
|
||||
|
||||
response = {
|
||||
"success": succeeded,
|
||||
"processed_count": len(processed),
|
||||
"error_count": len(errors),
|
||||
"timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
}
|
||||
|
||||
if processed:
|
||||
response["processed"] = processed
|
||||
|
||||
if errors:
|
||||
response["errors"] = errors
|
||||
|
||||
if message:
|
||||
response["message"] = message
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DATE/TIME UTILITIES
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def to_iso8601(dt: datetime) -> Optional[str]:
|
||||
"""
|
||||
Convert datetime to ISO 8601 format with UTC timezone.
|
||||
This is the standard format for all API responses.
|
||||
|
||||
Example output: "2025-12-04T10:30:00Z"
|
||||
"""
|
||||
if dt is None:
|
||||
return None
|
||||
|
||||
# Ensure timezone awareness
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
dt = dt.astimezone(timezone.utc)
|
||||
|
||||
return dt.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
|
||||
def from_iso8601(date_string: str) -> Optional[datetime]:
|
||||
"""
|
||||
Parse ISO 8601 date string to datetime.
|
||||
Handles various formats including with and without timezone.
|
||||
"""
|
||||
if not date_string:
|
||||
return None
|
||||
|
||||
# Common formats to try
|
||||
formats = [
|
||||
"%Y-%m-%dT%H:%M:%SZ", # ISO with Z
|
||||
"%Y-%m-%dT%H:%M:%S.%fZ", # ISO with microseconds and Z
|
||||
"%Y-%m-%dT%H:%M:%S", # ISO without Z
|
||||
"%Y-%m-%dT%H:%M:%S.%f", # ISO with microseconds
|
||||
"%Y-%m-%d %H:%M:%S", # Space separator
|
||||
"%Y-%m-%d", # Date only
|
||||
]
|
||||
|
||||
for fmt in formats:
|
||||
try:
|
||||
dt = datetime.strptime(date_string, fmt)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Try parsing with fromisoformat (Python 3.7+)
|
||||
try:
|
||||
dt = datetime.fromisoformat(date_string.replace('Z', '+00:00'))
|
||||
return dt
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def format_timestamp(
|
||||
timestamp: Optional[str],
|
||||
output_format: str = "iso8601"
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Format a timestamp string to the specified format.
|
||||
|
||||
Args:
|
||||
timestamp: Input timestamp string
|
||||
output_format: "iso8601", "unix", or strftime format string
|
||||
|
||||
Returns:
|
||||
Formatted timestamp string
|
||||
"""
|
||||
if not timestamp:
|
||||
return None
|
||||
|
||||
dt = from_iso8601(timestamp)
|
||||
if not dt:
|
||||
return timestamp # Return original if parsing failed
|
||||
|
||||
if output_format == "iso8601":
|
||||
return to_iso8601(dt)
|
||||
elif output_format == "unix":
|
||||
return str(int(dt.timestamp()))
|
||||
else:
|
||||
return dt.strftime(output_format)
|
||||
|
||||
|
||||
def now_iso8601() -> str:
|
||||
"""Get current UTC time in ISO 8601 format"""
|
||||
return to_iso8601(datetime.now(timezone.utc))
|
||||
582
web/backend/core/utils.py
Normal file
582
web/backend/core/utils.py
Normal file
@@ -0,0 +1,582 @@
|
||||
"""
|
||||
Shared Utility Functions
|
||||
|
||||
Common helper functions used across multiple routers.
|
||||
"""
|
||||
|
||||
import io
|
||||
import sqlite3
|
||||
import hashlib
|
||||
import subprocess
|
||||
from collections import OrderedDict
|
||||
from contextlib import closing
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
from PIL import Image
|
||||
|
||||
from .config import settings
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# THUMBNAIL LRU CACHE
|
||||
# ============================================================================
|
||||
|
||||
class ThumbnailLRUCache:
|
||||
"""Thread-safe LRU cache for thumbnail binary data.
|
||||
|
||||
Avoids SQLite lookups for frequently accessed thumbnails.
|
||||
Used by media.py and recycle.py routers.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int = 500, max_memory_mb: int = 100):
|
||||
self._cache: OrderedDict[str, bytes] = OrderedDict()
|
||||
self._lock = Lock()
|
||||
self._max_size = max_size
|
||||
self._max_memory = max_memory_mb * 1024 * 1024 # Convert to bytes
|
||||
self._current_memory = 0
|
||||
|
||||
def get(self, key: str) -> Optional[bytes]:
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
# Move to end (most recently used)
|
||||
self._cache.move_to_end(key)
|
||||
return self._cache[key]
|
||||
return None
|
||||
|
||||
def put(self, key: str, data: bytes) -> None:
|
||||
with self._lock:
|
||||
data_size = len(data)
|
||||
|
||||
# Don't cache if single item is too large (>1MB)
|
||||
if data_size > 1024 * 1024:
|
||||
return
|
||||
|
||||
# Remove old entry if exists
|
||||
if key in self._cache:
|
||||
self._current_memory -= len(self._cache[key])
|
||||
del self._cache[key]
|
||||
|
||||
# Evict oldest entries if needed
|
||||
while (len(self._cache) >= self._max_size or
|
||||
self._current_memory + data_size > self._max_memory) and self._cache:
|
||||
oldest_key, oldest_data = self._cache.popitem(last=False)
|
||||
self._current_memory -= len(oldest_data)
|
||||
|
||||
# Add new entry
|
||||
self._cache[key] = data
|
||||
self._current_memory += data_size
|
||||
|
||||
def clear(self) -> None:
|
||||
with self._lock:
|
||||
self._cache.clear()
|
||||
self._current_memory = 0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SQL FILTER CONSTANTS
|
||||
# ============================================================================
|
||||
|
||||
# Valid media file filters (excluding phrase checks, must have valid extension)
|
||||
# Used by downloads, health, and analytics endpoints
|
||||
MEDIA_FILTERS = """
|
||||
(filename NOT LIKE '%_phrase_checked_%' OR filename IS NULL)
|
||||
AND (file_path IS NOT NULL AND file_path != '' OR platform = 'forums')
|
||||
AND (LENGTH(filename) > 20 OR filename LIKE '%_%_%')
|
||||
AND (
|
||||
filename LIKE '%.jpg' OR filename LIKE '%.jpeg' OR
|
||||
filename LIKE '%.png' OR filename LIKE '%.gif' OR
|
||||
filename LIKE '%.heic' OR filename LIKE '%.heif' OR
|
||||
filename LIKE '%.mp4' OR filename LIKE '%.mov' OR
|
||||
filename LIKE '%.webm' OR filename LIKE '%.m4a' OR
|
||||
filename LIKE '%.mp3' OR filename LIKE '%.avi' OR
|
||||
filename LIKE '%.mkv' OR filename LIKE '%.flv'
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PATH VALIDATION
|
||||
# ============================================================================
|
||||
|
||||
# Allowed base paths for file operations
|
||||
ALLOWED_PATHS = [
|
||||
settings.MEDIA_BASE_PATH,
|
||||
settings.REVIEW_PATH,
|
||||
settings.RECYCLE_PATH,
|
||||
Path('/opt/media-downloader/temp/manual_import'),
|
||||
Path('/opt/immich/paid'),
|
||||
Path('/opt/immich/el'),
|
||||
Path('/opt/immich/elv'),
|
||||
]
|
||||
|
||||
|
||||
def validate_file_path(
|
||||
file_path: str,
|
||||
allowed_bases: Optional[List[Path]] = None,
|
||||
require_exists: bool = False
|
||||
) -> Path:
|
||||
"""
|
||||
Validate file path is within allowed directories.
|
||||
Prevents path traversal attacks.
|
||||
|
||||
Args:
|
||||
file_path: Path to validate
|
||||
allowed_bases: List of allowed base paths (defaults to ALLOWED_PATHS)
|
||||
require_exists: If True, also verify the file exists
|
||||
|
||||
Returns:
|
||||
Resolved Path object
|
||||
|
||||
Raises:
|
||||
HTTPException: If path is invalid or outside allowed directories
|
||||
"""
|
||||
if allowed_bases is None:
|
||||
allowed_bases = ALLOWED_PATHS
|
||||
|
||||
requested_path = Path(file_path)
|
||||
|
||||
try:
|
||||
resolved_path = requested_path.resolve()
|
||||
is_allowed = False
|
||||
|
||||
for allowed_base in allowed_bases:
|
||||
try:
|
||||
resolved_path.relative_to(allowed_base.resolve())
|
||||
is_allowed = True
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not is_allowed:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
if require_exists and not resolved_path.exists():
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="Invalid file path")
|
||||
|
||||
return resolved_path
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# QUERY FILTER BUILDER
|
||||
# ============================================================================
|
||||
|
||||
def build_media_filter_query(
|
||||
platform: Optional[str] = None,
|
||||
source: Optional[str] = None,
|
||||
media_type: Optional[str] = None,
|
||||
location: Optional[str] = None,
|
||||
date_from: Optional[str] = None,
|
||||
date_to: Optional[str] = None,
|
||||
table_alias: str = "fi"
|
||||
) -> Tuple[str, List]:
|
||||
"""
|
||||
Build SQL filter clause for common media queries.
|
||||
|
||||
This centralizes the filter building logic that was duplicated across
|
||||
media.py, downloads.py, and review.py routers.
|
||||
|
||||
Args:
|
||||
platform: Filter by platform (e.g., 'instagram', 'tiktok')
|
||||
source: Filter by source (e.g., 'stories', 'posts')
|
||||
media_type: Filter by media type ('image', 'video', 'all')
|
||||
location: Filter by location ('final', 'review', 'recycle')
|
||||
date_from: Start date filter (ISO format)
|
||||
date_to: End date filter (ISO format)
|
||||
table_alias: SQL table alias (default 'fi' for file_inventory)
|
||||
|
||||
Returns:
|
||||
Tuple of (SQL WHERE clause conditions, list of parameters)
|
||||
|
||||
Example:
|
||||
conditions, params = build_media_filter_query(platform="instagram", media_type="video")
|
||||
query = f"SELECT * FROM file_inventory fi WHERE {conditions}"
|
||||
cursor.execute(query, params)
|
||||
"""
|
||||
if not table_alias.isidentifier():
|
||||
table_alias = "fi"
|
||||
|
||||
conditions = []
|
||||
params = []
|
||||
|
||||
if platform:
|
||||
conditions.append(f"{table_alias}.platform = ?")
|
||||
params.append(platform)
|
||||
|
||||
if source:
|
||||
conditions.append(f"{table_alias}.source = ?")
|
||||
params.append(source)
|
||||
|
||||
if media_type and media_type != 'all':
|
||||
conditions.append(f"{table_alias}.media_type = ?")
|
||||
params.append(media_type)
|
||||
|
||||
if location:
|
||||
conditions.append(f"{table_alias}.location = ?")
|
||||
params.append(location)
|
||||
|
||||
if date_from:
|
||||
# Use COALESCE to handle both post_date from downloads and created_date from file_inventory
|
||||
conditions.append(f"""
|
||||
DATE(COALESCE(
|
||||
(SELECT MAX(d.post_date) FROM downloads d WHERE d.file_path = {table_alias}.file_path),
|
||||
{table_alias}.created_date
|
||||
)) >= ?
|
||||
""")
|
||||
params.append(date_from)
|
||||
|
||||
if date_to:
|
||||
conditions.append(f"""
|
||||
DATE(COALESCE(
|
||||
(SELECT MAX(d.post_date) FROM downloads d WHERE d.file_path = {table_alias}.file_path),
|
||||
{table_alias}.created_date
|
||||
)) <= ?
|
||||
""")
|
||||
params.append(date_to)
|
||||
|
||||
return " AND ".join(conditions) if conditions else "1=1", params
|
||||
|
||||
|
||||
def build_platform_list_filter(
|
||||
platforms: Optional[List[str]] = None,
|
||||
table_alias: str = "fi"
|
||||
) -> Tuple[str, List[str]]:
|
||||
"""
|
||||
Build SQL IN clause for filtering by multiple platforms.
|
||||
|
||||
Args:
|
||||
platforms: List of platform names to filter by
|
||||
table_alias: SQL table alias
|
||||
|
||||
Returns:
|
||||
Tuple of (SQL condition string, list of parameters)
|
||||
"""
|
||||
if not platforms:
|
||||
return "1=1", []
|
||||
|
||||
placeholders = ",".join(["?"] * len(platforms))
|
||||
return f"{table_alias}.platform IN ({placeholders})", platforms
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# THUMBNAIL GENERATION
|
||||
# ============================================================================
|
||||
|
||||
def generate_image_thumbnail(file_path: Path, max_size: Tuple[int, int] = (300, 300)) -> Optional[bytes]:
|
||||
"""
|
||||
Generate thumbnail for image file.
|
||||
|
||||
Args:
|
||||
file_path: Path to image file
|
||||
max_size: Maximum thumbnail dimensions
|
||||
|
||||
Returns:
|
||||
JPEG bytes or None if generation fails
|
||||
"""
|
||||
try:
|
||||
img = Image.open(file_path)
|
||||
img.thumbnail(max_size, Image.Resampling.LANCZOS)
|
||||
|
||||
# Convert to RGB if necessary
|
||||
if img.mode in ('RGBA', 'LA', 'P'):
|
||||
background = Image.new('RGB', img.size, (255, 255, 255))
|
||||
if img.mode == 'P':
|
||||
img = img.convert('RGBA')
|
||||
background.paste(img, mask=img.split()[-1] if img.mode in ('RGBA', 'LA') else None)
|
||||
img = background
|
||||
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='JPEG', quality=85)
|
||||
return buffer.getvalue()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def generate_video_thumbnail(file_path: Path, max_size: Tuple[int, int] = (300, 300)) -> Optional[bytes]:
|
||||
"""
|
||||
Generate thumbnail for video file using ffmpeg.
|
||||
|
||||
Args:
|
||||
file_path: Path to video file
|
||||
max_size: Maximum thumbnail dimensions
|
||||
|
||||
Returns:
|
||||
JPEG bytes or None if generation fails
|
||||
"""
|
||||
# Try seeking to 1s first, then fall back to first frame
|
||||
for seek_time in ['00:00:01.000', '00:00:00.000']:
|
||||
try:
|
||||
result = subprocess.run([
|
||||
'ffmpeg',
|
||||
'-ss', seek_time,
|
||||
'-i', str(file_path),
|
||||
'-vframes', '1',
|
||||
'-f', 'image2pipe',
|
||||
'-vcodec', 'mjpeg',
|
||||
'-'
|
||||
], capture_output=True, timeout=30)
|
||||
|
||||
if result.returncode != 0 or not result.stdout:
|
||||
continue
|
||||
|
||||
# Resize the frame
|
||||
img = Image.open(io.BytesIO(result.stdout))
|
||||
img.thumbnail(max_size, Image.Resampling.LANCZOS)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='JPEG', quality=85)
|
||||
return buffer.getvalue()
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_or_create_thumbnail(
|
||||
file_path: Union[str, Path],
|
||||
media_type: str,
|
||||
content_hash: Optional[str] = None,
|
||||
max_size: Tuple[int, int] = (300, 300)
|
||||
) -> Optional[bytes]:
|
||||
"""
|
||||
Get thumbnail from cache or generate and cache it.
|
||||
|
||||
Uses the thumbnails.db schema: file_hash (PK), file_path, thumbnail_data, created_at, file_mtime.
|
||||
|
||||
Lookup strategy:
|
||||
1. Try content_hash against file_hash column (survives file moves)
|
||||
2. Fall back to file_path lookup (legacy thumbnails)
|
||||
3. Generate and cache if not found
|
||||
|
||||
Args:
|
||||
file_path: Path to media file
|
||||
media_type: 'image' or 'video'
|
||||
content_hash: Optional pre-computed hash (computed from path if not provided)
|
||||
max_size: Maximum thumbnail dimensions
|
||||
|
||||
Returns:
|
||||
JPEG bytes or None if generation fails
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
thumb_db_path = settings.PROJECT_ROOT / 'database' / 'thumbnails.db'
|
||||
|
||||
# Compute hash if not provided
|
||||
file_hash = content_hash if content_hash else hashlib.sha256(str(file_path).encode()).hexdigest()
|
||||
|
||||
# Try to get from cache (skip mtime check — downloaded media files don't change)
|
||||
try:
|
||||
with sqlite3.connect(str(thumb_db_path), timeout=30.0) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 1. Try file_hash lookup (primary key)
|
||||
cursor.execute(
|
||||
"SELECT thumbnail_data FROM thumbnails WHERE file_hash = ?",
|
||||
(file_hash,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result and result[0]:
|
||||
return result[0]
|
||||
|
||||
# 2. Fall back to file_path lookup (legacy thumbnails)
|
||||
cursor.execute(
|
||||
"SELECT thumbnail_data FROM thumbnails WHERE file_path = ?",
|
||||
(str(file_path),)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result and result[0]:
|
||||
return result[0]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Only proceed with generation if the file actually exists
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
# Get mtime only when we need to generate and cache a new thumbnail
|
||||
try:
|
||||
file_mtime = file_path.stat().st_mtime
|
||||
except OSError:
|
||||
file_mtime = 0
|
||||
|
||||
# Generate thumbnail
|
||||
thumbnail_data = None
|
||||
if media_type == 'video':
|
||||
thumbnail_data = generate_video_thumbnail(file_path, max_size)
|
||||
else:
|
||||
thumbnail_data = generate_image_thumbnail(file_path, max_size)
|
||||
|
||||
# Cache the thumbnail
|
||||
if thumbnail_data:
|
||||
try:
|
||||
from .responses import now_iso8601
|
||||
with sqlite3.connect(str(thumb_db_path), timeout=30.0) as conn:
|
||||
conn.execute("""
|
||||
INSERT OR REPLACE INTO thumbnails
|
||||
(file_hash, file_path, thumbnail_data, created_at, file_mtime)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", (file_hash, str(file_path), thumbnail_data, now_iso8601(), file_mtime))
|
||||
conn.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return thumbnail_data
|
||||
|
||||
|
||||
def get_media_dimensions(file_path: str, width: int = None, height: int = None) -> Tuple[Optional[int], Optional[int]]:
|
||||
"""
|
||||
Get media dimensions, falling back to metadata cache if not provided.
|
||||
|
||||
Args:
|
||||
file_path: Path to media file
|
||||
width: Width from file_inventory (may be None)
|
||||
height: Height from file_inventory (may be None)
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height), or (None, None) if not available
|
||||
"""
|
||||
if width is not None and height is not None:
|
||||
return (width, height)
|
||||
|
||||
try:
|
||||
metadata_db_path = settings.PROJECT_ROOT / 'database' / 'media_metadata.db'
|
||||
|
||||
file_hash = hashlib.sha256(file_path.encode()).hexdigest()
|
||||
with closing(sqlite3.connect(str(metadata_db_path))) as conn:
|
||||
cursor = conn.execute(
|
||||
"SELECT width, height FROM media_metadata WHERE file_hash = ?",
|
||||
(file_hash,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
|
||||
if result:
|
||||
return (result[0], result[1])
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return (width, height)
|
||||
|
||||
|
||||
def get_media_dimensions_batch(file_paths: List[str]) -> Dict[str, Tuple[int, int]]:
|
||||
"""
|
||||
Get media dimensions for multiple files in a single query (batch lookup).
|
||||
Avoids N+1 query problem by fetching all dimensions at once.
|
||||
|
||||
Args:
|
||||
file_paths: List of file paths to look up
|
||||
|
||||
Returns:
|
||||
Dict mapping file_path -> (width, height)
|
||||
"""
|
||||
if not file_paths:
|
||||
return {}
|
||||
|
||||
result = {}
|
||||
|
||||
try:
|
||||
metadata_db_path = settings.PROJECT_ROOT / 'database' / 'media_metadata.db'
|
||||
|
||||
# Build hash -> path mapping
|
||||
hash_to_path = {}
|
||||
for fp in file_paths:
|
||||
file_hash = hashlib.sha256(fp.encode()).hexdigest()
|
||||
hash_to_path[file_hash] = fp
|
||||
|
||||
# Query all at once
|
||||
with closing(sqlite3.connect(str(metadata_db_path))) as conn:
|
||||
placeholders = ','.join('?' * len(hash_to_path))
|
||||
cursor = conn.execute(
|
||||
f"SELECT file_hash, width, height FROM media_metadata WHERE file_hash IN ({placeholders})",
|
||||
list(hash_to_path.keys())
|
||||
)
|
||||
|
||||
for row in cursor.fetchall():
|
||||
file_hash, width, height = row
|
||||
if file_hash in hash_to_path:
|
||||
result[hash_to_path[file_hash]] = (width, height)
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DATABASE UTILITIES
|
||||
# ============================================================================
|
||||
|
||||
def update_file_path_in_all_tables(db, old_path: str, new_path: str):
|
||||
"""
|
||||
Update file path in all relevant database tables.
|
||||
|
||||
Used when moving files between locations (review -> final, etc.)
|
||||
to keep database references consistent.
|
||||
|
||||
Args:
|
||||
db: UnifiedDatabase instance
|
||||
old_path: The old file path to replace
|
||||
new_path: The new file path to use
|
||||
"""
|
||||
from modules.universal_logger import get_logger
|
||||
logger = get_logger('API')
|
||||
|
||||
try:
|
||||
with db.get_connection(for_write=True) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('UPDATE downloads SET file_path = ? WHERE file_path = ?',
|
||||
(new_path, old_path))
|
||||
cursor.execute('UPDATE instagram_perceptual_hashes SET file_path = ? WHERE file_path = ?',
|
||||
(new_path, old_path))
|
||||
cursor.execute('UPDATE face_recognition_scans SET file_path = ? WHERE file_path = ?',
|
||||
(new_path, old_path))
|
||||
|
||||
try:
|
||||
cursor.execute('UPDATE semantic_embeddings SET file_path = ? WHERE file_path = ?',
|
||||
(new_path, old_path))
|
||||
except sqlite3.OperationalError:
|
||||
pass # Table may not exist
|
||||
|
||||
conn.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update file paths in tables: {e}", module="Database")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# FACE RECOGNITION UTILITIES
|
||||
# ============================================================================
|
||||
|
||||
# Cached FaceRecognitionModule singleton to avoid loading InsightFace models on every request
|
||||
_face_module_cache: Dict[int, 'FaceRecognitionModule'] = {}
|
||||
|
||||
|
||||
def get_face_module(db, module_name: str = "FaceAPI"):
|
||||
"""
|
||||
Get or create a cached FaceRecognitionModule instance for the given database.
|
||||
|
||||
Uses singleton pattern to avoid reloading heavy InsightFace models on each request.
|
||||
|
||||
Args:
|
||||
db: UnifiedDatabase instance
|
||||
module_name: Name to use in log messages
|
||||
|
||||
Returns:
|
||||
FaceRecognitionModule instance
|
||||
"""
|
||||
from modules.face_recognition_module import FaceRecognitionModule
|
||||
from modules.universal_logger import get_logger
|
||||
logger = get_logger('API')
|
||||
|
||||
db_id = id(db)
|
||||
if db_id not in _face_module_cache:
|
||||
logger.info("Creating cached FaceRecognitionModule instance", module=module_name)
|
||||
_face_module_cache[db_id] = FaceRecognitionModule(unified_db=db)
|
||||
return _face_module_cache[db_id]
|
||||
Reference in New Issue
Block a user