File size: 1,790 Bytes
bd53fbd
7f1e53f
 
 
f97cf59
7f1e53f
f97cf59
7f1e53f
 
8d69c8a
bd53fbd
f0fde48
 
7f1e53f
 
 
 
 
 
 
 
 
 
 
 
 
f97cf59
da72dd5
 
f97cf59
7f1e53f
61b75fc
7f1e53f
61b75fc
2416f1c
 
f97cf59
 
7f1e53f
 
 
2416f1c
 
f97cf59
7f1e53f
2416f1c
7f1e53f
 
 
2416f1c
dd41a03
f97cf59
9a529b4
f97cf59
 
2416f1c
f97cf59
 
9514cd1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from langchain_community.document_loaders import TextLoader , PyPDFLoader
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.llms import HuggingFacePipeline
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from langchain.chains import RetrievalQA
import torch
import gradio as gr

# loader = PyPDFLoader('bipolar.pdf')
loader = TextLoader("info.txt")
docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter()
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=20)
documents = text_splitter.split_documents(docs)

huggingface_embeddings = HuggingFaceBgeEmbeddings(
    model_name="BAAI/bge-small-en-v1.5",
    model_kwargs={'device':'cpu'},
    encode_kwargs={'normalize_embeddings': True}
)

vector = FAISS.from_documents(documents, huggingface_embeddings)
retriever = vector.as_retriever()

model_name = "facebook/bart-base"
# model_name = "distilbert/distilbert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

pipe = pipeline(
    "text2text-generation",
    model=model,
    tokenizer=tokenizer,
    max_length=300,
    temperature=0.9,
    top_p=0.9,
    repetition_penalty=1.15,
    do_sample=True

)
local_llm = HuggingFacePipeline(pipeline=pipe)
qa_chain =  RetrievalQA.from_llm(llm=local_llm, retriever=retriever)



def gradinterface(query,history):
    result = qa_chain({'query': query})
    return result['result'].split(': ')[-1].strip()


demo = gr.ChatInterface(fn=gradinterface, title='OUR_OWN_BOT')

if __name__ == "__main__":
    demo.launch(share=True)