Spaces:
Running
Running
import gradio as gr | |
import torch | |
import faiss | |
import numpy as np | |
import datasets | |
from transformers import AutoTokenizer, AutoModel | |
title = "HouseMD bot" | |
description = "Gradio Demo for telegram bot.\ | |
To use it, simply add your text message.\ | |
I've used the API on this Space to deploy the model on a Telegram bot." | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def embed_bert_cls(text, model, tokenizer): | |
t = tokenizer(text, padding=True, truncation=True, return_tensors='pt') | |
with torch.no_grad(): | |
model_output = model(**{k: v.to(model.device) for k, v in t.items()}) | |
embeds = model_output.last_hidden_state[:, 0, :] | |
embeds = torch.nn.functional.normalize(embeds) | |
return embeds[0].cpu().numpy() | |
def get_ranked_docs(query, vec_query_base, data, | |
bi_model, bi_tok, cross_model, cross_tok): | |
vec_shape = vec_query_base.shape[1] | |
index = faiss.IndexFlatL2(vec_shape) | |
index.add(vec_query_base) | |
xq = embed_bert_cls(query, bi_model, bi_tok) | |
_, I = index.search(xq.reshape(1, vec_shape), 50) | |
corpus = [data[int(i)]['answer'] for i in I[0]] | |
queries = [query] * len(corpus) | |
tokenized_texts = cross_tok( | |
queries, corpus, max_length=128, padding=True, truncation=True, return_tensors="pt" | |
).to(device) | |
with torch.no_grad(): | |
model_output = cross_model( | |
**{k: v.to(cross_model.device) for k, v in tokenized_texts.items()} | |
) | |
ce_scores = model_output.last_hidden_state[:, 0, :] | |
ce_scores = np.matmul(ce_scores, ce_scores.T) | |
scores = ce_scores.cpu().numpy() | |
scores_ix = np.argsort(scores)[::-1] | |
return corpus[scores_ix[0][0]] | |
def load_dataset(url='ekaterinatao/house_md_context3'): | |
dataset = datasets.load_dataset(url, split='train') | |
house_dataset = dataset.filter(lambda row: row['labels'] == 0) | |
return house_dataset | |
def load_cls_base(url='ekaterinatao/house_md_cls_embeds'): | |
cls_dataset = datasets.load_dataset(url, split='train') | |
cls_base = np.stack([embed['cls_embeds'] for embed in cls_dataset]) | |
return cls_base | |
def load_bi_enc_model(checkpoint='ekaterinatao/house-md-bot-bert-bi-encoder'): | |
bi_model = AutoModel.from_pretrained(checkpoint) | |
bi_tok = AutoTokenizer.from_pretrained(checkpoint) | |
return bi_model, bi_tok | |
def load_cross_enc_model(checkpoint='ekaterinatao/house-md-bot-bert-cross-encoder'): | |
cross_model = AutoModel.from_pretrained(checkpoint) | |
cross_tok = AutoTokenizer.from_pretrained(checkpoint) | |
return cross_model, cross_tok | |
def get_answer(message): | |
dataset = load_dataset() | |
cls_base = load_cls_base() | |
bi_enc_model = load_bi_enc_model() | |
cross_enc_model = load_cross_enc_model() | |
answer = get_ranked_docs( | |
query=message, vec_query_base=cls_base, data=dataset, | |
bi_model=bi_enc_model[0], bi_tok=bi_enc_model[1], | |
cross_model=cross_enc_model[0], cross_tok=cross_enc_model[1] | |
) | |
return answer | |
interface = gr.Interface( | |
fn=get_answer, | |
inputs=gr.Textbox(label="Input message to House MD", lines=3), | |
outputs=gr.Textbox(label="House MD's answer"), | |
title=title, | |
description=description | |
) | |
interface.launch() |