From d318296e7ec43751e9d32206221a61c2a20b33ff Mon Sep 17 00:00:00 2001 From: Christian Groschupp Date: Thu, 27 Mar 2025 09:58:57 +0100 Subject: [PATCH 1/8] feat: support sqlite3 and multiple databases --- cleanmedia.py | 137 ++++++++++++++++++++++++++++---------------------- 1 file changed, 77 insertions(+), 60 deletions(-) diff --git a/cleanmedia.py b/cleanmedia.py index 942e1e7..2a00720 100755 --- a/cleanmedia.py +++ b/cleanmedia.py @@ -26,6 +26,7 @@ from datetime import datetime, timedelta from functools import cached_property from pathlib import Path from typing import List, Tuple, TypeAlias, Union +import sqlite3 try: import psycopg2 @@ -144,7 +145,7 @@ class File: class MediaRepository: """Handle media storage and retrieval for a Dendrite server.""" - def __init__(self, media_path: Path, connection_string: str) -> None: + def __init__(self, media_path: Path, connection_strings: dict[str, str]) -> None: """Initialize MediaRepository. Args: @@ -157,8 +158,26 @@ class MediaRepository: self._validate_media_path(media_path) self.media_path = media_path self._avatar_media_ids: List[MediaID] = [] - self.db_conn_string = connection_string - self.conn = self.connect_db() + self.conn = {} + + for db_type, conn_string in connection_strings.items(): + self.conn[db_type] = self.connect_db(conn_string) + + def execute(self, query: str, params=(), db_type: str = "media_api"): + paramstyle = getattr(self.conn, "paramstyle", "format") + query = self._adjust_paramstyle(query, paramstyle) + cur = self.conn[db_type].cursor() + try: + cur.execute(query, params) + return cur + except Exception as e: + cur.close() + raise e + + def _adjust_paramstyle(self, query: str, paramstyle: str): + if paramstyle == "qmark": + return query.replace("%s", "?") + return query @staticmethod def _validate_media_path(path: Path) -> None: @@ -167,7 +186,7 @@ class MediaRepository: if not path.is_dir(): raise ValueError("Media directory not found") - def connect_db(self) -> DBConnection: + def connect_db(self, connection_string: str) -> DBConnection: """Establish database connection. Returns: @@ -176,19 +195,20 @@ class MediaRepository: Raises: ValueError: If connection string is invalid """ - if not self.db_conn_string or not self.db_conn_string.startswith(("postgres://", "postgresql://")): - raise ValueError("Invalid PostgreSQL connection string") - return psycopg2.connect(self.db_conn_string) + if connection_string.startswith(("postgres://", "postgresql://")): + return psycopg2.connect(connection_string) + if connection_string.startswith(("file:", "sqlite:")): + return sqlite3.connect(connection_string.removeprefix("file:")) + raise ValueError("Invalid PostgreSQL connection string") def get_single_media(self, mxid: MediaID) -> File | None: """Retrieve a single media file by ID.""" - with self.conn.cursor() as cur: - cur.execute( + cur = self.execute( "SELECT media_id, creation_ts, base64hash from mediaapi_media_repository WHERE media_id = %s;", (mxid,), ) - row = cur.fetchone() - return File(self, row[0], row[1] // 1000, row[2]) if row else None + row = cur.fetchone() + return File(self, row[0], row[1] // 1000, row[2]) if row else None def get_local_user_media(self, user_id: UserID) -> List[File]: """Get all media files created by a local user. @@ -199,12 +219,11 @@ class MediaRepository: Returns: List of File objects """ - with self.conn.cursor() as cur: - cur.execute( - "SELECT media_id, creation_ts, base64hash FROM mediaapi_media_repository WHERE user_id = %s;", - (user_id,), - ) - return [File(self, row[0], row[1] // 1000, row[2]) for row in cur.fetchall()] + cur = self.execute( + "SELECT media_id, creation_ts, base64hash FROM mediaapi_media_repository WHERE user_id = %s;", + (user_id,), + ) + return [File(self, row[0], row[1] // 1000, row[2]) for row in cur.fetchall()] def get_all_media(self, local: bool = False) -> List[File]: """Get all media files or only remote ones. @@ -215,13 +234,12 @@ class MediaRepository: Returns: List of File objects """ - with self.conn.cursor() as cur: - query = """SELECT media_id, creation_ts, base64hash - FROM mediaapi_media_repository""" - if not local: - query += " WHERE user_id = ''" - cur.execute(query) - return [File(self, row[0], row[1] // 1000, row[2]) for row in cur.fetchall()] + query = """SELECT media_id, creation_ts, base64hash + FROM mediaapi_media_repository""" + if not local: + query += " WHERE user_id = ''" + cur = self.execute(query) + return [File(self, row[0], row[1] // 1000, row[2]) for row in cur.fetchall()] def get_avatar_images(self) -> List[MediaID]: """Get media IDs of current avatar images. @@ -229,30 +247,28 @@ class MediaRepository: Returns: List of media IDs """ - with self.conn.cursor() as cur: - cur.execute("SELECT avatar_url FROM userapi_profiles WHERE avatar_url > '';") - media_ids = [] - for (url,) in cur.fetchall(): - try: - media_ids.append(url[url.rindex("/") + 1 :]) - except ValueError: - logging.warning("Invalid avatar URL: %s", url) - self._avatar_media_ids = media_ids - return media_ids + cur = self.execute("SELECT avatar_url FROM userapi_profiles WHERE avatar_url > '';", db_type="user_api") + media_ids = [] + for (url,) in cur.fetchall(): + try: + media_ids.append(url[url.rindex("/") + 1 :]) + except ValueError: + logging.warning("Invalid avatar URL: %s", url) + self._avatar_media_ids = media_ids + return media_ids def sanity_check_thumbnails(self) -> None: """Check for orphaned thumbnail entries in database.""" - with self.conn.cursor() as cur: - cur.execute( - """SELECT COUNT(media_id) FROM mediaapi_thumbnail - WHERE NOT EXISTS (SELECT media_id FROM mediaapi_media_repository);""", + cur = self.execute( + """SELECT COUNT(media_id) FROM mediaapi_thumbnail + WHERE NOT EXISTS (SELECT media_id FROM mediaapi_media_repository);""", + ) + if (row := cur.fetchone()) and (count := row[0]): + logging.error( + "You have %d thumbnails in your db that do not refer to media. " + "This needs fixing (we don't do that)!", + count, ) - if (row := cur.fetchone()) and (count := row[0]): - logging.error( - "You have %d thumbnails in your db that do not refer to media. " - "This needs fixing (we don't do that)!", - count, - ) def clean_media_files(self, days: int, local: bool = False, dryrun: bool = False) -> int: """Remove old media files. @@ -290,7 +306,7 @@ class MediaRepository: return len(files_to_delete) -def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str]: +def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str, dict[str, str]]: """Read database credentials and media path from config. Args: @@ -309,27 +325,28 @@ def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str]: logging.error("Config file %s not found. Use --help for usage.", conf_file) sys.exit(1) - if "media_api" not in config: - logging.error("Missing media_api section in config") - sys.exit(1) - conn_string = None - if "global" in config and "database" in config["global"]: - conn_string = config["global"]["database"].get("connection_string") - elif "database" in config["media_api"]: - logging.debug("Using database config from media_api section") - conn_string = config["media_api"]["database"].get("connection_string") + global_conn_string = config.get("global",{}).get("database", {}).get("connection_string") - if not conn_string: - logging.error("Database connection string not found in config") - sys.exit(1) + db_types = { + "media_api": lambda x: x.get("media_api",{}).get("database", {}), + "user_api": lambda x: x.get("user_api",{}).get("account_database", {}), + } + + conns = {} + for db_type, config_key in db_types.items(): + value = config_key(config).get("connection_string") + conns[db_type] = value if value else global_conn_string + if not conns[db_type]: + logging.error(f"Database {db_type} connection string not found in config") + sys.exit(1) base_path = config["media_api"].get("base_path") if not base_path: logging.error("base_path not found in media_api config") sys.exit(1) - return Path(base_path), conn_string + return Path(base_path), conns def parse_options() -> argparse.Namespace: @@ -363,8 +380,8 @@ def parse_options() -> argparse.Namespace: def main() -> None: """Execute the media cleanup process.""" args = parse_options() - media_path, conn_string = read_config(args.config) - repo = MediaRepository(media_path, conn_string) + media_path, conn_strings = read_config(args.config) + repo = MediaRepository(media_path, conn_strings) if args.mxid: process_single_media(repo, args) -- 2.39.5 From 9751bf28b981d6c437f44517147ef2dcdeebac78 Mon Sep 17 00:00:00 2001 From: Christian Groschupp Date: Thu, 27 Mar 2025 10:18:52 +0100 Subject: [PATCH 2/8] feat: calculate saved space --- cleanmedia.py | 46 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 36 insertions(+), 10 deletions(-) diff --git a/cleanmedia.py b/cleanmedia.py index 2a00720..58825d6 100755 --- a/cleanmedia.py +++ b/cleanmedia.py @@ -52,12 +52,14 @@ class File: media_id: MediaID, creation_ts: Timestamp, base64hash: Base64Hash, + file_size: int, ) -> None: """Initialize a File object.""" self.repo = media_repo self.media_id = media_id self.create_date = datetime.fromtimestamp(creation_ts) self.base64hash = base64hash + self.file_size = file_size @cached_property def fullpath(self) -> Path | None: @@ -163,7 +165,7 @@ class MediaRepository: for db_type, conn_string in connection_strings.items(): self.conn[db_type] = self.connect_db(conn_string) - def execute(self, query: str, params=(), db_type: str = "media_api"): + def _execute(self, query: str, params=(), db_type: str = "media_api"): paramstyle = getattr(self.conn, "paramstyle", "format") query = self._adjust_paramstyle(query, paramstyle) cur = self.conn[db_type].cursor() @@ -203,7 +205,7 @@ class MediaRepository: def get_single_media(self, mxid: MediaID) -> File | None: """Retrieve a single media file by ID.""" - cur = self.execute( + cur = self._execute( "SELECT media_id, creation_ts, base64hash from mediaapi_media_repository WHERE media_id = %s;", (mxid,), ) @@ -219,7 +221,7 @@ class MediaRepository: Returns: List of File objects """ - cur = self.execute( + cur = self._execute( "SELECT media_id, creation_ts, base64hash FROM mediaapi_media_repository WHERE user_id = %s;", (user_id,), ) @@ -234,12 +236,12 @@ class MediaRepository: Returns: List of File objects """ - query = """SELECT media_id, creation_ts, base64hash + query = """SELECT media_id, creation_ts, base64hash, file_size_bytes FROM mediaapi_media_repository""" if not local: query += " WHERE user_id = ''" - cur = self.execute(query) - return [File(self, row[0], row[1] // 1000, row[2]) for row in cur.fetchall()] + cur = self._execute(query) + return [File(self, row[0], row[1] // 1000, row[2], row[3]) for row in cur.fetchall()] def get_avatar_images(self) -> List[MediaID]: """Get media IDs of current avatar images. @@ -247,7 +249,7 @@ class MediaRepository: Returns: List of media IDs """ - cur = self.execute("SELECT avatar_url FROM userapi_profiles WHERE avatar_url > '';", db_type="user_api") + cur = self._execute("SELECT avatar_url FROM userapi_profiles WHERE avatar_url > '';", db_type="user_api") media_ids = [] for (url,) in cur.fetchall(): try: @@ -259,7 +261,7 @@ class MediaRepository: def sanity_check_thumbnails(self) -> None: """Check for orphaned thumbnail entries in database.""" - cur = self.execute( + cur = self._execute( """SELECT COUNT(media_id) FROM mediaapi_thumbnail WHERE NOT EXISTS (SELECT media_id FROM mediaapi_media_repository);""", ) @@ -292,8 +294,9 @@ class MediaRepository: for f in self.get_all_media(local) if f.media_id not in self._avatar_media_ids and f.create_date < cutoff_date ] - + file_size_counter = 0 for file in files_to_delete: + file_size_counter += file.file_size if dryrun: logging.info(f"Would delete file {file.media_id} at {file.fullpath}") if not file.exists(): @@ -302,7 +305,7 @@ class MediaRepository: file.delete() action = "Would have deleted" if dryrun else "Deleted" - logging.info("%s %d files", action, len(files_to_delete)) + logging.info("%s %d files, in total %s", action, len(files_to_delete), sizeof_fmt(file_size_counter)) return len(files_to_delete) @@ -348,6 +351,29 @@ def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str, dict[str, str]] return Path(base_path), conns +def sizeof_fmt(num: Union[int, float], suffix: str = "B") -> str: + """ + Convert a number of bytes (or other units) into a human-readable format using binary prefixes. + + Args: + num (int | float): The number to format, typically a size in bytes. + suffix (str): The suffix to append (default is 'B' for bytes). + + Returns: + str: A string representing the human-readable size (e.g. '1.5MiB', '42.0KiB'). + + Notes: + Uses binary (base-1024) units: Ki, Mi, Gi, Ti, etc. + Automatically chooses the appropriate unit based on the input value. + """ + BASE = 1024.0 + + for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"): + if abs(num) < BASE: + return f"{num:3.1f}{unit}{suffix}" + num /= 1024.0 + return f"{num:.1f}Yi{suffix}" + def parse_options() -> argparse.Namespace: """Parse command line arguments. -- 2.39.5 From 19476c593f2d0f553033eaae8eafc31496e34303 Mon Sep 17 00:00:00 2001 From: Christian Groschupp Date: Thu, 27 Mar 2025 11:27:16 +0100 Subject: [PATCH 3/8] feat: update tests --- cleanmedia.py | 57 ++++++++++++++++--------------- tests/test_cleanmedia.py | 72 +++++++++++++++++++++------------------- 2 files changed, 66 insertions(+), 63 deletions(-) diff --git a/cleanmedia.py b/cleanmedia.py index 58825d6..136d4b1 100755 --- a/cleanmedia.py +++ b/cleanmedia.py @@ -25,7 +25,7 @@ import sys from datetime import datetime, timedelta from functools import cached_property from pathlib import Path -from typing import List, Tuple, TypeAlias, Union +from typing import List, Tuple, TypeAlias, Union, Sequence, Mapping, Any, Callable import sqlite3 try: @@ -36,7 +36,9 @@ except ImportError as err: raise ImportError("Required dependencies not found. Please install psycopg2 and pyyaml.") from err # Type aliases -DBConnection: TypeAlias = psycopg2.extensions.connection +DBConnection: TypeAlias = Union[sqlite3.Connection, psycopg2.extensions.connection] +DBCursor = Union[sqlite3.Cursor, psycopg2.extensions.cursor] +Params = Union[Sequence[Any], Mapping[str, Any]] MediaID: TypeAlias = str Timestamp: TypeAlias = int Base64Hash: TypeAlias = str @@ -113,12 +115,11 @@ class File: Returns: True if database entries were deleted successfully """ - with self.repo.conn.cursor() as cur: - cur.execute("DELETE from mediaapi_thumbnail WHERE media_id=%s;", (self.media_id,)) - num_thumbnails = cur.rowcount - cur.execute("DELETE from mediaapi_media_repository WHERE media_id=%s;", (self.media_id,)) - num_media = cur.rowcount - self.repo.conn.commit() + cur = self.repo._execute("DELETE from mediaapi_thumbnail WHERE media_id=%s;", (self.media_id,)) + num_thumbnails = cur.rowcount + self.repo._execute("DELETE from mediaapi_media_repository WHERE media_id=%s;", (self.media_id,)) + num_media = cur.rowcount + self.repo.conn["media_api"].commit() logging.debug(f"Deleted {num_media} + {num_thumbnails} db entries for media id {self.media_id}") return True @@ -138,10 +139,9 @@ class File: Returns: Number of thumbnails """ - with self.repo.conn.cursor() as cur: - cur.execute("SELECT COUNT(media_id) FROM mediaapi_thumbnail WHERE media_id = %s;", (self.media_id,)) - row = cur.fetchone() - return int(row[0]) if row else 0 + cur = self.repo._execute("SELECT COUNT(media_id) FROM mediaapi_thumbnail WHERE media_id = %s;", (self.media_id,)) + row = cur.fetchone() + return int(row[0]) if row else 0 class MediaRepository: @@ -160,12 +160,12 @@ class MediaRepository: self._validate_media_path(media_path) self.media_path = media_path self._avatar_media_ids: List[MediaID] = [] - self.conn = {} + self.conn: dict[str, DBConnection] = {} for db_type, conn_string in connection_strings.items(): self.conn[db_type] = self.connect_db(conn_string) - def _execute(self, query: str, params=(), db_type: str = "media_api"): + def _execute(self, query: str, params: Params = (), db_type: str = "media_api") -> DBCursor: paramstyle = getattr(self.conn, "paramstyle", "format") query = self._adjust_paramstyle(query, paramstyle) cur = self.conn[db_type].cursor() @@ -176,7 +176,7 @@ class MediaRepository: cur.close() raise e - def _adjust_paramstyle(self, query: str, paramstyle: str): + def _adjust_paramstyle(self, query: str, paramstyle: str) -> str: if paramstyle == "qmark": return query.replace("%s", "?") return query @@ -206,11 +206,11 @@ class MediaRepository: def get_single_media(self, mxid: MediaID) -> File | None: """Retrieve a single media file by ID.""" cur = self._execute( - "SELECT media_id, creation_ts, base64hash from mediaapi_media_repository WHERE media_id = %s;", - (mxid,), - ) + "SELECT media_id, creation_ts, base64hash, file_size_bytes from mediaapi_media_repository WHERE media_id = %s;", + (mxid,), + ) row = cur.fetchone() - return File(self, row[0], row[1] // 1000, row[2]) if row else None + return File(self, row[0], row[1] // 1000, row[2], row[3]) if row else None def get_local_user_media(self, user_id: UserID) -> List[File]: """Get all media files created by a local user. @@ -222,10 +222,10 @@ class MediaRepository: List of File objects """ cur = self._execute( - "SELECT media_id, creation_ts, base64hash FROM mediaapi_media_repository WHERE user_id = %s;", + "SELECT media_id, creation_ts, base64hash, file_size_bytes FROM mediaapi_media_repository WHERE user_id = %s;", (user_id,), ) - return [File(self, row[0], row[1] // 1000, row[2]) for row in cur.fetchall()] + return [File(self, row[0], row[1] // 1000, row[2], row[3]) for row in cur.fetchall()] def get_all_media(self, local: bool = False) -> List[File]: """Get all media files or only remote ones. @@ -267,8 +267,7 @@ class MediaRepository: ) if (row := cur.fetchone()) and (count := row[0]): logging.error( - "You have %d thumbnails in your db that do not refer to media. " - "This needs fixing (we don't do that)!", + "You have %d thumbnails in your db that do not refer to media. " "This needs fixing (we don't do that)!", count, ) @@ -309,7 +308,7 @@ class MediaRepository: return len(files_to_delete) -def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str, dict[str, str]]: +def read_config(conf_file: Union[str, Path]) -> Tuple[Path, dict[str, str]]: """Read database credentials and media path from config. Args: @@ -328,12 +327,11 @@ def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str, dict[str, str]] logging.error("Config file %s not found. Use --help for usage.", conf_file) sys.exit(1) + global_conn_string = config.get("global", {}).get("database", {}).get("connection_string") - global_conn_string = config.get("global",{}).get("database", {}).get("connection_string") - - db_types = { - "media_api": lambda x: x.get("media_api",{}).get("database", {}), - "user_api": lambda x: x.get("user_api",{}).get("account_database", {}), + db_types: dict[str, Callable[[dict[str, Any]], dict[str, Any]]] = { + "media_api": lambda x: x.get("media_api", {}).get("database", {}), + "user_api": lambda x: x.get("user_api", {}).get("account_database", {}), } conns = {} @@ -351,6 +349,7 @@ def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str, dict[str, str]] return Path(base_path), conns + def sizeof_fmt(num: Union[int, float], suffix: str = "B") -> str: """ Convert a number of bytes (or other units) into a human-readable format using binary prefixes. diff --git a/tests/test_cleanmedia.py b/tests/test_cleanmedia.py index ef5f16e..f31b449 100644 --- a/tests/test_cleanmedia.py +++ b/tests/test_cleanmedia.py @@ -42,17 +42,18 @@ def media_repo(tmp_path: Any, mock_db_conn: Tuple[Any, Any], mocker: MockerFixtu Returns: Configured MediaRepository instance """ - conn_mock, _ = mock_db_conn + conn_mock, cursor = mock_db_conn media_path = tmp_path / "media" media_path.mkdir() mocker.patch("cleanmedia.MediaRepository.connect_db", return_value=conn_mock) - return MediaRepository(media_path, "postgresql://fake") + mocker.patch("cleanmedia.MediaRepository._execute", return_value=cursor) + return MediaRepository(media_path, {"media_api": "postgresql://fake", "user_api": "postgresql://fake"}) def test_file_init(mocker: MockerFixture) -> None: """Test File class initialization.""" repo = mocker.Mock() - file = File(repo, "mxid123", 1600000000, "base64hash123") + file = File(repo, "mxid123", 1600000000, "base64hash123", 1000) assert file.media_id == "mxid123" assert file.create_date == datetime.fromtimestamp(1600000000) assert file.base64hash == "base64hash123" @@ -61,26 +62,26 @@ def test_file_init(mocker: MockerFixture) -> None: def test_file_fullpath(media_repo: MediaRepository) -> None: """Test File.fullpath property returns correct path.""" - file = File(media_repo, "mxid123", 1600000000, "abc123") + file = File(media_repo, "mxid123", 1600000000, "abc123", 1000) expected_path = media_repo.media_path / "a" / "b" / "c123" assert file.fullpath == expected_path def test_file_exists_no_path(media_repo: MediaRepository) -> None: """Test File.exists returns False when fullpath is None.""" - file = File(media_repo, "mxid123", 1600000000, "") # Empty hash ensures fullpath is None + file = File(media_repo, "mxid123", 1600000000, "", 1000) # Empty hash ensures fullpath is None assert file.exists() is False def test_file_delete_no_path(media_repo: MediaRepository) -> None: """Test File._delete_files when file path is None.""" - file = File(media_repo, "mxid123", 1600000000, "") + file = File(media_repo, "mxid123", 1600000000, "", 1000) assert file._delete_files() is False def test_file_delete_oserror(media_repo: MediaRepository, mocker: MockerFixture, caplog: Any) -> None: """Test File._delete_files when OSError occurs.""" - file = File(media_repo, "mxid123", 1600000000, "abc123") + file = File(media_repo, "mxid123", 1600000000, "abc123", 1000) # Create directory structure file_path = media_repo.media_path / "a" / "b" / "c123" @@ -96,13 +97,13 @@ def test_file_delete_oserror(media_repo: MediaRepository, mocker: MockerFixture, def test_file_fullpath_none_if_no_hash(media_repo: MediaRepository) -> None: """Test File.fullpath returns None when hash is empty.""" - file = File(media_repo, "mxid123", 1600000000, "") + file = File(media_repo, "mxid123", 1600000000, "", 1000) assert file.fullpath is None def test_file_exists(media_repo: MediaRepository) -> None: """Test File.exists returns True when file exists.""" - file = File(media_repo, "mxid123", 1600000000, "abc123") + file = File(media_repo, "mxid123", 1600000000, "abc123", 1000) file_path = media_repo.media_path / "a" / "b" / "c123" file_path.mkdir(parents=True) (file_path / "file").touch() @@ -111,14 +112,14 @@ def test_file_exists(media_repo: MediaRepository) -> None: def test_file_not_exists(media_repo: MediaRepository) -> None: """Test File.exists returns False when file doesn't exist.""" - file = File(media_repo, "mxid123", 1600000000, "abc123") + file = File(media_repo, "mxid123", 1600000000, "abc123", 1000) assert file.exists() is False def test_file_delete(media_repo: MediaRepository, mock_db_conn: Tuple[Any, Any]) -> None: """Test File.delete removes files and database entries.""" _, cursor_mock = mock_db_conn - file = File(media_repo, "mxid123", 1600000000, "abc123") + file = File(media_repo, "mxid123", 1600000000, "abc123", 1000) file_path = media_repo.media_path / "a" / "b" / "c123" file_path.mkdir(parents=True) @@ -128,22 +129,22 @@ def test_file_delete(media_repo: MediaRepository, mock_db_conn: Tuple[Any, Any]) assert file.delete() is True assert not file_path.exists() - cursor_mock.execute.assert_any_call("DELETE from mediaapi_thumbnail WHERE media_id=%s;", ("mxid123",)) - cursor_mock.execute.assert_any_call("DELETE from mediaapi_media_repository WHERE media_id=%s;", ("mxid123",)) + media_repo._execute.assert_any_call("DELETE from mediaapi_thumbnail WHERE media_id=%s;", ("mxid123",)) + media_repo._execute.assert_any_call("DELETE from mediaapi_media_repository WHERE media_id=%s;", ("mxid123",)) def test_get_single_media(media_repo: MediaRepository, mock_db_conn: Tuple[Any, Any]) -> None: """Test MediaRepository.get_single_media returns correct File object.""" _, cursor_mock = mock_db_conn - cursor_mock.fetchone.return_value = ("mxid123", 1600000000000, "abc123") + cursor_mock.fetchone.return_value = ("mxid123", 1600000000000, "abc123", 1000) file = media_repo.get_single_media("mxid123") assert file is not None assert file.media_id == "mxid123" assert file.base64hash == "abc123" - cursor_mock.execute.assert_called_with( - "SELECT media_id, creation_ts, base64hash from mediaapi_media_repository WHERE media_id = %s;", + media_repo._execute.assert_called_with( + "SELECT media_id, creation_ts, base64hash, file_size_bytes from mediaapi_media_repository WHERE media_id = %s;", ("mxid123",), ) @@ -165,8 +166,8 @@ def test_clean_media_files(media_repo: MediaRepository, mock_db_conn: Tuple[Any, new_date = int((datetime.now() - timedelta(days=1)).timestamp()) cursor_mock.fetchall.return_value = [ - ("old_file", old_date * 1000, "abc123"), - ("new_file", new_date * 1000, "def456"), + ("old_file", old_date * 1000, "abc123", 1000), + ("new_file", new_date * 1000, "def456", 1000), ] media_repo._avatar_media_ids = [] @@ -181,7 +182,7 @@ def test_clean_media_files_dryrun(media_repo: MediaRepository, mock_db_conn: Tup old_date = int((datetime.now() - timedelta(days=31)).timestamp()) cursor_mock.fetchall.return_value = [ - ("old_file", old_date * 1000, "abc123"), + ("old_file", old_date * 1000, "abc123", 1000), ] media_repo._avatar_media_ids = [] @@ -252,24 +253,24 @@ def test_validate_media_path_relative(caplog: Any) -> None: def test_connect_db_success(mocker: MockerFixture) -> None: """Test successful database connection.""" mock_connect = mocker.patch("psycopg2.connect") - repo = MediaRepository(Path("/tmp"), "postgresql://fake") - repo.connect_db() + repo = MediaRepository(Path("/tmp"), {"media_api": "postgresql://fake"}) + repo.connect_db("postgresql://fake") mock_connect.assert_called_with("postgresql://fake") def test_connect_db_invalid_string(tmp_path: Path) -> None: """Test connect_db with invalid connection string.""" with pytest.raises(ValueError, match="Invalid PostgreSQL connection string"): - repo = MediaRepository(tmp_path, "invalid") - repo.connect_db() + repo = MediaRepository(tmp_path, {"media_api": "invalid"}) + repo.connect_db("invalid") def test_get_local_user_media(media_repo: MediaRepository, mock_db_conn: Tuple[Any, Any]) -> None: """Test get_local_user_media returns correct files.""" _, cursor_mock = mock_db_conn cursor_mock.fetchall.return_value = [ - ("media1", 1600000000000, "hash1"), - ("media2", 1600000000000, "hash2"), + ("media1", 1600000000000, "hash1", 1000), + ("media2", 1600000000000, "hash2", 1000), ] files = media_repo.get_local_user_media("@user:domain.com") @@ -277,8 +278,8 @@ def test_get_local_user_media(media_repo: MediaRepository, mock_db_conn: Tuple[A assert files[0].media_id == "media1" assert files[1].media_id == "media2" - cursor_mock.execute.assert_called_with( - "SELECT media_id, creation_ts, base64hash FROM mediaapi_media_repository WHERE user_id = %s;", + media_repo._execute.assert_called_with( + "SELECT media_id, creation_ts, base64hash, file_size_bytes FROM mediaapi_media_repository WHERE user_id = %s;", ("@user:domain.com",), ) @@ -333,15 +334,18 @@ def test_read_config_valid(tmp_path: Path) -> None: """Test read_config with valid config.""" config_file = tmp_path / "config.yaml" config_file.write_text(""" +global: + database: + connection_string: postgresql://global/db media_api: base_path: /media/path database: connection_string: postgresql://user:pass@localhost/db """) - path, conn_string = read_config(config_file) + path, conn_strings = read_config(config_file) assert path == Path("/media/path") - assert conn_string == "postgresql://user:pass@localhost/db" + assert conn_strings["media_api"] == "postgresql://user:pass@localhost/db" def test_read_config_global_database(tmp_path: Path) -> None: @@ -355,8 +359,8 @@ global: connection_string: postgresql://global/db """) - path, conn_string = read_config(config_file) - assert conn_string == "postgresql://global/db" + _, conn_string = read_config(config_file) + assert conn_string["media_api"] == "postgresql://global/db" def test_read_config_missing_conn_string(tmp_path: Path) -> None: @@ -401,9 +405,9 @@ def test_file_has_thumbnails(media_repo: MediaRepository, mock_db_conn: Tuple[An _, cursor_mock = mock_db_conn cursor_mock.fetchone.return_value = (3,) - file = File(media_repo, "mxid123", 1600000000, "abc123") + file = File(media_repo, "mxid123", 1600000000, "abc123", 1000) assert file.has_thumbnail() == 3 # noqa PLR2004 - cursor_mock.execute.assert_called_with( + media_repo._execute.assert_called_with( "SELECT COUNT(media_id) FROM mediaapi_thumbnail WHERE media_id = %s;", ("mxid123",), ) @@ -414,5 +418,5 @@ def test_file_has_no_thumbnails(media_repo: MediaRepository, mock_db_conn: Tuple _, cursor_mock = mock_db_conn cursor_mock.fetchone.return_value = None - file = File(media_repo, "mxid123", 1600000000, "abc123") + file = File(media_repo, "mxid123", 1600000000, "abc123", 1000) assert file.has_thumbnail() == 0 -- 2.39.5 From 5b248cd46b8e5766784e8f60789aad0cbf077700 Mon Sep 17 00:00:00 2001 From: Christian Groschupp Date: Thu, 27 Mar 2025 11:49:36 +0100 Subject: [PATCH 4/8] style: lint code --- cleanmedia.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cleanmedia.py b/cleanmedia.py index 136d4b1..0e66174 100755 --- a/cleanmedia.py +++ b/cleanmedia.py @@ -206,7 +206,8 @@ class MediaRepository: def get_single_media(self, mxid: MediaID) -> File | None: """Retrieve a single media file by ID.""" cur = self._execute( - "SELECT media_id, creation_ts, base64hash, file_size_bytes from mediaapi_media_repository WHERE media_id = %s;", + """SELECT media_id, creation_ts, base64hash, file_size_bytes + from mediaapi_media_repository WHERE media_id = %s;""", (mxid,), ) row = cur.fetchone() @@ -222,7 +223,8 @@ class MediaRepository: List of File objects """ cur = self._execute( - "SELECT media_id, creation_ts, base64hash, file_size_bytes FROM mediaapi_media_repository WHERE user_id = %s;", + """SELECT media_id, creation_ts, base64hash, file_size_bytes + FROM mediaapi_media_repository WHERE user_id = %s;""", (user_id,), ) return [File(self, row[0], row[1] // 1000, row[2], row[3]) for row in cur.fetchall()] @@ -267,7 +269,7 @@ class MediaRepository: ) if (row := cur.fetchone()) and (count := row[0]): logging.error( - "You have %d thumbnails in your db that do not refer to media. " "This needs fixing (we don't do that)!", + "You have %d thumbnails in your db that do not refer to media. This needs fixing (we don't do that)!", count, ) -- 2.39.5 From fdf878d2e39db131da9f29c36452266c8ac2d908 Mon Sep 17 00:00:00 2001 From: Christian Groschupp Date: Thu, 27 Mar 2025 12:11:49 +0100 Subject: [PATCH 5/8] fix: ignore some type checks in test file --- tests/test_cleanmedia.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_cleanmedia.py b/tests/test_cleanmedia.py index f31b449..0089033 100644 --- a/tests/test_cleanmedia.py +++ b/tests/test_cleanmedia.py @@ -129,8 +129,8 @@ def test_file_delete(media_repo: MediaRepository, mock_db_conn: Tuple[Any, Any]) assert file.delete() is True assert not file_path.exists() - media_repo._execute.assert_any_call("DELETE from mediaapi_thumbnail WHERE media_id=%s;", ("mxid123",)) - media_repo._execute.assert_any_call("DELETE from mediaapi_media_repository WHERE media_id=%s;", ("mxid123",)) + media_repo._execute.assert_any_call("DELETE from mediaapi_thumbnail WHERE media_id=%s;", ("mxid123",)) # type: ignore + media_repo._execute.assert_any_call("DELETE from mediaapi_media_repository WHERE media_id=%s;", ("mxid123",)) # type: ignore def test_get_single_media(media_repo: MediaRepository, mock_db_conn: Tuple[Any, Any]) -> None: @@ -143,7 +143,7 @@ def test_get_single_media(media_repo: MediaRepository, mock_db_conn: Tuple[Any, assert file.media_id == "mxid123" assert file.base64hash == "abc123" - media_repo._execute.assert_called_with( + media_repo._execute.assert_called_with( # type: ignore "SELECT media_id, creation_ts, base64hash, file_size_bytes from mediaapi_media_repository WHERE media_id = %s;", ("mxid123",), ) @@ -278,7 +278,7 @@ def test_get_local_user_media(media_repo: MediaRepository, mock_db_conn: Tuple[A assert files[0].media_id == "media1" assert files[1].media_id == "media2" - media_repo._execute.assert_called_with( + media_repo._execute.assert_called_with( # type: ignore "SELECT media_id, creation_ts, base64hash, file_size_bytes FROM mediaapi_media_repository WHERE user_id = %s;", ("@user:domain.com",), ) @@ -407,7 +407,7 @@ def test_file_has_thumbnails(media_repo: MediaRepository, mock_db_conn: Tuple[An file = File(media_repo, "mxid123", 1600000000, "abc123", 1000) assert file.has_thumbnail() == 3 # noqa PLR2004 - media_repo._execute.assert_called_with( + media_repo._execute.assert_called_with( # type: ignore "SELECT COUNT(media_id) FROM mediaapi_thumbnail WHERE media_id = %s;", ("mxid123",), ) -- 2.39.5 From 78fb13fb268b64e588d24eee03a7ef39cb0aaeec Mon Sep 17 00:00:00 2001 From: Christian Groschupp Date: Thu, 27 Mar 2025 12:20:25 +0100 Subject: [PATCH 6/8] fix: change sql format --- cleanmedia.py | 6 +++--- tests/test_cleanmedia.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cleanmedia.py b/cleanmedia.py index 0e66174..fc9c5c7 100755 --- a/cleanmedia.py +++ b/cleanmedia.py @@ -207,7 +207,7 @@ class MediaRepository: """Retrieve a single media file by ID.""" cur = self._execute( """SELECT media_id, creation_ts, base64hash, file_size_bytes - from mediaapi_media_repository WHERE media_id = %s;""", +from mediaapi_media_repository WHERE media_id = %s;""", (mxid,), ) row = cur.fetchone() @@ -224,7 +224,7 @@ class MediaRepository: """ cur = self._execute( """SELECT media_id, creation_ts, base64hash, file_size_bytes - FROM mediaapi_media_repository WHERE user_id = %s;""", +FROM mediaapi_media_repository WHERE user_id = %s;""", (user_id,), ) return [File(self, row[0], row[1] // 1000, row[2], row[3]) for row in cur.fetchall()] @@ -239,7 +239,7 @@ class MediaRepository: List of File objects """ query = """SELECT media_id, creation_ts, base64hash, file_size_bytes - FROM mediaapi_media_repository""" +FROM mediaapi_media_repository""" if not local: query += " WHERE user_id = ''" cur = self._execute(query) diff --git a/tests/test_cleanmedia.py b/tests/test_cleanmedia.py index 0089033..e048e5e 100644 --- a/tests/test_cleanmedia.py +++ b/tests/test_cleanmedia.py @@ -144,7 +144,7 @@ def test_get_single_media(media_repo: MediaRepository, mock_db_conn: Tuple[Any, assert file.base64hash == "abc123" media_repo._execute.assert_called_with( # type: ignore - "SELECT media_id, creation_ts, base64hash, file_size_bytes from mediaapi_media_repository WHERE media_id = %s;", + "SELECT media_id, creation_ts, base64hash, file_size_bytes\nfrom mediaapi_media_repository WHERE media_id = %s;", ("mxid123",), ) @@ -279,7 +279,7 @@ def test_get_local_user_media(media_repo: MediaRepository, mock_db_conn: Tuple[A assert files[1].media_id == "media2" media_repo._execute.assert_called_with( # type: ignore - "SELECT media_id, creation_ts, base64hash, file_size_bytes FROM mediaapi_media_repository WHERE user_id = %s;", + "SELECT media_id, creation_ts, base64hash, file_size_bytes\nFROM mediaapi_media_repository WHERE user_id = %s;", ("@user:domain.com",), ) -- 2.39.5 From 016d7db63e32606017b971e233d75e16588104c5 Mon Sep 17 00:00:00 2001 From: Christian Groschupp Date: Thu, 27 Mar 2025 12:47:55 +0100 Subject: [PATCH 7/8] docs: update doc strings --- cleanmedia.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cleanmedia.py b/cleanmedia.py index fc9c5c7..0fa4b42 100755 --- a/cleanmedia.py +++ b/cleanmedia.py @@ -152,7 +152,7 @@ class MediaRepository: Args: media_path: Path to media storage directory - connection_string: PostgreSQL connection string + connection_strings: Dictionary of db connection strings Raises: ValueError: If media_path doesn't exist or connection string is invalid @@ -191,6 +191,9 @@ class MediaRepository: def connect_db(self, connection_string: str) -> DBConnection: """Establish database connection. + Args: + connection_string: db connection string + Returns: PostgreSQL connection object @@ -317,7 +320,7 @@ def read_config(conf_file: Union[str, Path]) -> Tuple[Path, dict[str, str]]: conf_file: Path to Dendrite YAML config file Returns: - Tuple of (media_path, connection_string) + Tuple of (media_path, connection_strings) Raises: SystemExit: If config file is invalid or missing required fields -- 2.39.5 From 3efc1a9dd1ad22b18eaf57f48277668ad9cad7cd Mon Sep 17 00:00:00 2001 From: Christian Groschupp Date: Fri, 28 Mar 2025 22:48:28 +0100 Subject: [PATCH 8/8] style: sort imports --- cleanmedia.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cleanmedia.py b/cleanmedia.py index 0fa4b42..86d43e6 100755 --- a/cleanmedia.py +++ b/cleanmedia.py @@ -21,12 +21,12 @@ along with this program. If not, see . import argparse import logging +import sqlite3 import sys from datetime import datetime, timedelta from functools import cached_property from pathlib import Path -from typing import List, Tuple, TypeAlias, Union, Sequence, Mapping, Any, Callable -import sqlite3 +from typing import Any, Callable, List, Mapping, Sequence, Tuple, TypeAlias, Union try: import psycopg2 -- 2.39.5