query2osm / app.py
ellenhp's picture
Update to use new and improved bert model
b4df2a0
raw
history blame
1.96 kB
from gradio.components import Component
import torch
from hydra import Hydra
from transformers import AutoTokenizer
import gradio as gr
from hydra import Hydra
import os
from typing import Any, Optional
model_name = "ellenhp/query2osm-bert-v1"
tokenizer = AutoTokenizer.from_pretrained(model_name, padding=True)
model = Hydra.from_pretrained(model_name).to('cpu')
class DatasetSaver(gr.FlaggingCallback):
inner: Optional[gr.HuggingFaceDatasetSaver] = None
def __init__(self, inner):
self.inner = inner
def setup(self, components: list[Component], flagging_dir: str):
self.inner.setup(components, flagging_dir)
def flag(self,
flag_data: list[Any],
flag_option: str = "",
username: str | None = None):
flag_data = [flag_data[0], {"label": flag_data[1]['label']}]
self.inner.flag(flag_data, flag_option, None)
HF_TOKEN = os.getenv('HF_TOKEN')
if HF_TOKEN is not None:
hf_writer = gr.HuggingFaceDatasetSaver(
HF_TOKEN, "osm-queries-crowdsourced", True, "data.csv", False)
else:
hf_writer = None
flag_callback = DatasetSaver(hf_writer)
def predict(input_query):
with torch.no_grad():
print(input_query)
input_text = input_query.strip().lower()
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.forward(inputs.input_ids)
return {classification[0]: classification[1] for classification in outputs.classifications[0]}
textbox = gr.Textbox(label="Query",
placeholder="Quick bite to eat near me")
label = gr.Label(label="Result", num_top_classes=5)
gradio_app = gr.Interface(
predict,
inputs=[textbox],
outputs=[label],
title="Query Classification",
allow_flagging="manual",
flagging_options=["correct classification", "incorrect classification"],
flagging_callback=flag_callback,
)
if __name__ == "__main__":
gradio_app.launch()