sanchit-gandhi commited on
Commit
6d71260
·
1 Parent(s): a741e24

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -40
app.py CHANGED
@@ -2,6 +2,7 @@ import base64
2
  import math
3
  import os
4
  import time
 
5
  from multiprocessing import Pool
6
 
7
  import gradio as gr
@@ -24,8 +25,9 @@ To skip the queue, you may wish to create your own inference endpoint, details f
24
 
25
  article = "Whisper large-v2 model by OpenAI. Backend running JAX on a TPU v4-8 through the generous support of the [TRC](https://sites.research.google/trc/about/) programme. Whisper JAX [code](https://github.com/sanchit-gandhi/whisper-jax) and Gradio demo by 🤗 Hugging Face."
26
 
27
- API_URL = os.getenv("API_URL")
28
- API_URL_FROM_FEATURES = os.getenv("API_URL_FROM_FEATURES")
 
29
  language_names = sorted(TO_LANGUAGE_CODE.keys())
30
  CHUNK_LENGTH_S = 30
31
  BATCH_SIZE = 16
@@ -33,48 +35,32 @@ NUM_PROC = 16
33
  FILE_LIMIT_MB = 1000
34
 
35
 
36
- def query(payload):
37
- response = requests.post(API_URL, json=payload)
38
  return response.json(), response.status_code
39
 
40
 
41
- def inference(inputs, task=None, return_timestamps=False):
42
- payload = {"inputs": inputs, "task": task, "return_timestamps": return_timestamps}
43
 
44
- data, status_code = query(payload)
45
 
46
  if status_code == 200:
47
- text = data["text"]
 
48
  else:
49
- text = data["detail"]
50
-
51
- timestamps = data.get("chunks")
52
- if timestamps is not None:
53
- timestamps = [
54
- f"[{format_timestamp(chunk['timestamp'][0])} -> {format_timestamp(chunk['timestamp'][1])}] {chunk['text']}"
55
- for chunk in timestamps
56
- ]
57
- text = "\n".join(str(feature) for feature in timestamps)
58
- return text
59
-
60
-
61
- def chunked_query(payload):
62
- response = requests.post(API_URL_FROM_FEATURES, json=payload)
63
- return response.json()
64
 
65
 
66
- def forward(batch, task=None, return_timestamps=False):
67
  feature_shape = batch["input_features"].shape
68
  batch["input_features"] = base64.b64encode(batch["input_features"].tobytes()).decode()
69
- outputs = chunked_query(
70
- {"batch": batch, "task": task, "return_timestamps": return_timestamps, "feature_shape": feature_shape}
71
- )
72
- outputs["tokens"] = np.asarray(outputs["tokens"])
73
- return outputs
74
 
75
 
76
- def identity(batch):
77
- return batch
 
78
 
79
 
80
  # Copied from https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/utils.py#L50
@@ -108,22 +94,21 @@ if __name__ == "__main__":
108
 
109
  def tqdm_generate(inputs: dict, task: str, return_timestamps: bool, progress: gr.Progress):
110
  inputs_len = inputs["array"].shape[0]
111
- all_chunk_start_idx = np.arange(0, inputs_len, step)
112
- num_samples = len(all_chunk_start_idx)
113
  num_batches = math.ceil(num_samples / BATCH_SIZE)
114
- dummy_batches = list(
115
- range(num_batches)
116
- ) # Gradio progress bar not compatible with generator, see https://github.com/gradio-app/gradio/issues/3841
117
 
118
  dataloader = processor.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE)
119
- progress(0, desc="Pre-processing audio file...")
120
- dataloader = pool.map(identity, dataloader)
 
121
 
122
  model_outputs = []
123
  start_time = time.time()
124
  # iterate over our chunked audio samples
125
- for batch, _ in zip(dataloader, progress.tqdm(dummy_batches, desc="Transcribing...")):
126
- model_outputs.append(forward(batch, task=task, return_timestamps=return_timestamps))
127
  runtime = time.time() - start_time
128
 
129
  post_processed = processor.postprocess(model_outputs, return_timestamps=return_timestamps)
 
2
  import math
3
  import os
4
  import time
5
+ from functools import partial
6
  from multiprocessing import Pool
7
 
8
  import gradio as gr
 
25
 
26
  article = "Whisper large-v2 model by OpenAI. Backend running JAX on a TPU v4-8 through the generous support of the [TRC](https://sites.research.google/trc/about/) programme. Whisper JAX [code](https://github.com/sanchit-gandhi/whisper-jax) and Gradio demo by 🤗 Hugging Face."
27
 
28
+ API_SEND_URL = os.getenv("API_SEND_URL")
29
+ API_FORWARD_URL = os.getenv("API_FORWARD_URL")
30
+
31
  language_names = sorted(TO_LANGUAGE_CODE.keys())
32
  CHUNK_LENGTH_S = 30
33
  BATCH_SIZE = 16
 
35
  FILE_LIMIT_MB = 1000
36
 
37
 
38
+ def query(url, payload):
39
+ response = requests.post(url, json=payload)
40
  return response.json(), response.status_code
41
 
42
 
43
+ def inference(batch_id, idx, task=None, return_timestamps=False):
44
+ payload = {"batch_id": batch_id, "idx": idx, "task": task, "return_timestamps": return_timestamps}
45
 
46
+ data, status_code = query(API_FORWARD_URL, payload)
47
 
48
  if status_code == 200:
49
+ tokens = {"tokens": np.asarray(data["tokens"])}
50
+ return tokens
51
  else:
52
+ gr.Error(data["detail"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
 
55
+ def send_chunks(batch, batch_id):
56
  feature_shape = batch["input_features"].shape
57
  batch["input_features"] = base64.b64encode(batch["input_features"].tobytes()).decode()
58
+ query(API_SEND_URL, {"batch": batch, "feature_shape": feature_shape, "batch_id": batch_id})
 
 
 
 
59
 
60
 
61
+ def forward(batch_id, idx, task=None, return_timestamps=False):
62
+ outputs = inference(batch_id, idx, task, return_timestamps)
63
+ return outputs
64
 
65
 
66
  # Copied from https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/utils.py#L50
 
94
 
95
  def tqdm_generate(inputs: dict, task: str, return_timestamps: bool, progress: gr.Progress):
96
  inputs_len = inputs["array"].shape[0]
97
+ all_chunk_start_batch_id = np.arange(0, inputs_len, step)
98
+ num_samples = len(all_chunk_start_batch_id)
99
  num_batches = math.ceil(num_samples / BATCH_SIZE)
100
+ dummy_batches = list(range(num_batches))
 
 
101
 
102
  dataloader = processor.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE)
103
+ progress(0, desc="Sending audio to TPU...")
104
+ batch_id = np.random.randint(1000000) # TODO(SG): swap to an iterator
105
+ pool.map(partial(send_chunks, batch_id=batch_id), dataloader)
106
 
107
  model_outputs = []
108
  start_time = time.time()
109
  # iterate over our chunked audio samples
110
+ for idx in progress.tqdm(dummy_batches, desc="Transcribing..."):
111
+ model_outputs.append(forward(batch_id, idx, task=task, return_timestamps=return_timestamps))
112
  runtime = time.time() - start_time
113
 
114
  post_processed = processor.postprocess(model_outputs, return_timestamps=return_timestamps)