Anis1123 commited on
Commit
917878c
·
verified ·
1 Parent(s): e5a35d8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import List, Tuple
4
+ from huggingface_hub import InferenceClient
5
+ import edge_tts
6
+ import tempfile
7
+ import asyncio
8
+ import os
9
+ from fastapi.responses import FileResponse
10
+ from groq import Groq
11
+
12
+ app = FastAPI()
13
+
14
+ # Initialize the client for Hugging Face Inference API
15
+ # client = InferenceClient("unsloth/gemma-2b-it-bnb-4bit")
16
+ client = Groq(
17
+ api_key='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'
18
+ )
19
+
20
+ async def text_to_speech(text, voice, rate, pitch):
21
+ voice_short_name = voice.split(" - ")[0]
22
+ rate_str = f"{rate:+d}%"
23
+ pitch_str = f"{pitch:+d}Hz"
24
+ communicate = edge_tts.Communicate(text, voice_short_name, rate=rate_str, pitch=pitch_str)
25
+ submaker = edge_tts.SubMaker()
26
+
27
+
28
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
29
+ tmp_path = tmp_file.name
30
+
31
+ async for chunk in communicate.stream():
32
+ if chunk["type"] == "audio":
33
+ tmp_file.write(chunk["data"])
34
+ elif chunk["type"] == "WordBoundary":
35
+ submaker.create_sub((chunk["offset"], chunk["duration"]), chunk["text"])
36
+
37
+
38
+ # with open('test.vtt', "w", encoding="utf-8") as file:
39
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode='w+', encoding='utf-8') as tmp_file:
40
+ tmp_vtt_path = tmp_file.name
41
+ tmp_file.write(submaker.generate_subs())
42
+
43
+ return tmp_path, tmp_vtt_path, None
44
+
45
+ def tts_interface(text, voice, rate, pitch):
46
+ audio, vtt, warning = asyncio.run(text_to_speech(text, voice, rate, pitch))
47
+ return audio, vtt, warning
48
+
49
+ @app.get("/")
50
+ def greet_json():
51
+ return {"Hello": "World!"}
52
+
53
+
54
+ # Define a model for the incoming request
55
+ class ChatRequest(BaseModel):
56
+ message: str
57
+ history: List[Tuple[str, str]] = []
58
+ system_message: str
59
+ max_tokens: int = 512
60
+ temperature: float = 0.7
61
+ top_p: float = 0.95
62
+
63
+
64
+ @app.get("/file/")
65
+ def file(path: str):
66
+ return FileResponse(path, media_type="audio/mpeg", filename="audio.mp3")
67
+
68
+ @app.get("/file-vtt/")
69
+ def fileVtt(path: str):
70
+ return FileResponse(path)
71
+
72
+ # Define a route to handle POST requests
73
+ @app.post("/chat")
74
+ def chat(request: ChatRequest):
75
+ messages = [{"role": "system", "content": request.system_message}]
76
+
77
+ for val in request.history:
78
+ if val[0]:
79
+ messages.append({"role": "user", "content": val[0]})
80
+ if val[1]:
81
+ messages.append({"role": "assistant", "content": val[1]})
82
+
83
+ messages.append({"role": "user", "content": request.message})
84
+
85
+ try:
86
+ response = client.chat.completions.create(
87
+ model="llama-3.1-8b-instant",
88
+ messages=messages,
89
+ max_tokens=request.max_tokens,
90
+ stream=False,
91
+ stop=None,
92
+ temperature=request.temperature,
93
+ top_p=request.top_p,
94
+ )
95
+
96
+ data = tts_interface((response.choices[0].message.content.replace('**', '')).replace('**', ''), 'en-GB-MaisieNeural - en-GB (Female)', 0, 0)
97
+
98
+ if os.path.exists(data[0]):
99
+ return {
100
+ "text": response.choices[0].message.content.replace('**', ''),
101
+ "audio" : data[0],
102
+ "vtt" : data[1]
103
+ }
104
+
105
+ except Exception as e:
106
+ raise HTTPException(status_code=500, detail=str(e))