arjunanand13 commited on
Commit
db9bd14
·
verified ·
1 Parent(s): d90465c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -20
app.py CHANGED
@@ -7,10 +7,12 @@ from langchain.vectorstores import FAISS
7
  from langchain.chains import ConversationalRetrievalChain
8
  import gradio as gr
9
  from langchain.embeddings import HuggingFaceEmbeddings
10
- from sentence_transformers import CrossEncoder
11
 
 
 
12
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
13
 
 
14
  class StopOnTokens(StoppingCriteria):
15
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
16
  for stop_ids in stop_token_ids:
@@ -18,9 +20,11 @@ class StopOnTokens(StoppingCriteria):
18
  return True
19
  return False
20
 
 
21
  model_id = 'meta-llama/Meta-Llama-3-8B-Instruct'
22
  device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
23
 
 
24
  bnb_config = BitsAndBytesConfig(
25
  load_in_4bit=True,
26
  bnb_4bit_quant_type='nf4',
@@ -31,11 +35,13 @@ bnb_config = BitsAndBytesConfig(
31
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
32
  model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", token=HF_TOKEN, quantization_config=bnb_config)
33
 
 
34
  stop_list = ['\nHuman:', '\n```\n']
35
  stop_token_ids = [tokenizer(x)['input_ids'] for x in stop_list]
36
  stop_token_ids = [torch.LongTensor(x).to(device) for x in stop_token_ids]
37
  stopping_criteria = StoppingCriteriaList([StopOnTokens()])
38
 
 
39
  generate_text = pipeline(
40
  model=model,
41
  tokenizer=tokenizer,
@@ -49,21 +55,19 @@ generate_text = pipeline(
49
 
50
  llm = HuggingFacePipeline(pipeline=generate_text)
51
 
52
- """Load the stored FAISS index"""
53
  try:
54
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2", model_kwargs={"device": "cuda"})
55
- vectorstore = FAISS.load_local('faiss_index', embeddings)
56
- print("Loaded embeddings from FAISS Index successfully")
57
  except ImportError as e:
58
  print("FAISS could not be imported. Make sure FAISS is installed correctly.")
59
  raise e
60
 
 
61
  chain = ConversationalRetrievalChain.from_llm(llm, vectorstore.as_retriever(), return_source_documents=True)
62
 
63
  chat_history = []
64
 
65
- reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
66
-
67
  def format_prompt(query):
68
  prompt = f"""
69
  You are a knowledgeable assistant with access to a comprehensive database.
@@ -84,19 +88,8 @@ def format_prompt(query):
84
 
85
  def qa_infer(query):
86
  formatted_prompt = format_prompt(query)
87
- results = chain({"question": formatted_prompt, "chat_history": chat_history})
88
-
89
- documents = results['source_documents']
90
- query_document_pairs = [[query, doc.page_content] for doc in documents]
91
- scores = reranker.predict(query_document_pairs)
92
-
93
- """Sort documents based on the re-ranker scores"""
94
- ranked_docs = sorted(zip(scores, documents), key=lambda x: x[0], reverse=True)
95
-
96
- """Extract the best document"""
97
- best_doc = ranked_docs[0][1].page_content if ranked_docs else ""
98
-
99
- return best_doc
100
 
101
  EXAMPLES = ["How to use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM",
102
  "Can BQ25896 support I2C interface?",
@@ -104,3 +97,103 @@ EXAMPLES = ["How to use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM",
104
 
105
  demo = gr.Interface(fn=qa_infer, inputs="text", allow_flagging='never', examples=EXAMPLES, cache_examples=False, outputs="text")
106
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from langchain.chains import ConversationalRetrievalChain
8
  import gradio as gr
9
  from langchain.embeddings import HuggingFaceEmbeddings
 
10
 
11
+
12
+ # Load the Hugging Face token from environment
13
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
14
 
15
+ # Define stopping criteria
16
  class StopOnTokens(StoppingCriteria):
17
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
18
  for stop_ids in stop_token_ids:
 
20
  return True
21
  return False
22
 
23
+ # Load the LLaMA model and tokenizer
24
  model_id = 'meta-llama/Meta-Llama-3-8B-Instruct'
25
  device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
26
 
27
+ # Set quantization configuration
28
  bnb_config = BitsAndBytesConfig(
29
  load_in_4bit=True,
30
  bnb_4bit_quant_type='nf4',
 
35
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
36
  model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", token=HF_TOKEN, quantization_config=bnb_config)
37
 
38
+ # Define stopping criteria
39
  stop_list = ['\nHuman:', '\n```\n']
40
  stop_token_ids = [tokenizer(x)['input_ids'] for x in stop_list]
41
  stop_token_ids = [torch.LongTensor(x).to(device) for x in stop_token_ids]
42
  stopping_criteria = StoppingCriteriaList([StopOnTokens()])
43
 
44
+ # Create text generation pipeline
45
  generate_text = pipeline(
46
  model=model,
47
  tokenizer=tokenizer,
 
55
 
56
  llm = HuggingFacePipeline(pipeline=generate_text)
57
 
58
+ # Load the stored FAISS index
59
  try:
60
+ vectorstore = FAISS.load_local('faiss_index', HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2", model_kwargs={"device": "cuda"}))
61
+ print("Loaded embedding successfully")
 
62
  except ImportError as e:
63
  print("FAISS could not be imported. Make sure FAISS is installed correctly.")
64
  raise e
65
 
66
+ # Set up the Conversational Retrieval Chain
67
  chain = ConversationalRetrievalChain.from_llm(llm, vectorstore.as_retriever(), return_source_documents=True)
68
 
69
  chat_history = []
70
 
 
 
71
  def format_prompt(query):
72
  prompt = f"""
73
  You are a knowledgeable assistant with access to a comprehensive database.
 
88
 
89
  def qa_infer(query):
90
  formatted_prompt = format_prompt(query)
91
+ result = chain({"question": formatted_prompt, "chat_history": chat_history})
92
+ return result['answer']
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  EXAMPLES = ["How to use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM",
95
  "Can BQ25896 support I2C interface?",
 
97
 
98
  demo = gr.Interface(fn=qa_infer, inputs="text", allow_flagging='never', examples=EXAMPLES, cache_examples=False, outputs="text")
99
  demo.launch()
100
+
101
+ # import os
102
+ # import torch
103
+ # from torch import cuda, bfloat16
104
+ # from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig, StoppingCriteria, StoppingCriteriaList
105
+ # from langchain.llms import HuggingFacePipeline
106
+ # from langchain.vectorstores import FAISS
107
+ # from langchain.chains import ConversationalRetrievalChain
108
+ # import gradio as gr
109
+ # from langchain.embeddings import HuggingFaceEmbeddings
110
+
111
+ # # Load the Hugging Face token from environment
112
+ # HF_TOKEN = os.environ.get("HF_TOKEN", None)
113
+
114
+ # # Define stopping criteria
115
+ # class StopOnTokens(StoppingCriteria):
116
+ # def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
117
+ # for stop_ids in stop_token_ids:
118
+ # if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
119
+ # return True
120
+ # return False
121
+
122
+ # # Load the LLaMA model and tokenizer
123
+ # model_id = 'meta-llama/Meta-Llama-3-8B-Instruct'
124
+ # device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
125
+
126
+ # # Set quantization configuration
127
+ # bnb_config = BitsAndBytesConfig(
128
+ # load_in_4bit=True,
129
+ # bnb_4bit_quant_type='nf4',
130
+ # bnb_4bit_use_double_quant=True,
131
+ # bnb_4bit_compute_dtype=bfloat16
132
+ # )
133
+
134
+ # tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
135
+ # model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", token=HF_TOKEN, quantization_config=bnb_config)
136
+
137
+ # # Define stopping criteria
138
+ # stop_list = ['\nHuman:', '\n```\n']
139
+ # stop_token_ids = [tokenizer(x)['input_ids'] for x in stop_list]
140
+ # stop_token_ids = [torch.LongTensor(x).to(device) for x in stop_token_ids]
141
+ # stopping_criteria = StoppingCriteriaList([StopOnTokens()])
142
+
143
+ # # Create text generation pipeline
144
+ # generate_text = pipeline(
145
+ # model=model,
146
+ # tokenizer=tokenizer,
147
+ # return_full_text=True,
148
+ # task='text-generation',
149
+ # stopping_criteria=stopping_criteria,
150
+ # temperature=0.1,
151
+ # max_new_tokens=512,
152
+ # repetition_penalty=1.1
153
+ # )
154
+
155
+ # llm = HuggingFacePipeline(pipeline=generate_text)
156
+
157
+ # # Load the stored FAISS index
158
+ # try:
159
+ # embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2", model_kwargs={"device": "cuda"})
160
+ # vectorstore = FAISS.load_local('faiss_index', embeddings)
161
+ # print("Loaded embedding successfully")
162
+ # except ImportError as e:
163
+ # print("FAISS could not be imported. Make sure FAISS is installed correctly.")
164
+ # raise e
165
+
166
+ # # Set up the Conversational Retrieval Chain
167
+ # chain = ConversationalRetrievalChain.from_llm(llm, vectorstore.as_retriever(), return_source_documents=True)
168
+
169
+ # chat_history = []
170
+
171
+ # def format_prompt(query):
172
+ # prompt = f"""
173
+ # You are a knowledgeable assistant with access to a comprehensive database.
174
+ # I need you to answer my question and provide related information in a specific format.
175
+ # Here's what I need:
176
+ # 1. A brief, general response to my question based on related answers retrieved.
177
+ # 2. A JSON-formatted output containing:
178
+ # - "question": The original question.
179
+ # - "answer": The detailed answer.
180
+ # - "related_questions": A list of related questions and their answers, each as a dictionary with the keys:
181
+ # - "question": The related question.
182
+ # - "answer": The related answer.
183
+ # Here's my question:
184
+ # {query}
185
+ # Include a brief final answer without additional comments, sign-offs, or extra phrases. Be direct and to the point.
186
+ # """
187
+ # return prompt
188
+
189
+ # def qa_infer(query):
190
+ # formatted_prompt = format_prompt(query)
191
+ # result = chain({"question": formatted_prompt, "chat_history": chat_history})
192
+ # return result['answer']
193
+
194
+ # EXAMPLES = ["How to use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM",
195
+ # "Can BQ25896 support I2C interface?",
196
+ # "Does TDA2 vout support bt656 8-bit mode?"]
197
+
198
+ # demo = gr.Interface(fn=qa_infer, inputs="text", allow_flagging='never', examples=EXAMPLES, cache_examples=False, outputs="text")
199
+ # demo.launch()