fountai commited on
Commit
6fd592e
·
1 Parent(s): 1fe2f2f

adding image and mimic

Browse files
Files changed (3) hide show
  1. app.py +151 -17
  2. modules/r2.py +9 -0
  3. requirements.txt +2 -1
app.py CHANGED
@@ -11,9 +11,29 @@ from fastapi.openapi.docs import get_swagger_ui_html
11
  import os
12
  import requests
13
  from modules.audio import convert, get_audio_duration
14
- from modules.r2 import upload_to_s3
15
  import threading
16
  import queue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  vpv_webhook = os.environ.get("VPV_WEBHOOK")
19
 
@@ -26,6 +46,34 @@ app.add_middleware(
26
  allow_headers=["*"],
27
  )
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def download_file(url: str) -> str:
30
  """
31
  Baixa um arquivo da URL fornecida e o salva no diretório 'downloads/'.
@@ -71,8 +119,67 @@ class ProcessRequest(BaseModel):
71
  offset: float = -0.3
72
  format: str = "wav"
73
  speed: float = 0.8
74
- crossfade: float = 0.06
 
 
 
 
 
 
 
 
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  @app.post("/process")
77
  def process_audio(payload: ProcessRequest):
78
  key = payload.key
@@ -90,21 +197,7 @@ def process_audio(payload: ProcessRequest):
90
  raise HTTPException(status_code=500, detail=str(e))
91
 
92
  try:
93
- audio = generate_audio(key, text, censor, offset, speed=speed, crossfade=crossfade)
94
- convertedAudioPath = convert(audio, format)
95
- duration = get_audio_duration(convertedAudioPath)
96
- audioUrl = upload_to_s3(convertedAudioPath, f"{id}", format)
97
- os.remove(audio)
98
- os.remove(convertedAudioPath)
99
-
100
- payload = {
101
- "id": id,
102
- "duration": duration,
103
- "receiver": receiver,
104
- "url": audioUrl
105
- }
106
-
107
- requests.post(webhook, json=payload)
108
  return {"success": True, "err": ""}
109
 
110
  except ValueError as e:
@@ -133,6 +226,47 @@ def process_audio(payload: ProcessRequest):
133
  requests.post(dc_callback, headers=headers, data=json.dumps(data))
134
  raise HTTPException(status_code=500, detail=str(e))
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  class TrainRequest(BaseModel):
137
  audio: HttpUrl
138
  key: str
 
11
  import os
12
  import requests
13
  from modules.audio import convert, get_audio_duration
14
+ from modules.r2 import upload_to_s3, upload_image_to_s3
15
  import threading
16
  import queue
17
+ from diffusers import DiffusionPipeline
18
+ import torch
19
+ from datetime import datetime
20
+ import random
21
+ import numpy as np
22
+
23
+ SAVE_DIR = "saved_images"
24
+ if not os.path.exists(SAVE_DIR):
25
+ os.makedirs(SAVE_DIR, exist_ok=True)
26
+
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ repo_id = "black-forest-labs/FLUX.1-dev"
29
+ adapter_id = "guardiancc/lora"
30
+
31
+ pipeline = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
32
+ pipeline.load_lora_weights(adapter_id)
33
+ pipeline = pipeline.to(device)
34
+
35
+ MAX_SEED = np.iinfo(np.int32).max
36
+ MAX_IMAGE_SIZE = 1024
37
 
38
  vpv_webhook = os.environ.get("VPV_WEBHOOK")
39
 
 
46
  allow_headers=["*"],
47
  )
48
 
49
+
50
+ def save_generated_image(image):
51
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
52
+ unique_id = str(uuid.uuid4())[:8]
53
+ filename = f"{timestamp}_{unique_id}.png"
54
+ filepath = os.path.join(SAVE_DIR, filename)
55
+ image.save(filepath)
56
+
57
+ return filepath
58
+
59
+ def inference_image(prompt):
60
+ seed = random.randint(0, MAX_SEED)
61
+ generator = torch.Generator(device=device).manual_seed(seed)
62
+ image = pipeline(
63
+ prompt=prompt,
64
+ guidance_scale=3.5,
65
+ num_inference_steps=20,
66
+ width=512,
67
+ height=512,
68
+ generator=generator,
69
+ joint_attention_kwargs={"scale": 0.8},
70
+ ).images[0]
71
+
72
+ filepath = save_generated_image(image, prompt)
73
+ url = upload_image_to_s3(filepath, os.path.basename(filepath), "png")
74
+ os.unlink(filepath)
75
+ return url
76
+
77
  def download_file(url: str) -> str:
78
  """
79
  Baixa um arquivo da URL fornecida e o salva no diretório 'downloads/'.
 
119
  offset: float = -0.3
120
  format: str = "wav"
121
  speed: float = 0.8
122
+ crossfade: float = 0.1
123
+
124
+ class ProcessImage(BaseModel):
125
+ prompt: str
126
+ id: str
127
+ receiver: str
128
+ webhook: str
129
+
130
+ q = queue.Queue()
131
+ image_queue = queue.Queue()
132
 
