diff --git a/cleanmedia.py b/cleanmedia.py index 942e1e7..2a00720 100755 --- a/cleanmedia.py +++ b/cleanmedia.py @@ -26,6 +26,7 @@ from datetime import datetime, timedelta from functools import cached_property from pathlib import Path from typing import List, Tuple, TypeAlias, Union +import sqlite3 try: import psycopg2 @@ -144,7 +145,7 @@ class File: class MediaRepository: """Handle media storage and retrieval for a Dendrite server.""" - def __init__(self, media_path: Path, connection_string: str) -> None: + def __init__(self, media_path: Path, connection_strings: dict[str, str]) -> None: """Initialize MediaRepository. Args: @@ -157,8 +158,26 @@ class MediaRepository: self._validate_media_path(media_path) self.media_path = media_path self._avatar_media_ids: List[MediaID] = [] - self.db_conn_string = connection_string - self.conn = self.connect_db() + self.conn = {} + + for db_type, conn_string in connection_strings.items(): + self.conn[db_type] = self.connect_db(conn_string) + + def execute(self, query: str, params=(), db_type: str = "media_api"): + paramstyle = getattr(self.conn, "paramstyle", "format") + query = self._adjust_paramstyle(query, paramstyle) + cur = self.conn[db_type].cursor() + try: + cur.execute(query, params) + return cur + except Exception as e: + cur.close() + raise e + + def _adjust_paramstyle(self, query: str, paramstyle: str): + if paramstyle == "qmark": + return query.replace("%s", "?") + return query @staticmethod def _validate_media_path(path: Path) -> None: @@ -167,7 +186,7 @@ class MediaRepository: if not path.is_dir(): raise ValueError("Media directory not found") - def connect_db(self) -> DBConnection: + def connect_db(self, connection_string: str) -> DBConnection: """Establish database connection. Returns: @@ -176,19 +195,20 @@ class MediaRepository: Raises: ValueError: If connection string is invalid """ - if not self.db_conn_string or not self.db_conn_string.startswith(("postgres://", "postgresql://")): - raise ValueError("Invalid PostgreSQL connection string") - return psycopg2.connect(self.db_conn_string) + if connection_string.startswith(("postgres://", "postgresql://")): + return psycopg2.connect(connection_string) + if connection_string.startswith(("file:", "sqlite:")): + return sqlite3.connect(connection_string.removeprefix("file:")) + raise ValueError("Invalid PostgreSQL connection string") def get_single_media(self, mxid: MediaID) -> File | None: """Retrieve a single media file by ID.""" - with self.conn.cursor() as cur: - cur.execute( + cur = self.execute( "SELECT media_id, creation_ts, base64hash from mediaapi_media_repository WHERE media_id = %s;", (mxid,), ) - row = cur.fetchone() - return File(self, row[0], row[1] // 1000, row[2]) if row else None + row = cur.fetchone() + return File(self, row[0], row[1] // 1000, row[2]) if row else None def get_local_user_media(self, user_id: UserID) -> List[File]: """Get all media files created by a local user. @@ -199,12 +219,11 @@ class MediaRepository: Returns: List of File objects """ - with self.conn.cursor() as cur: - cur.execute( - "SELECT media_id, creation_ts, base64hash FROM mediaapi_media_repository WHERE user_id = %s;", - (user_id,), - ) - return [File(self, row[0], row[1] // 1000, row[2]) for row in cur.fetchall()] + cur = self.execute( + "SELECT media_id, creation_ts, base64hash FROM mediaapi_media_repository WHERE user_id = %s;", + (user_id,), + ) + return [File(self, row[0], row[1] // 1000, row[2]) for row in cur.fetchall()] def get_all_media(self, local: bool = False) -> List[File]: """Get all media files or only remote ones. @@ -215,13 +234,12 @@ class MediaRepository: Returns: List of File objects """ - with self.conn.cursor() as cur: - query = """SELECT media_id, creation_ts, base64hash - FROM mediaapi_media_repository""" - if not local: - query += " WHERE user_id = ''" - cur.execute(query) - return [File(self, row[0], row[1] // 1000, row[2]) for row in cur.fetchall()] + query = """SELECT media_id, creation_ts, base64hash + FROM mediaapi_media_repository""" + if not local: + query += " WHERE user_id = ''" + cur = self.execute(query) + return [File(self, row[0], row[1] // 1000, row[2]) for row in cur.fetchall()] def get_avatar_images(self) -> List[MediaID]: """Get media IDs of current avatar images. @@ -229,30 +247,28 @@ class MediaRepository: Returns: List of media IDs """ - with self.conn.cursor() as cur: - cur.execute("SELECT avatar_url FROM userapi_profiles WHERE avatar_url > '';") - media_ids = [] - for (url,) in cur.fetchall(): - try: - media_ids.append(url[url.rindex("/") + 1 :]) - except ValueError: - logging.warning("Invalid avatar URL: %s", url) - self._avatar_media_ids = media_ids - return media_ids + cur = self.execute("SELECT avatar_url FROM userapi_profiles WHERE avatar_url > '';", db_type="user_api") + media_ids = [] + for (url,) in cur.fetchall(): + try: + media_ids.append(url[url.rindex("/") + 1 :]) + except ValueError: + logging.warning("Invalid avatar URL: %s", url) + self._avatar_media_ids = media_ids + return media_ids def sanity_check_thumbnails(self) -> None: """Check for orphaned thumbnail entries in database.""" - with self.conn.cursor() as cur: - cur.execute( - """SELECT COUNT(media_id) FROM mediaapi_thumbnail - WHERE NOT EXISTS (SELECT media_id FROM mediaapi_media_repository);""", + cur = self.execute( + """SELECT COUNT(media_id) FROM mediaapi_thumbnail + WHERE NOT EXISTS (SELECT media_id FROM mediaapi_media_repository);""", + ) + if (row := cur.fetchone()) and (count := row[0]): + logging.error( + "You have %d thumbnails in your db that do not refer to media. " + "This needs fixing (we don't do that)!", + count, ) - if (row := cur.fetchone()) and (count := row[0]): - logging.error( - "You have %d thumbnails in your db that do not refer to media. " - "This needs fixing (we don't do that)!", - count, - ) def clean_media_files(self, days: int, local: bool = False, dryrun: bool = False) -> int: """Remove old media files. @@ -290,7 +306,7 @@ class MediaRepository: return len(files_to_delete) -def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str]: +def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str, dict[str, str]]: """Read database credentials and media path from config. Args: @@ -309,27 +325,28 @@ def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str]: logging.error("Config file %s not found. Use --help for usage.", conf_file) sys.exit(1) - if "media_api" not in config: - logging.error("Missing media_api section in config") - sys.exit(1) - conn_string = None - if "global" in config and "database" in config["global"]: - conn_string = config["global"]["database"].get("connection_string") - elif "database" in config["media_api"]: - logging.debug("Using database config from media_api section") - conn_string = config["media_api"]["database"].get("connection_string") + global_conn_string = config.get("global",{}).get("database", {}).get("connection_string") - if not conn_string: - logging.error("Database connection string not found in config") - sys.exit(1) + db_types = { + "media_api": lambda x: x.get("media_api",{}).get("database", {}), + "user_api": lambda x: x.get("user_api",{}).get("account_database", {}), + } + + conns = {} + for db_type, config_key in db_types.items(): + value = config_key(config).get("connection_string") + conns[db_type] = value if value else global_conn_string + if not conns[db_type]: + logging.error(f"Database {db_type} connection string not found in config") + sys.exit(1) base_path = config["media_api"].get("base_path") if not base_path: logging.error("base_path not found in media_api config") sys.exit(1) - return Path(base_path), conn_string + return Path(base_path), conns def parse_options() -> argparse.Namespace: @@ -363,8 +380,8 @@ def parse_options() -> argparse.Namespace: def main() -> None: """Execute the media cleanup process.""" args = parse_options() - media_path, conn_string = read_config(args.config) - repo = MediaRepository(media_path, conn_string) + media_path, conn_strings = read_config(args.config) + repo = MediaRepository(media_path, conn_strings) if args.mxid: process_single_media(repo, args)