Dash-inc's picture
Create main.py
db93af0 verified
raw
history blame
4.73 kB
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
import uuid
from concurrent.futures import ThreadPoolExecutor
from pymongo import MongoClient
from urllib.parse import quote_plus
from langchain_groq import ChatGroq
from aura_sr import AuraSR
from io import BytesIO
from PIL import Image
import requests
import os
app = FastAPI()
# Middleware for CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Globals
executor = ThreadPoolExecutor(max_workers=10)
llm = None
upscale_model = None
client = MongoClient(f"mongodb+srv://hammad:{quote_plus('momimaad@123')}@cluster0.2a9yu.mongodb.net/")
db = client["Flux"]
collection = db["chat_histories"]
chat_sessions = {}
image_storage_dir = "./images" # Directory to save images locally
# Ensure the image storage directory exists
os.makedirs(image_storage_dir, exist_ok=True)
@app.on_event("startup")
async def startup():
global llm, upscale_model
llm = ChatGroq(
model="llama-3.3-70b-versatile",
temperature=0.7,
max_tokens=1024,
api_key="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",
)
upscale_model = AuraSR.from_pretrained("fal/AuraSR-v2")
@app.on_event("shutdown")
def shutdown():
client.close()
executor.shutdown()
# Pydantic models
class ImageRequest(BaseModel):
subject: str
style: str
color_theme: str
elements: str
color_mode: str
lighting_conditions: str
framing_style: str
material_details: str
text: str
background_details: str
user_prompt: str
chat_id: str
# Helper Functions
def generate_chat_id():
chat_id = str(uuid.uuid4())
chat_sessions[chat_id] = collection
return chat_id
def get_chat_history(chat_id):
messages = collection.find({"session_id": chat_id})
return "\n".join(
f"User: {msg['content']}" if msg['role'] == "user" else f"AI: {msg['content']}"
for msg in messages
)
def save_image_locally(image, filename):
filepath = os.path.join(image_storage_dir, filename)
image.save(filepath, format="PNG")
return filepath
# Endpoints
@app.post("/new-chat", response_model=dict)
async def new_chat():
chat_id = generate_chat_id()
return {"chat_id": chat_id}
@app.post("/generate-image", response_model=dict)
async def generate_image(request: ImageRequest, background_tasks: BackgroundTasks):
def process_request():
chat_history = get_chat_history(request.chat_id)
prompt = f"""
Subject: {request.subject}
Style: {request.style}
...
Chat History: {chat_history}
User Prompt: {request.user_prompt}
"""
refined_prompt = llm.invoke(prompt).content.strip()
collection.insert_one({"session_id": request.chat_id, "role": "user", "content": request.user_prompt})
collection.insert_one({"session_id": request.chat_id, "role": "ai", "content": refined_prompt})
# Simulate image generation
response = requests.post(
"https://api.bfl.ml/v1/flux-pro-1.1",
json={"prompt": refined_prompt}
).json()
image_url = response["result"]["sample"]
# Download and save the image locally
image_response = requests.get(image_url)
img = Image.open(BytesIO(image_response.content))
filename = f"generated_{uuid.uuid4()}.png"
filepath = save_image_locally(img, filename)
return filepath
task = executor.submit(process_request)
background_tasks.add_task(task)
return {"status": "Processing"}
@app.post("/upscale-image", response_model=dict)
async def upscale_image(image_url: str, background_tasks: BackgroundTasks):
def process_image():
response = requests.get(image_url)
img = Image.open(BytesIO(response.content))
upscaled_image = upscale_model.upscale_4x_overlapped(img)
filename = f"upscaled_{uuid.uuid4()}.png"
filepath = save_image_locally(upscaled_image, filename)
return filepath
task = executor.submit(process_image)
background_tasks.add_task(task)
return {"status": "Processing"}
@app.post("/set-prompt", response_model=dict)
async def set_prompt(chat_id: str, user_prompt: str):
chat_history = get_chat_history(chat_id)
refined_prompt = llm.invoke(f"{chat_history}\nUser Prompt: {user_prompt}").content.strip()
collection.insert_one({"session_id": chat_id, "role": "user", "content": user_prompt})
collection.insert_one({"session_id": chat_id, "role": "ai", "content": refined_prompt})
return {"refined_prompt": refined_prompt}