test_ing / app.py
AyushS9020's picture
Update app.py
d569504 verified
import os
import tempfile
import asyncio
from fastapi import FastAPI, File, UploadFile, Response
import uvicorn
from groq import Groq
from transformers import VitsModel, AutoTokenizer
import torch
import torchaudio
from io import BytesIO
os.system('apt update')
os.system('apt install libgomp1 -y')
# Preload TTS models
tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
# Ensure the models are using GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tts_model = tts_model.to(device)
# Initialize the Groq client with API key
api_key = "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" # Replace with your actual Groq API key
chat_model = "llama3-8b-8192"
client = Groq(api_key=api_key)
# Initialize FastAPI app
app = FastAPI()
# Convert audio to text using Groq's API
async def audio_to_text(file: UploadFile):
audio_data = await file.read()
transcription = client.audio.transcriptions.create(
file=(file.filename, audio_data),
model="whisper-large-v3", # The Whisper model for transcription
prompt="Specify context or spelling", # Optional: Customize transcription context
response_format="json",
language="en",
temperature=0.0
)
return transcription.text
# Get chat response from Groq API
async def get_chat_response(api_key, model, user_message, temperature=0.5, max_tokens=258, top_p=1, stop=None):
client = Groq(api_key=api_key)
# Chat completion with a system message to control output format
chat_completion = client.chat.completions.create(
messages=[
{"role": "system", "content": "You are a virtual human assistant in an AR and VR environment. Your responses should be short, concise, and suitable for text-to-speech conversion. Avoid numbers in digit form."},
{"role": "user", "content": user_message}
],
model=model,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stop=stop,
stream=False,
)
return chat_completion.choices[0].message.content
# Convert text to speech using the Vits TTS model
async def text_to_speech(text, filename="output.wav"):
if not text or text.strip() == "":
raise ValueError("Input text is empty or invalid")
# Tokenize the input text for TTS
inputs = tokenizer(text, return_tensors="pt")
inputs['input_ids'] = inputs['input_ids'].to(torch.long)
inputs = {key: value.to(device) for key, value in inputs.items()}
print(f"Input IDs shape: {inputs['input_ids'].shape}")
# Generate waveform from text
with torch.no_grad():
output = tts_model(**inputs).waveform
# Save the generated waveform to a temporary WAV file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
temp_filename = temp_file.name
torchaudio.save(temp_filename, output.cpu(), sample_rate=tts_model.config.sampling_rate, format="wav")
# Read the generated audio into a BytesIO buffer
with open(temp_filename, "rb") as f:
audio_buffer = BytesIO(f.read())
os.remove(temp_filename) # Delete the temporary file after reading
audio_buffer.seek(0) # Rewind the buffer for reading
return audio_buffer
# Main API endpoint for processing audio
@app.get('/')
@app.post("/processaudio")
async def process_audio(audio_file: UploadFile = File(...)):
print('-\LOG : Request on /')
# Convert uploaded audio to text
user_message = await audio_to_text(audio_file)
# Generate a chat response from Groq
response_text = await get_chat_response(api_key, chat_model, user_message)
# Ensure response_text is valid
if not response_text:
return Response(content="Error: Generated response text is empty or invalid.", media_type="text/plain")
# Convert the chat response to speech
audio_output = await text_to_speech(response_text)
# Return the generated speech as a WAV file in the response
return Response(content=audio_output.read(), media_type="audio/wav", headers={
"Content-Disposition": "attachment; filename=response.wav"
})
# # Start the Uvicorn server for FastAPI
# if __name__ == "__main__":
# uvicorn.run(app, host="0.0.0.0", port=8000)