File size: 2,108 Bytes
7b86ace
 
 
 
e32f802
 
7b86ace
 
e32f802
7b86ace
e32f802
 
 
7b86ace
 
50ed058
e32f802
7b86ace
4cdf93f
e32f802
7b86ace
3326af6
e32f802
7b86ace
e32f802
 
 
 
7b86ace
e32f802
 
 
 
 
 
d5175dd
e32f802
 
 
 
d5175dd
e32f802
 
0fcce2a
e32f802
d5175dd
e32f802
d5175dd
e32f802
 
 
 
 
 
 
 
 
7b86ace
 
588b6a9
5d88c86
 
e32f802
7b86ace
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#%%
import pandas as pd
import numpy as np
import torch
import json
import re
from sentence_transformers.util import cos_sim
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import CountVectorizer
import gradio as gr
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
#%%

model = SentenceTransformer('sentence-transformers/multi-qa-distilbert-cos-v1')
russian_stopwords = stopwords.words('russian') + ['ВАШ']

with open("top_150_symps_by_spec.json", 'r') as f:
    symps = json.load(f)

with open("embeddings.npy", 'rb') as f:
    embs = np.load(f)

def remove_numbers(text):
    text = re.sub(r'\d+', '', text)
    text = re.sub(r'[^\w\s]', '', text)
    return text.strip()



vectorizer = CountVectorizer(ngram_range=(1, 3),
                             stop_words=russian_stopwords,
                             preprocessor=remove_numbers,
                             )
    
def get_symptomps_v2(text, treshold = 0.7):
    try:
        if isinstance(text, str):
            text = [text]
            
        X = vectorizer.fit_transform(text)
        text_emb = model.encode(vectorizer.get_feature_names_out(), batch_size=64)
        cos_sim_m = cos_sim(text_emb, embs).numpy()
        cos_sim_m = np.where(cos_sim_m > treshold, cos_sim_m, -1)
        
        arg_max_idx = np.argmax(cos_sim_m, axis=1)
        
        outputs = []
        for idx, cos_sim_row in zip(arg_max_idx, cos_sim_m):
            if cos_sim_row[idx] > 0:
                outputs.append(symps[idx])
        if len(outputs) == 0:
            return ['Симптомы не определены']
        return np.unique(outputs).tolist()
    except:
        return ['Симптомы не определены']
#%%
gradio_app = gr.Interface(
    get_symptomps_v2,
    inputs=['text',
            gr.Slider(minimum=0, maximum=1, step=0.05, label="Порог релевантности", value=0.8)],
    outputs=[gr.JSON(label='Симптомы: ')],
    description="Введите услугу:"
)

if __name__ == "__main__":
    gradio_app.launch()
# %%