audio_img / main.py
pengdaqian
add more
f0c5f90
import io
import os
from typing import Union, List
import requests
import uvicorn
from fastapi import BackgroundTasks, FastAPI
import model
from audio_to_text import AudioPipeline
from text_to_img import init_text2img_pipe, predict
from utils import read_from_url
from minio import Minio
import uuid
app = FastAPI()
def write_scan_audio_result(audio_id: int, scans: List[int], url: str, callback: str):
score_general_threshold = 0.35
score_character_threshold = 0.85
image_url = audio_file_to_image_url(url)
if image_url is None:
image_url = ""
print(image_url)
callBackReq = model.AudioScanCallbackRequest(id=audio_id, isValid=True, image_url=image_url)
try:
requests.post(callback, json=callBackReq.dict())
except Exception as ex:
print(ex)
# tags = list(map(lambda x: model.AudioScanTag(type="Moderation",
# confidence=x['confidence']), nsfw_tags))
ret = model.AudioScanResponse(ok=True, error="", deleted=False, blockedFor=[], tags=None)
return ret
def write_scan_model_result(model_name: str, callback: str):
pass
@app.post("/audio-scan")
async def image_scan_handler(req: model.AudioScanRequest, background_tasks: BackgroundTasks):
if not req.wait:
background_tasks.add_task(write_scan_audio_result,
audio_id=req.audioId,
scans=req.scans,
url=req.url, callback=req.callbackUrl)
return model.AudioScanResponse(ok=True, error="", deleted=False, blockedFor=[], tags=[])
else:
ret = write_scan_audio_result(audio_id=req.audioId,
scans=req.scans,
url=req.url, callback=req.callbackUrl)
return ret
def audio_file_to_image_url(audiofile):
file_path = read_from_url(audiofile)
text = auto_pipeline.audio2txt(file_path)
negative_prompt = [
"(watermark:2)", "signature", "username", "(text:2)", "website",
"(worst quality:2)", "(low quality:2)", "(normal quality:2)", "polar lowres", "jpeg",
"((monochrome))", "((grayscale))", "sketches", "Paintings",
"(blurry:2)", "cropped", "lowres", "error", "sketches",
"(duplicate:1.331)", "(morbid:1.21)", "(mutilated:1.21)", "(tranny:1.331)",
"(bad proportions:1.331)",
]
images = predict(text, " ".join(negative_prompt), text2img_pipeline)
for image in images:
in_mem_file = io.BytesIO()
image.save(in_mem_file, format='png', pnginfo=None)
in_mem_file.seek(0)
object_name = uuid.uuid4()
s3_client.put_object(
bucket_name=os.environ.get("S3_BUCKET"),
object_name=object_name,
data=in_mem_file,
length=in_mem_file.getbuffer().nbytes,
content_type="image/png"
)
image_url = f'{os.environ.get("PUB_VIEW_URL")}/{object_name}'
return image_url
global auto_pipeline, text2img_pipeline, s3_client
if __name__ == "__main__":
auto_pipeline = AudioPipeline(audio_text_path='/home/user/app/dedup_audio_text_80.json',
audio_text_embeddings_path='/home/user/app/audio_text_embeddings_cpu.safetensors')
text2img_pipeline = init_text2img_pipe()
s3_client = Minio(
os.environ.get("S3_ENDPOINT"),
access_key=os.environ.get("S3_ACCESS_KEY"),
secret_key=os.environ.get("S3_SECRET_KEY"),
)
uvicorn.run(app, host="0.0.0.0", port=7860)