ekaterinatao commited on
Commit
9b436ec
·
verified ·
1 Parent(s): 97b1a14

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -20
app.py CHANGED
@@ -2,16 +2,17 @@ import gradio as gr
2
  import torch
3
  import faiss
4
  import numpy as np
5
- import pandas as pd
6
  import datasets
7
  from transformers import AutoTokenizer, AutoModel
8
 
9
  title = "HouseMD bot"
10
 
11
- description = "Gradio Demo for telegram bot \
12
- To use it, simply add your text message. \
13
  I've used the API on this Space to deploy the model on a Telegram bot."
14
 
 
 
15
 
16
  def embed_bert_cls(text, model, tokenizer):
17
  t = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
@@ -29,23 +30,20 @@ def get_ranked_docs(query, vec_query_base, data,
29
  index = faiss.IndexFlatL2(vec_shape)
30
  index.add(vec_query_base)
31
  xq = embed_bert_cls(query, bi_model, bi_tok)
32
- D, I = index.search(xq.reshape(1, vec_shape), 50)
33
-
34
- corpus = []
35
- for i in I[0]:
36
- corpus.append(data['answer'][i])
37
 
38
  queries = [query] * len(corpus)
39
  tokenized_texts = cross_tok(
40
  queries, corpus, max_length=128, padding=True, truncation=True, return_tensors="pt"
41
- ).to(config.model.device)
42
 
43
  with torch.no_grad():
44
- ce_scores = cross_model(
45
- tokenized_texts['input_ids'], tokenized_texts['attention_mask']
46
- ).last_hidden_state[:, 0, :]
47
- ce_scores = ce_scores @ ce_scores.T
48
-
49
  scores = ce_scores.cpu().numpy()
50
  scores_ix = np.argsort(scores)[::-1]
51
 
@@ -55,18 +53,15 @@ def get_ranked_docs(query, vec_query_base, data,
55
  def load_dataset(url='ekaterinatao/house_md_context3'):
56
 
57
  dataset = datasets.load_dataset(url, split='train')
58
- house_dataset = []
59
- for data in dataset:
60
- if data['labels'] == 0:
61
- house_dataset.append(data)
62
 
63
- return pd.DataFrame(house_dataset)
64
 
65
 
66
  def load_cls_base(url='ekaterinatao/house_md_cls_embeds'):
67
 
68
  cls_dataset = datasets.load_dataset(url, split='train')
69
- cls_base = np.stack([embed for embed in pd.DataFrame(cls_dataset)['cls_embeds']])
70
 
71
  return cls_base
72
 
@@ -101,6 +96,7 @@ def get_answer(message):
101
  )
102
  return answer
103
 
 
104
  interface = gr.Interface(
105
  fn=get_answer,
106
  inputs=gr.Textbox(label="Input message to House MD", lines=3),
 
2
  import torch
3
  import faiss
4
  import numpy as np
 
5
  import datasets
6
  from transformers import AutoTokenizer, AutoModel
7
 
8
  title = "HouseMD bot"
9
 
10
+ description = "Gradio Demo for telegram bot.\
11
+ To use it, simply add your text message.\n\
12
  I've used the API on this Space to deploy the model on a Telegram bot."
13
 
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
 
17
  def embed_bert_cls(text, model, tokenizer):
18
  t = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
 
30
  index = faiss.IndexFlatL2(vec_shape)
31
  index.add(vec_query_base)
32
  xq = embed_bert_cls(query, bi_model, bi_tok)
33
+ _, I = index.search(xq.reshape(1, vec_shape), 50)
34
+ corpus = [data[int(i)]['answer'] for i in I[0]]
 
 
 
35
 
36
  queries = [query] * len(corpus)
37
  tokenized_texts = cross_tok(
38
  queries, corpus, max_length=128, padding=True, truncation=True, return_tensors="pt"
39
+ ).to(device)
40
 
41
  with torch.no_grad():
42
+ model_output = cross_model(
43
+ **{k: v.to(cross_model.device) for k, v in tokenized_texts.items()}
44
+ )
45
+ ce_scores = model_output.last_hidden_state[:, 0, :]
46
+ ce_scores = np.matmul(ce_scores, ce_scores.T)
47
  scores = ce_scores.cpu().numpy()
48
  scores_ix = np.argsort(scores)[::-1]
49
 
 
53
  def load_dataset(url='ekaterinatao/house_md_context3'):
54
 
55
  dataset = datasets.load_dataset(url, split='train')
56
+ house_dataset = dataset.filter(lambda row: row['labels'] == 0)
 
 
 
57
 
58
+ return house_dataset
59
 
60
 
61
  def load_cls_base(url='ekaterinatao/house_md_cls_embeds'):
62
 
63
  cls_dataset = datasets.load_dataset(url, split='train')
64
+ cls_base = np.stack([embed['cls_embeds'] for embed in cls_dataset])
65
 
66
  return cls_base
67
 
 
96
  )
97
  return answer
98
 
99
+
100
  interface = gr.Interface(
101
  fn=get_answer,
102
  inputs=gr.Textbox(label="Input message to House MD", lines=3),