Blakus commited on
Commit
9b668a1
·
verified ·
1 Parent(s): 2ddb872

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -7
app.py CHANGED
@@ -53,17 +53,22 @@ def predict(prompt, language, reference_audio):
53
 
54
  sentences = split_text(prompt)
55
 
56
- temperature = config.inference.get("temperature", 0.75)
57
- repetition_penalty = config.inference.get("repetition_penalty", 5.0)
58
- gpt_cond_len = config.inference.get("gpt_cond_len", 30)
59
- gpt_cond_chunk_len = config.inference.get("gpt_cond_chunk_len", 4)
60
- max_ref_length = config.inference.get("max_ref_length", 60)
 
 
 
 
 
61
 
62
  gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
63
  audio_path=reference_audio,
64
  gpt_cond_len=gpt_cond_len,
65
  gpt_cond_chunk_len=gpt_cond_chunk_len,
66
- max_ref_length=max_ref_length
67
  )
68
 
69
  start_time = time.time()
@@ -76,11 +81,14 @@ def predict(prompt, language, reference_audio):
76
  gpt_cond_latent,
77
  speaker_embedding,
78
  temperature=temperature,
 
79
  repetition_penalty=repetition_penalty,
 
 
80
  )
81
  audio_segment = AudioSegment(
82
  out["wav"].tobytes(),
83
- frame_rate=24000,
84
  sample_width=2,
85
  channels=1
86
  )
 
53
 
54
  sentences = split_text(prompt)
55
 
56
+ # Usar los parámetros del config.json
57
+ temperature = config.model_args.get("temperature", 0.85)
58
+ repetition_penalty = config.model_args.get("repetition_penalty", 2.0)
59
+ length_penalty = config.model_args.get("length_penalty", 1.0)
60
+ top_k = config.model_args.get("top_k", 50)
61
+ top_p = config.model_args.get("top_p", 0.85)
62
+
63
+ gpt_cond_len = config.model_args.get("gpt_cond_len", 12)
64
+ gpt_cond_chunk_len = config.model_args.get("gpt_cond_chunk_len", 4)
65
+ max_ref_len = config.model_args.get("max_ref_len", 10)
66
 
67
  gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
68
  audio_path=reference_audio,
69
  gpt_cond_len=gpt_cond_len,
70
  gpt_cond_chunk_len=gpt_cond_chunk_len,
71
+ max_ref_len=max_ref_len
72
  )
73
 
74
  start_time = time.time()
 
81
  gpt_cond_latent,
82
  speaker_embedding,
83
  temperature=temperature,
84
+ length_penalty=length_penalty,
85
  repetition_penalty=repetition_penalty,
86
+ top_k=top_k,
87
+ top_p=top_p
88
  )
89
  audio_segment = AudioSegment(
90
  out["wav"].tobytes(),
91
+ frame_rate=config.audio["output_sample_rate"],
92
  sample_width=2,
93
  channels=1
94
  )