Spaces:
Running
Running
from fastapi import APIRouter, Depends | |
from fastapi.responses import StreamingResponse | |
from PIL import Image, ImageEnhance | |
from fastapi import HTTPException | |
import io | |
import requests | |
import os | |
from dotenv import load_dotenv | |
from pydantic import BaseModel | |
from pymongo import MongoClient | |
from models import * | |
from huggingface_hub import InferenceClient | |
from fastapi import UploadFile, File | |
from fastapi.responses import JSONResponse, FileResponse | |
import uuid | |
from RyuzakiLib import GeminiLatest | |
class FluxAI(BaseModel): | |
user_id: int | |
args: str | |
auto_enhancer: bool = False | |
class MistralAI(BaseModel): | |
args: str | |
router = APIRouter() | |
load_dotenv() | |
MONGO_URL = os.environ["MONGO_URL"] | |
HUGGING_TOKEN = os.environ["HUGGING_TOKEN"] | |
GOOGLE_API_KEY = os.environ["GOOGLE_API_KEY"] | |
client_mongo = MongoClient(MONGO_URL) | |
db = client_mongo["tiktokbot"] | |
collection = db["users"] | |
async def schellwithflux(args): | |
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell" | |
headers = {"Authorization": f"Bearer {HUGGING_TOKEN}"} | |
payload = {"inputs": args} | |
response = requests.post(API_URL, headers=headers, json=payload) | |
if response.status_code != 200: | |
print(f"Error status {response.status_code}") | |
return None | |
return response.content | |
async def mistralai_post_message(message_str): | |
client = InferenceClient( | |
"mistralai/Mixtral-8x7B-Instruct-v0.1", | |
token=HUGGING_TOKEN | |
) | |
output = "" | |
for message in client.chat_completion( | |
messages=[{"role": "user", "content": message_str}], | |
max_tokens=500, | |
stream=True | |
): | |
output += message.choices[0].delta.content | |
return output | |
def get_user_tokens_gpt(user_id): | |
user = collection.find_one({"user_id": user_id}) | |
if not user: | |
return 0 | |
return user.get("tokens", 0) | |
def deduct_tokens_gpt(user_id, amount): | |
tokens = get_user_tokens_gpt(user_id) | |
if tokens >= amount: | |
collection.update_one( | |
{"user_id": user_id}, | |
{"$inc": {"tokens": -amount}} | |
) | |
return True | |
else: | |
return False | |
UPLOAD_DIRECTORY = "./uploads" | |
async def upload_file(file: UploadFile = File(...)): | |
try: | |
ext = file.filename.split(".")[-1] | |
unique_filename = f"{uuid.uuid4().hex}.{ext}" | |
file_location = os.path.join(UPLOAD_DIRECTORY, unique_filename) | |
with open(file_location, "wb") as f: | |
f.write(await file.read()) | |
return JSONResponse( | |
status_code=200, | |
content={"url": f"https://randydev-ryuzaki-api.hf.space/api/v1/uploads/{unique_filename}"} | |
) | |
except Exception as e: | |
return JSONResponse( | |
status_code=500, | |
content={"error": str(e)} | |
) | |
async def serve_file(filename: str): | |
file_location = os.path.join(UPLOAD_DIRECTORY, filename) | |
if os.path.exists(file_location): | |
return FileResponse(file_location) | |
return JSONResponse( | |
status_code=404, | |
content={"error": "File not found"} | |
) | |
async def mistralai_(payload: MistralAI): | |
try: | |
response = await mistralai_post_message(payload.args) | |
return SuccessResponse( | |
status="True", | |
randydev={"message": response} | |
) | |
except Exception as e: | |
return SuccessResponse( | |
status="False", | |
randydev={"error": f"An error occurred: {str(e)}"} | |
) | |
async def fluxai_image(payload: FluxAI): | |
if deduct_tokens_gpt(payload.user_id, amount=20): | |
try: | |
image_bytes = await schellwithflux(payload.args) | |
if image_bytes is None: | |
return SuccessResponse( | |
status="False", | |
randydev={"error": "Failed to generate an image"} | |
) | |
if payload.auto_enhancer: | |
with Image.open(io.BytesIO(image_bytes)) as image: | |
enhancer = ImageEnhance.Sharpness(image) | |
image = enhancer.enhance(1.5) | |
enhancer = ImageEnhance.Contrast(image) | |
image = enhancer.enhance(1.2) | |
enhancer = ImageEnhance.Color(image) | |
image = enhancer.enhance(1.1) | |
enhanced_image_bytes = io.BytesIO() | |
image.save(enhanced_image_bytes, format="JPEG", quality=95) | |
enhanced_image_bytes.seek(0) | |
file_path = "test.jpg" | |
with open(file_path, "wb") as f: | |
f.write(enhanced_image_bytes.getvalue()) | |
url = "https://randydev-ryuzaki-api.hf.space/api/v1/uploadfile/" | |
files = {"file": open(file_path, "rb")} | |
response_uploaded = requests.post(url, files=files).json() | |
get_url = response_uploaded.get("url") | |
example_test = "Explain how this picture looks like." | |
x = GeminiLatest(api_keys=GOOGLE_API_KEY) | |
response = x.get_response_image(example_test, file_path) | |
return SuccessResponse( | |
status="True", | |
randydev={"url": get_url, "caption": response} | |
) | |
else: | |
return StreamingResponse(io.BytesIO(image_bytes), media_type="image/jpeg") | |
except Exception as e: | |
return SuccessResponse( | |
status="False", | |
randydev={"error": f"An error occurred: {str(e)}"} | |
) | |
else: | |
tokens = get_user_tokens_gpt(payload.user_id) | |
return SuccessResponse( | |
status="False", | |
randydev={"error": f"Not enough tokens. Current tokens: {tokens}. Please support @xtdevs"} | |
) | |