siddhartharya's picture
Update utils.py
e81cad8 verified
raw
history blame
2.82 kB
from groq import Groq
from pydantic import BaseModel, ValidationError
from typing import List, Literal
import os
import tiktoken
import json
import re
from gtts import gTTS
import tempfile
groq_client = Groq(api_key=os.environ["GROQ_API_KEY"])
tokenizer = tiktoken.get_encoding("cl100k_base")
class DialogueItem(BaseModel):
speaker: Literal["Maria", "Sarah"]
text: str
class Dialogue(BaseModel):
dialogue: List[DialogueItem]
def truncate_text(text, max_tokens=2048):
tokens = tokenizer.encode(text)
if len(tokens) > max_tokens:
return tokenizer.decode(tokens[:max_tokens])
return text
def generate_script(system_prompt: str, input_text: str, tone: str, target_length: str):
input_text = truncate_text(input_text)
word_limit = 300 if target_length == "Short (1-2 min)" else 750 # Assuming 150 words per minute
prompt = f"""
{system_prompt}
TONE: {tone}
TARGET LENGTH: {target_length} (approximately {word_limit} words)
INPUT TEXT: {input_text}
Generate a complete, well-structured podcast script that:
1. Starts with a proper introduction
2. Covers the main points from the input text
3. Has a natural flow of conversation between Maria and Sarah
4. Concludes with a summary and sign-off
5. Fits within the {word_limit} word limit for the target length of {target_length}
Ensure the script is not abruptly cut off and forms a complete conversation.
"""
response = groq_client.chat.completions.create(
messages=[
{"role": "system", "content": prompt},
],
model="llama-3.1-70b-versatile",
max_tokens=2048,
temperature=0.7
)
content = response.choices[0].message.content
content = re.sub(r'```json\s*|\s*```', '', content)
try:
json_data = json.loads(content)
dialogue = Dialogue.model_validate(json_data)
except json.JSONDecodeError as json_error:
match = re.search(r'\{.*\}', content, re.DOTALL)
if match:
try:
json_data = json.loads(match.group())
dialogue = Dialogue.model_validate(json_data)
except (json.JSONDecodeError, ValidationError) as e:
raise ValueError(f"Failed to parse dialogue JSON: {e}\nContent: {content}")
else:
raise ValueError(f"Failed to find valid JSON in the response: {content}")
except ValidationError as e:
raise ValueError(f"Failed to validate dialogue structure: {e}\nContent: {content}")
return dialogue
def generate_audio(text: str, speaker: str) -> str:
tts = gTTS(text=text, lang='en', tld='com')
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_audio:
tts.save(temp_audio.name)
return temp_audio.name