Spaces:
Running
Running
from fastapi import APIRouter, Depends | |
from fastapi.responses import StreamingResponse | |
from PIL import Image, ImageEnhance | |
from fastapi import HTTPException | |
import io | |
import requests | |
import os | |
from dotenv import load_dotenv | |
from pydantic import BaseModel | |
from pymongo import MongoClient | |
from models import * | |
from huggingface_hub import InferenceClient | |
class FluxAI(BaseModel): | |
user_id: int | |
args: str | |
auto_enhancer: bool = False | |
class MistralAI(BaseModel): | |
args: str | |
router = APIRouter() | |
load_dotenv() | |
MONGO_URL = os.environ["MONGO_URL"] | |
HUGGING_TOKEN = os.environ["HUGGING_TOKEN"] | |
client_mongo = MongoClient(MONGO_URL) | |
db = client_mongo["tiktokbot"] | |
collection = db["users"] | |
async def schellwithflux(args): | |
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell" | |
headers = {"Authorization": f"Bearer {HUGGING_TOKEN}"} | |
payload = {"inputs": args} | |
response = requests.post(API_URL, headers=headers, json=payload) | |
if response.status_code != 200: | |
print(f"Error status {response.status_code}") | |
return None | |
return response.content | |
async def mistralai_post_message(message_str): | |
client = InferenceClient( | |
"mistralai/Mixtral-8x7B-Instruct-v0.1", | |
token=HUGGING_TOKEN | |
) | |
output = "" | |
for message in client.chat_completion( | |
messages=[{"role": "user", "content": message_str}], | |
max_tokens=500, | |
stream=True | |
): | |
output += message.choices[0].delta.content | |
return output | |
def get_user_tokens_gpt(user_id): | |
user = collection.find_one({"user_id": user_id}) | |
if not user: | |
return 0 | |
return user.get("tokens", 0) | |
def deduct_tokens_gpt(user_id, amount): | |
tokens = get_user_tokens_gpt(user_id) | |
if tokens >= amount: | |
collection.update_one( | |
{"user_id": user_id}, | |
{"$inc": {"tokens": -amount}} | |
) | |
return True | |
else: | |
return False | |
async def mistralai_(payload: MistralAI): | |
try: | |
response = await mistralai_post_message(payload.args) | |
return SuccessResponse( | |
status="True", | |
randydev={"message": response} | |
) | |
except Exception as e: | |
return SuccessResponse( | |
status="False", | |
randydev={"error": f"An error occurred: {str(e)}"} | |
) | |
async def fluxai_image(payload: FluxAI): | |
if deduct_tokens_gpt(payload.user_id, amount=20): | |
try: | |
# Generate the image from the flux AI model | |
image_bytes = await schellwithflux(payload.args) | |
if image_bytes is None: | |
return SuccessResponse( | |
status="False", | |
randydev={"error": "Failed to generate an image"} | |
) | |
if payload.auto_enhancer: | |
with Image.open(io.BytesIO(image_bytes)) as image: | |
enhancer = ImageEnhance.Sharpness(image) | |
image = enhancer.enhance(1.5) | |
enhancer = ImageEnhance.Contrast(image) | |
image = enhancer.enhance(1.2) | |
enhancer = ImageEnhance.Color(image) | |
image = enhancer.enhance(1.1) | |
enhanced_image_bytes = io.BytesIO() | |
image.save(enhanced_image_bytes, format="JPEG", quality=95) | |
enhanced_image_bytes.seek(0) | |
return StreamingResponse(enhanced_image_bytes, media_type="image/jpeg") | |
else: | |
return StreamingResponse(io.BytesIO(image_bytes), media_type="image/jpeg") | |
except Exception as e: | |
return SuccessResponse( | |
status="False", | |
randydev={"error": f"An error occurred: {str(e)}"} | |
) | |
else: | |
tokens = get_user_tokens_gpt(payload.user_id) | |
return SuccessResponse( | |
status="False", | |
randydev={"error": f"Not enough tokens. Current tokens: {tokens}. Please support @xtdevs"} | |
) | |