File size: 1,959 Bytes
bc83430 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
from gradio.components import Component
import torch
from hydra import Hydra
from transformers import AutoTokenizer
import gradio as gr
from hydra import Hydra
import os
from typing import Any, Optional
model_name = "ellenhp/query2osm-bert-v1"
tokenizer = AutoTokenizer.from_pretrained(model_name, padding=True)
model = Hydra.from_pretrained(model_name).to('cpu')
class DatasetSaver(gr.FlaggingCallback):
inner: Optional[gr.HuggingFaceDatasetSaver] = None
def __init__(self, inner):
self.inner = inner
def setup(self, components: list[Component], flagging_dir: str):
self.inner.setup(components, flagging_dir)
def flag(self,
flag_data: list[Any],
flag_option: str = "",
username: str | None = None):
flag_data = [flag_data[0], {"label": flag_data[1]['label']}]
self.inner.flag(flag_data, flag_option, None)
HF_TOKEN = os.getenv('HF_TOKEN')
if HF_TOKEN is not None:
hf_writer = gr.HuggingFaceDatasetSaver(
HF_TOKEN, "osm-queries-crowdsourced", True, "data.csv", False)
else:
hf_writer = None
flag_callback = DatasetSaver(hf_writer)
def predict(input_query):
with torch.no_grad():
print(input_query)
input_text = input_query.strip().lower()
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.forward(inputs.input_ids)
return {classification[0]: classification[1] for classification in outputs.classifications[0]}
textbox = gr.Textbox(label="Query",
placeholder="Where can I get a quick bite to eat?")
label = gr.Label(label="Result", num_top_classes=5)
gradio_app = gr.Interface(
predict,
inputs=[textbox],
outputs=[label],
title="Query Classification",
allow_flagging="manual",
flagging_options=["potentially harmful", "wrong classification"],
flagging_callback=flag_callback,
)
if __name__ == "__main__":
gradio_app.launch()
|