728
modules/semantic_search.py
Normal file
728
modules/semantic_search.py
Normal file
@@ -0,0 +1,728 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Semantic Search Module using CLIP
|
||||
Provides image/video similarity search and natural language search capabilities
|
||||
"""
|
||||
|
||||
import os
|
||||
import struct
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
import threading
|
||||
import queue
|
||||
from datetime import datetime
|
||||
from modules.universal_logger import get_logger
|
||||
|
||||
logger = get_logger('SemanticSearch')
|
||||
|
||||
# Global model instance (lazy loaded)
|
||||
_clip_model = None
|
||||
_clip_model_name = None
|
||||
_model_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_configured_model_name() -> str:
|
||||
"""Get the configured CLIP model name from settings"""
|
||||
try:
|
||||
from modules.settings_manager import SettingsManager
|
||||
from pathlib import Path
|
||||
# Use the correct database path
|
||||
db_path = Path(__file__).parent.parent / 'database' / 'media_downloader.db'
|
||||
settings_manager = SettingsManager(str(db_path))
|
||||
semantic_settings = settings_manager.get('semantic_search', {})
|
||||
if isinstance(semantic_settings, dict):
|
||||
model = semantic_settings.get('model', 'clip-ViT-B-32')
|
||||
logger.info(f"Configured CLIP model: {model}")
|
||||
return model
|
||||
return 'clip-ViT-B-32'
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get configured model: {e}")
|
||||
return 'clip-ViT-B-32'
|
||||
|
||||
|
||||
def get_clip_model(model_name: str = None):
|
||||
"""Get or load the CLIP model (thread-safe singleton)"""
|
||||
global _clip_model, _clip_model_name
|
||||
|
||||
if model_name is None:
|
||||
model_name = get_configured_model_name()
|
||||
|
||||
# Check if we need to reload (model changed)
|
||||
if _clip_model is not None and _clip_model_name != model_name:
|
||||
with _model_lock:
|
||||
logger.info(f"Model changed from {_clip_model_name} to {model_name}, reloading...")
|
||||
_clip_model = None
|
||||
_clip_model_name = None
|
||||
|
||||
if _clip_model is None:
|
||||
with _model_lock:
|
||||
if _clip_model is None:
|
||||
logger.info(f"Loading CLIP model ({model_name})...")
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
_clip_model = SentenceTransformer(model_name)
|
||||
_clip_model_name = model_name
|
||||
logger.info(f"CLIP model {model_name} loaded successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load CLIP model: {e}")
|
||||
raise
|
||||
|
||||
return _clip_model
|
||||
|
||||
|
||||
def embedding_to_bytes(embedding: np.ndarray) -> bytes:
|
||||
"""Convert numpy embedding to bytes for database storage"""
|
||||
return embedding.astype(np.float32).tobytes()
|
||||
|
||||
|
||||
def bytes_to_embedding(data: bytes) -> np.ndarray:
|
||||
"""Convert bytes from database back to numpy embedding"""
|
||||
return np.frombuffer(data, dtype=np.float32)
|
||||
|
||||
|
||||
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||
"""Calculate cosine similarity between two embeddings"""
|
||||
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
|
||||
|
||||
|
||||
class SemanticSearch:
|
||||
"""Semantic search engine using CLIP embeddings"""
|
||||
|
||||
SUPPORTED_IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp'}
|
||||
SUPPORTED_VIDEO_EXTENSIONS = {'.mp4', '.mov', '.avi', '.mkv', '.webm', '.m4v'}
|
||||
|
||||
def __init__(self, unified_db):
|
||||
"""
|
||||
Initialize Semantic Search
|
||||
|
||||
Args:
|
||||
unified_db: UnifiedDatabase instance
|
||||
"""
|
||||
self.db = unified_db
|
||||
self.logger = get_logger('SemanticSearch')
|
||||
self._model = None
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""Lazy load CLIP model"""
|
||||
if self._model is None:
|
||||
self._model = get_clip_model()
|
||||
return self._model
|
||||
|
||||
def get_image_embedding(self, image_path: str) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Generate CLIP embedding for an image
|
||||
|
||||
Args:
|
||||
image_path: Path to the image file
|
||||
|
||||
Returns:
|
||||
Embedding vector or None on error
|
||||
"""
|
||||
try:
|
||||
# Load and preprocess image
|
||||
with Image.open(image_path) as image:
|
||||
# Convert to RGB if necessary
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
|
||||
# Generate embedding
|
||||
embedding = self.model.encode(image, convert_to_numpy=True)
|
||||
|
||||
return embedding
|
||||
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Failed to get embedding for {image_path}: {e}")
|
||||
return None
|
||||
|
||||
def get_video_frame_embedding(self, video_path: str, frame_position: float = 0.1) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Generate CLIP embedding for a video by extracting a frame
|
||||
|
||||
Args:
|
||||
video_path: Path to the video file
|
||||
frame_position: Position in video (0-1) to extract frame from
|
||||
|
||||
Returns:
|
||||
Embedding vector or None on error
|
||||
"""
|
||||
# Try cv2 first, fall back to ffmpeg for codecs cv2 can't handle (e.g. AV1)
|
||||
image = self._extract_frame_cv2(video_path, frame_position)
|
||||
if image is None:
|
||||
image = self._extract_frame_ffmpeg(video_path, frame_position)
|
||||
|
||||
if image is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
embedding = self.model.encode(image, convert_to_numpy=True)
|
||||
return embedding
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Failed to encode video frame for {video_path}: {e}")
|
||||
return None
|
||||
finally:
|
||||
# Clean up image to prevent memory leaks
|
||||
if image is not None:
|
||||
try:
|
||||
image.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _extract_frame_cv2(self, video_path: str, frame_position: float) -> Optional[Image.Image]:
|
||||
"""Extract frame using OpenCV"""
|
||||
try:
|
||||
import cv2
|
||||
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
if not cap.isOpened():
|
||||
return None
|
||||
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
if total_frames <= 0:
|
||||
cap.release()
|
||||
return None
|
||||
|
||||
target_frame = int(total_frames * frame_position)
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame)
|
||||
|
||||
ret, frame = cap.read()
|
||||
cap.release()
|
||||
|
||||
if not ret:
|
||||
return None
|
||||
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
return Image.fromarray(frame_rgb)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.debug(f"cv2 frame extraction failed for {video_path}: {e}")
|
||||
return None
|
||||
|
||||
def _extract_frame_ffmpeg(self, video_path: str, frame_position: float) -> Optional[Image.Image]:
|
||||
"""Extract frame using ffmpeg (fallback for codecs cv2 can't handle)"""
|
||||
try:
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
# Get video duration
|
||||
probe_cmd = [
|
||||
'ffprobe', '-v', 'error', '-show_entries', 'format=duration',
|
||||
'-of', 'default=noprint_wrappers=1:nokey=1', video_path
|
||||
]
|
||||
result = subprocess.run(probe_cmd, capture_output=True, text=True, timeout=10)
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
|
||||
duration = float(result.stdout.strip())
|
||||
seek_time = duration * frame_position
|
||||
|
||||
# Extract frame to temp file
|
||||
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
|
||||
tmp_path = tmp.name
|
||||
|
||||
extract_cmd = [
|
||||
'ffmpeg', '-y', '-ss', str(seek_time), '-i', video_path,
|
||||
'-vframes', '1', '-q:v', '2', tmp_path
|
||||
]
|
||||
result = subprocess.run(extract_cmd, capture_output=True, timeout=30)
|
||||
|
||||
if result.returncode != 0 or not os.path.exists(tmp_path):
|
||||
return None
|
||||
|
||||
image = Image.open(tmp_path)
|
||||
image.load() # Load into memory before deleting file
|
||||
|
||||
# Clean up temp file
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass # Best effort cleanup of temp file
|
||||
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
|
||||
return image
|
||||
|
||||
except Exception as e:
|
||||
self.logger.debug(f"ffmpeg frame extraction failed for {video_path}: {e}")
|
||||
return None
|
||||
|
||||
def get_text_embedding(self, text: str) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Generate CLIP embedding for text query
|
||||
|
||||
Args:
|
||||
text: Text query
|
||||
|
||||
Returns:
|
||||
Embedding vector or None on error
|
||||
"""
|
||||
try:
|
||||
embedding = self.model.encode(text, convert_to_numpy=True)
|
||||
return embedding
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get text embedding: {e}")
|
||||
return None
|
||||
|
||||
def store_embedding(self, file_id: int, embedding: np.ndarray) -> bool:
|
||||
"""
|
||||
Store embedding in database
|
||||
|
||||
Args:
|
||||
file_id: File inventory ID
|
||||
embedding: Embedding vector
|
||||
|
||||
Returns:
|
||||
Success status
|
||||
"""
|
||||
try:
|
||||
embedding_bytes = embedding_to_bytes(embedding)
|
||||
|
||||
with self.db.get_connection(for_write=True) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
INSERT OR REPLACE INTO content_embeddings
|
||||
(file_id, embedding, embedding_model, embedding_version, created_date)
|
||||
VALUES (?, ?, 'clip-ViT-B-32', 1, CURRENT_TIMESTAMP)
|
||||
''', (file_id, embedding_bytes))
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to store embedding for file {file_id}: {e}")
|
||||
return False
|
||||
|
||||
def get_embedding(self, file_id: int) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Get stored embedding from database
|
||||
|
||||
Args:
|
||||
file_id: File inventory ID
|
||||
|
||||
Returns:
|
||||
Embedding vector or None
|
||||
"""
|
||||
try:
|
||||
with self.db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
SELECT embedding FROM content_embeddings WHERE file_id = ?
|
||||
''', (file_id,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row and row['embedding']:
|
||||
return bytes_to_embedding(row['embedding'])
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get embedding for file {file_id}: {e}")
|
||||
return None
|
||||
|
||||
def delete_embedding(self, file_id: int) -> bool:
|
||||
"""
|
||||
Delete embedding for a file
|
||||
|
||||
Args:
|
||||
file_id: File inventory ID
|
||||
|
||||
Returns:
|
||||
True if deleted, False otherwise
|
||||
"""
|
||||
try:
|
||||
with self.db.get_connection(for_write=True) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('DELETE FROM content_embeddings WHERE file_id = ?', (file_id,))
|
||||
if cursor.rowcount > 0:
|
||||
self.logger.debug(f"Deleted embedding for file_id {file_id}")
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to delete embedding for file {file_id}: {e}")
|
||||
return False
|
||||
|
||||
def delete_embedding_by_path(self, file_path: str) -> bool:
|
||||
"""
|
||||
Delete embedding for a file by its path
|
||||
|
||||
Args:
|
||||
file_path: File path
|
||||
|
||||
Returns:
|
||||
True if deleted, False otherwise
|
||||
"""
|
||||
try:
|
||||
with self.db.get_connection(for_write=True) as conn:
|
||||
cursor = conn.cursor()
|
||||
# First get the file_id
|
||||
cursor.execute('SELECT id FROM file_inventory WHERE file_path = ?', (file_path,))
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
cursor.execute('DELETE FROM content_embeddings WHERE file_id = ?', (row['id'],))
|
||||
if cursor.rowcount > 0:
|
||||
self.logger.debug(f"Deleted embedding for {file_path}")
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to delete embedding for {file_path}: {e}")
|
||||
return False
|
||||
|
||||
def generate_embedding_for_file(self, file_id: int, file_path: str, content_type: str = None) -> bool:
|
||||
"""
|
||||
Generate and store embedding for a single file
|
||||
|
||||
Args:
|
||||
file_id: File inventory ID
|
||||
file_path: Path to the file
|
||||
content_type: Optional content type ('image' or 'video')
|
||||
|
||||
Returns:
|
||||
True if embedding generated and stored successfully
|
||||
"""
|
||||
try:
|
||||
if not os.path.exists(file_path):
|
||||
self.logger.debug(f"File not found for embedding: {file_path}")
|
||||
return False
|
||||
|
||||
ext = Path(file_path).suffix.lower()
|
||||
|
||||
# Determine file type
|
||||
if content_type:
|
||||
is_image = 'image' in content_type.lower()
|
||||
is_video = 'video' in content_type.lower()
|
||||
else:
|
||||
is_image = ext in self.SUPPORTED_IMAGE_EXTENSIONS
|
||||
is_video = ext in self.SUPPORTED_VIDEO_EXTENSIONS
|
||||
|
||||
embedding = None
|
||||
if is_image:
|
||||
embedding = self.get_image_embedding(file_path)
|
||||
elif is_video:
|
||||
embedding = self.get_video_frame_embedding(file_path)
|
||||
|
||||
if embedding is not None:
|
||||
if self.store_embedding(file_id, embedding):
|
||||
self.logger.debug(f"Generated embedding for file_id {file_id}: {Path(file_path).name}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to generate embedding for file {file_id}: {e}")
|
||||
return False
|
||||
|
||||
def get_embedding_stats(self) -> Dict:
|
||||
"""Get statistics about embeddings in the database"""
|
||||
try:
|
||||
with self.db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Total embeddings for files in 'final' location only
|
||||
# (excludes embeddings for files moved to recycle bin or review)
|
||||
cursor.execute('''
|
||||
SELECT COUNT(*) FROM content_embeddings ce
|
||||
JOIN file_inventory fi ON ce.file_id = fi.id
|
||||
WHERE fi.location = 'final'
|
||||
''')
|
||||
total_embeddings = cursor.fetchone()[0]
|
||||
|
||||
# Total files in final location
|
||||
cursor.execute("SELECT COUNT(*) FROM file_inventory WHERE location = 'final'")
|
||||
total_files = cursor.fetchone()[0]
|
||||
|
||||
# Files without embeddings
|
||||
cursor.execute('''
|
||||
SELECT COUNT(*) FROM file_inventory fi
|
||||
WHERE fi.location = 'final'
|
||||
AND NOT EXISTS (SELECT 1 FROM content_embeddings ce WHERE ce.file_id = fi.id)
|
||||
''')
|
||||
missing_embeddings = cursor.fetchone()[0]
|
||||
|
||||
return {
|
||||
'total_embeddings': total_embeddings,
|
||||
'total_files': total_files,
|
||||
'missing_embeddings': missing_embeddings,
|
||||
'coverage_percent': round((total_embeddings / total_files * 100) if total_files > 0 else 0, 2)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get embedding stats: {e}")
|
||||
return {}
|
||||
|
||||
def generate_embeddings_batch(self, limit: int = 100, platform: str = None,
|
||||
progress_callback=None) -> Dict:
|
||||
"""
|
||||
Generate embeddings for files that don't have them yet
|
||||
|
||||
Args:
|
||||
limit: Maximum files to process
|
||||
platform: Filter by platform
|
||||
progress_callback: Optional callback(processed, total, current_file)
|
||||
|
||||
Returns:
|
||||
Dict with success/error counts
|
||||
"""
|
||||
results = {'processed': 0, 'success': 0, 'errors': 0, 'skipped': 0}
|
||||
|
||||
try:
|
||||
with self.db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get files without embeddings
|
||||
query = '''
|
||||
SELECT fi.id, fi.file_path, fi.content_type, fi.filename
|
||||
FROM file_inventory fi
|
||||
WHERE fi.location = 'final'
|
||||
AND NOT EXISTS (SELECT 1 FROM content_embeddings ce WHERE ce.file_id = fi.id)
|
||||
'''
|
||||
params = []
|
||||
|
||||
if platform:
|
||||
query += ' AND fi.platform = ?'
|
||||
params.append(platform)
|
||||
|
||||
query += ' LIMIT ?'
|
||||
params.append(limit)
|
||||
|
||||
cursor.execute(query, params)
|
||||
files = cursor.fetchall()
|
||||
|
||||
total = len(files)
|
||||
self.logger.info(f"Processing {total} files for embedding generation")
|
||||
|
||||
for i, file_row in enumerate(files):
|
||||
file_id = file_row['id']
|
||||
file_path = file_row['file_path']
|
||||
content_type = file_row['content_type'] or ''
|
||||
filename = file_row['filename'] or ''
|
||||
|
||||
results['processed'] += 1
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(i + 1, total, filename)
|
||||
|
||||
# Skip if file doesn't exist
|
||||
if not os.path.exists(file_path):
|
||||
results['skipped'] += 1
|
||||
continue
|
||||
|
||||
# Determine file type
|
||||
ext = Path(file_path).suffix.lower()
|
||||
|
||||
embedding = None
|
||||
|
||||
if ext in self.SUPPORTED_IMAGE_EXTENSIONS or 'image' in content_type.lower():
|
||||
embedding = self.get_image_embedding(file_path)
|
||||
elif ext in self.SUPPORTED_VIDEO_EXTENSIONS or 'video' in content_type.lower():
|
||||
embedding = self.get_video_frame_embedding(file_path)
|
||||
else:
|
||||
results['skipped'] += 1
|
||||
continue
|
||||
|
||||
if embedding is not None:
|
||||
if self.store_embedding(file_id, embedding):
|
||||
results['success'] += 1
|
||||
else:
|
||||
results['errors'] += 1
|
||||
else:
|
||||
results['errors'] += 1
|
||||
|
||||
self.logger.info(f"Embedding generation complete: {results}")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to generate embeddings batch: {e}")
|
||||
return results
|
||||
|
||||
def search_by_text(self, query: str, limit: int = 50, platform: str = None,
|
||||
source: str = None, threshold: float = 0.2) -> List[Dict]:
|
||||
"""
|
||||
Search for images/videos using natural language
|
||||
|
||||
Args:
|
||||
query: Natural language search query
|
||||
limit: Maximum results
|
||||
platform: Filter by platform
|
||||
source: Filter by source
|
||||
threshold: Minimum similarity score (0-1)
|
||||
|
||||
Returns:
|
||||
List of files with similarity scores
|
||||
"""
|
||||
try:
|
||||
# Get text embedding
|
||||
query_embedding = self.get_text_embedding(query)
|
||||
if query_embedding is None:
|
||||
return []
|
||||
|
||||
return self._search_by_embedding(query_embedding, limit, platform, source, threshold)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Text search failed: {e}")
|
||||
return []
|
||||
|
||||
def search_by_image(self, image_path: str, limit: int = 50, platform: str = None,
|
||||
source: str = None, threshold: float = 0.5) -> List[Dict]:
|
||||
"""
|
||||
Find similar images to a given image
|
||||
|
||||
Args:
|
||||
image_path: Path to query image
|
||||
limit: Maximum results
|
||||
platform: Filter by platform
|
||||
source: Filter by source
|
||||
threshold: Minimum similarity score (0-1)
|
||||
|
||||
Returns:
|
||||
List of similar files with scores
|
||||
"""
|
||||
try:
|
||||
# Get image embedding
|
||||
query_embedding = self.get_image_embedding(image_path)
|
||||
if query_embedding is None:
|
||||
return []
|
||||
|
||||
return self._search_by_embedding(query_embedding, limit, platform, source, threshold)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Image search failed: {e}")
|
||||
return []
|
||||
|
||||
def search_by_file_id(self, file_id: int, limit: int = 50, platform: str = None,
|
||||
source: str = None, threshold: float = 0.5) -> List[Dict]:
|
||||
"""
|
||||
Find similar files to a file already in the database
|
||||
|
||||
Args:
|
||||
file_id: File inventory ID
|
||||
limit: Maximum results
|
||||
platform: Filter by platform
|
||||
source: Filter by source
|
||||
threshold: Minimum similarity score (0-1)
|
||||
|
||||
Returns:
|
||||
List of similar files with scores
|
||||
"""
|
||||
try:
|
||||
# Get existing embedding
|
||||
query_embedding = self.get_embedding(file_id)
|
||||
|
||||
if query_embedding is None:
|
||||
# Try to generate it
|
||||
with self.db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('SELECT file_path FROM file_inventory WHERE id = ?', (file_id,))
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
query_embedding = self.get_image_embedding(row['file_path'])
|
||||
|
||||
if query_embedding is None:
|
||||
return []
|
||||
|
||||
results = self._search_by_embedding(query_embedding, limit + 1, platform, source, threshold)
|
||||
|
||||
# Remove the query file itself from results
|
||||
return [r for r in results if r['id'] != file_id][:limit]
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Similar file search failed: {e}")
|
||||
return []
|
||||
|
||||
def _search_by_embedding(self, query_embedding: np.ndarray, limit: int,
|
||||
platform: str = None, source: str = None,
|
||||
threshold: float = 0.2) -> List[Dict]:
|
||||
"""
|
||||
Internal search using embedding vector
|
||||
|
||||
Args:
|
||||
query_embedding: Query embedding vector
|
||||
limit: Maximum results
|
||||
platform: Filter by platform
|
||||
source: Filter by source
|
||||
threshold: Minimum similarity score
|
||||
|
||||
Returns:
|
||||
List of files with similarity scores, sorted by score
|
||||
"""
|
||||
try:
|
||||
with self.db.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Build query to get all embeddings (with optional filters)
|
||||
query = '''
|
||||
SELECT ce.file_id, ce.embedding, fi.file_path, fi.filename,
|
||||
fi.platform, fi.source, fi.content_type, fi.file_size
|
||||
FROM content_embeddings ce
|
||||
JOIN file_inventory fi ON fi.id = ce.file_id
|
||||
WHERE fi.location = 'final'
|
||||
'''
|
||||
params = []
|
||||
|
||||
if platform:
|
||||
query += ' AND fi.platform = ?'
|
||||
params.append(platform)
|
||||
if source:
|
||||
query += ' AND fi.source = ?'
|
||||
params.append(source)
|
||||
|
||||
cursor.execute(query, params)
|
||||
|
||||
results = []
|
||||
for row in cursor.fetchall():
|
||||
embedding = bytes_to_embedding(row['embedding'])
|
||||
similarity = cosine_similarity(query_embedding, embedding)
|
||||
|
||||
if similarity >= threshold:
|
||||
results.append({
|
||||
'id': row['file_id'],
|
||||
'file_path': row['file_path'],
|
||||
'filename': row['filename'],
|
||||
'platform': row['platform'],
|
||||
'source': row['source'],
|
||||
'content_type': row['content_type'],
|
||||
'file_size': row['file_size'],
|
||||
'similarity': round(similarity, 4)
|
||||
})
|
||||
|
||||
# Sort by similarity descending
|
||||
results.sort(key=lambda x: x['similarity'], reverse=True)
|
||||
|
||||
return results[:limit]
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Embedding search failed: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# Global instance (lazy initialization)
|
||||
_semantic_search = None
|
||||
|
||||
|
||||
def reset_clip_model():
|
||||
"""Reset the global CLIP model so it will be reloaded with new config"""
|
||||
global _clip_model, _clip_model_name
|
||||
with _model_lock:
|
||||
_clip_model = None
|
||||
_clip_model_name = None
|
||||
logger.info("CLIP model cache cleared, will reload on next use")
|
||||
|
||||
|
||||
def get_semantic_search(unified_db=None, force_reload=False):
|
||||
"""Get or create global semantic search instance
|
||||
|
||||
Args:
|
||||
unified_db: Database instance to use
|
||||
force_reload: If True, recreate the instance (useful when model config changes)
|
||||
"""
|
||||
global _semantic_search
|
||||
if _semantic_search is None or force_reload:
|
||||
if force_reload:
|
||||
# Also reset the CLIP model so it reloads with new config
|
||||
reset_clip_model()
|
||||
if unified_db is None:
|
||||
from modules.unified_database import UnifiedDatabase
|
||||
unified_db = UnifiedDatabase()
|
||||
_semantic_search = SemanticSearch(unified_db)
|
||||
return _semantic_search
|
||||
Reference in New Issue
Block a user