Enoch commited on
Commit
bdaff92
·
1 Parent(s): 0780b55

更新了应用程序代码

Browse files
Files changed (1) hide show
  1. app.py +93 -134
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import os
 
2
  import gradio as gr
 
3
  import tempfile
4
  import atexit
5
- from dotenv import load_dotenv
6
- from openai import OpenAI
7
 
8
  # 加载环境变量
9
  load_dotenv()
@@ -17,7 +17,6 @@ client = OpenAI(
17
  def call_openai_api(prompt, temperature=0.7):
18
  """调用OpenAI API生成内容"""
19
  try:
20
- print(f"开始调用API,prompt长度: {len(prompt)}") # 添加日志
21
  chat_completion = client.chat.completions.create(
22
  messages=[
23
  {
@@ -34,11 +33,10 @@ def call_openai_api(prompt, temperature=0.7):
34
  max_tokens=3000,
35
  n=1
36
  )
37
- print("API调用成功") # 添加日志
38
  return chat_completion.choices[0].message.content.strip()
39
  except Exception as e:
40
- print(f"API调用详细错误:{type(e).__name__}: {str(e)}") # 增强错误信息
41
- return f"生成失败:{type(e).__name__}: {str(e)}"
42
 
43
  # 定义每个部分的Agent Prompt模板
44
  BACKGROUND_PROMPT = """
@@ -159,133 +157,93 @@ ALTERNATIVE_PROMPT = """
159
  {user_input}
160
  """
161
 
162
-
163
- # 验证函数示例
164
- def validate_content(content, expected_keywords=None):
165
- """
166
- 自动化验证内容是否合理。
167
- - 检查内容是否包含预期的关键词。
168
- - 确保内容长度合理。
169
- - 其他自定义验证规则。
170
- """
171
- if expected_keywords:
172
- for keyword in expected_keywords:
173
- if keyword.lower() not in content.lower():
174
- raise ValidationError(f"内容缺少关键词: {keyword}")
175
- if len(content) < 100:
176
- raise ValidationError("内容过短,可能不够详实。")
177
- # 添加更多验证规则
178
- return True
179
-
180
- class ValidationError(Exception):
181
- """自定义验证错误"""
182
- def __init__(self, message):
183
- self.message = message
184
- super().__init__(self.message)
185
-
186
- def get_retry_prompt(self):
187
- return f"请根据反馈修改内容: {self.message}"
188
-
189
- def generate_with_hitl(prompt_template, user_input, previous_contents=None, max_retries=3):
190
- """通用的生成函数,包含HITL流程"""
191
- print(f"开始生成内容,用户输入长度: {len(user_input)}") # 添加日志
192
- retry_prompt = ""
193
- retries = 0
194
- success = False
195
- final_content = ""
196
-
197
- while retries < max_retries and not success:
198
- try:
199
- # 生成内容
200
- if previous_contents:
201
- prompt = prompt_template.format(
202
- user_input=user_input,
203
- **previous_contents
204
- )
205
- else:
206
- prompt = prompt_template.format(user_input=user_input)
207
-
208
- print(f"第{retries + 1}次尝试生成") # 添加日志
209
- content = call_openai_api(prompt)
210
-
211
- # 检查返回内容是否包含错误信息
212
- if content.startswith("生成失败"):
213
- print(f"生成返回错误: {content}") # 添加日志
214
- raise Exception(content)
215
-
216
- # 自动化验证
217
- validate_content(content, expected_keywords=["技术", "方案"])
218
- success = True
219
- final_content = content
220
- print("内容生成成功") # 添加日志
221
-
222
- except ValidationError as e:
223
- print(f"验证失败: {str(e)}") # 添加日志
224
- retry_prompt = e.get_retry_prompt()
225
- retries += 1
226
- except Exception as e:
227
- print(f"生成过程中发生错误: {type(e).__name__}: {str(e)}") # 增强错误信息
228
- retries += 1
229
-
230
- if not success:
231
- final_content = "生成失败,请检查输入或稍后重试。"
232
-
233
- return final_content
234
- def generate_patent_document(bg_input, sh_input, pr_input, so_input, kp_input, adv_input, alt_input, progress=gr.Progress()):
235
- """
236
- 生成完整的专利交底书,包含HITL流程的每个模块。
237
- """
238
- final_document = ""
239
  try:
240
- # 生成第一部分:背景技术
241
- progress(0.14, desc="生成背景技术部分...")
242
- background_content = generate_with_hitl(BACKGROUND_PROMPT, bg_input)
243
- final_document += f"{background_content}\n\n"
244
-
245
- # 生成第二部分:现有技术缺点
246
- progress(0.28, desc="生成现有技术缺点部分...")
247
- shortcoming_content = generate_with_hitl(SHORTCOMING_PROMPT, sh_input, {"background_content": background_content})
248
- final_document += f"{shortcoming_content}\n\n"
249
-
250
- # 生成第三部分:技术问题
251
- progress(0.42, desc="生成技术问题部分...")
252
- problem_content = generate_with_hitl(PROBLEM_PROMPT, pr_input, {
253
- "background_content": background_content,
254
- "shortcoming_content": shortcoming_content
255
- })
256
- final_document += f"{problem_content}\n\n"
257
-
258
- # 生成第四部分:技术方案
259
- progress(0.56, desc="生成技术方案部分...")
260
- solution_content = generate_with_hitl(SOLUTION_PROMPT, so_input, {
261
- "background_content": background_content,
262
- "shortcoming_content": shortcoming_content,
263
- "problem_content": problem_content
264
- })
265
- final_document += f"{solution_content}\n\n"
266
-
267
- # 生成第五部分:关键点
268
- progress(0.70, desc="生成关键点部分...")
269
- keypoint_content = generate_with_hitl(KEYPOINT_PROMPT, kp_input, {"solution_content": solution_content})
270
- final_document += f"{keypoint_content}\n\n"
271
-
272
- # 生成第六部分:优点
273
- progress(0.84, desc="生成优点部分...")
274
- advantage_content = generate_with_hitl(ADVANTAGE_PROMPT, adv_input, {
275
- "shortcoming_content": shortcoming_content,
276
- "problem_content": problem_content,
277
- "solution_content": solution_content
278
- })
279
- final_document += f"{advantage_content}\n\n"
280
-
281
- # 生成第七部分:替代方案
282
- progress(1.0, desc="生成替代方案部分...")
283
- alternative_content = generate_with_hitl(ALTERNATIVE_PROMPT, alt_input, {"solution_content": solution_content})
284
- final_document += f"{alternative_content}"
285
-
286
- return final_document
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  except Exception as e:
288
- return f"生成过程中发生错误:{str(e)}"
 
289
 
290
  def generate_filename(text):
291
  """根据专利交底书内容生成合适的文件名"""
@@ -425,18 +383,19 @@ with gr.Blocks(theme=gr.themes.Soft(
425
  outputs=[download_file]
426
  )
427
 
428
- # 设置生成按钮行为
429
  generate_button.click(
430
  fn=generate_patent_document,
431
  inputs=[bg, sh, pr, so, kp, adv, alt],
432
- outputs=[final_output]
433
  ).then(
434
  # 修改 lambda 函数,添加输入参数
435
  fn=lambda x: gr.update(visible=True),
436
- inputs=[final_output],
437
  outputs=[download_file]
438
  )
439
 
 
440
  clear_button.click(
441
  fn=clear_all,
442
  inputs=[],
 
1
  import os
2
+ from openai import OpenAI
3
  import gradio as gr
4
+ from dotenv import load_dotenv
5
  import tempfile
6
  import atexit
 
 
7
 
8
  # 加载环境变量
9
  load_dotenv()
 
17
  def call_openai_api(prompt, temperature=0.7):
18
  """调用OpenAI API生成内容"""
19
  try:
 
20
  chat_completion = client.chat.completions.create(
21
  messages=[
22
  {
 
33
  max_tokens=3000,
34
  n=1
35
  )
 
36
  return chat_completion.choices[0].message.content.strip()
37
  except Exception as e:
38
+ print(f"API调用出错:{str(e)}")
39
+ return f"生成失败:{str(e)}"
40
 
41
  # 定义每个部分的Agent Prompt模板
42
  BACKGROUND_PROMPT = """
 
157
  {user_input}
158
  """
159
 
160
+ def generate_patent_document(
161
+ bg_input,
162
+ shortcoming_input,
163
+ problem_input,
164
+ solution_input,
165
+ keypoint_input,
166
+ advantage_input,
167
+ alternative_input,
168
+ progress=gr.Progress() # 使用progress参数
169
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  try:
171
+ # 依次生成各部分,并在每个阶段更新进度
172
+ progress(0, desc="正在生成背景技术部分...")
173
+ yield "111正在生成背景技术部分...", None # 更新状态文本框
174
+ bg_prompt = BACKGROUND_PROMPT.format(user_input=bg_input)
175
+ background_content = call_openai_api(bg_prompt)
176
+
177
+ progress(0.15, desc="正在生成现有技术缺点部分...")
178
+ yield "111正在生成现有技术缺点部分...", None
179
+ short_prompt = SHORTCOMING_PROMPT.format(
180
+ previous_content=background_content,
181
+ user_input=shortcoming_input
182
+ )
183
+ shortcoming_content = call_openai_api(short_prompt)
184
+
185
+ progress(0.3, desc="正在生成技术问题部分...")
186
+ yield "111正在生成技术问题部分...", None
187
+ problem_prompt_full = PROBLEM_PROMPT.format(
188
+ background_content=background_content,
189
+ shortcoming_content=shortcoming_content,
190
+ user_input=problem_input
191
+ )
192
+ problem_content = call_openai_api(problem_prompt_full)
193
+
194
+ progress(0.45, desc="正在生成技术方案部分...")
195
+ yield "111正在生成技术方案部分...", None
196
+ solution_prompt_full = SOLUTION_PROMPT.format(
197
+ background_content=background_content,
198
+ shortcoming_content=shortcoming_content,
199
+ problem_content=problem_content,
200
+ user_input=solution_input
201
+ )
202
+ solution_content = call_openai_api(solution_prompt_full)
203
+
204
+ progress(0.6, desc="正在生成关键点部分...")
205
+ yield "111正在生成关键点部分...", None
206
+ keypoint_prompt_full = KEYPOINT_PROMPT.format(
207
+ solution_content=solution_content,
208
+ user_input=keypoint_input
209
+ )
210
+ keypoint_content = call_openai_api(keypoint_prompt_full)
211
+
212
+ progress(0.75, desc="正在生成优点部分...")
213
+ yield "正在生成优点部分...", None
214
+ advantage_prompt_full = ADVANTAGE_PROMPT.format(
215
+ shortcoming_content=shortcoming_content,
216
+ problem_content=problem_content,
217
+ solution_content=solution_content,
218
+ user_input=advantage_input
219
+ )
220
+ advantage_content = call_openai_api(advantage_prompt_full)
221
+
222
+ progress(0.9, desc="正在生成替代方案部分...")
223
+ yield "正在生成替代方案部分...", None
224
+ alternative_prompt_full = ALTERNATIVE_PROMPT.format(
225
+ solution_content=solution_content,
226
+ user_input=alternative_input
227
+ )
228
+ alternative_content = call_openai_api(alternative_prompt_full)
229
+
230
+
231
+ final_document = (
232
+ f"{background_content}\n\n"
233
+ f"{shortcoming_content}\n\n"
234
+ f"{problem_content}\n\n"
235
+ f"{solution_content}\n\n"
236
+ f"{keypoint_content}\n\n"
237
+ f"{advantage_content}\n\n"
238
+ f"{alternative_content}"
239
+ )
240
+
241
+ progress(1.0, desc="生成完成!")
242
+ yield "生成完成!", final_document
243
+
244
  except Exception as e:
245
+ error_message = f"生成过程中发生错误:{str(e)}"
246
+ yield error_message, error_message
247
 
248
  def generate_filename(text):
249
  """根据专利交底书内容生成合适的文件名"""
 
383
  outputs=[download_file]
384
  )
385
 
386
+ # 不使用status和show_progress,仅依赖progress参数来显示进度条
387
  generate_button.click(
388
  fn=generate_patent_document,
389
  inputs=[bg, sh, pr, so, kp, adv, alt],
390
+ outputs=[status_box, final_output]
391
  ).then(
392
  # 修改 lambda 函数,添加输入参数
393
  fn=lambda x: gr.update(visible=True),
394
+ inputs=[final_output], # 添加输入
395
  outputs=[download_file]
396
  )
397
 
398
+
399
  clear_button.click(
400
  fn=clear_all,
401
  inputs=[],