ccm commited on
Commit
6dbf6d3
·
verified ·
1 Parent(s): d227ba0

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +45 -17
main.py CHANGED
@@ -17,18 +17,23 @@ GREETING = (
17
  "https://en.wikipedia.org/wiki/Retrieval-augmented_generation) pipeline to answer questions about research by the "
18
  "Design Research Collective. And the best part is that I always cite my sources! What can I tell you about today?"
19
  )
 
 
 
 
 
20
  EMBEDDING_MODEL_NAME = "allenai-specter"
21
  LLM_MODEL_NAME = "Qwen/Qwen2-7B-Instruct"
 
22
 
23
  # Load the dataset and convert to pandas
24
- full_data = datasets.load_dataset("ccm/publications")["train"].to_pandas()
25
 
26
  # Filter out any publications without an abstract
27
  abstract_is_null = [
28
- '"abstract": null' in json.dumps(bibdict)
29
- for bibdict in full_data["bib_dict"].values
30
  ]
31
- data = full_data[~pandas.Series(abstract_is_null)]
32
  data.reset_index(inplace=True)
33
 
34
  # Create a FAISS index for fast similarity search
@@ -44,8 +49,15 @@ index.add(vectors)
44
  model = sentence_transformers.SentenceTransformer(EMBEDDING_MODEL_NAME)
45
 
46
 
47
- # Define the search function
48
- def search(query: str, k: int) -> tuple[str]:
 
 
 
 
 
 
 
49
  query = numpy.expand_dims(model.encode(query), axis=0)
50
  faiss.normalize_L2(query)
51
  D, I = index.search(query, k)
@@ -100,20 +112,40 @@ chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
100
  )
101
 
102
 
103
- def preprocess(message: str) -> tuple[str]:
104
- """Applies a preprocessing step to the user's message before the LLM receives it"""
 
 
 
 
 
 
105
  block_search_results, formatted_search_results = search(message, 5)
106
  return block_search_results + message, formatted_search_results
107
 
108
 
109
  def postprocess(response: str, bypass_from_preprocessing: str) -> str:
110
- """Applies a postprocessing step to the LLM's response before the user receives it"""
 
 
 
 
 
 
 
111
  return response + bypass_from_preprocessing
112
 
113
 
114
  @spaces.GPU
115
- def predict(message: str, history: list[str]) -> str:
116
- """This function is responsible for crafting a response"""
 
 
 
 
 
 
 
117
 
118
  # Apply preprocessing
119
  message, bypass = preprocess(message)
@@ -150,12 +182,8 @@ def predict(message: str, history: list[str]) -> str:
150
 
151
  # Create and run the gradio interface
152
  gradio.ChatInterface(
153
- predict,
154
- examples=[
155
- "Tell me about new research at the intersection of additive manufacturing and machine learning",
156
- "What is a physics-informed neural network and what can it be used for?",
157
- "What can agent-based models do about climate change?",
158
- ],
159
  chatbot=gradio.Chatbot(
160
  show_label=False, show_copy_button=True, value=[["", GREETING]]
161
  ),
 
17
  "https://en.wikipedia.org/wiki/Retrieval-augmented_generation) pipeline to answer questions about research by the "
18
  "Design Research Collective. And the best part is that I always cite my sources! What can I tell you about today?"
19
  )
20
+ EXAMPLE_QUERIES = [
21
+ "Tell me about new research at the intersection of additive manufacturing and machine learning",
22
+ "What is a physics-informed neural network and what can it be used for?",
23
+ "What can agent-based models do about climate change?",
24
+ ]
25
  EMBEDDING_MODEL_NAME = "allenai-specter"
26
  LLM_MODEL_NAME = "Qwen/Qwen2-7B-Instruct"
27
+ # LLM_MODEL_NAME = "Qwen/Qwen2-0.5B-Instruct"
28
 
29
  # Load the dataset and convert to pandas
30
+ data = datasets.load_dataset("ccm/publications")["train"].to_pandas()
31
 
32
  # Filter out any publications without an abstract
33
  abstract_is_null = [
34
+ '"abstract": null' in json.dumps(bibdict) for bibdict in data["bib_dict"].values
 
35
  ]
36
+ data = data[~pandas.Series(abstract_is_null)]
37
  data.reset_index(inplace=True)
38
 
39
  # Create a FAISS index for fast similarity search
 
49
  model = sentence_transformers.SentenceTransformer(EMBEDDING_MODEL_NAME)
50
 
51
 
52
+ def search(query: str, k: int) -> tuple[str, str]:
53
+ """
54
+ Searches the dataset for the top k most relevant papers to the query
55
+ Args:
56
+ query (str): The user's query
57
+ k (int): The number of results to return
58
+ Returns:
59
+ tuple[str, str]: A tuple containing the search results and references
60
+ """
61
  query = numpy.expand_dims(model.encode(query), axis=0)
62
  faiss.normalize_L2(query)
63
  D, I = index.search(query, k)
 
112
  )
113
 
114
 
115
+ def preprocess(message: str) -> tuple[str, str]:
116
+ """
117
+ Applies a preprocessing step to the user's message before the LLM receives it
118
+ Args:
119
+ message (str): The user's message
120
+ Returns:
121
+ tuple[str, str]: A tuple containing the preprocessed message and a bypass variable
122
+ """
123
  block_search_results, formatted_search_results = search(message, 5)
124
  return block_search_results + message, formatted_search_results
125
 
126
 
127
  def postprocess(response: str, bypass_from_preprocessing: str) -> str:
128
+ """
129
+ Applies a postprocessing step to the LLM's response before the user receives it
130
+ Args:
131
+ response (str): The LLM's response
132
+ bypass_from_preprocessing (str): The bypass variable from the preprocessing step
133
+ Returns:
134
+ str: The postprocessed response
135
+ """
136
  return response + bypass_from_preprocessing
137
 
138
 
139
  @spaces.GPU
140
+ def reply(message: str, history: list[str]) -> str:
141
+ """
142
+ This function is responsible for crafting a response
143
+ Args:
144
+ message (str): The user's message
145
+ history (list[str]): The conversation history
146
+ Returns:
147
+ str: The AI's response
148
+ """
149
 
150
  # Apply preprocessing
151
  message, bypass = preprocess(message)
 
182
 
183
  # Create and run the gradio interface
184
  gradio.ChatInterface(
185
+ reply,
186
+ examples=EXAMPLE_QUERIES,
 
 
 
 
187
  chatbot=gradio.Chatbot(
188
  show_label=False, show_copy_button=True, value=[["", GREETING]]
189
  ),