Spaces:
Sleeping
Sleeping
File size: 9,216 Bytes
df5fc1f db93af0 3555d2b 69138eb df5fc1f 4dd92b9 bab31ce df5fc1f db93af0 df5fc1f a8df78e 897b11c 544fc6a a8df78e 544fc6a db93af0 df5fc1f 897b11c df5fc1f 897b11c df5fc1f 1ba2980 897b11c df5fc1f 3555d2b 897b11c db93af0 df5fc1f 873d584 a8df78e 873d584 a8df78e 1ba2980 df5fc1f 897b11c df5fc1f 1ba2980 df5fc1f db93af0 df5fc1f db93af0 1ba2980 df5fc1f 1ba2980 897b11c df5fc1f db93af0 df5fc1f 0d39048 897b11c a8df78e 0d39048 a8df78e 897b11c df5fc1f d95a2ab 897b11c d205093 897b11c d205093 a8df78e d205093 897b11c d95a2ab 897b11c d95a2ab df5fc1f 8891183 a8df78e 8891183 a8df78e 8891183 a8df78e 8891183 a8df78e 8891183 1308b02 897b11c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends
from pydantic import BaseModel
from fastapi.staticfiles import StaticFiles
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
import uuid
import time
import tempfile
from concurrent.futures import ThreadPoolExecutor
from pymongo import MongoClient
from urllib.parse import quote_plus
from langchain_groq import ChatGroq
from aura_sr import AuraSR
from io import BytesIO
from PIL import Image
import requests
import os
import logging
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Validate environment variables
assert os.getenv('MONGO_USER') and os.getenv('MONGO_PASSWORD') and os.getenv('MONGO_HOST'), "MongoDB credentials missing!"
assert os.getenv('LLM_API_KEY'), "LLM API Key missing!"
assert os.getenv('BFL_API_KEY'), "BFL API Key missing!"
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set the Hugging Face cache directory to a writable location
os.environ['HF_HOME'] = '/tmp/huggingface_cache'
app = FastAPI()
# Middleware for CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Globals
executor = ThreadPoolExecutor(max_workers=5)
llm = None
mongo_client = MongoClient(f"mongodb+srv://{os.getenv('MONGO_USER')}:{quote_plus(os.getenv('MONGO_PASSWORD'))}@{os.getenv('MONGO_HOST')}/")
db = mongo_client["Flux"]
collection = db["chat_histories"]
chat_sessions = {}
# Temporary directory for storing images
image_storage_dir = tempfile.mkdtemp()
app.mount("/images", StaticFiles(directory=image_storage_dir), name="images")
# Initialize AuraSR during startup
aura_sr = None
@app.on_event("startup")
async def startup():
global llm, aura_sr
try:
llm = ChatGroq(
model="llama-3.3-70b-versatile",
temperature=0.7,
max_tokens=1024,
api_key=os.getenv('LLM_API_KEY'),
)
aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2")
except Exception as e:
logger.error(f"Error initializing models: {e}")
@app.on_event("shutdown")
def shutdown():
mongo_client.close()
executor.shutdown()
# Pydantic models
class ImageRequest(BaseModel):
subject: str
style: str
color_theme: str
elements: str
color_mode: str
lighting_conditions: str
framing_style: str
material_details: str
text: str
background_details: str
user_prompt: str
chat_id: str
class UpscaleRequest(BaseModel):
image_url: str
# Helper functions
def generate_chat_id():
chat_id = str(uuid.uuid4())
chat_sessions[chat_id] = collection
return chat_id
def get_chat_history(chat_id):
messages = collection.find({"session_id": chat_id})
return "\n".join(
f"User: {msg['content']}" if msg['role'] == "user" else f"AI: {msg['content']}"
for msg in messages
)
def save_image_locally(image, filename):
filepath = os.path.join(image_storage_dir, filename)
image.save(filepath, format="PNG")
return filepath
def make_request_with_retries(url, headers, payload, retries=3, delay=2):
"""
Makes an HTTP POST request with retries in case of failure.
:param url: The URL for the request.
:param headers: Headers to include in the request.
:param payload: Payload to include in the request.
:param retries: Number of retries on failure.
:param delay: Delay between retries.
:return: Response JSON from the server.
"""
for attempt in range(retries):
try:
with requests.Session() as session:
response = session.post(url, headers=headers, json=payload, timeout=30)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
if attempt < retries - 1:
time.sleep(delay)
continue
else:
raise HTTPException(status_code=500, detail=f"Request failed after {retries} attempts: {str(e)}")
def fetch_image(url):
try:
with requests.Session() as session:
response = session.get(url, timeout=30)
response.raise_for_status()
return Image.open(BytesIO(response.content))
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error fetching image: {str(e)}")
def poll_for_image_result(request_id, headers):
timeout = 60
start_time = time.time()
while time.time() - start_time < timeout:
time.sleep(0.5)
with requests.Session() as session:
result = session.get(
"https://api.bfl.ml/v1/get_result",
headers=headers,
params={"id": request_id},
timeout=10
).json()
if result["status"] == "Ready":
return result["result"].get("sample")
elif result["status"] == "Error":
raise HTTPException(status_code=500, detail=f"Image generation failed: {result.get('error', 'Unknown error')}")
raise HTTPException(status_code=500, detail="Image generation timed out.")
@app.post("/new-chat", response_model=dict)
async def new_chat():
chat_id = generate_chat_id()
return {"chat_id": chat_id}
@app.post("/generate-image", response_model=dict)
async def generate_image(request: ImageRequest):
chat_history = get_chat_history(request.chat_id)
prompt = f"""
You are a professional assistant responsible for crafting a clear and visually compelling prompt for an image generation model. Your task is to generate a high-quality prompt for creating both the **main subject** and the **background** of the image.
Image Specifications:
- **Subject**: Focus on **{request.subject}**, highlighting its defining features, expressions, and textures.
- **Style**: Emphasize the **{request.style}**, capturing its key characteristics.
- **Background**: Create a background with **{request.background_details}** that complements and enhances the subject. Ensure it aligns with the color theme and overall composition.
- **Camera and Lighting**:
- Lighting: Apply **{request.lighting_conditions}**, emphasizing depth, highlights, and shadows to accentuate the subject and harmonize the background.
- **Framing**: Use a **{request.framing_style}** to enhance the composition around both the subject and the background.
- **Materials**: Highlight textures like **{request.material_details}**, with realistic details and natural imperfections on the subject and background.
- **Key Elements**: Include **{request.elements}** to enrich the subject’s details and add cohesive elements to the background.
- **Color Theme**: Follow the **{request.color_theme}** to set the mood and tone for the entire scene.
- Negative Prompt: Avoid grainy, blurry, or deformed outputs.
- **Text to Include in Image**: Clearly display the text **"{request.text}"** as part of the composition (e.g., on a card, badge, or banner) attached to the subject in a realistic and contextually appropriate way.
"""
refined_prompt = llm.invoke(prompt).content.strip()
collection.insert_one({"session_id": request.chat_id, "role": "user", "content": request.user_prompt})
collection.insert_one({"session_id": request.chat_id, "role": "ai", "content": refined_prompt})
headers = {
"accept": "application/json",
"x-key": os.getenv('BFL_API_KEY'),
"Content-Type": "application/json"
}
payload = {
"prompt": refined_prompt,
"width": 1024,
"height": 1024,
"guidance_scale": 1,
"num_inference_steps": 50,
"max_sequence_length": 512,
}
response = make_request_with_retries("https://api.bfl.ml/v1/flux-pro-1.1", headers, payload)
if "id" not in response:
raise HTTPException(status_code=500, detail="Error generating image: ID missing from response")
image_url = poll_for_image_result(response["id"], headers)
image = fetch_image(image_url)
filename = f"generated_{uuid.uuid4()}.png"
filepath = save_image_locally(image, filename)
return {
"status": "Image generated successfully",
"file_path": filepath,
"file_url": f"/images/{filename}",
}
@app.post("/upscale-image", response_model=dict)
async def upscale_image(request: UpscaleRequest):
if aura_sr is None:
raise HTTPException(status_code=500, detail="Upscaling model not initialized.")
img = await run_in_threadpool(fetch_image, request.image_url)
def perform_upscaling():
upscaled_image = aura_sr.upscale_4x_overlapped(img)
filename = f"upscaled_{uuid.uuid4()}.png"
return save_image_locally(upscaled_image, filename)
future = executor.submit(perform_upscaling)
filepath = await run_in_threadpool(lambda: future.result())
return {
"status": "Upscaling successful",
"file_path": filepath,
"file_url": f"/images/{os.path.basename(filepath)}",
}
@app.get("/")
async def root():
return {"message": "API is up and running!"}
|