ekaterinatao commited on
Commit
ef9812d
·
verified ·
1 Parent(s): d90a3af

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ title = "HouseMD bot"
4
+
5
+ description = "Gradio Demo for telegram bot \
6
+ To use it, simply add your text message. \
7
+ I've used the API on this Space to deploy the model on a Telegram bot."
8
+
9
+ import torch
10
+ import faiss
11
+ import numpy as np
12
+ import pandas as pd
13
+ import datasets
14
+ from transformers import AutoTokenizer, AutoModel
15
+
16
+
17
+ def embed_bert_cls(text, model, tokenizer):
18
+ t = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
19
+ with torch.no_grad():
20
+ model_output = model(**{k: v.to(model.device) for k, v in t.items()})
21
+ embeds = model_output.last_hidden_state[:, 0, :]
22
+ embeds = torch.nn.functional.normalize(embeds)
23
+ return embeds[0].cpu().numpy()
24
+
25
+
26
+ def get_ranked_docs(query, vec_query_base, data,
27
+ bi_model, bi_tok, cross_model, cross_tok):
28
+
29
+ vec_shape = vec_query_base.shape[1]
30
+ index = faiss.IndexFlatL2(vec_shape)
31
+ index.add(vec_query_base)
32
+ xq = embed_bert_cls(query, bi_model, bi_tok)
33
+ D, I = index.search(xq.reshape(1, vec_shape), 50)
34
+
35
+ corpus = []
36
+ for i in I[0]:
37
+ corpus.append(data['answer'][i])
38
+
39
+ queries = [query] * len(corpus)
40
+ tokenized_texts = cross_tok(
41
+ queries, corpus, max_length=128, padding=True, truncation=True, return_tensors="pt"
42
+ ).to(config.model.device)
43
+
44
+ with torch.no_grad():
45
+ ce_scores = cross_model(
46
+ tokenized_texts['input_ids'], tokenized_texts['attention_mask']
47
+ ).last_hidden_state[:, 0, :]
48
+ ce_scores = ce_scores @ ce_scores.T
49
+
50
+ scores = ce_scores.cpu().numpy()
51
+ scores_ix = np.argsort(scores)[::-1]
52
+
53
+ return corpus[scores_ix[0][0]]
54
+
55
+
56
+ def load_dataset(url='ekaterinatao/house_md_context3'):
57
+
58
+ dataset = datasets.load_dataset(url, split='train')
59
+ house_dataset = []
60
+ for data in dataset:
61
+ if data['labels'] == 0:
62
+ house_dataset.append(data)
63
+
64
+ return house_dataset
65
+
66
+
67
+ def load_cls_base(url='ekaterinatao/house_md_cls_embeds'):
68
+
69
+ cls_dataset = datasets.load_dataset(url, split='train')
70
+ cls_base = np.stack([embed for embed in pd.DataFrame(cls_dataset)['cls_embeds']])
71
+
72
+ return cls_base
73
+
74
+
75
+ def load_bi_enc_model(checkpoint='ekaterinatao/house-md-bot-bert-bi-encoder'):
76
+
77
+ bi_model = AutoModel.from_pretrained(checkpoint)
78
+ bi_tok = AutoTokenizer.from_pretrained(checkpoint)
79
+
80
+ return bi_model, bi_tok
81
+
82
+
83
+ def load_cross_enc_model(checkpoint='ekaterinatao/house-md-bot-bert-cross-encoder'):
84
+
85
+ cross_model = AutoModel.from_pretrained(checkpoint)
86
+ cross_tok = AutoTokenizer.from_pretrained(checkpoint)
87
+
88
+ return cross_model, cross_tok
89
+
90
+
91
+ def get_answer(message):
92
+
93
+ dataset = load_dataset()
94
+ cls_base = load_cls_base()
95
+ bi_enc_model = load_bi_enc_model()
96
+ cross_enc_model = load_cross_enc_model()
97
+
98
+ answer = get_ranked_docs(
99
+ query=message, vec_query_base=cls_base, data=dataset,
100
+ bi_model=bi_enc_model[0], bi_tok=bi_enc_model[1],
101
+ cross_model=cross_enc_model[0], cross_tok=cross_enc_model[1]
102
+ )
103
+ return answer
104
+
105
+ interface = gr.Interface(
106
+ fn=get_answer,
107
+ inputs=gr.inputs.Textbox(lines=3, label="Input message to House MD"),
108
+ outputs=gr.Textbox(label="House MD's answer"),
109
+ title=title,
110
+ description=description,
111
+ enable_queue=True
112
+ )
113
+ interface.launch(debug=True)