File size: 2,493 Bytes
3175dca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr
import torch
from transformers import pipeline
import time
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)

# Define the models using pipeline
asr_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-small", chunk_length_s=30)
text_pipe = pipeline("text-generation", model="HuggingFaceTB/SmolLM2-360M", max_length=512, temperature=0.7, top_p=0.9)
tts_pipe = pipeline("text-to-speech", model="mussacharles60/swahili-tts-female-voice")

# Define conversation rules
MAX_INPUT_SIZE = 100
PREDEFINED_ATTRIBUTES = ["name", "age", "location"]
CONTEXT_HISTORY = []

# Define the function to recognize speech
def recognize_speech(audio):
    retries = 3
    for _ in range(retries):
        try:
            result = asr_pipe(audio, return_timestamps=True)
            return result['text']
        except Exception as e:
            logging.error(f"ASR failed: {e}")
            time.sleep(1)
    return ""

# Define the function to generate text
def generate_text(prompt):
    global CONTEXT_HISTORY
    CONTEXT_HISTORY.append(prompt)
    if len(CONTEXT_HISTORY) > 5:
        CONTEXT_HISTORY.pop(0)
    context = " ".join(CONTEXT_HISTORY)
    outputs = text_pipe(context, max_length=512, num_return_sequences=1)
    generated_text = outputs[0]['generated_text']
    return generated_text

# Define the function to synthesize speech
def synthesize_speech(text):
    audio = tts_pipe(text, output_format="wav", sample_rate=16000)
    return audio

# Define the function to handle conversation
def handle_conversation(audio):
    recognized_text = recognize_speech(audio)
    if any(attr in recognized_text.lower() for attr in PREDEFINED_ATTRIBUTES):
        generated_text = generate_text(f"Please provide your {recognized_text}")
    else:
        generated_text = generate_text(recognized_text)
    synthesized_audio = synthesize_speech(generated_text)
    return synthesized_audio, generated_text

# Define the Gradio app
demo = gr.Blocks()

# Define the input and output components
input_audio = gr.Audio(label="Input Audio")
output_audio = gr.Audio(label="Output Audio")
output_text = gr.Textbox(label="Output Text")

# Define the buttons
conversation_button = gr.Button("Start Conversation")

# Define the event listeners
conversation_button.click(handle_conversation, inputs=input_audio, outputs=[output_audio, output_text])

# Launch the app
demo.launch()