File size: 3,593 Bytes
171f55b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0c5f90
171f55b
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import io
import os
from typing import Union, List

import requests
import uvicorn
from fastapi import BackgroundTasks, FastAPI

import model
from audio_to_text import AudioPipeline
from text_to_img import init_text2img_pipe, predict
from utils import read_from_url
from minio import Minio
import uuid

app = FastAPI()


def write_scan_audio_result(audio_id: int, scans: List[int], url: str, callback: str):
    score_general_threshold = 0.35
    score_character_threshold = 0.85

    image_url = audio_file_to_image_url(url)

    if image_url is None:
        image_url = ""

    print(image_url)

    callBackReq = model.AudioScanCallbackRequest(id=audio_id, isValid=True, image_url=image_url)

    try:
        requests.post(callback, json=callBackReq.dict())
    except Exception as ex:
        print(ex)

    # tags = list(map(lambda x: model.AudioScanTag(type="Moderation",
    #                                                   confidence=x['confidence']), nsfw_tags))

    ret = model.AudioScanResponse(ok=True, error="", deleted=False, blockedFor=[], tags=None)
    return ret


def write_scan_model_result(model_name: str, callback: str):
    pass


@app.post("/audio-scan")
async def image_scan_handler(req: model.AudioScanRequest, background_tasks: BackgroundTasks):
    if not req.wait:
        background_tasks.add_task(write_scan_audio_result,
                                  audio_id=req.audioId,
                                  scans=req.scans,
                                  url=req.url, callback=req.callbackUrl)
        return model.AudioScanResponse(ok=True, error="", deleted=False, blockedFor=[], tags=[])
    else:
        ret = write_scan_audio_result(audio_id=req.audioId,
                                      scans=req.scans,
                                      url=req.url, callback=req.callbackUrl)
        return ret


def audio_file_to_image_url(audiofile):
    file_path = read_from_url(audiofile)
    text = auto_pipeline.audio2txt(file_path)
    negative_prompt = [
        "(watermark:2)", "signature", "username", "(text:2)", "website",
        "(worst quality:2)", "(low quality:2)", "(normal quality:2)", "polar lowres", "jpeg",
        "((monochrome))", "((grayscale))", "sketches", "Paintings",
        "(blurry:2)", "cropped", "lowres", "error", "sketches",
        "(duplicate:1.331)", "(morbid:1.21)", "(mutilated:1.21)", "(tranny:1.331)",
        "(bad proportions:1.331)",
    ]
    images = predict(text, " ".join(negative_prompt), text2img_pipeline)
    for image in images:
        in_mem_file = io.BytesIO()
        image.save(in_mem_file, format='png', pnginfo=None)
        in_mem_file.seek(0)

        object_name = uuid.uuid4()

        s3_client.put_object(
            bucket_name=os.environ.get("S3_BUCKET"),
            object_name=object_name,
            data=in_mem_file,
            length=in_mem_file.getbuffer().nbytes,
            content_type="image/png"
        )
        image_url = f'{os.environ.get("PUB_VIEW_URL")}/{object_name}'
        return image_url


global auto_pipeline, text2img_pipeline, s3_client

if __name__ == "__main__":
    auto_pipeline = AudioPipeline(audio_text_path='/home/user/app/dedup_audio_text_80.json',
                                  audio_text_embeddings_path='/home/user/app/audio_text_embeddings_cpu.safetensors')
    text2img_pipeline = init_text2img_pipe()
    s3_client = Minio(
        os.environ.get("S3_ENDPOINT"),
        access_key=os.environ.get("S3_ACCESS_KEY"),
        secret_key=os.environ.get("S3_SECRET_KEY"),
    )

    uvicorn.run(app, host="0.0.0.0", port=7860)