antonin perrot-audet commited on
Commit
9a32f24
·
1 Parent(s): fc9da2a

fix gradio client for huggingface Space

Browse files
Files changed (1) hide show
  1. app.py +77 -67
app.py CHANGED
@@ -2,24 +2,29 @@
2
 
3
  import os
4
  import gradio as gr
 
 
5
 
6
  from dotenv import load_dotenv
7
  from pydub import AudioSegment
8
  from tqdm.auto import tqdm
9
- print('starting')
10
 
11
- load_dotenv()
12
 
13
- from gradio_client import Client
14
 
15
  HF_API = os.getenv("HF_API")
16
  SEAMLESS_API_URL = os.getenv("SEAMLESS_API_URL") # path to Seamlessm4t API endpoint
17
  GPU_AVAILABLE = os.getenv("GPU_AVAILABLE")
18
  DEFAULT_TARGET_LANGUAGE = "French"
19
- MISTRAL_SUMMARY_URL= "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2"
20
- LLAMA_SUMMARY_URL="https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"
 
 
 
 
21
 
22
- print('env setup ok')
23
 
24
 
25
  DESCRIPTION = """
@@ -34,29 +39,31 @@ To duplicate this repo, you have to give permission from three reopsitories and
34
 
35
  """
36
  from pyannote.audio import Pipeline
37
- #initialize diarization pipeline
 
38
  diarizer = Pipeline.from_pretrained(
39
- "pyannote/speaker-diarization-3.1",
40
- use_auth_token=HF_API)
41
  # send pipeline to GPU (when available)
42
  import torch
 
43
  diarizer.to(torch.device(GPU_AVAILABLE))
44
 
45
- print('diarizer setup ok')
46
 
47
 
48
  # predict is a generator that incrementally yields recognized text with speaker label
49
  def predict(target_language, input_audio):
50
- print('->predict started')
51
  print(target_language, type(input_audio), input_audio)
52
 
53
- print('-->diarization')
54
  diarized = diarizer(input_audio, min_speakers=2, max_speakers=5)
55
-
56
- print('-->automatic speech recognition')
57
  # split audio according to diarization
58
  song = AudioSegment.from_wav(input_audio)
59
- client = Client(SEAMLESS_API_URL, hf_token=HF_API)
60
  output_text = ""
61
  for turn, _, speaker in diarized.itertracks(yield_label=True):
62
  print(speaker, turn)
@@ -64,11 +71,7 @@ def predict(target_language, input_audio):
64
  clipped = song[turn.start * 1000 : turn.end * 1000]
65
  clipped.export(f"my.wav", format="wav", bitrate=16000)
66
 
67
- result = client.predict(
68
- f"my.wav",
69
- target_language,
70
- api_name="/asr"
71
- )
72
 
73
  current_text = f"speaker: {speaker} text: {result} "
74
  print(current_text)
@@ -81,11 +84,25 @@ def predict(target_language, input_audio):
81
  print(e)
82
 
83
 
84
- import requests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
 
87
  def generate_summary_llama3(language, transcript):
88
- queryTxt = f'''
89
  <|begin_of_text|><|start_header_id|>system<|end_header_id|>
90
 
91
  You are a helpful and truthful patient-doctor encounter summary writer.
@@ -108,26 +125,27 @@ The summary only includes relevant sections.
108
  {transcript}
109
  </transcript><|eot_id|>
110
  <|start_header_id|>assistant<|end_header_id|>
111
- '''
112
-
113
  payload = {
114
- "inputs": queryTxt,
115
- "parameters": {
116
- "return_full_text": False,
117
- "wait_for_model": True,
118
- "min_length": 1000
119
- },
120
- "options": {
121
- "use_cache": False
122
- }
123
  }
124
 
125
- response = requests.post(LLAMA_SUMMARY_URL, headers = {"Authorization": f"Bearer {HF_API}"}, json=payload)
 
 
126
  print(response.json())
127
- return response.json()[0]['generated_text'][len('<summary>'):]
 
128
 
129
  def generate_summary_mistral(language, transcript):
130
- sysPrompt = f'''<s>[INST]
131
  You are a helpful and truthful patient-doctor encounter summary writer.
132
  Users sends you transcripts of patient-doctor encounter and you create accurate and concise summaries.
133
  The summary only contains informations from the transcript.
@@ -143,41 +161,43 @@ The summary only includes relevant sections.
143
  # Additional Notes
144
  </template>
145
 
146
- '''
147
- queryTxt=f'''
148
  <transcript>
149
  {transcript}
150
  </transcript>
151
  [/INST]
152
- '''
153
-
154
  payload = {
155
- "inputs": sysPrompt + queryTxt,
156
- "parameters": {
157
- "return_full_text": False,
158
- "wait_for_model": True,
159
- "min_length": 1000
160
- },
161
- "options": {
162
- "use_cache": False
163
- }
164
  }
