Spaces:
Sleeping
Sleeping
File size: 3,710 Bytes
69c2a91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import os
from groq import Groq, GroqError
import gradio as gr
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import soundfile as sf
# Initialize Groq client with API key
GROQ_API_KEY = "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"
# Device setup for Parler-TTS
device = "cuda:0" if torch.cuda.is_available() else "cpu"
parler_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1").to(device)
parler_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")
# Function to transcribe audio using Whisper through Groq, with error handling
def transcribe_audio(audio):
try:
# Ensure the audio is in the correct format supported by Groq
audio_input = audio
transcription_response = client.transcriptions.create(
model="openai/whisper-large-v3",
audio=audio_input,
)
return transcription_response['text']
except GroqError as e:
print(f"Groq transcription error: {e}")
return "Error: Failed to transcribe audio."
# Function to generate a response using LLaMA through Groq, with error handling
def generate_response(text):
try:
chat_completion = client.chat.completions.create(
messages=[{"role": "user", "content": text}],
model="llama3-70b-8192", # Modify based on the model you're using
)
return chat_completion.choices[0].message['content']
except GroqError as e:
print(f"Groq response generation error: {e}")
return "Error: Failed to generate a response."
# Function to convert text to speech using Parler-TTS, unchanged
def text_to_speech(text):
try:
description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch."
input_ids = parler_tokenizer(description, return_tensors="pt").input_ids.to(device)
prompt_input_ids = parler_tokenizer(text, return_tensors="pt").input_ids.to(device)
generation = parler_model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
audio_arr = generation.cpu().numpy().squeeze()
sf.write("parler_tts_out.wav", audio_arr, parler_model.config.sampling_rate)
return "parler_tts_out.wav"
except Exception as e:
print(f"Parler-TTS error: {e}")
return "Error: Failed to convert text to speech."
# Gradio interface combining all the components, with error handling in each step
def chatbot_pipeline(audio):
# Step 1: Convert speech to text using Whisper through Groq
transcribed_text = transcribe_audio(audio)
# If there was an error in transcription, return the error message
if "Error" in transcribed_text:
return transcribed_text, None
# Step 2: Generate a response using LLaMA through Groq
response_text = generate_response(transcribed_text)
# If there was an error in response generation, return the error message
if "Error" in response_text:
return response_text, None
# Step 3: Convert response text to speech using Parler-TTS
response_audio_path = text_to_speech(response_text)
# If there was an error in TTS conversion, return the error message
if "Error" in response_audio_path:
return response_text, None
# Return both text and audio for output
return response_text, response_audio_path
# Gradio interface setup
ui = gr.Interface(
fn=chatbot_pipeline,
inputs=gr.Audio(type="numpy"), # Removed 'source' and 'streaming'
outputs=[gr.Textbox(label="Chatbot Response"), gr.Audio(label="Chatbot Voice Response")],
live=True
)
ui.launch()
|