fountai commited on
Commit
8b66847
1 Parent(s): 6c279e4
Files changed (1) hide show
  1. app.py +2 -122
app.py CHANGED
@@ -11,30 +11,9 @@ from fastapi.openapi.docs import get_swagger_ui_html
11
  import os
12
  import requests
13
  from modules.audio import convert, get_audio_duration
14
- from modules.r2 import upload_to_s3, upload_image_to_s3
15
  import threading
16
  import queue
17
- from diffusers import DiffusionPipeline
18
- import torch
19
- from datetime import datetime
20
- import random
21
- import numpy as np
22
-
23
- SAVE_DIR = "saved_images"
24
- if not os.path.exists(SAVE_DIR):
25
- os.makedirs(SAVE_DIR, exist_ok=True)
26
-
27
- device = "cuda" if torch.cuda.is_available() else "cpu"
28
- repo_id = "black-forest-labs/FLUX.1-dev"
29
- adapter_id = "guardiancc/lora"
30
-
31
- pipeline = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
32
- pipeline.load_lora_weights(adapter_id)
33
- pipeline.enable_sequential_cpu_offload()
34
- pipeline = pipeline.to(device)
35
-
36
- MAX_SEED = np.iinfo(np.int32).max
37
- MAX_IMAGE_SIZE = 1024
38
 
39
  vpv_webhook = os.environ.get("VPV_WEBHOOK")
40
 
@@ -47,34 +26,6 @@ app.add_middleware(
47
  allow_headers=["*"],
48
  )
49
 
50
-
51
- def save_generated_image(image):
52
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
53
- unique_id = str(uuid.uuid4())[:8]
54
- filename = f"{timestamp}_{unique_id}.png"
55
- filepath = os.path.join(SAVE_DIR, filename)
56
- image.save(filepath)
57
-
58
- return filepath
59
-
60
- def inference_image(prompt):
61
- seed = random.randint(0, MAX_SEED)
62
- generator = torch.Generator(device=device).manual_seed(seed)
63
- image = pipeline(
64
- prompt=prompt,
65
- guidance_scale=3.5,
66
- num_inference_steps=20,
67
- width=512,
68
- height=512,
69
- generator=generator,
70
- joint_attention_kwargs={"scale": 0.8},
71
- ).images[0]
72
-
73
- filepath = save_generated_image(image, prompt)
74
- url = upload_image_to_s3(filepath, os.path.basename(filepath), "png")
75
- os.unlink(filepath)
76
- return url
77
-
78
  def download_file(url: str) -> str:
79
  """
80
  Baixa um arquivo da URL fornecida e o salva no diret贸rio 'downloads/'.
@@ -109,7 +60,6 @@ async def openapi():
109
  with open("swagger.json") as f:
110
  return json.load(f)
111
 
112
-
113
  class ProcessRequest(BaseModel):
114
  key: str
115
  text: str
@@ -121,15 +71,9 @@ class ProcessRequest(BaseModel):
121
  format: str = "wav"
122
  speed: float = 0.8
123
  crossfade: float = 0.1
124
-
125
- class ProcessImage(BaseModel):
126
- prompt: str
127
- id: str
128
- receiver: str
129
- webhook: str
130
 
131
  q = queue.Queue()
132
- image_queue = queue.Queue()
133
 
134
  def process_queue(q):
135
  while True:
@@ -154,32 +98,9 @@ def process_queue(q):
154
  print(e)
155
  finally:
156
  q.task_done()
157
-
158
- def process_image(q):
159
- while True:
160
- try:
161
- prompt, id, receiver, webhook = q.get(timeout=5)
162
- image = inference_image(prompt)
163
-
164
- payload = {
165
- "id": id,
166
- "receiver": receiver,
167
- "url": image,
168
- "type": "image"
169
- }
170
-
171
- requests.post(webhook, json=payload)
172
- except Exception as e:
173
- print(e)
174
- finally:
175
- q.task_done()
176
-
177
 
178
  worker_thread = threading.Thread(target=process_queue, args=(q,))
179
  worker_thread.start()
180
-
181
- imagge_worker = threading.Thread(target=process_queue, args=(q,))
182
- imagge_worker.start()
183
 
184
  @app.post("/process")
185
  def process_audio(payload: ProcessRequest):
@@ -227,47 +148,6 @@ def process_audio(payload: ProcessRequest):
227
  requests.post(dc_callback, headers=headers, data=json.dumps(data))
228
  raise HTTPException(status_code=500, detail=str(e))
229
 
230
- @app.post("/image")
231
- def process_image(payload: ProcessImage):
232
- prompt = payload.prompt
233
- id = payload.id
234
- receiver = payload.receiver
235
- webhook = payload.webhook
236
-
237
- if len(prompt) <= 5:
238
- raise HTTPException(status_code=500, detail=str(e))
239
-
240
- try:
241
- image_queue.put(( prompt, id, receiver, webhook))
242
- return {"success": True, "err": ""}
243
-
244
- except ValueError as e:
245
- raise HTTPException(status_code=400, detail=str(e))
246
-
247
- except Exception as e:
248
- error_trace = traceback.format_exc()
249
- dc_callback = "https://discord.com/api/webhooks/1285586984898662511/QNVvY2rtoKICamlXsC1BreBaYjS9341jz9ANCDBzayXt4C7v-vTFzKfUtKQkwW7BwpfP"
250
-
251
- data = {
252
- "content": "",
253
- "tts": False,
254
- "embeds": [
255
- {
256
- "type": "rich",
257
- "title": f"Erro aconteceu na IA - MIMIC - 2 ia",
258
- "description": f"Erro: {str(e)}\n\nDetalhes do erro:\n```{error_trace}```"
259
- }
260
- ]
261
- }
262
-
263
- headers = {
264
- "Content-Type": "application/json",
265
- "Accept": "application/json",
266
- }
267
- requests.post(dc_callback, headers=headers, data=json.dumps(data))
268
- raise HTTPException(status_code=500, detail=str(e))
269
-
270
-
271
  class TrainRequest(BaseModel):
272
  audio: HttpUrl
273
  key: str
 
11
  import os
12
  import requests
13
  from modules.audio import convert, get_audio_duration
14
+ from modules.r2 import upload_to_s3
15
  import threading
16
  import queue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  vpv_webhook = os.environ.get("VPV_WEBHOOK")
19
 
 
26
  allow_headers=["*"],
27
  )
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def download_file(url: str) -> str:
30
  """
31
  Baixa um arquivo da URL fornecida e o salva no diret贸rio 'downloads/'.
 
60
  with open("swagger.json") as f:
61
  return json.load(f)
62
 
 
63
  class ProcessRequest(BaseModel):
64
  key: str
65
  text: str
 
71
  format: str = "wav"
72
  speed: float = 0.8
73
  crossfade: float = 0.1
74
+
 
 
 
 
 
75
 
76
  q = queue.Queue()
 
77
 
78
  def process_queue(q):
79
  while True:
 
98
  print(e)
99
  finally:
100
  q.task_done()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  worker_thread = threading.Thread(target=process_queue, args=(q,))
103
  worker_thread.start()
 
 
 
104
 
105
  @app.post("/process")
106
  def process_audio(payload: ProcessRequest):
 
148
  requests.post(dc_callback, headers=headers, data=json.dumps(data))
149
  raise HTTPException(status_code=500, detail=str(e))
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  class TrainRequest(BaseModel):
152
  audio: HttpUrl
153
  key: str