Spaces:
Sleeping
Sleeping
change reranker
Browse files- rag_app/rag_2.py +41 -10
rag_app/rag_2.py
CHANGED
@@ -9,13 +9,39 @@ from llama_index.core.query_engine import RetrieverQueryEngine
|
|
9 |
from llama_index.core import StorageContext, load_index_from_storage
|
10 |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
11 |
from llama_index.core.postprocessor import LLMRerank
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
llm = LlamaCPP(
|
14 |
model_path="models/Llama-3.2-1B-Instruct-Q4_K_M.gguf",
|
15 |
temperature=0.1,
|
16 |
max_new_tokens=256,
|
17 |
-
context_window=16384
|
18 |
-
|
|
|
|
|
|
|
|
|
19 |
embedding_model = HuggingFaceEmbedding(
|
20 |
model_name="models/all-MiniLM-L6-v2"
|
21 |
)
|
@@ -34,11 +60,15 @@ def check_if_exists():
|
|
34 |
|
35 |
def precompute_index(data_folder='data'):
|
36 |
documents = SimpleDirectoryReader(data_folder).load_data()
|
37 |
-
|
|
|
|
|
|
|
38 |
index.storage_context.persist(persist_dir='models/precomputed_index')
|
39 |
bm25_retriever = BM25Retriever.from_defaults(
|
40 |
-
nodes=
|
41 |
-
similarity_top_k=5
|
|
|
42 |
)
|
43 |
bm25_retriever.persist("models/bm25_retriever")
|
44 |
|
@@ -56,20 +86,21 @@ def answer_question(query):
|
|
56 |
|
57 |
retriever = QueryFusionRetriever(
|
58 |
[
|
59 |
-
index.as_retriever(similarity_top_k=5),
|
60 |
bm25_retriever,
|
61 |
],
|
62 |
llm=llm,
|
63 |
num_queries=1,
|
64 |
similarity_top_k=5,
|
|
|
65 |
)
|
66 |
-
reranker =
|
67 |
-
|
68 |
-
top_n=5
|
69 |
)
|
70 |
keyword_query_engine = RetrieverQueryEngine(
|
71 |
retriever=retriever,
|
72 |
-
node_postprocessors=[reranker]
|
73 |
)
|
74 |
|
75 |
if is_harmful(query):
|
|
|
9 |
from llama_index.core import StorageContext, load_index_from_storage
|
10 |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
11 |
from llama_index.core.postprocessor import LLMRerank
|
12 |
+
from llama_index.core.node_parser import TokenTextSplitter
|
13 |
+
from transformers import AutoTokenizer
|
14 |
+
from llama_index.core.postprocessor import SentenceTransformerRerank
|
15 |
+
|
16 |
+
_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
|
17 |
+
|
18 |
+
|
19 |
+
def messages_to_prompt(messages):
|
20 |
+
messages = [{"role": m.role.value, "content": m.content} for m in messages]
|
21 |
+
prompt = _tokenizer.apply_chat_template(
|
22 |
+
messages, tokenize=False, add_generation_prompt=True
|
23 |
+
)
|
24 |
+
return prompt
|
25 |
+
|
26 |
+
|
27 |
+
def completion_to_prompt(completion):
|
28 |
+
messages = [{"role": "user", "content": completion}]
|
29 |
+
prompt = _tokenizer.apply_chat_template(
|
30 |
+
messages, tokenize=False, add_generation_prompt=True
|
31 |
+
)
|
32 |
+
return prompt
|
33 |
+
|
34 |
|
35 |
llm = LlamaCPP(
|
36 |
model_path="models/Llama-3.2-1B-Instruct-Q4_K_M.gguf",
|
37 |
temperature=0.1,
|
38 |
max_new_tokens=256,
|
39 |
+
context_window=16384,
|
40 |
+
model_kwargs={"n_gpu_layers":-1},
|
41 |
+
messages_to_prompt=messages_to_prompt,
|
42 |
+
completion_to_prompt=completion_to_prompt)
|
43 |
+
|
44 |
+
|
45 |
embedding_model = HuggingFaceEmbedding(
|
46 |
model_name="models/all-MiniLM-L6-v2"
|
47 |
)
|
|
|
60 |
|
61 |
def precompute_index(data_folder='data'):
|
62 |
documents = SimpleDirectoryReader(data_folder).load_data()
|
63 |
+
splitter = TokenTextSplitter(chunk_size=400, chunk_overlap=50)
|
64 |
+
nodes = splitter.get_nodes_from_documents(documents)
|
65 |
+
index = VectorStoreIndex(nodes, verbose=True)
|
66 |
+
# index = VectorStoreIndex.from_documents(documents)
|
67 |
index.storage_context.persist(persist_dir='models/precomputed_index')
|
68 |
bm25_retriever = BM25Retriever.from_defaults(
|
69 |
+
nodes=nodes,
|
70 |
+
similarity_top_k=5,
|
71 |
+
verbose=True
|
72 |
)
|
73 |
bm25_retriever.persist("models/bm25_retriever")
|
74 |
|
|
|
86 |
|
87 |
retriever = QueryFusionRetriever(
|
88 |
[
|
89 |
+
index.as_retriever(similarity_top_k=5, verbose=True),
|
90 |
bm25_retriever,
|
91 |
],
|
92 |
llm=llm,
|
93 |
num_queries=1,
|
94 |
similarity_top_k=5,
|
95 |
+
verbose=True
|
96 |
)
|
97 |
+
reranker = SentenceTransformerRerank(
|
98 |
+
model="cross-encoder/ms-marco-MiniLM-L-2-v2",
|
99 |
+
top_n=5
|
100 |
)
|
101 |
keyword_query_engine = RetrieverQueryEngine(
|
102 |
retriever=retriever,
|
103 |
+
node_postprocessors=[reranker],
|
104 |
)
|
105 |
|
106 |
if is_harmful(query):
|