Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends | |
from pydantic import BaseModel | |
from fastapi.middleware.cors import CORSMiddleware | |
import uuid | |
from concurrent.futures import ThreadPoolExecutor | |
from pymongo import MongoClient | |
from urllib.parse import quote_plus | |
from langchain_groq import ChatGroq | |
from aura_sr import AuraSR | |
from io import BytesIO | |
from PIL import Image | |
import requests | |
import os | |
app = FastAPI() | |
# Middleware for CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Globals | |
executor = ThreadPoolExecutor(max_workers=10) | |
llm = None | |
upscale_model = None | |
client = MongoClient(f"mongodb+srv://hammad:{quote_plus('momimaad@123')}@cluster0.2a9yu.mongodb.net/") | |
db = client["Flux"] | |
collection = db["chat_histories"] | |
chat_sessions = {} | |
image_storage_dir = "./images" # Directory to save images locally | |
# Ensure the image storage directory exists | |
os.makedirs(image_storage_dir, exist_ok=True) | |
async def startup(): | |
global llm, upscale_model | |
llm = ChatGroq( | |
model="llama-3.3-70b-versatile", | |
temperature=0.7, | |
max_tokens=1024, | |
api_key="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", | |
) | |
upscale_model = AuraSR.from_pretrained("fal/AuraSR-v2") | |
def shutdown(): | |
client.close() | |
executor.shutdown() | |
# Pydantic models | |
class ImageRequest(BaseModel): | |
subject: str | |
style: str | |
color_theme: str | |
elements: str | |
color_mode: str | |
lighting_conditions: str | |
framing_style: str | |
material_details: str | |
text: str | |
background_details: str | |
user_prompt: str | |
chat_id: str | |
# Helper Functions | |
def generate_chat_id(): | |
chat_id = str(uuid.uuid4()) | |
chat_sessions[chat_id] = collection | |
return chat_id | |
def get_chat_history(chat_id): | |
messages = collection.find({"session_id": chat_id}) | |
return "\n".join( | |
f"User: {msg['content']}" if msg['role'] == "user" else f"AI: {msg['content']}" | |
for msg in messages | |
) | |
def save_image_locally(image, filename): | |
filepath = os.path.join(image_storage_dir, filename) | |
image.save(filepath, format="PNG") | |
return filepath | |
# Endpoints | |
async def new_chat(): | |
chat_id = generate_chat_id() | |
return {"chat_id": chat_id} | |
async def generate_image(request: ImageRequest, background_tasks: BackgroundTasks): | |
def process_request(): | |
chat_history = get_chat_history(request.chat_id) | |
prompt = f""" | |
Subject: {request.subject} | |
Style: {request.style} | |
... | |
Chat History: {chat_history} | |
User Prompt: {request.user_prompt} | |
""" | |
refined_prompt = llm.invoke(prompt).content.strip() | |
collection.insert_one({"session_id": request.chat_id, "role": "user", "content": request.user_prompt}) | |
collection.insert_one({"session_id": request.chat_id, "role": "ai", "content": refined_prompt}) | |
# Simulate image generation | |
response = requests.post( | |
"https://api.bfl.ml/v1/flux-pro-1.1", | |
json={"prompt": refined_prompt} | |
).json() | |
image_url = response["result"]["sample"] | |
# Download and save the image locally | |
image_response = requests.get(image_url) | |
img = Image.open(BytesIO(image_response.content)) | |
filename = f"generated_{uuid.uuid4()}.png" | |
filepath = save_image_locally(img, filename) | |
return filepath | |
task = executor.submit(process_request) | |
background_tasks.add_task(task) | |
return {"status": "Processing"} | |
async def upscale_image(image_url: str, background_tasks: BackgroundTasks): | |
def process_image(): | |
response = requests.get(image_url) | |
img = Image.open(BytesIO(response.content)) | |
upscaled_image = upscale_model.upscale_4x_overlapped(img) | |
filename = f"upscaled_{uuid.uuid4()}.png" | |
filepath = save_image_locally(upscaled_image, filename) | |
return filepath | |
task = executor.submit(process_image) | |
background_tasks.add_task(task) | |
return {"status": "Processing"} | |
async def set_prompt(chat_id: str, user_prompt: str): | |
chat_history = get_chat_history(chat_id) | |
refined_prompt = llm.invoke(f"{chat_history}\nUser Prompt: {user_prompt}").content.strip() | |
collection.insert_one({"session_id": chat_id, "role": "user", "content": user_prompt}) | |
collection.insert_one({"session_id": chat_id, "role": "ai", "content": refined_prompt}) | |
return {"refined_prompt": refined_prompt} | |