Spaces:
Paused
Paused
File size: 5,227 Bytes
6970ab5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
# -*- coding: utf-8 -*-
"""wiki_chat_3_hack.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1chXsWeq1LzbvYIs6H73gibYmNDRbIgkD
"""
#!pip install gradio
#!pip install -U sentence-transformers
#!pip install datasets
#!pip install langchain
#!pip install openai
#!pip install faiss-cpu
#import numpy as np
import gradio as gr
#import random
from sentence_transformers import SentenceTransformer, CrossEncoder, util
from torch import tensor as torch_tensor
from datasets import load_dataset
"""# import models"""
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens
#The bi-encoder will retrieve top_k documents. We use a cross-encoder, to re-rank the results list to improve the quality
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
"""# import datasets"""
dataset = load_dataset("gfhayworth/hack_policy", split='train')
mypassages = list(dataset.to_pandas()['psg'])
dataset_embed = load_dataset("gfhayworth/hack_policy_embed", split='train')
dataset_embed_pd = dataset_embed.to_pandas()
mycorpus_embeddings = torch_tensor(dataset_embed_pd.values)
def search(query, passages = mypassages, doc_embedding = mycorpus_embeddings, top_k=20, top_n = 1):
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
question_embedding = question_embedding #.cuda()
hits = util.semantic_search(question_embedding, doc_embedding, top_k=top_k)
hits = hits[0] # Get the hits for the first query
##### Re-Ranking #####
cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
cross_scores = cross_encoder.predict(cross_inp)
# Sort results by the cross-encoder scores
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx]
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
predictions = hits[:top_n]
return predictions
# for hit in hits[0:3]:
# print("\t{:.3f}\t{}".format(hit['cross-score'], mypassages[hit['corpus_id']].replace("\n", " ")))
def get_text(qry):
predictions = search(qry)
prediction_text = []
for hit in predictions:
prediction_text.append("{}".format(mypassages[hit['corpus_id']]))
return prediction_text
# def prt_rslt(qry):
# rslt = get_text(qry)
# for r in rslt:
# print(r)
# prt_rslt("What is the name of the plan described by this summary of benefits?")
"""# new LLM based functions"""
import os
from langchain.llms import OpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
#from langchain.vectorstores.faiss import FAISS
from langchain.docstore.document import Document
from langchain.prompts import PromptTemplate
from langchain.chains.question_answering import load_qa_chain
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.chains import VectorDBQAWithSourcesChain
chain_qa = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="stuff")
def get_text_fmt(qry, passages = mypassages, doc_embedding=mycorpus_embeddings):
predictions = search(qry, passages = passages, doc_embedding = doc_embedding, top_n=5, )
prediction_text = []
for hit in predictions:
page_content = passages[hit['corpus_id']]
metadata = {"source": hit['corpus_id']}
result = Document(page_content=page_content, metadata=metadata)
prediction_text.append(result)
return prediction_text
#mypassages[0]
#mycorpus_embeddings[0][:5]
# query = "What is the name of the plan described by this summary of benefits?"
# mydocs = get_text_fmt(query)
# print(len(mydocs))
# for d in mydocs:
# print(d)
# chain_qa.run(input_documents=mydocs, question=query)
def get_llm_response(message):
mydocs = get_text_fmt(message)
responses = chain_qa.run(input_documents=mydocs, question=message)
return responses
"""# chat example"""
def chat(message, history):
history = history or []
message = message.lower()
response = get_llm_response(message)
history.append((message, response))
return history, history
css=".gradio-container {background-color: lightgray}"
with gr.Blocks(css=css) as demo:
history_state = gr.State()
gr.Markdown('# Hack QA')
title='Benefit Chatbot'
description='chatbot with search on Health Benefits'
with gr.Row():
chatbot = gr.Chatbot()
with gr.Row():
message = gr.Textbox(label='Input your question here:',
placeholder='What is the name of the plan described by this summary of benefits?',
lines=1)
submit = gr.Button(value='Send',
variant='secondary').style(full_width=False)
submit.click(chat,
inputs=[message, history_state],
outputs=[chatbot, history_state])
gr.Examples(
examples=["What is the name of the plan described by this summary of benefits?",
"How much is the monthly premium?",
"How much do I have to pay if I am admitted to the hospital?"],
inputs=message
)
demo.launch()
|