Dash-inc's picture
Update main.py
d205093 verified
raw
history blame
7.87 kB
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends
from pydantic import BaseModel
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
import uuid
import tempfile
from concurrent.futures import ThreadPoolExecutor
from pymongo import MongoClient
from urllib.parse import quote_plus
from langchain_groq import ChatGroq
from aura_sr import AuraSR
from io import BytesIO
from PIL import Image
import requests
import os
import os
# Set the Hugging Face cache directory to a writable location
os.environ['HF_HOME'] = '/tmp/huggingface_cache'
app = FastAPI()
# Middleware for CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Globals
executor = ThreadPoolExecutor(max_workers=10)
llm = None
upscale_model = None
client = MongoClient(f"mongodb+srv://hammad:{quote_plus('momimaad@123')}@cluster0.2a9yu.mongodb.net/")
db = client["Flux"]
collection = db["chat_histories"]
chat_sessions = {}
# Use a temporary directory for storing images
image_storage_dir = tempfile.mkdtemp()
print(f"Temporary directory for images: {image_storage_dir}")
# Dictionary to store images temporarily
image_store = {}
@app.on_event("startup")
async def startup():
global llm, upscale_model
llm = ChatGroq(
model="llama-3.3-70b-versatile",
temperature=0.7,
max_tokens=1024,
api_key="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",
)
upscale_model = AuraSR.from_pretrained("fal/AuraSR-v2")
@app.on_event("shutdown")
def shutdown():
client.close()
executor.shutdown()
# Pydantic models
class ImageRequest(BaseModel):
subject: str
style: str
color_theme: str
elements: str
color_mode: str
lighting_conditions: str
framing_style: str
material_details: str
text: str
background_details: str
user_prompt: str
chat_id: str
# Helper Functions
def generate_chat_id():
chat_id = str(uuid.uuid4())
chat_sessions[chat_id] = collection
return chat_id
def get_chat_history(chat_id):
messages = collection.find({"session_id": chat_id})
return "\n".join(
f"User: {msg['content']}" if msg['role'] == "user" else f"AI: {msg['content']}"
for msg in messages
)
def save_image_locally(image, filename):
filepath = os.path.join(image_storage_dir, filename)
image.save(filepath, format="PNG")
return filepath
# Endpoints
@app.post("/new-chat", response_model=dict)
async def new_chat():
chat_id = generate_chat_id()
return {"chat_id": chat_id}
@app.post("/generate-image", response_model=dict)
async def generate_image(request: ImageRequest):
def process_request():
chat_history = get_chat_history(request.chat_id)
prompt = f"""
You are a professional assistant responsible for crafting a clear and visually compelling prompt for an image generation model. Your task is to generate a high-quality prompt for creating both the **main subject** and the **background** of the image.
Image Specifications:
- **Subject**: Focus on **{request.subject}**, highlighting its defining features, expressions, and textures.
- **Style**: Emphasize the **{request.style}**, capturing its key characteristics.
- **Background**: Create a background with **{request.background_details}** that complements and enhances the subject. Ensure it aligns with the color theme and overall composition.
- **Camera and Lighting**:
- Lighting: Apply **{request.lighting_conditions}**, emphasizing depth, highlights, and shadows to accentuate the subject and harmonize the background.
- **Framing**: Use a **{request.framing_style}** to enhance the composition around both the subject and the background.
- **Materials**: Highlight textures like **{request.material_details}**, with realistic details and natural imperfections on the subject and background.
- **Key Elements**: Include **{request.elements}** to enrich the subject’s details and add cohesive elements to the background.
- **Color Theme**: Follow the **{request.color_theme}** to set the mood and tone for the entire scene.
- Negative Prompt: Avoid grainy, blurry, or deformed outputs.
- **Text to Include in Image**: Clearly display the text **"{request.text}"** as part of the composition (e.g., on a card, badge, or banner) attached to the subject in a realistic and contextually appropriate way.
"""
refined_prompt = llm.invoke(prompt).content.strip()
# Save the prompt in MongoDB
collection.insert_one({"session_id": request.chat_id, "role": "user", "content": request.user_prompt})
collection.insert_one({"session_id": request.chat_id, "role": "ai", "content": refined_prompt})
# API call to image generation service
url = "https://api.bfl.ml/v1/flux-pro-1.1"
headers = {
"accept": "application/json",
"x-key": "4f69d408-4979-4812-8ad2-ec6d232c9ddf",
"Content-Type": "application/json"
}
payload = {
"prompt": refined_prompt,
"width": 1024,
"height": 1024,
"guidance_scale": 1,
"num_inference_steps": 50,
"max_sequence_length": 512,
}
# Initial request to generate the image
response = requests.post(url, headers=headers, json=payload).json()
if "id" not in response:
raise HTTPException(status_code=500, detail="Error generating image: ID missing from response")
request_id = response["id"]
# Poll for the image result
while True:
time.sleep(0.5)
result = requests.get(
"https://api.bfl.ml/v1/get_result",
headers=headers,
params={"id": request_id},
).json()
if result["status"] == "Ready":
if "result" in result and "sample" in result["result"]:
image_url = result["result"]["sample"]
# Download the image
image_response = requests.get(image_url)
if image_response.status_code == 200:
img = Image.open(BytesIO(image_response.content))
filename = f"generated_{uuid.uuid4()}.png"
filepath = save_image_locally(img, filename)
return filepath
else:
raise HTTPException(status_code=500, detail="Failed to download the image")
else:
raise HTTPException(status_code=500, detail="Expected 'sample' key not found in the result")
elif result["status"] == "Error":
raise HTTPException(status_code=500, detail=f"Image generation failed: {result.get('error', 'Unknown error')}")
# Run the request processing in a thread
future = executor.submit(process_request)
filepath = await run_in_threadpool(future.result)
return {
"status": "Image generated successfully",
"file_path": filepath,
}
@app.post("/upscale-image", response_model=dict)
async def upscale_image(image_url: str, background_tasks: BackgroundTasks):
def process_image():
response = requests.get(image_url)
img = Image.open(BytesIO(response.content))
upscaled_image = upscale_model.upscale_4x_overlapped(img)
filename = f"upscaled_{uuid.uuid4()}.png"
filepath = save_image_locally(upscaled_image, filename)
return filepath
task = executor.submit(process_image)
background_tasks.add_task(task)
return {"status": "Processing"}
@app.get("/")
async def root():
return {"message": "API is up and running!"}