query2osm / app.py
ellenhp's picture
Check in space
bc83430
raw
history blame
1.96 kB
from gradio.components import Component
import torch
from hydra import Hydra
from transformers import AutoTokenizer
import gradio as gr
from hydra import Hydra
import os
from typing import Any, Optional
model_name = "ellenhp/query2osm-bert-v1"
tokenizer = AutoTokenizer.from_pretrained(model_name, padding=True)
model = Hydra.from_pretrained(model_name).to('cpu')
class DatasetSaver(gr.FlaggingCallback):
inner: Optional[gr.HuggingFaceDatasetSaver] = None
def __init__(self, inner):
self.inner = inner
def setup(self, components: list[Component], flagging_dir: str):
self.inner.setup(components, flagging_dir)
def flag(self,
flag_data: list[Any],
flag_option: str = "",
username: str | None = None):
flag_data = [flag_data[0], {"label": flag_data[1]['label']}]
self.inner.flag(flag_data, flag_option, None)
HF_TOKEN = os.getenv('HF_TOKEN')
if HF_TOKEN is not None:
hf_writer = gr.HuggingFaceDatasetSaver(
HF_TOKEN, "osm-queries-crowdsourced", True, "data.csv", False)
else:
hf_writer = None
flag_callback = DatasetSaver(hf_writer)
def predict(input_query):
with torch.no_grad():
print(input_query)
input_text = input_query.strip().lower()
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.forward(inputs.input_ids)
return {classification[0]: classification[1] for classification in outputs.classifications[0]}
textbox = gr.Textbox(label="Query",
placeholder="Where can I get a quick bite to eat?")
label = gr.Label(label="Result", num_top_classes=5)
gradio_app = gr.Interface(
predict,
inputs=[textbox],
outputs=[label],
title="Query Classification",
allow_flagging="manual",
flagging_options=["potentially harmful", "wrong classification"],
flagging_callback=flag_callback,
)
if __name__ == "__main__":
gradio_app.launch()