File size: 3,710 Bytes
69c2a91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
from groq import Groq, GroqError
import gradio as gr
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import soundfile as sf

# Initialize Groq client with API key
GROQ_API_KEY = "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"

# Device setup for Parler-TTS
device = "cuda:0" if torch.cuda.is_available() else "cpu"
parler_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1").to(device)
parler_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")

# Function to transcribe audio using Whisper through Groq, with error handling
def transcribe_audio(audio):
    try:
        # Ensure the audio is in the correct format supported by Groq
        audio_input = audio
        transcription_response = client.transcriptions.create(
            model="openai/whisper-large-v3",
            audio=audio_input,
        )
        return transcription_response['text']
    except GroqError as e:
        print(f"Groq transcription error: {e}")
        return "Error: Failed to transcribe audio."

# Function to generate a response using LLaMA through Groq, with error handling
def generate_response(text):
    try:
        chat_completion = client.chat.completions.create(
            messages=[{"role": "user", "content": text}],
            model="llama3-70b-8192",  # Modify based on the model you're using
        )
        return chat_completion.choices[0].message['content']
    except GroqError as e:
        print(f"Groq response generation error: {e}")
        return "Error: Failed to generate a response."

# Function to convert text to speech using Parler-TTS, unchanged
def text_to_speech(text):
    try:
        description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch."
        input_ids = parler_tokenizer(description, return_tensors="pt").input_ids.to(device)
        prompt_input_ids = parler_tokenizer(text, return_tensors="pt").input_ids.to(device)
        generation = parler_model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
        audio_arr = generation.cpu().numpy().squeeze()
        sf.write("parler_tts_out.wav", audio_arr, parler_model.config.sampling_rate)
        return "parler_tts_out.wav"
    except Exception as e:
        print(f"Parler-TTS error: {e}")
        return "Error: Failed to convert text to speech."

# Gradio interface combining all the components, with error handling in each step
def chatbot_pipeline(audio):
    # Step 1: Convert speech to text using Whisper through Groq
    transcribed_text = transcribe_audio(audio)

    # If there was an error in transcription, return the error message
    if "Error" in transcribed_text:
        return transcribed_text, None

    # Step 2: Generate a response using LLaMA through Groq
    response_text = generate_response(transcribed_text)

    # If there was an error in response generation, return the error message
    if "Error" in response_text:
        return response_text, None

    # Step 3: Convert response text to speech using Parler-TTS
    response_audio_path = text_to_speech(response_text)

    # If there was an error in TTS conversion, return the error message
    if "Error" in response_audio_path:
        return response_text, None

    # Return both text and audio for output
    return response_text, response_audio_path

# Gradio interface setup
ui = gr.Interface(
    fn=chatbot_pipeline,
    inputs=gr.Audio(type="numpy"),  # Removed 'source' and 'streaming'
    outputs=[gr.Textbox(label="Chatbot Response"), gr.Audio(label="Chatbot Voice Response")],
    live=True
)

ui.launch()