File size: 3,213 Bytes
341b0e8
 
db2d027
f437f2a
30b8a93
 
6bf42b4
30b8a93
f437f2a
6085a4e
6bf42b4
fe4f2dd
f437f2a
 
1d55d4a
73e234f
51c6493
f437f2a
 
eb40503
f437f2a
30b8a93
f437f2a
6bf42b4
51c6493
f437f2a
 
 
1e0339f
 
0dab96e
 
 
1e0339f
f40cccc
50f7573
f40cccc
0dab96e
 
 
 
f40cccc
0dab96e
f40cccc
3c80115
 
 
 
 
 
1e0339f
0dab96e
50f7573
1e0339f
0dab96e
 
 
f40cccc
50f7573
9389e44
1e0339f
d12676c
11a5cb0
1f323ef
11a5cb0
555598a
e92769e
c8f17c6
e92769e
a43db2b
f40cccc
a43db2b
d12676c
11a5cb0
 
 
 
 
67f8b6c
11a5cb0
 
 
 
 
 
d12676c
11a5cb0
 
 
 
d12676c
11a5cb0
 
 
f4c65b4
aa46ac9
11a5cb0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from datasets import load_dataset
dataset = load_dataset("Namitg02/Test")
print(dataset)

from langchain.docstore.document import Document as LangchainDocument

#RAW_KNOWLEDGE_BASE = [LangchainDocument(page_content=["dataset"])]

from langchain.text_splitter import RecursiveCharacterTextSplitter
splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=15,separators=["\n\n", "\n", " ", ""])
#docs = splitter.split_documents(RAW_KNOWLEDGE_BASE)
docs = splitter.create_documents(str(dataset))


from langchain_community.embeddings import HuggingFaceEmbeddings
embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
# embeddings = embedding_model.encode(docs)


from langchain_community.vectorstores import Chroma
persist_directory = 'docs/chroma/'

vectordb = Chroma.from_documents(
    documents=docs,
    embedding=embedding_model,
    persist_directory=persist_directory
)

#docs_ss = vectordb.similarity_search(question,k=3)

# Create placeholders for the login form widgets using st.empty()
#user_input_placeholder = st.empty()
#pass_input_placeholder = st.empty()

#from langchain_community.output_parsers.rail_parser import GuardrailsOutputParser
#from langchain.prompts import PromptTemplate

#template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Use three sentences maximum. Keep the answer as concise as possible. Always say "thanks for asking!" at the end of the answer. 
#{You are a helpful dietician}
#Question: {question}
#Helpful Answer:"""

#QA_CHAIN_PROMPT = PromptTemplate.from_template(template)

#from langchain.chains import ConversationalRetrievalChain
#from langchain.memory import ConversationBufferMemory
#memory = ConversationBufferMemory(
#    memory_key="chat_history",
#    return_messages=True
#)

question = "How can I reverse Diabetes?"
#print("template")

retriever = vectordb.as_retriever(
    search_type="similarity", search_kwargs={"k": 2}
)

#from langchain.chains import RetrievalQA
from langchain_core.prompts import ChatPromptTemplate

from langchain.chains.combine_documents import create_stuff_documents_chain
#from langchain import hub
from langchain.chains import create_retrieval_chain


READER_MODEL="HuggingFaceH4/zephyr-7b-beta"
#HuggingFaceH4/zephyr-7b-beta
#READER_MODEL=Ollama(model="meta-llama/Meta-Llama-Guard-2-8B")

#qa = ConversationalRetrievalChain.from_llm(llm=READER_MODEL,retriever=retriever,memory=memory)
#qa = RetrievalQA.from_chain_type(llm=READER_MODEL,retriever=retriever)

#retrieval_qa_chat_prompt = hub.pull("langchain-ai/retrieval-qa-chat")

qa_chat_prompt = ChatPromptTemplate.from_template("""Answer the following question based only on the provided context:

<context>
{context}
</context>

Question: {input}""")

docs_chain = create_stuff_documents_chain(
    READER_MODEL, qa_chat_prompt
)
retrieval_chain = create_retrieval_chain(retriever, docs_chain)
response = retrieval_chain.invoke({"input": "how can I reverse diabetes?"})
print(response["answer"])


#result = qa(question)
#import gradio as gr
#gr.load("READER_MODEL").launch()

#result = ({"query": question})
#print("qa")