anindya-hf-2002's picture
Upload 19 files
db17bc0 verified
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama import ChatOllama
from pydantic import BaseModel, Field
from typing import List
class DocumentRelevance(BaseModel):
"""Binary score for relevance check on retrieved documents."""
binary_score: str = Field(
description="Documents are relevant to the question, 'yes' or 'no'"
)
class HallucinationCheck(BaseModel):
"""Binary score for hallucination present in generation answer."""
binary_score: str = Field(
description="Answer is grounded in the facts, 'yes' or 'no'"
)
class AnswerQuality(BaseModel):
"""Binary score to assess answer addresses question."""
binary_score: str = Field(
description="Answer addresses the question, 'yes' or 'no'"
)
def create_llm_grader(grader_type: str, llm):
"""
Create an LLM grader based on the specified type.
Args:
grader_type (str): Type of grader to create
Returns:
Callable: LLM grader function
"""
# Initialize LLM
# Select grader type and create structured output
if grader_type == "document_relevance":
structured_llm_grader = llm.with_structured_output(DocumentRelevance)
system = """You are a grader assessing relevance of a retrieved document to a user question.
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant.
It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
prompt = ChatPromptTemplate.from_messages([
("system", system),
("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
])
elif grader_type == "hallucination":
structured_llm_grader = llm.with_structured_output(HallucinationCheck)
system = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts.
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
prompt = ChatPromptTemplate.from_messages([
("system", system),
("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"),
])
elif grader_type == "answer_quality":
structured_llm_grader = llm.with_structured_output(AnswerQuality)
system = """You are a grader assessing whether an answer addresses / resolves a question.
Give a binary score 'yes' or 'no'. 'Yes' means that the answer resolves the question."""
prompt = ChatPromptTemplate.from_messages([
("system", system),
("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
])
else:
raise ValueError(f"Unknown grader type: {grader_type}")
return prompt | structured_llm_grader
def grade_document_relevance(question: str, document: str, llm):
"""
Grade the relevance of a document to a given question.
Args:
question (str): User's question
document (str): Retrieved document content
Returns:
str: Binary score ('yes' or 'no')
"""
grader = create_llm_grader("document_relevance", llm)
result = grader.invoke({"question": question, "document": document})
return result.binary_score
def check_hallucination(documents: List[str], generation: str, llm):
"""
Check if the generation is grounded in the provided documents.
Args:
documents (List[str]): List of source documents
generation (str): LLM generated answer
Returns:
str: Binary score ('yes' or 'no')
"""
grader = create_llm_grader("hallucination", llm)
result = grader.invoke({"documents": documents, "generation": generation})
return result.binary_score
def grade_answer_quality(question: str, generation: str, llm):
"""
Grade the quality of the answer in addressing the question.
Args:
question (str): User's original question
generation (str): LLM generated answer
Returns:
str: Binary score ('yes' or 'no')
"""
grader = create_llm_grader("answer_quality", llm)
result = grader.invoke({"question": question, "generation": generation})
return result.binary_score
if __name__ == "__main__":
# Example usage
test_question = "What are the types of agent memory?"
test_document = "Agent memory can be classified into different types such as episodic, semantic, and working memory."
test_generation = "Agent memory includes episodic memory for storing experiences, semantic memory for general knowledge, and working memory for immediate processing."
llm = ChatOllama(model = "llama3.2", temperature = 0.1, num_predict = 256, top_p=0.5)
print("Document Relevance:", grade_document_relevance(test_question, test_document, llm))
print("Hallucination Check:", check_hallucination([test_document], test_generation, llm))
print("Answer Quality:", grade_answer_quality(test_question, test_generation, llm))