diff --git a/db/models.py b/db/models.py index 7376371..ec88e01 100644 --- a/db/models.py +++ b/db/models.py @@ -734,14 +734,12 @@ def get_most_used_drinks(user_id: int, user_is_postpaid: bool, limit: int = 4): list[dict]: Each dict has 'drink_type' and 'count'. """ if user_is_postpaid: - t = text("SELECT drink_type, count(drink_type) as count FROM drinks WHERE postpaid_user_id = :user_id AND drink_type IS NOT NULL AND drink_type != 'Sonstiges' GROUP BY drink_type ORDER BY count DESC LIMIT :limit") + t = text("SELECT drink_type, count(drink_type) as count FROM drinks WHERE postpaid_user_id = :user_id AND drink_type IS NOT NULL AND drink_type != 'Sonstiges' AND drink_type != 'None' GROUP BY drink_type ORDER BY count DESC LIMIT :limit") else: - t = text("SELECT drink_type, count(drink_type) as count FROM drinks WHERE prepaid_user_id = :user_id AND drink_type IS NOT NULL AND drink_type != 'Sonstiges' GROUP BY drink_type ORDER BY count DESC LIMIT :limit") + t = text("SELECT drink_type, count(drink_type) as count FROM drinks WHERE prepaid_user_id = :user_id AND drink_type IS NOT NULL AND drink_type != 'Sonstiges' AND drink_type != 'None' GROUP BY drink_type ORDER BY count DESC LIMIT :limit") with engine.connect() as connection: result = connection.execute(t, {"user_id": user_id, "limit": limit}).fetchall() - if not result: - return [] drinks = [{"drink_type": row[0], "count": row[1]} for row in result] while len(drinks) < limit: @@ -761,7 +759,7 @@ def get_stats_drink_types(): list[dict]: A list of dictionaries, each containing 'drink_type' (str) and 'count' (int). Returns an empty list if no results are found. """ - t = text("SELECT drink_type, count(drink_type) as count FROM drinks WHERE drink_type IS NOT NULL GROUP BY drink_type ORDER BY count DESC") + t = text("SELECT drink_type, count(drink_type) as count FROM drinks WHERE drink_type IS NOT NULL AND drink_type != 'None' GROUP BY drink_type ORDER BY count DESC") with engine.connect() as connection: result = connection.execute(t).fetchall() diff --git a/main.py b/main.py index 07952b8..31c9b11 100644 --- a/main.py +++ b/main.py @@ -497,6 +497,7 @@ def stats(request: Request): "stats_drink_types": drink_types, }) + def get_is_postpaid(user_authentik: dict) -> bool: """ Determine if a user is postpaid based on their authentication information.