Spaces:
Paused
Paused
File size: 2,835 Bytes
cf32bd5 7138209 cf32bd5 cdf69a2 327e3b5 539da00 cf32bd5 327e3b5 cf32bd5 b2c1876 cf32bd5 327e3b5 cf32bd5 b1d0b95 cf32bd5 b2c1876 cf32bd5 b2c1876 cf32bd5 327e3b5 f7d8526 327e3b5 cf32bd5 bbb642a cf32bd5 d73593f cf32bd5 2fd5502 bbb642a a3b778d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
# Gaepago model V1 (CPU Test)
# import package
from transformers import AutoModelForAudioClassification
from transformers import AutoFeatureExtractor
from transformers import pipeline
from datasets import Dataset, Audio
import gradio as gr
import torch
from utils.postprocess import text_mapping,text_encoding
import json
import os
# Set model & Dataset NM
MODEL_NAME = "Gae8J/gaepago-20"
DATASET_NAME = "Gae8J/modeling_v1"
TEXT_LABEL = "text_label.json"
# Import Model & feature extractor
# model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME)
from transformers import AutoConfig
config = AutoConfig.from_pretrained(MODEL_NAME)
model = torch.jit.load(f"./model/gaepago-20-lite/model_quant_int8.pt")
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
# ๋ชจ๋ธ cpu๋ก ๋ณ๊ฒฝํ์ฌ ์งํ
model.to("cpu")
# TEXT LABEL ๋ถ๋ฌ์ค๊ธฐ
with open(TEXT_LABEL,"r",encoding='utf-8') as f:
text_label = json.load(f)
# Gaepago Inference Model function
def gaepago_fn(tmp_audio_dir):
# if os.path.isfile(tmp_audio_dir):
print(tmp_audio_dir)
# else:
# ## khan test
# tmp_audio_dir = './sample/bark_sample.wav'
audio_dataset = Dataset.from_dict({"audio": [tmp_audio_dir]}).cast_column("audio", Audio(sampling_rate=16000))
inputs = feature_extractor(audio_dataset[0]["audio"]["array"]
,sampling_rate=audio_dataset[0]["audio"]["sampling_rate"]
,return_tensors="pt")
with torch.no_grad():
# logits = model(**inputs).logits
logits = model(**inputs)["logits"]
# predicted_class_ids = torch.argmax(logits).item()
# predicted_label = model.config.id2label[predicted_class_ids]
predicted_class_ids = torch.argmax(logits).item()
predicted_label = config.id2label[predicted_class_ids]
# add postprocessing
## 1. text mapping
output = text_mapping(predicted_label,text_label)
# output = text_encoding(output)
return output
# Main
example_list = ["./sample/bark_sample.wav"
,"./sample/growling_sample.wav"
,"./sample/howl_sample.wav"
,"./sample/panting_sample.wav"
,"./sample/whimper_sample.wav"
]
main_api = gr.Blocks()
with main_api as demo:
gr.Markdown("## 8J Gaepago Demo(with CPU)")
with gr.Row():
audio = gr.Audio(source="microphone", type="filepath"
,label='๋
น์๋ฒํผ์ ๋๋ฌ ์ด์ฝ๊ฐ ํ๋ ๋ง์ ๋ค๋ ค์ฃผ์ธ์')
transcription = gr.Textbox(label='์ง๊ธ ์ด์ฝ๊ฐ ํ๋ ๋ง์...')
b1 = gr.Button("๊ฐ์์ง ์ธ์ด ๋ฒ์ญ!")
b1.click(gaepago_fn, inputs=audio, outputs=transcription,api_name="predict")
examples = gr.Examples(examples=example_list, inputs=[audio])
demo.launch(show_error=True) |