import gradio as gr from rex.utils.initialization import set_seed_and_log_path from rex.utils.logging import logger from src.task import MrcQaTask, SchemaGuidedInstructBertTask set_seed_and_log_path(log_path="app.log") class MrcQaPipeline: def __init__(self, task_dir: str, load_path: str = None) -> None: self.task = MrcQaTask.from_taskdir( task_dir, load_best_model=load_path is None, initialize=False ) if load_path: self.task.load(load_path, load_history=False) def predict(self, query, context, background=None): data = [ { "query": query, "context": context, "background": background, } ] results = self.task.predict(data) ret = results[0] data[0]["pred"] = ret logger.opt(colors=False).debug(data[0]) return ret class InstructBertPipeline: def __init__(self, task_dir: str, load_path: str = None) -> None: self.task = SchemaGuidedInstructBertTask.from_taskdir( task_dir, load_best_model=load_path is None, initialize=False ) if load_path: self.task.load(load_path, load_history=False) def predict(self, instruction, schema, text, background): data = [ { "query": query, "context": context, "background": background, } ] results = self.task.predict(data) ret = results[0] data[0]["pred"] = ret logger.opt(colors=False).debug(data[0]) return ret def mrc_qa(): pipe = Pipeline("outputs/RobertaBase_data20230314v2") with gr.Blocks() as demo: gr.Markdown("# 🪞 Mirror Mirror") with gr.Row(): with gr.Column(): with gr.Row(): query = gr.Textbox( label="Query", placeholder="Mirror Mirror, tell me ..." ) with gr.Row(): context = gr.TextArea( label="Candidates", placeholder="Separated by comma (,) without spaces.", ) with gr.Row(): background = gr.TextArea( label="Background", placeholder="Background explanation, could be empty", ) with gr.Column(): with gr.Row(): trigger_button = gr.Button("Tell me the truth", variant="primary") with gr.Row(): output = gr.TextArea(label="Output") trigger_button.click( pipe.predict, inputs=[query, context, background], outputs=output ) demo.launch(show_error=True, share=False) def instruct_bert_pipeline(): task = SchemaGuidedInstructBertTask.from_taskdir()