File size: 3,999 Bytes
a616a2e
dd4c06b
4b18df1
bab5632
8607936
0bbcfe0
dd4c06b
7f92c21
 
 
bca5261
7f92c21
 
a4e4751
 
7f92c21
3d7bd2f
a4e4751
d3b8a9b
8607936
52cfee9
7f92c21
8607936
 
 
7f92c21
8607936
 
 
 
 
 
 
 
 
 
 
 
bca5261
7f92c21
 
e648c2d
 
3d7bd2f
 
e648c2d
d3b8a9b
e648c2d
 
5576fae
8e2fde3
5576fae
 
bab5632
d3b8a9b
 
 
 
e648c2d
 
d3b8a9b
7f92c21
 
d3b8a9b
7f92c21
 
 
 
 
 
 
 
 
 
6e40332
0bbcfe0
 
3d7bd2f
7f92c21
 
 
0bbcfe0
6e40332
 
a28c209
be20f8e
7f92c21
e828a9f
7f92c21
 
6e40332
 
 
 
 
8607936
7f92c21
 
 
 
6e40332
a616a2e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import spaces
import gradio as gr
import os
import orjson
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, AutoModelForCausalLM, AutoTokenizer

transcribe_model = None
proofread_model = None

@spaces.GPU(duration=60)
def transcribe_audio(audio):
    global transcribe_model
    if audio is None:
        return "Please upload an audio file."
    if transcribe_model is None:
        return "Please load the transcription model first."

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

    processor = AutoProcessor.from_pretrained(transcribe_model)

    pipe = pipeline(
        "automatic-speech-recognition",
        model=transcribe_model,
        tokenizer=processor.tokenizer,
        feature_extractor=processor.feature_extractor,
        max_new_tokens=128,
        chunk_length_s=25,
        batch_size=16,
        torch_dtype=torch_dtype,
        device=device,
    )

    result = pipe(audio)
    return result["text"]

@spaces.GPU(duration=120)
def proofread(text):
    global proofread_model
    if text is None:
        return "Please provide the transcribed text for proofreading."
    if proofread_model is None:
        return "Please load the proofreading model first."
    
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    
    messages = [
        {"role": "system", "content": "用繁體中文語體文整理這段文字,在最後加上整段文字的重點。"},
        {"role": "user", "content": text},
    ]
    
    inputs = proofread_model.tokenizer(messages, return_tensors="tf", padding=True)
    outputs = proofread_model.generate(**inputs)
    proofread_text = proofread_model.tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return proofread_text

@spaces.GPU(duration=120)
def load_models(transcribe_model_id, proofread_model_id):
    global transcribe_model, proofread_model
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    
    transcribe_model = AutoModelForSpeechSeq2Seq.from_pretrained(
        transcribe_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
    )
    transcribe_model.to(device)
    
    proofread_model = AutoModelForCausalLM.from_pretrained(proofread_model_id)
    proofread_model.to(device)

with gr.Blocks() as demo:
    gr.Markdown("""
                # Audio Transcription and Proofreading
                1. Select models for transcription and proofreading and load them
                2. Upload an audio file (Wait for the file to be fully loaded first)
                3. Transcribe the audio
                4. Proofread the transcribed text
                """)

    with gr.Row():
        transcribe_model_dropdown = gr.Dropdown(choices=["openai/whisper-large-v2", "alvanlii/whisper-small-cantonese"], value="alvanlii/whisper-small-cantonese", label="Select Transcription Model")
        proofread_model_dropdown = gr.Dropdown(choices=["hfl/llama-3-chinese-8b-instruct-v3"], value="hfl/llama-3-chinese-8b-instruct-v3", label="Select Proofreading Model")
        load_button = gr.Button("Load Models")

    audio = gr.Audio(sources="upload", type="filepath")
    
    transcribe_button = gr.Button("Transcribe")
    transcribed_text = gr.Textbox(label="Transcribed Text")
    
    proofread_button = gr.Button("Proofread")
    proofread_output = gr.Textbox(label="Proofread Text")

    load_button.click(load_models, inputs=[transcribe_model_dropdown, proofread_model_dropdown])
    transcribe_button.click(transcribe_audio, inputs=audio, outputs=transcribed_text)
    proofread_button.click(proofread, inputs=transcribed_text, outputs=proofread_output)
    transcribed_text.change(proofread, inputs=transcribed_text, outputs=proofread_output)

demo.launch()