Dash-inc commited on
Commit
897b11c
·
verified ·
1 Parent(s): 8891183

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +93 -132
main.py CHANGED
@@ -5,7 +5,7 @@ from fastapi.concurrency import run_in_threadpool
5
  from fastapi.middleware.cors import CORSMiddleware
6
  import uuid
7
  import time
8
- import tempfile
9
  from concurrent.futures import ThreadPoolExecutor
10
  from pymongo import MongoClient
11
  from urllib.parse import quote_plus
@@ -15,7 +15,10 @@ from io import BytesIO
15
  from PIL import Image
16
  import requests
17
  import os
18
- import os
 
 
 
19
 
20
  # Set the Hugging Face cache directory to a writable location
21
  os.environ['HF_HOME'] = '/tmp/huggingface_cache'
@@ -31,22 +34,19 @@ app.add_middleware(
31
  )
32
 
33
  # Globals
34
- executor = ThreadPoolExecutor(max_workers=10)
35
  llm = None
36
- client = MongoClient(f"mongodb+srv://hammad:{quote_plus('momimaad@123')}@cluster0.2a9yu.mongodb.net/")
37
- db = client["Flux"]
38
  collection = db["chat_histories"]
39
  chat_sessions = {}
40
 
41
- # Use a temporary directory for storing images
42
  image_storage_dir = tempfile.mkdtemp()
43
- print(f"Temporary directory for images: {image_storage_dir}")
44
-
45
- # Serve the temporary image directory as a static file directory
46
  app.mount("/images", StaticFiles(directory=image_storage_dir), name="images")
47
 
48
- # Dictionary to store images temporarily
49
- image_store = {}
50
 
51
  @app.on_event("startup")
52
  async def startup():
@@ -55,19 +55,16 @@ async def startup():
55
  model="llama-3.3-70b-versatile",
56
  temperature=0.7,
57
  max_tokens=1024,
58
- api_key="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",
59
  )
60
- # Ensure AuraSR is initialized properly
61
  try:
62
  aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2")
63
  except Exception as e:
64
  print(f"Error initializing AuraSR: {e}")
65
- aura_sr = None
66
-
67
 
68
  @app.on_event("shutdown")
69
  def shutdown():
70
- client.close()
71
  executor.shutdown()
72
 
73
  # Pydantic models
@@ -85,7 +82,10 @@ class ImageRequest(BaseModel):
85
  user_prompt: str
86
  chat_id: str
87
 
88
- # Helper Functions
 
 
 
89
  def generate_chat_id():
90
  chat_id = str(uuid.uuid4())
91
  chat_sessions[chat_id] = collection
@@ -103,6 +103,34 @@ def save_image_locally(image, filename):
103
  image.save(filepath, format="PNG")
104
  return filepath
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  # Endpoints
107
  @app.post("/new-chat", response_model=dict)
108
  async def new_chat():
@@ -111,148 +139,81 @@ async def new_chat():
111
 
112
  @app.post("/generate-image", response_model=dict)
113
  async def generate_image(request: ImageRequest):
