feat: update tests

This commit is contained in:
Christian Groschupp 2025-03-27 11:27:16 +01:00
parent 9751bf28b9
commit 19476c593f
2 changed files with 66 additions and 63 deletions

View File

@ -25,7 +25,7 @@ import sys
from datetime import datetime, timedelta
from functools import cached_property
from pathlib import Path
from typing import List, Tuple, TypeAlias, Union
from typing import List, Tuple, TypeAlias, Union, Sequence, Mapping, Any, Callable
import sqlite3
try:
@ -36,7 +36,9 @@ except ImportError as err:
raise ImportError("Required dependencies not found. Please install psycopg2 and pyyaml.") from err
# Type aliases
DBConnection: TypeAlias = psycopg2.extensions.connection
DBConnection: TypeAlias = Union[sqlite3.Connection, psycopg2.extensions.connection]
DBCursor = Union[sqlite3.Cursor, psycopg2.extensions.cursor]
Params = Union[Sequence[Any], Mapping[str, Any]]
MediaID: TypeAlias = str
Timestamp: TypeAlias = int
Base64Hash: TypeAlias = str
@ -113,12 +115,11 @@ class File:
Returns:
True if database entries were deleted successfully
"""
with self.repo.conn.cursor() as cur:
cur.execute("DELETE from mediaapi_thumbnail WHERE media_id=%s;", (self.media_id,))
cur = self.repo._execute("DELETE from mediaapi_thumbnail WHERE media_id=%s;", (self.media_id,))
num_thumbnails = cur.rowcount
cur.execute("DELETE from mediaapi_media_repository WHERE media_id=%s;", (self.media_id,))
self.repo._execute("DELETE from mediaapi_media_repository WHERE media_id=%s;", (self.media_id,))
num_media = cur.rowcount
self.repo.conn.commit()
self.repo.conn["media_api"].commit()
logging.debug(f"Deleted {num_media} + {num_thumbnails} db entries for media id {self.media_id}")
return True
@ -138,8 +139,7 @@ class File:
Returns:
Number of thumbnails
"""
with self.repo.conn.cursor() as cur:
cur.execute("SELECT COUNT(media_id) FROM mediaapi_thumbnail WHERE media_id = %s;", (self.media_id,))
cur = self.repo._execute("SELECT COUNT(media_id) FROM mediaapi_thumbnail WHERE media_id = %s;", (self.media_id,))
row = cur.fetchone()
return int(row[0]) if row else 0
@ -160,12 +160,12 @@ class MediaRepository:
self._validate_media_path(media_path)
self.media_path = media_path
self._avatar_media_ids: List[MediaID] = []
self.conn = {}
self.conn: dict[str, DBConnection] = {}
for db_type, conn_string in connection_strings.items():
self.conn[db_type] = self.connect_db(conn_string)
def _execute(self, query: str, params=(), db_type: str = "media_api"):
def _execute(self, query: str, params: Params = (), db_type: str = "media_api") -> DBCursor:
paramstyle = getattr(self.conn, "paramstyle", "format")
query = self._adjust_paramstyle(query, paramstyle)
cur = self.conn[db_type].cursor()
@ -176,7 +176,7 @@ class MediaRepository:
cur.close()
raise e
def _adjust_paramstyle(self, query: str, paramstyle: str):
def _adjust_paramstyle(self, query: str, paramstyle: str) -> str:
if paramstyle == "qmark":
return query.replace("%s", "?")
return query
@ -206,11 +206,11 @@ class MediaRepository:
def get_single_media(self, mxid: MediaID) -> File | None:
"""Retrieve a single media file by ID."""
cur = self._execute(
"SELECT media_id, creation_ts, base64hash from mediaapi_media_repository WHERE media_id = %s;",
"SELECT media_id, creation_ts, base64hash, file_size_bytes from mediaapi_media_repository WHERE media_id = %s;",
(mxid,),
)
row = cur.fetchone()
return File(self, row[0], row[1] // 1000, row[2]) if row else None
return File(self, row[0], row[1] // 1000, row[2], row[3]) if row else None
def get_local_user_media(self, user_id: UserID) -> List[File]:
"""Get all media files created by a local user.
@ -222,10 +222,10 @@ class MediaRepository:
List of File objects
"""
cur = self._execute(
"SELECT media_id, creation_ts, base64hash FROM mediaapi_media_repository WHERE user_id = %s;",
"SELECT media_id, creation_ts, base64hash, file_size_bytes FROM mediaapi_media_repository WHERE user_id = %s;",
(user_id,),
)
return [File(self, row[0], row[1] // 1000, row[2]) for row in cur.fetchall()]
return [File(self, row[0], row[1] // 1000, row[2], row[3]) for row in cur.fetchall()]
def get_all_media(self, local: bool = False) -> List[File]:
"""Get all media files or only remote ones.
@ -267,8 +267,7 @@ class MediaRepository:
)
if (row := cur.fetchone()) and (count := row[0]):
logging.error(
"You have %d thumbnails in your db that do not refer to media. "
"This needs fixing (we don't do that)!",
"You have %d thumbnails in your db that do not refer to media. " "This needs fixing (we don't do that)!",
count,
)
@ -309,7 +308,7 @@ class MediaRepository:
return len(files_to_delete)
def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str, dict[str, str]]:
def read_config(conf_file: Union[str, Path]) -> Tuple[Path, dict[str, str]]:
"""Read database credentials and media path from config.
Args:
@ -328,10 +327,9 @@ def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str, dict[str, str]]
logging.error("Config file %s not found. Use --help for usage.", conf_file)
sys.exit(1)
global_conn_string = config.get("global", {}).get("database", {}).get("connection_string")
db_types = {
db_types: dict[str, Callable[[dict[str, Any]], dict[str, Any]]] = {
"media_api": lambda x: x.get("media_api", {}).get("database", {}),
"user_api": lambda x: x.get("user_api", {}).get("account_database", {}),
}
@ -351,6 +349,7 @@ def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str, dict[str, str]]
return Path(base_path), conns
def sizeof_fmt(num: Union[int, float], suffix: str = "B") -> str:
"""
Convert a number of bytes (or other units) into a human-readable format using binary prefixes.

View File

@ -42,17 +42,18 @@ def media_repo(tmp_path: Any, mock_db_conn: Tuple[Any, Any], mocker: MockerFixtu
Returns:
Configured MediaRepository instance
"""
conn_mock, _ = mock_db_conn
conn_mock, cursor = mock_db_conn
media_path = tmp_path / "media"
media_path.mkdir()
mocker.patch("cleanmedia.MediaRepository.connect_db", return_value=conn_mock)
return MediaRepository(media_path, "postgresql://fake")
mocker.patch("cleanmedia.MediaRepository._execute", return_value=cursor)
return MediaRepository(media_path, {"media_api": "postgresql://fake", "user_api": "postgresql://fake"})
def test_file_init(mocker: MockerFixture) -> None:
"""Test File class initialization."""
repo = mocker.Mock()
file = File(repo, "mxid123", 1600000000, "base64hash123")
file = File(repo, "mxid123", 1600000000, "base64hash123", 1000)
assert file.media_id == "mxid123"
assert file.create_date == datetime.fromtimestamp(1600000000)
assert file.base64hash == "base64hash123"
@ -61,26 +62,26 @@ def test_file_init(mocker: MockerFixture) -> None:
def test_file_fullpath(media_repo: MediaRepository) -> None:
"""Test File.fullpath property returns correct path."""
file = File(media_repo, "mxid123", 1600000000, "abc123")
file = File(media_repo, "mxid123", 1600000000, "abc123", 1000)
expected_path = media_repo.media_path / "a" / "b" / "c123"
assert file.fullpath == expected_path
def test_file_exists_no_path(media_repo: MediaRepository) -> None:
"""Test File.exists returns False when fullpath is None."""
file = File(media_repo, "mxid123", 1600000000, "") # Empty hash ensures fullpath is None
file = File(media_repo, "mxid123", 1600000000, "", 1000) # Empty hash ensures fullpath is None
assert file.exists() is False
def test_file_delete_no_path(media_repo: MediaRepository) -> None:
"""Test File._delete_files when file path is None."""
file = File(media_repo, "mxid123", 1600000000, "")
file = File(media_repo, "mxid123", 1600000000, "", 1000)
assert file._delete_files() is False
def test_file_delete_oserror(media_repo: MediaRepository, mocker: MockerFixture, caplog: Any) -> None:
"""Test File._delete_files when OSError occurs."""
file = File(media_repo, "mxid123", 1600000000, "abc123")
file = File(media_repo, "mxid123", 1600000000, "abc123", 1000)
# Create directory structure
file_path = media_repo.media_path / "a" / "b" / "c123"
@ -96,13 +97,13 @@ def test_file_delete_oserror(media_repo: MediaRepository, mocker: MockerFixture,
def test_file_fullpath_none_if_no_hash(media_repo: MediaRepository) -> None:
"""Test File.fullpath returns None when hash is empty."""
file = File(media_repo, "mxid123", 1600000000, "")
file = File(media_repo, "mxid123", 1600000000, "", 1000)
assert file.fullpath is None
def test_file_exists(media_repo: MediaRepository) -> None:
"""Test File.exists returns True when file exists."""
file = File(media_repo, "mxid123", 1600000000, "abc123")
file = File(media_repo, "mxid123", 1600000000, "abc123", 1000)
file_path = media_repo.media_path / "a" / "b" / "c123"
file_path.mkdir(parents=True)
(file_path / "file").touch()
@ -111,14 +112,14 @@ def test_file_exists(media_repo: MediaRepository) -> None:
def test_file_not_exists(media_repo: MediaRepository) -> None:
"""Test File.exists returns False when file doesn't exist."""
file = File(media_repo, "mxid123", 1600000000, "abc123")
file = File(media_repo, "mxid123", 1600000000, "abc123", 1000)
assert file.exists() is False
def test_file_delete(media_repo: MediaRepository, mock_db_conn: Tuple[Any, Any]) -> None:
"""Test File.delete removes files and database entries."""
_, cursor_mock = mock_db_conn
file = File(media_repo, "mxid123", 1600000000, "abc123")
file = File(media_repo, "mxid123", 1600000000, "abc123", 1000)
file_path = media_repo.media_path / "a" / "b" / "c123"
file_path.mkdir(parents=True)
@ -128,22 +129,22 @@ def test_file_delete(media_repo: MediaRepository, mock_db_conn: Tuple[Any, Any])
assert file.delete() is True
assert not file_path.exists()
cursor_mock.execute.assert_any_call("DELETE from mediaapi_thumbnail WHERE media_id=%s;", ("mxid123",))
cursor_mock.execute.assert_any_call("DELETE from mediaapi_media_repository WHERE media_id=%s;", ("mxid123",))
media_repo._execute.assert_any_call("DELETE from mediaapi_thumbnail WHERE media_id=%s;", ("mxid123",))
media_repo._execute.assert_any_call("DELETE from mediaapi_media_repository WHERE media_id=%s;", ("mxid123",))
def test_get_single_media(media_repo: MediaRepository, mock_db_conn: Tuple[Any, Any]) -> None:
"""Test MediaRepository.get_single_media returns correct File object."""
_, cursor_mock = mock_db_conn
cursor_mock.fetchone.return_value = ("mxid123", 1600000000000, "abc123")
cursor_mock.fetchone.return_value = ("mxid123", 1600000000000, "abc123", 1000)
file = media_repo.get_single_media("mxid123")
assert file is not None
assert file.media_id == "mxid123"
assert file.base64hash == "abc123"
cursor_mock.execute.assert_called_with(
"SELECT media_id, creation_ts, base64hash from mediaapi_media_repository WHERE media_id = %s;",
media_repo._execute.assert_called_with(
"SELECT media_id, creation_ts, base64hash, file_size_bytes from mediaapi_media_repository WHERE media_id = %s;",
("mxid123",),
)
@ -165,8 +166,8 @@ def test_clean_media_files(media_repo: MediaRepository, mock_db_conn: Tuple[Any,
new_date = int((datetime.now() - timedelta(days=1)).timestamp())
cursor_mock.fetchall.return_value = [
("old_file", old_date * 1000, "abc123"),
("new_file", new_date * 1000, "def456"),
("old_file", old_date * 1000, "abc123", 1000),
("new_file", new_date * 1000, "def456", 1000),
]
media_repo._avatar_media_ids = []
@ -181,7 +182,7 @@ def test_clean_media_files_dryrun(media_repo: MediaRepository, mock_db_conn: Tup
old_date = int((datetime.now() - timedelta(days=31)).timestamp())
cursor_mock.fetchall.return_value = [
("old_file", old_date * 1000, "abc123"),
("old_file", old_date * 1000, "abc123", 1000),
]
media_repo._avatar_media_ids = []
@ -252,24 +253,24 @@ def test_validate_media_path_relative(caplog: Any) -> None:
def test_connect_db_success(mocker: MockerFixture) -> None:
"""Test successful database connection."""
mock_connect = mocker.patch("psycopg2.connect")
repo = MediaRepository(Path("/tmp"), "postgresql://fake")
repo.connect_db()
repo = MediaRepository(Path("/tmp"), {"media_api": "postgresql://fake"})
repo.connect_db("postgresql://fake")
mock_connect.assert_called_with("postgresql://fake")
def test_connect_db_invalid_string(tmp_path: Path) -> None:
"""Test connect_db with invalid connection string."""
with pytest.raises(ValueError, match="Invalid PostgreSQL connection string"):
repo = MediaRepository(tmp_path, "invalid")
repo.connect_db()
repo = MediaRepository(tmp_path, {"media_api": "invalid"})
repo.connect_db("invalid")
def test_get_local_user_media(media_repo: MediaRepository, mock_db_conn: Tuple[Any, Any]) -> None:
"""Test get_local_user_media returns correct files."""
_, cursor_mock = mock_db_conn
cursor_mock.fetchall.return_value = [
("media1", 1600000000000, "hash1"),
("media2", 1600000000000, "hash2"),
("media1", 1600000000000, "hash1", 1000),
("media2", 1600000000000, "hash2", 1000),
]
files = media_repo.get_local_user_media("@user:domain.com")
@ -277,8 +278,8 @@ def test_get_local_user_media(media_repo: MediaRepository, mock_db_conn: Tuple[A
assert files[0].media_id == "media1"
assert files[1].media_id == "media2"
cursor_mock.execute.assert_called_with(
"SELECT media_id, creation_ts, base64hash FROM mediaapi_media_repository WHERE user_id = %s;",
media_repo._execute.assert_called_with(
"SELECT media_id, creation_ts, base64hash, file_size_bytes FROM mediaapi_media_repository WHERE user_id = %s;",
("@user:domain.com",),
)
@ -333,15 +334,18 @@ def test_read_config_valid(tmp_path: Path) -> None:
"""Test read_config with valid config."""
config_file = tmp_path / "config.yaml"
config_file.write_text("""
global:
database:
connection_string: postgresql://global/db
media_api:
base_path: /media/path
database:
connection_string: postgresql://user:pass@localhost/db
""")
path, conn_string = read_config(config_file)
path, conn_strings = read_config(config_file)
assert path == Path("/media/path")
assert conn_string == "postgresql://user:pass@localhost/db"
assert conn_strings["media_api"] == "postgresql://user:pass@localhost/db"
def test_read_config_global_database(tmp_path: Path) -> None:
@ -355,8 +359,8 @@ global:
connection_string: postgresql://global/db
""")
path, conn_string = read_config(config_file)
assert conn_string == "postgresql://global/db"
_, conn_string = read_config(config_file)
assert conn_string["media_api"] == "postgresql://global/db"
def test_read_config_missing_conn_string(tmp_path: Path) -> None:
@ -401,9 +405,9 @@ def test_file_has_thumbnails(media_repo: MediaRepository, mock_db_conn: Tuple[An
_, cursor_mock = mock_db_conn
cursor_mock.fetchone.return_value = (3,)
file = File(media_repo, "mxid123", 1600000000, "abc123")
file = File(media_repo, "mxid123", 1600000000, "abc123", 1000)
assert file.has_thumbnail() == 3 # noqa PLR2004
cursor_mock.execute.assert_called_with(
media_repo._execute.assert_called_with(
"SELECT COUNT(media_id) FROM mediaapi_thumbnail WHERE media_id = %s;",
("mxid123",),
)
@ -414,5 +418,5 @@ def test_file_has_no_thumbnails(media_repo: MediaRepository, mock_db_conn: Tuple
_, cursor_mock = mock_db_conn
cursor_mock.fetchone.return_value = None
file = File(media_repo, "mxid123", 1600000000, "abc123")
file = File(media_repo, "mxid123", 1600000000, "abc123", 1000)
assert file.has_thumbnail() == 0