AXCXEPT commited on
Commit
93ddbeb
·
verified ·
1 Parent(s): ef99e69

Create modeling_custom_qwen.py

Browse files
Files changed (1) hide show
  1. modeling_custom_qwen.py +421 -0
modeling_custom_qwen.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Qwen2Config, Qwen2ForCausalLM
2
+ import torch
3
+ import requests
4
+ from bs4 import BeautifulSoup
5
+ from duckduckgo_search import DDGS
6
+ import logging
7
+ import re
8
+
9
+ # ログの設定
10
+ logging.basicConfig(level=logging.INFO)
11
+
12
+ class CustomQwen2Config(Qwen2Config):
13
+ model_type = "custom_qwen2config"
14
+
15
+ def __init__(self, **kwargs):
16
+ super().__init__(**kwargs)
17
+
18
+ @classmethod
19
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
20
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
21
+ config = super().from_dict(config_dict, **kwargs)
22
+ return config
23
+
24
+ def to_dict(self):
25
+ output = super().to_dict()
26
+ output["model_type"] = self.model_type
27
+ return output
28
+
29
+ class CustomQwen2Model(Qwen2ForCausalLM):
30
+ config_class = CustomQwen2Config
31
+
32
+ def __init__(self, config):
33
+ super().__init__(config)
34
+ self.tokenizer = None
35
+ self.embedding_model = None
36
+ self.max_iterations = 5 # Maximum number of times to recreate keywords
37
+ self.use_search = True
38
+ self.top_k = 3 # of documents to retrieve for each search
39
+ self.max_search_attempts = 3 # of search attempts for each keyword
40
+
41
+ def set_tokenizer(self, tokenizer=None):
42
+ self.tokenizer = tokenizer
43
+
44
+ # パラメータ設定メソッド
45
+ def set_max_iterations(self, max_iterations):
46
+ self.max_iterations = max_iterations
47
+
48
+ def set_use_search(self, use_search):
49
+ self.use_search = use_search
50
+
51
+ def set_top_k(self, top_k):
52
+ self.top_k = top_k
53
+
54
+ def generate_step(self, input_ids, max_new_tokens=150):
55
+ """
56
+ Generates output from input_ids and returns tokenized output.
57
+ """
58
+ input_ids = input_ids.to(self.device)
59
+ output_ids = super().generate(input_ids, max_new_tokens=max_new_tokens)
60
+ return output_ids # Return tokenized results
61
+
62
+ def extract_response(self, output_ids, keyword):
63
+ """
64
+ Extracts the tokens following a specific keyword from the generated response.
65
+ Returns extracted text.
66
+ """
67
+ # Decode generated output to text
68
+ raw_response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
69
+
70
+ # Extract text after keywords
71
+ pattern = rf"{re.escape(keyword)}\s*(.*)"
72
+ match = re.search(pattern, raw_response, re.DOTALL)
73
+
74
+ if match:
75
+ # Return matched parts
76
+ extracted_text = match.group(1).strip()
77
+ return extracted_text
78
+ else:
79
+ # Return empty string if keyword not found
80
+ return "[ALL]" + raw_response
81
+
82
+ def generate(self, input_ids, max_new_tokens=150, **kwargs):
83
+ logging.info(f"Maximum keyword regeneration attempts: {self.max_iterations}")
84
+ logging.info(f"External URL reference: {'Enabled' if self.use_search else 'Disabled'}")
85
+ logging.info(f"k_top value: {self.top_k}")
86
+
87
+ org_instruction = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
88
+
89
+ # of attempts to re-create keywords
90
+ keyword_attempt = 0
91
+ sufficient_info = False
92
+ summarized_info = ""
93
+
94
+ while keyword_attempt < self.max_iterations and not sufficient_info:
95
+ logging.info(f"Keyword regeneration attempt: {keyword_attempt + 1}/{self.max_iterations}")
96
+
97
+ # When using external references
98
+ if self.use_search:
99
+ logging.info("Retrieving relevant information using external URL references...")
100
+ for search_attempt in range(1, self.max_search_attempts + 1):
101
+ logging.info(f"Search attempt: {search_attempt}/{self.max_search_attempts}")
102
+ relevant_docs = self.retrieve_relevant_information(org_instruction, top_k=self.top_k)
103
+ summarized_info = self.summarize_documents(relevant_docs, org_instruction)
104
+
105
+ # Determine whether to accept or reject the answer.
106
+ sufficient_info = self.is_answer_sufficient(summarized_info, org_instruction)
107
+ if sufficient_info:
108
+ logging.info("Sufficient information found.")
109
+ break
110
+ else:
111
+ logging.info("Insufficient information. Attempting next search.")
112
+
113
+ if not sufficient_info:
114
+ # Regenerate keywords
115
+ new_keywords = self.generate_new_keywords(org_instruction)
116
+ if new_keywords:
117
+ org_instruction = self.update_instruction_with_new_keywords(org_instruction, new_keywords)
118
+ logging.info(f"Retrying search with new keywords: {new_keywords}")
119
+ else:
120
+ logging.warning("Failed to generate new keywords.")
121
+ break
122
+
123
+ else:
124
+ summarized_info = ""
125
+ sufficient_info = False
126
+
127
+ keyword_attempt += 1
128
+
129
+ if not sufficient_info:
130
+ logging.info("Relevant data sources not found. Performing self-reasoning.")
131
+ final_response = self.self_reasoning(org_instruction, max_new_tokens)
132
+ else:
133
+ # Perform normal answer generation process
134
+ final_response = self.generate_answer(org_instruction, summarized_info, max_new_tokens)
135
+
136
+ # Return final answer
137
+ final_response_ids = self.tokenizer.encode(final_response, return_tensors="pt").to(self.device)
138
+ return final_response_ids
139
+
140
+ def retrieve_relevant_information(self, user_input, top_k=3):
141
+ search_query = self.generate_search_query(user_input)
142
+ logging.info(f"Generated search query: {search_query}")
143
+
144
+ if not search_query:
145
+ logging.warning("Search query is empty.")
146
+ return ["No relevant information found."]
147
+
148
+ with DDGS() as ddgs:
149
+ search_results = ddgs.text(
150
+ keywords=search_query,
151
+ region='wt-wt',
152
+ safesearch='off',
153
+ timelimit=None,
154
+ max_results=20
155
+ )
156
+ search_results = list(search_results)
157
+
158
+ if not search_results:
159
+ return ["No relevant information found."]
160
+
161
+ # Filtering search results
162
+ documents = []
163
+ for result in search_results:
164
+ if 'body' in result and result['body']:
165
+ documents.append(result['body'])
166
+ elif 'snippet' in result and result['snippet']:
167
+ documents.append(result['snippet'])
168
+
169
+ # Select top k documents
170
+ documents = documents[:top_k]
171
+ return documents
172
+
173
+ def generate_search_query(self, user_input):
174
+ """
175
+ Generates a search query using the model's inference.
176
+ """
177
+ # Create prompt
178
+ prompt = f"""
179
+ User's question:
180
+ {user_input}
181
+
182
+ Organize what you need to know to answer this problem and list three keywords to research.
183
+
184
+ Keywords:
185
+ -"""
186
+ # Encode prompt
187
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
188
+ # Generate output from model
189
+ output_ids = self.generate_step(input_ids, max_new_tokens=50)
190
+ # Extract keywords from output
191
+ generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
192
+ # Extract keyword section
193
+ pattern = r"Keywords:\s*(.*)" # Changed from "Keywords:\s*(.*)"
194
+ match = re.search(pattern, generated_text, re.DOTALL)
195
+ if match:
196
+ keywords_text = match.group(1).strip()
197
+ # Listify keywords
198
+ keywords = re.findall(r"-\s*(.*)", keywords_text)
199
+ search_query = ' '.join(keywords)
200
+ logging.info(f"Generated search query: {search_query}")
201
+ return search_query
202
+ else:
203
+ logging.warning("Failed to generate keywords.")
204
+ return ""
205
+
206
+ def generate_new_keywords(self, user_input):
207
+ """
208
+ Attempts to regenerate keywords.
209
+ """
210
+ prompt = f"""
211
+ User's question:
212
+ {user_input}
213
+
214
+ Insufficient information was obtained. Please generate new keywords.
215
+ List three new keywords.
216
+
217
+ Keywords:
218
+ -"""
219
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
220
+ output_ids = self.generate_step(input_ids, max_new_tokens=50)
221
+ generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
222
+ pattern = r"Keywords:\s*(.*)" # Changed from "Keywords:\s*(.*)"
223
+ match = re.search(pattern, generated_text, re.DOTALL)
224
+ if match:
225
+ keywords_text = match.group(1).strip()
226
+ keywords = re.findall(r"-\s*(.*)", keywords_text)
227
+ search_query = ' '.join(keywords)
228
+ logging.info(f"Regenerated search query: {search_query}")
229
+ return search_query
230
+ else:
231
+ logging.warning("Failed to extract regenerated keywords.")
232
+ return ""
233
+
234
+ def update_instruction_with_new_keywords(self, instruction, new_keywords):
235
+ """
236
+ Incorporates new keywords into the original instruction.
237
+ """
238
+ # Simply appends new keywords to the original instruction.
239
+ updated_instruction = f"{instruction} Keywords: {new_keywords}"
240
+ return updated_instruction
241
+
242
+ def is_answer_sufficient(self, summarized_info, user_input):
243
+ """
244
+ Determines if the summarized information is sufficient to answer the question.
245
+ """
246
+ prompt = f"""
247
+ User's question:
248
+ {user_input}
249
+
250
+ Retrieved information:
251
+ {summarized_info}
252
+
253
+ Based on this information, determine if you can answer the user's question.
254
+ If yes, respond with "Yes". If no, respond with "No" only.
255
+ """
256
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
257
+ output_ids = self.generate_step(input_ids, max_new_tokens=10)
258
+ generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
259
+
260
+ if "Yes" in generated_text:
261
+ return True
262
+ else:
263
+ return False
264
+
265
+ def generate_answer(self, user_input, summarized_info, max_new_tokens=150):
266
+ """
267
+ Generates an answer based on the retrieved information.
268
+ """
269
+ # Step 1: Understanding the question and extracting key points
270
+ step1_prompt = f"""
271
+ #User's question:
272
+ {user_input}
273
+
274
+ #Step 1: Understanding the question and extracting key points
275
+ Accurately understand the user's question or instructions.
276
+ Output the rules for answering and the tasks to be performed in a bullet list.
277
+
278
+ #Rules for answering and tasks to be performed:
279
+ """
280
+ step1_input_ids = self.tokenizer.encode(step1_prompt, return_tensors="pt").to(self.device)
281
+ outputs_step1 = self.generate_step(step1_input_ids, max_new_tokens=max_new_tokens)
282
+ step1_response = self.extract_response(outputs_step1, "#Rules for answering and tasks to be performed:")
283
+ logging.info("Understanding the question...\n======================\n" + step1_response)
284
+
285
+ # Step 2: Considerations for problem-solving
286
+ step2_prompt = f"""
287
+ #Step 2: Considerations for problem-solving
288
+ Based on the content of Step 1, consider approaches and necessary information for solving the problem.
289
+
290
+ #Step 2 response:
291
+ """
292
+ step2_input_ids = self.tokenizer.encode(step1_response + step2_prompt, return_tensors="pt").to(self.device)
293
+ outputs_step2 = self.generate_step(step2_input_ids, max_new_tokens=max_new_tokens)
294
+ step2_response = self.extract_response(outputs_step2, "#Step 2 response:")
295
+ logging.info("Considering approaches for problem-solving...\n======================\n" + step2_response)
296
+
297
+ # Step 3: Creating the initial answer
298
+ step3_prompt = f"""
299
+ #Step 3: Creating the initial answer
300
+ Based on the content so far, create an initial answer to the user's question.
301
+ Your information may not be up-to-date. Fully consider information from the internet.
302
+
303
+ #Latest internet information:
304
+ {summarized_info}
305
+
306
+ #Initial answer:
307
+ """
308
+ step3_input_ids = self.tokenizer.encode(step2_response + step3_prompt, return_tensors="pt").to(self.device)
309
+ outputs_step3 = self.generate_step(step3_input_ids, max_new_tokens=max_new_tokens)
310
+ step3_response = self.extract_response(outputs_step3, "#Initial answer:")
311
+ logging.info("Creating the initial answer...\n======================\n" + step3_response)
312
+
313
+ # Step 4: Reflection (Self-verification)
314
+ reflection_prompt = f"""
315
+ #Step 4: Reflection (Self-verification)
316
+ Verify whether the initial answer accurately responds to the user's question or instructions, and point out any errors or areas for improvement.
317
+ Be cautious of overinterpreting the instructions and critically assess whether you have accurately understood them.
318
+ Your information may not be up-to-date. Fully consider information from the internet.
319
+ Reconfirm the user's question and provide an accurate answer to the question itself. (Ensure that you provide an answer to the question itself)
320
+
321
+ #User's question:
322
+ {user_input}
323
+
324
+ #Latest internet information:
325
+ {summarized_info}
326
+
327
+ #Initial answer:
328
+ {step3_response}
329
+
330
+ #Reflection result:
331
+ """
332
+ reflection_input_ids = self.tokenizer.encode(reflection_prompt, return_tensors="pt").to(self.device)
333
+ outputs_reflection = self.generate_step(reflection_input_ids, max_new_tokens=max_new_tokens)
334
+ reflection_response = self.extract_response(outputs_reflection, "#Reflection result:")
335
+ logging.info("Performing reflection...\n======================\n" + reflection_response)
336
+
337
+ # Step 5: Creating the final answer
338
+ final_prompt = f"""
339
+ #Step 5: Creating the final answer
340
+ Based on the reflection results, modify the initial answer as needed.
341
+ Your knowledge may not be up-to-date. Fully consider information from the internet.
342
+ Reconfirm the user's question, and check for overinterpretation, misunderstandings, omissions, and careless mistakes.
343
+ Create the final answer incorporating these.
344
+
345
+ #Initial answer:
346
+ {step3_response}
347
+
348
+ #Reflection result:
349
+ {reflection_response}
350
+
351
+ #Latest internet information:
352
+ {summarized_info}
353
+
354
+ #User's question:
355
+ {user_input}
356
+
357
+ Please provide the final answer to the user's question.
358
+ #Final answer:
359
+ """
360
+ final_input_ids = self.tokenizer.encode(final_prompt, return_tensors="pt").to(self.device)
361
+ outputs_final = self.generate_step(final_input_ids, max_new_tokens=max_new_tokens)
362
+ final_response = self.extract_response(outputs_final, "#Final answer:").strip()
363
+
364
+ return final_response
365
+
366
+ def self_reasoning(self, user_input, max_new_tokens=150):
367
+ """
368
+ Generates an answer based on self-reasoning.
369
+ """
370
+ prompt = f"""
371
+ User's question:
372
+ {user_input}
373
+
374
+ No relevant information was found on the internet. Please use your own knowledge and reasoning to answer.
375
+
376
+ #Answer based on self-reasoning:
377
+ """
378
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
379
+ output_ids = self.generate_step(input_ids, max_new_tokens=max_new_tokens)
380
+ generated_text = self.extract_response(output_ids, "#Answer based on self-reasoning:").strip()
381
+ logging.info("Answer based on self-reasoning:\n======================\n" + generated_text)
382
+ return generated_text
383
+
384
+ def process_document(self, doc, user_input):
385
+ """
386
+ Determines if each document is relevant to the user's question and generates an answer if applicable.
387
+ """
388
+ # Create prompt
389
+ prompt = f"""
390
+ User's question:
391
+ {user_input}
392
+
393
+ Content of the document:
394
+ {doc[:2000]} # Truncate if too long
395
+
396
+ Do not think of the question superficially. Use paradoxes and rephrasing to organize.
397
+ Create an answer to the question based on the content of this document.
398
+ Understand the points of disagreement between your own thoughts and the answer you would create based on this document, and prioritize the answer based on the document.
399
+
400
+ Answer:
401
+ """
402
+ input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
403
+ output_ids = self.generate_step(input_ids, max_new_tokens=500)
404
+ generated_text = self.extract_response(output_ids, "Answer:")
405
+ logging.info("Document processing result: " + generated_text)
406
+ # Return empty string if deemed low relevance
407
+ if "low relevance" in generated_text:
408
+ return ""
409
+ else:
410
+ return generated_text.strip()
411
+
412
+ def summarize_documents(self, documents, user_input):
413
+ """
414
+ Processes each document and summarizes relevant information.
415
+ """
416
+ summaries = []
417
+ for doc in documents:
418
+ processed_text = self.process_document(doc, user_input)
419
+ if processed_text:
420
+ summaries.append(processed_text)
421
+ return "\n\n".join(summaries)