Spaces:
Sleeping
Sleeping
import os | |
import time | |
from pydantic import BaseModel | |
from fastapi import FastAPI, HTTPException, Query, Request | |
from fastapi.responses import FileResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from langchain.chains import LLMChain | |
from langchain.prompts import PromptTemplate | |
from TextGen.suno import custom_generate_audio, get_audio_information | |
from langchain_google_genai import ( | |
ChatGoogleGenerativeAI, | |
HarmBlockThreshold, | |
HarmCategory, | |
) | |
from TextGen import app | |
from gradio_client import Client | |
from typing import List | |
class Message(BaseModel): | |
npc: str | None = None | |
messages: List[str] | None = None | |
class VoiceMessage(BaseModel): | |
npc: str | None = None | |
input: str | None = None | |
language: str | None = "en" | |
genre:str | None = "Male" | |
song_base_api=os.environ["VERCEL_API"] | |
my_hf_token=os.environ["HF_TOKEN"] | |
tts_client = Client("https://jofthomas-xtts.hf.space/",hf_token=my_hf_token) | |
main_npcs={ | |
"Blacksmith":"./voices/Blacksmith.mp3", | |
"Herbalist":"./voices/female.mp3", | |
"Bard":"./voices/Bard_voice.mp3" | |
} | |
main_npc_system_prompts={ | |
"Blacksmith":"You are a blacksmith in a video game", | |
"Herbalist":"You are an herbalist in a video game", | |
"Bard":"You are a bard in a video game" | |
} | |
class Generate(BaseModel): | |
text:str | |
def generate_text(messages: List[str], npc:str): | |
print(npc) | |
if npc in main_npcs: | |
system_prompt=main_npc_system_prompts[npc] | |
else: | |
system_prompt="you're a character in a video game. Play along." | |
print(system_prompt) | |
new_messages=[{"role": "user", "content": system_prompt}] | |
for index, message in enumerate(messages): | |
if index%2==0: | |
new_messages.append({"role": "user", "content": message}) | |
else: | |
new_messages.append({"role": "assistant", "content": message}) | |
print(new_messages) | |
# Initialize the LLM | |
llm = ChatGoogleGenerativeAI( | |
model="gemini-pro", | |
safety_settings={ | |
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, | |
}, | |
) | |
llm_response = llm.invoke(new_messages) | |
print(llm_response) | |
return Generate(text=llm_response.content) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
def api_home(): | |
return {'detail': 'Welcome to FastAPI TextGen Tutorial!'} | |
def inference(message: Message): | |
return generate_text(messages=message.messages, npc=message.npc) | |
#Dummy function for now | |
def determine_vocie_from_npc(npc,genre): | |
if npc in main_npcs: | |
return main_npcs[npc] | |
else: | |
if genre =="Male": | |
"./voices/default_male.mp3" | |
if genre=="Female": | |
return"./voices/default_female.mp3" | |
else: | |
return "./voices/narator_out.wav" | |
async def generate_wav(message:VoiceMessage): | |
try: | |
voice=determine_vocie_from_npc(message.npc, message.genre) | |
# Use the Gradio client to generate the wav file | |
result = tts_client.predict( | |
message.input, # str in 'Text Prompt' Textbox component | |
message.language, # str in 'Language' Dropdown component | |
voice, # str (filepath on your computer (or URL) of file) in 'Reference Audio' Audio component | |
voice, # str (filepath on your computer (or URL) of file) in 'Use Microphone for Reference' Audio component | |
False, # bool in 'Use Microphone' Checkbox component | |
False, # bool in 'Cleanup Reference Voice' Checkbox component | |
False, # bool in 'Do not use language auto-detect' Checkbox component | |
True, # bool in 'Agree' Checkbox component | |
fn_index=1 | |
) | |
# Get the path of the generated wav file | |
wav_file_path = result[1] | |
# Return the generated wav file as a response | |
return FileResponse(wav_file_path, media_type="audio/wav", filename="output.wav") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def generate_song(text: str): | |
try: | |
data = custom_generate_audio({ | |
"prompt": f"{text}", | |
"make_instrumental": False, | |
"wait_audio": False | |
}) | |
ids = f"{data[0]['id']},{data[1]['id']}" | |
print(f"ids: {ids}") | |
for _ in range(60): | |
data = get_audio_information(ids) | |
if data[0]["status"] == 'streaming': | |
print(f"{data[0]['id']} ==> {data[0]['audio_url']}") | |
print(f"{data[1]['id']} ==> {data[1]['audio_url']}") | |
break | |
# sleep 5s | |
time.sleep(5) | |
except: | |
print("Error") |