nurfarah57 commited on
Commit
3b6b693
·
verified ·
1 Parent(s): 7291a4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -22
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
 
3
- # Set cache dirs BEFORE imports for permission fix
4
  os.environ["HF_HOME"] = "/tmp"
5
  os.environ["TRANSFORMERS_CACHE"] = "/tmp"
6
  os.environ["TORCH_HOME"] = "/tmp"
@@ -19,7 +18,6 @@ from transformers import VitsModel, AutoTokenizer
19
 
20
  app = FastAPI()
21
 
22
- # Load model and tokenizer ONCE at startup
23
  model = VitsModel.from_pretrained("Somali-tts/somali_tts_model")
24
  tokenizer = AutoTokenizer.from_pretrained("saleolow/somali-mms-tts")
25
 
@@ -84,15 +82,12 @@ def normalize_text(text: str) -> str:
84
 
85
  def waveform_to_wav_bytes(waveform: torch.Tensor, sample_rate: int = 22050) -> bytes:
86
  np_waveform = waveform.cpu().numpy()
87
-
88
  if np_waveform.ndim == 3:
89
  np_waveform = np_waveform[0]
90
  if np_waveform.ndim == 2:
91
  np_waveform = np_waveform.mean(axis=0)
92
-
93
  np_waveform = np.clip(np_waveform, -1.0, 1.0).astype(np.float32)
94
  pcm_waveform = (np_waveform * 32767).astype(np.int16)
95
-
96
  buf = io.BytesIO()
97
  scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform)
98
  buf.seek(0)
@@ -102,7 +97,25 @@ class TextIn(BaseModel):
102
  inputs: str
103
 
104
  @app.post("/synthesize")
105
- async def synthesize(data: TextIn, test: bool = Query(False, description="Set true to return a test tone")):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  if test:
107
  duration_s = 2.0
108
  sample_rate = 22050
@@ -110,22 +123,14 @@ async def synthesize(data: TextIn, test: bool = Query(False, description="Set tr
110
  freq = 440
111
  waveform = 0.5 * np.sin(2 * math.pi * freq * t).astype(np.float32)
112
  pcm_waveform = (waveform * 32767).astype(np.int16)
113
-
114
  buf = io.BytesIO()
115
  scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform)
116
  buf.seek(0)
117
-
118
- print(f"[TEST MODE] Generated test tone: {pcm_waveform.shape[0]} samples, Sample rate: {sample_rate}")
119
  return StreamingResponse(buf, media_type="audio/wav")
120
-
121
- text = normalize_text(data.inputs)
122
- inputs = tokenizer(text, return_tensors="pt").to(device)
123
-
124
  with torch.no_grad():
125
  output = model(**inputs)
126
-
127
- print("Model output type:", type(output))
128
- # Try to extract waveform safely:
129
  if hasattr(output, "waveform"):
130
  waveform = output.waveform
131
  elif isinstance(output, dict) and "waveform" in output:
@@ -134,12 +139,6 @@ async def synthesize(data: TextIn, test: bool = Query(False, description="Set tr
134
  waveform = output[0]
135
  else:
136
  return {"error": "Waveform not found in model output"}
137
-
138
- print("Extracted waveform shape:", waveform.shape)
139
-
140
  sample_rate = getattr(model.config, "sampling_rate", 22050)
141
- print("Sample rate:", sample_rate)
142
-
143
  wav_bytes = waveform_to_wav_bytes(waveform, sample_rate=sample_rate)
144
-
145
  return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav")
 
1
  import os
2
 
 
3
  os.environ["HF_HOME"] = "/tmp"
4
  os.environ["TRANSFORMERS_CACHE"] = "/tmp"
5
  os.environ["TORCH_HOME"] = "/tmp"
 
18
 
19
  app = FastAPI()
20
 
 
21
  model = VitsModel.from_pretrained("Somali-tts/somali_tts_model")
22
  tokenizer = AutoTokenizer.from_pretrained("saleolow/somali-mms-tts")
23
 
 
82
 
83
  def waveform_to_wav_bytes(waveform: torch.Tensor, sample_rate: int = 22050) -> bytes:
84
  np_waveform = waveform.cpu().numpy()
 
85
  if np_waveform.ndim == 3:
86
  np_waveform = np_waveform[0]
87
  if np_waveform.ndim == 2:
88
  np_waveform = np_waveform.mean(axis=0)
 
89
  np_waveform = np.clip(np_waveform, -1.0, 1.0).astype(np.float32)
90
  pcm_waveform = (np_waveform * 32767).astype(np.int16)
 
91
  buf = io.BytesIO()
92
  scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform)
93
  buf.seek(0)
 
97
  inputs: str
98
 
99
  @app.post("/synthesize")
100
+ async def synthesize_post(data: TextIn):
101
+ text = normalize_text(data.inputs)
102
+ inputs = tokenizer(text, return_tensors="pt").to(device)
103
+ with torch.no_grad():
104
+ output = model(**inputs)
105
+ if hasattr(output, "waveform"):
106
+ waveform = output.waveform
107
+ elif isinstance(output, dict) and "waveform" in output:
108
+ waveform = output["waveform"]
109
+ elif isinstance(output, (tuple, list)):
110
+ waveform = output[0]
111
+ else:
112
+ return {"error": "Waveform not found in model output"}
113
+ sample_rate = getattr(model.config, "sampling_rate", 22050)
114
+ wav_bytes = waveform_to_wav_bytes(waveform, sample_rate=sample_rate)
115
+ return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav")
116
+
117
+ @app.get("/synthesize")
118
+ async def synthesize_get(text: str = Query(..., description="Text to synthesize"), test: bool = Query(False)):
119
  if test:
120
  duration_s = 2.0
121
  sample_rate = 22050
 
123
  freq = 440
124
  waveform = 0.5 * np.sin(2 * math.pi * freq * t).astype(np.float32)
125
  pcm_waveform = (waveform * 32767).astype(np.int16)
 
126
  buf = io.BytesIO()
127
  scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform)
128
  buf.seek(0)
 
 
129
  return StreamingResponse(buf, media_type="audio/wav")
130
+ normalized = normalize_text(text)
131
+ inputs = tokenizer(normalized, return_tensors="pt").to(device)
 
 
132
  with torch.no_grad():
133
  output = model(**inputs)
 
 
 
134
  if hasattr(output, "waveform"):
135
  waveform = output.waveform
136
  elif isinstance(output, dict) and "waveform" in output:
 
139
  waveform = output[0]
140
  else:
141
  return {"error": "Waveform not found in model output"}
 
 
 
142
  sample_rate = getattr(model.config, "sampling_rate", 22050)
 
 
143
  wav_bytes = waveform_to_wav_bytes(waveform, sample_rate=sample_rate)
 
144
  return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav")