feat: update tests

This commit is contained in:
Christian Groschupp 2025-03-27 11:27:16 +01:00
parent 9751bf28b9
commit 19476c593f
2 changed files with 66 additions and 63 deletions

View File

@ -25,7 +25,7 @@ import sys
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import List, Tuple, TypeAlias, Union from typing import List, Tuple, TypeAlias, Union, Sequence, Mapping, Any, Callable
import sqlite3 import sqlite3
try: try:
@ -36,7 +36,9 @@ except ImportError as err:
raise ImportError("Required dependencies not found. Please install psycopg2 and pyyaml.") from err raise ImportError("Required dependencies not found. Please install psycopg2 and pyyaml.") from err
# Type aliases # Type aliases
DBConnection: TypeAlias = psycopg2.extensions.connection DBConnection: TypeAlias = Union[sqlite3.Connection, psycopg2.extensions.connection]
DBCursor = Union[sqlite3.Cursor, psycopg2.extensions.cursor]
Params = Union[Sequence[Any], Mapping[str, Any]]
MediaID: TypeAlias = str MediaID: TypeAlias = str
Timestamp: TypeAlias = int Timestamp: TypeAlias = int
Base64Hash: TypeAlias = str Base64Hash: TypeAlias = str
@ -113,12 +115,11 @@ class File:
Returns: Returns:
True if database entries were deleted successfully True if database entries were deleted successfully
""" """
with self.repo.conn.cursor() as cur: cur = self.repo._execute("DELETE from mediaapi_thumbnail WHERE media_id=%s;", (self.media_id,))
cur.execute("DELETE from mediaapi_thumbnail WHERE media_id=%s;", (self.media_id,)) num_thumbnails = cur.rowcount
num_thumbnails = cur.rowcount self.repo._execute("DELETE from mediaapi_media_repository WHERE media_id=%s;", (self.media_id,))
cur.execute("DELETE from mediaapi_media_repository WHERE media_id=%s;", (self.media_id,)) num_media = cur.rowcount
num_media = cur.rowcount self.repo.conn["media_api"].commit()
self.repo.conn.commit()
logging.debug(f"Deleted {num_media} + {num_thumbnails} db entries for media id {self.media_id}") logging.debug(f"Deleted {num_media} + {num_thumbnails} db entries for media id {self.media_id}")
return True return True
@ -138,10 +139,9 @@ class File:
Returns: Returns:
Number of thumbnails Number of thumbnails
""" """
with self.repo.conn.cursor() as cur: cur = self.repo._execute("SELECT COUNT(media_id) FROM mediaapi_thumbnail WHERE media_id = %s;", (self.media_id,))
cur.execute("SELECT COUNT(media_id) FROM mediaapi_thumbnail WHERE media_id = %s;", (self.media_id,)) row = cur.fetchone()
row = cur.fetchone() return int(row[0]) if row else 0
return int(row[0]) if row else 0
class MediaRepository: class MediaRepository:
@ -160,12 +160,12 @@ class MediaRepository:
self._validate_media_path(media_path) self._validate_media_path(media_path)
self.media_path = media_path self.media_path = media_path
self._avatar_media_ids: List[MediaID] = [] self._avatar_media_ids: List[MediaID] = []
self.conn = {} self.conn: dict[str, DBConnection] = {}
for db_type, conn_string in connection_strings.items(): for db_type, conn_string in connection_strings.items():
self.conn[db_type] = self.connect_db(conn_string) self.conn[db_type] = self.connect_db(conn_string)
def _execute(self, query: str, params=(), db_type: str = "media_api"): def _execute(self, query: str, params: Params = (), db_type: str = "media_api") -> DBCursor:
paramstyle = getattr(self.conn, "paramstyle", "format") paramstyle = getattr(self.conn, "paramstyle", "format")
query = self._adjust_paramstyle(query, paramstyle) query = self._adjust_paramstyle(query, paramstyle)
cur = self.conn[db_type].cursor() cur = self.conn[db_type].cursor()
@ -176,7 +176,7 @@ class MediaRepository:
cur.close() cur.close()
raise e raise e
def _adjust_paramstyle(self, query: str, paramstyle: str): def _adjust_paramstyle(self, query: str, paramstyle: str) -> str:
if paramstyle == "qmark": if paramstyle == "qmark":
return query.replace("%s", "?") return query.replace("%s", "?")
return query return query
@ -206,11 +206,11 @@ class MediaRepository:
def get_single_media(self, mxid: MediaID) -> File | None: def get_single_media(self, mxid: MediaID) -> File | None:
"""Retrieve a single media file by ID.""" """Retrieve a single media file by ID."""
cur = self._execute( cur = self._execute(
"SELECT media_id, creation_ts, base64hash from mediaapi_media_repository WHERE media_id = %s;", "SELECT media_id, creation_ts, base64hash, file_size_bytes from mediaapi_media_repository WHERE media_id = %s;",
(mxid,), (mxid,),
) )
row = cur.fetchone() row = cur.fetchone()
return File(self, row[0], row[1] // 1000, row[2]) if row else None return File(self, row[0], row[1] // 1000, row[2], row[3]) if row else None
def get_local_user_media(self, user_id: UserID) -> List[File]: def get_local_user_media(self, user_id: UserID) -> List[File]:
"""Get all media files created by a local user. """Get all media files created by a local user.
@ -222,10 +222,10 @@ class MediaRepository:
List of File objects List of File objects
""" """
cur = self._execute( cur = self._execute(
"SELECT media_id, creation_ts, base64hash FROM mediaapi_media_repository WHERE user_id = %s;", "SELECT media_id, creation_ts, base64hash, file_size_bytes FROM mediaapi_media_repository WHERE user_id = %s;",
(user_id,), (user_id,),
) )
return [File(self, row[0], row[1] // 1000, row[2]) for row in cur.fetchall()] return [File(self, row[0], row[1] // 1000, row[2], row[3]) for row in cur.fetchall()]
def get_all_media(self, local: bool = False) -> List[File]: def get_all_media(self, local: bool = False) -> List[File]:
"""Get all media files or only remote ones. """Get all media files or only remote ones.
@ -267,8 +267,7 @@ class MediaRepository:
) )
if (row := cur.fetchone()) and (count := row[0]): if (row := cur.fetchone()) and (count := row[0]):
logging.error( logging.error(
"You have %d thumbnails in your db that do not refer to media. " "You have %d thumbnails in your db that do not refer to media. " "This needs fixing (we don't do that)!",
"This needs fixing (we don't do that)!",
count, count,
) )
@ -309,7 +308,7 @@ class MediaRepository:
return len(files_to_delete) return len(files_to_delete)
def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str, dict[str, str]]: def read_config(conf_file: Union[str, Path]) -> Tuple[Path, dict[str, str]]:
"""Read database credentials and media path from config. """Read database credentials and media path from config.
Args: Args:
@ -328,12 +327,11 @@ def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str, dict[str, str]]
logging.error("Config file %s not found. Use --help for usage.", conf_file) logging.error("Config file %s not found. Use --help for usage.", conf_file)
sys.exit(1) sys.exit(1)
global_conn_string = config.get("global", {}).get("database", {}).get("connection_string")
global_conn_string = config.get("global",{}).get("database", {}).get("connection_string") db_types: dict[str, Callable[[dict[str, Any]], dict[str, Any]]] = {
"media_api": lambda x: x.get("media_api", {}).get("database", {}),
db_types = { "user_api": lambda x: x.get("user_api", {}).get("account_database", {}),
"media_api": lambda x: x.get("media_api",{}).get("database", {}),
"user_api": lambda x: x.get("user_api",{}).get("account_database", {}),
} }
conns = {} conns = {}
@ -351,6 +349,7 @@ def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str, dict[str, str]]
return Path(base_path), conns return Path(base_path), conns
def sizeof_fmt(num: Union[int, float], suffix: str = "B") -> str: def sizeof_fmt(num: Union[int, float], suffix: str = "B") -> str:
""" """
Convert a number of bytes (or other units) into a human-readable format using binary prefixes. Convert a number of bytes (or other units) into a human-readable format using binary prefixes.