133
+ def process_queue(q):
134
+ while True:
135
+ try:
136
+ key, censor, offset, text, format, speed, crossfade, id, receiver, webhook = q.get(timeout=5)
137
+ audio = generate_audio(key, text, censor, offset, speed=speed, crossfade=crossfade)
138
+ convertedAudioPath = convert(audio, format)
139
+ duration = get_audio_duration(convertedAudioPath)
140
+ audioUrl = upload_to_s3(convertedAudioPath, f"{id}", format)
141
+ os.remove(audio)
142
+ os.remove(convertedAudioPath)
143
+
144
+ payload = {
145
+ "id": id,
146
+ "duration": duration,
147
+ "receiver": receiver,
148
+ "url": audioUrl
149
+ }
150
+
151
+ requests.post(webhook, json=payload)
152
+ except Exception as e:
153
+ print(e)
154
+ finally:
155
+ q.task_done()
156
+
157
+ def process_image(q):
158
+ while True:
159
+ try:
160
+ prompt, id, receiver, webhook = q.get(timeout=5)
161
+ image = inference_image(prompt)
162
+
163
+ payload = {
164
+ "id": id,
165
+ "receiver": receiver,
166
+ "url": image,
167
+ "type": "image"
168
+ }
169
+
170
+ requests.post(webhook, json=payload)
171
+ except Exception as e:
172
+ print(e)
173
+ finally:
174
+ q.task_done()
175
+
176
+
177
+ worker_thread = threading.Thread(target=process_queue, args=(q,))
178
+ worker_thread.start()
179
+
180
+ imagge_worker = threading.Thread(target=process_queue, args=(q,))
181
+ imagge_worker.start()
182
+
183
  @app.post("/process")
184
  def process_audio(payload: ProcessRequest):
185
  key = payload.key
 
197
  raise HTTPException(status_code=500, detail=str(e))
198
 
199
  try:
200
+ q.put((key, censor, offset, text, format, speed, crossfade, id, receiver, webhook))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  return {"success": True, "err": ""}
202
 
203
  except ValueError as e:
 
226
  requests.post(dc_callback, headers=headers, data=json.dumps(data))
227
  raise HTTPException(status_code=500, detail=str(e))
228
 
229
+ @app.post("/image")
230
+ def process_image(payload: ProcessImage):
231
+ prompt = payload.prompt
232
+ id = payload.id
233
+ receiver = payload.receiver
234
+ webhook = payload.webhook
235
+
236
+ if len(prompt) <= 5:
237
+ raise HTTPException(status_code=500, detail=str(e))
238
+
239
+ try:
240
+ image_queue.put(( prompt, id, receiver, webhook))
241
+ return {"success": True, "err": ""}
242
+
243
+ except ValueError as e:
244
+ raise HTTPException(status_code=400, detail=str(e))
245
+
246
+ except Exception as e:
247
+ error_trace = traceback.format_exc()
248
+ dc_callback = "https://discord.com/api/webhooks/1285586984898662511/QNVvY2rtoKICamlXsC1BreBaYjS9341jz9ANCDBzayXt4C7v-vTFzKfUtKQkwW7BwpfP"
249
+
250
+ data = {
251
+ "content": "",
252
+ "tts": False,
253
+ "embeds": [
254
+ {
255
+ "type": "rich",
256
+ "title": f"Erro aconteceu na IA - MIMIC - 2 ia",
257
+ "description": f"Erro: {str(e)}\n\nDetalhes do erro:\n```{error_trace}```"
258
+ }
259
+ ]
260
+ }
261
+
262
+ headers = {
263
+ "Content-Type": "application/json",
264
+ "Accept": "application/json",
265
+ }
266
+ requests.post(dc_callback, headers=headers, data=json.dumps(data))
267
+ raise HTTPException(status_code=500, detail=str(e))
268
+
269
+
270
  class TrainRequest(BaseModel):
271
  audio: HttpUrl
272
  key: str
modules/r2.py CHANGED
@@ -28,6 +28,15 @@ def upload_to_s3(path, name, extension):
28
  ExpiresIn=604800
29
  )
30
  return url
 
 
 
 
 
 
 
 
 
31
 
32
  def get_url(name):
33
  url = s3.generate_presigned_url(
 
28
  ExpiresIn=604800
29
  )
30
  return url
31
+
32
+ def upload_image_to_s3(path, name, extension):
33
+ s3.upload_file(path, bucket, name, ExtraArgs={'ContentType': f'image/{extension}', 'ACL': 'public-read'})
34
+ url = s3.generate_presigned_url(
35
+ 'get_object',
36
+ Params={'Bucket': bucket, 'Key': name},
37
+ ExpiresIn=604800
38
+ )
39
+ return url
40
 
41
  def get_url(name):
42
  url = s3.generate_presigned_url(
requirements.txt CHANGED
@@ -7,4 +7,5 @@ phonemizer
7
  pydub
8
  fastapi
9
  uvicorn
10
- uuid
 
 
7
  pydub
8
  fastapi
9
  uvicorn
10
+ uuid
11
+ diffusers