Spaces:
Runtime error
Runtime error
File size: 2,108 Bytes
7b86ace e32f802 7b86ace e32f802 7b86ace e32f802 7b86ace 50ed058 e32f802 7b86ace 4cdf93f e32f802 7b86ace 3326af6 e32f802 7b86ace e32f802 7b86ace e32f802 d5175dd e32f802 d5175dd e32f802 0fcce2a e32f802 d5175dd e32f802 d5175dd e32f802 7b86ace 588b6a9 5d88c86 e32f802 7b86ace |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
#%%
import pandas as pd
import numpy as np
import torch
import json
import re
from sentence_transformers.util import cos_sim
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import CountVectorizer
import gradio as gr
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
#%%
model = SentenceTransformer('sentence-transformers/multi-qa-distilbert-cos-v1')
russian_stopwords = stopwords.words('russian') + ['ВАШ']
with open("top_150_symps_by_spec.json", 'r') as f:
symps = json.load(f)
with open("embeddings.npy", 'rb') as f:
embs = np.load(f)
def remove_numbers(text):
text = re.sub(r'\d+', '', text)
text = re.sub(r'[^\w\s]', '', text)
return text.strip()
vectorizer = CountVectorizer(ngram_range=(1, 3),
stop_words=russian_stopwords,
preprocessor=remove_numbers,
)
def get_symptomps_v2(text, treshold = 0.7):
try:
if isinstance(text, str):
text = [text]
X = vectorizer.fit_transform(text)
text_emb = model.encode(vectorizer.get_feature_names_out(), batch_size=64)
cos_sim_m = cos_sim(text_emb, embs).numpy()
cos_sim_m = np.where(cos_sim_m > treshold, cos_sim_m, -1)
arg_max_idx = np.argmax(cos_sim_m, axis=1)
outputs = []
for idx, cos_sim_row in zip(arg_max_idx, cos_sim_m):
if cos_sim_row[idx] > 0:
outputs.append(symps[idx])
if len(outputs) == 0:
return ['Симптомы не определены']
return np.unique(outputs).tolist()
except:
return ['Симптомы не определены']
#%%
gradio_app = gr.Interface(
get_symptomps_v2,
inputs=['text',
gr.Slider(minimum=0, maximum=1, step=0.05, label="Порог релевантности", value=0.8)],
outputs=[gr.JSON(label='Симптомы: ')],
description="Введите услугу:"
)
if __name__ == "__main__":
gradio_app.launch()
# %%
|