import os import tempfile import asyncio from fastapi import FastAPI, File, UploadFile, Response import uvicorn from groq import Groq from transformers import VitsModel, AutoTokenizer import torch import torchaudio from io import BytesIO os.system('apt update') os.system('apt install libgomp1 -y') # Preload TTS models tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng") tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng") # Ensure the models are using GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tts_model = tts_model.to(device) # Initialize the Groq client with API key api_key = "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" # Replace with your actual Groq API key chat_model = "llama3-8b-8192" client = Groq(api_key=api_key) # Initialize FastAPI app app = FastAPI() # Convert audio to text using Groq's API async def audio_to_text(file: UploadFile): audio_data = await file.read() transcription = client.audio.transcriptions.create( file=(file.filename, audio_data), model="whisper-large-v3", # The Whisper model for transcription prompt="Specify context or spelling", # Optional: Customize transcription context response_format="json", language="en", temperature=0.0 ) return transcription.text # Get chat response from Groq API async def get_chat_response(api_key, model, user_message, temperature=0.5, max_tokens=258, top_p=1, stop=None): client = Groq(api_key=api_key) # Chat completion with a system message to control output format chat_completion = client.chat.completions.create( messages=[ {"role": "system", "content": "You are a virtual human assistant in an AR and VR environment. Your responses should be short, concise, and suitable for text-to-speech conversion. Avoid numbers in digit form."}, {"role": "user", "content": user_message} ], model=model, temperature=temperature, max_tokens=max_tokens, top_p=top_p, stop=stop, stream=False, ) return chat_completion.choices[0].message.content # Convert text to speech using the Vits TTS model async def text_to_speech(text, filename="output.wav"): if not text or text.strip() == "": raise ValueError("Input text is empty or invalid") # Tokenize the input text for TTS inputs = tokenizer(text, return_tensors="pt") inputs['input_ids'] = inputs['input_ids'].to(torch.long) inputs = {key: value.to(device) for key, value in inputs.items()} print(f"Input IDs shape: {inputs['input_ids'].shape}") # Generate waveform from text with torch.no_grad(): output = tts_model(**inputs).waveform # Save the generated waveform to a temporary WAV file with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file: temp_filename = temp_file.name torchaudio.save(temp_filename, output.cpu(), sample_rate=tts_model.config.sampling_rate, format="wav") # Read the generated audio into a BytesIO buffer with open(temp_filename, "rb") as f: audio_buffer = BytesIO(f.read()) os.remove(temp_filename) # Delete the temporary file after reading audio_buffer.seek(0) # Rewind the buffer for reading return audio_buffer # Main API endpoint for processing audio @app.get('/') @app.post("/processaudio") async def process_audio(audio_file: UploadFile = File(...)): print('-\LOG : Request on /') # Convert uploaded audio to text user_message = await audio_to_text(audio_file) # Generate a chat response from Groq response_text = await get_chat_response(api_key, chat_model, user_message) # Ensure response_text is valid if not response_text: return Response(content="Error: Generated response text is empty or invalid.", media_type="text/plain") # Convert the chat response to speech audio_output = await text_to_speech(response_text) # Return the generated speech as a WAV file in the response return Response(content=audio_output.read(), media_type="audio/wav", headers={ "Content-Disposition": "attachment; filename=response.wav" }) # # Start the Uvicorn server for FastAPI # if __name__ == "__main__": # uvicorn.run(app, host="0.0.0.0", port=8000)