siddhartharya commited on
Commit
652d9d0
·
verified ·
1 Parent(s): 337fec8

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +48 -0
utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from groq import Groq
3
+ from gtts import gTTS
4
+ from pydantic import BaseModel, ValidationError
5
+ from typing import List, Literal
6
+ import tiktoken
7
+
8
+ groq_client = Groq()
9
+ tokenizer = tiktoken.get_encoding("cl100k_base")
10
+
11
+ class DialogueItem(BaseModel):
12
+ speaker: Literal["Host", "Guest"]
13
+ text: str
14
+
15
+ class Dialogue(BaseModel):
16
+ dialogue: List[DialogueItem]
17
+
18
+ def truncate_text(text, max_tokens=2048):
19
+ tokens = tokenizer.encode(text)
20
+ if len(tokens) > max_tokens:
21
+ return tokenizer.decode(tokens[:max_tokens])
22
+ return text
23
+
24
+ def generate_script(system_prompt: str, input_text: str, tone: str):
25
+ input_text = truncate_text(input_text)
26
+ prompt = f"{system_prompt}\nTONE: {tone}\nINPUT TEXT: {input_text}"
27
+
28
+ response = groq_client.chat.completions.create(
29
+ messages=[
30
+ {"role": "system", "content": prompt},
31
+ ],
32
+ model="llama2-70b-4096",
33
+ max_tokens=2048,
34
+ temperature=0.7
35
+ )
36
+
37
+ try:
38
+ dialogue = Dialogue.model_validate_json(response.choices[0].message.content)
39
+ except ValidationError as e:
40
+ raise ValueError(f"Failed to parse dialogue JSON: {e}")
41
+
42
+ return dialogue
43
+
44
+ def generate_audio(text: str, speaker: str) -> str:
45
+ tts = gTTS(text, lang='en', tld='com' if speaker == "Host" else 'co.uk')
46
+ filename = f"{speaker.lower()}_audio.mp3"
47
+ tts.save(filename)
48
+ return filename