Spaces:
Sleeping
Sleeping
import json | |
import numpy as np | |
import os | |
from transformers import BertTokenizer | |
from rank_bm25 import BM25Okapi | |
import gradio as gr | |
HF_TOKEN = os.getenv('HF_TOKEN') | |
hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "budu_search_data") | |
tokenizer = BertTokenizer.from_pretrained("DeepPavlov/rubert-base-cased") | |
f = open('budu_search_syn_database.json') | |
database = json.load(f) | |
b25corpus = [x for x in database.values()] | |
b25local_names = [x for x in database.keys()] | |
bm25 = BM25Okapi(corpus=b25corpus) | |
def predict_bm25(service): | |
tokenized_query = tokenizer.tokenize(service.lower()) | |
doc_scores = bm25.get_scores(tokenized_query) | |
sorted_doc_indices = doc_scores.argsort()[::-1] | |
sorted_local_names = np.array([b25local_names[i] for i in sorted_doc_indices]) | |
scores = doc_scores[sorted_doc_indices] | |
scores_filtered = np.argwhere(scores>1).reshape(-1) | |
filtered_local_names = sorted_local_names[scores_filtered.tolist()].tolist() | |
if len(filtered_local_names)>3: | |
filtered_local_names = filtered_local_names[:3] | |
return filtered_local_names | |
demo = gr.Interface(fn=predict_bm25,inputs=gr.components.Textbox(label='Запрос пользователя'), | |
outputs=[gr.components.Textbox(label='Рекомендованные услуги')], | |
allow_flagging='auto', | |
flagging_callback = hf_writer, | |
examples=[ | |
['кальций'], | |
['узи'], | |
['железо'], | |
['прием']]) | |
if __name__ == "__main__": | |
demo.launch() |