from fastapi import APIRouter, Depends from fastapi.responses import StreamingResponse from PIL import Image, ImageEnhance from fastapi import HTTPException import io import requests import os import base64 from dotenv import load_dotenv from pydantic import BaseModel from pymongo import MongoClient from models import * from huggingface_hub import InferenceClient from fastapi import UploadFile from fastapi.responses import JSONResponse import uuid from RyuzakiLib import GeminiLatest class FluxAI(BaseModel): user_id: int api_key: str args: str auto_enhancer: bool = False class MistralAI(BaseModel): args: str router = APIRouter() load_dotenv() MONGO_URL = os.environ["MONGO_URL"] HUGGING_TOKEN = os.environ["HUGGING_TOKEN"] GOOGLE_API_KEY = os.environ["GOOGLE_API_KEY"] client_mongo = MongoClient(MONGO_URL) db = client_mongo["tiktokbot"] collection = db["users"] async def schellwithflux(args): API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell" headers = {"Authorization": f"Bearer {HUGGING_TOKEN}"} payload = {"inputs": args} response = requests.post(API_URL, headers=headers, json=payload) if response.status_code != 200: print(f"Error status {response.status_code}") return None return response.content async def mistralai_post_message(message_str): client = InferenceClient( "mistralai/Mixtral-8x7B-Instruct-v0.1", token=HUGGING_TOKEN ) output = "" for message in client.chat_completion( messages=[{"role": "user", "content": message_str}], max_tokens=500, stream=True ): output += message.choices[0].delta.content return output def get_user_tokens_gpt(user_id): user = collection.find_one({"user_id": user_id}) if not user: return 0 return user.get("tokens", 0) def deduct_tokens_gpt(user_id, amount): tokens = get_user_tokens_gpt(user_id) if tokens >= amount: collection.update_one( {"user_id": user_id}, {"$inc": {"tokens": -amount}} ) return True else: return False @router.get("/akeno/gettoken") async def get_token_with_flux(user_id: int): tokens = get_user_tokens_gpt(user_id) if tokens: return SuccessResponse( status="True", randydev={"tokens": f"Current tokens: {tokens}."} ) else: return SuccessResponse( status="False", randydev={"tokens": f"Not enough tokens. Current tokens: {tokens}."} ) @router.post("/akeno/mistralai", response_model=SuccessResponse, responses={422: {"model": SuccessResponse}}) async def mistralai_(payload: MistralAI): try: response = await mistralai_post_message(payload.args) return SuccessResponse( status="True", randydev={"message": response} ) except Exception as e: return SuccessResponse( status="False", randydev={"error": f"An error occurred: {str(e)}"} ) def get_all_api_keys(): user = collection.find({}) api_keys = [] for x in user: api_key = x.get("ryuzaki_api_key") if api_key: api_keys.append(api_key) return api_keys @router.post("/akeno/fluxai", response_model=SuccessResponse, responses={422: {"model": SuccessResponse}}) async def fluxai_image(payload: FluxAI): if deduct_tokens_gpt(payload.user_id, amount=20): USERS_API_KEYS = get_all_api_keys() if payload.api_key in USERS_API_KEYS: try: image_bytes = await schellwithflux(payload.args) if image_bytes is None: return SuccessResponse( status="False", randydev={"error": "Failed to generate an image"} ) if payload.auto_enhancer: with Image.open(io.BytesIO(image_bytes)) as image: enhancer = ImageEnhance.Sharpness(image) image = enhancer.enhance(1.5) enhancer = ImageEnhance.Contrast(image) image = enhancer.enhance(1.2) enhancer = ImageEnhance.Color(image) image = enhancer.enhance(1.1) enhanced_image_bytes = "akeno.jpg" image.save(enhanced_image_bytes, format="JPEG", quality=95) with open(enhanced_image_bytes, "rb") as image_file: encoded_string = base64.b64encode(image_file.read()).decode('utf-8') example_test = "Accurately identify the baked good in the image and provide an appropriate and recipe consistent with your analysis." x = GeminiLatest(api_keys=GOOGLE_API_KEY) response = x.get_response_image(example_test, enhanced_image_bytes) return SuccessResponse( status="True", randydev={"image_data": encoded_string, "caption": response} ) else: return StreamingResponse(io.BytesIO(image_bytes), media_type="image/jpeg") except Exception as e: return SuccessResponse( status="False", randydev={"error": f"An error occurred: {str(e)}"} ) else: return SuccessResponse( status="False", randydev={"error": f"Error required api_key"} ) else: tokens = get_user_tokens_gpt(payload.user_id) return SuccessResponse( status="False", randydev={"error": f"Not enough tokens. Current tokens: {tokens} and required api_key."} )