Dash-inc's picture
Update main.py
df5fc1f verified
raw
history blame
6.13 kB
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
import uuid
import tempfile
from concurrent.futures import ThreadPoolExecutor
from pymongo import MongoClient
from urllib.parse import quote_plus
from langchain_groq import ChatGroq
from aura_sr import AuraSR
from io import BytesIO
from PIL import Image
import requests
import os
app = FastAPI()
# Middleware for CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Globals
executor = ThreadPoolExecutor(max_workers=10)
llm = None
upscale_model = None
client = MongoClient(f"mongodb+srv://hammad:{quote_plus('momimaad@123')}@cluster0.2a9yu.mongodb.net/")
db = client["Flux"]
collection = db["chat_histories"]
chat_sessions = {}
# Use a temporary directory for storing images
image_storage_dir = tempfile.mkdtemp()
print(f"Temporary directory for images: {image_storage_dir}")
# Dictionary to store images temporarily
image_store = {}
@app.on_event("startup")
async def startup():
global llm, upscale_model
llm = ChatGroq(
model="llama-3.3-70b-versatile",
temperature=0.7,
max_tokens=1024,
api_key="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",
)
upscale_model = AuraSR.from_pretrained("fal/AuraSR-v2")
@app.on_event("shutdown")
def shutdown():
client.close()
executor.shutdown()
# Pydantic models
class ImageRequest(BaseModel):
subject: str
style: str
color_theme: str
elements: str
color_mode: str
lighting_conditions: str
framing_style: str
material_details: str
text: str
background_details: str
user_prompt: str
chat_id: str
# Helper Functions
def generate_chat_id():
chat_id = str(uuid.uuid4())
chat_sessions[chat_id] = collection
return chat_id
def get_chat_history(chat_id):
messages = collection.find({"session_id": chat_id})
return "\n".join(
f"User: {msg['content']}" if msg['role'] == "user" else f"AI: {msg['content']}"
for msg in messages
)
def save_image_locally(image, filename):
filepath = os.path.join(image_storage_dir, filename)
image.save(filepath, format="PNG")
return filepath
# Endpoints
@app.post("/new-chat", response_model=dict)
async def new_chat():
chat_id = generate_chat_id()
return {"chat_id": chat_id}
@app.post("/generate-image", response_model=dict)
async def generate_image(request: ImageRequest, background_tasks: BackgroundTasks):
def process_request():
chat_history = get_chat_history(request.chat_id)
prompt = f"""
You are a professional assistant responsible for crafting a clear and visually compelling prompt for an image generation model. Your task is to generate a high-quality prompt for creating both the **main subject** and the **background** of the image.
Image Specifications:
- **Subject**: Focus on **{subject}**, highlighting its defining features, expressions, and textures.
- **Style**: Emphasize the **{style}**, capturing its key characteristics.
- **Background**: Create a background with **{background_details}** that complements and enhances the subject. Ensure it aligns with the color theme and overall composition.
- **Camera and Lighting**:
- Lighting: Apply **{lighting_conditions}**, emphasizing depth, highlights, and shadows to accentuate the subject and harmonize the background.
- **Framing**: Use a **{framing_style}** to enhance the composition around both the subject and the background.
- **Materials**: Highlight textures like **{material_details}**, with realistic details and natural imperfections on the subject and background.
- **Key Elements**: Include **{elements}** to enrich the subject’s details and add cohesive elements to the background.
- **Color Theme**: Follow the **{color_theme}** to set the mood and tone for the entire scene.
- Negative Prompt: Avoid grainy, blurry, or deformed outputs.
- **Text to Include in Image**: Clearly display the text **"{text}"** as part of the composition (e.g., on a card, badge, or banner) attached to the subject in a realistic and contextually appropriate way.
- Write the prompt only in the output. Do not include anything else except the prompt. Do not write Generated Image Prompt: in the output.
Chat History:
{chat_history_text}
User Prompt:
{user_prompt}
Generated Image Prompt:
"""
refined_prompt = llm.invoke(prompt).content.strip()
collection.insert_one({"session_id": request.chat_id, "role": "user", "content": request.user_prompt})
collection.insert_one({"session_id": request.chat_id, "role": "ai", "content": refined_prompt})
# Simulate image generation
response = requests.post(
"https://api.bfl.ml/v1/flux-pro-1.1",
json={"prompt": refined_prompt}
).json()
image_url = response["result"]["sample"]
# Download and save the image locally
image_response = requests.get(image_url)
img = Image.open(BytesIO(image_response.content))
filename = f"generated_{uuid.uuid4()}.png"
filepath = save_image_locally(img, filename)
return filepath
task = executor.submit(process_request)
background_tasks.add_task(task)
return {"status": "Processing"}
@app.post("/upscale-image", response_model=dict)
async def upscale_image(image_url: str, background_tasks: BackgroundTasks):
def process_image():
response = requests.get(image_url)
img = Image.open(BytesIO(response.content))
upscaled_image = upscale_model.upscale_4x_overlapped(img)
filename = f"upscaled_{uuid.uuid4()}.png"
filepath = save_image_locally(upscaled_image, filename)
return filepath
task = executor.submit(process_image)
background_tasks.add_task(task)
return {"status": "Processing"}