114
- def process_request():
115
- chat_history = get_chat_history(request.chat_id)
116
- prompt = f"""
117
- You are a professional assistant responsible for crafting a clear and visually compelling prompt for an image generation model. Your task is to generate a high-quality prompt for creating both the **main subject** and the **background** of the image.
118
-
119
- Image Specifications:
120
- - **Subject**: Focus on **{request.subject}**, highlighting its defining features, expressions, and textures.
121
- - **Style**: Emphasize the **{request.style}**, capturing its key characteristics.
122
- - **Background**: Create a background with **{request.background_details}** that complements and enhances the subject. Ensure it aligns with the color theme and overall composition.
123
- - **Camera and Lighting**:
124
- - Lighting: Apply **{request.lighting_conditions}**, emphasizing depth, highlights, and shadows to accentuate the subject and harmonize the background.
125
- - **Framing**: Use a **{request.framing_style}** to enhance the composition around both the subject and the background.
126
- - **Materials**: Highlight textures like **{request.material_details}**, with realistic details and natural imperfections on the subject and background.
127
- - **Key Elements**: Include **{request.elements}** to enrich the subject’s details and add cohesive elements to the background.
128
- - **Color Theme**: Follow the **{request.color_theme}** to set the mood and tone for the entire scene.
129
- - Negative Prompt: Avoid grainy, blurry, or deformed outputs.
130
- - **Text to Include in Image**: Clearly display the text **"{request.text}"** as part of the composition (e.g., on a card, badge, or banner) attached to the subject in a realistic and contextually appropriate way.
131
- """
132
- refined_prompt = llm.invoke(prompt).content.strip()
133
-
134
- # Save the prompt in MongoDB
135
- collection.insert_one({"session_id": request.chat_id, "role": "user", "content": request.user_prompt})
136
- collection.insert_one({"session_id": request.chat_id, "role": "ai", "content": refined_prompt})
137
-
138
- # API call to image generation service
139
- url = "https://api.bfl.ml/v1/flux-pro-1.1"
140
- headers = {
141
- "accept": "application/json",
142
- "x-key": "4f69d408-4979-4812-8ad2-ec6d232c9ddf",
143
- "Content-Type": "application/json"
144
- }
145
- payload = {
146
- "prompt": refined_prompt,
147
- "width": 1024,
148
- "height": 1024,
149
- "guidance_scale": 1,
150
- "num_inference_steps": 50,
151
- "max_sequence_length": 512,
152
- }
153
 
154
- response = requests.post(url, headers=headers, json=payload).json()
155
- if "id" not in response:
156
- raise HTTPException(status_code=500, detail="Error generating image: ID missing from response")
 
 
 
 
 
157
 
158
- request_id = response["id"]
 
159
 
160
- # Poll for the image result
161
- while True:
162
- time.sleep(0.5)
163
- result = requests.get(
164
- "https://api.bfl.ml/v1/get_result",
165
- headers=headers,
166
- params={"id": request_id},
167
- ).json()
168
 
169
- if result["status"] == "Ready":
170
- if "result" in result and "sample" in result["result"]:
171
- image_url = result["result"]["sample"]
172
-
173
- # Download the image
174
- image_response = requests.get(image_url)
175
- if image_response.status_code == 200:
176
- img = Image.open(BytesIO(image_response.content))
177
- filename = f"generated_{uuid.uuid4()}.png"
178
- filepath = save_image_locally(img, filename)
179
-
180
- # Generate a URL for the image
181
- file_url = f"/images/{filename}"
182
- return filepath, file_url
183
- else:
184
- raise HTTPException(status_code=500, detail="Failed to download the image")
185
- else:
186
- raise HTTPException(status_code=500, detail="Expected 'sample' key not found in the result")
187
- elif result["status"] == "Error":
188
- raise HTTPException(status_code=500, detail=f"Image generation failed: {result.get('error', 'Unknown error')}")
189
-
190
- future = executor.submit(process_request)
191
- filepath, file_url = await run_in_threadpool(future.result)
192
 
193
  return {
194
  "status": "Image generated successfully",
195
  "file_path": filepath,
196
- "file_url": file_url,
197
  }
198
 
199
- class UpscaleRequest(BaseModel):
200
- image_url: str
201
-
202
-
203
- def process_image(image_url):
204
- if aura_sr is None:
205
- raise RuntimeError("aura_sr is None. Ensure it's initialized during startup.")
206
-
207
- response = requests.get(image_url)
208
- img = Image.open(BytesIO(response.content))
209
-
210
- try:
211
- upscaled_image = aura_sr.upscale_4x_overlapped(img)
212
- except Exception as e:
213
- raise RuntimeError(f"Error during upscaling: {str(e)}")
214
-
215
- filename = f"upscaled_{uuid.uuid4()}.png"
216
- filepath = save_image_locally(upscaled_image, filename)
217
- return filepath
218
-
219
-
220
-
221
  @app.post("/upscale-image", response_model=dict)
222
  async def upscale_image(request: UpscaleRequest):
223
  if aura_sr is None:
224
  raise HTTPException(status_code=500, detail="Upscaling model not initialized.")
225
 
226
  try:
227
- # Fetch the image from the provided URL
228
- response = requests.get(request.image_url)
229
- if response.status_code != 200:
230
- raise HTTPException(status_code=400, detail="Failed to fetch the image from the provided URL.")
231
-
232
- img = Image.open(BytesIO(response.content))
233
-
234
- # Perform upscaling
235
  upscaled_image = aura_sr.upscale_4x_overlapped(img)
