File size: 3,181 Bytes
de9a14f
341b0e8
 
 
db2d027
f437f2a
30b8a93
 
6bf42b4
30b8a93
f437f2a
6085a4e
6bf42b4
fe4f2dd
f437f2a
 
1d55d4a
73e234f
51c6493
f437f2a
 
eb40503
f437f2a
30b8a93
f437f2a
6bf42b4
51c6493
f437f2a
 
 
 
341b0e8
1e0339f
 
 
 
 
 
 
 
f40cccc
4e0de4e
f868774
c964469
f40cccc
 
0a78d67
fae8f63
f40cccc
 
 
1e0339f
f40cccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e0339f
f40cccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e0339f
 
f40cccc
 
 
 
 
 
 
 
 
 
 
 
 
 
8d7d889
 
 
1e0339f
3719e06
555598a
 
f40cccc
3719e06
f4c65b4
 
 
aa46ac9
f40cccc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#from langchain_community.document_loaders import PyPDFLoader

from datasets import load_dataset
dataset = load_dataset("Namitg02/Test")
print(dataset)

from langchain.docstore.document import Document as LangchainDocument

#RAW_KNOWLEDGE_BASE = [LangchainDocument(page_content=["dataset"])]

from langchain.text_splitter import RecursiveCharacterTextSplitter
splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=15,separators=["\n\n", "\n", " ", ""])
#docs = splitter.split_documents(RAW_KNOWLEDGE_BASE)
docs = splitter.create_documents(str(dataset))


from langchain_community.embeddings import HuggingFaceEmbeddings
embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
# embeddings = embedding_model.encode(docs)


from langchain_community.vectorstores import Chroma
persist_directory = 'docs/chroma/'

vectordb = Chroma.from_documents(
    documents=docs,
    embedding=embedding_model,
    persist_directory=persist_directory
)

retriever = vectordb.as_retriever()

#docs_ss = vectordb.similarity_search(question,k=3)


#qa_chain = RetrievalQA.from_chain_type(
#    models/HuggingFaceH4/zephyr-7b-beta,
#    retriever=vectordb.as_retriever()
#)

from transformers import pipeline
from transformers import BitsAndBytesConfig
from transformers import AutoTokenizer, AutoModelForCausalLM


READER_MODEL = "HuggingFaceH4/zephyr-7b-beta"
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
model = AutoModelForCausalLM.from_pretrained(READER_MODEL,quantization_config=bnb_config)
tokenizer = AutoTokenizer.from_pretrained(READER_MODEL)

from langchain.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from transformers import pipeline
from langchain_core.output_parsers import StrOutputParser


text_generation_pipeline = pipeline(
    model=model,
    tokenizer=tokenizer,
    task="text-generation",
    temperature=0.2,
    do_sample=True,
    repetition_penalty=1.1,
    return_full_text=True,
    max_new_tokens=100,
)

llm = HuggingFacePipeline(pipeline=text_generation_pipeline)


#from langchain_community.output_parsers.rail_parser import GuardrailsOutputParser

prompt_template = """
<|system|>
Answer the question based on your knowledge. Use the following context to help:

{context}

</s>
<|user|>
{question}
</s>
<|assistant|>

 """

QA_CHAIN_PROMPT = PromptTemplate(
    input_variables=["context", "question"],
    template=prompt_template,
)


llm_chain = QA_CHAIN_PROMPT | llm | StrOutputParser()


from langchain_core.runnables import RunnablePassthrough


retriever=vectordb.as_retriever()

rag_chain = {"context": retriever, "question": RunnablePassthrough()} | llm_chain


#from langchain.chains import ConversationalRetrievalChain
#from langchain.memory import ConversationBufferMemory
#memory = ConversationBufferMemory(
#    memory_key="chat_history",
#    return_messages=True
#)

question = "Can I reverse Diabetes?"
print("template")

#qa = ConversationalRetrievalChain.from_llm(llm=READER_MODEL,retriever=retriever,memory=memory)

import gradio as gr
gr.load("READER_MODEL").launch()

#result = ({"query": question})
#print("qa")