Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -5,7 +5,7 @@ from fastapi.concurrency import run_in_threadpool
|
|
5 |
from fastapi.middleware.cors import CORSMiddleware
|
6 |
import uuid
|
7 |
import time
|
8 |
-
|
9 |
from concurrent.futures import ThreadPoolExecutor
|
10 |
from pymongo import MongoClient
|
11 |
from urllib.parse import quote_plus
|
@@ -15,7 +15,10 @@ from io import BytesIO
|
|
15 |
from PIL import Image
|
16 |
import requests
|
17 |
import os
|
18 |
-
import
|
|
|
|
|
|
|
19 |
|
20 |
# Set the Hugging Face cache directory to a writable location
|
21 |
os.environ['HF_HOME'] = '/tmp/huggingface_cache'
|
@@ -31,22 +34,19 @@ app.add_middleware(
|
|
31 |
)
|
32 |
|
33 |
# Globals
|
34 |
-
executor = ThreadPoolExecutor(max_workers=
|
35 |
llm = None
|
36 |
-
|
37 |
-
db =
|
38 |
collection = db["chat_histories"]
|
39 |
chat_sessions = {}
|
40 |
|
41 |
-
#
|
42 |
image_storage_dir = tempfile.mkdtemp()
|
43 |
-
print(f"Temporary directory for images: {image_storage_dir}")
|
44 |
-
|
45 |
-
# Serve the temporary image directory as a static file directory
|
46 |
app.mount("/images", StaticFiles(directory=image_storage_dir), name="images")
|
47 |
|
48 |
-
#
|
49 |
-
|
50 |
|
51 |
@app.on_event("startup")
|
52 |
async def startup():
|
@@ -55,19 +55,16 @@ async def startup():
|
|
55 |
model="llama-3.3-70b-versatile",
|
56 |
temperature=0.7,
|
57 |
max_tokens=1024,
|
58 |
-
api_key=
|
59 |
)
|
60 |
-
# Ensure AuraSR is initialized properly
|
61 |
try:
|
62 |
aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2")
|
63 |
except Exception as e:
|
64 |
print(f"Error initializing AuraSR: {e}")
|
65 |
-
aura_sr = None
|
66 |
-
|
67 |
|
68 |
@app.on_event("shutdown")
|
69 |
def shutdown():
|
70 |
-
|
71 |
executor.shutdown()
|
72 |
|
73 |
# Pydantic models
|
@@ -85,7 +82,10 @@ class ImageRequest(BaseModel):
|
|
85 |
user_prompt: str
|
86 |
chat_id: str
|
87 |
|
88 |
-
|
|
|
|
|
|
|
89 |
def generate_chat_id():
|
90 |
chat_id = str(uuid.uuid4())
|
91 |
chat_sessions[chat_id] = collection
|
@@ -103,6 +103,34 @@ def save_image_locally(image, filename):
|
|
103 |
image.save(filepath, format="PNG")
|
104 |
return filepath
|
105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
# Endpoints
|
107 |
@app.post("/new-chat", response_model=dict)
|
108 |
async def new_chat():
|
@@ -111,148 +139,81 @@ async def new_chat():
|
|
111 |
|
112 |
@app.post("/generate-image", response_model=dict)
|
113 |
async def generate_image(request: ImageRequest):
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
"accept": "application/json",
|
142 |
-
"x-key": "4f69d408-4979-4812-8ad2-ec6d232c9ddf",
|
143 |
-
"Content-Type": "application/json"
|
144 |
-
}
|
145 |
-
payload = {
|
146 |
-
"prompt": refined_prompt,
|
147 |
-
"width": 1024,
|
148 |
-
"height": 1024,
|
149 |
-
"guidance_scale": 1,
|
150 |
-
"num_inference_steps": 50,
|
151 |
-
"max_sequence_length": 512,
|
152 |
-
}
|
153 |
|
154 |
-
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
-
|
|
|
159 |
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
"https://api.bfl.ml/v1/get_result",
|
165 |
-
headers=headers,
|
166 |
-
params={"id": request_id},
|
167 |
-
).json()
|
168 |
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
# Download the image
|
174 |
-
image_response = requests.get(image_url)
|
175 |
-
if image_response.status_code == 200:
|
176 |
-
img = Image.open(BytesIO(image_response.content))
|
177 |
-
filename = f"generated_{uuid.uuid4()}.png"
|
178 |
-
filepath = save_image_locally(img, filename)
|
179 |
-
|
180 |
-
# Generate a URL for the image
|
181 |
-
file_url = f"/images/{filename}"
|
182 |
-
return filepath, file_url
|
183 |
-
else:
|
184 |
-
raise HTTPException(status_code=500, detail="Failed to download the image")
|
185 |
-
else:
|
186 |
-
raise HTTPException(status_code=500, detail="Expected 'sample' key not found in the result")
|
187 |
-
elif result["status"] == "Error":
|
188 |
-
raise HTTPException(status_code=500, detail=f"Image generation failed: {result.get('error', 'Unknown error')}")
|
189 |
-
|
190 |
-
future = executor.submit(process_request)
|
191 |
-
filepath, file_url = await run_in_threadpool(future.result)
|
192 |
|
193 |
return {
|
194 |
"status": "Image generated successfully",
|
195 |
"file_path": filepath,
|
196 |
-
"file_url":
|
197 |
}
|
198 |
|
199 |
-
class UpscaleRequest(BaseModel):
|
200 |
-
image_url: str
|
201 |
-
|
202 |
-
|
203 |
-
def process_image(image_url):
|
204 |
-
if aura_sr is None:
|
205 |
-
raise RuntimeError("aura_sr is None. Ensure it's initialized during startup.")
|
206 |
-
|
207 |
-
response = requests.get(image_url)
|
208 |
-
img = Image.open(BytesIO(response.content))
|
209 |
-
|
210 |
-
try:
|
211 |
-
upscaled_image = aura_sr.upscale_4x_overlapped(img)
|
212 |
-
except Exception as e:
|
213 |
-
raise RuntimeError(f"Error during upscaling: {str(e)}")
|
214 |
-
|
215 |
-
filename = f"upscaled_{uuid.uuid4()}.png"
|
216 |
-
filepath = save_image_locally(upscaled_image, filename)
|
217 |
-
return filepath
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
@app.post("/upscale-image", response_model=dict)
|
222 |
async def upscale_image(request: UpscaleRequest):
|
223 |
if aura_sr is None:
|
224 |
raise HTTPException(status_code=500, detail="Upscaling model not initialized.")
|
225 |
|
226 |
try:
|
227 |
-
|
228 |
-
response = requests.get(request.image_url)
|
229 |
-
if response.status_code != 200:
|
230 |
-
raise HTTPException(status_code=400, detail="Failed to fetch the image from the provided URL.")
|
231 |
-
|
232 |
-
img = Image.open(BytesIO(response.content))
|
233 |
-
|
234 |
-
# Perform upscaling
|
235 |
upscaled_image = aura_sr.upscale_4x_overlapped(img)
|
236 |
|
237 |
-
# Save the upscaled image locally
|
238 |
filename = f"upscaled_{uuid.uuid4()}.png"
|
239 |
filepath = save_image_locally(upscaled_image, filename)
|
240 |
|
241 |
-
# Generate a public URL for the upscaled image
|
242 |
-
file_url = f"/images/{filename}"
|
243 |
-
|
244 |
return {
|
245 |
"status": "Upscaling successful",
|
246 |
"file_path": filepath,
|
247 |
-
"file_url":
|
248 |
}
|
249 |
except Exception as e:
|
250 |
raise HTTPException(status_code=500, detail=f"Error during upscaling: {str(e)}")
|
251 |
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
@app.get("/")
|
257 |
async def root():
|
258 |
-
return {"message": "API is up and running!"}
|
|
|
5 |
from fastapi.middleware.cors import CORSMiddleware
|
6 |
import uuid
|
7 |
import time
|
8 |
+
tempfile
|
9 |
from concurrent.futures import ThreadPoolExecutor
|
10 |
from pymongo import MongoClient
|
11 |
from urllib.parse import quote_plus
|
|
|
15 |
from PIL import Image
|
16 |
import requests
|
17 |
import os
|
18 |
+
from dotenv import load_dotenv
|
19 |
+
|
20 |
+
# Load environment variables
|
21 |
+
load_dotenv()
|
22 |
|
23 |
# Set the Hugging Face cache directory to a writable location
|
24 |
os.environ['HF_HOME'] = '/tmp/huggingface_cache'
|
|
|
34 |
)
|
35 |
|
36 |
# Globals
|
37 |
+
executor = ThreadPoolExecutor(max_workers=5)
|
38 |
llm = None
|
39 |
+
mongo_client = MongoClient(f"mongodb+srv://{os.getenv('MONGO_USER')}:{quote_plus(os.getenv('MONGO_PASSWORD'))}@{os.getenv('MONGO_HOST')}/")
|
40 |
+
db = mongo_client["Flux"]
|
41 |
collection = db["chat_histories"]
|
42 |
chat_sessions = {}
|
43 |
|
44 |
+
# Temporary directory for storing images
|
45 |
image_storage_dir = tempfile.mkdtemp()
|
|
|
|
|
|
|
46 |
app.mount("/images", StaticFiles(directory=image_storage_dir), name="images")
|
47 |
|
48 |
+
# Initialize AuraSR during startup
|
49 |
+
aura_sr = None
|
50 |
|
51 |
@app.on_event("startup")
|
52 |
async def startup():
|
|
|
55 |
model="llama-3.3-70b-versatile",
|
56 |
temperature=0.7,
|
57 |
max_tokens=1024,
|
58 |
+
api_key=os.getenv('LLM_API_KEY'),
|
59 |
)
|
|
|
60 |
try:
|
61 |
aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2")
|
62 |
except Exception as e:
|
63 |
print(f"Error initializing AuraSR: {e}")
|
|
|
|
|
64 |
|
65 |
@app.on_event("shutdown")
|
66 |
def shutdown():
|
67 |
+
mongo_client.close()
|
68 |
executor.shutdown()
|
69 |
|
70 |
# Pydantic models
|
|
|
82 |
user_prompt: str
|
83 |
chat_id: str
|
84 |
|
85 |
+
class UpscaleRequest(BaseModel):
|
86 |
+
image_url: str
|
87 |
+
|
88 |
+
# Helper functions
|
89 |
def generate_chat_id():
|
90 |
chat_id = str(uuid.uuid4())
|
91 |
chat_sessions[chat_id] = collection
|
|
|
103 |
image.save(filepath, format="PNG")
|
104 |
return filepath
|
105 |
|
106 |
+
def fetch_image(url):
|
107 |
+
with requests.Session() as session:
|
108 |
+
response = session.get(url, timeout=10)
|
109 |
+
if response.status_code != 200:
|
110 |
+
raise HTTPException(status_code=400, detail="Failed to fetch image.")
|
111 |
+
return Image.open(BytesIO(response.content))
|
112 |
+
|
113 |
+
def poll_for_image_result(request_id, headers):
|
114 |
+
timeout = 60
|
115 |
+
start_time = time.time()
|
116 |
+
|
117 |
+
while time.time() - start_time < timeout:
|
118 |
+
time.sleep(0.5)
|
119 |
+
with requests.Session() as session:
|
120 |
+
result = session.get(
|
121 |
+
"https://api.bfl.ml/v1/get_result",
|
122 |
+
headers=headers,
|
123 |
+
params={"id": request_id},
|
124 |
+
timeout=10
|
125 |
+
).json()
|
126 |
+
|
127 |
+
if result["status"] == "Ready":
|
128 |
+
return result["result"].get("sample")
|
129 |
+
elif result["status"] == "Error":
|
130 |
+
raise HTTPException(status_code=500, detail=f"Image generation failed: {result.get('error', 'Unknown error')}")
|
131 |
+
|
132 |
+
raise HTTPException(status_code=500, detail="Image generation timed out.")
|
133 |
+
|
134 |
# Endpoints
|
135 |
@app.post("/new-chat", response_model=dict)
|
136 |
async def new_chat():
|
|
|
139 |
|
140 |
@app.post("/generate-image", response_model=dict)
|
141 |
async def generate_image(request: ImageRequest):
|
142 |
+
chat_history = get_chat_history(request.chat_id)
|
143 |
+
prompt = f"""
|
144 |
+
You are a professional assistant responsible for crafting a clear and visually compelling prompt for an image generation model. Your task is to generate a high-quality prompt for creating both the **main subject** and the **background** of the image.
|
145 |
+
|
146 |
+
Image Specifications:
|
147 |
+
- **Subject**: Focus on **{request.subject}**, highlighting its defining features, expressions, and textures.
|
148 |
+
- **Style**: Emphasize the **{request.style}**, capturing its key characteristics.
|
149 |
+
- **Background**: Create a background with **{request.background_details}** that complements and enhances the subject. Ensure it aligns with the color theme and overall composition.
|
150 |
+
- **Camera and Lighting**:
|
151 |
+
- Lighting: Apply **{request.lighting_conditions}**, emphasizing depth, highlights, and shadows to accentuate the subject and harmonize the background.
|
152 |
+
- **Framing**: Use a **{request.framing_style}** to enhance the composition around both the subject and the background.
|
153 |
+
- **Materials**: Highlight textures like **{request.material_details}**, with realistic details and natural imperfections on the subject and background.
|
154 |
+
- **Key Elements**: Include **{request.elements}** to enrich the subject’s details and add cohesive elements to the background.
|
155 |
+
- **Color Theme**: Follow the **{request.color_theme}** to set the mood and tone for the entire scene.
|
156 |
+
- Negative Prompt: Avoid grainy, blurry, or deformed outputs.
|
157 |
+
- **Text to Include in Image**: Clearly display the text **"{request.text}"** as part of the composition (e.g., on a card, badge, or banner) attached to the subject in a realistic and contextually appropriate way.
|
158 |
+
"""
|
159 |
+
|
160 |
+
refined_prompt = llm.invoke(prompt).content.strip()
|
161 |
+
collection.insert_one({"session_id": request.chat_id, "role": "user", "content": request.user_prompt})
|
162 |
+
collection.insert_one({"session_id": request.chat_id, "role": "ai", "content": refined_prompt})
|
163 |
+
|
164 |
+
headers = {
|
165 |
+
"accept": "application/json",
|
166 |
+
"x-key": os.getenv('BFL_API_KEY'),
|
167 |
+
"Content-Type": "application/json"
|
168 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
+
payload = {
|
171 |
+
"prompt": refined_prompt,
|
172 |
+
"width": 1024,
|
173 |
+
"height": 1024,
|
174 |
+
"guidance_scale": 1,
|
175 |
+
"num_inference_steps": 50,
|
176 |
+
"max_sequence_length": 512,
|
177 |
+
}
|
178 |
|
179 |
+
with requests.Session() as session:
|
180 |
+
response = session.post("https://api.bfl.ml/v1/flux-pro-1.1", headers=headers, json=payload, timeout=10).json()
|
181 |
|
182 |
+
if "id" not in response:
|
183 |
+
raise HTTPException(status_code=500, detail="Error generating image: ID missing from response")
|
184 |
+
|
185 |
+
image_url = poll_for_image_result(response["id"], headers)
|
|
|
|
|
|
|
|
|
186 |
|
187 |
+
image = fetch_image(image_url)
|
188 |
+
filename = f"generated_{uuid.uuid4()}.png"
|
189 |
+
filepath = save_image_locally(image, filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
|
191 |
return {
|
192 |
"status": "Image generated successfully",
|
193 |
"file_path": filepath,
|
194 |
+
"file_url": f"/images/{filename}",
|
195 |
}
|
196 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
@app.post("/upscale-image", response_model=dict)
|
198 |
async def upscale_image(request: UpscaleRequest):
|
199 |
if aura_sr is None:
|
200 |
raise HTTPException(status_code=500, detail="Upscaling model not initialized.")
|
201 |
|
202 |
try:
|
203 |
+
img = fetch_image(request.image_url)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
upscaled_image = aura_sr.upscale_4x_overlapped(img)
|
205 |
|
|
|
206 |
filename = f"upscaled_{uuid.uuid4()}.png"
|
207 |
filepath = save_image_locally(upscaled_image, filename)
|
208 |
|
|
|
|
|
|
|
209 |
return {
|
210 |
"status": "Upscaling successful",
|
211 |
"file_path": filepath,
|
212 |
+
"file_url": f"/images/{filename}",
|
213 |
}
|
214 |
except Exception as e:
|
215 |
raise HTTPException(status_code=500, detail=f"Error during upscaling: {str(e)}")
|
216 |
|
|
|
|
|
|
|
|
|
217 |
@app.get("/")
|
218 |
async def root():
|
219 |
+
return {"message": "API is up and running!"}
|