Vishnu-add commited on
Commit
0361dbf
·
1 Parent(s): 5dba2dd

Uncommented llm code

Browse files
Files changed (1) hide show
  1. app.py +57 -57
app.py CHANGED
@@ -16,61 +16,61 @@ from langchain.vectorstores import Chroma
16
  import os
17
 
18
  st.set_page_config(page_title="pdf-GPT", page_icon="📖", layout="wide")
19
- # @st.cache_resource
20
- # def get_model():
21
- # device = torch.device('cpu')
22
- # # device = torch.device('cuda:0')
23
-
24
- # checkpoint = "LaMini-T5-738M"
25
- # checkpoint = "MBZUAI/LaMini-T5-738M"
26
- # tokenizer = AutoTokenizer.from_pretrained(checkpoint)
27
- # base_model = AutoModelForSeq2SeqLM.from_pretrained(
28
- # checkpoint,
29
- # device_map=device,
30
- # torch_dtype = torch.float32,
31
- # # offload_folder= "/model_ck"
32
- # )
33
- # return base_model,tokenizer
34
-
35
- # @st.cache_resource
36
- # def llm_pipeline():
37
- # base_model,tokenizer = get_model()
38
- # pipe = pipeline(
39
- # 'text2text-generation',
40
- # model = base_model,
41
- # tokenizer=tokenizer,
42
- # max_length = 512,
43
- # do_sample = True,
44
- # temperature = 0.3,
45
- # top_p = 0.95,
46
- # # device=device
47
- # )
48
-
49
- # local_llm = HuggingFacePipeline(pipeline = pipe)
50
- # return local_llm
51
-
52
- # @st.cache_resource
53
- # def qa_llm():
54
- # llm = llm_pipeline()
55
- # embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
56
- # db = Chroma(persist_directory="db", embedding_function = embeddings)
57
- # retriever = db.as_retriever()
58
- # qa = RetrievalQA.from_chain_type(
59
- # llm=llm,
60
- # chain_type = "stuff",
61
- # retriever = retriever,
62
- # return_source_documents=True
63
- # )
64
- # return qa
65
-
66
-
67
- # def process_answer(instruction):
68
- # response=''
69
- # instruction = instruction
70
- # qa = qa_llm()
71
- # generated_text = qa(instruction)
72
- # answer = generated_text['result']
73
- # return answer, generated_text
74
 
75
  # Display conversation history using Streamlit messages
76
  def display_conversation(history):
@@ -174,8 +174,8 @@ def main():
174
 
175
  # Search the database for a response based on user input and update session state
176
  if user_input:
177
- # answer = process_answer({"query" : user_input})
178
- answer = user_input
179
  st.session_state["past"].append(user_input)
180
  response = answer
181
  st.session_state["generated"].append(response)
 
16
  import os
17
 
18
  st.set_page_config(page_title="pdf-GPT", page_icon="📖", layout="wide")
19
+ @st.cache_resource
20
+ def get_model():
21
+ device = torch.device('cpu')
22
+ # device = torch.device('cuda:0')
23
+
24
+ checkpoint = "LaMini-T5-738M"
25
+ checkpoint = "MBZUAI/LaMini-T5-738M"
26
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
27
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(
28
+ checkpoint,
29
+ device_map=device,
30
+ torch_dtype = torch.float32,
31
+ # offload_folder= "/model_ck"
32
+ )
33
+ return base_model,tokenizer
34
+
35
+ @st.cache_resource
36
+ def llm_pipeline():
37
+ base_model,tokenizer = get_model()
38
+ pipe = pipeline(
39
+ 'text2text-generation',
40
+ model = base_model,
41
+ tokenizer=tokenizer,
42
+ max_length = 512,
43
+ do_sample = True,
44
+ temperature = 0.3,
45
+ top_p = 0.95,
46
+ # device=device
47
+ )
48
+
49
+ local_llm = HuggingFacePipeline(pipeline = pipe)
50
+ return local_llm
51
+
52
+ @st.cache_resource
53
+ def qa_llm():
54
+ llm = llm_pipeline()
55
+ embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
56
+ db = Chroma(persist_directory="db", embedding_function = embeddings)
57
+ retriever = db.as_retriever()
58
+ qa = RetrievalQA.from_chain_type(
59
+ llm=llm,
60
+ chain_type = "stuff",
61
+ retriever = retriever,
62
+ return_source_documents=True
63
+ )
64
+ return qa
65
+
66
+
67
+ def process_answer(instruction):
68
+ response=''
69
+ instruction = instruction
70
+ qa = qa_llm()
71
+ generated_text = qa(instruction)
72
+ answer = generated_text['result']
73
+ return answer, generated_text
74
 
75
  # Display conversation history using Streamlit messages
76
  def display_conversation(history):
 
174
 
175
  # Search the database for a response based on user input and update session state
176
  if user_input:
177
+ answer = process_answer({"query" : user_input})
178
+ # answer = user_input
179
  st.session_state["past"].append(user_input)
180
  response = answer
181
  st.session_state["generated"].append(response)