File size: 1,082 Bytes
652d9d0 8412e92 652d9d0 8412e92 652d9d0 8412e92 652d9d0 8412e92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
from groq import Groq
from pydantic import BaseModel, ValidationError
from typing import List, Literal
import os
groq_client = Groq(api_key=os.environ["GROQ_API_KEY"])
class DialogueItem(BaseModel):
speaker: Literal["Host", "Guest"]
text: str
class Dialogue(BaseModel):
dialogue: List[DialogueItem]
def generate_script(system_prompt: str, input_text: str, tone: str):
input_text = truncate_text(input_text)
prompt = f"{system_prompt}\nTONE: {tone}\nINPUT TEXT: {input_text}"
response = groq_client.chat.completions.create(
messages=[
{"role": "system", "content": prompt},
],
model="llama-3.1-70b-versatile", # Updated to the correct model name
max_tokens=2048,
temperature=0.7
)
try:
dialogue = Dialogue.model_validate_json(response.choices[0].message.content)
except ValidationError as e:
raise ValueError(f"Failed to parse dialogue JSON: {e}")
return dialogue
# Make sure the truncate_text function is defined here or imported if it's in another file |