evBackend / TextGen /router.py
Jofthomas's picture
Jofthomas HF staff
Update TextGen/router.py
30fa914 verified
raw
history blame
6.76 kB
import os
import time
from langchain_core.pydantic_v1 import BaseModel, Field
from fastapi import FastAPI, HTTPException, Query, Request
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from TextGen.suno import custom_generate_audio, get_audio_information
from coqui import predict
from langchain_google_genai import (
ChatGoogleGenerativeAI,
HarmBlockThreshold,
HarmCategory,
)
from TextGen import app
from gradio_client import Client, handle_file
from typing import List
class PlayLastMusic(BaseModel):
'''plays the lastest created music '''
Desicion: str = Field(
..., description="Yes or No"
)
class CreateLyrics(BaseModel):
f'''create some Lyrics for a new music'''
Desicion: str = Field(
..., description="Yes or No"
)
class CreateNewMusic(BaseModel):
f'''create a new music with the Lyrics previously computed'''
Name: str = Field(
..., description="tags to describe the new music"
)
class Message(BaseModel):
npc: str | None = None
messages: List[str] | None = None
class VoiceMessage(BaseModel):
npc: str | None = None
input: str | None = None
language: str | None = "en"
genre:str | None = "Male"
song_base_api=os.environ["VERCEL_API"]
my_hf_token=os.environ["HF_TOKEN"]
tts_client = Client("Jofthomas/xtts",hf_token=my_hf_token)
main_npcs={
"Blacksmith":"./voices/Blacksmith.mp3",
"Herbalist":"./voices/female.mp3",
"Bard":"./voices/Bard_voice.mp3"
}
main_npc_system_prompts={
"Blacksmith":"You are a blacksmith in a video game",
"Herbalist":"You are an herbalist in a video game",
"Bard":"You are a bard in a video game"
}
class Generate(BaseModel):
text:str
def generate_text(messages: List[str], npc:str):
print(npc)
if npc in main_npcs:
system_prompt=main_npc_system_prompts[npc]
else:
system_prompt="you're a character in a video game. Play along."
print(system_prompt)
new_messages=[{"role": "user", "content": system_prompt}]
for index, message in enumerate(messages):
if index%2==0:
new_messages.append({"role": "user", "content": message})
else:
new_messages.append({"role": "assistant", "content": message})
print(new_messages)
# Initialize the LLM
llm = ChatGoogleGenerativeAI(
model="gemini-1.5-pro-latest",
max_output_tokens=100,
temperature=1,
safety_settings={
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE
},
)
if npc=="bard":
llm = llm.bind_tools([PlayLastMusic,CreateNewMusic,CreateLyrics])
llm_response = llm.invoke(new_messages)
print(llm_response)
return Generate(text=llm_response.content)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/", tags=["Home"])
def api_home():
return {'detail': 'Everchanging Quest backend, nothing to see here'}
@app.post("/api/generate", summary="Generate text from prompt", tags=["Generate"], response_model=Generate)
def inference(message: Message):
return generate_text(messages=message.messages, npc=message.npc)
#Dummy function for now
def determine_vocie_from_npc(npc,genre):
if npc in main_npcs:
return main_npcs[npc]
else:
if genre =="Male":
"./voices/default_male.mp3"
if genre=="Female":
return"./voices/default_female.mp3"
else:
return "./voices/narator_out.wav"
@app.post("/generate_wav")
async def generate_wav(message: VoiceMessage):
try:
voice = determine_vocie_from_npc(message.npc, message.genre)
audio_file_pth = handle_file(voice)
# Generator function to yield audio chunks
async def audio_stream():
result = tts_client.predict(
prompt=message.input,
language=message.language,
audio_file_pth=audio_file_pth,
mic_file_path=None,
use_mic=False,
voice_cleanup=False,
no_lang_auto_detect=False,
agree=True,
api_name="/predict"
)
for sampling_rate, audio_chunk in result:
yield audio_chunk.tobytes()
await asyncio.sleep(0) # Yield control to the event loop
# Return the generated audio as a streaming response
return StreamingResponse(audio_stream(), media_type="audio/wav")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/generate_voice")
async def generate_voice(message: VoiceMessage):
try:
voice = determine_vocie_from_npc(message.npc, message.genre)
audio_file_pth = handle_file(voice)
# Generator function to yield audio chunks
async def audio_stream():
result = predict(
prompt=message.input,
language=message.language,
audio_file_pth=audio_file_pth,
mic_file_path=None,
use_mic=False,
voice_cleanup=False,
no_lang_auto_detect=False,
agree=True,
)
for sampling_rate, audio_chunk in result:
yield audio_chunk.tobytes()
await asyncio.sleep(0) # Yield control to the event loop
# Return the generated audio as a streaming response
return StreamingResponse(audio_stream(), media_type="audio/wav")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/generate_song")
async def generate_song(text: str):
try:
data = custom_generate_audio({
"prompt": f"{text}",
"make_instrumental": False,
"wait_audio": False
})
ids = f"{data[0]['id']},{data[1]['id']}"
print(f"ids: {ids}")
for _ in range(60):
data = get_audio_information(ids)
if data[0]["status"] == 'streaming':
print(f"{data[0]['id']} ==> {data[0]['audio_url']}")
print(f"{data[1]['id']} ==> {data[1]['audio_url']}")
break
# sleep 5s
time.sleep(5)
except:
print("Error")