File size: 2,325 Bytes
cf32bd5
 
 
 
 
 
7138209
cf32bd5
 
 
 
 
 
 
 
b2c1876
 
 
 
cf32bd5
 
 
 
 
 
 
 
 
 
 
 
 
b2c1876
 
 
 
cf32bd5
b2c1876
cf32bd5
 
 
 
bbb642a
 
 
 
 
 
 
cf32bd5
 
eb6c3f3
cf32bd5
 
 
 
 
 
 
 
bbb642a
 
b5778d6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Gaepago model V1 (CPU Test)

# import package
from transformers import AutoModelForAudioClassification
from transformers import AutoFeatureExtractor
from transformers import pipeline
from datasets import Dataset, Audio
import gradio as gr
import torch

# Set model & Dataset NM
MODEL_NAME = "Gae8J/gaepago-20"
DATASET_NAME = "Gae8J/modeling_v1"

# Import Model & feature extractor
# model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME)
from transformers import AutoConfig
config = AutoConfig.from_pretrained(MODEL_NAME)
model = torch.jit.load(f"./model/gaepago-20-lite/model_quant_int8.pt")
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)

# ๋ชจ๋ธ cpu๋กœ ๋ณ€๊ฒฝํ•˜์—ฌ ์ง„ํ–‰
model.to("cpu")

# Gaepago Inference Model function
def gaepago_fn(tmp_audio_dir):
    print(tmp_audio_dir)
    audio_dataset = Dataset.from_dict({"audio": [tmp_audio_dir]}).cast_column("audio", Audio(sampling_rate=16000))
    inputs = feature_extractor(audio_dataset[0]["audio"]["array"]
                               ,sampling_rate=audio_dataset[0]["audio"]["sampling_rate"]
                               ,return_tensors="pt")
    with torch.no_grad():
        # logits = model(**inputs).logits
        logits = model(**inputs)["logits"]        
    # predicted_class_ids = torch.argmax(logits).item()
    # predicted_label = model.config.id2label[predicted_class_ids]
    predicted_class_ids = torch.argmax(logits).item()
    predicted_label = config.id2label[predicted_class_ids]
    
    return predicted_label

# Main
example_list = ["./sample/bark_sample.wav"
                ,"./sample/growling_sample.wav"
                ,"./sample/howl_sample.wav"
                ,"./sample/panting_sample.wav"
                ,"./sample/whimper_sample.wav"
               ]

main_api = gr.Blocks()

with main_api:
    gr.Markdown("## 8J Gaepago Demo(with CPU)")
    with gr.Row():
        audio = gr.Audio(source="microphone", type="filepath"
                         ,label='๋…น์Œ๋ฒ„ํŠผ์„ ๋ˆŒ๋Ÿฌ ์ดˆ์ฝ”๊ฐ€ ํ•˜๋Š” ๋ง์„ ๋“ค๋ ค์ฃผ์„ธ์š”')
        transcription = gr.Textbox(label='์ง€๊ธˆ ์ดˆ์ฝ”๊ฐ€ ํ•˜๋Š” ๋ง์€...')
    b1 = gr.Button("๊ฐ•์•„์ง€ ์–ธ์–ด ๋ฒˆ์—ญ!")

    b1.click(gaepago_fn, inputs=audio, outputs=transcription)
    examples = gr.Examples(examples=example_list, inputs=[audio])

main_api.launch()