from fastapi import APIRouter, Depends from fastapi.responses import StreamingResponse from PIL import Image, ImageEnhance from fastapi import HTTPException import io import requests import os from dotenv import load_dotenv from pydantic import BaseModel from pymongo import MongoClient from models import * from huggingface_hub import InferenceClient from fastapi import UploadFile from fastapi.responses import JSONResponse import uuid from RyuzakiLib import GeminiLatest class FluxAI(BaseModel): user_id: int args: str auto_enhancer: bool = False class MistralAI(BaseModel): args: str router = APIRouter() load_dotenv() MONGO_URL = os.environ["MONGO_URL"] HUGGING_TOKEN = os.environ["HUGGING_TOKEN"] GOOGLE_API_KEY = os.environ["GOOGLE_API_KEY"] client_mongo = MongoClient(MONGO_URL) db = client_mongo["tiktokbot"] collection = db["users"] async def schellwithflux(args): API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell" headers = {"Authorization": f"Bearer {HUGGING_TOKEN}"} payload = {"inputs": args} response = requests.post(API_URL, headers=headers, json=payload) if response.status_code != 200: print(f"Error status {response.status_code}") return None return response.content async def mistralai_post_message(message_str): client = InferenceClient( "mistralai/Mixtral-8x7B-Instruct-v0.1", token=HUGGING_TOKEN ) output = "" for message in client.chat_completion( messages=[{"role": "user", "content": message_str}], max_tokens=500, stream=True ): output += message.choices[0].delta.content return output def get_user_tokens_gpt(user_id): user = collection.find_one({"user_id": user_id}) if not user: return 0 return user.get("tokens", 0) def deduct_tokens_gpt(user_id, amount): tokens = get_user_tokens_gpt(user_id) if tokens >= amount: collection.update_one( {"user_id": user_id}, {"$inc": {"tokens": -amount}} ) return True else: return False @router.post("/akeno/mistralai", response_model=SuccessResponse, responses={422: {"model": SuccessResponse}}) async def mistralai_(payload: MistralAI): try: response = await mistralai_post_message(payload.args) return SuccessResponse( status="True", randydev={"message": response} ) except Exception as e: return SuccessResponse( status="False", randydev={"error": f"An error occurred: {str(e)}"} ) @router.post("/akeno/fluxai", response_model=SuccessResponse, responses={422: {"model": SuccessResponse}}) async def fluxai_image(payload: FluxAI, file: UploadFile): if deduct_tokens_gpt(payload.user_id, amount=20): try: image_bytes = await schellwithflux(payload.args) if image_bytes is None: return SuccessResponse( status="False", randydev={"error": "Failed to generate an image"} ) if payload.auto_enhancer: with Image.open(io.BytesIO(image_bytes)) as image: enhancer = ImageEnhance.Sharpness(image) image = enhancer.enhance(1.5) enhancer = ImageEnhance.Contrast(image) image = enhancer.enhance(1.2) enhancer = ImageEnhance.Color(image) image = enhancer.enhance(1.1) enhanced_image_bytes = io.BytesIO() image.save(enhanced_image_bytes, format="JPEG", quality=95) enhanced_image_bytes.seek(0) ext = file.filename.split(".")[-1] unique_filename = f"{uuid.uuid4().hex}.{ext}" file_path = os.path.join("uploads", unique_filename) os.makedirs(os.path.dirname(file_path), exist_ok=True) with open(file_path, "wb") as f: f.write(enhanced_image_bytes.getvalue()) example_test = "Accurately identify the baked good in the image and provide an appropriate and recipe consistent with your analysis." x = GeminiLatest(api_keys=GOOGLE_API_KEY) response = x.get_response_image(example_test, file_path) url = f"https://randydev-ryuzaki-api.hf.space/{file_path}" return SuccessResponse( status="True", randydev={"url": url, "caption": response} ) else: return StreamingResponse(io.BytesIO(image_bytes), media_type="image/jpeg") except Exception as e: return SuccessResponse( status="False", randydev={"error": f"An error occurred: {str(e)}"} ) else: tokens = get_user_tokens_gpt(payload.user_id) return SuccessResponse( status="False", randydev={"error": f"Not enough tokens. Current tokens: {tokens}. Please support @xtdevs"} )