File size: 4,328 Bytes
9357625 9e35a98 9357625 0090cdd 9e35a98 f424cff 76e81d0 9357625 76e81d0 9357625 76e81d0 ca29824 76e81d0 40025b2 76e81d0 9e35a98 12163f7 76e81d0 9e35a98 0090cdd 76e81d0 0090cdd 9e35a98 40025b2 76e81d0 151932e 76e81d0 40025b2 76e81d0 9e35a98 40025b2 76e81d0 40025b2 9e35a98 0090cdd 9e35a98 76e81d0 9e35a98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import time
import uuid
import gradio as gr
from gtts import gTTS
from transformers import pipeline
from main import index, run
p = pipeline("automatic-speech-recognition", model="openai/whisper-base")
"""Use text to call chat method from main.py"""
models = ["GPT-3.5", "Flan UL2", "Flan T5"]
with gr.Blocks(theme='snehilsanyal/scikit-learn') as demo:
state = gr.State([])
def create_session_id():
return str(uuid.uuid4())
def add_text(history, text, model):
print("Question asked: " + text)
response = run_model(text, model)
history = history + [(text, response)]
print(history)
return history, ""
def run_model(text, model):
start_time = time.time()
print("start time:" + str(start_time))
response = run(text, model, state.session_id)
end_time = time.time()
# If response contains string `SOURCES:`, then add a \n before `SOURCES`
if "SOURCES:" in response:
response = response.replace("SOURCES:", "\nSOURCES:")
# response = response + "\n\n" + "Time taken: " + str(end_time - start_time)
print(response)
print("Time taken: " + str(end_time - start_time))
return response
def get_output(history, audio, model):
txt = p(audio)["text"]
# history.append(( (audio, ) , txt))
audio_path = 'response.wav'
response = run_model(txt, model)
# Remove all text from SOURCES: to the end of the string
trimmed_response = response.split("SOURCES:")[0]
myobj = gTTS(text=trimmed_response, lang='en', slow=False)
myobj.save(audio_path)
# split audio by / and keep the last element
# audio = audio.split("/")[-1]
# audio = audio + ".wav"
history.append(((audio,), (audio_path,)))
print(history)
return history
def set_model(history, model):
print("Model selected: " + model)
history = get_first_message(history)
index(model, state.session_id)
return history
def get_first_message(history):
history = [(None,
'Learn about the course and get answers with referred sources.\nWarning! Use the bot wisely. It might give incorrect answers.')]
return history
def bot(history):
return history
state.session_id = create_session_id()
print("Session ID: " + state.session_id)
# Title on top in middle of the page
# gr.HTML("<h1 style='text-align: center;'>Course Assistant - 3D Printing Revolution</h1>")
chatbot = gr.Chatbot(get_first_message([]), elem_id="chatbot", label='3D Printing Revolution').style(height=300,
container=False)
# with gr.Row():
# Create radio button to select model
radio = gr.Radio(models, label="Choose a model", value="GPT-3.5", type="value", visible=False)
with gr.Row():
# with gr.Column(scale=0.75):
txt = gr.Textbox(
label="Ask your question here and press enter",
placeholder="Enter text and press enter", lines=1
).style(container=False)
# with gr.Column(scale=0.25):
audio = gr.Audio(source="microphone", type="filepath", visible=False)
with gr.Row():
gr.Examples(
examples=['What is 3D printing?', 'Who are the instructors of the course?', 'What is the course about?',
'Which software can be used to create a design file for 3D printing?',
'What are the key takeaways from the course?', 'How to create a 3D printing design file?'],
inputs=[txt],
label="Examples")
txt.submit(add_text, [chatbot, txt, radio], [chatbot, txt], postprocess=False).then(
bot, chatbot, chatbot
)
audio.change(fn=get_output, inputs=[chatbot, audio, radio], outputs=[chatbot], show_progress=True).then(
bot, chatbot, chatbot
)
radio.change(fn=set_model, inputs=[chatbot, radio], outputs=[chatbot]).then(bot, chatbot, chatbot)
audio.change(lambda: None, None, audio)
set_model(chatbot, radio.value)
if __name__ == "__main__":
demo.queue()
demo.queue(concurrency_count=5)
demo.launch(debug=True)
|