dubai / app.py
fountai's picture
adding image and mimic
6fd592e
raw
history blame
9.71 kB
import traceback
import uuid
from models.whisper import model
import modules.register as register
from processor import generate_audio
import json
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, HttpUrl
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import get_swagger_ui_html
import os
import requests
from modules.audio import convert, get_audio_duration
from modules.r2 import upload_to_s3, upload_image_to_s3
import threading
import queue
from diffusers import DiffusionPipeline
import torch
from datetime import datetime
import random
import numpy as np
SAVE_DIR = "saved_images"
if not os.path.exists(SAVE_DIR):
os.makedirs(SAVE_DIR, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
repo_id = "black-forest-labs/FLUX.1-dev"
adapter_id = "guardiancc/lora"
pipeline = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
pipeline.load_lora_weights(adapter_id)
pipeline = pipeline.to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
vpv_webhook = os.environ.get("VPV_WEBHOOK")
app = FastAPI(title="Minha API", description="API de exemplo com FastAPI e Swagger", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def save_generated_image(image):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
unique_id = str(uuid.uuid4())[:8]
filename = f"{timestamp}_{unique_id}.png"
filepath = os.path.join(SAVE_DIR, filename)
image.save(filepath)
return filepath
def inference_image(prompt):
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
image = pipeline(
prompt=prompt,
guidance_scale=3.5,
num_inference_steps=20,
width=512,
height=512,
generator=generator,
joint_attention_kwargs={"scale": 0.8},
).images[0]
filepath = save_generated_image(image, prompt)
url = upload_image_to_s3(filepath, os.path.basename(filepath), "png")
os.unlink(filepath)
return url
def download_file(url: str) -> str:
"""
Baixa um arquivo da URL fornecida e o salva no diretório 'downloads/'.
O nome do arquivo é extraído da URL automaticamente.
"""
try:
os.makedirs("downloads", exist_ok=True)
file_name = os.path.basename(url.split("?")[0])
save_path = os.path.join("downloads", file_name)
response = requests.get(url)
response.raise_for_status()
with open(save_path, 'wb') as f:
f.write(response.content)
return save_path
except requests.exceptions.RequestException as e:
raise Exception(f"Erro ao baixar o arquivo: {e}")
@app.get("/test", include_in_schema=False)
def test():
return {"ok": True}
@app.get("/", include_in_schema=False)
async def custom_swagger_ui_html():
return get_swagger_ui_html(openapi_url="/openapi.json", title="Alert Pix Ai v2")
@app.get("/openapi.json", include_in_schema=False)
async def openapi():
with open("swagger.json") as f:
return json.load(f)
class ProcessRequest(BaseModel):
key: str
text: str
id: str
receiver: str
webhook: str
censor: bool = False
offset: float = -0.3
format: str = "wav"
speed: float = 0.8
crossfade: float = 0.1
class ProcessImage(BaseModel):
prompt: str
id: str
receiver: str
webhook: str
q = queue.Queue()
image_queue = queue.Queue()
def process_queue(q):
while True:
try:
key, censor, offset, text, format, speed, crossfade, id, receiver, webhook = q.get(timeout=5)
audio = generate_audio(key, text, censor, offset, speed=speed, crossfade=crossfade)
convertedAudioPath = convert(audio, format)
duration = get_audio_duration(convertedAudioPath)
audioUrl = upload_to_s3(convertedAudioPath, f"{id}", format)
os.remove(audio)
os.remove(convertedAudioPath)
payload = {
"id": id,
"duration": duration,
"receiver": receiver,
"url": audioUrl
}
requests.post(webhook, json=payload)
except Exception as e:
print(e)
finally:
q.task_done()
def process_image(q):
while True:
try:
prompt, id, receiver, webhook = q.get(timeout=5)
image = inference_image(prompt)
payload = {
"id": id,
"receiver": receiver,
"url": image,
"type": "image"
}
requests.post(webhook, json=payload)
except Exception as e:
print(e)
finally:
q.task_done()
worker_thread = threading.Thread(target=process_queue, args=(q,))
worker_thread.start()
imagge_worker = threading.Thread(target=process_queue, args=(q,))
imagge_worker.start()
@app.post("/process")
def process_audio(payload: ProcessRequest):
key = payload.key
censor = payload.censor
offset = payload.offset
text = payload.text
format = payload.format
speed = payload.speed
crossfade = payload.crossfade
id = payload.id
receiver = payload.receiver
webhook = payload.webhook
if len(text) >= 1000:
raise HTTPException(status_code=500, detail=str(e))
try:
q.put((key, censor, offset, text, format, speed, crossfade, id, receiver, webhook))
return {"success": True, "err": ""}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
error_trace = traceback.format_exc()
dc_callback = "https://discord.com/api/webhooks/1285586984898662511/QNVvY2rtoKICamlXsC1BreBaYjS9341jz9ANCDBzayXt4C7v-vTFzKfUtKQkwW7BwpfP"
data = {
"content": "",
"tts": False,
"embeds": [
{
"type": "rich",
"title": f"Erro aconteceu na IA - MIMIC - processo",
"description": f"Erro: {str(e)}\n\nDetalhes do erro:\n```{error_trace}```"
}
]
}
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
requests.post(dc_callback, headers=headers, data=json.dumps(data))
raise HTTPException(status_code=500, detail=str(e))
@app.post("/image")
def process_image(payload: ProcessImage):
prompt = payload.prompt
id = payload.id
receiver = payload.receiver
webhook = payload.webhook
if len(prompt) <= 5:
raise HTTPException(status_code=500, detail=str(e))
try:
image_queue.put(( prompt, id, receiver, webhook))
return {"success": True, "err": ""}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
error_trace = traceback.format_exc()
dc_callback = "https://discord.com/api/webhooks/1285586984898662511/QNVvY2rtoKICamlXsC1BreBaYjS9341jz9ANCDBzayXt4C7v-vTFzKfUtKQkwW7BwpfP"
data = {
"content": "",
"tts": False,
"embeds": [
{
"type": "rich",
"title": f"Erro aconteceu na IA - MIMIC - 2 ia",
"description": f"Erro: {str(e)}\n\nDetalhes do erro:\n```{error_trace}```"
}
]
}
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
requests.post(dc_callback, headers=headers, data=json.dumps(data))
raise HTTPException(status_code=500, detail=str(e))
class TrainRequest(BaseModel):
audio: HttpUrl
key: str
endpoint: str
id: str
@app.post("/train")
def create_item(payload: TrainRequest):
audio = payload.audio
key = payload.key
endpoint = payload.endpoint
try:
src = download_file(str(audio))
data = register.process_audio(src, key)
for i in range(3):
try:
payload = {"success": True, "id": payload.id}
requests.post(endpoint, json=payload)
break
except Exception as e:
pass
return data
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
error_trace = traceback.format_exc()
dc_callback = "https://discord.com/api/webhooks/1285586984898662511/QNVvY2rtoKICamlXsC1BreBaYjS9341jz9ANCDBzayXt4C7v-vTFzKfUtKQkwW7BwpfP"
data = {
"content": "",
"tts": False,
"embeds": [
{
"type": "rich",
"title": f"Erro aconteceu na IA -MIMIC - treinar",
"description": f"Erro: {str(e)}\n\nDetalhes do erro:\n```{error_trace}```"
}
]
}
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
requests.post(dc_callback, headers=headers, data=json.dumps(data))
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)