feat: add support for sqlite3 and multiple databases #2
205
cleanmedia.py
205
cleanmedia.py
@ -21,11 +21,12 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import sqlite3
|
||||||
import sys
|
import sys
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple, TypeAlias, Union
|
from typing import Any, Callable, List, Mapping, Sequence, Tuple, TypeAlias, Union
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import psycopg2
|
import psycopg2
|
||||||
@ -35,7 +36,9 @@ except ImportError as err:
|
|||||||
raise ImportError("Required dependencies not found. Please install psycopg2 and pyyaml.") from err
|
raise ImportError("Required dependencies not found. Please install psycopg2 and pyyaml.") from err
|
||||||
|
|
||||||
# Type aliases
|
# Type aliases
|
||||||
DBConnection: TypeAlias = psycopg2.extensions.connection
|
DBConnection: TypeAlias = Union[sqlite3.Connection, psycopg2.extensions.connection]
|
||||||
|
DBCursor = Union[sqlite3.Cursor, psycopg2.extensions.cursor]
|
||||||
|
Params = Union[Sequence[Any], Mapping[str, Any]]
|
||||||
MediaID: TypeAlias = str
|
MediaID: TypeAlias = str
|
||||||
Timestamp: TypeAlias = int
|
Timestamp: TypeAlias = int
|
||||||
Base64Hash: TypeAlias = str
|
Base64Hash: TypeAlias = str
|
||||||
@ -51,12 +54,14 @@ class File:
|
|||||||
media_id: MediaID,
|
media_id: MediaID,
|
||||||
creation_ts: Timestamp,
|
creation_ts: Timestamp,
|
||||||
base64hash: Base64Hash,
|
base64hash: Base64Hash,
|
||||||
|
file_size: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize a File object."""
|
"""Initialize a File object."""
|
||||||
self.repo = media_repo
|
self.repo = media_repo
|
||||||
self.media_id = media_id
|
self.media_id = media_id
|
||||||
self.create_date = datetime.fromtimestamp(creation_ts)
|
self.create_date = datetime.fromtimestamp(creation_ts)
|
||||||
self.base64hash = base64hash
|
self.base64hash = base64hash
|
||||||
|
self.file_size = file_size
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def fullpath(self) -> Path | None:
|
def fullpath(self) -> Path | None:
|
||||||
@ -110,12 +115,11 @@ class File:
|
|||||||
Returns:
|
Returns:
|
||||||
True if database entries were deleted successfully
|
True if database entries were deleted successfully
|
||||||
"""
|
"""
|
||||||
with self.repo.conn.cursor() as cur:
|
cur = self.repo._execute("DELETE from mediaapi_thumbnail WHERE media_id=%s;", (self.media_id,))
|
||||||
cur.execute("DELETE from mediaapi_thumbnail WHERE media_id=%s;", (self.media_id,))
|
num_thumbnails = cur.rowcount
|
||||||
num_thumbnails = cur.rowcount
|
self.repo._execute("DELETE from mediaapi_media_repository WHERE media_id=%s;", (self.media_id,))
|
||||||
cur.execute("DELETE from mediaapi_media_repository WHERE media_id=%s;", (self.media_id,))
|
num_media = cur.rowcount
|
||||||
num_media = cur.rowcount
|
self.repo.conn["media_api"].commit()
|
||||||
self.repo.conn.commit()
|
|
||||||
logging.debug(f"Deleted {num_media} + {num_thumbnails} db entries for media id {self.media_id}")
|
logging.debug(f"Deleted {num_media} + {num_thumbnails} db entries for media id {self.media_id}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -135,21 +139,20 @@ class File:
|
|||||||
Returns:
|
Returns:
|
||||||
Number of thumbnails
|
Number of thumbnails
|
||||||
"""
|
"""
|
||||||
with self.repo.conn.cursor() as cur:
|
cur = self.repo._execute("SELECT COUNT(media_id) FROM mediaapi_thumbnail WHERE media_id = %s;", (self.media_id,))
|
||||||
cur.execute("SELECT COUNT(media_id) FROM mediaapi_thumbnail WHERE media_id = %s;", (self.media_id,))
|
row = cur.fetchone()
|
||||||
row = cur.fetchone()
|
return int(row[0]) if row else 0
|
||||||
return int(row[0]) if row else 0
|
|
||||||
|
|
||||||
|
|
||||||
class MediaRepository:
|
class MediaRepository:
|
||||||
"""Handle media storage and retrieval for a Dendrite server."""
|
"""Handle media storage and retrieval for a Dendrite server."""
|
||||||
|
|
||||||
def __init__(self, media_path: Path, connection_string: str) -> None:
|
def __init__(self, media_path: Path, connection_strings: dict[str, str]) -> None:
|
||||||
"""Initialize MediaRepository.
|
"""Initialize MediaRepository.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
media_path: Path to media storage directory
|
media_path: Path to media storage directory
|
||||||
connection_string: PostgreSQL connection string
|
connection_strings: Dictionary of db connection strings
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If media_path doesn't exist or connection string is invalid
|
ValueError: If media_path doesn't exist or connection string is invalid
|
||||||
@ -157,8 +160,26 @@ class MediaRepository:
|
|||||||
self._validate_media_path(media_path)
|
self._validate_media_path(media_path)
|
||||||
self.media_path = media_path
|
self.media_path = media_path
|
||||||
self._avatar_media_ids: List[MediaID] = []
|
self._avatar_media_ids: List[MediaID] = []
|
||||||
self.db_conn_string = connection_string
|
self.conn: dict[str, DBConnection] = {}
|
||||||
self.conn = self.connect_db()
|
|
||||||
|
for db_type, conn_string in connection_strings.items():
|
||||||
|
self.conn[db_type] = self.connect_db(conn_string)
|
||||||
|
|
||||||
|
def _execute(self, query: str, params: Params = (), db_type: str = "media_api") -> DBCursor:
|
||||||
|
paramstyle = getattr(self.conn, "paramstyle", "format")
|
||||||
|
query = self._adjust_paramstyle(query, paramstyle)
|
||||||
|
cur = self.conn[db_type].cursor()
|
||||||
|
try:
|
||||||
|
cur.execute(query, params)
|
||||||
|
return cur
|
||||||
|
except Exception as e:
|
||||||
|
cur.close()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _adjust_paramstyle(self, query: str, paramstyle: str) -> str:
|
||||||
|
if paramstyle == "qmark":
|
||||||
|
return query.replace("%s", "?")
|
||||||
|
return query
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _validate_media_path(path: Path) -> None:
|
def _validate_media_path(path: Path) -> None:
|
||||||
@ -167,28 +188,33 @@ class MediaRepository:
|
|||||||
if not path.is_dir():
|
if not path.is_dir():
|
||||||
raise ValueError("Media directory not found")
|
raise ValueError("Media directory not found")
|
||||||
|
|
||||||
def connect_db(self) -> DBConnection:
|
def connect_db(self, connection_string: str) -> DBConnection:
|
||||||
"""Establish database connection.
|
"""Establish database connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection_string: db connection string
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
PostgreSQL connection object
|
PostgreSQL connection object
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If connection string is invalid
|
ValueError: If connection string is invalid
|
||||||
"""
|
"""
|
||||||
if not self.db_conn_string or not self.db_conn_string.startswith(("postgres://", "postgresql://")):
|
if connection_string.startswith(("postgres://", "postgresql://")):
|
||||||
raise ValueError("Invalid PostgreSQL connection string")
|
return psycopg2.connect(connection_string)
|
||||||
return psycopg2.connect(self.db_conn_string)
|
if connection_string.startswith(("file:", "sqlite:")):
|
||||||
|
return sqlite3.connect(connection_string.removeprefix("file:"))
|
||||||
|
raise ValueError("Invalid PostgreSQL connection string")
|
||||||
|
|
||||||
def get_single_media(self, mxid: MediaID) -> File | None:
|
def get_single_media(self, mxid: MediaID) -> File | None:
|
||||||
"""Retrieve a single media file by ID."""
|
"""Retrieve a single media file by ID."""
|
||||||
with self.conn.cursor() as cur:
|
cur = self._execute(
|
||||||
cur.execute(
|
"""SELECT media_id, creation_ts, base64hash, file_size_bytes
|
||||||
"SELECT media_id, creation_ts, base64hash from mediaapi_media_repository WHERE media_id = %s;",
|
from mediaapi_media_repository WHERE media_id = %s;""",
|
||||||
(mxid,),
|
(mxid,),
|
||||||
)
|
)
|
||||||
row = cur.fetchone()
|
row = cur.fetchone()
|
||||||
return File(self, row[0], row[1] // 1000, row[2]) if row else None
|
return File(self, row[0], row[1] // 1000, row[2], row[3]) if row else None
|
||||||
|
|
||||||
def get_local_user_media(self, user_id: UserID) -> List[File]:
|
def get_local_user_media(self, user_id: UserID) -> List[File]:
|
||||||
"""Get all media files created by a local user.
|
"""Get all media files created by a local user.
|
||||||
@ -199,12 +225,12 @@ class MediaRepository:
|
|||||||
Returns:
|
Returns:
|
||||||
List of File objects
|
List of File objects
|
||||||
"""
|
"""
|
||||||
with self.conn.cursor() as cur:
|
cur = self._execute(
|
||||||
cur.execute(
|
"""SELECT media_id, creation_ts, base64hash, file_size_bytes
|
||||||
"SELECT media_id, creation_ts, base64hash FROM mediaapi_media_repository WHERE user_id = %s;",
|
FROM mediaapi_media_repository WHERE user_id = %s;""",
|
||||||
(user_id,),
|
(user_id,),
|
||||||
)
|
)
|
||||||
return [File(self, row[0], row[1] // 1000, row[2]) for row in cur.fetchall()]
|
return [File(self, row[0], row[1] // 1000, row[2], row[3]) for row in cur.fetchall()]
|
||||||
|
|
||||||
def get_all_media(self, local: bool = False) -> List[File]:
|
def get_all_media(self, local: bool = False) -> List[File]:
|
||||||
"""Get all media files or only remote ones.
|
"""Get all media files or only remote ones.
|
||||||
@ -215,13 +241,12 @@ class MediaRepository:
|
|||||||
Returns:
|
Returns:
|
||||||
List of File objects
|
List of File objects
|
||||||
"""
|
"""
|
||||||
with self.conn.cursor() as cur:
|
query = """SELECT media_id, creation_ts, base64hash, file_size_bytes
|
||||||
query = """SELECT media_id, creation_ts, base64hash
|
FROM mediaapi_media_repository"""
|
||||||
FROM mediaapi_media_repository"""
|
if not local:
|
||||||
if not local:
|
query += " WHERE user_id = ''"
|
||||||
query += " WHERE user_id = ''"
|
cur = self._execute(query)
|
||||||
cur.execute(query)
|
return [File(self, row[0], row[1] // 1000, row[2], row[3]) for row in cur.fetchall()]
|
||||||
return [File(self, row[0], row[1] // 1000, row[2]) for row in cur.fetchall()]
|
|
||||||
|
|
||||||
def get_avatar_images(self) -> List[MediaID]:
|
def get_avatar_images(self) -> List[MediaID]:
|
||||||
"""Get media IDs of current avatar images.
|
"""Get media IDs of current avatar images.
|
||||||
@ -229,30 +254,27 @@ class MediaRepository:
|
|||||||
Returns:
|
Returns:
|
||||||
List of media IDs
|
List of media IDs
|
||||||
"""
|
"""
|
||||||
with self.conn.cursor() as cur:
|
cur = self._execute("SELECT avatar_url FROM userapi_profiles WHERE avatar_url > '';", db_type="user_api")
|
||||||
cur.execute("SELECT avatar_url FROM userapi_profiles WHERE avatar_url > '';")
|
media_ids = []
|
||||||
media_ids = []
|
for (url,) in cur.fetchall():
|
||||||
for (url,) in cur.fetchall():
|
try:
|
||||||
try:
|
media_ids.append(url[url.rindex("/") + 1 :])
|
||||||
media_ids.append(url[url.rindex("/") + 1 :])
|
except ValueError:
|
||||||
except ValueError:
|
logging.warning("Invalid avatar URL: %s", url)
|
||||||
logging.warning("Invalid avatar URL: %s", url)
|
self._avatar_media_ids = media_ids
|
||||||
self._avatar_media_ids = media_ids
|
return media_ids
|
||||||
return media_ids
|
|
||||||
|
|
||||||
def sanity_check_thumbnails(self) -> None:
|
def sanity_check_thumbnails(self) -> None:
|
||||||
"""Check for orphaned thumbnail entries in database."""
|
"""Check for orphaned thumbnail entries in database."""
|
||||||
with self.conn.cursor() as cur:
|
cur = self._execute(
|
||||||
cur.execute(
|
"""SELECT COUNT(media_id) FROM mediaapi_thumbnail
|
||||||
"""SELECT COUNT(media_id) FROM mediaapi_thumbnail
|
WHERE NOT EXISTS (SELECT media_id FROM mediaapi_media_repository);""",
|
||||||
WHERE NOT EXISTS (SELECT media_id FROM mediaapi_media_repository);""",
|
)
|
||||||
|
if (row := cur.fetchone()) and (count := row[0]):
|
||||||
|
logging.error(
|
||||||
|
"You have %d thumbnails in your db that do not refer to media. This needs fixing (we don't do that)!",
|
||||||
|
count,
|
||||||
)
|
)
|
||||||
if (row := cur.fetchone()) and (count := row[0]):
|
|
||||||
logging.error(
|
|
||||||
"You have %d thumbnails in your db that do not refer to media. "
|
|
||||||
"This needs fixing (we don't do that)!",
|
|
||||||
count,
|
|
||||||
)
|
|
||||||
|
|
||||||
def clean_media_files(self, days: int, local: bool = False, dryrun: bool = False) -> int:
|
def clean_media_files(self, days: int, local: bool = False, dryrun: bool = False) -> int:
|
||||||
"""Remove old media files.
|
"""Remove old media files.
|
||||||
@ -276,8 +298,9 @@ class MediaRepository:
|
|||||||
for f in self.get_all_media(local)
|
for f in self.get_all_media(local)
|
||||||
if f.media_id not in self._avatar_media_ids and f.create_date < cutoff_date
|
if f.media_id not in self._avatar_media_ids and f.create_date < cutoff_date
|
||||||
]
|
]
|
||||||
|
file_size_counter = 0
|
||||||
for file in files_to_delete:
|
for file in files_to_delete:
|
||||||
|
file_size_counter += file.file_size
|
||||||
if dryrun:
|
if dryrun:
|
||||||
logging.info(f"Would delete file {file.media_id} at {file.fullpath}")
|
logging.info(f"Would delete file {file.media_id} at {file.fullpath}")
|
||||||
if not file.exists():
|
if not file.exists():
|
||||||
@ -286,18 +309,18 @@ class MediaRepository:
|
|||||||
file.delete()
|
file.delete()
|
||||||
|
|
||||||
action = "Would have deleted" if dryrun else "Deleted"
|
action = "Would have deleted" if dryrun else "Deleted"
|
||||||
logging.info("%s %d files", action, len(files_to_delete))
|
logging.info("%s %d files, in total %s", action, len(files_to_delete), sizeof_fmt(file_size_counter))
|
||||||
return len(files_to_delete)
|
return len(files_to_delete)
|
||||||
|
|
||||||
|
|
||||||
def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str]:
|
def read_config(conf_file: Union[str, Path]) -> Tuple[Path, dict[str, str]]:
|
||||||
"""Read database credentials and media path from config.
|
"""Read database credentials and media path from config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
conf_file: Path to Dendrite YAML config file
|
conf_file: Path to Dendrite YAML config file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (media_path, connection_string)
|
Tuple of (media_path, connection_strings)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
SystemExit: If config file is invalid or missing required fields
|
SystemExit: If config file is invalid or missing required fields
|
||||||
@ -309,27 +332,51 @@ def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str]:
|
|||||||
logging.error("Config file %s not found. Use --help for usage.", conf_file)
|
logging.error("Config file %s not found. Use --help for usage.", conf_file)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if "media_api" not in config:
|
global_conn_string = config.get("global", {}).get("database", {}).get("connection_string")
|
||||||
logging.error("Missing media_api section in config")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
conn_string = None
|
db_types: dict[str, Callable[[dict[str, Any]], dict[str, Any]]] = {
|
||||||
if "global" in config and "database" in config["global"]:
|
"media_api": lambda x: x.get("media_api", {}).get("database", {}),
|
||||||
conn_string = config["global"]["database"].get("connection_string")
|
"user_api": lambda x: x.get("user_api", {}).get("account_database", {}),
|
||||||
elif "database" in config["media_api"]:
|
}
|
||||||
logging.debug("Using database config from media_api section")
|
|
||||||
conn_string = config["media_api"]["database"].get("connection_string")
|
|
||||||
|
|
||||||
if not conn_string:
|
conns = {}
|
||||||
logging.error("Database connection string not found in config")
|
for db_type, config_key in db_types.items():
|
||||||
sys.exit(1)
|
value = config_key(config).get("connection_string")
|
||||||
|
conns[db_type] = value if value else global_conn_string
|
||||||
|
if not conns[db_type]:
|
||||||
|
logging.error(f"Database {db_type} connection string not found in config")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
base_path = config["media_api"].get("base_path")
|
base_path = config["media_api"].get("base_path")
|
||||||
if not base_path:
|
if not base_path:
|
||||||
logging.error("base_path not found in media_api config")
|
logging.error("base_path not found in media_api config")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
return Path(base_path), conn_string
|
return Path(base_path), conns
|
||||||
|
|
||||||
|
|
||||||
|
def sizeof_fmt(num: Union[int, float], suffix: str = "B") -> str:
|
||||||
|
"""
|
||||||
|
Convert a number of bytes (or other units) into a human-readable format using binary prefixes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num (int | float): The number to format, typically a size in bytes.
|
||||||
|
suffix (str): The suffix to append (default is 'B' for bytes).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A string representing the human-readable size (e.g. '1.5MiB', '42.0KiB').
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
Uses binary (base-1024) units: Ki, Mi, Gi, Ti, etc.
|
||||||
|
Automatically chooses the appropriate unit based on the input value.
|
||||||
|
"""
|
||||||
|
BASE = 1024.0
|
||||||
|
|
||||||
|
for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"):
|
||||||
|
if abs(num) < BASE:
|
||||||
|
return f"{num:3.1f}{unit}{suffix}"
|
||||||
|
num /= 1024.0
|
||||||
|
return f"{num:.1f}Yi{suffix}"
|
||||||
|
|
||||||
|
|
||||||
def parse_options() -> argparse.Namespace:
|
def parse_options() -> argparse.Namespace:
|
||||||
@ -363,8 +410,8 @@ def parse_options() -> argparse.Namespace:
|
|||||||
def main() -> None:
|
def main() -> None:
|
||||||
"""Execute the media cleanup process."""
|
"""Execute the media cleanup process."""
|
||||||
args = parse_options()
|
args = parse_options()
|
||||||
media_path, conn_string = read_config(args.config)
|
media_path, conn_strings = read_config(args.config)
|
||||||
repo = MediaRepository(media_path, conn_string)
|
repo = MediaRepository(media_path, conn_strings)
|
||||||
|
|
||||||
if args.mxid:
|
if args.mxid:
|
||||||
process_single_media(repo, args)
|
process_single_media(repo, args)
|
||||||
|
@ -42,17 +42,18 @@ def media_repo(tmp_path: Any, mock_db_conn: Tuple[Any, Any], mocker: MockerFixtu
|
|||||||
Returns:
|
Returns:
|
||||||
Configured MediaRepository instance
|
Configured MediaRepository instance
|
||||||
"""
|
"""
|
||||||
conn_mock, _ = mock_db_conn
|
conn_mock, cursor = mock_db_conn
|
||||||
media_path = tmp_path / "media"
|
media_path = tmp_path / "media"
|
||||||
media_path.mkdir()
|
media_path.mkdir()
|
||||||
mocker.patch("cleanmedia.MediaRepository.connect_db", return_value=conn_mock)
|
mocker.patch("cleanmedia.MediaRepository.connect_db", return_value=conn_mock)
|
||||||
return MediaRepository(media_path, "postgresql://fake")
|
mocker.patch("cleanmedia.MediaRepository._execute", return_value=cursor)
|
||||||
|
return MediaRepository(media_path, {"media_api": "postgresql://fake", "user_api": "postgresql://fake"})
|
||||||
|
|
||||||
|
|
||||||
def test_file_init(mocker: MockerFixture) -> None:
|
def test_file_init(mocker: MockerFixture) -> None:
|
||||||
"""Test File class initialization."""
|
"""Test File class initialization."""
|
||||||
repo = mocker.Mock()
|
repo = mocker.Mock()
|
||||||
file = File(repo, "mxid123", 1600000000, "base64hash123")
|
file = File(repo, "mxid123", 1600000000, "base64hash123", 1000)
|
||||||
assert file.media_id == "mxid123"
|
assert file.media_id == "mxid123"
|
||||||
assert file.create_date == datetime.fromtimestamp(1600000000)
|
assert file.create_date == datetime.fromtimestamp(1600000000)
|
||||||
assert file.base64hash == "base64hash123"
|
assert file.base64hash == "base64hash123"
|
||||||
@ -61,26 +62,26 @@ def test_file_init(mocker: MockerFixture) -> None:
|
|||||||
|
|
||||||
def test_file_fullpath(media_repo: MediaRepository) -> None:
|
def test_file_fullpath(media_repo: MediaRepository) -> None:
|
||||||
"""Test File.fullpath property returns correct path."""
|
"""Test File.fullpath property returns correct path."""
|
||||||
file = File(media_repo, "mxid123", 1600000000, "abc123")
|
file = File(media_repo, "mxid123", 1600000000, "abc123", 1000)
|
||||||
expected_path = media_repo.media_path / "a" / "b" / "c123"
|
expected_path = media_repo.media_path / "a" / "b" / "c123"
|
||||||
assert file.fullpath == expected_path
|
assert file.fullpath == expected_path
|
||||||
|
|
||||||
|
|
||||||
def test_file_exists_no_path(media_repo: MediaRepository) -> None:
|
def test_file_exists_no_path(media_repo: MediaRepository) -> None:
|
||||||
"""Test File.exists returns False when fullpath is None."""
|
"""Test File.exists returns False when fullpath is None."""
|
||||||
file = File(media_repo, "mxid123", 1600000000, "") # Empty hash ensures fullpath is None
|
file = File(media_repo, "mxid123", 1600000000, "", 1000) # Empty hash ensures fullpath is None
|
||||||
assert file.exists() is False
|
assert file.exists() is False
|
||||||
|
|
||||||
|
|
||||||
def test_file_delete_no_path(media_repo: MediaRepository) -> None:
|
def test_file_delete_no_path(media_repo: MediaRepository) -> None:
|
||||||
"""Test File._delete_files when file path is None."""
|
"""Test File._delete_files when file path is None."""
|
||||||
file = File(media_repo, "mxid123", 1600000000, "")
|
file = File(media_repo, "mxid123", 1600000000, "", 1000)
|
||||||
assert file._delete_files() is False
|
assert file._delete_files() is False
|
||||||
|
|
||||||
|
|
||||||
def test_file_delete_oserror(media_repo: MediaRepository, mocker: MockerFixture, caplog: Any) -> None:
|
def test_file_delete_oserror(media_repo: MediaRepository, mocker: MockerFixture, caplog: Any) -> None:
|
||||||
"""Test File._delete_files when OSError occurs."""
|
"""Test File._delete_files when OSError occurs."""
|
||||||
file = File(media_repo, "mxid123", 1600000000, "abc123")
|
file = File(media_repo, "mxid123", 1600000000, "abc123", 1000)
|
||||||
|
|
||||||
# Create directory structure
|
# Create directory structure
|
||||||
file_path = media_repo.media_path / "a" / "b" / "c123"
|
file_path = media_repo.media_path / "a" / "b" / "c123"
|
||||||
@ -96,13 +97,13 @@ def test_file_delete_oserror(media_repo: MediaRepository, mocker: MockerFixture,
|
|||||||
|
|
||||||
def test_file_fullpath_none_if_no_hash(media_repo: MediaRepository) -> None:
|
def test_file_fullpath_none_if_no_hash(media_repo: MediaRepository) -> None:
|
||||||
"""Test File.fullpath returns None when hash is empty."""
|
"""Test File.fullpath returns None when hash is empty."""
|
||||||
file = File(media_repo, "mxid123", 1600000000, "")
|
file = File(media_repo, "mxid123", 1600000000, "", 1000)
|
||||||
assert file.fullpath is None
|
assert file.fullpath is None
|
||||||
|
|
||||||
|
|
||||||
def test_file_exists(media_repo: MediaRepository) -> None:
|
def test_file_exists(media_repo: MediaRepository) -> None:
|
||||||
"""Test File.exists returns True when file exists."""
|
"""Test File.exists returns True when file exists."""
|
||||||
file = File(media_repo, "mxid123", 1600000000, "abc123")
|
file = File(media_repo, "mxid123", 1600000000, "abc123", 1000)
|
||||||
file_path = media_repo.media_path / "a" / "b" / "c123"
|
file_path = media_repo.media_path / "a" / "b" / "c123"
|
||||||
file_path.mkdir(parents=True)
|
file_path.mkdir(parents=True)
|
||||||
(file_path / "file").touch()
|
(file_path / "file").touch()
|
||||||
@ -111,14 +112,14 @@ def test_file_exists(media_repo: MediaRepository) -> None:
|
|||||||
|
|
||||||
def test_file_not_exists(media_repo: MediaRepository) -> None:
|
def test_file_not_exists(media_repo: MediaRepository) -> None:
|
||||||
"""Test File.exists returns False when file doesn't exist."""
|
"""Test File.exists returns False when file doesn't exist."""
|
||||||
file = File(media_repo, "mxid123", 1600000000, "abc123")
|
file = File(media_repo, "mxid123", 1600000000, "abc123", 1000)
|
||||||
assert file.exists() is False
|
assert file.exists() is False
|
||||||
|
|
||||||
|
|
||||||
def test_file_delete(media_repo: MediaRepository, mock_db_conn: Tuple[Any, Any]) -> None:
|
def test_file_delete(media_repo: MediaRepository, mock_db_conn: Tuple[Any, Any]) -> None:
|
||||||
"""Test File.delete removes files and database entries."""
|
"""Test File.delete removes files and database entries."""
|
||||||
_, cursor_mock = mock_db_conn
|
_, cursor_mock = mock_db_conn
|
||||||
file = File(media_repo, "mxid123", 1600000000, "abc123")
|
file = File(media_repo, "mxid123", 1600000000, "abc123", 1000)
|
||||||
|
|
||||||
file_path = media_repo.media_path / "a" / "b" / "c123"
|
file_path = media_repo.media_path / "a" / "b" / "c123"
|
||||||
file_path.mkdir(parents=True)
|
file_path.mkdir(parents=True)
|
||||||
@ -128,22 +129,22 @@ def test_file_delete(media_repo: MediaRepository, mock_db_conn: Tuple[Any, Any])
|
|||||||
assert file.delete() is True
|
assert file.delete() is True
|
||||||
assert not file_path.exists()
|
assert not file_path.exists()
|
||||||
|
|
||||||
cursor_mock.execute.assert_any_call("DELETE from mediaapi_thumbnail WHERE media_id=%s;", ("mxid123",))
|
media_repo._execute.assert_any_call("DELETE from mediaapi_thumbnail WHERE media_id=%s;", ("mxid123",)) # type: ignore
|
||||||
cursor_mock.execute.assert_any_call("DELETE from mediaapi_media_repository WHERE media_id=%s;", ("mxid123",))
|
media_repo._execute.assert_any_call("DELETE from mediaapi_media_repository WHERE media_id=%s;", ("mxid123",)) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def test_get_single_media(media_repo: MediaRepository, mock_db_conn: Tuple[Any, Any]) -> None:
|
def test_get_single_media(media_repo: MediaRepository, mock_db_conn: Tuple[Any, Any]) -> None:
|
||||||
"""Test MediaRepository.get_single_media returns correct File object."""
|
"""Test MediaRepository.get_single_media returns correct File object."""
|
||||||
_, cursor_mock = mock_db_conn
|
_, cursor_mock = mock_db_conn
|
||||||
cursor_mock.fetchone.return_value = ("mxid123", 1600000000000, "abc123")
|
cursor_mock.fetchone.return_value = ("mxid123", 1600000000000, "abc123", 1000)
|
||||||
|
|
||||||
file = media_repo.get_single_media("mxid123")
|
file = media_repo.get_single_media("mxid123")
|
||||||
assert file is not None
|
assert file is not None
|
||||||
assert file.media_id == "mxid123"
|
assert file.media_id == "mxid123"
|
||||||
assert file.base64hash == "abc123"
|
assert file.base64hash == "abc123"
|
||||||
|
|
||||||
cursor_mock.execute.assert_called_with(
|
media_repo._execute.assert_called_with( # type: ignore
|
||||||
"SELECT media_id, creation_ts, base64hash from mediaapi_media_repository WHERE media_id = %s;",
|
"SELECT media_id, creation_ts, base64hash, file_size_bytes\nfrom mediaapi_media_repository WHERE media_id = %s;",
|
||||||
("mxid123",),
|
("mxid123",),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -165,8 +166,8 @@ def test_clean_media_files(media_repo: MediaRepository, mock_db_conn: Tuple[Any,
|
|||||||
new_date = int((datetime.now() - timedelta(days=1)).timestamp())
|
new_date = int((datetime.now() - timedelta(days=1)).timestamp())
|
||||||
|
|
||||||
cursor_mock.fetchall.return_value = [
|
cursor_mock.fetchall.return_value = [
|
||||||
("old_file", old_date * 1000, "abc123"),
|
("old_file", old_date * 1000, "abc123", 1000),
|
||||||
("new_file", new_date * 1000, "def456"),
|
("new_file", new_date * 1000, "def456", 1000),
|
||||||
]
|
]
|
||||||
|
|
||||||
media_repo._avatar_media_ids = []
|
media_repo._avatar_media_ids = []
|
||||||
@ -181,7 +182,7 @@ def test_clean_media_files_dryrun(media_repo: MediaRepository, mock_db_conn: Tup
|
|||||||
|
|
||||||
old_date = int((datetime.now() - timedelta(days=31)).timestamp())
|
old_date = int((datetime.now() - timedelta(days=31)).timestamp())
|
||||||
cursor_mock.fetchall.return_value = [
|
cursor_mock.fetchall.return_value = [
|
||||||
("old_file", old_date * 1000, "abc123"),
|
("old_file", old_date * 1000, "abc123", 1000),
|
||||||
]
|
]
|
||||||
|
|
||||||
media_repo._avatar_media_ids = []
|
media_repo._avatar_media_ids = []
|
||||||
@ -252,24 +253,24 @@ def test_validate_media_path_relative(caplog: Any) -> None:
|
|||||||
def test_connect_db_success(mocker: MockerFixture) -> None:
|
def test_connect_db_success(mocker: MockerFixture) -> None:
|
||||||
"""Test successful database connection."""
|
"""Test successful database connection."""
|
||||||
mock_connect = mocker.patch("psycopg2.connect")
|
mock_connect = mocker.patch("psycopg2.connect")
|
||||||
repo = MediaRepository(Path("/tmp"), "postgresql://fake")
|
repo = MediaRepository(Path("/tmp"), {"media_api": "postgresql://fake"})
|
||||||
repo.connect_db()
|
repo.connect_db("postgresql://fake")
|
||||||
mock_connect.assert_called_with("postgresql://fake")
|
mock_connect.assert_called_with("postgresql://fake")
|
||||||
|
|
||||||
|
|
||||||
def test_connect_db_invalid_string(tmp_path: Path) -> None:
|
def test_connect_db_invalid_string(tmp_path: Path) -> None:
|
||||||
"""Test connect_db with invalid connection string."""
|
"""Test connect_db with invalid connection string."""
|
||||||
with pytest.raises(ValueError, match="Invalid PostgreSQL connection string"):
|
with pytest.raises(ValueError, match="Invalid PostgreSQL connection string"):
|
||||||
repo = MediaRepository(tmp_path, "invalid")
|
repo = MediaRepository(tmp_path, {"media_api": "invalid"})
|
||||||
repo.connect_db()
|
repo.connect_db("invalid")
|
||||||
|
|
||||||
|
|
||||||
def test_get_local_user_media(media_repo: MediaRepository, mock_db_conn: Tuple[Any, Any]) -> None:
|
def test_get_local_user_media(media_repo: MediaRepository, mock_db_conn: Tuple[Any, Any]) -> None:
|
||||||
"""Test get_local_user_media returns correct files."""
|
"""Test get_local_user_media returns correct files."""
|
||||||
_, cursor_mock = mock_db_conn
|
_, cursor_mock = mock_db_conn
|
||||||
cursor_mock.fetchall.return_value = [
|
cursor_mock.fetchall.return_value = [
|
||||||
("media1", 1600000000000, "hash1"),
|
("media1", 1600000000000, "hash1", 1000),
|
||||||
("media2", 1600000000000, "hash2"),
|
("media2", 1600000000000, "hash2", 1000),
|
||||||
]
|
]
|
||||||
|
|
||||||
files = media_repo.get_local_user_media("@user:domain.com")
|
files = media_repo.get_local_user_media("@user:domain.com")
|
||||||
@ -277,8 +278,8 @@ def test_get_local_user_media(media_repo: MediaRepository, mock_db_conn: Tuple[A
|
|||||||
assert files[0].media_id == "media1"
|
assert files[0].media_id == "media1"
|
||||||
assert files[1].media_id == "media2"
|
assert files[1].media_id == "media2"
|
||||||
|
|
||||||
cursor_mock.execute.assert_called_with(
|
media_repo._execute.assert_called_with( # type: ignore
|
||||||
"SELECT media_id, creation_ts, base64hash FROM mediaapi_media_repository WHERE user_id = %s;",
|
"SELECT media_id, creation_ts, base64hash, file_size_bytes\nFROM mediaapi_media_repository WHERE user_id = %s;",
|
||||||
("@user:domain.com",),
|
("@user:domain.com",),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -333,15 +334,18 @@ def test_read_config_valid(tmp_path: Path) -> None:
|
|||||||
"""Test read_config with valid config."""
|
"""Test read_config with valid config."""
|
||||||
config_file = tmp_path / "config.yaml"
|
config_file = tmp_path / "config.yaml"
|
||||||
config_file.write_text("""
|
config_file.write_text("""
|
||||||
|
global:
|
||||||
|
database:
|
||||||
|
connection_string: postgresql://global/db
|
||||||
media_api:
|
media_api:
|
||||||
base_path: /media/path
|
base_path: /media/path
|
||||||
database:
|
database:
|
||||||
connection_string: postgresql://user:pass@localhost/db
|
connection_string: postgresql://user:pass@localhost/db
|
||||||
""")
|
""")
|
||||||
|
|
||||||
path, conn_string = read_config(config_file)
|
path, conn_strings = read_config(config_file)
|
||||||
assert path == Path("/media/path")
|
assert path == Path("/media/path")
|
||||||
assert conn_string == "postgresql://user:pass@localhost/db"
|
assert conn_strings["media_api"] == "postgresql://user:pass@localhost/db"
|
||||||
|
|
||||||
|
|
||||||
def test_read_config_global_database(tmp_path: Path) -> None:
|
def test_read_config_global_database(tmp_path: Path) -> None:
|
||||||
@ -355,8 +359,8 @@ global:
|
|||||||
connection_string: postgresql://global/db
|
connection_string: postgresql://global/db
|
||||||
""")
|
""")
|
||||||
|
|
||||||
path, conn_string = read_config(config_file)
|
_, conn_string = read_config(config_file)
|
||||||
assert conn_string == "postgresql://global/db"
|
assert conn_string["media_api"] == "postgresql://global/db"
|
||||||
|
|
||||||
|
|
||||||
def test_read_config_missing_conn_string(tmp_path: Path) -> None:
|
def test_read_config_missing_conn_string(tmp_path: Path) -> None:
|
||||||
@ -401,9 +405,9 @@ def test_file_has_thumbnails(media_repo: MediaRepository, mock_db_conn: Tuple[An
|
|||||||
_, cursor_mock = mock_db_conn
|
_, cursor_mock = mock_db_conn
|
||||||
cursor_mock.fetchone.return_value = (3,)
|
cursor_mock.fetchone.return_value = (3,)
|
||||||
|
|
||||||
file = File(media_repo, "mxid123", 1600000000, "abc123")
|
file = File(media_repo, "mxid123", 1600000000, "abc123", 1000)
|
||||||
assert file.has_thumbnail() == 3 # noqa PLR2004
|
assert file.has_thumbnail() == 3 # noqa PLR2004
|
||||||
cursor_mock.execute.assert_called_with(
|
media_repo._execute.assert_called_with( # type: ignore
|
||||||
"SELECT COUNT(media_id) FROM mediaapi_thumbnail WHERE media_id = %s;",
|
"SELECT COUNT(media_id) FROM mediaapi_thumbnail WHERE media_id = %s;",
|
||||||
("mxid123",),
|
("mxid123",),
|
||||||
)
|
)
|
||||||
@ -414,5 +418,5 @@ def test_file_has_no_thumbnails(media_repo: MediaRepository, mock_db_conn: Tuple
|
|||||||
_, cursor_mock = mock_db_conn
|
_, cursor_mock = mock_db_conn
|
||||||
cursor_mock.fetchone.return_value = None
|
cursor_mock.fetchone.return_value = None
|
||||||
|
|
||||||
file = File(media_repo, "mxid123", 1600000000, "abc123")
|
file = File(media_repo, "mxid123", 1600000000, "abc123", 1000)
|
||||||
assert file.has_thumbnail() == 0
|
assert file.has_thumbnail() == 0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user