165
 
166
- response = requests.post(MISTRAL_SUMMARY_URL, headers = {"Authorization": f"Bearer {HF_API}"}, json=payload)
 
 
167
  print(response.json())
168
- return response.json()[0]['generated_text'][len('<summary>'):]
 
169
 
170
  def generate_summary(model, language, transcript):
171
  match model:
172
  case "Mistral-7B":
173
  print("-> summarize with mistral")
174
- return generate_summary_mistral( language, transcript)
175
  case "LLAMA3":
176
  print("-> summarize with llama3")
177
  return generate_summary_llama3(language, transcript)
178
  case _:
179
  return f"Unknown model {model}"
180
 
 
181
  def update_audio_ui(audio_source: str) -> tuple[dict, dict]:
182
  mic = audio_source == "microphone"
183
  return (
@@ -191,16 +211,14 @@ with gr.Blocks() as demo:
191
  with gr.Group():
192
  with gr.Row():
193
  target_language = gr.Dropdown(
194
- choices= ["French", "English"],
195
  label="Output Language",
196
  value="French",
197
  interactive=True,
198
  info="Select your target language",
199
  )
200
  with gr.Row() as audio_box:
201
- input_audio = gr.Audio(
202
- type="filepath"
203
- )
204
  submit = gr.Button("Transcribe")
205
  transcribe_output = gr.Textbox(
206
  label="Transcribed Text",
@@ -212,16 +230,13 @@ with gr.Blocks() as demo:
212
  )
213
  submit.click(
214
  fn=predict,
215
- inputs=[
216
- target_language,
217
- input_audio
218
- ],
219
  outputs=[transcribe_output],
220
  api_name="predict",
221
  )
222
  with gr.Row():
223
  sumary_model = gr.Dropdown(
224
- choices= ["Mistral-7B", "LLAMA3"],
225
  label="Summary model",
226
  value="Mistral-7B",
227
  interactive=True,
@@ -238,15 +253,10 @@ with gr.Blocks() as demo:
238
  )
239
  summarize.click(
240
  fn=generate_summary,
241
- inputs=[
242
- sumary_model,
243
- target_language,
244
- transcribe_output
245
- ],
246
  outputs=[summary_output],
247
  api_name="predict",
248
  )
249
  gr.Markdown(DUPLICATE)
250
 
251
  demo.queue(max_size=50).launch()
252
-
 
2
 
3
  import os
4
  import gradio as gr
5
+ from gradio_client import Client
6
+ import requests
7
 
8
  from dotenv import load_dotenv
9
  from pydub import AudioSegment
10
  from tqdm.auto import tqdm
 
11
 
12
+ print("starting")
13
 
14
+ load_dotenv()
15
 
16
  HF_API = os.getenv("HF_API")
17
  SEAMLESS_API_URL = os.getenv("SEAMLESS_API_URL") # path to Seamlessm4t API endpoint
18
  GPU_AVAILABLE = os.getenv("GPU_AVAILABLE")
19
  DEFAULT_TARGET_LANGUAGE = "French"
20
+ MISTRAL_SUMMARY_URL = (
21
+ "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2"
22
+ )
23
+ LLAMA_SUMMARY_URL = (
24
+ "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"
25
+ )
26
 
27
+ print("env setup ok")
28
 
29
 
30
  DESCRIPTION = """
 
39
 
40
  """
41
  from pyannote.audio import Pipeline
42
+
43
+ # initialize diarization pipeline
44
  diarizer = Pipeline.from_pretrained(
45
+ "pyannote/speaker-diarization-3.1", use_auth_token=HF_API
46
+ )
47
  # send pipeline to GPU (when available)
48
  import torch
49
+
50
  diarizer.to(torch.device(GPU_AVAILABLE))
51
 
52
+ print("diarizer setup ok")
53
 
54
 
55
  # predict is a generator that incrementally yields recognized text with speaker label
56
  def predict(target_language, input_audio):
57
+ print("->predict started")
58
  print(target_language, type(input_audio), input_audio)
59
 
60
+ print("-->diarization")
61
  diarized = diarizer(input_audio, min_speakers=2, max_speakers=5)
62
+
63
+ print("-->automatic speech recognition")
64
  # split audio according to diarization
65
  song = AudioSegment.from_wav(input_audio)
66
+ client = Client(SEAMLESS_API_URL, hf_token=HF_API, serialize=False)
67
  output_text = ""
68
  for turn, _, speaker in diarized.itertracks(yield_label=True):
69
  print(speaker, turn)
 
71
  clipped = song[turn.start * 1000 : turn.end * 1000]
72
  clipped.export(f"my.wav", format="wav", bitrate=16000)
73
 
74
+ result = client.predict(f"my.wav", target_language, api_name="/asr")
 
 
 
 
75
 
76
  current_text = f"speaker: {speaker} text: {result} "
77
  print(current_text)
 
84
  print(e)
85
 
86
 
87
+ def automatic_speech_recognition(language, filename):
88
+ match language:
89
+ case "French":
90
+ api_url = "https://api-inference.huggingface.co/models/bofenghuang/whisper-large-v3-french"
91
+ case "English":
92
+ api_url = "https://api-inference.huggingface.co/models/facebook/wav2vec2-base-960h"
93
+ case _:
94
+ return f"Unknown language {language}"
95
+ print(f"-> automatic_speech_recognition with {api_url}")
96
+ with open(filename, "rb") as f:
97
+ data = f.read()
98
+ response = requests.post(
99
+ api_url, headers={"Authorization": f"Bearer {HF_API}"}, data=data
100
+ )
101
+ return response.json()["text"]
102
 
103
 
104
  def generate_summary_llama3(language, transcript):
105
+ queryTxt = f"""
106
  <|begin_of_text|><|start_header_id|>system<|end_header_id|>
107
 
108
  You are a helpful and truthful patient-doctor encounter summary writer.
 
125
  {transcript}
126
  </transcript><|eot_id|>
127
  <|start_header_id|>assistant<|end_header_id|>
128
+ """
129
+
130
  payload = {
131
+ "inputs": queryTxt,
132
+ "parameters": {
133
+ "return_full_text": False,
134
+ "wait_for_model": True,
135
+ "min_length": 1000,
136
+ },
137
+ "options": {"use_cache": False},
 
 
138
  }
139
 
140
+ response = requests.post(
141
+ LLAMA_SUMMARY_URL, headers={"Authorization": f"Bearer {HF_API}"}, json=payload
142
+ )
143
  print(response.json())
144
+ return response.json()[0]["generated_text"][len("<summary>") :]
145
+
146
 
147
  def generate_summary_mistral(language, transcript):
148
+ sysPrompt = f"""<s>[INST]
149
  You are a helpful and truthful patient-doctor encounter summary writer.
150
  Users sends you transcripts of patient-doctor encounter and you create accurate and concise summaries.
151
  The summary only contains informations from the transcript.
 
161
  # Additional Notes
162
  </template>
163
 
164
+ """
165
+ queryTxt = f"""
166
  <transcript>
167
  {transcript}
168
  </transcript>
169
  [/INST]
170
+ """
171
+
172
  payload = {
173
+ "inputs": sysPrompt + queryTxt,
174
+ "parameters": {
175
+ "return_full_text": False,
176
+ "wait_for_model": True,
177
+ "min_length": 1000,
178
+ },
179
+ "options": {"use_cache": False},
 
 
180
  }
181
 
182
+ response = requests.post(
183
+ MISTRAL_SUMMARY_URL, headers={"Authorization": f"Bearer {HF_API}"}, json=payload
184
+ )
185
  print(response.json())
186
+ return response.json()[0]["generated_text"][len("<summary>") :]
187
+
188
 
189
  def generate_summary(model, language, transcript):
190
  match model:
191
  case "Mistral-7B":
192
  print("-> summarize with mistral")
193
+ return generate_summary_mistral(language, transcript)
194
  case "LLAMA3":
195
  print("-> summarize with llama3")
196
  return generate_summary_llama3(language, transcript)
197
  case _:
198
  return f"Unknown model {model}"
199
 
200
+
201
  def update_audio_ui(audio_source: str) -> tuple[dict, dict]:
202
  mic = audio_source == "microphone"
203
  return (
 
211
  with gr.Group():
212
  with gr.Row():
213
  target_language = gr.Dropdown(
214
+ choices=["French", "English"],
215
  label="Output Language",
216
  value="French",
217
  interactive=True,
218
  info="Select your target language",
219
  )
220
  with gr.Row() as audio_box:
221
+ input_audio = gr.Audio(type="filepath")
 
 
222
  submit = gr.Button("Transcribe")
223
  transcribe_output = gr.Textbox(
224
  label="Transcribed Text",
 
230
  )
231
  submit.click(
232
  fn=predict,
233
+ inputs=[target_language, input_audio],
 
 
 
234
  outputs=[transcribe_output],
235
  api_name="predict",
236
  )
237
  with gr.Row():
238
  sumary_model = gr.Dropdown(
239
+ choices=["Mistral-7B", "LLAMA3"],
240
  label="Summary model",
241
  value="Mistral-7B",
242
  interactive=True,
 
253
  )
254
  summarize.click(
255
  fn=generate_summary,
256
+ inputs=[sumary_model, target_language, transcribe_output],
 
 
 
 
257
  outputs=[summary_output],
258
  api_name="predict",
259
  )
260
  gr.Markdown(DUPLICATE)
261
 
262
  demo.queue(max_size=50).launch()