diff --git a/cleanmedia.py b/cleanmedia.py index d5fb842..91e9823 100755 --- a/cleanmedia.py +++ b/cleanmedia.py @@ -1,4 +1,4 @@ -"""Main cleanmedia module.""" +"""Media cleanup utility for Dendrite servers.""" """ CleanMedia. @@ -24,20 +24,33 @@ import sys from datetime import datetime, timedelta from functools import cached_property from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import List, Tuple, TypeAlias, Union try: import psycopg2 import psycopg2.extensions import yaml except ImportError as err: - raise Exception("Please install psycopg2 and pyyaml") from err + raise ImportError("Required dependencies not found. Please install psycopg2 and pyyaml.") from err + +# Type aliases +DBConnection: TypeAlias = psycopg2.extensions.connection +MediaID: TypeAlias = str +Timestamp: TypeAlias = int +Base64Hash: TypeAlias = str +UserID: TypeAlias = str class File: - """Represent a file in our db together with physical file and thumbnails.""" + """Represent a media file with its metadata and physical storage location.""" - def __init__(self, media_repo: "MediaRepository", media_id: str, creation_ts: int, base64hash: str): + def __init__( + self, + media_repo: "MediaRepository", + media_id: MediaID, + creation_ts: Timestamp, + base64hash: Base64Hash, + ) -> None: """Initialize a File object.""" self.repo = media_repo self.media_id = media_id @@ -45,31 +58,57 @@ class File: self.base64hash = base64hash @cached_property - def fullpath(self) -> Optional[Path]: - """Returns the directory in which the "file" and all thumbnails are located, or None if no file is known.""" + def fullpath(self) -> Path | None: + """Get the directory containing the file and its thumbnails. + + Returns: + Path to directory or None if no file location is known + """ if not self.base64hash: return None return self.repo.media_path / self.base64hash[0:1] / self.base64hash[1:2] / self.base64hash[2:] def delete(self) -> bool: - """Remove db entries and the file itself. + """Remove file from filesystem and database. - :returns: True on successful delete of file, - False or Exception on failure + Returns: + True if deletion was successful, False otherwise + """ + if not self._delete_files(): + return False + + return self._delete_db_entries() + + def _delete_files(self) -> bool: + """Remove physical files from filesystem. + + Returns: + True if files were deleted or didn't exist, False on error """ - res = True if self.fullpath is None: logging.info(f"No known path for file id '{self.media_id}', cannot delete file.") - res = False - elif not self.fullpath.is_dir(): - logging.debug(f"Path for file id '{self.media_id}' is not a directory or does not exist, not deleting.") - res = False - else: + return False + + if not self.fullpath.is_dir(): + logging.debug(f"Path for file id '{self.media_id}' is not a directory or does not exist.") + return False + + try: for file in self.fullpath.glob("*"): file.unlink() self.fullpath.rmdir() logging.debug(f"Deleted directory {self.fullpath}") + return True + except OSError as err: + logging.error(f"Failed to delete files for {self.media_id}: {err}") + return False + def _delete_db_entries(self) -> bool: + """Remove file entries from database. + + Returns: + True if database entries were deleted successfully + """ with self.repo.conn.cursor() as cur: cur.execute("DELETE from mediaapi_thumbnail WHERE media_id=%s;", (self.media_id,)) num_thumbnails = cur.rowcount @@ -77,243 +116,299 @@ class File: num_media = cur.rowcount self.repo.conn.commit() logging.debug(f"Deleted {num_media} + {num_thumbnails} db entries for media id {self.media_id}") - return res + return True def exists(self) -> bool: - """Return True if the media file exists on the file system.""" + """Check if the media file exists on the filesystem. + + Returns: + True if file exists, False otherwise + """ if self.fullpath is None: return False return (self.fullpath / "file").exists() def has_thumbnail(self) -> int: - """Return the number of thumbnails associated with this file.""" + """Count thumbnails associated with this file. + + Returns: + Number of thumbnails + """ with self.repo.conn.cursor() as cur: - cur.execute(f"select COUNT(media_id) from mediaapi_thumbnail WHERE media_id='{self.media_id}';") + cur.execute("SELECT COUNT(media_id) FROM mediaapi_thumbnail WHERE media_id = %s;", (self.media_id,)) row = cur.fetchone() - if row is None: - return 0 - return int(row[0]) + return int(row[0]) if row else 0 class MediaRepository: - """Handle a dendrite media repository.""" + """Handle media storage and retrieval for a Dendrite server.""" - def __init__(self, media_path: Path, connection_string: str): - """Initialize a MediaRepository object.""" + def __init__(self, media_path: Path, connection_string: str) -> None: + """Initialize MediaRepository. + + Args: + media_path: Path to media storage directory + connection_string: PostgreSQL connection string + + Raises: + ValueError: If media_path doesn't exist or connection string is invalid + """ + self._validate_media_path(media_path) self.media_path = media_path - if not self.media_path.is_absolute(): - logging.warn("The media path is relative, make sure you run this script in the correct directory!") - if not self.media_path.is_dir(): - raise Exception("The configured media dir cannot be found!") - self._avatar_media_ids: List[str] = [] - + self._avatar_media_ids: List[MediaID] = [] self.db_conn_string = connection_string self.conn = self.connect_db() - def connect_db(self) -> psycopg2.extensions.connection: - """Return a connection to the database.""" - if self.db_conn_string is None or not self.db_conn_string.startswith(("postgres://", "postgresql://")): - errstr = "DB connection not a postgres one" - logging.error(errstr) - raise ValueError(errstr) + @staticmethod + def _validate_media_path(path: Path) -> None: + if not path.is_absolute(): + logging.warning("Media path is relative. Ensure correct working directory!") + if not path.is_dir(): + raise ValueError("Media directory not found") + + def connect_db(self) -> DBConnection: + """Establish database connection. + + Returns: + PostgreSQL connection object + + Raises: + ValueError: If connection string is invalid + """ + if not self.db_conn_string or not self.db_conn_string.startswith(("postgres://", "postgresql://")): + raise ValueError("Invalid PostgreSQL connection string") return psycopg2.connect(self.db_conn_string) - def get_single_media(self, mxid: str) -> Optional[File]: - """Return a File object or None for given media ID.""" - with self.conn.cursor() as cur: - sql_str = "SELECT media_id, creation_ts, base64hash from mediaapi_media_repository WHERE media_id = %s;" - cur.execute(sql_str, (mxid,)) - row = cur.fetchone() - if row is None: - return None - return File(self, row[0], row[1] // 1000, row[2]) - - def get_local_user_media(self, user_id: str) -> List[File]: - """Return all media created by a local user. - - :params: - :user_id: (`str`) of form "@user:servername.com" - :returns: `List[File]` - """ - with self.conn.cursor() as cur: - sql_str = "SELECT media_id, creation_ts, base64hash from mediaapi_media_repository WHERE user_id = %s;" - cur.execute(sql_str, (user_id,)) - files = [] - for row in cur.fetchall(): - f = File(self, row[0], row[1] // 1000, row[2]) - files.append(f) - return files - - def get_all_media(self, local: bool = False) -> List[File]: - """Return a list of remote media or ALL media if local==True.""" - with self.conn.cursor() as cur: - sql_str = "SELECT media_id, creation_ts, base64hash from mediaapi_media_repository" - if not local: - sql_str += " WHERE user_id = ''" - sql_str += ";" - cur.execute(sql_str) - files = [] - for row in cur.fetchall(): - f = File(self, row[0], row[1] // 1000, row[2]) - files.append(f) - return files - - def get_avatar_images(self) -> List[str]: - """Return a list of media_id which are current avatar images. - - We don't want to clean up those. Save & cache them internally. - """ - media_id = [] - with self.conn.cursor() as cur: - cur.execute("SELECT avatar_url FROM userapi_profiles WHERE avatar_url > '';") - for row in cur.fetchall(): - url = row[0] - try: - media_id.append(url[url.rindex("/") + 1 :]) - except ValueError: - logging.warn("No slash in URL '%s'!", url) - self._avatar_media_ids = media_id - return self._avatar_media_ids - - def sanity_check_thumbnails(self) -> None: - """Check for thumbnails in db that don't refer to existing media.""" + def get_single_media(self, mxid: MediaID) -> File | None: + """Retrieve a single media file by ID.""" with self.conn.cursor() as cur: cur.execute( - "SELECT COUNT(media_id) from mediaapi_thumbnail " - "WHERE NOT EXISTS (SELECT media_id FROM mediaapi_media_repository);", + "SELECT media_id, creation_ts, base64hash from mediaapi_media_repository WHERE media_id = %s;", + (mxid,), ) row = cur.fetchone() - if row is not None and row[0]: + return File(self, row[0], row[1] // 1000, row[2]) if row else None + + def get_local_user_media(self, user_id: UserID) -> List[File]: + """Get all media files created by a local user. + + Args: + user_id: User ID in format "@user:servername.com" + + Returns: + List of File objects + """ + with self.conn.cursor() as cur: + cur.execute( + """SELECT media_id, creation_ts, base64hash + FROM mediaapi_media_repository + WHERE user_id = %s;""", + (user_id,), + ) + return [File(self, row[0], row[1], row[2]) for row in cur.fetchall()] + + def get_all_media(self, local: bool = False) -> List[File]: + """Get all media files or only remote ones. + + Args: + local: If True, include local media files + + Returns: + List of File objects + """ + with self.conn.cursor() as cur: + query = """SELECT media_id, creation_ts, base64hash + FROM mediaapi_media_repository""" + if not local: + query += " WHERE user_id = ''" + cur.execute(query) + return [File(self, row[0], row[1] // 1000, row[2]) for row in cur.fetchall()] + + def get_avatar_images(self) -> List[MediaID]: + """Get media IDs of current avatar images. + + Returns: + List of media IDs + """ + with self.conn.cursor() as cur: + cur.execute("SELECT avatar_url FROM userapi_profiles WHERE avatar_url > '';") + media_ids = [] + for (url,) in cur.fetchall(): + try: + media_ids.append(url[url.rindex("/") + 1 :]) + except ValueError: + logging.warning("Invalid avatar URL: %s", url) + self._avatar_media_ids = media_ids + return media_ids + + def sanity_check_thumbnails(self) -> None: + """Check for orphaned thumbnail entries in database.""" + with self.conn.cursor() as cur: + cur.execute( + """SELECT COUNT(media_id) FROM mediaapi_thumbnail + WHERE NOT EXISTS (SELECT media_id FROM mediaapi_media_repository);""", + ) + if (row := cur.fetchone()) and (count := row[0]): logging.error( - "You have {} thumbnails in your db that do not refer to media. " - "This needs fixing (we don't do that)!".format(row[0]), + "You have %d thumbnails in your db that do not refer to media. " + "This needs fixing (we don't do that)!", + count, ) def clean_media_files(self, days: int, local: bool = False, dryrun: bool = False) -> int: - """Remove old media files from this repository. + """Remove old media files. - :params: - :days: (int) delete media files older than N days. - :local: (bool) Also delete media originating from local users - :dryrun: (bool) Do not actually delete any files (just count) - :returns: (int) The number of files that were/would be deleted + Args: + days: Delete files older than this many days + local: If True, include local media files + dryrun: If True, only simulate deletion + + Returns: + Number of files deleted (or that would be deleted in dryrun mode) """ if local: self.get_avatar_images() - cleantime = datetime.today() - timedelta(days=days) - logging.info("Deleting remote media older than %s", cleantime) - num_deleted = 0 - files = self.get_all_media(local) - for file in [f for f in files if f.media_id not in self._avatar_media_ids]: - if file.create_date < cleantime: - num_deleted += 1 - if dryrun: - logging.info(f"Pretending to delete file id {file.media_id} on path {file.fullpath}.") - if not file.exists(): - logging.info(f"File id {file.media_id} does not physically exist (path {file.fullpath}).") - else: - file.delete() - info_str = "Deleted %d files during the run." - if dryrun: - info_str = "%d files would have been deleted during the run." - logging.info(info_str, num_deleted) + cutoff_date = datetime.today() - timedelta(days=days) + logging.info("Deleting remote media older than %s", cutoff_date) - return num_deleted + files_to_delete = [ + f + for f in self.get_all_media(local) + if f.media_id not in self._avatar_media_ids and f.create_date < cutoff_date + ] + + for file in files_to_delete: + if dryrun: + logging.info(f"Would delete file {file.media_id} at {file.fullpath}") + if not file.exists(): + logging.info(f"File {file.media_id} doesn't exist at {file.fullpath}") + else: + file.delete() + + action = "Would have deleted" if dryrun else "Deleted" + logging.info("%s %d files", action, len(files_to_delete)) + return len(files_to_delete) def read_config(conf_file: Union[str, Path]) -> Tuple[Path, str]: - """Return db credentials and media path from dendrite config file.""" + """Read database credentials and media path from config. + + Args: + conf_file: Path to Dendrite YAML config file + + Returns: + Tuple of (media_path, connection_string) + + Raises: + SystemExit: If config file is invalid or missing required fields + """ try: with open(conf_file) as f: config = yaml.safe_load(f) except FileNotFoundError: - errstr = f"Config file {conf_file} not found. Use the --help option to find out more." - logging.error(errstr) + logging.error("Config file %s not found. Use --help for usage.", conf_file) sys.exit(1) if "media_api" not in config: - logging.error("Missing section media_api") + logging.error("Missing media_api section in config") sys.exit(1) - CONN_STR = None + conn_string = None if "global" in config and "database" in config["global"]: - CONN_STR = config["global"]["database"].get("connection_string", None) + conn_string = config["global"]["database"].get("connection_string") elif "database" in config["media_api"]: - logging.debug("No database section in global, but one in media_api, using that") - CONN_STR = config["media_api"]["database"].get("connection_string", None) + logging.debug("Using database config from media_api section") + conn_string = config["media_api"]["database"].get("connection_string") - if CONN_STR is None: - logging.error("Did not find connection string to media database.") + if not conn_string: + logging.error("Database connection string not found in config") sys.exit(1) - BASE_PATH = Path(config["media_api"].get("base_path", None)) - - if BASE_PATH is None: - logging.error("Missing base_path in media_api") + base_path = config["media_api"].get("base_path") + if not base_path: + logging.error("base_path not found in media_api config") sys.exit(1) - return (BASE_PATH, CONN_STR) + + return Path(base_path), conn_string def parse_options() -> argparse.Namespace: - """Return parsed command line options.""" - loglevel = logging.INFO - parser = argparse.ArgumentParser( - prog="cleanmedia", - description="Deletes 30 day old remote media files from dendrite servers", - ) - parser.add_argument("-c", "--config", default="config.yaml", help="location of the dendrite.yaml config file.") - parser.add_argument("-m", "--mxid", dest="mxid", help="Just delete media . (no cleanup otherwise)") - parser.add_argument( - "-u", - "--userid", - dest="userid", - help=( - "Delete all media by local user '\\@user:domain.com'. " - "(ie, a user on hour homeserver. no cleanup otherwise)" - ), - ) - parser.add_argument("-t", "--days", dest="days", default="30", type=int, help="Keep remote media for days.") - parser.add_argument("-l", "--local", action="store_true", help="Also purge local (ie, from *our* users) media.") - parser.add_argument("-n", "--dryrun", action="store_true", help="Dry run (don't actually modify any files).") - parser.add_argument("-q", "--quiet", action="store_true", help="Reduce output verbosity.") - parser.add_argument("-d", "--debug", action="store_true", help="Increase output verbosity.") - args: argparse.Namespace = parser.parse_args() + """Parse command line arguments. + + Returns: + Parsed argument namespace + """ + parser = argparse.ArgumentParser(prog="cleanmedia", description="Delete old media files from Dendrite servers") + parser.add_argument("-c", "--config", default="config.yaml", help="Path to dendrite.yaml config file") + parser.add_argument("-m", "--mxid", help="Delete specific media ID") + parser.add_argument("-u", "--userid", help="Delete all media from local user '@user:domain.com'") + parser.add_argument("-t", "--days", type=int, default=30, help="Keep remote media for DAYS days") + parser.add_argument("-l", "--local", action="store_true", help="Include local user media in cleanup") + parser.add_argument("-n", "--dryrun", action="store_true", help="Simulate cleanup without modifying files") + parser.add_argument("-q", "--quiet", action="store_true", help="Reduce output verbosity") + parser.add_argument("-d", "--debug", action="store_true", help="Increase output verbosity") + + args = parser.parse_args() + + log_level = logging.INFO if args.debug: - loglevel = logging.DEBUG + log_level = logging.DEBUG elif args.quiet: - loglevel = logging.WARNING - logging.basicConfig(level=loglevel, format="%(levelname)s - %(message)s") + log_level = logging.WARNING + logging.basicConfig(level=log_level, format="%(levelname)s - %(message)s") + return args -if __name__ == "__main__": +def main() -> None: + """Execute the media cleanup process.""" args = parse_options() - (MEDIA_PATH, CONN_STR) = read_config(args.config) - mr = MediaRepository(MEDIA_PATH, CONN_STR) + media_path, conn_string = read_config(args.config) + repo = MediaRepository(media_path, conn_string) if args.mxid: - logging.info("Attempting to delete media '%s'", args.mxid) - file = mr.get_single_media(args.mxid) - if file: - logging.info("Found media with id '%s'", args.mxid) - if not args.dryrun: - file.delete() + process_single_media(repo, args) elif args.userid: - logging.info("Attempting to delete media by user '%s'", args.userid) - files = mr.get_local_user_media(args.userid) - num_deleted = 0 - for file in files: - num_deleted += 1 - if args.dryrun: - logging.info(f"Pretending to delete file id {file.media_id} on path {file.fullpath}.") - else: - file.delete() - info_str = "Deleted %d files during the run." - if args.dryrun: - info_str = "%d files would have been deleted during the run." - logging.info(info_str, num_deleted) - + process_user_media(repo, args) else: - mr.sanity_check_thumbnails() - mr.clean_media_files(args.days, args.local, args.dryrun) + repo.sanity_check_thumbnails() + repo.clean_media_files(args.days, args.local, args.dryrun) + + +def process_single_media(repo: MediaRepository, args: argparse.Namespace) -> None: + """Handle deletion of a single media file. + + Args: + repo: MediaRepository instance + args: Parsed command line arguments + """ + logging.info("Attempting to delete media '%s'", args.mxid) + if file := repo.get_single_media(args.mxid): + logging.info("Found media with id '%s'", args.mxid) + if not args.dryrun: + file.delete() + + +def process_user_media(repo: MediaRepository, args: argparse.Namespace) -> None: + """Handle deletion of all media from a user. + + Args: + repo: MediaRepository instance + args: Parsed command line arguments + """ + logging.info("Attempting to delete media by user '%s'", args.userid) + files = repo.get_local_user_media(args.userid) + + for file in files: + if args.dryrun: + logging.info("Would delete file %s at %s", file.media_id, file.fullpath) + else: + file.delete() + + action = "Would delete" if args.dryrun else "Deleted" + logging.info("%s %d files", action, len(files)) + + +if __name__ == "__main__": + main()