Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends | |
from pydantic import BaseModel | |
from fastapi.middleware.cors import CORSMiddleware | |
import uuid | |
import tempfile | |
from concurrent.futures import ThreadPoolExecutor | |
from pymongo import MongoClient | |
from urllib.parse import quote_plus | |
from langchain_groq import ChatGroq | |
from aura_sr import AuraSR | |
from io import BytesIO | |
from PIL import Image | |
import requests | |
import os | |
app = FastAPI() | |
# Middleware for CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Globals | |
executor = ThreadPoolExecutor(max_workers=10) | |
llm = None | |
upscale_model = None | |
client = MongoClient(f"mongodb+srv://hammad:{quote_plus('momimaad@123')}@cluster0.2a9yu.mongodb.net/") | |
db = client["Flux"] | |
collection = db["chat_histories"] | |
chat_sessions = {} | |
# Use a temporary directory for storing images | |
image_storage_dir = tempfile.mkdtemp() | |
print(f"Temporary directory for images: {image_storage_dir}") | |
# Dictionary to store images temporarily | |
image_store = {} | |
async def startup(): | |
global llm, upscale_model | |
llm = ChatGroq( | |
model="llama-3.3-70b-versatile", | |
temperature=0.7, | |
max_tokens=1024, | |
api_key="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", | |
) | |
upscale_model = AuraSR.from_pretrained("fal/AuraSR-v2") | |
def shutdown(): | |
client.close() | |
executor.shutdown() | |
# Pydantic models | |
class ImageRequest(BaseModel): | |
subject: str | |
style: str | |
color_theme: str | |
elements: str | |
color_mode: str | |
lighting_conditions: str | |
framing_style: str | |
material_details: str | |
text: str | |
background_details: str | |
user_prompt: str | |
chat_id: str | |
# Helper Functions | |
def generate_chat_id(): | |
chat_id = str(uuid.uuid4()) | |
chat_sessions[chat_id] = collection | |
return chat_id | |
def get_chat_history(chat_id): | |
messages = collection.find({"session_id": chat_id}) | |
return "\n".join( | |
f"User: {msg['content']}" if msg['role'] == "user" else f"AI: {msg['content']}" | |
for msg in messages | |
) | |
def save_image_locally(image, filename): | |
filepath = os.path.join(image_storage_dir, filename) | |
image.save(filepath, format="PNG") | |
return filepath | |
# Endpoints | |
async def new_chat(): | |
chat_id = generate_chat_id() | |
return {"chat_id": chat_id} | |
async def generate_image(request: ImageRequest, background_tasks: BackgroundTasks): | |
def process_request(): | |
chat_history = get_chat_history(request.chat_id) | |
prompt = f""" | |
You are a professional assistant responsible for crafting a clear and visually compelling prompt for an image generation model. Your task is to generate a high-quality prompt for creating both the **main subject** and the **background** of the image. | |
Image Specifications: | |
- **Subject**: Focus on **{subject}**, highlighting its defining features, expressions, and textures. | |
- **Style**: Emphasize the **{style}**, capturing its key characteristics. | |
- **Background**: Create a background with **{background_details}** that complements and enhances the subject. Ensure it aligns with the color theme and overall composition. | |
- **Camera and Lighting**: | |
- Lighting: Apply **{lighting_conditions}**, emphasizing depth, highlights, and shadows to accentuate the subject and harmonize the background. | |
- **Framing**: Use a **{framing_style}** to enhance the composition around both the subject and the background. | |
- **Materials**: Highlight textures like **{material_details}**, with realistic details and natural imperfections on the subject and background. | |
- **Key Elements**: Include **{elements}** to enrich the subject’s details and add cohesive elements to the background. | |
- **Color Theme**: Follow the **{color_theme}** to set the mood and tone for the entire scene. | |
- Negative Prompt: Avoid grainy, blurry, or deformed outputs. | |
- **Text to Include in Image**: Clearly display the text **"{text}"** as part of the composition (e.g., on a card, badge, or banner) attached to the subject in a realistic and contextually appropriate way. | |
- Write the prompt only in the output. Do not include anything else except the prompt. Do not write Generated Image Prompt: in the output. | |
Chat History: | |
{chat_history_text} | |
User Prompt: | |
{user_prompt} | |
Generated Image Prompt: | |
""" | |
refined_prompt = llm.invoke(prompt).content.strip() | |
collection.insert_one({"session_id": request.chat_id, "role": "user", "content": request.user_prompt}) | |
collection.insert_one({"session_id": request.chat_id, "role": "ai", "content": refined_prompt}) | |
# Simulate image generation | |
response = requests.post( | |
"https://api.bfl.ml/v1/flux-pro-1.1", | |
json={"prompt": refined_prompt} | |
).json() | |
image_url = response["result"]["sample"] | |
# Download and save the image locally | |
image_response = requests.get(image_url) | |
img = Image.open(BytesIO(image_response.content)) | |
filename = f"generated_{uuid.uuid4()}.png" | |
filepath = save_image_locally(img, filename) | |
return filepath | |
task = executor.submit(process_request) | |
background_tasks.add_task(task) | |
return {"status": "Processing"} | |
async def upscale_image(image_url: str, background_tasks: BackgroundTasks): | |
def process_image(): | |
response = requests.get(image_url) | |
img = Image.open(BytesIO(response.content)) | |
upscaled_image = upscale_model.upscale_4x_overlapped(img) | |
filename = f"upscaled_{uuid.uuid4()}.png" | |
filepath = save_image_locally(upscaled_image, filename) | |
return filepath | |
task = executor.submit(process_image) | |
background_tasks.add_task(task) | |
return {"status": "Processing"} | |