Dash-inc's picture
Update main.py
69138eb verified
raw
history blame
6.85 kB
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends
from pydantic import BaseModel
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
import uuid
import tempfile
from concurrent.futures import ThreadPoolExecutor
from pymongo import MongoClient
from urllib.parse import quote_plus
from langchain_groq import ChatGroq
from aura_sr import AuraSR
from io import BytesIO
from PIL import Image
import requests
import os
import os
# Set the Hugging Face cache directory to a writable location
os.environ['HF_HOME'] = '/tmp/huggingface_cache'
app = FastAPI()
# Middleware for CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Globals
executor = ThreadPoolExecutor(max_workers=10)
llm = None
upscale_model = None
client = MongoClient(f"mongodb+srv://hammad:{quote_plus('momimaad@123')}@cluster0.2a9yu.mongodb.net/")
db = client["Flux"]
collection = db["chat_histories"]
chat_sessions = {}
# Use a temporary directory for storing images
image_storage_dir = tempfile.mkdtemp()
print(f"Temporary directory for images: {image_storage_dir}")
# Dictionary to store images temporarily
image_store = {}
@app.on_event("startup")
async def startup():
global llm, upscale_model
llm = ChatGroq(
model="llama-3.3-70b-versatile",
temperature=0.7,
max_tokens=1024,
api_key="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",
)
upscale_model = AuraSR.from_pretrained("fal/AuraSR-v2")
@app.on_event("shutdown")
def shutdown():
client.close()
executor.shutdown()
# Pydantic models
class ImageRequest(BaseModel):
subject: str
style: str
color_theme: str
elements: str
color_mode: str
lighting_conditions: str
framing_style: str
material_details: str
text: str
background_details: str
user_prompt: str
chat_id: str
# Helper Functions
def generate_chat_id():
chat_id = str(uuid.uuid4())
chat_sessions[chat_id] = collection
return chat_id
def get_chat_history(chat_id):
messages = collection.find({"session_id": chat_id})
return "\n".join(
f"User: {msg['content']}" if msg['role'] == "user" else f"AI: {msg['content']}"
for msg in messages
)
def save_image_locally(image, filename):
filepath = os.path.join(image_storage_dir, filename)
image.save(filepath, format="PNG")
return filepath
# Endpoints
@app.post("/new-chat", response_model=dict)
async def new_chat():
chat_id = generate_chat_id()
return {"chat_id": chat_id}
@app.post("/generate-image", response_model=dict)
async def generate_image(request: ImageRequest):
def process_request():
chat_history = get_chat_history(request.chat_id)
prompt = f"""
You are a professional assistant responsible for crafting a clear and visually compelling prompt for an image generation model. Your task is to generate a high-quality prompt for creating both the **main subject** and the **background** of the image.
Image Specifications:
- **Subject**: Focus on **{request.subject}**, highlighting its defining features, expressions, and textures.
- **Style**: Emphasize the **{request.style}**, capturing its key characteristics.
- **Background**: Create a background with **{request.background_details}** that complements and enhances the subject. Ensure it aligns with the color theme and overall composition.
- **Camera and Lighting**:
- Lighting: Apply **{request.lighting_conditions}**, emphasizing depth, highlights, and shadows to accentuate the subject and harmonize the background.
- **Framing**: Use a **{request.framing_style}** to enhance the composition around both the subject and the background.
- **Materials**: Highlight textures like **{request.material_details}**, with realistic details and natural imperfections on the subject and background.
- **Key Elements**: Include **{request.elements}** to enrich the subject’s details and add cohesive elements to the background.
- **Color Theme**: Follow the **{request.color_theme}** to set the mood and tone for the entire scene.
- Negative Prompt: Avoid grainy, blurry, or deformed outputs.
- **Text to Include in Image**: Clearly display the text **"{request.text}"** as part of the composition (e.g., on a card, badge, or banner) attached to the subject in a realistic and contextually appropriate way.
- Write the prompt only in the output. Do not include anything else except the prompt. Do not write Generated Image Prompt: in the output.
Chat History:
{chat_history}
User Prompt:
{request.user_prompt}
Generated Image Prompt:
"""
refined_prompt = llm.invoke(prompt).content.strip()
collection.insert_one({"session_id": request.chat_id, "role": "user", "content": request.user_prompt})
collection.insert_one({"session_id": request.chat_id, "role": "ai", "content": refined_prompt})
# Simulate image generation
response = requests.post(
"https://api.bfl.ml/v1/flux-pro-1.1",
json={"prompt": refined_prompt}
).json()
image_url = response["result"]["sample"]
# Download and save the image locally
image_response = requests.get(image_url)
img = Image.open(BytesIO(image_response.content))
filename = f"generated_{uuid.uuid4()}.png"
filepath = save_image_locally(img, filename)
return filepath
# Run the request processing in a thread
future = executor.submit(process_request)
filepath = await run_in_threadpool(future.result) # Wait for the task to complete
# Load the image to return as a response
with open(filepath, "rb") as f:
image_data = f.read()
# Convert the file path into a downloadable format
file_url = f"/images/{os.path.basename(filepath)}"
# Return the response
return {
"status": "Image generated successfully",
"file_path": filepath,
"file_url": file_url,
"image": image_data,
}
@app.post("/upscale-image", response_model=dict)
async def upscale_image(image_url: str, background_tasks: BackgroundTasks):
def process_image():
response = requests.get(image_url)
img = Image.open(BytesIO(response.content))
upscaled_image = upscale_model.upscale_4x_overlapped(img)
filename = f"upscaled_{uuid.uuid4()}.png"
filepath = save_image_locally(upscaled_image, filename)
return filepath
task = executor.submit(process_image)
background_tasks.add_task(task)
return {"status": "Processing"}
@app.get("/")
async def root():
return {"message": "API is up and running!"}