snsynth commited on
Commit
a643df2
·
1 Parent(s): 0e46c7b

change reranker

Browse files
Files changed (1) hide show
  1. rag_app/rag_2.py +41 -10
rag_app/rag_2.py CHANGED
@@ -9,13 +9,39 @@ from llama_index.core.query_engine import RetrieverQueryEngine
9
  from llama_index.core import StorageContext, load_index_from_storage
10
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
11
  from llama_index.core.postprocessor import LLMRerank
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  llm = LlamaCPP(
14
  model_path="models/Llama-3.2-1B-Instruct-Q4_K_M.gguf",
15
  temperature=0.1,
16
  max_new_tokens=256,
17
- context_window=16384
18
- )
 
 
 
 
19
  embedding_model = HuggingFaceEmbedding(
20
  model_name="models/all-MiniLM-L6-v2"
21
  )
@@ -34,11 +60,15 @@ def check_if_exists():
34
 
35
  def precompute_index(data_folder='data'):
36
  documents = SimpleDirectoryReader(data_folder).load_data()
37
- index = VectorStoreIndex.from_documents(documents)
 
 
 
38
  index.storage_context.persist(persist_dir='models/precomputed_index')
39
  bm25_retriever = BM25Retriever.from_defaults(
40
- nodes=documents,
41
- similarity_top_k=5
 
42
  )
43
  bm25_retriever.persist("models/bm25_retriever")
44
 
@@ -56,20 +86,21 @@ def answer_question(query):
56
 
57
  retriever = QueryFusionRetriever(
58
  [
59
- index.as_retriever(similarity_top_k=5),
60
  bm25_retriever,
61
  ],
62
  llm=llm,
63
  num_queries=1,
64
  similarity_top_k=5,
 
65
  )
66
- reranker = LLMRerank(
67
- choice_batch_size=5,
68
- top_n=5,
69
  )
70
  keyword_query_engine = RetrieverQueryEngine(
71
  retriever=retriever,
72
- node_postprocessors=[reranker]
73
  )
74
 
75
  if is_harmful(query):
 
9
  from llama_index.core import StorageContext, load_index_from_storage
10
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
11
  from llama_index.core.postprocessor import LLMRerank
12
+ from llama_index.core.node_parser import TokenTextSplitter
13
+ from transformers import AutoTokenizer
14
+ from llama_index.core.postprocessor import SentenceTransformerRerank
15
+
16
+ _tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
17
+
18
+
19
+ def messages_to_prompt(messages):
20
+ messages = [{"role": m.role.value, "content": m.content} for m in messages]
21
+ prompt = _tokenizer.apply_chat_template(
22
+ messages, tokenize=False, add_generation_prompt=True
23
+ )
24
+ return prompt
25
+
26
+
27
+ def completion_to_prompt(completion):
28
+ messages = [{"role": "user", "content": completion}]
29
+ prompt = _tokenizer.apply_chat_template(
30
+ messages, tokenize=False, add_generation_prompt=True
31
+ )
32
+ return prompt
33
+
34
 
35
  llm = LlamaCPP(
36
  model_path="models/Llama-3.2-1B-Instruct-Q4_K_M.gguf",
37
  temperature=0.1,
38
  max_new_tokens=256,
39
+ context_window=16384,
40
+ model_kwargs={"n_gpu_layers":-1},
41
+ messages_to_prompt=messages_to_prompt,
42
+ completion_to_prompt=completion_to_prompt)
43
+
44
+
45
  embedding_model = HuggingFaceEmbedding(
46
  model_name="models/all-MiniLM-L6-v2"
47
  )
 
60
 
61
  def precompute_index(data_folder='data'):
62
  documents = SimpleDirectoryReader(data_folder).load_data()
63
+ splitter = TokenTextSplitter(chunk_size=400, chunk_overlap=50)
64
+ nodes = splitter.get_nodes_from_documents(documents)
65
+ index = VectorStoreIndex(nodes, verbose=True)
66
+ # index = VectorStoreIndex.from_documents(documents)
67
  index.storage_context.persist(persist_dir='models/precomputed_index')
68
  bm25_retriever = BM25Retriever.from_defaults(
69
+ nodes=nodes,
70
+ similarity_top_k=5,
71
+ verbose=True
72
  )
73
  bm25_retriever.persist("models/bm25_retriever")
74
 
 
86
 
87
  retriever = QueryFusionRetriever(
88
  [
89
+ index.as_retriever(similarity_top_k=5, verbose=True),
90
  bm25_retriever,
91
  ],
92
  llm=llm,
93
  num_queries=1,
94
  similarity_top_k=5,
95
+ verbose=True
96
  )
97
+ reranker = SentenceTransformerRerank(
98
+ model="cross-encoder/ms-marco-MiniLM-L-2-v2",
99
+ top_n=5
100
  )
101
  keyword_query_engine = RetrieverQueryEngine(
102
  retriever=retriever,
103
+ node_postprocessors=[reranker],
104
  )
105
 
106
  if is_harmful(query):