divyanshusingh commited on
Commit
26175df
·
1 Parent(s): 7c627e9

Added: model.py

Browse files

Adds scripts to construct the model

Files changed (1) hide show
  1. model.py +78 -0
model.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from dotenv import load_dotenv
4
+ load_dotenv()
5
+ try:
6
+ os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_API_TOKEN")
7
+ PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
8
+ except:
9
+ PINECONE_API_KEY = subprocess.check_output(["bash", "-c", "echo ${{ secrets.PINECONE_API_KEY }}"]).decode("utf-8").strip()
10
+
11
+
12
+ from langchain.embeddings import HuggingFaceEmbeddings
13
+ import pinecone
14
+ import torch
15
+ from langchain import PromptTemplate, LLMChain,HuggingFacePipeline
16
+ from langchain.vectorstores import Pinecone
17
+ from langchain.chains.question_answering import load_qa_chain
18
+ from langchain.chains import RetrievalQA
19
+ from transformers import pipeline
20
+
21
+ def get_llm(model_name,pinecone_index,llm):
22
+ # model_name = "bert-large-uncased" #"t5-large"
23
+ model_kwargs = {'device': 'cuda' if torch.cuda.is_available() else 'cpu'}
24
+
25
+ embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
26
+
27
+
28
+ pinecone.init(
29
+ api_key=PINECONE_API_KEY,
30
+ environment="us-east-1-aws"
31
+ )
32
+
33
+ index = pinecone.Index(pinecone_index)
34
+ print(index.describe_index_stats())
35
+
36
+ docsearch = Pinecone(index, embeddings.embed_query,"text")
37
+
38
+ # print("About to load the model")
39
+
40
+ instruct_pipeline = pipeline(model=llm, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto",
41
+ return_full_text=True, do_sample=False, max_new_tokens=128)
42
+ llm = HuggingFacePipeline(pipeline=instruct_pipeline)
43
+ # print("Loaded the LLM")
44
+
45
+ # print("Prompting")
46
+
47
+ template = """Context: {context}
48
+
49
+ Question: {question}
50
+
51
+ Answer: Let's go step by step."""
52
+
53
+ prompt = PromptTemplate(template=template, input_variables=["question","context"])
54
+ llm_chain = LLMChain(prompt=prompt, llm=llm)
55
+ return llm_chain, docsearch
56
+
57
+
58
+ if __name__ == "__main__":
59
+ model_name = "bert-large-uncased"
60
+ pinecone_index = "bert-large-uncased"
61
+ llm = "databricks/dolly-v2-3b"
62
+ llm_chain, docsearch = get_llm(model_name,pinecone_index,llm)
63
+ print(":"*40)
64
+ questions = ["what is the name of the first Hindi newspaper published in Bihar?",
65
+ "what is the capital of Bihar?",
66
+ "Brief about the Gupta Dynasty"]
67
+ for question in questions:
68
+ context = docsearch.similarity_search(question, k=3,metadata=False)
69
+ content = ""
70
+ for i in context:
71
+ content= content + f"{i.__dict__['page_content']}"
72
+ print(f"{question}")
73
+ response = llm_chain.predict(question=question,context=content)
74
+ print(f"{response}\n{'--'*25}")
75
+
76
+
77
+
78
+