siddhartharya
commited on
Commit
•
08f2510
1
Parent(s):
de2153f
Update utils.py
Browse files
utils.py
CHANGED
@@ -3,16 +3,26 @@ from pydantic import BaseModel, ValidationError
|
|
3 |
from typing import List, Literal
|
4 |
import os
|
5 |
import tiktoken
|
6 |
-
from gtts import gTTS
|
7 |
import tempfile
|
8 |
import json
|
9 |
import re
|
|
|
|
|
|
|
10 |
|
11 |
groq_client = Groq(api_key=os.environ["GROQ_API_KEY"])
|
12 |
tokenizer = tiktoken.get_encoding("cl100k_base")
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
class DialogueItem(BaseModel):
|
15 |
-
speaker: Literal["
|
16 |
text: str
|
17 |
|
18 |
class Dialogue(BaseModel):
|
@@ -37,16 +47,13 @@ def generate_script(system_prompt: str, input_text: str, tone: str):
|
|
37 |
temperature=0.7
|
38 |
)
|
39 |
|
40 |
-
# Extract content and remove any markdown code block syntax
|
41 |
content = response.choices[0].message.content
|
42 |
content = re.sub(r'```json\s*|\s*```', '', content)
|
43 |
|
44 |
try:
|
45 |
-
# First, try to parse as JSON
|
46 |
json_data = json.loads(content)
|
47 |
dialogue = Dialogue.model_validate(json_data)
|
48 |
except json.JSONDecodeError as json_error:
|
49 |
-
# If JSON parsing fails, try to extract JSON from the text
|
50 |
match = re.search(r'\{.*\}', content, re.DOTALL)
|
51 |
if match:
|
52 |
try:
|
@@ -62,7 +69,11 @@ def generate_script(system_prompt: str, input_text: str, tone: str):
|
|
62 |
return dialogue
|
63 |
|
64 |
def generate_audio(text: str, speaker: str) -> str:
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
68 |
return temp_audio.name
|
|
|
3 |
from typing import List, Literal
|
4 |
import os
|
5 |
import tiktoken
|
|
|
6 |
import tempfile
|
7 |
import json
|
8 |
import re
|
9 |
+
from transformers import pipeline
|
10 |
+
import torch
|
11 |
+
import soundfile as sf
|
12 |
|
13 |
groq_client = Groq(api_key=os.environ["GROQ_API_KEY"])
|
14 |
tokenizer = tiktoken.get_encoding("cl100k_base")
|
15 |
|
16 |
+
# Initialize TTS pipelines
|
17 |
+
tts_male = pipeline("text-to-speech", model="microsoft/speecht5_tts", device="cpu")
|
18 |
+
tts_female = pipeline("text-to-speech", model="microsoft/speecht5_tts", device="cpu")
|
19 |
+
|
20 |
+
# Load speaker embeddings
|
21 |
+
male_embedding = torch.load("https://huggingface.co/microsoft/speecht5_tts/resolve/main/en_speaker_1.pt")
|
22 |
+
female_embedding = torch.load("https://huggingface.co/microsoft/speecht5_tts/resolve/main/en_speaker_9.pt")
|
23 |
+
|
24 |
class DialogueItem(BaseModel):
|
25 |
+
speaker: Literal["John", "Sarah"]
|
26 |
text: str
|
27 |
|
28 |
class Dialogue(BaseModel):
|
|
|
47 |
temperature=0.7
|
48 |
)
|
49 |
|
|
|
50 |
content = response.choices[0].message.content
|
51 |
content = re.sub(r'```json\s*|\s*```', '', content)
|
52 |
|
53 |
try:
|
|
|
54 |
json_data = json.loads(content)
|
55 |
dialogue = Dialogue.model_validate(json_data)
|
56 |
except json.JSONDecodeError as json_error:
|
|
|
57 |
match = re.search(r'\{.*\}', content, re.DOTALL)
|
58 |
if match:
|
59 |
try:
|
|
|
69 |
return dialogue
|
70 |
|
71 |
def generate_audio(text: str, speaker: str) -> str:
|
72 |
+
if speaker == "John":
|
73 |
+
speech = tts_male(text, speaker_embeddings=male_embedding)
|
74 |
+
else: # Sarah
|
75 |
+
speech = tts_female(text, speaker_embeddings=female_embedding)
|
76 |
+
|
77 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
|
78 |
+
sf.write(temp_audio.name, speech["audio"], speech["sampling_rate"])
|
79 |
return temp_audio.name
|