Mirror / src /app /demo1_deprecated.py
Spico's picture
update
5953ef9
import gradio as gr
from rex.utils.initialization import set_seed_and_log_path
from rex.utils.logging import logger
from src.task import MrcQaTask, SchemaGuidedInstructBertTask
set_seed_and_log_path(log_path="app.log")
class MrcQaPipeline:
def __init__(self, task_dir: str, load_path: str = None) -> None:
self.task = MrcQaTask.from_taskdir(
task_dir, load_best_model=load_path is None, initialize=False
)
if load_path:
self.task.load(load_path, load_history=False)
def predict(self, query, context, background=None):
data = [
{
"query": query,
"context": context,
"background": background,
}
]
results = self.task.predict(data)
ret = results[0]
data[0]["pred"] = ret
logger.opt(colors=False).debug(data[0])
return ret
class InstructBertPipeline:
def __init__(self, task_dir: str, load_path: str = None) -> None:
self.task = SchemaGuidedInstructBertTask.from_taskdir(
task_dir, load_best_model=load_path is None, initialize=False
)
if load_path:
self.task.load(load_path, load_history=False)
def predict(self, instruction, schema, text, background):
data = [
{
"query": query,
"context": context,
"background": background,
}
]
results = self.task.predict(data)
ret = results[0]
data[0]["pred"] = ret
logger.opt(colors=False).debug(data[0])
return ret
def mrc_qa():
pipe = Pipeline("outputs/RobertaBase_data20230314v2")
with gr.Blocks() as demo:
gr.Markdown("# 🪞 Mirror Mirror")
with gr.Row():
with gr.Column():
with gr.Row():
query = gr.Textbox(
label="Query", placeholder="Mirror Mirror, tell me ..."
)
with gr.Row():
context = gr.TextArea(
label="Candidates",
placeholder="Separated by comma (,) without spaces.",
)
with gr.Row():
background = gr.TextArea(
label="Background",
placeholder="Background explanation, could be empty",
)
with gr.Column():
with gr.Row():
trigger_button = gr.Button("Tell me the truth", variant="primary")
with gr.Row():
output = gr.TextArea(label="Output")
trigger_button.click(
pipe.predict, inputs=[query, context, background], outputs=output
)
demo.launch(show_error=True, share=False)
def instruct_bert_pipeline():
task = SchemaGuidedInstructBertTask.from_taskdir()