from gradio.components import Component import torch from hydra import Hydra from transformers import AutoTokenizer import gradio as gr from hydra import Hydra import os from typing import Any, Optional model_name = "ellenhp/query2osm-bert-v1" tokenizer = AutoTokenizer.from_pretrained(model_name, padding=True) model = Hydra.from_pretrained(model_name).to('cpu') class DatasetSaver(gr.FlaggingCallback): inner: Optional[gr.HuggingFaceDatasetSaver] = None def __init__(self, inner): self.inner = inner def setup(self, components: list[Component], flagging_dir: str): self.inner.setup(components, flagging_dir) def flag(self, flag_data: list[Any], flag_option: str = "", username: str | None = None): flag_data = [flag_data[0], {"label": flag_data[1]['label']}] self.inner.flag(flag_data, flag_option, None) HF_TOKEN = os.getenv('HF_TOKEN') if HF_TOKEN is not None: hf_writer = gr.HuggingFaceDatasetSaver( HF_TOKEN, "osm-queries-crowdsourced", True, "data.csv", False) else: hf_writer = None flag_callback = DatasetSaver(hf_writer) def predict(input_query): with torch.no_grad(): print(input_query) input_text = input_query.strip().lower() inputs = tokenizer(input_text, return_tensors="pt") outputs = model.forward(inputs.input_ids) return {classification[0]: classification[1] for classification in outputs.classifications[0]} textbox = gr.Textbox(label="Query", placeholder="Where can I get a quick bite to eat?") label = gr.Label(label="Result", num_top_classes=5) gradio_app = gr.Interface( predict, inputs=[textbox], outputs=[label], title="Query Classification", allow_flagging="manual", flagging_options=["potentially harmful", "wrong classification"], flagging_callback=flag_callback, ) if __name__ == "__main__": gradio_app.launch()