Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -15,11 +15,21 @@ from io import BytesIO
|
|
15 |
from PIL import Image
|
16 |
import requests
|
17 |
import os
|
|
|
18 |
from dotenv import load_dotenv
|
19 |
|
20 |
# Load environment variables
|
21 |
load_dotenv()
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
# Set the Hugging Face cache directory to a writable location
|
24 |
os.environ['HF_HOME'] = '/tmp/huggingface_cache'
|
25 |
|
@@ -51,16 +61,16 @@ aura_sr = None
|
|
51 |
@app.on_event("startup")
|
52 |
async def startup():
|
53 |
global llm, aura_sr
|
54 |
-
llm = ChatGroq(
|
55 |
-
model="llama-3.3-70b-versatile",
|
56 |
-
temperature=0.7,
|
57 |
-
max_tokens=1024,
|
58 |
-
api_key=os.getenv('LLM_API_KEY'),
|
59 |
-
)
|
60 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2")
|
62 |
except Exception as e:
|
63 |
-
|
64 |
|
65 |
@app.on_event("shutdown")
|
66 |
def shutdown():
|
@@ -104,11 +114,13 @@ def save_image_locally(image, filename):
|
|
104 |
return filepath
|
105 |
|
106 |
def fetch_image(url):
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
112 |
|
113 |
def poll_for_image_result(request_id, headers):
|
114 |
timeout = 60
|
@@ -131,7 +143,6 @@ def poll_for_image_result(request_id, headers):
|
|
131 |
|
132 |
raise HTTPException(status_code=500, detail="Image generation timed out.")
|
133 |
|
134 |
-
# Endpoints
|
135 |
@app.post("/new-chat", response_model=dict)
|
136 |
async def new_chat():
|
137 |
chat_id = generate_chat_id()
|
@@ -176,14 +187,12 @@ async def generate_image(request: ImageRequest):
|
|
176 |
"max_sequence_length": 512,
|
177 |
}
|
178 |
|
179 |
-
|
180 |
-
response = session.post("https://api.bfl.ml/v1/flux-pro-1.1", headers=headers, json=payload, timeout=10).json()
|
181 |
|
182 |
if "id" not in response:
|
183 |
raise HTTPException(status_code=500, detail="Error generating image: ID missing from response")
|
184 |
|
185 |
image_url = poll_for_image_result(response["id"], headers)
|
186 |
-
|
187 |
image = fetch_image(image_url)
|
188 |
filename = f"generated_{uuid.uuid4()}.png"
|
189 |
filepath = save_image_locally(image, filename)
|
@@ -199,20 +208,21 @@ async def upscale_image(request: UpscaleRequest):
|
|
199 |
if aura_sr is None:
|
200 |
raise HTTPException(status_code=500, detail="Upscaling model not initialized.")
|
201 |
|
202 |
-
|
203 |
-
img = fetch_image(request.image_url)
|
204 |
-
upscaled_image = aura_sr.upscale_4x_overlapped(img)
|
205 |
|
|
|
|
|
206 |
filename = f"upscaled_{uuid.uuid4()}.png"
|
207 |
-
|
208 |
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
|
|
216 |
|
217 |
@app.get("/")
|
218 |
async def root():
|
|
|
15 |
from PIL import Image
|
16 |
import requests
|
17 |
import os
|
18 |
+
import logging
|
19 |
from dotenv import load_dotenv
|
20 |
|
21 |
# Load environment variables
|
22 |
load_dotenv()
|
23 |
|
24 |
+
# Validate environment variables
|
25 |
+
assert os.getenv('MONGO_USER') and os.getenv('MONGO_PASSWORD') and os.getenv('MONGO_HOST'), "MongoDB credentials missing!"
|
26 |
+
assert os.getenv('LLM_API_KEY'), "LLM API Key missing!"
|
27 |
+
assert os.getenv('BFL_API_KEY'), "BFL API Key missing!"
|
28 |
+
|
29 |
+
# Configure logging
|
30 |
+
logging.basicConfig(level=logging.INFO)
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
|
33 |
# Set the Hugging Face cache directory to a writable location
|
34 |
os.environ['HF_HOME'] = '/tmp/huggingface_cache'
|
35 |
|
|
|
61 |
@app.on_event("startup")
|
62 |
async def startup():
|
63 |
global llm, aura_sr
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
try:
|
65 |
+
llm = ChatGroq(
|
66 |
+
model="llama-3.3-70b-versatile",
|
67 |
+
temperature=0.7,
|
68 |
+
max_tokens=1024,
|
69 |
+
api_key=os.getenv('LLM_API_KEY'),
|
70 |
+
)
|
71 |
aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2")
|
72 |
except Exception as e:
|
73 |
+
logger.error(f"Error initializing models: {e}")
|
74 |
|
75 |
@app.on_event("shutdown")
|
76 |
def shutdown():
|
|
|
114 |
return filepath
|
115 |
|
116 |
def fetch_image(url):
|
117 |
+
try:
|
118 |
+
with requests.Session() as session:
|
119 |
+
response = session.get(url, timeout=10)
|
120 |
+
response.raise_for_status()
|
121 |
+
return Image.open(BytesIO(response.content))
|
122 |
+
except Exception as e:
|
123 |
+
raise HTTPException(status_code=400, detail=f"Error fetching image: {str(e)}")
|
124 |
|
125 |
def poll_for_image_result(request_id, headers):
|
126 |
timeout = 60
|
|
|
143 |
|
144 |
raise HTTPException(status_code=500, detail="Image generation timed out.")
|
145 |
|
|
|
146 |
@app.post("/new-chat", response_model=dict)
|
147 |
async def new_chat():
|
148 |
chat_id = generate_chat_id()
|
|
|
187 |
"max_sequence_length": 512,
|
188 |
}
|
189 |
|
190 |
+
response = make_request_with_retries("https://api.bfl.ml/v1/flux-pro-1.1", headers, payload)
|
|
|
191 |
|
192 |
if "id" not in response:
|
193 |
raise HTTPException(status_code=500, detail="Error generating image: ID missing from response")
|
194 |
|
195 |
image_url = poll_for_image_result(response["id"], headers)
|
|
|
196 |
image = fetch_image(image_url)
|
197 |
filename = f"generated_{uuid.uuid4()}.png"
|
198 |
filepath = save_image_locally(image, filename)
|
|
|
208 |
if aura_sr is None:
|
209 |
raise HTTPException(status_code=500, detail="Upscaling model not initialized.")
|
210 |
|
211 |
+
img = await run_in_threadpool(fetch_image, request.image_url)
|
|
|
|
|
212 |
|
213 |
+
def perform_upscaling():
|
214 |
+
upscaled_image = aura_sr.upscale_4x_overlapped(img)
|
215 |
filename = f"upscaled_{uuid.uuid4()}.png"
|
216 |
+
return save_image_locally(upscaled_image, filename)
|
217 |
|
218 |
+
future = executor.submit(perform_upscaling)
|
219 |
+
filepath = await run_in_threadpool(lambda: future.result())
|
220 |
+
|
221 |
+
return {
|
222 |
+
"status": "Upscaling successful",
|
223 |
+
"file_path": filepath,
|
224 |
+
"file_url": f"/images/{os.path.basename(filepath)}",
|
225 |
+
}
|
226 |
|
227 |
@app.get("/")
|
228 |
async def root():
|