jaindivyansh commited on
Commit
9d526c3
·
verified ·
1 Parent(s): 22f2c80
Files changed (1) hide show
  1. rag.py +11 -0
rag.py CHANGED
@@ -1,7 +1,9 @@
1
  # %%
 
2
  import os
3
  import json
4
 
 
5
  import torch
6
  import faiss
7
  import numpy as np
@@ -13,8 +15,10 @@ from transformers import (
13
  AutoTokenizer,
14
  )
15
 
 
16
  HF_TOKEN = os.environ["hf_token"]
17
 
 
18
  SYSTEM_PROMPT = """You are a helpful question answering assistant. You will be given a context and a question. You need to provide the answer to the question based on the context. Answer briefly, based on the context. Only output the answer, and nothing else. Here is an example:
19
 
20
  >> Context
@@ -40,12 +44,14 @@ USER_PROMPT = """
40
  """
41
 
42
 
 
43
  def load_embedder(model_path: str, device: str) -> SentenceTransformer:
44
  embedder = SentenceTransformer(model_path)
45
  embedder.to(device)
46
  return embedder
47
 
48
 
 
49
  def load_contexts(context_file: str) -> list[str]:
50
  contexts = []
51
  with open(context_file, "r") as f_in:
@@ -56,10 +62,12 @@ def load_contexts(context_file: str) -> list[str]:
56
  return contexts
57
 
58
 
 
59
  def load_index(index_file: str) -> faiss.Index:
60
  return faiss.read_index(index_file)
61
 
62
 
 
63
  def load_reader(model_path: str, device: str) -> TextGenerationPipeline:
64
  model = AutoModelForCausalLM.from_pretrained(model_path, token=HF_TOKEN)
65
 
@@ -78,6 +86,7 @@ def load_reader(model_path: str, device: str) -> TextGenerationPipeline:
78
  return reader
79
 
80
 
 
81
  def construct_prompt(contexts: list[str], question: str) -> list[dict]:
82
  return [
83
  {"role": "system", "content": SYSTEM_PROMPT},
@@ -90,6 +99,7 @@ def construct_prompt(contexts: list[str], question: str) -> list[dict]:
90
  ]
91
 
92
 
 
93
  def load_all(
94
  embedder_path: str,
95
  context_file: str,
@@ -110,6 +120,7 @@ def load_all(
110
  }
111
 
112
 
 
113
  def run_query(
114
  question: str,
115
  embedder: SentenceTransformer,
 
1
  # %%
2
+
3
  import os
4
  import json
5
 
6
+
7
  import torch
8
  import faiss
9
  import numpy as np
 
15
  AutoTokenizer,
16
  )
17
 
18
+
19
  HF_TOKEN = os.environ["hf_token"]
20
 
21
+
22
  SYSTEM_PROMPT = """You are a helpful question answering assistant. You will be given a context and a question. You need to provide the answer to the question based on the context. Answer briefly, based on the context. Only output the answer, and nothing else. Here is an example:
23
 
24
  >> Context
 
44
  """
45
 
46
 
47
+
48
  def load_embedder(model_path: str, device: str) -> SentenceTransformer:
49
  embedder = SentenceTransformer(model_path)
50
  embedder.to(device)
51
  return embedder
52
 
53
 
54
+
55
  def load_contexts(context_file: str) -> list[str]:
56
  contexts = []
57
  with open(context_file, "r") as f_in:
 
62
  return contexts
63
 
64
 
65
+
66
  def load_index(index_file: str) -> faiss.Index:
67
  return faiss.read_index(index_file)
68
 
69
 
70
+
71
  def load_reader(model_path: str, device: str) -> TextGenerationPipeline:
72
  model = AutoModelForCausalLM.from_pretrained(model_path, token=HF_TOKEN)
73
 
 
86
  return reader
87
 
88
 
89
+
90
  def construct_prompt(contexts: list[str], question: str) -> list[dict]:
91
  return [
92
  {"role": "system", "content": SYSTEM_PROMPT},
 
99
  ]
100
 
101
 
102
+
103
  def load_all(
104
  embedder_path: str,
105
  context_file: str,
 
120
  }
121
 
122
 
123
+
124
  def run_query(
125
  question: str,
126
  embedder: SentenceTransformer,