Spaces:
Running
Running
from fastapi import APIRouter, Depends | |
from fastapi.responses import StreamingResponse | |
from PIL import Image, ImageEnhance | |
from fastapi import HTTPException | |
import io | |
import requests | |
import os | |
from dotenv import load_dotenv | |
from pydantic import BaseModel | |
from pymongo import MongoClient | |
from models import * | |
class FluxAI(BaseModel): | |
user_id: int | |
args: str | |
auto_enhancer: bool = False | |
router = APIRouter() | |
load_dotenv() | |
MONGO_URL = os.environ["MONGO_URL"] | |
HUGGING_TOKEN = os.environ["HUGGING_TOKEN"] | |
client_mongo = MongoClient(MONGO_URL) | |
db = client_mongo["tiktokbot"] | |
collection = db["users"] | |
async def schellwithflux(args): | |
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell" | |
headers = {"Authorization": f"Bearer {HUGGING_TOKEN}"} | |
payload = {"inputs": args} | |
response = requests.post(API_URL, headers=headers, json=payload) | |
if response.status_code != 200: | |
print(f"Error status {response.status_code}") | |
return None | |
return response.content | |
def get_user_tokens_gpt(user_id): | |
user = collection.find_one({"user_id": user_id}) | |
if not user: | |
return 0 | |
return user.get("tokens", 0) | |
def deduct_tokens_gpt(user_id, amount): | |
tokens = get_user_tokens_gpt(user_id) | |
if tokens >= amount: | |
collection.update_one( | |
{"user_id": user_id}, | |
{"$inc": {"tokens": -amount}} | |
) | |
return True | |
else: | |
return False | |
async def fluxai_image(payload: FluxAI): | |
if deduct_tokens_gpt(payload.user_id, amount=20): | |
try: | |
# Generate the image from the flux AI model | |
image_bytes = await schellwithflux(payload.args) | |
if image_bytes is None: | |
return SuccessResponse( | |
status="False", | |
randydev={"error": "Failed to generate an image"} | |
) | |
if payload.auto_enhancer: | |
with Image.open(io.BytesIO(image_bytes)) as image: | |
enhancer = ImageEnhance.Sharpness(image) | |
image = enhancer.enhance(1.5) | |
enhancer = ImageEnhance.Contrast(image) | |
image = enhancer.enhance(1.2) | |
enhancer = ImageEnhance.Color(image) | |
image = enhancer.enhance(1.1) | |
enhanced_image_bytes = io.BytesIO() | |
image.save(enhanced_image_bytes, format="JPEG", quality=95) | |
enhanced_image_bytes.seek(0) | |
return StreamingResponse(enhanced_image_bytes, media_type="image/jpeg") | |
else: | |
return StreamingResponse(io.BytesIO(image_bytes), media_type="image/jpeg") | |
except Exception as e: | |
return SuccessResponse( | |
status="False", | |
randydev={"error": f"An error occurred: {str(e)}"} | |
) | |
else: | |
tokens = get_user_tokens_gpt(payload.user_id) | |
return SuccessResponse( | |
status="False", | |
randydev={"error": f"Not enough tokens. Current tokens: {tokens}. Please support @xtdevs"} | |
) | |