File size: 4,351 Bytes
0f7e22b
 
 
 
 
 
 
 
 
 
 
8994fef
 
730a407
0f7e22b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d569504
0f7e22b
 
d569504
 
0f7e22b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae1999d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import tempfile
import asyncio
from fastapi import FastAPI, File, UploadFile, Response
import uvicorn
from groq import Groq
from transformers import VitsModel, AutoTokenizer
import torch
import torchaudio
from io import BytesIO

os.system('apt update')
os.system('apt install libgomp1 -y')

# Preload TTS models
tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")

# Ensure the models are using GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tts_model = tts_model.to(device)

# Initialize the Groq client with API key
api_key = "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"  # Replace with your actual Groq API key
chat_model = "llama3-8b-8192"
client = Groq(api_key=api_key)

# Initialize FastAPI app
app = FastAPI()

# Convert audio to text using Groq's API
async def audio_to_text(file: UploadFile):
    audio_data = await file.read()
    transcription = client.audio.transcriptions.create(
        file=(file.filename, audio_data),
        model="whisper-large-v3",  # The Whisper model for transcription
        prompt="Specify context or spelling",  # Optional: Customize transcription context
        response_format="json",
        language="en",
        temperature=0.0
    )
    return transcription.text

# Get chat response from Groq API
async def get_chat_response(api_key, model, user_message, temperature=0.5, max_tokens=258, top_p=1, stop=None):
    client = Groq(api_key=api_key)

    # Chat completion with a system message to control output format
    chat_completion = client.chat.completions.create(
        messages=[
            {"role": "system", "content": "You are a virtual human assistant in an AR and VR environment. Your responses should be short, concise, and suitable for text-to-speech conversion. Avoid numbers in digit form."},
            {"role": "user", "content": user_message}
        ],
        model=model,
        temperature=temperature,
        max_tokens=max_tokens,
        top_p=top_p,
        stop=stop,
        stream=False,
    )

    return chat_completion.choices[0].message.content

# Convert text to speech using the Vits TTS model
async def text_to_speech(text, filename="output.wav"):
    if not text or text.strip() == "":
        raise ValueError("Input text is empty or invalid")

    # Tokenize the input text for TTS
    inputs = tokenizer(text, return_tensors="pt")
    inputs['input_ids'] = inputs['input_ids'].to(torch.long)
    inputs = {key: value.to(device) for key, value in inputs.items()}

    print(f"Input IDs shape: {inputs['input_ids'].shape}")

    # Generate waveform from text
    with torch.no_grad():
        output = tts_model(**inputs).waveform

    # Save the generated waveform to a temporary WAV file
    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
        temp_filename = temp_file.name
        torchaudio.save(temp_filename, output.cpu(), sample_rate=tts_model.config.sampling_rate, format="wav")

    # Read the generated audio into a BytesIO buffer
    with open(temp_filename, "rb") as f:
        audio_buffer = BytesIO(f.read())

    os.remove(temp_filename)  # Delete the temporary file after reading

    audio_buffer.seek(0)  # Rewind the buffer for reading
    return audio_buffer

# Main API endpoint for processing audio
@app.get('/')
@app.post("/processaudio")
async def process_audio(audio_file: UploadFile = File(...)):

    print('-\LOG : Request on /')
    # Convert uploaded audio to text
    user_message = await audio_to_text(audio_file)

    # Generate a chat response from Groq
    response_text = await get_chat_response(api_key, chat_model, user_message)

    # Ensure response_text is valid
    if not response_text:
        return Response(content="Error: Generated response text is empty or invalid.", media_type="text/plain")

    # Convert the chat response to speech
    audio_output = await text_to_speech(response_text)

    # Return the generated speech as a WAV file in the response
    return Response(content=audio_output.read(), media_type="audio/wav", headers={
        "Content-Disposition": "attachment; filename=response.wav"
    })

# # Start the Uvicorn server for FastAPI
# if __name__ == "__main__":
#     uvicorn.run(app, host="0.0.0.0", port=8000)