from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
import io
import requests
import os
from dotenv import load_dotenv
from pydantic import BaseModel
from pymongo import MongoClient
from models import *

class FluxAI(BaseModel):
    user_id: int
    args: str

router = APIRouter()

load_dotenv()
MONGO_URL = os.environ["MONGO_URL"]
HUGGING_TOKEN = os.environ["HUGGING_TOKEN"]

client_mongo = MongoClient(MONGO_URL)
db = client_mongo["tiktokbot"]
collection = db["users"]

async def schellwithflux(args):
    API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
    headers = {"Authorization": f"Bearer {HUGGING_TOKEN}"}
    payload = {"inputs": args}
    response = requests.post(API_URL, headers=headers, json=payload)
    if response.status_code != 200:
        print(f"Error status {response.status_code}")
        return None
    return response.content

def get_user_tokens_gpt(user_id):
    user = collection.find_one({"user_id": user_id})
    if not user:
        return 0
    return user.get("tokens", 0)

def deduct_tokens_gpt(user_id, amount):
    tokens = get_user_tokens_gpt(user_id)
    if tokens >= amount:
        collection.update_one(
            {"user_id": user_id},
            {"$inc": {"tokens": -amount}}
        )
        return True
    else:
        return False

@router.post("/akeno/fluxai", response_model=SuccessResponse, responses={422: {"model": SuccessResponse}})
async def fluxai_image(payload: FluxAI):
    if deduct_tokens_gpt(payload.user_id, amount=20):
        try:
            image_bytes = await schellwithflux(payload.args)
            if image_bytes is None:
                return SuccessResponse(
                    status="False",
                    randydev={"error": "Failed to generate an image"}
                )
            
            return StreamingResponse(io.BytesIO(image_bytes), media_type="image/jpeg")
        
        except Exception as e:
            return SuccessResponse(
                status="False",
                randydev={"error": f"An error occurred: {str(e)}"}
            )
    else:
        tokens = get_user_tokens_gpt(payload.user_id)
        return SuccessResponse(
            status="False",
            randydev={"error": f"Not enough tokens. Current tokens: {tokens}. Please support @xtdevs"}
        )