diff --git a/db/models.py b/db/models.py index 4ae414a..f6469db 100644 --- a/db/models.py +++ b/db/models.py @@ -18,6 +18,7 @@ Functions: """ import secrets import datetime +import random from sqlalchemy import create_engine, text from fastapi import HTTPException @@ -31,6 +32,7 @@ DATABASE_URL = "sqlite:///" + str(DATABASE_FILE) engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) DRINK_COST = 100 # cent +AVAILABLE_DRINKS = ["Paulaner Spezi", "Mio Mate", "Club Mate", "Eistee Pfirsisch"] with engine.connect() as conn: # Create a table for postpaid users @@ -465,7 +467,8 @@ def get_last_drink(user_id: int, user_is_postpaid: bool, max_since_seconds: int last_drink_time = last_drink_time.replace(tzinfo=datetime.timezone.utc) if (now - last_drink_time).total_seconds() > max_since_seconds: return None - return {"id": drink_id, "timestamp": timestamp, "drink_type": drink_type} + drink_obj = {"id": drink_id, "timestamp": timestamp, "drink_type": drink_type} + return drink_obj def revert_last_drink(user_id: int, user_is_postpaid: bool, drink_id: int, drink_cost: int = DRINK_COST): if user_is_postpaid: @@ -513,3 +516,23 @@ def update_drink_type(user_id: int, user_is_postpaid: bool, drink_id, drink_type raise HTTPException(status_code=404, detail="Drink not found") connection.commit() return result.rowcount + +def get_most_used_drinks(user_id: int, user_is_postpaid: bool, limit: int = 4): + if user_is_postpaid: + t = text("SELECT drink_type, count(drink_type) as count FROM drinks WHERE postpaid_user_id = :user_id AND drink_type IS NOT NULL AND drink_type != 'Sonstiges' GROUP BY drink_type ORDER BY count DESC LIMIT :limit") + else: + t = text("SELECT drink_type, count(drink_type) as count FROM drinks WHERE prepaid_user_id = :user_id AND drink_type IS NOT NULL AND drink_type != 'Sonstiges' GROUP BY drink_type ORDER BY count DESC LIMIT :limit") + + with engine.connect() as connection: + result = connection.execute(t, {"user_id": user_id, "limit": limit}).fetchall() + if not result: + return [] + drinks = [{"drink_type": row[0], "count": row[1]} for row in result] + + while len(drinks) < limit: + random_drink = random.choice(AVAILABLE_DRINKS) + if any(drink["drink_type"] == random_drink for drink in drinks): + continue + drinks.append({"drink_type": random_drink, "count": 0}) + + return drinks diff --git a/main.py b/main.py index 7c3b04c..6e7bec8 100644 --- a/main.py +++ b/main.py @@ -26,6 +26,7 @@ from db.models import del_user_prepaid from db.models import get_last_drink from db.models import revert_last_drink from db.models import update_drink_type +from db.models import get_most_used_drinks from auth import oidc import os @@ -107,6 +108,9 @@ def home(request: Request): # get last drink for current user, if not less than 60 seconds ago last_drink = get_last_drink(user_db_id, user_is_postpaid, 60) + most_used_drinks = get_most_used_drinks(user_db_id, user_is_postpaid, 3) + most_used_drinks.append({"drink_type": "Sonstiges", "count": 0}) # Ensure "Sonstiges" is always included + return templates.TemplateResponse("index.html", { "request": request, "user": user_authentik, @@ -116,7 +120,7 @@ def home(request: Request): "db_users_prepaid": db_users_prepaid, "prepaid_users_from_curr_user": prepaid_users_from_curr_user, "last_drink": last_drink, - "avail_drink_types": ["Paulaner Spezi", "Mio Mate", "Club Mate", "Sonstiges"], + "avail_drink_types": most_used_drinks, }) @app.get("/login", response_class=HTMLResponse) diff --git a/static/drinks/eistee_pfirsisch.png b/static/drinks/eistee_pfirsisch.png new file mode 100644 index 0000000..d415518 Binary files /dev/null and b/static/drinks/eistee_pfirsisch.png differ diff --git a/static/drinks/sonstiges.png b/static/drinks/sonstiges.png new file mode 100644 index 0000000..921ff89 Binary files /dev/null and b/static/drinks/sonstiges.png differ diff --git a/templates/index.html b/templates/index.html index b4eff68..fb25771 100644 --- a/templates/index.html +++ b/templates/index.html @@ -78,11 +78,14 @@ content %}