Spaces:
Sleeping
Sleeping
import os | |
import tempfile | |
import asyncio | |
from fastapi import FastAPI, File, UploadFile, Response | |
import uvicorn | |
from groq import Groq | |
from transformers import VitsModel, AutoTokenizer | |
import torch | |
import torchaudio | |
from io import BytesIO | |
os.system('apt update') | |
os.system('apt install libgomp1 -y') | |
# Preload TTS models | |
tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng") | |
tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng") | |
# Ensure the models are using GPU if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
tts_model = tts_model.to(device) | |
# Initialize the Groq client with API key | |
api_key = "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" # Replace with your actual Groq API key | |
chat_model = "llama3-8b-8192" | |
client = Groq(api_key=api_key) | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Convert audio to text using Groq's API | |
async def audio_to_text(file: UploadFile): | |
audio_data = await file.read() | |
transcription = client.audio.transcriptions.create( | |
file=(file.filename, audio_data), | |
model="whisper-large-v3", # The Whisper model for transcription | |
prompt="Specify context or spelling", # Optional: Customize transcription context | |
response_format="json", | |
language="en", | |
temperature=0.0 | |
) | |
return transcription.text | |
# Get chat response from Groq API | |
async def get_chat_response(api_key, model, user_message, temperature=0.5, max_tokens=258, top_p=1, stop=None): | |
client = Groq(api_key=api_key) | |
# Chat completion with a system message to control output format | |
chat_completion = client.chat.completions.create( | |
messages=[ | |
{"role": "system", "content": "You are a virtual human assistant in an AR and VR environment. Your responses should be short, concise, and suitable for text-to-speech conversion. Avoid numbers in digit form."}, | |
{"role": "user", "content": user_message} | |
], | |
model=model, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
top_p=top_p, | |
stop=stop, | |
stream=False, | |
) | |
return chat_completion.choices[0].message.content | |
# Convert text to speech using the Vits TTS model | |
async def text_to_speech(text, filename="output.wav"): | |
if not text or text.strip() == "": | |
raise ValueError("Input text is empty or invalid") | |
# Tokenize the input text for TTS | |
inputs = tokenizer(text, return_tensors="pt") | |
inputs['input_ids'] = inputs['input_ids'].to(torch.long) | |
inputs = {key: value.to(device) for key, value in inputs.items()} | |
print(f"Input IDs shape: {inputs['input_ids'].shape}") | |
# Generate waveform from text | |
with torch.no_grad(): | |
output = tts_model(**inputs).waveform | |
# Save the generated waveform to a temporary WAV file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file: | |
temp_filename = temp_file.name | |
torchaudio.save(temp_filename, output.cpu(), sample_rate=tts_model.config.sampling_rate, format="wav") | |
# Read the generated audio into a BytesIO buffer | |
with open(temp_filename, "rb") as f: | |
audio_buffer = BytesIO(f.read()) | |
os.remove(temp_filename) # Delete the temporary file after reading | |
audio_buffer.seek(0) # Rewind the buffer for reading | |
return audio_buffer | |
# Main API endpoint for processing audio | |
async def process_audio(audio_file: UploadFile = File(...)): | |
print('-\LOG : Request on /') | |
# Convert uploaded audio to text | |
user_message = await audio_to_text(audio_file) | |
# Generate a chat response from Groq | |
response_text = await get_chat_response(api_key, chat_model, user_message) | |
# Ensure response_text is valid | |
if not response_text: | |
return Response(content="Error: Generated response text is empty or invalid.", media_type="text/plain") | |
# Convert the chat response to speech | |
audio_output = await text_to_speech(response_text) | |
# Return the generated speech as a WAV file in the response | |
return Response(content=audio_output.read(), media_type="audio/wav", headers={ | |
"Content-Disposition": "attachment; filename=response.wav" | |
}) | |
# # Start the Uvicorn server for FastAPI | |
# if __name__ == "__main__": | |
# uvicorn.run(app, host="0.0.0.0", port=8000) | |