nurfarah57 commited on
Commit
7291a4c
·
verified ·
1 Parent(s): 1ad4fd7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -33
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
 
3
- # Set cache dirs before imports to fix permission errors
4
  os.environ["HF_HOME"] = "/tmp"
5
  os.environ["TRANSFORMERS_CACHE"] = "/tmp"
6
  os.environ["TORCH_HOME"] = "/tmp"
@@ -19,7 +19,7 @@ from transformers import VitsModel, AutoTokenizer
19
 
20
  app = FastAPI()
21
 
22
- # Load model/tokenizer once at startup
23
  model = VitsModel.from_pretrained("Somali-tts/somali_tts_model")
24
  tokenizer = AutoTokenizer.from_pretrained("saleolow/somali-mms-tts")
25
 
@@ -82,16 +82,31 @@ def normalize_text(text: str) -> str:
82
  text = text.replace("ZamZam", "SamSam")
83
  return text
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  class TextIn(BaseModel):
86
  inputs: str
87
 
88
  @app.post("/synthesize")
89
- async def synthesize(data: TextIn, test: bool = Query(False, description="Set true to generate test tone instead of TTS")):
90
  if test:
91
- # Generate 2-second 440Hz sine wave for testing playback
92
  duration_s = 2.0
93
  sample_rate = 22050
94
- t = np.linspace(0, duration_s, int(sample_rate*duration_s), endpoint=False)
95
  freq = 440
96
  waveform = 0.5 * np.sin(2 * math.pi * freq * t).astype(np.float32)
97
  pcm_waveform = (waveform * 32767).astype(np.int16)
@@ -101,45 +116,30 @@ async def synthesize(data: TextIn, test: bool = Query(False, description="Set tr
101
  buf.seek(0)
102
 
103
  print(f"[TEST MODE] Generated test tone: {pcm_waveform.shape[0]} samples, Sample rate: {sample_rate}")
104
-
105
  return StreamingResponse(buf, media_type="audio/wav")
106
 
107
- # Normalize input text
108
  text = normalize_text(data.inputs)
109
-
110
- # Tokenize and move to device
111
  inputs = tokenizer(text, return_tensors="pt").to(device)
112
 
113
- # Generate waveform
114
  with torch.no_grad():
115
  output = model(**inputs)
116
 
117
- print("Raw waveform shape:", output.waveform.shape)
118
-
119
- waveform = output.waveform.cpu().numpy()
120
-
121
- # Process waveform dimensions
122
- if waveform.ndim == 3:
123
- waveform = waveform[0] # batch dimension
124
- if waveform.ndim == 2:
125
- waveform = waveform.mean(axis=0) # average channels to mono
126
-
127
- print("Processed waveform shape:", waveform.shape)
128
- print("Waveform min/max before clip:", waveform.min(), waveform.max())
129
-
130
- waveform = waveform.astype(np.float32)
131
- waveform = np.clip(waveform, -1.0, 1.0)
132
-
133
- pcm_waveform = (waveform * 32767).astype(np.int16)
134
 
135
- print("PCM waveform shape:", pcm_waveform.shape)
136
- print("PCM waveform min/max:", pcm_waveform.min(), pcm_waveform.max())
137
 
138
- buf = io.BytesIO()
139
  sample_rate = getattr(model.config, "sampling_rate", 22050)
140
  print("Sample rate:", sample_rate)
141
 
142
- scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform)
143
- buf.seek(0)
144
 
145
- return StreamingResponse(buf, media_type="audio/wav")
 
1
  import os
2
 
3
+ # Set cache dirs BEFORE imports for permission fix
4
  os.environ["HF_HOME"] = "/tmp"
5
  os.environ["TRANSFORMERS_CACHE"] = "/tmp"
6
  os.environ["TORCH_HOME"] = "/tmp"
 
19
 
20
  app = FastAPI()
21
 
22
+ # Load model and tokenizer ONCE at startup
23
  model = VitsModel.from_pretrained("Somali-tts/somali_tts_model")
24
  tokenizer = AutoTokenizer.from_pretrained("saleolow/somali-mms-tts")
25
 
 
82
  text = text.replace("ZamZam", "SamSam")
83
  return text
84
 
85
+ def waveform_to_wav_bytes(waveform: torch.Tensor, sample_rate: int = 22050) -> bytes:
86
+ np_waveform = waveform.cpu().numpy()
87
+
88
+ if np_waveform.ndim == 3:
89
+ np_waveform = np_waveform[0]
90
+ if np_waveform.ndim == 2:
91
+ np_waveform = np_waveform.mean(axis=0)
92
+
93
+ np_waveform = np.clip(np_waveform, -1.0, 1.0).astype(np.float32)
94
+ pcm_waveform = (np_waveform * 32767).astype(np.int16)
95
+
96
+ buf = io.BytesIO()
97
+ scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform)
98
+ buf.seek(0)
99
+ return buf.read()
100
+
101
  class TextIn(BaseModel):
102
  inputs: str
103
 
104
  @app.post("/synthesize")
105
+ async def synthesize(data: TextIn, test: bool = Query(False, description="Set true to return a test tone")):
106
  if test:
 
107
  duration_s = 2.0
108
  sample_rate = 22050
109
+ t = np.linspace(0, duration_s, int(sample_rate * duration_s), endpoint=False)
110
  freq = 440
111
  waveform = 0.5 * np.sin(2 * math.pi * freq * t).astype(np.float32)
112
  pcm_waveform = (waveform * 32767).astype(np.int16)
 
116
  buf.seek(0)
117
 
118
  print(f"[TEST MODE] Generated test tone: {pcm_waveform.shape[0]} samples, Sample rate: {sample_rate}")
 
119
  return StreamingResponse(buf, media_type="audio/wav")
120
 
 
121
  text = normalize_text(data.inputs)
 
 
122
  inputs = tokenizer(text, return_tensors="pt").to(device)
123
 
 
124
  with torch.no_grad():
125
  output = model(**inputs)
126
 
127
+ print("Model output type:", type(output))
128
+ # Try to extract waveform safely:
129
+ if hasattr(output, "waveform"):
130
+ waveform = output.waveform
131
+ elif isinstance(output, dict) and "waveform" in output:
132
+ waveform = output["waveform"]
133
+ elif isinstance(output, (tuple, list)):
134
+ waveform = output[0]
135
+ else:
136
+ return {"error": "Waveform not found in model output"}
 
 
 
 
 
 
 
137
 
138
+ print("Extracted waveform shape:", waveform.shape)
 
139
 
 
140
  sample_rate = getattr(model.config, "sampling_rate", 22050)
141
  print("Sample rate:", sample_rate)
142
 
143
+ wav_bytes = waveform_to_wav_bytes(waveform, sample_rate=sample_rate)
 
144
 
145
+ return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav")