ryuzaki-api / fluxai.py
randydev's picture
Update fluxai.py
8f2d0ff verified
raw
history blame
3.25 kB
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from PIL import Image, ImageEnhance
from fastapi import HTTPException
import io
import requests
import os
from dotenv import load_dotenv
from pydantic import BaseModel
from pymongo import MongoClient
from models import *
class FluxAI(BaseModel):
user_id: int
args: str
auto_enhancer: bool = False
router = APIRouter()
load_dotenv()
MONGO_URL = os.environ["MONGO_URL"]
HUGGING_TOKEN = os.environ["HUGGING_TOKEN"]
client_mongo = MongoClient(MONGO_URL)
db = client_mongo["tiktokbot"]
collection = db["users"]
async def schellwithflux(args):
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
headers = {"Authorization": f"Bearer {HUGGING_TOKEN}"}
payload = {"inputs": args}
response = requests.post(API_URL, headers=headers, json=payload)
if response.status_code != 200:
print(f"Error status {response.status_code}")
return None
return response.content
def get_user_tokens_gpt(user_id):
user = collection.find_one({"user_id": user_id})
if not user:
return 0
return user.get("tokens", 0)
def deduct_tokens_gpt(user_id, amount):
tokens = get_user_tokens_gpt(user_id)
if tokens >= amount:
collection.update_one(
{"user_id": user_id},
{"$inc": {"tokens": -amount}}
)
return True
else:
return False
@router.post("/akeno/fluxai", response_model=SuccessResponse, responses={422: {"model": SuccessResponse}})
async def fluxai_image(payload: FluxAI):
if deduct_tokens_gpt(payload.user_id, amount=20):
try:
# Generate the image from the flux AI model
image_bytes = await schellwithflux(payload.args)
if image_bytes is None:
return SuccessResponse(
status="False",
randydev={"error": "Failed to generate an image"}
)
if payload.auto_enhancer:
with Image.open(io.BytesIO(image_bytes)) as image:
enhancer = ImageEnhance.Sharpness(image)
image = enhancer.enhance(1.5)
enhancer = ImageEnhance.Contrast(image)
image = enhancer.enhance(1.2)
enhancer = ImageEnhance.Color(image)
image = enhancer.enhance(1.1)
enhanced_image_bytes = io.BytesIO()
image.save(enhanced_image_bytes, format="JPEG", quality=95)
enhanced_image_bytes.seek(0)
return StreamingResponse(enhanced_image_bytes, media_type="image/jpeg")
else:
return StreamingResponse(io.BytesIO(image_bytes), media_type="image/jpeg")
except Exception as e:
return SuccessResponse(
status="False",
randydev={"error": f"An error occurred: {str(e)}"}
)
else:
tokens = get_user_tokens_gpt(payload.user_id)
return SuccessResponse(
status="False",
randydev={"error": f"Not enough tokens. Current tokens: {tokens}. Please support @xtdevs"}
)