236
 
237
- # Save the upscaled image locally
238
  filename = f"upscaled_{uuid.uuid4()}.png"
239
  filepath = save_image_locally(upscaled_image, filename)
240
 
241
- # Generate a public URL for the upscaled image
242
- file_url = f"/images/{filename}"
243
-
244
  return {
245
  "status": "Upscaling successful",
246
  "file_path": filepath,
247
- "file_url": file_url,
248
  }
249
  except Exception as e:
250
  raise HTTPException(status_code=500, detail=f"Error during upscaling: {str(e)}")
251
 
252
-
253
-
254
-
255
-
256
  @app.get("/")
257
  async def root():
258
- return {"message": "API is up and running!"}
 
5
  from fastapi.middleware.cors import CORSMiddleware
6
  import uuid
7
  import time
8
+ tempfile
9
  from concurrent.futures import ThreadPoolExecutor
10
  from pymongo import MongoClient
11
  from urllib.parse import quote_plus
 
15
  from PIL import Image
16
  import requests
17
  import os
18
+ from dotenv import load_dotenv
19
+
20
+ # Load environment variables
21
+ load_dotenv()
22
 
23
  # Set the Hugging Face cache directory to a writable location
24
  os.environ['HF_HOME'] = '/tmp/huggingface_cache'
 
34
  )
35
 
36
  # Globals
37
+ executor = ThreadPoolExecutor(max_workers=5)
38
  llm = None
39
+ mongo_client = MongoClient(f"mongodb+srv://{os.getenv('MONGO_USER')}:{quote_plus(os.getenv('MONGO_PASSWORD'))}@{os.getenv('MONGO_HOST')}/")
40
+ db = mongo_client["Flux"]
41
  collection = db["chat_histories"]
42
  chat_sessions = {}
43
 
44
+ # Temporary directory for storing images
45
  image_storage_dir = tempfile.mkdtemp()
 
 
 
46
  app.mount("/images", StaticFiles(directory=image_storage_dir), name="images")
47
 
48
+ # Initialize AuraSR during startup
49
+ aura_sr = None
50
 
51
  @app.on_event("startup")
52
  async def startup():
 
55
  model="llama-3.3-70b-versatile",
56
  temperature=0.7,
57
  max_tokens=1024,
58
+ api_key=os.getenv('LLM_API_KEY'),
59
  )
 
60
  try:
61
  aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2")
62
  except Exception as e:
63
  print(f"Error initializing AuraSR: {e}")
 
 
64
 
65
  @app.on_event("shutdown")
66
  def shutdown():
67
+ mongo_client.close()
68
  executor.shutdown()
69
 
70
  # Pydantic models
 
82
  user_prompt: str
83
  chat_id: str
84
 
85
+ class UpscaleRequest(BaseModel):
86
+ image_url: str
87
+
88
+ # Helper functions
89
  def generate_chat_id():
90
  chat_id = str(uuid.uuid4())
91
  chat_sessions[chat_id] = collection
 
103
  image.save(filepath, format="PNG")
104
  return filepath
105
 
106
+ def fetch_image(url):
107
+ with requests.Session() as session:
108
+ response = session.get(url, timeout=10)
109
+ if response.status_code != 200:
110
+ raise HTTPException(status_code=400, detail="Failed to fetch image.")
111
+ return Image.open(BytesIO(response.content))
112
+
113
+ def poll_for_image_result(request_id, headers):
114
+ timeout = 60
115
+ start_time = time.time()
116
+
117
+ while time.time() - start_time < timeout:
118
+ time.sleep(0.5)
119
+ with requests.Session() as session:
120
+ result = session.get(
121
+ "https://api.bfl.ml/v1/get_result",
122
+ headers=headers,
123
+ params={"id": request_id},
124
+ timeout=10
125
+ ).json()
126
+
127
+ if result["status"] == "Ready":
128
+ return result["result"].get("sample")
129
+ elif result["status"] == "Error":
130
+ raise HTTPException(status_code=500, detail=f"Image generation failed: {result.get('error', 'Unknown error')}")
131
+
132
+ raise HTTPException(status_code=500, detail="Image generation timed out.")
133
+
134
  # Endpoints
