479 lines
16 KiB
Python
479 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Base Module - Shared functionality for all media downloader modules
|
|
|
|
Provides:
|
|
- LoggingMixin: Consistent logging with universal logger and backwards-compatible callback support
|
|
- CookieManagerMixin: Centralized cookie loading/saving for scrapers
|
|
- RateLimitMixin: Smart delay handling for rate limiting
|
|
- DeferredDownloadsMixin: Track downloads for batch database recording
|
|
"""
|
|
|
|
import random
|
|
import time
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from modules.universal_logger import get_logger
|
|
|
|
|
|
class LoggingMixin:
|
|
"""
|
|
Mixin providing consistent logging across all modules.
|
|
|
|
Uses the universal logger for all logging, with optional callback support
|
|
for backwards compatibility with existing code.
|
|
|
|
Usage:
|
|
class MyModule(LoggingMixin):
|
|
def __init__(self, log_callback=None):
|
|
self._init_logger('MyModule', log_callback)
|
|
# ... rest of init
|
|
|
|
def do_something(self):
|
|
self.log("Starting operation", "info")
|
|
# ...
|
|
self.log("Operation complete", "success")
|
|
"""
|
|
|
|
_logger_name: str = 'Unknown'
|
|
_default_module: str = 'Core'
|
|
logger = None
|
|
log_callback = None
|
|
show_debug: bool = True
|
|
|
|
def _init_logger(self, logger_name: str, log_callback=None, default_module: str = 'Core', show_debug: bool = True):
|
|
"""
|
|
Initialize logging for this module.
|
|
|
|
Args:
|
|
logger_name: Name for the logger (e.g., 'Instagram', 'TikTok', 'Forum')
|
|
log_callback: Optional callback function for backwards compatibility
|
|
default_module: Default module name for log messages (default: 'Core')
|
|
show_debug: Whether to show debug messages (default: True)
|
|
"""
|
|
self._logger_name = logger_name
|
|
self._default_module = default_module
|
|
self.log_callback = log_callback
|
|
self.show_debug = show_debug
|
|
self.logger = get_logger(logger_name)
|
|
|
|
def log(self, message: str, level: str = "info", module: str = None):
|
|
"""
|
|
Log a message using universal logger with optional callback.
|
|
|
|
Args:
|
|
message: The message to log
|
|
level: Log level ('debug', 'info', 'warning', 'error', 'success', 'critical')
|
|
module: Module name for the log entry (default: uses _default_module)
|
|
"""
|
|
level_lower = level.lower()
|
|
|
|
# Skip debug messages if show_debug is False
|
|
if level_lower == "debug" and not self.show_debug:
|
|
return
|
|
|
|
# Use universal logger (always log here first)
|
|
actual_module = module or self._default_module
|
|
self.logger.log(message, level.upper(), module=actual_module)
|
|
|
|
# Call log_callback for backwards compatibility
|
|
if self.log_callback:
|
|
self.log_callback(f"[{self._logger_name}] {message}", level_lower)
|
|
|
|
|
|
class CookieManagerMixin:
|
|
"""
|
|
Mixin providing centralized cookie management for scrapers.
|
|
|
|
Handles loading and saving cookies to/from the database.
|
|
|
|
Usage:
|
|
class MyScraper(LoggingMixin, CookieManagerMixin):
|
|
def __init__(self, unified_db=None):
|
|
self._init_logger('MyScraper')
|
|
self._init_cookie_manager(unified_db, 'my_scraper')
|
|
self._load_cookies_from_db()
|
|
|
|
def after_auth(self, cookies):
|
|
self._save_cookies_to_db(cookies)
|
|
"""
|
|
|
|
unified_db = None
|
|
scraper_id: str = ''
|
|
cf_handler = None # CloudflareHandler if used
|
|
user_agent: str = ''
|
|
|
|
def _init_cookie_manager(self, unified_db, scraper_id: str, cf_handler=None, user_agent: str = ''):
|
|
"""
|
|
Initialize cookie management.
|
|
|
|
Args:
|
|
unified_db: UnifiedDatabase instance
|
|
scraper_id: ID for this scraper in database
|
|
cf_handler: Optional CloudflareHandler instance
|
|
user_agent: User agent string
|
|
"""
|
|
self.unified_db = unified_db
|
|
self.scraper_id = scraper_id
|
|
self.cf_handler = cf_handler
|
|
self.user_agent = user_agent
|
|
|
|
def _load_cookies_from_db(self) -> Optional[List[Dict]]:
|
|
"""
|
|
Load cookies from database if available.
|
|
|
|
Returns:
|
|
List of cookie dicts or None if not available
|
|
"""
|
|
if not self.unified_db:
|
|
return None
|
|
|
|
try:
|
|
cookies = self.unified_db.get_scraper_cookies(self.scraper_id)
|
|
if cookies:
|
|
# Load into CloudflareHandler if available
|
|
if self.cf_handler:
|
|
self.cf_handler._cookies = cookies
|
|
if hasattr(self, 'log'):
|
|
self.log(f"Loaded {len(cookies)} cookies from database", "debug")
|
|
return cookies
|
|
except Exception as e:
|
|
if hasattr(self, 'log'):
|
|
self.log(f"Error loading cookies from database: {e}", "warning")
|
|
|
|
return None
|
|
|
|
def _save_cookies_to_db(self, cookies: List[Dict], merge: bool = True, user_agent: str = None):
|
|
"""
|
|
Save cookies to database.
|
|
|
|
Args:
|
|
cookies: List of cookie dicts
|
|
merge: Whether to merge with existing cookies
|
|
user_agent: User agent to associate with cookies (important for cf_clearance).
|
|
If not provided, uses self.user_agent as fallback.
|
|
"""
|
|
if not self.unified_db:
|
|
return
|
|
|
|
try:
|
|
# Use provided user_agent or fall back to self.user_agent
|
|
ua = user_agent or self.user_agent
|
|
self.unified_db.save_scraper_cookies(
|
|
self.scraper_id,
|
|
cookies,
|
|
user_agent=ua,
|
|
merge=merge
|
|
)
|
|
if hasattr(self, 'log'):
|
|
self.log(f"Saved {len(cookies)} cookies to database (UA: {ua[:50] if ua else 'None'}...)", "debug")
|
|
except Exception as e:
|
|
if hasattr(self, 'log'):
|
|
self.log(f"Error saving cookies to database: {e}", "warning")
|
|
|
|
def _cookies_expired(self) -> bool:
|
|
"""
|
|
Check if cookies are expired.
|
|
|
|
Returns:
|
|
True if expired, False otherwise
|
|
"""
|
|
if self.cf_handler:
|
|
return self.cf_handler.cookies_expired()
|
|
return True
|
|
|
|
def _get_cookies_for_requests(self) -> Dict[str, str]:
|
|
"""
|
|
Get cookies in format for requests library.
|
|
|
|
Returns:
|
|
Dict of cookie name -> value
|
|
"""
|
|
if self.cf_handler:
|
|
return self.cf_handler.get_cookies_dict()
|
|
return {}
|
|
|
|
|
|
class RateLimitMixin:
|
|
"""
|
|
Mixin providing smart rate limiting for scrapers.
|
|
|
|
Handles delays between requests to avoid detection and rate limiting.
|
|
|
|
Usage:
|
|
class MyScraper(LoggingMixin, RateLimitMixin):
|
|
def __init__(self):
|
|
self._init_logger('MyScraper')
|
|
self._init_rate_limiter(min_delay=5, max_delay=15, batch_delay=30)
|
|
|
|
def download_batch(self, items):
|
|
for i, item in enumerate(items):
|
|
self.download_item(item)
|
|
is_batch_end = (i + 1) % 10 == 0
|
|
self._smart_delay(is_batch_end)
|
|
"""
|
|
|
|
min_delay: float = 5.0
|
|
max_delay: float = 15.0
|
|
batch_delay_min: float = 30.0
|
|
batch_delay_max: float = 60.0
|
|
error_delay: float = 120.0
|
|
|
|
def _init_rate_limiter(
|
|
self,
|
|
min_delay: float = 5.0,
|
|
max_delay: float = 15.0,
|
|
batch_delay_min: float = 30.0,
|
|
batch_delay_max: float = 60.0,
|
|
error_delay: float = 120.0
|
|
):
|
|
"""
|
|
Initialize rate limiting.
|
|
|
|
Args:
|
|
min_delay: Minimum delay between requests (seconds)
|
|
max_delay: Maximum delay between requests (seconds)
|
|
batch_delay_min: Minimum delay between batches (seconds)
|
|
batch_delay_max: Maximum delay between batches (seconds)
|
|
error_delay: Delay after errors (seconds)
|
|
"""
|
|
self.min_delay = min_delay
|
|
self.max_delay = max_delay
|
|
self.batch_delay_min = batch_delay_min
|
|
self.batch_delay_max = batch_delay_max
|
|
self.error_delay = error_delay
|
|
|
|
def _smart_delay(self, is_batch_end: bool = False, had_error: bool = False):
|
|
"""
|
|
Apply smart delay between requests.
|
|
|
|
Args:
|
|
is_batch_end: True if this is the end of a batch
|
|
had_error: True if there was an error (uses longer delay)
|
|
"""
|
|
if had_error:
|
|
delay = self.error_delay
|
|
elif is_batch_end:
|
|
delay = random.uniform(self.batch_delay_min, self.batch_delay_max)
|
|
else:
|
|
delay = random.uniform(self.min_delay, self.max_delay)
|
|
|
|
if hasattr(self, 'log'):
|
|
self.log(f"Waiting {delay:.1f}s before next request", "debug")
|
|
|
|
time.sleep(delay)
|
|
|
|
def _delay_after_error(self):
|
|
"""Apply error delay."""
|
|
self._smart_delay(had_error=True)
|
|
|
|
def _delay_between_items(self):
|
|
"""Apply normal delay between items."""
|
|
self._smart_delay(is_batch_end=False)
|
|
|
|
def _delay_between_batches(self):
|
|
"""Apply batch delay."""
|
|
self._smart_delay(is_batch_end=True)
|
|
|
|
|
|
class DeferredDownloadsMixin:
|
|
"""
|
|
Mixin for tracking downloads to be recorded in batch.
|
|
|
|
Allows deferring database writes for better performance.
|
|
|
|
Usage:
|
|
class MyScraper(LoggingMixin, DeferredDownloadsMixin):
|
|
def __init__(self):
|
|
self._init_logger('MyScraper')
|
|
self._init_deferred_downloads()
|
|
|
|
def download_file(self, url, path):
|
|
# ... download logic ...
|
|
self._add_pending_download({
|
|
'platform': 'my_platform',
|
|
'source': 'username',
|
|
'file_path': str(path),
|
|
# ... other fields ...
|
|
})
|
|
|
|
def finish_batch(self):
|
|
downloads = self.get_pending_downloads()
|
|
self.db.record_downloads_batch(downloads)
|
|
self.clear_pending_downloads()
|
|
"""
|
|
|
|
pending_downloads: List[Dict] = None
|
|
|
|
def _init_deferred_downloads(self):
|
|
"""Initialize deferred downloads tracking."""
|
|
self.pending_downloads = []
|
|
|
|
def _add_pending_download(self, download_info: Dict[str, Any]):
|
|
"""
|
|
Add a download to pending list.
|
|
|
|
Args:
|
|
download_info: Dict with download metadata
|
|
"""
|
|
if self.pending_downloads is None:
|
|
self.pending_downloads = []
|
|
self.pending_downloads.append(download_info)
|
|
|
|
def get_pending_downloads(self) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get all pending downloads.
|
|
|
|
Returns:
|
|
List of pending download dicts
|
|
"""
|
|
return self.pending_downloads or []
|
|
|
|
def clear_pending_downloads(self):
|
|
"""Clear pending downloads list."""
|
|
self.pending_downloads = []
|
|
|
|
def has_pending_downloads(self) -> bool:
|
|
"""Check if there are pending downloads."""
|
|
return bool(self.pending_downloads)
|
|
|
|
|
|
class BaseDatabaseAdapter:
|
|
"""
|
|
Base class for platform-specific database adapters.
|
|
|
|
Provides common functionality for recording and querying downloads.
|
|
Platform-specific adapters should inherit from this class.
|
|
|
|
Usage:
|
|
class MyPlatformAdapter(BaseDatabaseAdapter):
|
|
def __init__(self, unified_db):
|
|
super().__init__(unified_db, platform='my_platform')
|
|
|
|
def record_download(self, content_id, username, filename, **kwargs):
|
|
# Platform-specific URL construction
|
|
url = f"https://my_platform.com/{username}/{content_id}"
|
|
return self._record_download_internal(
|
|
url=url,
|
|
source=username,
|
|
filename=filename,
|
|
**kwargs
|
|
)
|
|
"""
|
|
|
|
def __init__(self, unified_db, platform: str, method: str = None):
|
|
"""
|
|
Initialize base adapter.
|
|
|
|
Args:
|
|
unified_db: UnifiedDatabase instance
|
|
platform: Platform name (e.g., 'instagram', 'tiktok')
|
|
method: Optional method identifier for multi-method platforms
|
|
"""
|
|
self.db = unified_db
|
|
self.unified_db = unified_db # Alias for compatibility
|
|
self.platform = platform
|
|
self.method = method or platform
|
|
|
|
def get_connection(self, for_write: bool = False):
|
|
"""Get database connection (delegates to UnifiedDatabase)."""
|
|
return self.db.get_connection(for_write)
|
|
|
|
def get_file_hash(self, file_path: str) -> Optional[str]:
|
|
"""Calculate SHA256 hash of a file."""
|
|
return self.db.get_file_hash(file_path)
|
|
|
|
def get_download_by_file_hash(self, file_hash: str) -> Optional[Dict]:
|
|
"""Get download record by file hash."""
|
|
return self.db.get_download_by_file_hash(file_hash)
|
|
|
|
def get_download_by_media_id(self, media_id: str) -> Optional[Dict]:
|
|
"""Get download record by media_id."""
|
|
return self.db.get_download_by_media_id(media_id, self.platform, self.method)
|
|
|
|
def is_already_downloaded_by_hash(self, file_path: str) -> bool:
|
|
"""Check if file is already downloaded by comparing file hash."""
|
|
file_hash = self.get_file_hash(file_path)
|
|
if not file_hash:
|
|
return False
|
|
return self.get_download_by_file_hash(file_hash) is not None
|
|
|
|
def is_already_downloaded_by_media_id(self, media_id: str) -> bool:
|
|
"""Check if content is already downloaded by media_id."""
|
|
with self.db.get_connection() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute('''
|
|
SELECT 1 FROM downloads
|
|
WHERE platform = ?
|
|
AND media_id = ?
|
|
LIMIT 1
|
|
''', (self.platform, media_id))
|
|
return cursor.fetchone() is not None
|
|
|
|
def _calculate_file_hash(self, file_path: str) -> Optional[str]:
|
|
"""Helper to safely calculate file hash."""
|
|
if not file_path:
|
|
return None
|
|
try:
|
|
from pathlib import Path
|
|
if Path(file_path).exists():
|
|
return self.get_file_hash(file_path)
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
def _detect_content_type(self, filename: str) -> str:
|
|
"""Detect content type from filename extension."""
|
|
from pathlib import Path
|
|
ext = Path(filename).suffix.lower()
|
|
image_exts = {'.jpg', '.jpeg', '.png', '.gif', '.heic', '.heif', '.webp', '.bmp', '.tiff'}
|
|
return 'image' if ext in image_exts else 'video'
|
|
|
|
def _record_download_internal(
|
|
self,
|
|
url: str,
|
|
source: str,
|
|
filename: str,
|
|
content_type: str = None,
|
|
file_path: str = None,
|
|
post_date=None,
|
|
metadata: Dict = None,
|
|
file_hash: str = None,
|
|
**extra_kwargs
|
|
) -> bool:
|
|
"""
|
|
Internal method to record a download.
|
|
|
|
Args:
|
|
url: Unique URL/identifier for the content
|
|
source: Username or source identifier
|
|
filename: Downloaded filename
|
|
content_type: 'image' or 'video' (auto-detected if not provided)
|
|
file_path: Full path to downloaded file
|
|
post_date: Original post date
|
|
metadata: Additional metadata dict
|
|
file_hash: Pre-computed file hash (computed if not provided and file_path exists)
|
|
**extra_kwargs: Additional arguments passed to unified_db.record_download
|
|
"""
|
|
# Auto-detect content type if not provided
|
|
if not content_type:
|
|
content_type = self._detect_content_type(filename)
|
|
|
|
# Calculate file hash if not provided
|
|
if not file_hash and file_path:
|
|
file_hash = self._calculate_file_hash(file_path)
|
|
|
|
return self.db.record_download(
|
|
url=url,
|
|
platform=self.platform,
|
|
source=source,
|
|
content_type=content_type,
|
|
filename=filename,
|
|
file_path=file_path,
|
|
file_hash=file_hash,
|
|
post_date=post_date,
|
|
metadata=metadata,
|
|
method=self.method,
|
|
**extra_kwargs
|
|
)
|