File size: 5,626 Bytes
b6b3361
fb3c56f
257e92c
1ffdc41
9f5ffb5
07631e1
 
 
 
ac79266
6234905
 
 
 
 
07631e1
235bceb
74a9029
a309f78
257e92c
 
 
 
 
122cc50
257e92c
 
 
 
 
122cc50
257e92c
 
 
 
 
 
122cc50
 
a309f78
 
74a9029
b667879
 
 
 
 
826930b
 
fb3c56f
07631e1
6aa3952
 
ef53cde
6aa3952
826930b
990cdd9
2290e9c
990cdd9
826930b
da84aa4
 
 
 
 
07631e1
 
 
da84aa4
 
c4e11f2
 
 
 
da84aa4
c4e11f2
da84aa4
 
 
 
 
 
74a9029
 
257e92c
a6818a7
257e92c
74a9029
257e92c
 
 
 
 
74a9029
257e92c
 
 
da84aa4
be1fb13
48a69b6
07631e1
 
 
 
 
 
 
 
 
 
 
12bb931
07631e1
 
a309f78
da84aa4
6aa3952
b667879
826930b
 
 
 
 
990cdd9
826930b
990cdd9
826930b
 
b667879
7f18929
9f5ffb5
6aa3952
d865d73
9f5ffb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6aa3952
 
fb3c56f
 
1ffdc41
fb3c56f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import os
import time
from langchain_core.pydantic_v1 import BaseModel, Field
from fastapi import FastAPI, HTTPException, Query, Request
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware

from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from TextGen.suno import custom_generate_audio, get_audio_information
from langchain_google_genai import (
    ChatGoogleGenerativeAI,
    HarmBlockThreshold,
    HarmCategory,
)
from TextGen import app
from gradio_client import Client, handle_file
from typing import List

class PlayLastMusic(BaseModel):
    '''plays the lastest created music '''
    Desicion: str = Field(
        ..., description="Yes or No"
    )

class CreateLyrics(BaseModel):
    f'''create some Lyrics for a new music'''
    Desicion: str = Field(
        ..., description="Yes or No"
    )

class CreateNewMusic(BaseModel):
    f'''create a new music with the Lyrics previously computed'''
    Name: str = Field(
        ..., description="tags to describe the new music"
    )



class Message(BaseModel):
    npc: str | None  = None
    messages: List[str] | None = None
    
class VoiceMessage(BaseModel):
    npc: str | None  = None
    input: str | None = None
    language: str | None = "en"
    genre:str | None = "Male"
    
song_base_api=os.environ["VERCEL_API"]

my_hf_token=os.environ["HF_TOKEN"]

tts_client = Client("Jofthomas/xtts",hf_token=my_hf_token)

main_npcs={
    "Blacksmith":"./voices/Blacksmith.mp3",
    "Herbalist":"./voices/female.mp3",
    "Bard":"./voices/Bard_voice.mp3"
}
main_npc_system_prompts={
    "Blacksmith":"You are a blacksmith in a video game",
    "Herbalist":"You are an herbalist in a video game",
    "Bard":"You are a bard in a video game"
}
class Generate(BaseModel):
    text:str

def generate_text(messages: List[str], npc:str):
    print(npc)
    if npc in main_npcs:
        system_prompt=main_npc_system_prompts[npc]
    else:
        system_prompt="you're a character in a video game. Play along."
    print(system_prompt)    
    new_messages=[{"role": "user", "content": system_prompt}]
    for index, message in enumerate(messages):
      if index%2==0:
        new_messages.append({"role": "user", "content": message})
      else:
        new_messages.append({"role": "assistant", "content": message})
    print(new_messages)
    # Initialize the LLM
    llm = ChatGoogleGenerativeAI(
        model="gemini-1.5-pro-latest",
        max_output_tokens=100,
        temperature=1,
        safety_settings={
                HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
                HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
                HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
                HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE
            },
    )
    if npc=="bard":
        llm = llm.bind_tools([PlayLastMusic,CreateNewMusic,CreateLyrics])

    llm_response = llm.invoke(new_messages)
    print(llm_response)
    return Generate(text=llm_response.content)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/", tags=["Home"])
def api_home():
    return {'detail': 'Everchanging Quest backend, nothing to see here'}

@app.post("/api/generate", summary="Generate text from prompt", tags=["Generate"], response_model=Generate)
def inference(message: Message):
    return generate_text(messages=message.messages, npc=message.npc)

#Dummy function for now
def determine_vocie_from_npc(npc,genre):
    if npc in main_npcs:
        return main_npcs[npc]
    else:
        if genre =="Male":
            "./voices/default_male.mp3"
        if genre=="Female":
            return"./voices/default_female.mp3"
        else:
            return "./voices/narator_out.wav"
    
@app.post("/generate_wav")
async def generate_wav(message: VoiceMessage):
    try:
        voice = determine_vocie_from_npc(message.npc, message.genre)
        audio_file_pth = handle_file(voice)

        # Generator function to yield audio chunks
        def audio_stream():
            result = tts_client.predict(
                prompt=message.input,
                language=message.language,
                audio_file_pth=audio_file_pth,
                mic_file_path=None,
                use_mic=False,
                voice_cleanup=False,
                no_lang_auto_detect=False,
                agree=True,
                api_name="/predict"
            )
            for sampling_rate, audio_chunk in result:
                yield audio_chunk.tobytes()

        # Return the generated audio as a streaming response
        return StreamingResponse(audio_stream(), media_type="audio/wav")

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/generate_song")
async def generate_song(text: str):
    try:
        data = custom_generate_audio({
            "prompt": f"{text}",
            "make_instrumental": False,
            "wait_audio": False
        })
        ids = f"{data[0]['id']},{data[1]['id']}"
        print(f"ids: {ids}")

        for _ in range(60):
            data = get_audio_information(ids)
            if data[0]["status"] == 'streaming':
                print(f"{data[0]['id']} ==> {data[0]['audio_url']}")
                print(f"{data[1]['id']} ==> {data[1]['audio_url']}")
                break
            # sleep 5s
            time.sleep(5)
    except:
        print("Error")