135
  @app.post("/new-chat", response_model=dict)
136
  async def new_chat():
 
139
 
140
  @app.post("/generate-image", response_model=dict)
141
  async def generate_image(request: ImageRequest):
142
+ chat_history = get_chat_history(request.chat_id)
143
+ prompt = f"""
144
+ You are a professional assistant responsible for crafting a clear and visually compelling prompt for an image generation model. Your task is to generate a high-quality prompt for creating both the **main subject** and the **background** of the image.
145
+
146
+ Image Specifications:
147
+ - **Subject**: Focus on **{request.subject}**, highlighting its defining features, expressions, and textures.
148
+ - **Style**: Emphasize the **{request.style}**, capturing its key characteristics.
149
+ - **Background**: Create a background with **{request.background_details}** that complements and enhances the subject. Ensure it aligns with the color theme and overall composition.
150
+ - **Camera and Lighting**:
151
+ - Lighting: Apply **{request.lighting_conditions}**, emphasizing depth, highlights, and shadows to accentuate the subject and harmonize the background.
152
+ - **Framing**: Use a **{request.framing_style}** to enhance the composition around both the subject and the background.
153
+ - **Materials**: Highlight textures like **{request.material_details}**, with realistic details and natural imperfections on the subject and background.
154
+ - **Key Elements**: Include **{request.elements}** to enrich the subject’s details and add cohesive elements to the background.
155
+ - **Color Theme**: Follow the **{request.color_theme}** to set the mood and tone for the entire scene.
156
+ - Negative Prompt: Avoid grainy, blurry, or deformed outputs.
157
+ - **Text to Include in Image**: Clearly display the text **"{request.text}"** as part of the composition (e.g., on a card, badge, or banner) attached to the subject in a realistic and contextually appropriate way.
158
+ """
159
+
160
+ refined_prompt = llm.invoke(prompt).content.strip()
161
+ collection.insert_one({"session_id": request.chat_id, "role": "user", "content": request.user_prompt})
162
+ collection.insert_one({"session_id": request.chat_id, "role": "ai", "content": refined_prompt})
163
+
164
+ headers = {
165
+ "accept": "application/json",
166
+ "x-key": os.getenv('BFL_API_KEY'),
167
+ "Content-Type": "application/json"
168
+ }
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
+ payload = {
171
+ "prompt": refined_prompt,
172
+ "width": 1024,
173
+ "height": 1024,
174
+ "guidance_scale": 1,
175
+ "num_inference_steps": 50,
176
+ "max_sequence_length": 512,
177
+ }
178
 
179
+ with requests.Session() as session:
180
+ response = session.post("https://api.bfl.ml/v1/flux-pro-1.1", headers=headers, json=payload, timeout=10).json()
181
 
182
+ if "id" not in response:
183
+ raise HTTPException(status_code=500, detail="Error generating image: ID missing from response")
184
+
185
+ image_url = poll_for_image_result(response["id"], headers)
 
 
 
 
186
 
187
+ image = fetch_image(image_url)
188
+ filename = f"generated_{uuid.uuid4()}.png"
189
+ filepath = save_image_locally(image, filename)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  return {
192
  "status": "Image generated successfully",
193
  "file_path": filepath,
194
+ "file_url": f"/images/{filename}",
195
  }
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  @app.post("/upscale-image", response_model=dict)
198
  async def upscale_image(request: UpscaleRequest):
199
  if aura_sr is None:
200
  raise HTTPException(status_code=500, detail="Upscaling model not initialized.")
201
 
202
  try:
203
+ img = fetch_image(request.image_url)
 
 
 
 
 
 
 
204
  upscaled_image = aura_sr.upscale_4x_overlapped(img)
205
 
 
206
  filename = f"upscaled_{uuid.uuid4()}.png"
207
  filepath = save_image_locally(upscaled_image, filename)
208
 
 
 
 
209
  return {
210
  "status": "Upscaling successful",
211
  "file_path": filepath,
212
+ "file_url": f"/images/{filename}",
213
  }
214
  except Exception as e:
215
  raise HTTPException(status_code=500, detail=f"Error during upscaling: {str(e)}")
216
 
 
 
 
 
217
  @app.get("/")
218
  async def root():
219
+ return {"message": "API is up and running!"}