|
from gradio.components import Component |
|
import torch |
|
from hydra import Hydra |
|
from transformers import AutoTokenizer |
|
import gradio as gr |
|
from hydra import Hydra |
|
import os |
|
from typing import Any, Optional |
|
|
|
model_name = "ellenhp/query2osm-bert-v1" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, padding=True) |
|
model = Hydra.from_pretrained(model_name).to('cpu') |
|
|
|
|
|
class DatasetSaver(gr.FlaggingCallback): |
|
inner: Optional[gr.HuggingFaceDatasetSaver] = None |
|
|
|
def __init__(self, inner): |
|
self.inner = inner |
|
|
|
def setup(self, components: list[Component], flagging_dir: str): |
|
self.inner.setup(components, flagging_dir) |
|
|
|
def flag(self, |
|
flag_data: list[Any], |
|
flag_option: str = "", |
|
username: str | None = None): |
|
flag_data = [flag_data[0], {"label": flag_data[1]['label']}] |
|
self.inner.flag(flag_data, flag_option, None) |
|
|
|
|
|
HF_TOKEN = os.getenv('HF_TOKEN') |
|
if HF_TOKEN is not None: |
|
hf_writer = gr.HuggingFaceDatasetSaver( |
|
HF_TOKEN, "osm-queries-crowdsourced", True, "data.csv", False) |
|
else: |
|
hf_writer = None |
|
|
|
|
|
flag_callback = DatasetSaver(hf_writer) |
|
|
|
|
|
def predict(input_query): |
|
with torch.no_grad(): |
|
print(input_query) |
|
input_text = input_query.strip().lower() |
|
inputs = tokenizer(input_text, return_tensors="pt") |
|
outputs = model.forward(inputs.input_ids) |
|
return {classification[0]: classification[1] for classification in outputs.classifications[0]} |
|
|
|
|
|
textbox = gr.Textbox(label="Query", |
|
placeholder="Where can I get a quick bite to eat?") |
|
label = gr.Label(label="Result", num_top_classes=5) |
|
|
|
gradio_app = gr.Interface( |
|
predict, |
|
inputs=[textbox], |
|
outputs=[label], |
|
title="Query Classification", |
|
allow_flagging="manual", |
|
flagging_options=["potentially harmful", "wrong classification"], |
|
flagging_callback=flag_callback, |
|
) |
|
|
|
if __name__ == "__main__": |
|
gradio_app.launch() |
|
|