audio_img / main.py
pengdaqian
add more
f0c5f90
raw
history blame
3.59 kB
import io
import os
from typing import Union, List
import requests
import uvicorn
from fastapi import BackgroundTasks, FastAPI
import model
from audio_to_text import AudioPipeline
from text_to_img import init_text2img_pipe, predict
from utils import read_from_url
from minio import Minio
import uuid
app = FastAPI()
def write_scan_audio_result(audio_id: int, scans: List[int], url: str, callback: str):
score_general_threshold = 0.35
score_character_threshold = 0.85
image_url = audio_file_to_image_url(url)
if image_url is None:
image_url = ""
print(image_url)
callBackReq = model.AudioScanCallbackRequest(id=audio_id, isValid=True, image_url=image_url)
try:
requests.post(callback, json=callBackReq.dict())
except Exception as ex:
print(ex)
# tags = list(map(lambda x: model.AudioScanTag(type="Moderation",
# confidence=x['confidence']), nsfw_tags))
ret = model.AudioScanResponse(ok=True, error="", deleted=False, blockedFor=[], tags=None)
return ret
def write_scan_model_result(model_name: str, callback: str):
pass
@app.post("/audio-scan")
async def image_scan_handler(req: model.AudioScanRequest, background_tasks: BackgroundTasks):
if not req.wait:
background_tasks.add_task(write_scan_audio_result,
audio_id=req.audioId,
scans=req.scans,
url=req.url, callback=req.callbackUrl)
return model.AudioScanResponse(ok=True, error="", deleted=False, blockedFor=[], tags=[])
else:
ret = write_scan_audio_result(audio_id=req.audioId,
scans=req.scans,
url=req.url, callback=req.callbackUrl)
return ret
def audio_file_to_image_url(audiofile):
file_path = read_from_url(audiofile)
text = auto_pipeline.audio2txt(file_path)
negative_prompt = [
"(watermark:2)", "signature", "username", "(text:2)", "website",
"(worst quality:2)", "(low quality:2)", "(normal quality:2)", "polar lowres", "jpeg",
"((monochrome))", "((grayscale))", "sketches", "Paintings",
"(blurry:2)", "cropped", "lowres", "error", "sketches",
"(duplicate:1.331)", "(morbid:1.21)", "(mutilated:1.21)", "(tranny:1.331)",
"(bad proportions:1.331)",
]
images = predict(text, " ".join(negative_prompt), text2img_pipeline)
for image in images:
in_mem_file = io.BytesIO()
image.save(in_mem_file, format='png', pnginfo=None)
in_mem_file.seek(0)
object_name = uuid.uuid4()
s3_client.put_object(
bucket_name=os.environ.get("S3_BUCKET"),
object_name=object_name,
data=in_mem_file,
length=in_mem_file.getbuffer().nbytes,
content_type="image/png"
)
image_url = f'{os.environ.get("PUB_VIEW_URL")}/{object_name}'
return image_url
global auto_pipeline, text2img_pipeline, s3_client
if __name__ == "__main__":
auto_pipeline = AudioPipeline(audio_text_path='/home/user/app/dedup_audio_text_80.json',
audio_text_embeddings_path='/home/user/app/audio_text_embeddings_cpu.safetensors')
text2img_pipeline = init_text2img_pipe()
s3_client = Minio(
os.environ.get("S3_ENDPOINT"),
access_key=os.environ.get("S3_ACCESS_KEY"),
secret_key=os.environ.get("S3_SECRET_KEY"),
)
uvicorn.run(app, host="0.0.0.0", port=7860)