siddhartharya's picture
Update utils.py
ba22d1b verified
raw
history blame
1.51 kB
from groq import Groq
from pydantic import BaseModel, ValidationError
from typing import List, Literal
import os
import tiktoken
from gtts import gTTS
import tempfile
groq_client = Groq(api_key=os.environ["GROQ_API_KEY"])
tokenizer = tiktoken.get_encoding("cl100k_base")
class DialogueItem(BaseModel):
speaker: Literal["Host", "Guest"]
text: str
class Dialogue(BaseModel):
dialogue: List[DialogueItem]
def truncate_text(text, max_tokens=2048):
tokens = tokenizer.encode(text)
if len(tokens) > max_tokens:
return tokenizer.decode(tokens[:max_tokens])
return text
def generate_script(system_prompt: str, input_text: str, tone: str):
input_text = truncate_text(input_text)
prompt = f"{system_prompt}\nTONE: {tone}\nINPUT TEXT: {input_text}"
response = groq_client.chat.completions.create(
messages=[
{"role": "system", "content": prompt},
],
model="llama-3.1-70b-versatile",
max_tokens=2048,
temperature=0.7
)
try:
dialogue = Dialogue.model_validate_json(response.choices[0].message.content)
except ValidationError as e:
raise ValueError(f"Failed to parse dialogue JSON: {e}")
return dialogue
def generate_audio(text: str, speaker: str) -> str:
tts = gTTS(text, lang='en', tld='com' if speaker == "Host" else 'co.uk')
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_audio:
tts.save(temp_audio.name)
return temp_audio.name