Files
media-downloader/modules/semantic_search.py
Todd 0d7b2b1aab Initial commit
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-29 22:42:55 -04:00

729 lines
25 KiB
Python

#!/usr/bin/env python3
"""
Semantic Search Module using CLIP
Provides image/video similarity search and natural language search capabilities
"""
import os
import struct
import numpy as np
from typing import Dict, List, Optional, Tuple, Any
from pathlib import Path
from PIL import Image
import threading
import queue
from datetime import datetime
from modules.universal_logger import get_logger
logger = get_logger('SemanticSearch')
# Global model instance (lazy loaded)
_clip_model = None
_clip_model_name = None
_model_lock = threading.Lock()
def get_configured_model_name() -> str:
"""Get the configured CLIP model name from settings"""
try:
from modules.settings_manager import SettingsManager
from pathlib import Path
# Use the correct database path
db_path = Path(__file__).parent.parent / 'database' / 'media_downloader.db'
settings_manager = SettingsManager(str(db_path))
semantic_settings = settings_manager.get('semantic_search', {})
if isinstance(semantic_settings, dict):
model = semantic_settings.get('model', 'clip-ViT-B-32')
logger.info(f"Configured CLIP model: {model}")
return model
return 'clip-ViT-B-32'
except Exception as e:
logger.error(f"Failed to get configured model: {e}")
return 'clip-ViT-B-32'
def get_clip_model(model_name: str = None):
"""Get or load the CLIP model (thread-safe singleton)"""
global _clip_model, _clip_model_name
if model_name is None:
model_name = get_configured_model_name()
# Check if we need to reload (model changed)
if _clip_model is not None and _clip_model_name != model_name:
with _model_lock:
logger.info(f"Model changed from {_clip_model_name} to {model_name}, reloading...")
_clip_model = None
_clip_model_name = None
if _clip_model is None:
with _model_lock:
if _clip_model is None:
logger.info(f"Loading CLIP model ({model_name})...")
try:
from sentence_transformers import SentenceTransformer
_clip_model = SentenceTransformer(model_name)
_clip_model_name = model_name
logger.info(f"CLIP model {model_name} loaded successfully")
except Exception as e:
logger.error(f"Failed to load CLIP model: {e}")
raise
return _clip_model
def embedding_to_bytes(embedding: np.ndarray) -> bytes:
"""Convert numpy embedding to bytes for database storage"""
return embedding.astype(np.float32).tobytes()
def bytes_to_embedding(data: bytes) -> np.ndarray:
"""Convert bytes from database back to numpy embedding"""
return np.frombuffer(data, dtype=np.float32)
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""Calculate cosine similarity between two embeddings"""
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
class SemanticSearch:
"""Semantic search engine using CLIP embeddings"""
SUPPORTED_IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp'}
SUPPORTED_VIDEO_EXTENSIONS = {'.mp4', '.mov', '.avi', '.mkv', '.webm', '.m4v'}
def __init__(self, unified_db):
"""
Initialize Semantic Search
Args:
unified_db: UnifiedDatabase instance
"""
self.db = unified_db
self.logger = get_logger('SemanticSearch')
self._model = None
@property
def model(self):
"""Lazy load CLIP model"""
if self._model is None:
self._model = get_clip_model()
return self._model
def get_image_embedding(self, image_path: str) -> Optional[np.ndarray]:
"""
Generate CLIP embedding for an image
Args:
image_path: Path to the image file
Returns:
Embedding vector or None on error
"""
try:
# Load and preprocess image
with Image.open(image_path) as image:
# Convert to RGB if necessary
if image.mode != 'RGB':
image = image.convert('RGB')
# Generate embedding
embedding = self.model.encode(image, convert_to_numpy=True)
return embedding
except Exception as e:
self.logger.debug(f"Failed to get embedding for {image_path}: {e}")
return None
def get_video_frame_embedding(self, video_path: str, frame_position: float = 0.1) -> Optional[np.ndarray]:
"""
Generate CLIP embedding for a video by extracting a frame
Args:
video_path: Path to the video file
frame_position: Position in video (0-1) to extract frame from
Returns:
Embedding vector or None on error
"""
# Try cv2 first, fall back to ffmpeg for codecs cv2 can't handle (e.g. AV1)
image = self._extract_frame_cv2(video_path, frame_position)
if image is None:
image = self._extract_frame_ffmpeg(video_path, frame_position)
if image is None:
return None
try:
embedding = self.model.encode(image, convert_to_numpy=True)
return embedding
except Exception as e:
self.logger.debug(f"Failed to encode video frame for {video_path}: {e}")
return None
finally:
# Clean up image to prevent memory leaks
if image is not None:
try:
image.close()
except Exception:
pass
def _extract_frame_cv2(self, video_path: str, frame_position: float) -> Optional[Image.Image]:
"""Extract frame using OpenCV"""
try:
import cv2
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total_frames <= 0:
cap.release()
return None
target_frame = int(total_frames * frame_position)
cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame)
ret, frame = cap.read()
cap.release()
if not ret:
return None
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return Image.fromarray(frame_rgb)
except Exception as e:
self.logger.debug(f"cv2 frame extraction failed for {video_path}: {e}")
return None
def _extract_frame_ffmpeg(self, video_path: str, frame_position: float) -> Optional[Image.Image]:
"""Extract frame using ffmpeg (fallback for codecs cv2 can't handle)"""
try:
import subprocess
import tempfile
# Get video duration
probe_cmd = [
'ffprobe', '-v', 'error', '-show_entries', 'format=duration',
'-of', 'default=noprint_wrappers=1:nokey=1', video_path
]
result = subprocess.run(probe_cmd, capture_output=True, text=True, timeout=10)
if result.returncode != 0:
return None
duration = float(result.stdout.strip())
seek_time = duration * frame_position
# Extract frame to temp file
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
tmp_path = tmp.name
extract_cmd = [
'ffmpeg', '-y', '-ss', str(seek_time), '-i', video_path,
'-vframes', '1', '-q:v', '2', tmp_path
]
result = subprocess.run(extract_cmd, capture_output=True, timeout=30)
if result.returncode != 0 or not os.path.exists(tmp_path):
return None
image = Image.open(tmp_path)
image.load() # Load into memory before deleting file
# Clean up temp file
try:
os.unlink(tmp_path)
except OSError:
pass # Best effort cleanup of temp file
if image.mode != 'RGB':
image = image.convert('RGB')
return image
except Exception as e:
self.logger.debug(f"ffmpeg frame extraction failed for {video_path}: {e}")
return None
def get_text_embedding(self, text: str) -> Optional[np.ndarray]:
"""
Generate CLIP embedding for text query
Args:
text: Text query
Returns:
Embedding vector or None on error
"""
try:
embedding = self.model.encode(text, convert_to_numpy=True)
return embedding
except Exception as e:
self.logger.error(f"Failed to get text embedding: {e}")
return None
def store_embedding(self, file_id: int, embedding: np.ndarray) -> bool:
"""
Store embedding in database
Args:
file_id: File inventory ID
embedding: Embedding vector
Returns:
Success status
"""
try:
embedding_bytes = embedding_to_bytes(embedding)
with self.db.get_connection(for_write=True) as conn:
cursor = conn.cursor()
cursor.execute('''
INSERT OR REPLACE INTO content_embeddings
(file_id, embedding, embedding_model, embedding_version, created_date)
VALUES (?, ?, 'clip-ViT-B-32', 1, CURRENT_TIMESTAMP)
''', (file_id, embedding_bytes))
return True
except Exception as e:
self.logger.error(f"Failed to store embedding for file {file_id}: {e}")
return False
def get_embedding(self, file_id: int) -> Optional[np.ndarray]:
"""
Get stored embedding from database
Args:
file_id: File inventory ID
Returns:
Embedding vector or None
"""
try:
with self.db.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('''
SELECT embedding FROM content_embeddings WHERE file_id = ?
''', (file_id,))
row = cursor.fetchone()
if row and row['embedding']:
return bytes_to_embedding(row['embedding'])
return None
except Exception as e:
self.logger.error(f"Failed to get embedding for file {file_id}: {e}")
return None
def delete_embedding(self, file_id: int) -> bool:
"""
Delete embedding for a file
Args:
file_id: File inventory ID
Returns:
True if deleted, False otherwise
"""
try:
with self.db.get_connection(for_write=True) as conn:
cursor = conn.cursor()
cursor.execute('DELETE FROM content_embeddings WHERE file_id = ?', (file_id,))
if cursor.rowcount > 0:
self.logger.debug(f"Deleted embedding for file_id {file_id}")
return True
return False
except Exception as e:
self.logger.error(f"Failed to delete embedding for file {file_id}: {e}")
return False
def delete_embedding_by_path(self, file_path: str) -> bool:
"""
Delete embedding for a file by its path
Args:
file_path: File path
Returns:
True if deleted, False otherwise
"""
try:
with self.db.get_connection(for_write=True) as conn:
cursor = conn.cursor()
# First get the file_id
cursor.execute('SELECT id FROM file_inventory WHERE file_path = ?', (file_path,))
row = cursor.fetchone()
if row:
cursor.execute('DELETE FROM content_embeddings WHERE file_id = ?', (row['id'],))
if cursor.rowcount > 0:
self.logger.debug(f"Deleted embedding for {file_path}")
return True
return False
except Exception as e:
self.logger.error(f"Failed to delete embedding for {file_path}: {e}")
return False
def generate_embedding_for_file(self, file_id: int, file_path: str, content_type: str = None) -> bool:
"""
Generate and store embedding for a single file
Args:
file_id: File inventory ID
file_path: Path to the file
content_type: Optional content type ('image' or 'video')
Returns:
True if embedding generated and stored successfully
"""
try:
if not os.path.exists(file_path):
self.logger.debug(f"File not found for embedding: {file_path}")
return False
ext = Path(file_path).suffix.lower()
# Determine file type
if content_type:
is_image = 'image' in content_type.lower()
is_video = 'video' in content_type.lower()
else:
is_image = ext in self.SUPPORTED_IMAGE_EXTENSIONS
is_video = ext in self.SUPPORTED_VIDEO_EXTENSIONS
embedding = None
if is_image:
embedding = self.get_image_embedding(file_path)
elif is_video:
embedding = self.get_video_frame_embedding(file_path)
if embedding is not None:
if self.store_embedding(file_id, embedding):
self.logger.debug(f"Generated embedding for file_id {file_id}: {Path(file_path).name}")
return True
return False
except Exception as e:
self.logger.error(f"Failed to generate embedding for file {file_id}: {e}")
return False
def get_embedding_stats(self) -> Dict:
"""Get statistics about embeddings in the database"""
try:
with self.db.get_connection() as conn:
cursor = conn.cursor()
# Total embeddings for files in 'final' location only
# (excludes embeddings for files moved to recycle bin or review)
cursor.execute('''
SELECT COUNT(*) FROM content_embeddings ce
JOIN file_inventory fi ON ce.file_id = fi.id
WHERE fi.location = 'final'
''')
total_embeddings = cursor.fetchone()[0]
# Total files in final location
cursor.execute("SELECT COUNT(*) FROM file_inventory WHERE location = 'final'")
total_files = cursor.fetchone()[0]
# Files without embeddings
cursor.execute('''
SELECT COUNT(*) FROM file_inventory fi
WHERE fi.location = 'final'
AND NOT EXISTS (SELECT 1 FROM content_embeddings ce WHERE ce.file_id = fi.id)
''')
missing_embeddings = cursor.fetchone()[0]
return {
'total_embeddings': total_embeddings,
'total_files': total_files,
'missing_embeddings': missing_embeddings,
'coverage_percent': round((total_embeddings / total_files * 100) if total_files > 0 else 0, 2)
}
except Exception as e:
self.logger.error(f"Failed to get embedding stats: {e}")
return {}
def generate_embeddings_batch(self, limit: int = 100, platform: str = None,
progress_callback=None) -> Dict:
"""
Generate embeddings for files that don't have them yet
Args:
limit: Maximum files to process
platform: Filter by platform
progress_callback: Optional callback(processed, total, current_file)
Returns:
Dict with success/error counts
"""
results = {'processed': 0, 'success': 0, 'errors': 0, 'skipped': 0}
try:
with self.db.get_connection() as conn:
cursor = conn.cursor()
# Get files without embeddings
query = '''
SELECT fi.id, fi.file_path, fi.content_type, fi.filename
FROM file_inventory fi
WHERE fi.location = 'final'
AND NOT EXISTS (SELECT 1 FROM content_embeddings ce WHERE ce.file_id = fi.id)
'''
params = []
if platform:
query += ' AND fi.platform = ?'
params.append(platform)
query += ' LIMIT ?'
params.append(limit)
cursor.execute(query, params)
files = cursor.fetchall()
total = len(files)
self.logger.info(f"Processing {total} files for embedding generation")
for i, file_row in enumerate(files):
file_id = file_row['id']
file_path = file_row['file_path']
content_type = file_row['content_type'] or ''
filename = file_row['filename'] or ''
results['processed'] += 1
if progress_callback:
progress_callback(i + 1, total, filename)
# Skip if file doesn't exist
if not os.path.exists(file_path):
results['skipped'] += 1
continue
# Determine file type
ext = Path(file_path).suffix.lower()
embedding = None
if ext in self.SUPPORTED_IMAGE_EXTENSIONS or 'image' in content_type.lower():
embedding = self.get_image_embedding(file_path)
elif ext in self.SUPPORTED_VIDEO_EXTENSIONS or 'video' in content_type.lower():
embedding = self.get_video_frame_embedding(file_path)
else:
results['skipped'] += 1
continue
if embedding is not None:
if self.store_embedding(file_id, embedding):
results['success'] += 1
else:
results['errors'] += 1
else:
results['errors'] += 1
self.logger.info(f"Embedding generation complete: {results}")
return results
except Exception as e:
self.logger.error(f"Failed to generate embeddings batch: {e}")
return results
def search_by_text(self, query: str, limit: int = 50, platform: str = None,
source: str = None, threshold: float = 0.2) -> List[Dict]:
"""
Search for images/videos using natural language
Args:
query: Natural language search query
limit: Maximum results
platform: Filter by platform
source: Filter by source
threshold: Minimum similarity score (0-1)
Returns:
List of files with similarity scores
"""
try:
# Get text embedding
query_embedding = self.get_text_embedding(query)
if query_embedding is None:
return []
return self._search_by_embedding(query_embedding, limit, platform, source, threshold)
except Exception as e:
self.logger.error(f"Text search failed: {e}")
return []
def search_by_image(self, image_path: str, limit: int = 50, platform: str = None,
source: str = None, threshold: float = 0.5) -> List[Dict]:
"""
Find similar images to a given image
Args:
image_path: Path to query image
limit: Maximum results
platform: Filter by platform
source: Filter by source
threshold: Minimum similarity score (0-1)
Returns:
List of similar files with scores
"""
try:
# Get image embedding
query_embedding = self.get_image_embedding(image_path)
if query_embedding is None:
return []
return self._search_by_embedding(query_embedding, limit, platform, source, threshold)
except Exception as e:
self.logger.error(f"Image search failed: {e}")
return []
def search_by_file_id(self, file_id: int, limit: int = 50, platform: str = None,
source: str = None, threshold: float = 0.5) -> List[Dict]:
"""
Find similar files to a file already in the database
Args:
file_id: File inventory ID
limit: Maximum results
platform: Filter by platform
source: Filter by source
threshold: Minimum similarity score (0-1)
Returns:
List of similar files with scores
"""
try:
# Get existing embedding
query_embedding = self.get_embedding(file_id)
if query_embedding is None:
# Try to generate it
with self.db.get_connection() as conn:
cursor = conn.cursor()
cursor.execute('SELECT file_path FROM file_inventory WHERE id = ?', (file_id,))
row = cursor.fetchone()
if row:
query_embedding = self.get_image_embedding(row['file_path'])
if query_embedding is None:
return []
results = self._search_by_embedding(query_embedding, limit + 1, platform, source, threshold)
# Remove the query file itself from results
return [r for r in results if r['id'] != file_id][:limit]
except Exception as e:
self.logger.error(f"Similar file search failed: {e}")
return []
def _search_by_embedding(self, query_embedding: np.ndarray, limit: int,
platform: str = None, source: str = None,
threshold: float = 0.2) -> List[Dict]:
"""
Internal search using embedding vector
Args:
query_embedding: Query embedding vector
limit: Maximum results
platform: Filter by platform
source: Filter by source
threshold: Minimum similarity score
Returns:
List of files with similarity scores, sorted by score
"""
try:
with self.db.get_connection() as conn:
cursor = conn.cursor()
# Build query to get all embeddings (with optional filters)
query = '''
SELECT ce.file_id, ce.embedding, fi.file_path, fi.filename,
fi.platform, fi.source, fi.content_type, fi.file_size
FROM content_embeddings ce
JOIN file_inventory fi ON fi.id = ce.file_id
WHERE fi.location = 'final'
'''
params = []
if platform:
query += ' AND fi.platform = ?'
params.append(platform)
if source:
query += ' AND fi.source = ?'
params.append(source)
cursor.execute(query, params)
results = []
for row in cursor.fetchall():
embedding = bytes_to_embedding(row['embedding'])
similarity = cosine_similarity(query_embedding, embedding)
if similarity >= threshold:
results.append({
'id': row['file_id'],
'file_path': row['file_path'],
'filename': row['filename'],
'platform': row['platform'],
'source': row['source'],
'content_type': row['content_type'],
'file_size': row['file_size'],
'similarity': round(similarity, 4)
})
# Sort by similarity descending
results.sort(key=lambda x: x['similarity'], reverse=True)
return results[:limit]
except Exception as e:
self.logger.error(f"Embedding search failed: {e}")
return []
# Global instance (lazy initialization)
_semantic_search = None
def reset_clip_model():
"""Reset the global CLIP model so it will be reloaded with new config"""
global _clip_model, _clip_model_name
with _model_lock:
_clip_model = None
_clip_model_name = None
logger.info("CLIP model cache cleared, will reload on next use")
def get_semantic_search(unified_db=None, force_reload=False):
"""Get or create global semantic search instance
Args:
unified_db: Database instance to use
force_reload: If True, recreate the instance (useful when model config changes)
"""
global _semantic_search
if _semantic_search is None or force_reload:
if force_reload:
# Also reset the CLIP model so it reloads with new config
reset_clip_model()
if unified_db is None:
from modules.unified_database import UnifiedDatabase
unified_db = UnifiedDatabase()
_semantic_search = SemanticSearch(unified_db)
return _semantic_search