import io import os from typing import Union, List import requests import uvicorn from fastapi import BackgroundTasks, FastAPI import model from audio_to_text import AudioPipeline from text_to_img import init_text2img_pipe, predict from utils import read_from_url from minio import Minio import uuid app = FastAPI() def write_scan_audio_result(audio_id: int, scans: List[int], url: str, callback: str): score_general_threshold = 0.35 score_character_threshold = 0.85 image_url = audio_file_to_image_url(url) if image_url is None: image_url = "" print(image_url) callBackReq = model.AudioScanCallbackRequest(id=audio_id, isValid=True, image_url=image_url) try: requests.post(callback, json=callBackReq.dict()) except Exception as ex: print(ex) # tags = list(map(lambda x: model.AudioScanTag(type="Moderation", # confidence=x['confidence']), nsfw_tags)) ret = model.AudioScanResponse(ok=True, error="", deleted=False, blockedFor=[], tags=None) return ret def write_scan_model_result(model_name: str, callback: str): pass @app.post("/audio-scan") async def image_scan_handler(req: model.AudioScanRequest, background_tasks: BackgroundTasks): if not req.wait: background_tasks.add_task(write_scan_audio_result, audio_id=req.audioId, scans=req.scans, url=req.url, callback=req.callbackUrl) return model.AudioScanResponse(ok=True, error="", deleted=False, blockedFor=[], tags=[]) else: ret = write_scan_audio_result(audio_id=req.audioId, scans=req.scans, url=req.url, callback=req.callbackUrl) return ret def audio_file_to_image_url(audiofile): file_path = read_from_url(audiofile) text = auto_pipeline.audio2txt(file_path) negative_prompt = [ "(watermark:2)", "signature", "username", "(text:2)", "website", "(worst quality:2)", "(low quality:2)", "(normal quality:2)", "polar lowres", "jpeg", "((monochrome))", "((grayscale))", "sketches", "Paintings", "(blurry:2)", "cropped", "lowres", "error", "sketches", "(duplicate:1.331)", "(morbid:1.21)", "(mutilated:1.21)", "(tranny:1.331)", "(bad proportions:1.331)", ] images = predict(text, " ".join(negative_prompt), text2img_pipeline) for image in images: in_mem_file = io.BytesIO() image.save(in_mem_file, format='png', pnginfo=None) in_mem_file.seek(0) object_name = uuid.uuid4() s3_client.put_object( bucket_name=os.environ.get("S3_BUCKET"), object_name=object_name, data=in_mem_file, length=in_mem_file.getbuffer().nbytes, content_type="image/png" ) image_url = f'{os.environ.get("PUB_VIEW_URL")}/{object_name}' return image_url global auto_pipeline, text2img_pipeline, s3_client if __name__ == "__main__": auto_pipeline = AudioPipeline(audio_text_path='/home/user/app/dedup_audio_text_80.json', audio_text_embeddings_path='/home/user/app/audio_text_embeddings_cpu.safetensors') text2img_pipeline = init_text2img_pipe() s3_client = Minio( os.environ.get("S3_ENDPOINT"), access_key=os.environ.get("S3_ACCESS_KEY"), secret_key=os.environ.get("S3_SECRET_KEY"), ) uvicorn.run(app, host="0.0.0.0", port=7860)