366
web/backend/routers/semantic.py
Normal file
366
web/backend/routers/semantic.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""
|
||||
Semantic Search Router
|
||||
|
||||
Handles CLIP-based semantic search operations:
|
||||
- Text-based image/video search
|
||||
- Similar file search
|
||||
- Embedding generation and management
|
||||
- Model settings
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Body, Depends, Query, Request
|
||||
from pydantic import BaseModel
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from ..core.dependencies import get_current_user, require_admin, get_app_state
|
||||
from ..core.exceptions import handle_exceptions, ValidationError
|
||||
from modules.semantic_search import get_semantic_search
|
||||
from modules.universal_logger import get_logger
|
||||
|
||||
logger = get_logger('API')
|
||||
|
||||
router = APIRouter(prefix="/api/semantic", tags=["Semantic Search"])
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
# Batch limit for embedding generation
|
||||
EMBEDDING_BATCH_LIMIT = 10000
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# PYDANTIC MODELS
|
||||
# ============================================================================
|
||||
|
||||
class SemanticSearchRequest(BaseModel):
|
||||
query: str
|
||||
limit: int = 50
|
||||
platform: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
threshold: float = 0.2
|
||||
|
||||
|
||||
class GenerateEmbeddingsRequest(BaseModel):
|
||||
limit: int = 100
|
||||
platform: Optional[str] = None
|
||||
|
||||
|
||||
class SemanticSettingsUpdate(BaseModel):
|
||||
model: Optional[str] = None
|
||||
threshold: Optional[float] = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SEARCH ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/stats")
|
||||
@limiter.limit("120/minute")
|
||||
@handle_exceptions
|
||||
async def get_semantic_stats(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get statistics about semantic search embeddings."""
|
||||
app_state = get_app_state()
|
||||
search = get_semantic_search(app_state.db)
|
||||
stats = search.get_embedding_stats()
|
||||
return stats
|
||||
|
||||
|
||||
@router.post("/search")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def semantic_search(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user),
|
||||
query: str = Body(..., embed=True),
|
||||
limit: int = Body(50, ge=1, le=200),
|
||||
platform: Optional[str] = Body(None),
|
||||
source: Optional[str] = Body(None),
|
||||
threshold: float = Body(0.2, ge=0.0, le=1.0)
|
||||
):
|
||||
"""Search for images/videos using natural language query."""
|
||||
app_state = get_app_state()
|
||||
search = get_semantic_search(app_state.db)
|
||||
results = search.search_by_text(
|
||||
query=query,
|
||||
limit=limit,
|
||||
platform=platform,
|
||||
source=source,
|
||||
threshold=threshold
|
||||
)
|
||||
return {"results": results, "count": len(results), "query": query}
|
||||
|
||||
|
||||
@router.post("/similar/{file_id}")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def find_similar_files(
|
||||
request: Request,
|
||||
file_id: int,
|
||||
current_user: Dict = Depends(get_current_user),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
platform: Optional[str] = Query(None),
|
||||
source: Optional[str] = Query(None),
|
||||
threshold: float = Query(0.5, ge=0.0, le=1.0)
|
||||
):
|
||||
"""Find files similar to a given file."""
|
||||
app_state = get_app_state()
|
||||
search = get_semantic_search(app_state.db)
|
||||
results = search.search_by_file_id(
|
||||
file_id=file_id,
|
||||
limit=limit,
|
||||
platform=platform,
|
||||
source=source,
|
||||
threshold=threshold
|
||||
)
|
||||
return {"results": results, "count": len(results), "source_file_id": file_id}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# EMBEDDING GENERATION ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.post("/generate")
|
||||
@limiter.limit("10/minute")
|
||||
@handle_exceptions
|
||||
async def generate_embeddings(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict = Depends(get_current_user),
|
||||
limit: int = Body(100, ge=1, le=1000),
|
||||
platform: Optional[str] = Body(None)
|
||||
):
|
||||
"""Generate CLIP embeddings for files that don't have them yet."""
|
||||
app_state = get_app_state()
|
||||
|
||||
if app_state.indexing_running:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Indexing already in progress",
|
||||
"already_running": True,
|
||||
"status": "already_running"
|
||||
}
|
||||
|
||||
search = get_semantic_search(app_state.db)
|
||||
|
||||
app_state.indexing_running = True
|
||||
app_state.indexing_start_time = time.time()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
manager = getattr(app_state, 'websocket_manager', None)
|
||||
|
||||
def progress_callback(processed: int, total: int, current_file: str):
|
||||
try:
|
||||
if processed % 10 == 0 or processed == total:
|
||||
stats = search.get_embedding_stats()
|
||||
if manager:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
manager.broadcast({
|
||||
"type": "embedding_progress",
|
||||
"processed": processed,
|
||||
"total": total,
|
||||
"current_file": current_file,
|
||||
"total_embeddings": stats.get('total_embeddings', 0),
|
||||
"coverage_percent": stats.get('coverage_percent', 0)
|
||||
}),
|
||||
loop
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def run_generation():
|
||||
try:
|
||||
results = search.generate_embeddings_batch(
|
||||
limit=limit,
|
||||
platform=platform,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
logger.info(f"Embedding generation complete: {results}", module="SemanticSearch")
|
||||
try:
|
||||
stats = search.get_embedding_stats()
|
||||
if manager:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
manager.broadcast({
|
||||
"type": "embedding_complete",
|
||||
"results": results,
|
||||
"total_embeddings": stats.get('total_embeddings', 0),
|
||||
"coverage_percent": stats.get('coverage_percent', 0)
|
||||
}),
|
||||
loop
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Embedding generation failed: {e}", module="SemanticSearch")
|
||||
finally:
|
||||
app_state.indexing_running = False
|
||||
app_state.indexing_start_time = None
|
||||
|
||||
background_tasks.add_task(run_generation)
|
||||
|
||||
return {
|
||||
"message": f"Started embedding generation for up to {limit} files",
|
||||
"status": "processing"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
@limiter.limit("60/minute")
|
||||
@handle_exceptions
|
||||
async def get_semantic_indexing_status(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Check if semantic indexing is currently running."""
|
||||
app_state = get_app_state()
|
||||
elapsed = None
|
||||
if app_state.indexing_running and app_state.indexing_start_time:
|
||||
elapsed = int(time.time() - app_state.indexing_start_time)
|
||||
return {
|
||||
"indexing_running": app_state.indexing_running,
|
||||
"elapsed_seconds": elapsed
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SETTINGS ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/settings")
|
||||
@limiter.limit("30/minute")
|
||||
@handle_exceptions
|
||||
async def get_semantic_settings(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(get_current_user)
|
||||
):
|
||||
"""Get semantic search settings."""
|
||||
app_state = get_app_state()
|
||||
settings = app_state.settings.get('semantic_search', {})
|
||||
return {
|
||||
"model": settings.get('model', 'clip-ViT-B-32'),
|
||||
"threshold": settings.get('threshold', 0.2)
|
||||
}
|
||||
|
||||
|
||||
@router.post("/settings")
|
||||
@limiter.limit("10/minute")
|
||||
@handle_exceptions
|
||||
async def update_semantic_settings(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict = Depends(require_admin),
|
||||
model: str = Body(None, embed=True),
|
||||
threshold: float = Body(None, embed=True)
|
||||
):
|
||||
"""Update semantic search settings."""
|
||||
app_state = get_app_state()
|
||||
|
||||
current_settings = app_state.settings.get('semantic_search', {}) or {}
|
||||
old_model = current_settings.get('model', 'clip-ViT-B-32') if isinstance(current_settings, dict) else 'clip-ViT-B-32'
|
||||
model_changed = False
|
||||
|
||||
new_settings = dict(current_settings) if isinstance(current_settings, dict) else {}
|
||||
|
||||
if model:
|
||||
valid_models = ['clip-ViT-B-32', 'clip-ViT-B-16', 'clip-ViT-L-14']
|
||||
if model not in valid_models:
|
||||
raise ValidationError(f"Invalid model. Must be one of: {valid_models}")
|
||||
if model != old_model:
|
||||
model_changed = True
|
||||
new_settings['model'] = model
|
||||
|
||||
if threshold is not None:
|
||||
if threshold < 0 or threshold > 1:
|
||||
raise ValidationError("Threshold must be between 0 and 1")
|
||||
new_settings['threshold'] = threshold
|
||||
|
||||
app_state.settings.set('semantic_search', new_settings, category='ai')
|
||||
|
||||
if model_changed:
|
||||
logger.info(f"Model changed from {old_model} to {model}, clearing embeddings", module="SemanticSearch")
|
||||
|
||||
with app_state.db.get_connection(for_write=True) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('DELETE FROM content_embeddings')
|
||||
deleted = cursor.rowcount
|
||||
logger.info(f"Cleared {deleted} embeddings for model change", module="SemanticSearch")
|
||||
|
||||
def run_reindex_for_model_change():
|
||||
try:
|
||||
search = get_semantic_search(app_state.db, force_reload=True)
|
||||
results = search.generate_embeddings_batch(limit=EMBEDDING_BATCH_LIMIT)
|
||||
logger.info(f"Model change re-index complete: {results}", module="SemanticSearch")
|
||||
except Exception as e:
|
||||
logger.error(f"Model change re-index failed: {e}", module="SemanticSearch")
|
||||
|
||||
background_tasks.add_task(run_reindex_for_model_change)
|
||||
|
||||
logger.info(f"Semantic search settings updated: {new_settings} (re-indexing started)", module="SemanticSearch")
|
||||
return {"success": True, "settings": new_settings, "reindexing": True, "message": f"Model changed to {model}, re-indexing started"}
|
||||
|
||||
logger.info(f"Semantic search settings updated: {new_settings}", module="SemanticSearch")
|
||||
return {"success": True, "settings": new_settings, "reindexing": False}
|
||||
|
||||
|
||||
@router.post("/reindex")
|
||||
@limiter.limit("2/minute")
|
||||
@handle_exceptions
|
||||
async def reindex_embeddings(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict = Depends(require_admin)
|
||||
):
|
||||
"""Clear and regenerate all embeddings."""
|
||||
app_state = get_app_state()
|
||||
|
||||
if app_state.indexing_running:
|
||||
return {"success": False, "message": "Indexing already in progress", "already_running": True}
|
||||
|
||||
search = get_semantic_search(app_state.db)
|
||||
|
||||
with app_state.db.get_connection(for_write=True) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('DELETE FROM content_embeddings')
|
||||
deleted = cursor.rowcount
|
||||
logger.info(f"Cleared {deleted} embeddings for reindexing", module="SemanticSearch")
|
||||
|
||||
app_state.indexing_running = True
|
||||
app_state.indexing_start_time = time.time()
|
||||
|
||||
def run_reindex():
|
||||
try:
|
||||
results = search.generate_embeddings_batch(limit=EMBEDDING_BATCH_LIMIT)
|
||||
logger.info(f"Reindex complete: {results}", module="SemanticSearch")
|
||||
except Exception as e:
|
||||
logger.error(f"Reindex failed: {e}", module="SemanticSearch")
|
||||
finally:
|
||||
app_state.indexing_running = False
|
||||
app_state.indexing_start_time = None
|
||||
|
||||
background_tasks.add_task(run_reindex)
|
||||
|
||||
return {"success": True, "message": "Re-indexing started in background"}
|
||||
|
||||
|
||||
@router.post("/clear")
|
||||
@limiter.limit("2/minute")
|
||||
@handle_exceptions
|
||||
async def clear_embeddings(
|
||||
request: Request,
|
||||
current_user: Dict = Depends(require_admin)
|
||||
):
|
||||
"""Clear all embeddings."""
|
||||
app_state = get_app_state()
|
||||
|
||||
with app_state.db.get_connection(for_write=True) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('DELETE FROM content_embeddings')
|
||||
deleted = cursor.rowcount
|
||||
|
||||
logger.info(f"Cleared {deleted} embeddings", module="SemanticSearch")
|
||||
return {"success": True, "deleted": deleted}
|
||||
Reference in New Issue
Block a user