Spaces:
Sleeping
Sleeping
File size: 4,335 Bytes
0b5fda4 3fa3177 0b5fda4 3fa3177 0b5fda4 3fa3177 e6d6c2e 3fa3177 f4be8f1 3fa3177 f4be8f1 3fa3177 f4be8f1 e6d6c2e f4be8f1 0b5fda4 3fa3177 0b5fda4 3fa3177 0b5fda4 f4be8f1 3fa3177 0b5fda4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import pipeline
import torch
import base64
import textwrap
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.vectorstores import Chroma
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.chains import RetrievalQA
from streamlit_chat import message
@st.cache_resource
def get_model():
device = torch.device('cpu')
# device = torch.device('cuda:0')
checkpoint = "LaMini-T5-738M"
checkpoint = "MBZUAI/LaMini-T5-738M"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
base_model = AutoModelForSeq2SeqLM.from_pretrained(
checkpoint,
device_map=device,
torch_dtype = torch.float32,
# offload_folder= "/model_ck"
)
return base_model,tokenizer
@st.cache_resource
def llm_pipeline():
base_model,tokenizer = get_model()
pipe = pipeline(
'text2text-generation',
model = base_model,
tokenizer=tokenizer,
max_length = 512,
do_sample = True,
temperature = 0.3,
top_p = 0.95,
# device=device
)
local_llm = HuggingFacePipeline(pipeline = pipe)
return local_llm
@st.cache_resource
def qa_llm():
llm = llm_pipeline()
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
db = Chroma(persist_directory="db", embedding_function = embeddings)
retriever = db.as_retriever()
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type = "stuff",
retriever = retriever,
return_source_documents=True
)
return qa
def process_answer(instruction):
response=''
instruction = instruction
qa = qa_llm()
generated_text = qa(instruction)
answer = generated_text['result']
return answer, generated_text
# Display conversation history using Streamlit messages
def display_conversation(history):
# st.write(history)
for i in range(len(history["generated"])):
message(history["past"][i] , is_user=True, key= str(i) + "_user")
if isinstance(history["generated"][i],str):
message(history["generated"][i] , key= str(i))
else:
message(history["generated"][i][0] , key= str(i))
sources_list = []
for source in history["generated"][i][1]['source_documents']:
# st.write(source.metadata['source'])
sources_list.append(source.metadata['source'])
# Uncomment below line to display sources
# message(str(set(sources_list)) , key="source_"+str(i))
def main():
# Search with pdf code
# st.title("Search your pdf📚")
# with st.expander("About the App"):
# st.markdown(
# """This is a Generative AI powered Question and Answering app that responds to questions about your PDF file.
# """
# )
# question = st.text_area("Enter Your Question")
# if st.button("Search"):
# st.info("Your question: "+question)
# st.info("Your Answer")
# answer, metadata = process_answer(question)
# st.write(answer)
# st.write(metadata)
# Chat with pdf code
st.title("Chat with your pdf📚")
with st.expander("About the App"):
st.markdown(
"""
This is a Generative AI powered Question and Answering app that responds to questions about your PDF file.
"""
)
# user_input = st.text_input("",key="input")
user_input = st.chat_input("",key="input")
# Initialize session state for generated responses and past messages
if "generated" not in st.session_state:
st.session_state["generated"] = ["I am ready to help you"]
if "past" not in st.session_state:
st.session_state["past"] = ["Hey There!"]
# Search the database for a response based on user input and update session state
if user_input:
answer = process_answer({"query" : user_input})
st.session_state["past"].append(user_input)
response = answer
st.session_state["generated"].append(response)
# Display Conversation history using Streamlit messages
if st.session_state["generated"]:
display_conversation(st.session_state)
if __name__ == "__main__":
main()
|