Dash-inc commited on
Commit
a8df78e
·
verified ·
1 Parent(s): bab31ce

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +37 -27
main.py CHANGED
@@ -15,11 +15,21 @@ from io import BytesIO
15
  from PIL import Image
16
  import requests
17
  import os
 
18
  from dotenv import load_dotenv
19
 
20
  # Load environment variables
21
  load_dotenv()
22
 
 
 
 
 
 
 
 
 
 
23
  # Set the Hugging Face cache directory to a writable location
24
  os.environ['HF_HOME'] = '/tmp/huggingface_cache'
25
 
@@ -51,16 +61,16 @@ aura_sr = None
51
  @app.on_event("startup")
52
  async def startup():
53
  global llm, aura_sr
54
- llm = ChatGroq(
55
- model="llama-3.3-70b-versatile",
56
- temperature=0.7,
57
- max_tokens=1024,
58
- api_key=os.getenv('LLM_API_KEY'),
59
- )
60
  try:
 
 
 
 
 
 
61
  aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2")
62
  except Exception as e:
63
- print(f"Error initializing AuraSR: {e}")
64
 
65
  @app.on_event("shutdown")
66
  def shutdown():
@@ -104,11 +114,13 @@ def save_image_locally(image, filename):
104
  return filepath
105
 
106
  def fetch_image(url):
107
- with requests.Session() as session:
108
- response = session.get(url, timeout=10)
109
- if response.status_code != 200:
110
- raise HTTPException(status_code=400, detail="Failed to fetch image.")
111
- return Image.open(BytesIO(response.content))
 
 
112
 
113
  def poll_for_image_result(request_id, headers):
114
  timeout = 60
@@ -131,7 +143,6 @@ def poll_for_image_result(request_id, headers):
131
 
132
  raise HTTPException(status_code=500, detail="Image generation timed out.")
133
 
134
- # Endpoints
135
  @app.post("/new-chat", response_model=dict)
136
  async def new_chat():
137
  chat_id = generate_chat_id()
@@ -176,14 +187,12 @@ async def generate_image(request: ImageRequest):
176
  "max_sequence_length": 512,
177
  }
178
 
179
- with requests.Session() as session:
180
- response = session.post("https://api.bfl.ml/v1/flux-pro-1.1", headers=headers, json=payload, timeout=10).json()
181
 
182
  if "id" not in response:
183
  raise HTTPException(status_code=500, detail="Error generating image: ID missing from response")
184
 
185
  image_url = poll_for_image_result(response["id"], headers)
186
-
187
  image = fetch_image(image_url)
188
  filename = f"generated_{uuid.uuid4()}.png"
189
  filepath = save_image_locally(image, filename)
@@ -199,20 +208,21 @@ async def upscale_image(request: UpscaleRequest):
199
  if aura_sr is None:
200
  raise HTTPException(status_code=500, detail="Upscaling model not initialized.")
201
 
202
- try:
203
- img = fetch_image(request.image_url)
204
- upscaled_image = aura_sr.upscale_4x_overlapped(img)
205
 
 
 
206
  filename = f"upscaled_{uuid.uuid4()}.png"
207
- filepath = save_image_locally(upscaled_image, filename)
208
 
209
- return {
210
- "status": "Upscaling successful",
211
- "file_path": filepath,
212
- "file_url": f"/images/{filename}",
213
- }
214
- except Exception as e:
215
- raise HTTPException(status_code=500, detail=f"Error during upscaling: {str(e)}")
 
216
 
217
  @app.get("/")
218
  async def root():
 
15
  from PIL import Image
16
  import requests
17
  import os
18
+ import logging
19
  from dotenv import load_dotenv
20
 
21
  # Load environment variables
22
  load_dotenv()
23
 
24
+ # Validate environment variables
25
+ assert os.getenv('MONGO_USER') and os.getenv('MONGO_PASSWORD') and os.getenv('MONGO_HOST'), "MongoDB credentials missing!"
26
+ assert os.getenv('LLM_API_KEY'), "LLM API Key missing!"
27
+ assert os.getenv('BFL_API_KEY'), "BFL API Key missing!"
28
+
29
+ # Configure logging
30
+ logging.basicConfig(level=logging.INFO)
31
+ logger = logging.getLogger(__name__)
32
+
33
  # Set the Hugging Face cache directory to a writable location
34
  os.environ['HF_HOME'] = '/tmp/huggingface_cache'
35
 
 
61
  @app.on_event("startup")
62
  async def startup():
63
  global llm, aura_sr
 
 
 
 
 
 
64
  try:
65
+ llm = ChatGroq(
66
+ model="llama-3.3-70b-versatile",
67
+ temperature=0.7,
68
+ max_tokens=1024,
69
+ api_key=os.getenv('LLM_API_KEY'),
70
+ )
71
  aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2")
72
  except Exception as e:
73
+ logger.error(f"Error initializing models: {e}")
74
 
75
  @app.on_event("shutdown")
76
  def shutdown():
 
114
  return filepath
115
 
116
  def fetch_image(url):
117
+ try:
118
+ with requests.Session() as session:
119
+ response = session.get(url, timeout=10)
120
+ response.raise_for_status()
121
+ return Image.open(BytesIO(response.content))
122
+ except Exception as e:
123
+ raise HTTPException(status_code=400, detail=f"Error fetching image: {str(e)}")
124
 
125
  def poll_for_image_result(request_id, headers):
126
  timeout = 60
 
143
 
144
  raise HTTPException(status_code=500, detail="Image generation timed out.")
145
 
 
146
  @app.post("/new-chat", response_model=dict)
147
  async def new_chat():
148
  chat_id = generate_chat_id()
 
187
  "max_sequence_length": 512,
188
  }
189
 
190
+ response = make_request_with_retries("https://api.bfl.ml/v1/flux-pro-1.1", headers, payload)
 
191
 
192
  if "id" not in response:
193
  raise HTTPException(status_code=500, detail="Error generating image: ID missing from response")
194
 
195
  image_url = poll_for_image_result(response["id"], headers)
 
196
  image = fetch_image(image_url)
197
  filename = f"generated_{uuid.uuid4()}.png"
198
  filepath = save_image_locally(image, filename)
 
208
  if aura_sr is None:
209
  raise HTTPException(status_code=500, detail="Upscaling model not initialized.")
210
 
211
+ img = await run_in_threadpool(fetch_image, request.image_url)
 
 
212
 
213
+ def perform_upscaling():
214
+ upscaled_image = aura_sr.upscale_4x_overlapped(img)
215
  filename = f"upscaled_{uuid.uuid4()}.png"
216
+ return save_image_locally(upscaled_image, filename)
217
 
218
+ future = executor.submit(perform_upscaling)
219
+ filepath = await run_in_threadpool(lambda: future.result())
220
+
221
+ return {
222
+ "status": "Upscaling successful",
223
+ "file_path": filepath,
224
+ "file_url": f"/images/{os.path.basename(filepath)}",
225
+ }
226
 
227
  @app.get("/")
228
  async def root():