feat: support sqlite3 and multiple databases

This commit is contained in:
Christian Groschupp 2025-03-27 09:58:57 +01:00
parent f80b805a71
commit d318296e7e

View File

@ -26,6 +26,7 @@ from datetime import datetime, timedelta
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import List, Tuple, TypeAlias, Union from typing import List, Tuple, TypeAlias, Union
import sqlite3
try: try:
import psycopg2 import psycopg2
@ -144,7 +145,7 @@ class File:
class MediaRepository: class MediaRepository:
"""Handle media storage and retrieval for a Dendrite server.""" """Handle media storage and retrieval for a Dendrite server."""
def __init__(self, media_path: Path, connection_string: str) -> None: def __init__(self, media_path: Path, connection_strings: dict[str, str]) -> None:
"""Initialize MediaRepository. """Initialize MediaRepository.
Args: Args:
@ -157,8 +158,26 @@ class MediaRepository:
self._validate_media_path(media_path) self._validate_media_path(media_path)
self.media_path = media_path self.media_path = media_path
self._avatar_media_ids: List[MediaID] = [] self._avatar_media_ids: List[MediaID] = []
self.db_conn_string = connection_string self.conn = {}
self.conn = self.connect_db()
for db_type, conn_string in connection_strings.items():
self.conn[db_type] = self.connect_db(conn_string)
def execute(self, query: str, params=(), db_type: str = "media_api"):
paramstyle = getattr(self.conn, "paramstyle", "format")
query = self._adjust_paramstyle(query, paramstyle)
cur = self.conn[db_type].cursor()
try:
cur.execute(query, params)
return cur
except Exception as e:
cur.close()
raise e
def _adjust_paramstyle(self, query: str, paramstyle: str):
if paramstyle == "qmark":
return query.replace("%s", "?")
return query
@staticmethod @staticmethod
def _validate_media_path(path: Path) -> None: def _validate_media_path(path: Path) -> None:
@ -167,7 +186,7 @@ class MediaRepository:
if not path.is_dir(): if not path.is_dir():
raise ValueError("Media directory not found") raise ValueError("Media directory not found")
def connect_db(self) -> DBConnection: def connect_db(self, connection_string: str) -> DBConnection:
"""Establish database connection. """Establish database connection.
Returns: Returns:
@ -176,14 +195,15 @@ class MediaRepository:
Raises: Raises:
ValueError: If connection string is invalid ValueError: If connection string is invalid
""" """
if not self.db_conn_string or not self.db_conn_string.startswith(("postgres://", "postgresql://")): if connection_string.startswith(("postgres://", "postgresql://")):
return psycopg2.connect(connection_string)
if connection_string.startswith(("file:", "sqlite:")):
return sqlite3.connect(connection_string.removeprefix("file:"))
raise ValueError("Invalid PostgreSQL connection string") raise ValueError("Invalid PostgreSQL connection string")
return psycopg2.connect(self.db_conn_string)
def get_single_media(self, mxid: MediaID) -> File | None: def get_single_media(self, mxid: MediaID) -> File | None:
"""Retrieve a single media file by ID.""" """Retrieve a single media file by ID."""
with self.conn.cursor() as cur: cur = self.execute(
cur.execute(
"SELECT media_id, creation_ts, base64hash from mediaapi_media_repository WHERE media_id = %s;", "SELECT media_id, creation_ts, base64hash from mediaapi_media_repository WHERE media_id = %s;",
(mxid,), (mxid,),
) )
@ -199,8 +219,7 @@ class MediaRepository:
Returns: Returns:
List of File objects List of File objects
""" """
with self.conn.cursor() as cur: cur = self.execute(
cur.execute(
"SELECT media_id, creation_ts, base64hash FROM mediaapi_media_repository WHERE user_id = %s;", "SELECT media_id, creation_ts, base64hash FROM mediaapi_media_repository WHERE user_id = %s;",
(user_id,), (user_id,),
) )
@ -215,12 +234,11 @@ class MediaRepository:
Returns: Returns:
List of File objects List of File objects
""" """
with self.conn.cursor() as cur:
query = """SELECT media_id, creation_ts, base64hash query = """SELECT media_id, creation_ts, base64hash
FROM mediaapi_media_repository""" FROM mediaapi_media_repository"""
if not local: if not local:
query += " WHERE user_id = ''" query += " WHERE user_id = ''"
cur.execute(query) cur = self.execute(query)
return [File(self, row[0], row[1] // 1000, row[2]) for row in cur.fetchall()] return [File(self, row[0], row[1] // 1000, row[2]) for row in cur.fetchall()]
def get_avatar_images(self) -> List[MediaID]: def get_avatar_images(self) -> List[MediaID]:
@ -229,8 +247,7 @@ class MediaRepository:
Returns: Returns:
List of media IDs List of media IDs
""" """
with self.conn.cursor() as cur: cur = self.execute("SELECT avatar_url FROM userapi_profiles WHERE avatar_url > '';", db_type="user_api")
cur.execute("SELECT avatar_url FROM userapi_profiles WHERE avatar_url > '';")
media_ids = [] media_ids = []
for (url,) in cur.fetchall(): for (url,) in cur.fetchall():
try: try:
@ -242,8 +259,7 @@ class MediaRepository:
def sanity_check_thumbnails(self) -> None: def sanity_check_thumbnails(self) -> None:
"""Check for orphaned thumbnail entries in database.""" """Check for orphaned thumbnail entries in database."""
with self.conn.cursor() as cur: cur = self.execute(
cur.execute(
"""SELECT COUNT(media_id) FROM mediaapi_thumbnail """SELECT COUNT(media_id) FROM mediaapi_thumbnail
WHERE NOT EXISTS (SELECT media_id FROM mediaapi_media_repository);""", WHERE NOT EXISTS (SELECT media_id FROM mediaapi_media_repository);""",
) )
@ -290,7 +306,7 @@ class MediaRepository:
return len(files_to_delete) return len(files_to_delete)
def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str]: def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str, dict[str, str]]:
"""Read database credentials and media path from config. """Read database credentials and media path from config.
Args: Args:
@ -309,19 +325,20 @@ def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str]:
logging.error("Config file %s not found. Use --help for usage.", conf_file) logging.error("Config file %s not found. Use --help for usage.", conf_file)
sys.exit(1) sys.exit(1)
if "media_api" not in config:
logging.error("Missing media_api section in config")
sys.exit(1)
conn_string = None global_conn_string = config.get("global",{}).get("database", {}).get("connection_string")
if "global" in config and "database" in config["global"]:
conn_string = config["global"]["database"].get("connection_string")
elif "database" in config["media_api"]:
logging.debug("Using database config from media_api section")
conn_string = config["media_api"]["database"].get("connection_string")
if not conn_string: db_types = {
logging.error("Database connection string not found in config") "media_api": lambda x: x.get("media_api",{}).get("database", {}),
"user_api": lambda x: x.get("user_api",{}).get("account_database", {}),
}
conns = {}
for db_type, config_key in db_types.items():
value = config_key(config).get("connection_string")
conns[db_type] = value if value else global_conn_string
if not conns[db_type]:
logging.error(f"Database {db_type} connection string not found in config")
sys.exit(1) sys.exit(1)
base_path = config["media_api"].get("base_path") base_path = config["media_api"].get("base_path")
@ -329,7 +346,7 @@ def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str]:
logging.error("base_path not found in media_api config") logging.error("base_path not found in media_api config")
sys.exit(1) sys.exit(1)
return Path(base_path), conn_string return Path(base_path), conns
def parse_options() -> argparse.Namespace: def parse_options() -> argparse.Namespace:
@ -363,8 +380,8 @@ def parse_options() -> argparse.Namespace:
def main() -> None: def main() -> None:
"""Execute the media cleanup process.""" """Execute the media cleanup process."""
args = parse_options() args = parse_options()
media_path, conn_string = read_config(args.config) media_path, conn_strings = read_config(args.config)
repo = MediaRepository(media_path, conn_string) repo = MediaRepository(media_path, conn_strings)
if args.mxid: if args.mxid:
process_single_media(repo, args) process_single_media(repo, args)