File size: 4,126 Bytes
42a2568
 
9b150b6
42a2568
9b150b6
42a2568
 
 
 
 
 
 
 
 
9b150b6
42a2568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from myrpunct import RestorePuncts
from youtube_transcript_api import YouTubeTranscriptApi
import gradio as gr
import re

def get_srt(input_link):
    if "v=" in input_link:
        video_id = input_link.split("v=")[1]
    else:
        return "Error: Invalid Link, it does not have the pattern 'v=' in it."
    print("video_id: ",video_id)
    transcript_raw = YouTubeTranscriptApi.get_transcript(video_id)
    transcript_text= '\n'.join([i['text'] for i in transcript_raw])
    return transcript_text

def predict(input_text, input_file, input_link, input_checkbox):

    if input_checkbox=="File" and input_file is not None:
        print("Input File ...")
        with open(input_file.name) as file:
            input_file_read = file.read()
        return run_predict(input_file_read)
    elif input_checkbox=="Text" and len(input_text) >0:
        print("Input Text ...")
        return run_predict(input_text)
    elif input_checkbox=="Link" and len(input_link)>0:
        print("Input Link ...", input_link)
        input_link_text = get_srt(input_link)
        if "Error" in input_link_text:
            return input_link_text
        else:
            return run_predict(input_link_text)
    else:
        return "Error: Please provide either an input text or file and select an option accordingly."

def run_predict(input_text):
    rpunct = RestorePuncts()
    output_text = rpunct.punctuate(input_text)
    print("Punctuation finished...")
    
    # restore the carrige returns 
    srt_file = input_text
    punctuated = output_text

    srt_file_strip=srt_file.strip()
    srt_file_sub=re.sub('\s*\n\s*','# ',srt_file_strip)
    srt_file_array=srt_file_sub.split(' ')
    pcnt_file_array=punctuated.split(' ')

    # goal: restore the break points i.e. the same number of lines as the srt file
    # this is necessary, because each line in the srt file corresponds to a frame from the video
    if len(srt_file_array)!=len(pcnt_file_array):
        return "AssertError: The length of the transcript and the punctuated file should be the same: ",len(srt_file_array),len(pcnt_file_array)
    pcnt_file_array_hash = []
    for idx, item in enumerate(srt_file_array):
        if item.endswith('#'):
            pcnt_file_array_hash.append(pcnt_file_array[idx]+'#')
        else:
            pcnt_file_array_hash.append(pcnt_file_array[idx])

    # assemble the array back to a string
    pcnt_file_cr=' '.join(pcnt_file_array_hash).replace('#','\n')

    return pcnt_file_cr
 
if __name__ == "__main__":

    title = "Rpunct Gradio App"
    description = """
<b>Description</b>: <br>
Model restores punctuation and case i.e. of the following punctuations -- [! ? . , - : ; ' ] and also the upper-casing of words. <br>
<b>Usage</b>: <br>
There are three input types any text, a file that can be uploaded or a YouTube video. <br>
Because all three options can be provided by the user (that is you) at the same time <br>
the user has to decisde which input type has to be processed.
"""
    article = "Model by [felflare](https://huggingface.co/felflare/bert-restore-punctuation)"
    
    sample_link = "https://www.youtube.com/watch?v=6MI0f6YjJIk"

    examples = [["my name is clara and i live in berkeley california", "sample.srt", sample_link, "Text"]]

    interface = gr.Interface(fn = predict,
                         inputs = ["text", "file", "text", gr.Radio(["Text", "File", "Link"], type="value", label='Input Type')],
                         outputs = ["text"],
                         title = title,
                         description = description, 
                         article = article, 
                         examples=examples, 
                         allow_flagging="never")

    interface.launch() 

# save flagging to a hf dataset
# https://github.com/gradio-app/gradio/issues/914
# the best option here is to use a Hugging Face dataset as the storage for flagged data. And to do that, please check out the HuggingFaceDatasetSaver() flagging handler, which allows you to do that easily.
#Here is an example Space that uses this: https://huggingface.co/spaces/abidlabs/crowd-speech