524
web/backend/totp_manager.py
Normal file
524
web/backend/totp_manager.py
Normal file
@@ -0,0 +1,524 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
TOTP Manager for Media Downloader
|
||||
Handles Time-based One-Time Password (TOTP) operations for 2FA
|
||||
|
||||
Based on backup-central's implementation
|
||||
"""
|
||||
|
||||
import sys
|
||||
import pyotp
|
||||
import qrcode
|
||||
import io
|
||||
import base64
|
||||
import sqlite3
|
||||
import hashlib
|
||||
import bcrypt
|
||||
import secrets
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, List, Tuple
|
||||
from cryptography.fernet import Fernet
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent path to allow imports from modules
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from modules.universal_logger import get_logger
|
||||
|
||||
logger = get_logger('TOTPManager')
|
||||
|
||||
|
||||
class TOTPManager:
|
||||
def __init__(self, db_path: str = None):
|
||||
if db_path is None:
|
||||
db_path = str(Path(__file__).parent.parent.parent / 'database' / 'auth.db')
|
||||
|
||||
self.db_path = db_path
|
||||
self.issuer = 'Media Downloader'
|
||||
self.window = 1 # Allow 1 step (30s) tolerance for time skew
|
||||
|
||||
# Encryption key for TOTP secrets
|
||||
self.encryption_key = self._get_encryption_key()
|
||||
self.cipher_suite = Fernet(self.encryption_key)
|
||||
|
||||
# Rate limiting configuration
|
||||
self.max_attempts = 5
|
||||
self.rate_limit_window = timedelta(minutes=15)
|
||||
self.lockout_duration = timedelta(minutes=30)
|
||||
|
||||
# Initialize database tables
|
||||
self._init_database()
|
||||
|
||||
def _get_encryption_key(self) -> bytes:
|
||||
"""Get or generate encryption key for TOTP secrets"""
|
||||
key_file = Path(__file__).parent.parent.parent / '.totp_encryption_key'
|
||||
|
||||
if key_file.exists():
|
||||
with open(key_file, 'rb') as f:
|
||||
return f.read()
|
||||
|
||||
# Generate new key
|
||||
key = Fernet.generate_key()
|
||||
try:
|
||||
with open(key_file, 'wb') as f:
|
||||
f.write(key)
|
||||
import os
|
||||
os.chmod(key_file, 0o600)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not save TOTP encryption key: {e}", module="TOTP")
|
||||
|
||||
return key
|
||||
|
||||
def _init_database(self):
|
||||
"""Initialize TOTP-related database tables"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Ensure TOTP columns exist in users table
|
||||
try:
|
||||
cursor.execute("ALTER TABLE users ADD COLUMN totp_secret TEXT")
|
||||
except sqlite3.OperationalError:
|
||||
pass # Column already exists
|
||||
|
||||
try:
|
||||
cursor.execute("ALTER TABLE users ADD COLUMN totp_enabled INTEGER NOT NULL DEFAULT 0")
|
||||
except sqlite3.OperationalError:
|
||||
pass
|
||||
|
||||
try:
|
||||
cursor.execute("ALTER TABLE users ADD COLUMN totp_enrolled_at TEXT")
|
||||
except sqlite3.OperationalError:
|
||||
pass
|
||||
|
||||
# TOTP rate limiting table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS totp_rate_limit (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT NOT NULL,
|
||||
ip_address TEXT,
|
||||
attempts INTEGER NOT NULL DEFAULT 1,
|
||||
window_start TEXT NOT NULL,
|
||||
locked_until TEXT,
|
||||
UNIQUE(username, ip_address)
|
||||
)
|
||||
""")
|
||||
|
||||
# TOTP audit log
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS totp_audit_log (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
success INTEGER NOT NULL,
|
||||
ip_address TEXT,
|
||||
user_agent TEXT,
|
||||
details TEXT,
|
||||
timestamp TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
|
||||
def generate_secret(self, username: str) -> Dict:
|
||||
"""
|
||||
Generate a new TOTP secret and QR code for a user
|
||||
|
||||
Returns:
|
||||
dict with secret, qrCodeDataURL, manualEntryKey
|
||||
"""
|
||||
try:
|
||||
# Generate secret
|
||||
secret = pyotp.random_base32()
|
||||
|
||||
# Create provisioning URI
|
||||
totp = pyotp.TOTP(secret)
|
||||
provisioning_uri = totp.provisioning_uri(
|
||||
name=username,
|
||||
issuer_name=self.issuer
|
||||
)
|
||||
|
||||
# Generate QR code
|
||||
qr = qrcode.QRCode(version=1, box_size=10, border=4)
|
||||
qr.add_data(provisioning_uri)
|
||||
qr.make(fit=True)
|
||||
|
||||
img = qr.make_image(fill_color="black", back_color="white")
|
||||
|
||||
# Convert to base64 data URL
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='PNG')
|
||||
img_str = base64.b64encode(buffer.getvalue()).decode()
|
||||
qr_code_data_url = f"data:image/png;base64,{img_str}"
|
||||
|
||||
self._log_audit(username, 'totp_secret_generated', True, None, None,
|
||||
'TOTP secret generated for user')
|
||||
|
||||
return {
|
||||
'secret': secret,
|
||||
'qrCodeDataURL': qr_code_data_url,
|
||||
'manualEntryKey': secret,
|
||||
'otpauthURL': provisioning_uri
|
||||
}
|
||||
except Exception as error:
|
||||
self._log_audit(username, 'totp_secret_generation_failed', False, None, None,
|
||||
f'Error: {str(error)}')
|
||||
raise
|
||||
|
||||
def verify_token(self, secret: str, token: str) -> bool:
|
||||
"""
|
||||
Verify a TOTP token
|
||||
|
||||
Args:
|
||||
secret: Base32 encoded secret
|
||||
token: 6-digit TOTP code
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
totp = pyotp.TOTP(secret)
|
||||
return totp.verify(token, valid_window=self.window)
|
||||
except Exception as e:
|
||||
logger.error(f"TOTP verification error: {e}", module="TOTP")
|
||||
return False
|
||||
|
||||
def encrypt_secret(self, secret: str) -> str:
|
||||
"""Encrypt TOTP secret for storage"""
|
||||
encrypted = self.cipher_suite.encrypt(secret.encode())
|
||||
return base64.b64encode(encrypted).decode()
|
||||
|
||||
def decrypt_secret(self, encrypted_secret: str) -> str:
|
||||
"""Decrypt TOTP secret from storage"""
|
||||
encrypted = base64.b64decode(encrypted_secret.encode())
|
||||
decrypted = self.cipher_suite.decrypt(encrypted)
|
||||
return decrypted.decode()
|
||||
|
||||
def generate_backup_codes(self, count: int = 10) -> List[str]:
|
||||
"""
|
||||
Generate backup codes for recovery
|
||||
|
||||
Args:
|
||||
count: Number of codes to generate
|
||||
|
||||
Returns:
|
||||
List of formatted backup codes
|
||||
"""
|
||||
codes = []
|
||||
for _ in range(count):
|
||||
# Generate 8-character hex code
|
||||
code = secrets.token_hex(4).upper()
|
||||
# Format as XXXX-XXXX
|
||||
formatted = f"{code[:4]}-{code[4:8]}"
|
||||
codes.append(formatted)
|
||||
return codes
|
||||
|
||||
def hash_backup_code(self, code: str) -> str:
|
||||
"""
|
||||
Hash a backup code for storage using bcrypt
|
||||
|
||||
Args:
|
||||
code: Backup code to hash
|
||||
|
||||
Returns:
|
||||
Bcrypt hash of the code
|
||||
"""
|
||||
# Use bcrypt with default work factor (12 rounds)
|
||||
return bcrypt.hashpw(code.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
|
||||
|
||||
def verify_backup_code(self, code: str, username: str) -> Tuple[bool, int]:
|
||||
"""
|
||||
Verify a backup code and mark it as used
|
||||
|
||||
Returns:
|
||||
Tuple of (valid, remaining_codes)
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get unused backup codes for user
|
||||
cursor.execute("""
|
||||
SELECT id, code_hash FROM backup_codes
|
||||
WHERE username = ? AND used = 0
|
||||
""", (username,))
|
||||
|
||||
rows = cursor.fetchall()
|
||||
|
||||
# Check if code matches any unused code
|
||||
for row_id, stored_hash in rows:
|
||||
# Support both bcrypt (new) and SHA256 (legacy) hashes
|
||||
is_match = False
|
||||
|
||||
if stored_hash.startswith('$2b$'):
|
||||
# Bcrypt hash
|
||||
try:
|
||||
is_match = bcrypt.checkpw(code.encode('utf-8'), stored_hash.encode('utf-8'))
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying bcrypt backup code: {e}", module="TOTP")
|
||||
continue
|
||||
else:
|
||||
# Legacy SHA256 hash
|
||||
legacy_hash = hashlib.sha256(code.encode()).hexdigest()
|
||||
is_match = (stored_hash == legacy_hash)
|
||||
|
||||
if is_match:
|
||||
# Mark as used
|
||||
cursor.execute("""
|
||||
UPDATE backup_codes
|
||||
SET used = 1, used_at = ?
|
||||
WHERE id = ?
|
||||
""", (datetime.now().isoformat(), row_id))
|
||||
|
||||
conn.commit()
|
||||
|
||||
# Count remaining codes
|
||||
cursor.execute("""
|
||||
SELECT COUNT(*) FROM backup_codes
|
||||
WHERE username = ? AND used = 0
|
||||
""", (username,))
|
||||
|
||||
remaining = cursor.fetchone()[0]
|
||||
return True, remaining
|
||||
|
||||
return False, len(rows)
|
||||
|
||||
def enable_totp(self, username: str, secret: str, backup_codes: List[str]) -> bool:
|
||||
"""
|
||||
Enable TOTP for a user
|
||||
|
||||
Args:
|
||||
username: Username
|
||||
secret: TOTP secret
|
||||
backup_codes: List of backup codes
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Encrypt secret
|
||||
encrypted_secret = self.encrypt_secret(secret)
|
||||
|
||||
# Update user
|
||||
cursor.execute("""
|
||||
UPDATE users
|
||||
SET totp_secret = ?, totp_enabled = 1, totp_enrolled_at = ?
|
||||
WHERE username = ?
|
||||
""", (encrypted_secret, datetime.now().isoformat(), username))
|
||||
|
||||
# Delete old backup codes
|
||||
cursor.execute("DELETE FROM backup_codes WHERE username = ?", (username,))
|
||||
|
||||
# Insert new backup codes
|
||||
for code in backup_codes:
|
||||
code_hash = self.hash_backup_code(code)
|
||||
cursor.execute("""
|
||||
INSERT INTO backup_codes (username, code_hash, created_at)
|
||||
VALUES (?, ?, ?)
|
||||
""", (username, code_hash, datetime.now().isoformat()))
|
||||
|
||||
conn.commit()
|
||||
|
||||
self._log_audit(username, 'totp_enabled', True, None, None,
|
||||
f'TOTP enabled with {len(backup_codes)} backup codes')
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error enabling TOTP: {e}", module="TOTP")
|
||||
return False
|
||||
|
||||
def disable_totp(self, username: str) -> bool:
|
||||
"""
|
||||
Disable TOTP for a user
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
UPDATE users
|
||||
SET totp_secret = NULL, totp_enabled = 0, totp_enrolled_at = NULL
|
||||
WHERE username = ?
|
||||
""", (username,))
|
||||
|
||||
# Delete backup codes
|
||||
cursor.execute("DELETE FROM backup_codes WHERE username = ?", (username,))
|
||||
|
||||
conn.commit()
|
||||
|
||||
self._log_audit(username, 'totp_disabled', True, None, None, 'TOTP disabled')
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error disabling TOTP: {e}", module="TOTP")
|
||||
return False
|
||||
|
||||
def get_totp_status(self, username: str) -> Dict:
|
||||
"""
|
||||
Get TOTP status for a user
|
||||
|
||||
Returns:
|
||||
dict with enabled, enrolledAt
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT totp_enabled, totp_enrolled_at FROM users
|
||||
WHERE username = ?
|
||||
""", (username,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
return {'enabled': False, 'enrolledAt': None}
|
||||
|
||||
return {
|
||||
'enabled': bool(row[0]),
|
||||
'enrolledAt': row[1]
|
||||
}
|
||||
|
||||
def check_rate_limit(self, username: str, ip_address: str) -> Tuple[bool, int, Optional[str]]:
|
||||
"""
|
||||
Check and enforce rate limiting for TOTP verification
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed, attempts_remaining, locked_until)
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
now = datetime.now()
|
||||
window_start = now - self.rate_limit_window
|
||||
|
||||
cursor.execute("""
|
||||
SELECT attempts, window_start, locked_until
|
||||
FROM totp_rate_limit
|
||||
WHERE username = ? AND ip_address = ?
|
||||
""", (username, ip_address))
|
||||
|
||||
row = cursor.fetchone()
|
||||
|
||||
if not row:
|
||||
# First attempt - create record
|
||||
cursor.execute("""
|
||||
INSERT INTO totp_rate_limit (username, ip_address, attempts, window_start)
|
||||
VALUES (?, ?, 1, ?)
|
||||
""", (username, ip_address, now.isoformat()))
|
||||
conn.commit()
|
||||
return True, self.max_attempts - 1, None
|
||||
|
||||
attempts, stored_window_start, locked_until = row
|
||||
stored_window_start = datetime.fromisoformat(stored_window_start)
|
||||
|
||||
# Check if locked out
|
||||
if locked_until:
|
||||
locked_until_dt = datetime.fromisoformat(locked_until)
|
||||
if locked_until_dt > now:
|
||||
return False, 0, locked_until
|
||||
else:
|
||||
# Lockout expired - reset
|
||||
cursor.execute("""
|
||||
DELETE FROM totp_rate_limit
|
||||
WHERE username = ? AND ip_address = ?
|
||||
""", (username, ip_address))
|
||||
conn.commit()
|
||||
return True, self.max_attempts, None
|
||||
|
||||
# Check if window expired
|
||||
if stored_window_start < window_start:
|
||||
# Reset window
|
||||
cursor.execute("""
|
||||
UPDATE totp_rate_limit
|
||||
SET attempts = 1, window_start = ?, locked_until = NULL
|
||||
WHERE username = ? AND ip_address = ?
|
||||
""", (now.isoformat(), username, ip_address))
|
||||
conn.commit()
|
||||
return True, self.max_attempts - 1, None
|
||||
|
||||
# Increment attempts
|
||||
new_attempts = attempts + 1
|
||||
|
||||
if new_attempts >= self.max_attempts:
|
||||
# Lock out
|
||||
locked_until = (now + self.lockout_duration).isoformat()
|
||||
cursor.execute("""
|
||||
UPDATE totp_rate_limit
|
||||
SET attempts = ?, locked_until = ?
|
||||
WHERE username = ? AND ip_address = ?
|
||||
""", (new_attempts, locked_until, username, ip_address))
|
||||
conn.commit()
|
||||
return False, 0, locked_until
|
||||
else:
|
||||
cursor.execute("""
|
||||
UPDATE totp_rate_limit
|
||||
SET attempts = ?
|
||||
WHERE username = ? AND ip_address = ?
|
||||
""", (new_attempts, username, ip_address))
|
||||
conn.commit()
|
||||
return True, self.max_attempts - new_attempts, None
|
||||
|
||||
def reset_rate_limit(self, username: str, ip_address: str):
|
||||
"""Reset rate limit after successful authentication"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
DELETE FROM totp_rate_limit
|
||||
WHERE username = ? AND ip_address = ?
|
||||
""", (username, ip_address))
|
||||
|
||||
conn.commit()
|
||||
|
||||
def regenerate_backup_codes(self, username: str) -> List[str]:
|
||||
"""
|
||||
Regenerate backup codes for a user
|
||||
|
||||
Returns:
|
||||
List of new backup codes
|
||||
"""
|
||||
# Generate new codes
|
||||
new_codes = self.generate_backup_codes(10)
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Delete old codes
|
||||
cursor.execute("DELETE FROM backup_codes WHERE username = ?", (username,))
|
||||
|
||||
# Insert new codes
|
||||
for code in new_codes:
|
||||
code_hash = self.hash_backup_code(code)
|
||||
cursor.execute("""
|
||||
INSERT INTO backup_codes (username, code_hash, created_at)
|
||||
VALUES (?, ?, ?)
|
||||
""", (username, code_hash, datetime.now().isoformat()))
|
||||
|
||||
conn.commit()
|
||||
|
||||
self._log_audit(username, 'backup_codes_regenerated', True, None, None,
|
||||
f'Generated {len(new_codes)} new backup codes')
|
||||
|
||||
return new_codes
|
||||
|
||||
def _log_audit(self, username: str, action: str, success: bool,
|
||||
ip_address: Optional[str], user_agent: Optional[str],
|
||||
details: Optional[str] = None):
|
||||
"""Log TOTP audit event"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
INSERT INTO totp_audit_log
|
||||
(username, action, success, ip_address, user_agent, details, timestamp)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""", (username, action, int(success), ip_address, user_agent, details,
|
||||
datetime.now().isoformat()))
|
||||
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging TOTP audit: {e}", module="TOTP")
|
||||
Reference in New Issue
Block a user