mypy: Make the typing checks pass without error

This commit is contained in:
Sebastian Spaeth 2023-04-19 18:19:50 +02:00
parent 52484e7351
commit 94a643d636

View File

@ -15,11 +15,11 @@
# mediaapi_media_repository: media_id | media_origin | content_type | file_size_bytes | creation_ts | upload_name | base64hash | user_id
# mediaapi_thumbnail: media_id | media_origin | content_type | file_size_bytes | creation_ts | width | height | resize_method
import argparse, logging
from datetime import datetime, timedelta
from functools import cached_property
from pathlib import Path
import argparse, logging, typing
from typing import Optional, Union, List, Tuple
try:
import psycopg2
@ -42,12 +42,12 @@ class File:
self.base64hash = base64hash
@cached_property
def fullpath(self):
def fullpath(self) -> Optional[Path]:
"""returns the directory in which the "file" and all thumbnails are located, or None if no file is known"""
if not self.base64hash: return None
return self.repo.media_path / self.base64hash[0:1] / self.base64hash[1:2] / self.base64hash[2:]
def delete(self):
def delete(self) -> bool:
"""Delete db entries, and the file itself
:returns: True on successful delete of file,
@ -75,22 +75,23 @@ class File:
logging.debug(f"Deleted {num_media} + {num_thumbnails} db entries for media id {self.media_id}")
return res
def exists(self):
def exists(self) -> bool:
"""returns True if the media file itself exists on the file system"""
if self.fullpath is None:
return False
return (self.fullpath / 'file').exists()
def has_thumbnail(self):
def has_thumbnail(self) -> int:
"""Returns the number of thumbnails associated with this file"""
with self.repo.conn.cursor() as cur:
cur.execute(f"select COUNT(media_id) from mediaapi_thumbnail WHERE media_id='{self.media_id}';")
row = cur.fetchone()
return(row[0])
if row is None:
return 0
return(int(row[0]))
#----------------------------------------------------------------------
class MediaRepository:
def __init__(self, media_path: Path, connection_string: str):
# media_path is a pathlib.Path
self.media_path = media_path
@ -100,22 +101,21 @@ class MediaRepository:
raise Exception(f"The configured media dir cannot be found!")
# psql db connection
self.conn = None
self.db_conn_string = connection_string
self.connect_db();
self.conn = self.connect_db();
def connect_db(self):
def connect_db(self) -> psycopg2.connection:
#postgresql://user:pass@localhost/database?params
if self.db_conn_string is None or not self.db_conn_string.startswith(("postgres://","postgresql://")):
errstr = "DB connection not a postgres one"
logging.error(errstr)
raise ValueError(errstr)
self.conn = psycopg2.connect(self.db_conn_string)
return psycopg2.connect(self.db_conn_string)
def get_remote_media(self):
def get_remote_media(self) -> List[File]:
with self.conn.cursor() as cur:
# media_id | media_origin | content_type | file_size_bytes | creation_ts | upload_name | base64hash | user_id
res = cur.execute("select media_id, creation_ts, base64hash from mediaapi_media_repository WHERE user_id = '';")
cur.execute("select media_id, creation_ts, base64hash from mediaapi_media_repository WHERE user_id = '';")
files = []
for row in cur.fetchall():
# creation_ts is ms since the epoch, so convert to seconds
@ -123,7 +123,7 @@ class MediaRepository:
files.append(f)
return files
#--------------------------------------------------------------
def read_config(conf_file):
def read_config(conf_file: Union[str,Path]) -> Tuple[Path, str]:
"""Read in the dendrite config file and return db creds and media path"""
try:
with open(conf_file) as f:
@ -154,7 +154,7 @@ def read_config(conf_file):
exit(1)
return (BASE_PATH, CONN_STR)
def parse_options():
def parse_options() -> argparse.Namespace:
loglevel=logging.INFO # default
parser = argparse.ArgumentParser(
prog = 'cleanmedia',
@ -166,7 +166,7 @@ def parse_options():
parser.add_argument('-n', '--dryrun', action='store_true',
help="Dry run (don't actually modify any files).")
parser.add_argument('-d', '--debug', action='store_true', help="Turn debug output on.")
args = parser.parse_args()
args: argparse.Namespace = parser.parse_args()
if args.debug: loglevel=logging.DEBUG
logging.basicConfig(level=loglevel, format= '%(levelname)s - %(message)s')
return args