eventhorizon28
commited on
Update handler.py
Browse files- handler.py +12 -3
handler.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
-
from transformers import pipeline
|
|
|
2 |
|
3 |
class EndpointHandler():
|
4 |
def __init__(self, path=""):
|
5 |
# create inference pipeline
|
6 |
-
self.pipeline = pipeline("text-to-speech", path)
|
7 |
|
8 |
def __call__(self, data: Any) -> Any:
|
9 |
inputs = data.pop("inputs", data)
|
@@ -14,5 +15,13 @@ class EndpointHandler():
|
|
14 |
prediction = self.pipeline(inputs, **parameters)
|
15 |
else:
|
16 |
prediction = self.pipeline(inputs)
|
|
|
17 |
# postprocess the prediction
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import pipeline
|
2 |
+
from typing import Any
|
3 |
|
4 |
class EndpointHandler():
|
5 |
def __init__(self, path=""):
|
6 |
# create inference pipeline
|
7 |
+
self.pipeline = pipeline("text-to-speech", model=path)
|
8 |
|
9 |
def __call__(self, data: Any) -> Any:
|
10 |
inputs = data.pop("inputs", data)
|
|
|
15 |
prediction = self.pipeline(inputs, **parameters)
|
16 |
else:
|
17 |
prediction = self.pipeline(inputs)
|
18 |
+
|
19 |
# postprocess the prediction
|
20 |
+
audio_array = prediction['audio']
|
21 |
+
sampling_rate = prediction['sampling_rate']
|
22 |
+
|
23 |
+
# If you need to return raw audio data
|
24 |
+
return {
|
25 |
+
"audio": audio_array,
|
26 |
+
"sampling_rate": sampling_rate
|
27 |
+
}
|