ginic commited on
Commit
09d4e3b
·
1 Parent(s): 44993c6

Added TextGrid output to model with download button

Browse files
Files changed (1) hide show
  1. app.py +94 -36
app.py CHANGED
@@ -1,10 +1,16 @@
1
  from pathlib import Path
 
2
 
3
  import gradio as gr
4
-
 
 
5
  from transformers import pipeline
6
 
 
7
  DEFAULT_MODEL = "ginic/data_seed_bs64_4_wav2vec2-large-xlsr-53-buckeye-ipa"
 
 
8
 
9
 
10
  VALID_MODELS = [
@@ -23,10 +29,27 @@ VALID_MODELS = [
23
  "ginic/gender_split_70_female_3_wav2vec2-large-xlsr-53-buckeye-ipa",
24
  "ginic/gender_split_70_female_4_wav2vec2-large-xlsr-53-buckeye-ipa",
25
  "ginic/gender_split_70_female_5_wav2vec2-large-xlsr-53-buckeye-ipa",
 
 
 
 
 
 
26
  ]
27
 
28
 
29
- def load_model_and_predict(model_name: str, audio_in: str, model_state: dict):
 
 
 
 
 
 
 
 
 
 
 
30
  if model_state["model_name"] != model_name:
31
  model_state = {
32
  "loaded_model": pipeline(
@@ -35,16 +58,50 @@ def load_model_and_predict(model_name: str, audio_in: str, model_state: dict):
35
  "model_name": model_name,
36
  }
37
 
 
38
  return (
39
- model_state["loaded_model"](audio_in)["text"],
40
  model_state,
41
- gr.DownloadButton("Download TextGrid file", visible=True),
 
 
 
 
42
  )
43
 
44
 
45
- def download_textgrid(audio_in, textgrid_tier_name, prediction):
46
- # TODO
47
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
 
50
  def launch_demo():
@@ -71,45 +128,46 @@ def launch_demo():
71
 
72
  prediction = gr.Textbox(label="Predicted IPA transcription")
73
 
 
 
 
74
  textgrid_tier = gr.Textbox(
75
  label="TextGrid Tier Name", value="transcription", interactive=True
76
  )
77
 
78
- download_btn = gr.DownloadButton("Download TextGrid file", visible=False)
 
 
79
 
80
- # If user updates model name or audio, run prediction
81
- audio_in.input(
82
- fn=load_model_and_predict,
83
- inputs=[model_name, audio_in, model_state],
84
- outputs=[prediction, model_state, download_btn],
 
 
 
 
 
85
  )
86
- model_name.change(
 
 
 
87
  fn=load_model_and_predict,
88
  inputs=[model_name, audio_in, model_state],
89
- outputs=[prediction, model_state, download_btn],
 
 
 
 
 
 
 
 
90
  )
91
 
92
- # demo = gr.Interface(
93
- # fn=load_model_and_predict,
94
- # inputs=[
95
- # gr.Dropdown(
96
- # VALID_MODELS,
97
- # value=DEFAULT_MODEL,
98
- # label="IPA transcription ASR model",
99
- # info="Select the model to use for prediction.",
100
- # ),
101
- # gr.Audio(type="filepath", show_download_button=True),
102
- # gr.State(
103
- # value=initial_model
104
- # ), # Store the name of the currently loaded model
105
- # ],
106
- # outputs=[gr.Textbox(label="Predicted IPA transcription"), gr.State()],
107
- # allow_flagging="never",
108
- # title="Automatic International Phonetic Alphabet Transcription",
109
- # description="This demo allows you to experiment with producing phonetic transcriptions of uploaded or recorded audio using a selected automatic speech recognition (ASR) model.",
110
- # )
111
-
112
- demo.launch()
113
 
114
 
115
  if __name__ == "__main__":
 
1
  from pathlib import Path
2
+ import tempfile
3
 
4
  import gradio as gr
5
+ import librosa
6
+ import tgt.core
7
+ import tgt.io3
8
  from transformers import pipeline
9
 
10
+ TEXTGRID_DIR = tempfile.mkdtemp()
11
  DEFAULT_MODEL = "ginic/data_seed_bs64_4_wav2vec2-large-xlsr-53-buckeye-ipa"
12
+ TEXTGRID_DOWNLOAD_TEXT = "Download TextGrid file"
13
+ TEXTGRID_NAME_INPUT_LABEL = "TextGrid file name"
14
 
15
 
16
  VALID_MODELS = [
 
29
  "ginic/gender_split_70_female_3_wav2vec2-large-xlsr-53-buckeye-ipa",
30
  "ginic/gender_split_70_female_4_wav2vec2-large-xlsr-53-buckeye-ipa",
31
  "ginic/gender_split_70_female_5_wav2vec2-large-xlsr-53-buckeye-ipa",
32
+ "ginic/vary_individuals_old_only_1_wav2vec2-large-xlsr-53-buckeye-ipa",
33
+ "ginic/vary_individuals_old_only_2_wav2vec2-large-xlsr-53-buckeye-ipa",
34
+ "ginic/vary_individuals_old_only_3_wav2vec2-large-xlsr-53-buckeye-ipa",
35
+ "ginic/vary_individuals_young_only_1_wav2vec2-large-xlsr-53-buckeye-ipa",
36
+ "ginic/vary_individuals_young_only_2_wav2vec2-large-xlsr-53-buckeye-ipa",
37
+ "ginic/vary_individuals_young_only_3_wav2vec2-large-xlsr-53-buckeye-ipa",
38
  ]
39
 
40
 
41
+ def load_model_and_predict(
42
+ model_name: str,
43
+ audio_in: str,
44
+ model_state: dict,
45
+ ):
46
+ if audio_in is None:
47
+ return (
48
+ "",
49
+ model_state,
50
+ gr.Textbox(label=TEXTGRID_NAME_INPUT_LABEL, interactive=False),
51
+ )
52
+
53
  if model_state["model_name"] != model_name:
54
  model_state = {
55
  "loaded_model": pipeline(
 
58
  "model_name": model_name,
59
  }
60
 
61
+ prediction = model_state["loaded_model"](audio_in)["text"]
62
  return (
63
+ prediction,
64
  model_state,
65
+ gr.Textbox(
66
+ label=TEXTGRID_NAME_INPUT_LABEL,
67
+ interactive=True,
68
+ value=Path(audio_in).with_suffix(".TextGrid").name,
69
+ ),
70
  )
71
 
72
 
73
+ def get_textgrid_contents(audio_in, textgrid_tier_name, transcription_prediction):
74
+ if audio_in is None or transcription_prediction is None:
75
+ return ""
76
+
77
+ duration = librosa.get_duration(path=audio_in)
78
+
79
+ annotation = tgt.core.Interval(0, duration, transcription_prediction)
80
+ transcription_tier = tgt.core.IntervalTier(
81
+ start_time=0, end_time=duration, name=textgrid_tier_name
82
+ )
83
+ transcription_tier.add_annotation(annotation)
84
+ textgrid = tgt.core.TextGrid()
85
+ textgrid.add_tier(transcription_tier)
86
+ return tgt.io3.export_to_long_textgrid(textgrid)
87
+
88
+
89
+ def write_textgrid(textgrid_contents, textgrid_filename):
90
+ """Writes the text grid contents to a named file in the temporary directory.
91
+ Returns the path for download.
92
+ """
93
+ textgrid_path = Path(TEXTGRID_DIR) / Path(textgrid_filename).name
94
+ textgrid_path.write_text(textgrid_contents)
95
+ return textgrid_path
96
+
97
+
98
+ def get_interactive_download_button(textgrid_contents, textgrid_filename):
99
+ return gr.DownloadButton(
100
+ label=TEXTGRID_DOWNLOAD_TEXT,
101
+ variant="primary",
102
+ interactive=True,
103
+ value=write_textgrid(textgrid_contents, textgrid_filename),
104
+ )
105
 
106
 
107
  def launch_demo():
 
128
 
129
  prediction = gr.Textbox(label="Predicted IPA transcription")
130
 
131
+ gr.Markdown("""## TextGrid File Options
132
+ Change these inputs if you'd like to customize and download the transcription in [TextGrid format](https://www.fon.hum.uva.nl/praat/manual/TextGrid_file_formats.html) for Praat.
133
+ """)
134
  textgrid_tier = gr.Textbox(
135
  label="TextGrid Tier Name", value="transcription", interactive=True
136
  )
137
 
138
+ textgrid_filename = gr.Textbox(
139
+ label=TEXTGRID_NAME_INPUT_LABEL, interactive=False
140
+ )
141
 
142
+ textgrid_contents = gr.Textbox(
143
+ label="TextGrid Contents",
144
+ value=get_textgrid_contents,
145
+ inputs=[audio_in, textgrid_tier, prediction],
146
+ )
147
+
148
+ download_btn = gr.DownloadButton(
149
+ label=TEXTGRID_DOWNLOAD_TEXT,
150
+ interactive=False, # Don't allow download button to be active until an upload happened
151
+ variant="primary",
152
  )
153
+
154
+ # Update prediction if model or audio changes
155
+ gr.on(
156
+ triggers=[audio_in.input, model_name.change],
157
  fn=load_model_and_predict,
158
  inputs=[model_name, audio_in, model_state],
159
+ outputs=[prediction, model_state, textgrid_filename],
160
+ )
161
+
162
+ # Download button becomes interactive if user updates audio or textgrid params
163
+ gr.on(
164
+ triggers=[textgrid_contents.change, textgrid_filename.change],
165
+ fn=get_interactive_download_button,
166
+ inputs=[textgrid_contents, textgrid_filename],
167
+ outputs=[download_btn],
168
  )
169
 
170
+ demo.launch(max_file_size="100mb")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
 
173
  if __name__ == "__main__":