snsynth's picture
update relevance check logic
a6d3adb
import os
import math
import numpy as np
from llama_cpp import Llama
from llama_index.llms.llama_cpp import LlamaCPP
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import QueryFusionRetriever
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core import StorageContext, load_index_from_storage, QueryBundle
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.postprocessor import LLMRerank
from llama_index.core.node_parser import TokenTextSplitter
from transformers import AutoTokenizer
from llama_index.core.postprocessor import SentenceTransformerRerank
_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
def messages_to_prompt(messages):
messages = [{"role": m.role.value, "content": m.content} for m in messages]
prompt = _tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return prompt
def completion_to_prompt(completion):
messages = [{"role": "user", "content": completion}]
prompt = _tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return prompt
llm = LlamaCPP(
model_path="models/Llama-3.2-1B-Instruct-Q4_K_M.gguf",
temperature=0.1,
max_new_tokens=128,
context_window=16384,
model_kwargs={"n_gpu_layers":-1, 'logits_all': False},
messages_to_prompt=messages_to_prompt,
completion_to_prompt=completion_to_prompt,)
llm2 = Llama(model_path="models/Llama-3.2-1B-Instruct-Q4_K_M.gguf",
n_gpu_layers=-1, n_ctx=8000, logits_all=True)
embedding_model = HuggingFaceEmbedding(
model_name="models/all-MiniLM-L6-v2"
)
Settings.llm = llm
Settings.embed_model = embedding_model
def check_if_exists():
index = os.path.exists("models/precomputed_index")
bm25 = os.path.exists("models/bm25_retriever")
if index and bm25:
return True
else:
return False
def precompute_index(data_folder='data'):
documents = SimpleDirectoryReader(data_folder).load_data()
splitter = TokenTextSplitter(chunk_size=400, chunk_overlap=50)
nodes = splitter.get_nodes_from_documents(documents)
index = VectorStoreIndex(nodes, verbose=True)
# index = VectorStoreIndex.from_documents(documents)
index.storage_context.persist(persist_dir='models/precomputed_index')
bm25_retriever = BM25Retriever.from_defaults(
nodes=nodes,
similarity_top_k=5,
verbose=True
)
bm25_retriever.persist("models/bm25_retriever")
def is_harmful(query):
harmful_keywords = ["bomb", "kill", "weapon", "suicide", "terror", "attack"]
return any(keyword in query.lower() for keyword in harmful_keywords)
def is_not_relevant(query, index, threshold=0.7):
retriever = index.as_retriever(similarity_top_k=1)
nodes = retriever.retrieve(query)
if not nodes:
return False
similarity = nodes[0].score
return similarity <= threshold
def get_sequence_probability(llm, input_sequence):
input_tokens = llm.tokenize(input_sequence.encode("utf-8"))
llm.eval(input_tokens)
probs = llm.logits_to_logprobs(llm.eval_logits)
total_log_prob = 0.0
for i, token in enumerate(input_tokens):
token_log_prob = probs[i, token]
total_log_prob += token_log_prob
sequence_probability = math.exp(total_log_prob)
return sequence_probability
def answer_question(query):
if is_harmful(query):
return "This query has been flagged as unsafe."
print("loading bm25 retriever")
bm25_retriever = BM25Retriever.from_persist_dir("models/bm25_retriever")
print("loading saved vector index")
storage_context = StorageContext.from_defaults(persist_dir="models/precomputed_index")
index = load_index_from_storage(storage_context)
if is_not_relevant(query, index, 0.2):
return "This query doesn't appear relevant to finance."
retriever = QueryFusionRetriever(
[
index.as_retriever(similarity_top_k=5, verbose=True),
bm25_retriever,
],
llm=llm,
num_queries=1,
similarity_top_k=5,
verbose=True
)
reranker = SentenceTransformerRerank(
model="cross-encoder/ms-marco-MiniLM-L-2-v2",
top_n=5
)
keyword_query_engine = RetrieverQueryEngine(
retriever=retriever,
node_postprocessors=[reranker],
)
response = keyword_query_engine.query(f"Answer in less than 100 words: \nQuery:{query}")
response_text = str(response)
response_prob = get_sequence_probability(llm2, response_text)
print(f"Output probability: {response_prob}")
return response_text