|
import gradio as gr
|
|
import io
|
|
import os
|
|
import yaml
|
|
import pyarrow
|
|
import tokenizers
|
|
from retro_reader import RetroReader
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
|
|
def from_library():
|
|
from retro_reader import constants as C
|
|
return C, RetroReader
|
|
|
|
C, RetroReader = from_library()
|
|
|
|
|
|
def load_model(config_path):
|
|
return RetroReader.load(config_file=config_path)
|
|
|
|
|
|
model_base = load_model("configs/inference_en_electra_base.yaml")
|
|
model_large = load_model("configs/inference_en_electra_large.yaml")
|
|
|
|
def retro_reader_demo(query, context, model_choice):
|
|
model = model_base if model_choice == "Base" else model_large
|
|
outputs = model(query=query, context=context, return_submodule_outputs=True)
|
|
answer = outputs[0]["id-01"] if outputs[0]["id-01"] else "No answer found"
|
|
return answer
|
|
|
|
|
|
iface = gr.Interface(
|
|
fn=retro_reader_demo,
|
|
inputs=[
|
|
gr.Textbox(label="Query", placeholder="Type your query here..."),
|
|
gr.Textbox(label="Context", placeholder="Provide the context here...", lines=10),
|
|
gr.Radio(choices=["Base", "Large"], label="Model Choice")
|
|
],
|
|
outputs=gr.Textbox(label="Answer"),
|
|
title="Retrospective Reader Demo",
|
|
description="This interface uses the RetroReader model to perform reading comprehension tasks."
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
iface.launch(share=True)
|
|
|