ASG Models commited on
Commit
b0bb61b
·
verified ·
1 Parent(s): dce3b4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -2
app.py CHANGED
@@ -27,6 +27,89 @@ model = genai.GenerativeModel(
27
  # safety_settings = Adjust safety settings
28
  # See https://ai.google.dev/gemini-api/docs/safety-settings
29
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  def create_chat_session():
32
  chat_session = model.start_chat(
@@ -72,7 +155,15 @@ def create_chat_session():
72
  return chat_session
73
 
74
  AI=create_chat_session()
 
 
75
 
 
 
 
 
 
 
76
 
77
 
78
  def get_answer_ai(text):
@@ -117,9 +208,9 @@ def text_to_speech(text):
117
  pad_text=''
118
  k+=1
119
 
120
- yield modelspeech(out)
121
  if pad_text!='':
122
- yield modelspeech(pad_text)
123
  def dash(text):
124
 
125
  response=get_answer_ai(text)
 
27
  # safety_settings = Adjust safety settings
28
  # See https://ai.google.dev/gemini-api/docs/safety-settings
29
  )
30
+ import torch
31
+ from typing import Any, Callable, Optional, Tuple, Union,Iterator
32
+ import numpy as np
33
+ import torch.nn as nn # Import the missing module
34
+ def _inference_forward_stream(
35
+ self,
36
+ input_ids: Optional[torch.Tensor] = None,
37
+ attention_mask: Optional[torch.Tensor] = None,
38
+ speaker_embeddings: Optional[torch.Tensor] = None,
39
+ output_attentions: Optional[bool] = None,
40
+ output_hidden_states: Optional[bool] = None,
41
+ return_dict: Optional[bool] = None,
42
+ padding_mask: Optional[torch.Tensor] = None,
43
+ chunk_size: int = 32, # Chunk size for streaming output
44
+ ) -> Iterator[torch.Tensor]:
45
+ """Generates speech waveforms in a streaming fashion."""
46
+ if attention_mask is not None:
47
+ padding_mask = attention_mask.unsqueeze(-1).float()
48
+ else:
49
+ padding_mask = torch.ones_like(input_ids).unsqueeze(-1).float()
50
+
51
+
52
+
53
+ text_encoder_output = self.text_encoder(
54
+ input_ids=input_ids,
55
+ padding_mask=padding_mask,
56
+ attention_mask=attention_mask,
57
+ output_attentions=output_attentions,
58
+ output_hidden_states=output_hidden_states,
59
+ return_dict=return_dict,
60
+ )
61
+ hidden_states = text_encoder_output[0] if not return_dict else text_encoder_output.last_hidden_state
62
+ hidden_states = hidden_states.transpose(1, 2)
63
+ input_padding_mask = padding_mask.transpose(1, 2)
64
+
65
+ prior_means = text_encoder_output[1] if not return_dict else text_encoder_output.prior_means
66
+ prior_log_variances = text_encoder_output[2] if not return_dict else text_encoder_output.prior_log_variances
67
+
68
+ if self.config.use_stochastic_duration_prediction:
69
+ log_duration = self.duration_predictor(
70
+ hidden_states,
71
+ input_padding_mask,
72
+ speaker_embeddings,
73
+ reverse=True,
74
+ noise_scale=self.noise_scale_duration,
75
+ )
76
+ else:
77
+ log_duration = self.duration_predictor(hidden_states, input_padding_mask, speaker_embeddings)
78
+
79
+ length_scale = 1.0 / self.speaking_rate
80
+ duration = torch.ceil(torch.exp(log_duration) * input_padding_mask * length_scale)
81
+ predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long()
82
+
83
+
84
+ # Create a padding mask for the output lengths of shape (batch, 1, max_output_length)
85
+ indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device)
86
+ output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1)
87
+ output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype)
88
+
89
+ # Reconstruct an attention tensor of shape (batch, 1, out_length, in_length)
90
+ attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1)
91
+ batch_size, _, output_length, input_length = attn_mask.shape
92
+ cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1)
93
+ indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device)
94
+ valid_indices = indices.unsqueeze(0) < cum_duration
95
+ valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length)
96
+ padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1]
97
+ attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask
98
+
99
+ # Expand prior distribution
100
+ prior_means = torch.matmul(attn.squeeze(1), prior_means).transpose(1, 2)
101
+ prior_log_variances = torch.matmul(attn.squeeze(1), prior_log_variances).transpose(1, 2)
102
+
103
+ prior_latents = prior_means + torch.randn_like(prior_means) * torch.exp(prior_log_variances) * self.noise_scale
104
+ latents = self.flow(prior_latents, output_padding_mask, speaker_embeddings, reverse=True)
105
+
106
+ spectrogram = latents * output_padding_mask
107
+
108
+ for i in range(0, spectrogram.size(-1), chunk_size):
109
+ with torch.no_grad():
110
+ wav=self.decoder(spectrogram[:,:,i : i + chunk_size] ,speaker_embeddings)
111
+ yield wav.squeeze().cpu().numpy()
112
+
113
 
114
  def create_chat_session():
115
  chat_session = model.start_chat(
 
155
  return chat_session
156
 
157
  AI=create_chat_session()
158
+ def generate_audio(text,speaker_id=None):
159
+ inputs = tokenizer(text, return_tensors="pt")#.input_ids
160
 
161
+ speaker_embeddings = None
162
+
163
+ #torch.cuda.empty_cache()
164
+ with torch.no_grad():
165
+ for chunk in _inference_forward_stream(model_vits,input_ids=inputs.input_ids,attention_mask=inputs.attention_mask,speaker_embeddings= speaker_embeddings,chunk_size=256):
166
+ yield 16000,chunk#.squeeze().cpu().numpy()#.astype(np.int16).tobytes()
167
 
168
 
169
  def get_answer_ai(text):
 
208
  pad_text=''
209
  k+=1
210
 
211
+ yield generate_audio(out)
212
  if pad_text!='':
213
+ yield generate_audio(pad_text)
214
  def dash(text):
215
 
216
  response=get_answer_ai(text)