siddhartharya's picture
Update utils.py
6c57b19 verified
raw
history blame
2.36 kB
from groq import Groq
from pydantic import BaseModel, ValidationError
from typing import List, Literal
import os
import tiktoken
from gtts import gTTS
import tempfile
import json
import re
groq_client = Groq(api_key=os.environ["GROQ_API_KEY"])
tokenizer = tiktoken.get_encoding("cl100k_base")
class DialogueItem(BaseModel):
speaker: Literal["Host", "Guest"]
text: str
class Dialogue(BaseModel):
dialogue: List[DialogueItem]
def truncate_text(text, max_tokens=2048):
tokens = tokenizer.encode(text)
if len(tokens) > max_tokens:
return tokenizer.decode(tokens[:max_tokens])
return text
def generate_script(system_prompt: str, input_text: str, tone: str):
input_text = truncate_text(input_text)
prompt = f"{system_prompt}\nTONE: {tone}\nINPUT TEXT: {input_text}"
response = groq_client.chat.completions.create(
messages=[
{"role": "system", "content": prompt},
],
model="llama-3.1-70b-versatile",
max_tokens=2048,
temperature=0.7
)
# Extract content and remove any markdown code block syntax
content = response.choices[0].message.content
content = re.sub(r'```json\s*|\s*```', '', content)
try:
# First, try to parse as JSON
json_data = json.loads(content)
dialogue = Dialogue.model_validate(json_data)
except json.JSONDecodeError as json_error:
# If JSON parsing fails, try to extract JSON from the text
match = re.search(r'\{.*\}', content, re.DOTALL)
if match:
try:
json_data = json.loads(match.group())
dialogue = Dialogue.model_validate(json_data)
except (json.JSONDecodeError, ValidationError) as e:
raise ValueError(f"Failed to parse dialogue JSON: {e}\nContent: {content}")
else:
raise ValueError(f"Failed to find valid JSON in the response: {content}")
except ValidationError as e:
raise ValueError(f"Failed to validate dialogue structure: {e}\nContent: {content}")
return dialogue
def generate_audio(text: str, speaker: str) -> str:
tts = gTTS(text, lang='en', tld='com' if speaker == "Host" else 'co.uk')
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_audio:
tts.save(temp_audio.name)
return temp_audio.name