File size: 3,120 Bytes
ae466a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import time
import tempfile
from math import floor
from typing import Optional, List, Dict, Any

import torch
import gradio as gr
from transformers import pipeline
from transformers.pipelines.audio_utils import ffmpeg_read


# configuration
MODEL_NAME = "federerjiang/mambavoice-ja-v1"
BATCH_SIZE = 16
CHUNK_LENGTH_S = 15
FILE_LIMIT_MB = 1000
TOKEN = os.environ.get('HF_TOKEN', None)
# device setting
if torch.cuda.is_available():
    torch_dtype = torch.bfloat16
    device = "cuda:0"
else:
    torch_dtype = torch.float32
    device = "cpu"
# define the pipeline
pipe = pipeline(
    model=MODEL_NAME,
    chunk_length_s=CHUNK_LENGTH_S,
    batch_size=BATCH_SIZE,
    torch_dtype=torch_dtype,
    device=device,
    trust_remote_code=True,
    token=TOKEN,
)


def get_prediction(inputs, prompt: Optional[str]=None):
    generate_kwargs = {
        "language": "japanese", 
        "task": "transcribe",
        "length_penalty": 0,
        "num_beams": 2,
    }
    if prompt:
        generate_kwargs['prompt_ids'] = pipe.tokenizer.get_prompt_ids(prompt, return_tensors='pt').to(device)
    prediction = pipe(inputs, return_timestamps=True, generate_kwargs=generate_kwargs)
    text = "".join([c['text'] for c in prediction['chunks']])
    return text


def transcribe(inputs: str):
    if inputs is None:
        raise gr.Error("音声ファイルが送信されていません!リクエストを送信する前に、音声ファイルをアップロードまたは録音してください。")
    with open(inputs, "rb") as f:
        inputs = f.read()
    inputs = ffmpeg_read(inputs, pipe.feature_extractor.sampling_rate)
    inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
    return get_prediction(inputs)

demo = gr.Blocks()
mf_transcribe = gr.Interface(
    fn=transcribe,
    inputs=[
        gr.Audio(sources=["microphone"], type="filepath"),
    ],
    outputs=["text"],
    # layout="horizontal",
    theme="huggingface",
    title=f"オーディオをMambaVoice-v1で文字起こしする",
    description=f"ボタンをクリックするだけで、長時間のマイク入力やオーディオ入力を文字起こしできます!デモではMambaVoice-v1モデルを使用しており、任意の長さの音声ファイルを文字起こしすることができます。",
    allow_flagging="never",
)

file_transcribe = gr.Interface(
    fn=transcribe,
    inputs=[
        gr.Audio(sources=["upload"], type="filepath", label="Audio file"),
    ],
    outputs=["text"],
    # layout="horizontal",
    theme="huggingface",
    title=f"オーディオをMambaVoice-v1で文字起こしする",
    description=f"ボタンをクリックするだけで、長時間のマイク入力やオーディオ入力を文字起こしできます!デモではMambaVoice-v1モデルを使用しており、任意の長さの音声ファイルを文字起こしすることができます。",
    allow_flagging="never",
)

with demo:
    gr.TabbedInterface([mf_transcribe, file_transcribe], ["Microphone", "Audio file"])

demo.queue(max_size=10)
demo.launch()