chabi / main.py
anasmkh's picture
Update main.py
da72dd5 verified
raw
history blame
1.79 kB
from langchain_community.document_loaders import TextLoader , PyPDFLoader
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.llms import HuggingFacePipeline
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from langchain.chains import RetrievalQA
import torch
import gradio as gr
# loader = PyPDFLoader('bipolar.pdf')
loader = TextLoader("info.txt")
docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter()
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=20)
documents = text_splitter.split_documents(docs)
huggingface_embeddings = HuggingFaceBgeEmbeddings(
model_name="BAAI/bge-small-en-v1.5",
model_kwargs={'device':'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
vector = FAISS.from_documents(documents, huggingface_embeddings)
retriever = vector.as_retriever()
model_name = "facebook/bart-base"
# model_name = "distilbert/distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
pipe = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
max_length=300,
temperature=0.9,
top_p=0.9,
repetition_penalty=1.15,
do_sample=True
)
local_llm = HuggingFacePipeline(pipeline=pipe)
qa_chain = RetrievalQA.from_llm(llm=local_llm, retriever=retriever)
def gradinterface(query,history):
result = qa_chain({'query': query})
return result['result'].split(': ')[-1].strip()
demo = gr.ChatInterface(fn=gradinterface, title='OUR_OWN_BOT')
if __name__ == "__main__":
demo.launch(share=True)