File size: 3,179 Bytes
ef9812d
 
 
 
 
 
 
d2b1190
 
9b436ec
64b430f
d2b1190
 
9b436ec
 
ef9812d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b436ec
 
ef9812d
 
 
 
9b436ec
ef9812d
 
9b436ec
 
 
 
 
ef9812d
 
 
 
 
 
 
 
 
9b436ec
ef9812d
9b436ec
ef9812d
 
 
 
 
9b436ec
ef9812d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b436ec
ef9812d
 
8e173c1
ef9812d
 
baf8663
ef9812d
64b430f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
import torch
import faiss
import numpy as np
import datasets
from transformers import AutoTokenizer, AutoModel

title = "HouseMD bot"

description = "Gradio Demo for telegram bot.\
To use it, simply add your text message.\
I've used the API on this Space to deploy the model on a Telegram bot."

device = "cuda" if torch.cuda.is_available() else "cpu"


def embed_bert_cls(text, model, tokenizer):
    t = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
    with torch.no_grad():
        model_output = model(**{k: v.to(model.device) for k, v in t.items()})
    embeds = model_output.last_hidden_state[:, 0, :]
    embeds = torch.nn.functional.normalize(embeds)
    return embeds[0].cpu().numpy()


def get_ranked_docs(query, vec_query_base, data,
    bi_model, bi_tok, cross_model, cross_tok):

    vec_shape = vec_query_base.shape[1]
    index = faiss.IndexFlatL2(vec_shape)
    index.add(vec_query_base)
    xq = embed_bert_cls(query, bi_model, bi_tok)
    _, I = index.search(xq.reshape(1, vec_shape), 50)
    corpus = [data[int(i)]['answer'] for i in I[0]]

    queries = [query] * len(corpus)
    tokenized_texts = cross_tok(
        queries, corpus, max_length=128, padding=True, truncation=True, return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        model_output = cross_model(
            **{k: v.to(cross_model.device) for k, v in tokenized_texts.items()}
        )
    ce_scores = model_output.last_hidden_state[:, 0, :]
    ce_scores = np.matmul(ce_scores, ce_scores.T)
    scores = ce_scores.cpu().numpy()
    scores_ix = np.argsort(scores)[::-1]

    return corpus[scores_ix[0][0]]


def load_dataset(url='ekaterinatao/house_md_context3'):

    dataset = datasets.load_dataset(url, split='train')
    house_dataset = dataset.filter(lambda row: row['labels'] == 0)

    return house_dataset


def load_cls_base(url='ekaterinatao/house_md_cls_embeds'):

    cls_dataset = datasets.load_dataset(url, split='train')
    cls_base = np.stack([embed['cls_embeds'] for embed in cls_dataset])

    return cls_base


def load_bi_enc_model(checkpoint='ekaterinatao/house-md-bot-bert-bi-encoder'):

    bi_model = AutoModel.from_pretrained(checkpoint)
    bi_tok = AutoTokenizer.from_pretrained(checkpoint)

    return bi_model, bi_tok


def load_cross_enc_model(checkpoint='ekaterinatao/house-md-bot-bert-cross-encoder'):

    cross_model = AutoModel.from_pretrained(checkpoint)
    cross_tok = AutoTokenizer.from_pretrained(checkpoint)

    return cross_model, cross_tok


def get_answer(message):

    dataset = load_dataset()
    cls_base = load_cls_base()
    bi_enc_model = load_bi_enc_model()
    cross_enc_model = load_cross_enc_model()

    answer = get_ranked_docs(
        query=message, vec_query_base=cls_base, data=dataset,
        bi_model=bi_enc_model[0], bi_tok=bi_enc_model[1],
        cross_model=cross_enc_model[0], cross_tok=cross_enc_model[1]
    )
    return answer


interface = gr.Interface(
    fn=get_answer,
    inputs=gr.Textbox(label="Input message to House MD", lines=3),
    outputs=gr.Textbox(label="House MD's answer"),
    title=title,
    description=description
)
interface.launch()