View File

@ -42,17 +42,18 @@ def media_repo(tmp_path: Any, mock_db_conn: Tuple[Any, Any], mocker: MockerFixtu
Returns: Returns:
Configured MediaRepository instance Configured MediaRepository instance
""" """
conn_mock, _ = mock_db_conn conn_mock, cursor = mock_db_conn
media_path = tmp_path / "media" media_path = tmp_path / "media"
media_path.mkdir() media_path.mkdir()
mocker.patch("cleanmedia.MediaRepository.connect_db", return_value=conn_mock) mocker.patch("cleanmedia.MediaRepository.connect_db", return_value=conn_mock)
return MediaRepository(media_path, "postgresql://fake") mocker.patch("cleanmedia.MediaRepository._execute", return_value=cursor)
return MediaRepository(media_path, {"media_api": "postgresql://fake", "user_api": "postgresql://fake"})
def test_file_init(mocker: MockerFixture) -> None: def test_file_init(mocker: MockerFixture) -> None:
"""Test File class initialization.""" """Test File class initialization."""
repo = mocker.Mock() repo = mocker.Mock()
file = File(repo, "mxid123", 1600000000, "base64hash123") file = File(repo, "mxid123", 1600000000, "base64hash123", 1000)
assert file.media_id == "mxid123" assert file.media_id == "mxid123"
assert file.create_date == datetime.fromtimestamp(1600000000) assert file.create_date == datetime.fromtimestamp(1600000000)
assert file.base64hash == "base64hash123" assert file.base64hash == "base64hash123"
@ -61,26 +62,26 @@ def test_file_init(mocker: MockerFixture) -> None:
def test_file_fullpath(media_repo: MediaRepository) -> None: def test_file_fullpath(media_repo: MediaRepository) -> None:
"""Test File.fullpath property returns correct path.""" """Test File.fullpath property returns correct path."""
file = File(media_repo, "mxid123", 1600000000, "abc123") file = File(media_repo, "mxid123", 1600000000, "abc123", 1000)
expected_path = media_repo.media_path / "a" / "b" / "c123" expected_path = media_repo.media_path / "a" / "b" / "c123"
assert file.fullpath == expected_path assert file.fullpath == expected_path
def test_file_exists_no_path(media_repo: MediaRepository) -> None: def test_file_exists_no_path(media_repo: MediaRepository) -> None:
"""Test File.exists returns False when fullpath is None.""" """Test File.exists returns False when fullpath is None."""
file = File(media_repo, "mxid123", 1600000000, "") # Empty hash ensures fullpath is None file = File(media_repo, "mxid123", 1600000000, "", 1000) # Empty hash ensures fullpath is None
assert file.exists() is False assert file.exists() is False
def test_file_delete_no_path(media_repo: MediaRepository) -> None: def test_file_delete_no_path(media_repo: MediaRepository) -> None:
"""Test File._delete_files when file path is None.""" """Test File._delete_files when file path is None."""
file = File(media_repo, "mxid123", 1600000000, "") file = File(media_repo, "mxid123", 1600000000, "", 1000)
assert file._delete_files() is False assert file._delete_files() is False
def test_file_delete_oserror(media_repo: MediaRepository, mocker: MockerFixture, caplog: Any) -> None: def test_file_delete_oserror(media_repo: MediaRepository, mocker: MockerFixture, caplog: Any) -> None:
"""Test File._delete_files when OSError occurs.""" """Test File._delete_files when OSError occurs."""
file = File(media_repo, "mxid123", 1600000000, "abc123") file = File(media_repo, "mxid123", 1600000000, "abc123", 1000)
# Create directory structure # Create directory structure
file_path = media_repo.media_path / "a" / "b" / "c123" file_path = media_repo.media_path / "a" / "b" / "c123"
@ -96,13 +97,13 @@ def test_file_delete_oserror(media_repo: MediaRepository, mocker: MockerFixture,
def test_file_fullpath_none_if_no_hash(media_repo: MediaRepository) -> None: def test_file_fullpath_none_if_no_hash(media_repo: MediaRepository) -> None:
"""Test File.fullpath returns None when hash is empty.""" """Test File.fullpath returns None when hash is empty."""
file = File(media_repo, "mxid123", 1600000000, "") file = File(media_repo, "mxid123", 1600000000, "", 1000)
assert file.fullpath is None assert file.fullpath is None
def test_file_exists(media_repo: MediaRepository) -> None: def test_file_exists(media_repo: MediaRepository) -> None:
"""Test File.exists returns True when file exists.""" """Test File.exists returns True when file exists."""
file = File(media_repo, "mxid123", 1600000000, "abc123") file = File(media_repo, "mxid123", 1600000000, "abc123", 1000)
file_path = media_repo.media_path / "a" / "b" / "c123" file_path = media_repo.media_path / "a" / "b" / "c123"
file_path.mkdir(parents=True) file_path.mkdir(parents=True)
(file_path / "file").touch() (file_path / "file").touch()
@ -111,14 +112,14 @@ def test_file_exists(media_repo: MediaRepository) -> None:
def test_file_not_exists(media_repo: MediaRepository) -> None: def test_file_not_exists(media_repo: MediaRepository) -> None:
"""Test File.exists returns False when file doesn't exist.""" """Test File.exists returns False when file doesn't exist."""
file = File(media_repo, "mxid123", 1600000000, "abc123") file = File(media_repo, "mxid123", 1600000000, "abc123", 1000)
assert file.exists() is False assert file.exists() is False
def test_file_delete(media_repo: MediaRepository, mock_db_conn: Tuple[Any, Any]) -> None: def test_file_delete(media_repo: MediaRepository, mock_db_conn: Tuple[Any, Any]) -> None:
"""Test File.delete removes files and database entries.""" """Test File.delete removes files and database entries."""
_, cursor_mock = mock_db_conn _, cursor_mock = mock_db_conn
file = File(media_repo, "mxid123", 1600000000, "abc123") file = File(media_repo, "mxid123", 1600000000, "abc123", 1000)
file_path = media_repo.media_path / "a" / "b" / "c123" file_path = media_repo.media_path / "a" / "b" / "c123"
file_path.mkdir(parents=True) file_path.mkdir(parents=True)
@ -128,22 +129,22 @@ def test_file_delete(media_repo: MediaRepository, mock_db_conn: Tuple[Any, Any])
assert file.delete() is True assert file.delete() is True
assert not file_path.exists() assert not file_path.exists()
cursor_mock.execute.assert_any_call("DELETE from mediaapi_thumbnail WHERE media_id=%s;", ("mxid123",)) media_repo._execute.assert_any_call("DELETE from mediaapi_thumbnail WHERE media_id=%s;", ("mxid123",))
cursor_mock.execute.assert_any_call("DELETE from mediaapi_media_repository WHERE media_id=%s;", ("mxid123",)) media_repo._execute.assert_any_call("DELETE from mediaapi_media_repository WHERE media_id=%s;", ("mxid123",))
def test_get_single_media(media_repo: MediaRepository, mock_db_conn: Tuple[Any, Any]) -> None: def test_get_single_media(media_repo: MediaRepository, mock_db_conn: Tuple[Any, Any]) -> None:
"""Test MediaRepository.get_single_media returns correct File object.""" """Test MediaRepository.get_single_media returns correct File object."""
_, cursor_mock = mock_db_conn _, cursor_mock = mock_db_conn
cursor_mock.fetchone.return_value = ("mxid123", 1600000000000, "abc123") cursor_mock.fetchone.return_value = ("mxid123", 1600000000000, "abc123", 1000)
file = media_repo.get_single_media("mxid123") file = media_repo.get_single_media("mxid123")
assert file is not None assert file is not None
assert file.media_id == "mxid123" assert file.media_id == "mxid123"
assert file.base64hash == "abc123" assert file.base64hash == "abc123"
cursor_mock.execute.assert_called_with( media_repo._execute.assert_called_with(
"SELECT media_id, creation_ts, base64hash from mediaapi_media_repository WHERE media_id = %s;", "SELECT media_id, creation_ts, base64hash, file_size_bytes from mediaapi_media_repository WHERE media_id = %s;",
("mxid123",), ("mxid123",),
) )
@ -165,8 +166,8 @@ def test_clean_media_files(media_repo: MediaRepository, mock_db_conn: Tuple[Any,
new_date = int((datetime.now() - timedelta(days=1)).timestamp()) new_date = int((datetime.now() - timedelta(days=1)).timestamp())
cursor_mock.fetchall.return_value = [ cursor_mock.fetchall.return_value = [
("old_file", old_date * 1000, "abc123"), ("old_file", old_date * 1000, "abc123", 1000),
("new_file", new_date * 1000, "def456"), ("new_file", new_date * 1000, "def456", 1000),
] ]
media_repo._avatar_media_ids = [] media_repo._avatar_media_ids = []
@ -181,7 +182,7 @@ def test_clean_media_files_dryrun(media_repo: MediaRepository, mock_db_conn: Tup
old_date = int((datetime.now() - timedelta(days=31)).timestamp()) old_date = int((datetime.now() - timedelta(days=31)).timestamp())
cursor_mock.fetchall.return_value = [ cursor_mock.fetchall.return_value = [
("old_file", old_date * 1000, "abc123"), ("old_file", old_date * 1000, "abc123", 1000),
] ]
media_repo._avatar_media_ids = [] media_repo._avatar_media_ids = []
@ -252,24 +253,24 @@ def test_validate_media_path_relative(caplog: Any) -> None:
def test_connect_db_success(mocker: MockerFixture) -> None: def test_connect_db_success(mocker: MockerFixture) -> None:
"""Test successful database connection.""" """Test successful database connection."""
mock_connect = mocker.patch("psycopg2.connect") mock_connect = mocker.patch("psycopg2.connect")
repo = MediaRepository(Path("/tmp"), "postgresql://fake") repo = MediaRepository(Path("/tmp"), {"media_api": "postgresql://fake"})
repo.connect_db() repo.connect_db("postgresql://fake")
mock_connect.assert_called_with("postgresql://fake") mock_connect.assert_called_with("postgresql://fake")
def test_connect_db_invalid_string(tmp_path: Path) -> None: def test_connect_db_invalid_string(tmp_path: Path) -> None:
"""Test connect_db with invalid connection string.""" """Test connect_db with invalid connection string."""
with pytest.raises(ValueError, match="Invalid PostgreSQL connection string"): with pytest.raises(ValueError, match="Invalid PostgreSQL connection string"):
repo = MediaRepository(tmp_path, "invalid") repo = MediaRepository(tmp_path, {"media_api": "invalid"})
repo.connect_db() repo.connect_db("invalid")
def test_get_local_user_media(media_repo: MediaRepository, mock_db_conn: Tuple[Any, Any]) -> None: def test_get_local_user_media(media_repo: MediaRepository, mock_db_conn: Tuple[Any, Any]) -> None:
"""Test get_local_user_media returns correct files.""" """Test get_local_user_media returns correct files."""
_, cursor_mock = mock_db_conn _, cursor_mock = mock_db_conn
cursor_mock.fetchall.return_value = [ cursor_mock.fetchall.return_value = [
("media1", 1600000000000, "hash1"), ("media1", 1600000000000, "hash1", 1000),
("media2", 1600000000000, "hash2"), ("media2", 1600000000000, "hash2", 1000),
] ]
files = media_repo.get_local_user_media("@user:domain.com") files = media_repo.get_local_user_media("@user:domain.com")
@ -277,8 +278,8 @@ def test_get_local_user_media(media_repo: MediaRepository, mock_db_conn: Tuple[A
assert files[0].media_id == "media1" assert files[0].media_id == "media1"
assert files[1].media_id == "media2" assert files[1].media_id == "media2"
cursor_mock.execute.assert_called_with( media_repo._execute.assert_called_with(
"SELECT media_id, creation_ts, base64hash FROM mediaapi_media_repository WHERE user_id = %s;", "SELECT media_id, creation_ts, base64hash, file_size_bytes FROM mediaapi_media_repository WHERE user_id = %s;",
("@user:domain.com",), ("@user:domain.com",),
) )
@ -333,15 +334,18 @@ def test_read_config_valid(tmp_path: Path) -> None:
"""Test read_config with valid config.""" """Test read_config with valid config."""
config_file = tmp_path / "config.yaml" config_file = tmp_path / "config.yaml"
config_file.write_text(""" config_file.write_text("""
global:
database:
connection_string: postgresql://global/db
media_api: media_api:
base_path: /media/path base_path: /media/path
database: database:
connection_string: postgresql://user:pass@localhost/db connection_string: postgresql://user:pass@localhost/db
""") """)
path, conn_string = read_config(config_file) path, conn_strings = read_config(config_file)
assert path == Path("/media/path") assert path == Path("/media/path")
assert conn_string == "postgresql://user:pass@localhost/db" assert conn_strings["media_api"] == "postgresql://user:pass@localhost/db"
def test_read_config_global_database(tmp_path: Path) -> None: def test_read_config_global_database(tmp_path: Path) -> None:
@ -355,8 +359,8 @@ global:
connection_string: postgresql://global/db connection_string: postgresql://global/db
""") """)
path, conn_string = read_config(config_file) _, conn_string = read_config(config_file)
assert conn_string == "postgresql://global/db" assert conn_string["media_api"] == "postgresql://global/db"
def test_read_config_missing_conn_string(tmp_path: Path) -> None: def test_read_config_missing_conn_string(tmp_path: Path) -> None:
@ -401,9 +405,9 @@ def test_file_has_thumbnails(media_repo: MediaRepository, mock_db_conn: Tuple[An
_, cursor_mock = mock_db_conn _, cursor_mock = mock_db_conn
cursor_mock.fetchone.return_value = (3,) cursor_mock.fetchone.return_value = (3,)
file = File(media_repo, "mxid123", 1600000000, "abc123") file = File(media_repo, "mxid123", 1600000000, "abc123", 1000)
assert file.has_thumbnail() == 3 # noqa PLR2004 assert file.has_thumbnail() == 3 # noqa PLR2004
cursor_mock.execute.assert_called_with( media_repo._execute.assert_called_with(
"SELECT COUNT(media_id) FROM mediaapi_thumbnail WHERE media_id = %s;", "SELECT COUNT(media_id) FROM mediaapi_thumbnail WHERE media_id = %s;",
("mxid123",), ("mxid123",),
) )
@ -414,5 +418,5 @@ def test_file_has_no_thumbnails(media_repo: MediaRepository, mock_db_conn: Tuple
_, cursor_mock = mock_db_conn _, cursor_mock = mock_db_conn
cursor_mock.fetchone.return_value = None cursor_mock.fetchone.return_value = None
file = File(media_repo, "mxid123", 1600000000, "abc123") file = File(media_repo, "mxid123", 1600000000, "abc123", 1000)
assert file.has_thumbnail() == 0 assert file.has_thumbnail() == 0