Spaces:
Build error
Build error
import io | |
import os | |
from typing import Union, List | |
import requests | |
import uvicorn | |
from fastapi import BackgroundTasks, FastAPI | |
import model | |
from audio_to_text import AudioPipeline | |
from text_to_img import init_text2img_pipe, predict | |
from utils import read_from_url | |
from minio import Minio | |
import uuid | |
app = FastAPI() | |
def write_scan_audio_result(audio_id: int, scans: List[int], url: str, callback: str): | |
score_general_threshold = 0.35 | |
score_character_threshold = 0.85 | |
image_url = audio_file_to_image_url(url) | |
if image_url is None: | |
image_url = "" | |
print(image_url) | |
callBackReq = model.AudioScanCallbackRequest(id=audio_id, isValid=True, image_url=image_url) | |
try: | |
requests.post(callback, json=callBackReq.dict()) | |
except Exception as ex: | |
print(ex) | |
# tags = list(map(lambda x: model.AudioScanTag(type="Moderation", | |
# confidence=x['confidence']), nsfw_tags)) | |
ret = model.AudioScanResponse(ok=True, error="", deleted=False, blockedFor=[], tags=None) | |
return ret | |
def write_scan_model_result(model_name: str, callback: str): | |
pass | |
async def image_scan_handler(req: model.AudioScanRequest, background_tasks: BackgroundTasks): | |
if not req.wait: | |
background_tasks.add_task(write_scan_audio_result, | |
audio_id=req.audioId, | |
scans=req.scans, | |
url=req.url, callback=req.callbackUrl) | |
return model.AudioScanResponse(ok=True, error="", deleted=False, blockedFor=[], tags=[]) | |
else: | |
ret = write_scan_audio_result(audio_id=req.audioId, | |
scans=req.scans, | |
url=req.url, callback=req.callbackUrl) | |
return ret | |
def audio_file_to_image_url(audiofile): | |
file_path = read_from_url(audiofile) | |
text = auto_pipeline.audio2txt(file_path) | |
negative_prompt = [ | |
"(watermark:2)", "signature", "username", "(text:2)", "website", | |
"(worst quality:2)", "(low quality:2)", "(normal quality:2)", "polar lowres", "jpeg", | |
"((monochrome))", "((grayscale))", "sketches", "Paintings", | |
"(blurry:2)", "cropped", "lowres", "error", "sketches", | |
"(duplicate:1.331)", "(morbid:1.21)", "(mutilated:1.21)", "(tranny:1.331)", | |
"(bad proportions:1.331)", | |
] | |
images = predict(text, " ".join(negative_prompt), text2img_pipeline) | |
for image in images: | |
in_mem_file = io.BytesIO() | |
image.save(in_mem_file, format='png', pnginfo=None) | |
in_mem_file.seek(0) | |
object_name = uuid.uuid4() | |
s3_client.put_object( | |
bucket_name=os.environ.get("S3_BUCKET"), | |
object_name=object_name, | |
data=in_mem_file, | |
length=in_mem_file.getbuffer().nbytes, | |
content_type="image/png" | |
) | |
image_url = f'{os.environ.get("PUB_VIEW_URL")}/{object_name}' | |
return image_url | |
global auto_pipeline, text2img_pipeline, s3_client | |
if __name__ == "__main__": | |
auto_pipeline = AudioPipeline(audio_text_path='/home/user/app/dedup_audio_text_80.json', | |
audio_text_embeddings_path='/home/user/app/audio_text_embeddings_cpu.safetensors') | |
text2img_pipeline = init_text2img_pipe() | |
s3_client = Minio( | |
os.environ.get("S3_ENDPOINT"), | |
access_key=os.environ.get("S3_ACCESS_KEY"), | |
secret_key=os.environ.get("S3_SECRET_KEY"), | |
) | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |