area444 commited on
Commit
36fd81e
·
verified ·
1 Parent(s): be495d0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +174 -0
app.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Query
2
+ from fastapi.responses import StreamingResponse
3
+ import os
4
+ from os import environ as env
5
+ import torch
6
+ import time
7
+ import nltk
8
+ import io
9
+ import base64
10
+ import torchaudio
11
+ from fastapi.responses import JSONResponse
12
+ from app.inference import inference, LFinference, compute_style
13
+ import numpy as np
14
+
15
+ nltk.download('punkt')
16
+ nltk.download('punkt_tab')
17
+
18
+ app = FastAPI()
19
+
20
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
21
+
22
+ @app.get("/")
23
+ async def read_root():
24
+ #return {"details": f"Hello! This is {env['SECRET_API_KEY']} environment"}
25
+ #return {"details": f"Hello Stream!"}
26
+ #return {"details": f"Hello Stream! This is {env['API_KEY_SECRET']} environment running OK!"}
27
+ return {"details": "Environment is running OK!"}
28
+
29
+ @app.post("/synthesize/")
30
+ async def synthesize(
31
+ text: str,
32
+ return_base64: bool = True,
33
+ ###################################################
34
+ diffusion_steps: int = Query(5, ge=5, le=200),
35
+ embedding_scale: float = Query(1.0, ge=1.0, le=5.0)
36
+ ###################################################
37
+ ):
38
+ try:
39
+ start = time.time()
40
+ noise = torch.randn(1, 1, 256).to(device)
41
+ wav = inference(text, noise, diffusion_steps=diffusion_steps, embedding_scale=embedding_scale)
42
+ rtf = (time.time() - start) / (len(wav) / 24000)
43
+
44
+ if return_base64:
45
+ audio_buffer = io.BytesIO()
46
+ torchaudio.save(audio_buffer, torch.tensor(wav).unsqueeze(0), 24000, format="wav")
47
+ audio_buffer.seek(0)
48
+
49
+ audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8')
50
+
51
+ return JSONResponse(content={"RTF": rtf, "audio_base64": audio_base64})
52
+ else:
53
+ return JSONResponse(content={"RTF": rtf, "audio": wav.tolist()})
54
+
55
+ except Exception as e:
56
+ raise HTTPException(status_code=500, detail=str(e))
57
+
58
+
59
+ @app.post("/synthesize_longform_streaming/")
60
+ async def synthesize_longform(
61
+ passage: str,
62
+ return_base64: bool = False,
63
+ ###################################################
64
+ alpha: float = Query(0.7, ge=0.0, le=1.0),
65
+ diffusion_steps: int = Query(10, ge=5, le=200),
66
+ embedding_scale: float = Query(1.5, ge=1.0, le=5.0)
67
+ ###################################################
68
+ ):
69
+ try:
70
+ sentences = passage.split('.') # simple split
71
+ wavs = []
72
+ s_prev = None
73
+
74
+ start = time.time()
75
+
76
+ for text in sentences:
77
+ if text.strip() == "":
78
+ continue
79
+ text += '.' # add it back
80
+ noise = torch.randn(1, 1, 256).to(device) # Generate noise
81
+ wav, s_prev = LFinference(text, s_prev, noise, alpha=0.7,
82
+ diffusion_steps=diffusion_steps,
83
+ embedding_scale=embedding_scale)
84
+ wavs.append(wav)
85
+
86
+ final_wav = np.concatenate(wavs) # Concatenate all wavs
87
+ rtf = (time.time() - start) / (len(final_wav) / 24000)
88
+
89
+ audio_buffer = io.BytesIO()
90
+ torchaudio.save(audio_buffer, torch.tensor(final_wav).unsqueeze(0), 24000, format="wav")
91
+ audio_buffer.seek(0)
92
+
93
+ if return_base64:
94
+ audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8')
95
+ return JSONResponse(content={"RTF": rtf, "audio_base64": audio_base64})
96
+ else:
97
+ #return JSONResponse(content={"RTF": rtf, "audio": final_wav.tolist()})
98
+ return StreamingResponse(audio_buffer, media_type="audio/wav")
99
+
100
+ except Exception as e:
101
+ raise HTTPException(status_code=500, detail=str(e))
102
+
103
+
104
+
105
+ @app.post("/synthesize_with_emotion/")
106
+ async def synthesize_with_emotion(
107
+ texts: dict,
108
+ return_base64: bool = True,
109
+ ###################################################
110
+ diffusion_steps: int = Query(100, ge=5, le=200),
111
+ embedding_scale: float = Query(5.0, ge=1.0, le=5.0)
112
+ ###################################################
113
+ ):
114
+ try:
115
+ results = []
116
+
117
+ for emotion, text in texts.items():
118
+ noise = torch.randn(1, 1, 256).to(device)
119
+ wav = inference(text, noise, diffusion_steps=diffusion_steps,
120
+ embedding_scale=embedding_scale)
121
+
122
+ if return_base64:
123
+ audio_buffer = io.BytesIO()
124
+ torchaudio.save(audio_buffer, torch.tensor(wav).unsqueeze(0), 24000, format="wav")
125
+ audio_buffer.seek(0)
126
+
127
+ audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8')
128
+
129
+ results.append({
130
+ "emotion": emotion,
131
+ "audio_base64": audio_base64
132
+ })
133
+ else:
134
+ results.append({
135
+ "emotion": emotion,
136
+ "audio": wav.tolist()
137
+ })
138
+
139
+ return JSONResponse(content={"results": results})
140
+
141
+ except Exception as e:
142
+ raise HTTPException(status_code=500, detail=str(e))
143
+
144
+
145
+
146
+
147
+
148
+ @app.post("/synthesize_streaming_audio/")
149
+ async def synthesize_streaming_audio(
150
+ text: str,
151
+ return_base64: bool = False,
152
+ ###################################################
153
+ diffusion_steps: int = Query(5, ge=5, le=200),
154
+ embedding_scale: float = Query(1.0, ge=1.0, le=5.0)
155
+ ###################################################
156
+ ):
157
+ try:
158
+ start = time.time()
159
+ noise = torch.randn(1, 1, 256).to(device)
160
+ wav = inference(text, noise, diffusion_steps=diffusion_steps, embedding_scale=embedding_scale)
161
+ rtf = (time.time() - start) / (len(wav) / 24000)
162
+
163
+ audio_buffer = io.BytesIO()
164
+ torchaudio.save(audio_buffer, torch.tensor(wav).unsqueeze(0), 24000, format="wav")
165
+ audio_buffer.seek(0)
166
+
167
+ if return_base64:
168
+ audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8')
169
+ return JSONResponse(content={"RTF": rtf, "audio_base64": audio_base64})
170
+ else:
171
+ return StreamingResponse(audio_buffer, media_type="audio/wav")
172
+
173
+ except Exception as e:
174
+ raise HTTPException(status_code=500, detail=str(e))