Spaces:
Paused
Paused
# Gaepago model V1 (CPU Test) | |
# import package | |
from transformers import AutoModelForAudioClassification | |
from transformers import AutoFeatureExtractor | |
from transformers import pipeline | |
from datasets import Dataset, Audio | |
import gradio as gr | |
import torch | |
# Set model & Dataset NM | |
MODEL_NAME = "Gae8J/gaepago-20" | |
DATASET_NAME = "Gae8J/modeling_v1" | |
# Import Model & feature extractor | |
# model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME) | |
from transformers import AutoConfig | |
config = AutoConfig.from_pretrained(MODEL_NAME) | |
model = torch.jit.load(f"./model/gaepago-20-lite/model_quant_int8.pt") | |
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME) | |
# ๋ชจ๋ธ cpu๋ก ๋ณ๊ฒฝํ์ฌ ์งํ | |
model.to("cpu") | |
# Gaepago Inference Model function | |
def gaepago_fn(tmp_audio_dir): | |
print(tmp_audio_dir) | |
audio_dataset = Dataset.from_dict({"audio": [tmp_audio_dir]}).cast_column("audio", Audio(sampling_rate=16000)) | |
inputs = feature_extractor(audio_dataset[0]["audio"]["array"] | |
,sampling_rate=audio_dataset[0]["audio"]["sampling_rate"] | |
,return_tensors="pt") | |
with torch.no_grad(): | |
# logits = model(**inputs).logits | |
logits = model(**inputs)["logits"] | |
# predicted_class_ids = torch.argmax(logits).item() | |
# predicted_label = model.config.id2label[predicted_class_ids] | |
predicted_class_ids = torch.argmax(logits).item() | |
predicted_label = config.id2label[predicted_class_ids] | |
return predicted_label | |
# Main | |
main_api = gr.Blocks() | |
with main_api: | |
gr.Markdown("## 8J Gaepago Demo(with CPU)") | |
with gr.Row(): | |
audio = gr.Audio(source="microphone", type="filepath" | |
,label='๋ น์๋ฒํผ์ ๋๋ฌ ์ด์ฝ๊ฐ ํ๋ ๋ง์ ๋ค๋ ค์ฃผ์ธ์') | |
transcription = gr.Textbox(label='์ง๊ธ ์ด์ฝ๊ฐ ํ๋ ๋ง์...') | |
b1 = gr.Button("๊ฐ์์ง ์ธ์ด ๋ฒ์ญ!") | |
b1.click(gaepago_fn, inputs=audio, outputs=transcription) | |
# examples = gr.Examples(examples=example_list, | |
# inputs=[audio]) | |
main_api.launch(share=True) |