ArturG9 commited on
Commit
e568bb4
·
verified ·
1 Parent(s): 45ba126

Update functions.py

Browse files
Files changed (1) hide show
  1. functions.py +520 -0
functions.py CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  def create_retriever_from_chroma(vectorstore_path="./docs/chroma/", search_type='mmr', k=7, chunk_size=300, chunk_overlap=30,lambda_mult= 0.7):
2
 
3
  model_name = "Alibaba-NLP/gte-large-en-v1.5"
@@ -40,3 +46,517 @@ def create_retriever_from_chroma(vectorstore_path="./docs/chroma/", search_type=
40
 
41
 
42
  return retriever
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from some_llm_library import PromptTemplate, StrOutputParser
2
+
3
+
4
+
5
+
6
+
7
  def create_retriever_from_chroma(vectorstore_path="./docs/chroma/", search_type='mmr', k=7, chunk_size=300, chunk_overlap=30,lambda_mult= 0.7):
8
 
9
  model_name = "Alibaba-NLP/gte-large-en-v1.5"
 
46
 
47
 
48
  return retriever
49
+
50
+
51
+ def retrieval_grader_grader(llm):
52
+ """
53
+ Function to create a grader object using a passed LLM model.
54
+
55
+ Args:
56
+ llm: The language model to be used for grading.
57
+
58
+ Returns:
59
+ Callable: A pipeline function that grades relevance based on the LLM.
60
+ """
61
+
62
+ # Define the class for grading documents inside the function
63
+ class GradeDocuments(BaseModel):
64
+ """Binary score for relevance check on retrieved documents."""
65
+ binary_score: str = Field(
66
+ description="Documents are relevant to the question, 'yes' or 'no'"
67
+ )
68
+
69
+ # Create the structured LLM grader using the passed LLM
70
+ structured_llm_grader = llm.with_structured_output(GradeDocuments)
71
+
72
+ # Define the prompt template
73
+ prompt = PromptTemplate(
74
+ template="""You are a teacher grading a quiz. You will be given:
75
+ 1/ a QUESTION
76
+ 2/ A FACT provided by the student
77
+
78
+ You are grading RELEVANCE RECALL:
79
+ A score of 1 means that ANY of the statements in the FACT are relevant to the QUESTION.
80
+ A score of 0 means that NONE of the statements in the FACT are relevant to the QUESTION.
81
+ 1 is the highest (best) score. 0 is the lowest score you can give.
82
+
83
+ Explain your reasoning in a step-by-step manner. Ensure your reasoning and conclusion are correct.
84
+
85
+ Avoid simply stating the correct answer at the outset.
86
+
87
+ Question: {question} \n
88
+ Fact: \n\n {documents} \n\n
89
+
90
+ Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. \n
91
+ Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.
92
+ """,
93
+ input_variables=['documents', 'question'],
94
+ )
95
+
96
+ # Combine the prompt with the structured LLM grader
97
+ retrieval_grader = prompt | structured_llm_grader
98
+
99
+ # Return the grader object
100
+ return retrieval_grader
101
+
102
+
103
+ def create_toxicity_checker(llm):
104
+ """
105
+ Function to create a toxicity checker object using a passed LLM model.
106
+
107
+ Args:
108
+ llm: The language model to be used for toxicity checking.
109
+
110
+ Returns:
111
+ Callable: A pipeline function that checks for toxicity based on the LLM.
112
+ """
113
+
114
+ # Define the class for toxicity check inside the function
115
+ class PromtGuard(BaseModel):
116
+ """Binary score for toxicity check on question."""
117
+ score: str = Field(
118
+ description="Is the question toxic, 'yes' or 'no'"
119
+ )
120
+
121
+ # Create the structured LLM toxicity checker using the passed LLM
122
+ structured_llm_toxicity_checker = llm.with_structured_output(PromtGuard)
123
+
124
+ # Define the prompt template
125
+ prompt = PromptTemplate(
126
+ template="""You are a classifier that will score a text based upon supplied criteria.
127
+
128
+ Here is the criteria for scoring:
129
+
130
+ (1) insults
131
+
132
+ (2) threats
133
+
134
+ (3) highly negative comments
135
+
136
+ (4) any Personally Identifiable Information such as API Keys
137
+
138
+ Score:
139
+
140
+ A score of yes means that the text meets ANY of the criteria.
141
+
142
+ A score of no means that the text does not meet ALL of the criteria.
143
+
144
+ Here is the question: {question}
145
+
146
+ Give a binary score 'yes' or 'no' to indicate whether the answer is useful to resolve a question.
147
+ If the answer contains repeated phrases, repetition, then return 'no'\n
148
+ Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.""",
149
+ input_variables=["question"],
150
+ )
151
+
152
+ # Combine the prompt with the structured LLM toxicity checker
153
+ toxicity_grader = prompt | structured_llm_toxicity_checker
154
+
155
+ # Return the toxicity checker object
156
+ return toxicity_grader
157
+
158
+
159
+ def grade_question_toxicity(state):
160
+ """
161
+ Grades the question for toxicity.
162
+
163
+ Args:
164
+ state (dict): The current graph state.
165
+
166
+ Returns:
167
+ str: 'good' if the question passes the toxicity check, 'bad' otherwise.
168
+ """
169
+ steps = state["steps"]
170
+ steps.append("promt guard")
171
+ score = toxicity_grader.invoke({"question": state["question"]})
172
+ grade = getattr(score, 'score', None)
173
+
174
+ if grade == "yes":
175
+ return "bad"
176
+ else:
177
+ return "good"
178
+
179
+
180
+
181
+ def create_helpfulness_checker(llm):
182
+ """
183
+ Function to create a helpfulness checker object using a passed LLM model.
184
+
185
+ Args:
186
+ llm: The language model to be used for checking the helpfulness of answers.
187
+
188
+ Returns:
189
+ Callable: A pipeline function that checks if the student's answer is helpful.
190
+ """
191
+
192
+ # Define the class for helpfulness grading inside the function
193
+ class GradeHelpfulness(BaseModel):
194
+ """Binary score for Helpfulness check on answer."""
195
+ score: str = Field(
196
+ description="Is the answer helpfulness, 'yes' or 'no'"
197
+ )
198
+
199
+ # Create the structured LLM helpfulness checker using the passed LLM
200
+ structured_llm_helpfulness_checker = llm.with_structured_output(GradeHelpfulness)
201
+
202
+ # Define the prompt template
203
+ prompt = PromptTemplate(
204
+ template="""You will be given a QUESTION and a STUDENT ANSWER.
205
+
206
+ Here is the grade criteria to follow:
207
+
208
+ (1) Ensure the STUDENT ANSWER is concise and relevant to the QUESTION
209
+
210
+ (2) Ensure the STUDENT ANSWER helps to answer the QUESTION
211
+
212
+ Score:
213
+
214
+ A score of yes means that the student's answer meets all of the criteria. This is the highest (best) score.
215
+
216
+ A score of no means that the student's answer does not meet all of the criteria. This is the lowest possible score you can give.
217
+
218
+ Explain your reasoning in a step-by-step manner to ensure your reasoning and conclusion are correct.
219
+
220
+ Avoid simply stating the correct answer at the outset.
221
+
222
+ If the answer contains repeated phrases, repetition, then return 'no'\n
223
+ Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.""",
224
+ input_variables=["generation", "question"],
225
+ )
226
+
227
+ # Combine the prompt with the structured LLM helpfulness checker
228
+ helpfulness_grader = prompt | structured_llm_helpfulness_checker
229
+
230
+ # Return the helpfulness checker object
231
+ return helpfulness_grader
232
+
233
+
234
+ def grade_document_relevance(question: str, document: str):
235
+ input_data = {"documents": documents,"question": question, }
236
+ try:
237
+ result = retrieval_grader.invoke(input_data)
238
+ return result
239
+ except Exception as e:
240
+ print(f"Error parsing result: {e}")
241
+ return {"score": "no"} # Default to "no" if there is an error
242
+
243
+ # Example usage
244
+ question = "What are the types of agent memory?"
245
+ documents = "Agents can have various types of memory, such as short-term memory and long-term memory."
246
+ grade = grade_document_relevance(documents,question )
247
+ print(grade) # Expected output: {'value': 'yes'}
248
+
249
+
250
+ def create_hallucination_checker(llm):
251
+ """
252
+ Function to create a hallucination checker object using a passed LLM model.
253
+
254
+ Args:
255
+ llm: The language model to be used for checking hallucinations in the student's answer.
256
+
257
+ Returns:
258
+ Callable: A pipeline function that checks if the student's answer contains hallucinations.
259
+ """
260
+
261
+ # Define the class for hallucination grading inside the function
262
+ class GradeHaliucinations(BaseModel):
263
+ """Binary score for hallucinations check on answer."""
264
+ score: str = Field(
265
+ description="Answer contains hallucinations, 'yes' or 'no'"
266
+ )
267
+
268
+ # Create the structured LLM hallucination checker using the passed LLM
269
+ structured_llm_haliucinations_checker = llm.with_structured_output(GradeHaliucinations)
270
+
271
+ # Define the prompt template
272
+ prompt = PromptTemplate(
273
+ template="""You are a teacher grading a quiz.
274
+
275
+ You will be given FACTS and a STUDENT ANSWER.
276
+
277
+ You are grading STUDENT ANSWER of source FACTS. Focus on correctness of the STUDENT ANSWER and detection of any hallucinations.
278
+
279
+ Ensure that the STUDENT ANSWER meets the following criteria:
280
+
281
+ (1) it does not contain information outside of the FACTS
282
+
283
+ (2) the STUDENT ANSWER should be fully grounded in and based upon information in the source documents
284
+
285
+ Score:
286
+
287
+ A score of yes means that the student's answer meets all of the criteria. This is the highest (best) score.
288
+
289
+ A score of no means that the student's answer does not meet all of the criteria. This is the lowest possible score you can give.
290
+
291
+ Explain your reasoning in a step-by-step manner to ensure your reasoning and conclusion are correct.
292
+
293
+ Avoid simply stating the correct answer at the outset.
294
+ STUDENT ANSWER: {generation} \n
295
+ Fact: \n\n {documents} \n\n
296
+
297
+ Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. \n
298
+ Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.
299
+ """,
300
+ input_variables=["generation", "documents"],
301
+ )
302
+
303
+ # Combine the prompt with the structured LLM hallucination checker
304
+ hallucination_grader = prompt | structured_llm_haliucinations_checker
305
+
306
+ # Return the hallucination checker object
307
+ return hallucination_grader
308
+
309
+
310
+ def create_question_rewriter(llm):
311
+ """
312
+ Function to create a question rewriter object using a passed LLM model.
313
+
314
+ Args:
315
+ llm: The language model to be used for rewriting questions.
316
+
317
+ Returns:
318
+ Callable: A pipeline function that rewrites questions for optimized vector store retrieval.
319
+ """
320
+
321
+ # Define the prompt template for question rewriting
322
+ re_write_prompt = PromptTemplate(
323
+ template="""You are a question re-writer that converts an input question to a better version that is optimized for vector store retrieval.\n
324
+ Your task is to enhance the question by clarifying the intent, removing any ambiguity, and including specific details to retrieve the most relevant information.\n
325
+ I don't need explanations, only the enhanced question.
326
+ Here is the initial question: \n\n {question}. Improved question with no preamble: \n """,
327
+ input_variables=["question"],
328
+ )
329
+
330
+ # Combine the prompt with the LLM and output parser
331
+ question_rewriter = re_write_prompt | llm | StrOutputParser()
332
+
333
+ # Return the question rewriter object
334
+ return question_rewriter
335
+
336
+
337
+ def transform_query(state):
338
+ """
339
+ Transform the query to produce a better question.
340
+
341
+ Args:
342
+ state (dict): The current graph state
343
+
344
+ Returns:
345
+ state (dict): Updates question key with a re-phrased question
346
+ """
347
+
348
+ print("---TRANSFORM QUERY---")
349
+ question = state["question"]
350
+ documents = state["documents"]
351
+ steps = state["steps"]
352
+ steps.append("question_transformation")
353
+
354
+ # Re-write question
355
+ better_question = question_rewriter.invoke({"question": question})
356
+ print(f" Transformed question: {better_question}")
357
+ return {"documents": documents, "question": better_question}
358
+
359
+
360
+
361
+
362
+ def format_google_results(google_results):
363
+ formatted_documents = []
364
+
365
+ # Loop through each organic result and create a Document for it
366
+ for result in google_results['organic']:
367
+ title = result.get('title', 'No title')
368
+ link = result.get('link', 'No link')
369
+ snippet = result.get('snippet', 'No summary available')
370
+
371
+ # Create a Document object with similar metadata structure to WikipediaRetriever
372
+ document = Document(
373
+ metadata={
374
+ 'title': title,
375
+ 'summary': snippet,
376
+ 'source': link
377
+ },
378
+ page_content=snippet # Using the snippet as the page content
379
+ )
380
+
381
+ formatted_documents.append(document)
382
+
383
+ return formatted_documents
384
+
385
+
386
+ def grade_generation_v_documents_and_question(state):
387
+ """
388
+ Determines whether the generation is grounded in the document and answers the question.
389
+ """
390
+ print("---CHECK HALLUCINATIONS---")
391
+ question = state["question"]
392
+ documents = state["documents"]
393
+ generation = state["generation"]
394
+ generation_count = state.get("generation_count") # Use state.get to avoid KeyError
395
+ print(f" generation number: {generation_count}")
396
+
397
+ # Grading hallucinations
398
+ score = hallucination_grader.invoke(
399
+ {"documents": documents, "generation": generation}
400
+ )
401
+ grade = getattr(score, 'score', None)
402
+
403
+ # Check hallucination
404
+ if grade == "yes":
405
+ print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
406
+ # Check question-answering
407
+ print("---GRADE GENERATION vs QUESTION---")
408
+ score = answer_grader.invoke({"question": question, "generation": generation})
409
+ grade = getattr(score, 'score', None)
410
+ if grade == "yes":
411
+ print("---DECISION: GENERATION ADDRESSES QUESTION---")
412
+ return "useful"
413
+ else:
414
+ print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
415
+ return "not useful"
416
+ else:
417
+ if generation_count > 1:
418
+ print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, TRANSFORM QUERY---")
419
+ # Reset count if it exceeds limit
420
+ return "not useful"
421
+ else:
422
+ print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
423
+ # Increment correctly here
424
+ print(f" generation number after increment: {state['generation_count']}")
425
+ return "not supported"
426
+
427
+
428
+ def ask_question(state):
429
+ """
430
+ Initialize question
431
+
432
+ Args:
433
+ state (dict): The current graph state
434
+
435
+ Returns:
436
+ state (dict): Question
437
+ """
438
+ steps = state["steps"]
439
+ question = state["question"]
440
+ generations_count = state.get("generations_count", 0)
441
+ documents = retriever.invoke(question)
442
+
443
+ steps.append("question_asked")
444
+ return {"question": question, "steps": steps,"generation_count": generations_count}
445
+
446
+
447
+ def retrieve(state):
448
+ """
449
+ Retrieve documents
450
+
451
+ Args:
452
+ state (dict): The current graph state
453
+
454
+ Returns:
455
+ state (dict): New key added to state, documents, that contains retrieved documents
456
+ """
457
+ steps = state["steps"]
458
+ question = state["question"]
459
+
460
+ documents = retriever.invoke(question)
461
+
462
+ steps.append("retrieve_documents")
463
+ return {"documents": documents, "question": question, "steps": steps}
464
+
465
+
466
+ def generate(state):
467
+ """
468
+ Generate answer
469
+ """
470
+ question = state["question"]
471
+ documents = state["documents"]
472
+ generation = rag_chain.invoke({"documents": documents, "question": question})
473
+ steps = state["steps"]
474
+ steps.append("generate_answer")
475
+ generation_count = state["generation_count"]
476
+
477
+ generation_count += 1
478
+
479
+ return {
480
+ "documents": documents,
481
+ "question": question,
482
+ "generation": generation,
483
+ "steps": steps,
484
+ "generation_count": generation_count # Include generation_count in return
485
+ }
486
+
487
+
488
+ def grade_documents(state):
489
+ question = state["question"]
490
+ documents = state["documents"]
491
+ steps = state["steps"]
492
+ steps.append("grade_document_retrieval")
493
+
494
+ filtered_docs = []
495
+ web_results_list = []
496
+ search = "No"
497
+
498
+ for d in documents:
499
+ # Call the grading function
500
+ score = retrieval_grader.invoke({"question": question, "documents": d.page_content})
501
+ print(f"Grader output for document: {score}") # Detailed debugging output
502
+
503
+ # Extract the grade
504
+ grade = getattr(score, 'binary_score', None)
505
+ if grade and grade.lower() in ["yes", "true", "1"]:
506
+ filtered_docs.append(d)
507
+ elif len(filtered_docs) < 4:
508
+ search = "Yes"
509
+
510
+ # Check the decision-making process
511
+ print(f"Final decision - Perform web search: {search}")
512
+ print(f"Filtered documents count: {len(filtered_docs)}")
513
+
514
+ return {
515
+ "documents": filtered_docs,
516
+ "question": question,
517
+ "search": search,
518
+ "steps": steps,
519
+ }
520
+
521
+ def web_search(state):
522
+ question = state["question"]
523
+ documents = state.get("documents")
524
+ steps = state["steps"]
525
+ steps.append("web_search")
526
+ k = 4 - len(documents)
527
+ good_wiki_splits = []
528
+ good_exa_splits = []
529
+ web_results_list = []
530
+
531
+ wiki_results = WikipediaRetriever( lang = 'en',top_k_results = 1,doc_content_chars_max = 1000).invoke(question)
532
+
533
+
534
+ if k<1:
535
+ combined_documents = documents + wiki_results
536
+ else:
537
+ web_results = GoogleSerperAPIWrapper(k = k).results(question)
538
+ formatted_documents = format_google_results(web_results)
539
+ for doc in formatted_documents:
540
+ web_results_list.append(doc)
541
+
542
+
543
+ combined_documents = documents + wiki_results + web_results_list
544
+
545
+ return {"documents": combined_documents, "question": question, "steps": steps}
546
+
547
+ def decide_to_generate(state):
548
+ """
549
+ Determines whether to generate an answer, or re-generate a question.
550
+
551
+ Args:
552
+ state (dict): The current graph state
553
+
554
+ Returns:
555
+ str: Binary decision for next node to call
556
+ """
557
+ search = state["search"]
558
+ if search == "Yes":
559
+ return "search"
560
+ else:
561
+ return "generate"
562
+