siddhartharya commited on
Commit
08f2510
1 Parent(s): de2153f

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +19 -8
utils.py CHANGED
@@ -3,16 +3,26 @@ from pydantic import BaseModel, ValidationError
3
  from typing import List, Literal
4
  import os
5
  import tiktoken
6
- from gtts import gTTS
7
  import tempfile
8
  import json
9
  import re
 
 
 
10
 
11
  groq_client = Groq(api_key=os.environ["GROQ_API_KEY"])
12
  tokenizer = tiktoken.get_encoding("cl100k_base")
13
 
 
 
 
 
 
 
 
 
14
  class DialogueItem(BaseModel):
15
- speaker: Literal["Host", "Guest"]
16
  text: str
17
 
18
  class Dialogue(BaseModel):
@@ -37,16 +47,13 @@ def generate_script(system_prompt: str, input_text: str, tone: str):
37
  temperature=0.7
38
  )
39
 
40
- # Extract content and remove any markdown code block syntax
41
  content = response.choices[0].message.content
42
  content = re.sub(r'```json\s*|\s*```', '', content)
43
 
44
  try:
45
- # First, try to parse as JSON
46
  json_data = json.loads(content)
47
  dialogue = Dialogue.model_validate(json_data)
48
  except json.JSONDecodeError as json_error:
49
- # If JSON parsing fails, try to extract JSON from the text
50
  match = re.search(r'\{.*\}', content, re.DOTALL)
51
  if match:
52
  try:
@@ -62,7 +69,11 @@ def generate_script(system_prompt: str, input_text: str, tone: str):
62
  return dialogue
63
 
64
  def generate_audio(text: str, speaker: str) -> str:
65
- tts = gTTS(text, lang='en', tld='com' if speaker == "Host" else 'co.uk')
66
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_audio:
67
- tts.save(temp_audio.name)
 
 
 
 
68
  return temp_audio.name
 
3
  from typing import List, Literal
4
  import os
5
  import tiktoken
 
6
  import tempfile
7
  import json
8
  import re
9
+ from transformers import pipeline
10
+ import torch
11
+ import soundfile as sf
12
 
13
  groq_client = Groq(api_key=os.environ["GROQ_API_KEY"])
14
  tokenizer = tiktoken.get_encoding("cl100k_base")
15
 
16
+ # Initialize TTS pipelines
17
+ tts_male = pipeline("text-to-speech", model="microsoft/speecht5_tts", device="cpu")
18
+ tts_female = pipeline("text-to-speech", model="microsoft/speecht5_tts", device="cpu")
19
+
20
+ # Load speaker embeddings
21
+ male_embedding = torch.load("https://huggingface.co/microsoft/speecht5_tts/resolve/main/en_speaker_1.pt")
22
+ female_embedding = torch.load("https://huggingface.co/microsoft/speecht5_tts/resolve/main/en_speaker_9.pt")
23
+
24
  class DialogueItem(BaseModel):
25
+ speaker: Literal["John", "Sarah"]
26
  text: str
27
 
28
  class Dialogue(BaseModel):
 
47
  temperature=0.7
48
  )
49
 
 
50
  content = response.choices[0].message.content
51
  content = re.sub(r'```json\s*|\s*```', '', content)
52
 
53
  try:
 
54
  json_data = json.loads(content)
55
  dialogue = Dialogue.model_validate(json_data)
56
  except json.JSONDecodeError as json_error:
 
57
  match = re.search(r'\{.*\}', content, re.DOTALL)
58
  if match:
59
  try:
 
69
  return dialogue
70
 
71
  def generate_audio(text: str, speaker: str) -> str:
72
+ if speaker == "John":
73
+ speech = tts_male(text, speaker_embeddings=male_embedding)
74
+ else: # Sarah
75
+ speech = tts_female(text, speaker_embeddings=female_embedding)
76
+
77
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
78
+ sf.write(temp_audio.name, speech["audio"], speech["sampling_rate"])
79
  return temp_audio.name