from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from PIL import Image, ImageEnhance
from fastapi import HTTPException
import io
import requests
import os
import base64
from dotenv import load_dotenv
from pydantic import BaseModel
from pymongo import MongoClient
from models import *
from huggingface_hub import InferenceClient
from fastapi import UploadFile
from fastapi.responses import JSONResponse
import uuid
from RyuzakiLib import GeminiLatest

class FluxAI(BaseModel):
    user_id: int
    api_key: str
    args: str
    auto_enhancer: bool = False

class MistralAI(BaseModel):
    args: str

router = APIRouter()

load_dotenv()
MONGO_URL = os.environ["MONGO_URL"]
HUGGING_TOKEN = os.environ["HUGGING_TOKEN"]
GOOGLE_API_KEY = os.environ["GOOGLE_API_KEY"]

client_mongo = MongoClient(MONGO_URL)
db = client_mongo["tiktokbot"]
collection = db["users"]

async def schellwithflux(args):
    API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
    headers = {"Authorization": f"Bearer {HUGGING_TOKEN}"}
    payload = {"inputs": args}
    response = requests.post(API_URL, headers=headers, json=payload)
    if response.status_code != 200:
        print(f"Error status {response.status_code}")
        return None
    return response.content

async def mistralai_post_message(message_str):
    client = InferenceClient(
        "mistralai/Mixtral-8x7B-Instruct-v0.1",
        token=HUGGING_TOKEN
    )
    output = ""
    for message in client.chat_completion(
        messages=[{"role": "user", "content": message_str}],
        max_tokens=500,
        stream=True
    ):
        output += message.choices[0].delta.content
    return output

def get_user_tokens_gpt(user_id):
    user = collection.find_one({"user_id": user_id})
    if not user:
        return 0
    return user.get("tokens", 0)

def deduct_tokens_gpt(user_id, amount):
    tokens = get_user_tokens_gpt(user_id)
    if tokens >= amount:
        collection.update_one(
            {"user_id": user_id},
            {"$inc": {"tokens": -amount}}
        )
        return True
    else:
        return False


@router.get("/akeno/gettoken")
async def get_token_with_flux(user_id: int):
    tokens = get_user_tokens_gpt(user_id)
    if tokens:
        return SuccessResponse(
            status="True",
            randydev={"tokens": f"Current tokens: {tokens}."}
        )
    else:
        return SuccessResponse(
            status="False",
            randydev={"tokens": f"Not enough tokens. Current tokens: {tokens}."}
        )

@router.post("/akeno/mistralai", response_model=SuccessResponse, responses={422: {"model": SuccessResponse}})
async def mistralai_(payload: MistralAI):
        try:
            response = await mistralai_post_message(payload.args)
            return SuccessResponse(
                status="True",
                randydev={"message": response}
            )
        except Exception as e:
            return SuccessResponse(
                status="False",
                randydev={"error": f"An error occurred: {str(e)}"}
            )

def get_all_api_keys():
    user = collection.find({})
    api_keys = []
    for x in user:
        api_key = x.get("ryuzaki_api_key")
        if api_key:
            api_keys.append(api_key)
    return api_keys

@router.post("/akeno/fluxai", response_model=SuccessResponse, responses={422: {"model": SuccessResponse}})
async def fluxai_image(payload: FluxAI):
    if deduct_tokens_gpt(payload.user_id, amount=20):
        USERS_API_KEYS = get_all_api_keys()
        if payload.api_key in USERS_API_KEYS:
            try:
                image_bytes = await schellwithflux(payload.args)
                if image_bytes is None:
                    return SuccessResponse(
                        status="False",
                        randydev={"error": "Failed to generate an image"}
                    )
                if payload.auto_enhancer:
                    with Image.open(io.BytesIO(image_bytes)) as image:
                        enhancer = ImageEnhance.Sharpness(image)
                        image = enhancer.enhance(1.5)
                        enhancer = ImageEnhance.Contrast(image)
                        image = enhancer.enhance(1.2)
                        enhancer = ImageEnhance.Color(image)
                        image = enhancer.enhance(1.1)
                        
                        enhanced_image_bytes = "akeno.jpg"
                        image.save(enhanced_image_bytes, format="JPEG", quality=95)
                        with open(enhanced_image_bytes, "rb") as image_file:
                            encoded_string = base64.b64encode(image_file.read()).decode('utf-8')

                    example_test = "Accurately identify the baked good in the image and provide an appropriate and recipe consistent with your analysis."
                    x = GeminiLatest(api_keys=GOOGLE_API_KEY)
                    response = x.get_response_image(example_test, enhanced_image_bytes)
                    return SuccessResponse(
                        status="True",
                        randydev={"image_data": encoded_string, "caption": response}
                    )
                else:
                    return StreamingResponse(io.BytesIO(image_bytes), media_type="image/jpeg")
            except Exception as e:
                return SuccessResponse(
                    status="False",
                    randydev={"error": f"An error occurred: {str(e)}"}
                )
        else:
            return SuccessResponse(
                status="False",
                randydev={"error": f"Error required api_key"}
            )
    else:
        tokens = get_user_tokens_gpt(payload.user_id)
        return SuccessResponse(
            status="False",
            randydev={"error": f"Not enough tokens. Current tokens: {tokens} and required api_key."}
        )