Spaces:
Runtime error
Runtime error
Commit
·
6d71260
1
Parent(s):
a741e24
Update app.py
Browse files
app.py
CHANGED
@@ -2,6 +2,7 @@ import base64
|
|
2 |
import math
|
3 |
import os
|
4 |
import time
|
|
|
5 |
from multiprocessing import Pool
|
6 |
|
7 |
import gradio as gr
|
@@ -24,8 +25,9 @@ To skip the queue, you may wish to create your own inference endpoint, details f
|
|
24 |
|
25 |
article = "Whisper large-v2 model by OpenAI. Backend running JAX on a TPU v4-8 through the generous support of the [TRC](https://sites.research.google/trc/about/) programme. Whisper JAX [code](https://github.com/sanchit-gandhi/whisper-jax) and Gradio demo by 🤗 Hugging Face."
|
26 |
|
27 |
-
|
28 |
-
|
|
|
29 |
language_names = sorted(TO_LANGUAGE_CODE.keys())
|
30 |
CHUNK_LENGTH_S = 30
|
31 |
BATCH_SIZE = 16
|
@@ -33,48 +35,32 @@ NUM_PROC = 16
|
|
33 |
FILE_LIMIT_MB = 1000
|
34 |
|
35 |
|
36 |
-
def query(payload):
|
37 |
-
response = requests.post(
|
38 |
return response.json(), response.status_code
|
39 |
|
40 |
|
41 |
-
def inference(
|
42 |
-
payload = {"
|
43 |
|
44 |
-
data, status_code = query(payload)
|
45 |
|
46 |
if status_code == 200:
|
47 |
-
|
|
|
48 |
else:
|
49 |
-
|
50 |
-
|
51 |
-
timestamps = data.get("chunks")
|
52 |
-
if timestamps is not None:
|
53 |
-
timestamps = [
|
54 |
-
f"[{format_timestamp(chunk['timestamp'][0])} -> {format_timestamp(chunk['timestamp'][1])}] {chunk['text']}"
|
55 |
-
for chunk in timestamps
|
56 |
-
]
|
57 |
-
text = "\n".join(str(feature) for feature in timestamps)
|
58 |
-
return text
|
59 |
-
|
60 |
-
|
61 |
-
def chunked_query(payload):
|
62 |
-
response = requests.post(API_URL_FROM_FEATURES, json=payload)
|
63 |
-
return response.json()
|
64 |
|
65 |
|
66 |
-
def
|
67 |
feature_shape = batch["input_features"].shape
|
68 |
batch["input_features"] = base64.b64encode(batch["input_features"].tobytes()).decode()
|
69 |
-
|
70 |
-
{"batch": batch, "task": task, "return_timestamps": return_timestamps, "feature_shape": feature_shape}
|
71 |
-
)
|
72 |
-
outputs["tokens"] = np.asarray(outputs["tokens"])
|
73 |
-
return outputs
|
74 |
|
75 |
|
76 |
-
def
|
77 |
-
|
|
|
78 |
|
79 |
|
80 |
# Copied from https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/utils.py#L50
|
@@ -108,22 +94,21 @@ if __name__ == "__main__":
|
|
108 |
|
109 |
def tqdm_generate(inputs: dict, task: str, return_timestamps: bool, progress: gr.Progress):
|
110 |
inputs_len = inputs["array"].shape[0]
|
111 |
-
|
112 |
-
num_samples = len(
|
113 |
num_batches = math.ceil(num_samples / BATCH_SIZE)
|
114 |
-
dummy_batches = list(
|
115 |
-
range(num_batches)
|
116 |
-
) # Gradio progress bar not compatible with generator, see https://github.com/gradio-app/gradio/issues/3841
|
117 |
|
118 |
dataloader = processor.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE)
|
119 |
-
progress(0, desc="
|
120 |
-
|
|
|
121 |
|
122 |
model_outputs = []
|
123 |
start_time = time.time()
|
124 |
# iterate over our chunked audio samples
|
125 |
-
for
|
126 |
-
model_outputs.append(forward(
|
127 |
runtime = time.time() - start_time
|
128 |
|
129 |
post_processed = processor.postprocess(model_outputs, return_timestamps=return_timestamps)
|
|
|
2 |
import math
|
3 |
import os
|
4 |
import time
|
5 |
+
from functools import partial
|
6 |
from multiprocessing import Pool
|
7 |
|
8 |
import gradio as gr
|
|
|
25 |
|
26 |
article = "Whisper large-v2 model by OpenAI. Backend running JAX on a TPU v4-8 through the generous support of the [TRC](https://sites.research.google/trc/about/) programme. Whisper JAX [code](https://github.com/sanchit-gandhi/whisper-jax) and Gradio demo by 🤗 Hugging Face."
|
27 |
|
28 |
+
API_SEND_URL = os.getenv("API_SEND_URL")
|
29 |
+
API_FORWARD_URL = os.getenv("API_FORWARD_URL")
|
30 |
+
|
31 |
language_names = sorted(TO_LANGUAGE_CODE.keys())
|
32 |
CHUNK_LENGTH_S = 30
|
33 |
BATCH_SIZE = 16
|
|
|
35 |
FILE_LIMIT_MB = 1000
|
36 |
|
37 |
|
38 |
+
def query(url, payload):
|
39 |
+
response = requests.post(url, json=payload)
|
40 |
return response.json(), response.status_code
|
41 |
|
42 |
|
43 |
+
def inference(batch_id, idx, task=None, return_timestamps=False):
|
44 |
+
payload = {"batch_id": batch_id, "idx": idx, "task": task, "return_timestamps": return_timestamps}
|
45 |
|
46 |
+
data, status_code = query(API_FORWARD_URL, payload)
|
47 |
|
48 |
if status_code == 200:
|
49 |
+
tokens = {"tokens": np.asarray(data["tokens"])}
|
50 |
+
return tokens
|
51 |
else:
|
52 |
+
gr.Error(data["detail"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
|
55 |
+
def send_chunks(batch, batch_id):
|
56 |
feature_shape = batch["input_features"].shape
|
57 |
batch["input_features"] = base64.b64encode(batch["input_features"].tobytes()).decode()
|
58 |
+
query(API_SEND_URL, {"batch": batch, "feature_shape": feature_shape, "batch_id": batch_id})
|
|
|
|
|
|
|
|
|
59 |
|
60 |
|
61 |
+
def forward(batch_id, idx, task=None, return_timestamps=False):
|
62 |
+
outputs = inference(batch_id, idx, task, return_timestamps)
|
63 |
+
return outputs
|
64 |
|
65 |
|
66 |
# Copied from https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/utils.py#L50
|
|
|
94 |
|
95 |
def tqdm_generate(inputs: dict, task: str, return_timestamps: bool, progress: gr.Progress):
|
96 |
inputs_len = inputs["array"].shape[0]
|
97 |
+
all_chunk_start_batch_id = np.arange(0, inputs_len, step)
|
98 |
+
num_samples = len(all_chunk_start_batch_id)
|
99 |
num_batches = math.ceil(num_samples / BATCH_SIZE)
|
100 |
+
dummy_batches = list(range(num_batches))
|
|
|
|
|
101 |
|
102 |
dataloader = processor.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE)
|
103 |
+
progress(0, desc="Sending audio to TPU...")
|
104 |
+
batch_id = np.random.randint(1000000) # TODO(SG): swap to an iterator
|
105 |
+
pool.map(partial(send_chunks, batch_id=batch_id), dataloader)
|
106 |
|
107 |
model_outputs = []
|
108 |
start_time = time.time()
|
109 |
# iterate over our chunked audio samples
|
110 |
+
for idx in progress.tqdm(dummy_batches, desc="Transcribing..."):
|
111 |
+
model_outputs.append(forward(batch_id, idx, task=task, return_timestamps=return_timestamps))
|
112 |
runtime = time.time() - start_time
|
113 |
|
114 |
post_processed = processor.postprocess(model_outputs, return_timestamps=return_timestamps)
|