import json import numpy as np import os from transformers import BertTokenizer from rank_bm25 import BM25Okapi import gradio as gr HF_TOKEN = os.getenv('HF_TOKEN') hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "budu_search_data") tokenizer = BertTokenizer.from_pretrained("DeepPavlov/rubert-base-cased") f = open('budu_search_syn_database.json') database = json.load(f) b25corpus = [x for x in database.values()] b25local_names = [x for x in database.keys()] bm25 = BM25Okapi(corpus=b25corpus) def predict_bm25(service): tokenized_query = tokenizer.tokenize(service.lower()) doc_scores = bm25.get_scores(tokenized_query) sorted_doc_indices = doc_scores.argsort()[::-1] sorted_local_names = np.array([b25local_names[i] for i in sorted_doc_indices]) scores = doc_scores[sorted_doc_indices] scores_filtered = np.argwhere(scores>1).reshape(-1) filtered_local_names = sorted_local_names[scores_filtered.tolist()].tolist() if len(filtered_local_names)>3: filtered_local_names = filtered_local_names[:3] return filtered_local_names demo = gr.Interface(fn=predict_bm25,inputs=gr.components.Textbox(label='Запрос пользователя'), outputs=[gr.components.Textbox(label='Рекомендованные услуги')], allow_flagging='auto', flagging_callback = hf_writer, examples=[ ['кальций'], ['узи'], ['железо'], ['прием']]) if __name__ == "__main__": demo.launch()