|
import time |
|
import uuid |
|
|
|
import gradio as gr |
|
from gtts import gTTS |
|
from transformers import pipeline |
|
|
|
from main import index, run |
|
|
|
p = pipeline("automatic-speech-recognition", model="openai/whisper-base") |
|
|
|
"""Use text to call chat method from main.py""" |
|
|
|
models = ["GPT-3.5", "Flan UL2", "Flan T5"] |
|
|
|
with gr.Blocks(theme='snehilsanyal/scikit-learn') as demo: |
|
state = gr.State([]) |
|
|
|
|
|
def create_session_id(): |
|
return str(uuid.uuid4()) |
|
|
|
|
|
def add_text(history, text, model): |
|
print("Question asked: " + text) |
|
response = run_model(text, model) |
|
history = history + [(text, response)] |
|
print(history) |
|
return history, "" |
|
|
|
|
|
def run_model(text, model): |
|
start_time = time.time() |
|
print("start time:" + str(start_time)) |
|
response = run(text, model, state.session_id) |
|
end_time = time.time() |
|
|
|
if "SOURCES:" in response: |
|
response = response.replace("SOURCES:", "\nSOURCES:") |
|
|
|
print(response) |
|
print("Time taken: " + str(end_time - start_time)) |
|
return response |
|
|
|
|
|
def get_output(history, audio, model): |
|
txt = p(audio)["text"] |
|
|
|
audio_path = 'response.wav' |
|
response = run_model(txt, model) |
|
|
|
trimmed_response = response.split("SOURCES:")[0] |
|
myobj = gTTS(text=trimmed_response, lang='en', slow=False) |
|
myobj.save(audio_path) |
|
|
|
|
|
|
|
history.append(((audio,), (audio_path,))) |
|
print(history) |
|
return history |
|
|
|
|
|
def set_model(history, model): |
|
print("Model selected: " + model) |
|
history = get_first_message(history) |
|
index(model, state.session_id) |
|
return history |
|
|
|
|
|
def get_first_message(history): |
|
history = [(None, |
|
'Learn about the course and get answers with referred sources.\nWarning! Use the bot wisely. It might give incorrect answers.')] |
|
return history |
|
|
|
|
|
def bot(history): |
|
return history |
|
|
|
|
|
state.session_id = create_session_id() |
|
print("Session ID: " + state.session_id) |
|
|
|
|
|
|
|
chatbot = gr.Chatbot(get_first_message([]), elem_id="chatbot", label='3D Printing Revolution').style(height=300, |
|
container=False) |
|
|
|
|
|
|
|
radio = gr.Radio(models, label="Choose a model", value="GPT-3.5", type="value", visible=False) |
|
with gr.Row(): |
|
|
|
txt = gr.Textbox( |
|
label="Ask your question here and press enter", |
|
placeholder="Enter text and press enter", lines=1 |
|
).style(container=False) |
|
|
|
|
|
audio = gr.Audio(source="microphone", type="filepath", visible=False) |
|
|
|
with gr.Row(): |
|
gr.Examples( |
|
examples=['What is 3D printing?', 'Who are the instructors of the course?', 'What is the course about?', |
|
'Which software can be used to create a design file for 3D printing?', |
|
'What are the key takeaways from the course?', 'How to create a 3D printing design file?'], |
|
inputs=[txt], |
|
label="Examples") |
|
|
|
txt.submit(add_text, [chatbot, txt, radio], [chatbot, txt], postprocess=False).then( |
|
bot, chatbot, chatbot |
|
) |
|
|
|
audio.change(fn=get_output, inputs=[chatbot, audio, radio], outputs=[chatbot], show_progress=True).then( |
|
bot, chatbot, chatbot |
|
) |
|
|
|
radio.change(fn=set_model, inputs=[chatbot, radio], outputs=[chatbot]).then(bot, chatbot, chatbot) |
|
|
|
audio.change(lambda: None, None, audio) |
|
|
|
set_model(chatbot, radio.value) |
|
|
|
if __name__ == "__main__": |
|
demo.queue() |
|
demo.queue(concurrency_count=5) |
|
demo.launch(debug=True) |
|
|