feat: add support for sqlite3 and multiple databases #2
137
cleanmedia.py
137
cleanmedia.py
@ -26,6 +26,7 @@ from datetime import datetime, timedelta
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, TypeAlias, Union
|
||||
import sqlite3
|
||||
|
||||
try:
|
||||
import psycopg2
|
||||
@ -144,7 +145,7 @@ class File:
|
||||
class MediaRepository:
|
||||
"""Handle media storage and retrieval for a Dendrite server."""
|
||||
|
||||
def __init__(self, media_path: Path, connection_string: str) -> None:
|
||||
def __init__(self, media_path: Path, connection_strings: dict[str, str]) -> None:
|
||||
"""Initialize MediaRepository.
|
||||
|
||||
Args:
|
||||
@ -157,8 +158,26 @@ class MediaRepository:
|
||||
self._validate_media_path(media_path)
|
||||
self.media_path = media_path
|
||||
self._avatar_media_ids: List[MediaID] = []
|
||||
self.db_conn_string = connection_string
|
||||
self.conn = self.connect_db()
|
||||
self.conn = {}
|
||||
|
||||
for db_type, conn_string in connection_strings.items():
|
||||
self.conn[db_type] = self.connect_db(conn_string)
|
||||
|
||||
def execute(self, query: str, params=(), db_type: str = "media_api"):
|
||||
paramstyle = getattr(self.conn, "paramstyle", "format")
|
||||
query = self._adjust_paramstyle(query, paramstyle)
|
||||
cur = self.conn[db_type].cursor()
|
||||
try:
|
||||
cur.execute(query, params)
|
||||
return cur
|
||||
except Exception as e:
|
||||
cur.close()
|
||||
raise e
|
||||
|
||||
def _adjust_paramstyle(self, query: str, paramstyle: str):
|
||||
if paramstyle == "qmark":
|
||||
return query.replace("%s", "?")
|
||||
return query
|
||||
|
||||
@staticmethod
|
||||
def _validate_media_path(path: Path) -> None:
|
||||
@ -167,7 +186,7 @@ class MediaRepository:
|
||||
if not path.is_dir():
|
||||
raise ValueError("Media directory not found")
|
||||
|
||||
def connect_db(self) -> DBConnection:
|
||||
def connect_db(self, connection_string: str) -> DBConnection:
|
||||
"""Establish database connection.
|
||||
|
||||
Returns:
|
||||
@ -176,19 +195,20 @@ class MediaRepository:
|
||||
Raises:
|
||||
ValueError: If connection string is invalid
|
||||
"""
|
||||
if not self.db_conn_string or not self.db_conn_string.startswith(("postgres://", "postgresql://")):
|
||||
raise ValueError("Invalid PostgreSQL connection string")
|
||||
return psycopg2.connect(self.db_conn_string)
|
||||
if connection_string.startswith(("postgres://", "postgresql://")):
|
||||
return psycopg2.connect(connection_string)
|
||||
if connection_string.startswith(("file:", "sqlite:")):
|
||||
return sqlite3.connect(connection_string.removeprefix("file:"))
|
||||
raise ValueError("Invalid PostgreSQL connection string")
|
||||
|
||||
def get_single_media(self, mxid: MediaID) -> File | None:
|
||||
"""Retrieve a single media file by ID."""
|
||||
with self.conn.cursor() as cur:
|
||||
cur.execute(
|
||||
cur = self.execute(
|
||||
"SELECT media_id, creation_ts, base64hash from mediaapi_media_repository WHERE media_id = %s;",
|
||||
(mxid,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
return File(self, row[0], row[1] // 1000, row[2]) if row else None
|
||||
row = cur.fetchone()
|
||||
return File(self, row[0], row[1] // 1000, row[2]) if row else None
|
||||
|
||||
def get_local_user_media(self, user_id: UserID) -> List[File]:
|
||||
"""Get all media files created by a local user.
|
||||
@ -199,12 +219,11 @@ class MediaRepository:
|
||||
Returns:
|
||||
List of File objects
|
||||
"""
|
||||
with self.conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT media_id, creation_ts, base64hash FROM mediaapi_media_repository WHERE user_id = %s;",
|
||||
(user_id,),
|
||||
)
|
||||
return [File(self, row[0], row[1] // 1000, row[2]) for row in cur.fetchall()]
|
||||
cur = self.execute(
|
||||
"SELECT media_id, creation_ts, base64hash FROM mediaapi_media_repository WHERE user_id = %s;",
|
||||
(user_id,),
|
||||
)
|
||||
return [File(self, row[0], row[1] // 1000, row[2]) for row in cur.fetchall()]
|
||||
|
||||
def get_all_media(self, local: bool = False) -> List[File]:
|
||||
"""Get all media files or only remote ones.
|
||||
@ -215,13 +234,12 @@ class MediaRepository:
|
||||
Returns:
|
||||
List of File objects
|
||||
"""
|
||||
with self.conn.cursor() as cur:
|
||||
query = """SELECT media_id, creation_ts, base64hash
|
||||
FROM mediaapi_media_repository"""
|
||||
if not local:
|
||||
query += " WHERE user_id = ''"
|
||||
cur.execute(query)
|
||||
return [File(self, row[0], row[1] // 1000, row[2]) for row in cur.fetchall()]
|
||||
query = """SELECT media_id, creation_ts, base64hash
|
||||
FROM mediaapi_media_repository"""
|
||||
if not local:
|
||||
query += " WHERE user_id = ''"
|
||||
cur = self.execute(query)
|
||||
return [File(self, row[0], row[1] // 1000, row[2]) for row in cur.fetchall()]
|
||||
|
||||
def get_avatar_images(self) -> List[MediaID]:
|
||||
"""Get media IDs of current avatar images.
|
||||
@ -229,30 +247,28 @@ class MediaRepository:
|
||||
Returns:
|
||||
List of media IDs
|
||||
"""
|
||||
with self.conn.cursor() as cur:
|
||||
cur.execute("SELECT avatar_url FROM userapi_profiles WHERE avatar_url > '';")
|
||||
media_ids = []
|
||||
for (url,) in cur.fetchall():
|
||||
try:
|
||||
media_ids.append(url[url.rindex("/") + 1 :])
|
||||
except ValueError:
|
||||
logging.warning("Invalid avatar URL: %s", url)
|
||||
self._avatar_media_ids = media_ids
|
||||
return media_ids
|
||||
cur = self.execute("SELECT avatar_url FROM userapi_profiles WHERE avatar_url > '';", db_type="user_api")
|
||||
media_ids = []
|
||||
for (url,) in cur.fetchall():
|
||||
try:
|
||||
media_ids.append(url[url.rindex("/") + 1 :])
|
||||
except ValueError:
|
||||
logging.warning("Invalid avatar URL: %s", url)
|
||||
self._avatar_media_ids = media_ids
|
||||
return media_ids
|
||||
|
||||
def sanity_check_thumbnails(self) -> None:
|
||||
"""Check for orphaned thumbnail entries in database."""
|
||||
with self.conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""SELECT COUNT(media_id) FROM mediaapi_thumbnail
|
||||
WHERE NOT EXISTS (SELECT media_id FROM mediaapi_media_repository);""",
|
||||
cur = self.execute(
|
||||
"""SELECT COUNT(media_id) FROM mediaapi_thumbnail
|
||||
WHERE NOT EXISTS (SELECT media_id FROM mediaapi_media_repository);""",
|
||||
)
|
||||
if (row := cur.fetchone()) and (count := row[0]):
|
||||
logging.error(
|
||||
"You have %d thumbnails in your db that do not refer to media. "
|
||||
"This needs fixing (we don't do that)!",
|
||||
count,
|
||||
)
|
||||
if (row := cur.fetchone()) and (count := row[0]):
|
||||
logging.error(
|
||||
"You have %d thumbnails in your db that do not refer to media. "
|
||||
"This needs fixing (we don't do that)!",
|
||||
count,
|
||||
)
|
||||
|
||||
def clean_media_files(self, days: int, local: bool = False, dryrun: bool = False) -> int:
|
||||
"""Remove old media files.
|
||||
@ -290,7 +306,7 @@ class MediaRepository:
|
||||
return len(files_to_delete)
|
||||
|
||||
|
||||
def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str]:
|
||||
def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str, dict[str, str]]:
|
||||
"""Read database credentials and media path from config.
|
||||
|
||||
Args:
|
||||
@ -309,27 +325,28 @@ def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str]:
|
||||
logging.error("Config file %s not found. Use --help for usage.", conf_file)
|
||||
sys.exit(1)
|
||||
|
||||
if "media_api" not in config:
|
||||
logging.error("Missing media_api section in config")
|
||||
sys.exit(1)
|
||||
|
||||
conn_string = None
|
||||
if "global" in config and "database" in config["global"]:
|
||||
conn_string = config["global"]["database"].get("connection_string")
|
||||
elif "database" in config["media_api"]:
|
||||
logging.debug("Using database config from media_api section")
|
||||
conn_string = config["media_api"]["database"].get("connection_string")
|
||||
global_conn_string = config.get("global",{}).get("database", {}).get("connection_string")
|
||||
|
||||
if not conn_string:
|
||||
logging.error("Database connection string not found in config")
|
||||
sys.exit(1)
|
||||
db_types = {
|
||||
"media_api": lambda x: x.get("media_api",{}).get("database", {}),
|
||||
"user_api": lambda x: x.get("user_api",{}).get("account_database", {}),
|
||||
}
|
||||
|
||||
conns = {}
|
||||
for db_type, config_key in db_types.items():
|
||||
value = config_key(config).get("connection_string")
|
||||
conns[db_type] = value if value else global_conn_string
|
||||
if not conns[db_type]:
|
||||
logging.error(f"Database {db_type} connection string not found in config")
|
||||
sys.exit(1)
|
||||
|
||||
base_path = config["media_api"].get("base_path")
|
||||
if not base_path:
|
||||
logging.error("base_path not found in media_api config")
|
||||
sys.exit(1)
|
||||
|
||||
return Path(base_path), conn_string
|
||||
return Path(base_path), conns
|
||||
|
||||
|
||||
def parse_options() -> argparse.Namespace:
|
||||
@ -363,8 +380,8 @@ def parse_options() -> argparse.Namespace:
|
||||
def main() -> None:
|
||||
"""Execute the media cleanup process."""
|
||||
args = parse_options()
|
||||
media_path, conn_string = read_config(args.config)
|
||||
repo = MediaRepository(media_path, conn_string)
|
||||
media_path, conn_strings = read_config(args.config)
|
||||
repo = MediaRepository(media_path, conn_strings)
|
||||
|
||||
if args.mxid:
|
||||
process_single_media(repo, args)
|
||||
|
Loading…
x
Reference in New Issue
Block a user