ccm commited on
Commit
03004bb
·
verified ·
1 Parent(s): c487a2a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +22 -13
main.py CHANGED
@@ -1,5 +1,5 @@
1
  import threading # to allow streaming response
2
- import time # to pave the deliver of the message
3
 
4
  import datasets # for loading RAG database
5
  import faiss # to create a search index
@@ -47,11 +47,11 @@ chat_model = transformers.AutoModelForCausalLM.from_pretrained(
47
 
48
  # Create a FAISS index for fast similarity search
49
  vectors = numpy.stack(data["embedding"].tolist(), axis=0).astype("float32")
50
- index = faiss.IndexFlatL2(len(data["embedding"][0]))
51
- index.metric_type = faiss.METRIC_INNER_PRODUCT
52
  faiss.normalize_L2(vectors)
53
- index.train(vectors)
54
- index.add(vectors)
55
 
56
 
57
  def preprocess(query: str, k: int) -> tuple[str, str]:
@@ -65,8 +65,8 @@ def preprocess(query: str, k: int) -> tuple[str, str]:
65
  """
66
  encoded_query = numpy.expand_dims(embedding_model.encode(query), axis=0)
67
  faiss.normalize_L2(encoded_query)
68
- D, I = index.search(encoded_query, k)
69
- top_five = data.loc[I[0]]
70
 
71
  print(top_five["text"].values)
72
 
@@ -86,7 +86,6 @@ def preprocess(query: str, k: int) -> tuple[str, str]:
86
  title = top_five["title"].values[i]
87
  id = top_five["id"].values[i]
88
  url = "https://doi.org/10.1115/" + id
89
- path = top_five["path"].values[i]
90
  text = top_five["text"].values[i]
91
 
92
  research_excerpts += (
@@ -104,16 +103,26 @@ def preprocess(query: str, k: int) -> tuple[str, str]:
104
 
105
  print(references)
106
 
107
- return prompt, "\n\n### References\n\n" + "\n".join(
108
  [
109
- str(i + 1)
110
  + ". "
111
- + ref
112
- + "\n - ".join(["", *['"...' + x + '..."' for x in references[ref]]])
113
- for i, ref in enumerate(references.keys())
 
 
 
 
 
 
 
 
114
  ]
115
  )
116
 
 
 
117
 
118
  def postprocess(response: str, bypass_from_preprocessing: str) -> str:
119
  """
 
1
  import threading # to allow streaming response
2
+ import time # to pave the delivery of the message
3
 
4
  import datasets # for loading RAG database
5
  import faiss # to create a search index
 
47
 
48
  # Create a FAISS index for fast similarity search
49
  vectors = numpy.stack(data["embedding"].tolist(), axis=0).astype("float32")
50
+ excerpt_index = faiss.IndexFlatL2(len(data["embedding"][0]))
51
+ excerpt_index.metric_type = faiss.METRIC_INNER_PRODUCT
52
  faiss.normalize_L2(vectors)
53
+ excerpt_index.train(vectors)
54
+ excerpt_index.add(vectors)
55
 
56
 
57
  def preprocess(query: str, k: int) -> tuple[str, str]:
 
65
  """
66
  encoded_query = numpy.expand_dims(embedding_model.encode(query), axis=0)
67
  faiss.normalize_L2(encoded_query)
68
+ _, indices = excerpt_index.search(encoded_query, k)
69
+ top_five = data.loc[indices[0]]
70
 
71
  print(top_five["text"].values)
72
 
 
86
  title = top_five["title"].values[i]
87
  id = top_five["id"].values[i]
88
  url = "https://doi.org/10.1115/" + id
 
89
  text = top_five["text"].values[i]
90
 
91
  research_excerpts += (
 
103
 
104
  print(references)
105
 
106
+ list_of_references = "\n".join(
107
  [
108
+ str(idx + 1)
109
  + ". "
110
+ + hyperlinked_title
111
+ + "\n\n> ".join(
112
+ [
113
+ "",
114
+ *[
115
+ '"...' + excerpt + '..."'
116
+ for excerpt in references[hyperlinked_title]
117
+ ],
118
+ ]
119
+ )
120
+ for idx, hyperlinked_title in enumerate(references.keys())
121
  ]
122
  )
123
 
124
+ return prompt, "\n\n### References\n\n" + list_of_references
125
+
126
 
127
  def postprocess(response: str, bypass_from_preprocessing: str) -> str:
128
  """