dangthr commited on
Commit
da3cf12
·
verified ·
1 Parent(s): 4eb64b8

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +41 -87
inference.py CHANGED
@@ -1,5 +1,6 @@
1
  # inference.py
2
  import os
 
3
  import argparse
4
  import random
5
  import json
@@ -14,7 +15,7 @@ from PIL import Image
14
  from diffusers import QwenImageEditPipeline
15
 
16
  # --- 从原脚本保留的辅助函数 ---
17
-
18
  SYSTEM_PROMPT = '''
19
  # Edit Instruction Rewriter
20
  You are a professional edit instruction rewriter. Your task is to generate a precise, concise, and visually achievable professional-level edit instruction based on the user-provided instruction and the image to be edited.  
@@ -80,9 +81,8 @@ def polish_prompt(prompt, img):
80
  if not os.environ.get('DASH_API_KEY'):
81
  print("[警告] 环境变量 DASH_API_KEY 未设置,将跳过提示词重写。")
82
  return prompt
83
-
84
  full_prompt = f"{SYSTEM_PROMPT}\n\nUser Input: {prompt}\n\nRewritten Prompt:"
85
- for attempt in range(3): # 最多重试3次
86
  try:
87
  result = api(full_prompt, [img])
88
  if isinstance(result, str):
@@ -90,12 +90,10 @@ def polish_prompt(prompt, img):
90
  result_data = json.loads(result_json_str)
91
  else:
92
  result_data = json.loads(result)
93
-
94
  polished = result_data['Rewritten']
95
  return polished.strip().replace("\n", " ")
96
  except Exception as e:
97
  print(f"[警告] API调用失败 (尝试 {attempt + 1}): {e}")
98
-
99
  print("[错误] 多次尝试后提示词重写失败,将使用原始提示词。")
100
  return prompt
101
 
@@ -111,23 +109,11 @@ def api(prompt, img_list, model="qwen-vl-max-latest", kwargs={}):
111
  api_key = os.environ.get('DASH_API_KEY')
112
  if not api_key:
113
  raise EnvironmentError("DASH_API_KEY is not set")
114
-
115
- messages = [
116
- {"role": "system", "content": "you are a helpful assistant, you should provide useful answers to users."},
117
- {"role": "user", "content": []}
118
- ]
119
  for img in img_list:
120
  messages[1]["content"].append({"image": f"data:image/png;base64,{encode_image(img)}"})
121
  messages[1]["content"].append({"text": f"{prompt}"})
122
-
123
- response = dashscope.MultiModalConversation.call(
124
- api_key=api_key,
125
- model=model,
126
- messages=messages,
127
- result_format='message',
128
- response_format=kwargs.get('response_format', None),
129
- )
130
-
131
  if response.status_code == 200:
132
  return response.output.choices[0].message.content[0]['text']
133
  else:
@@ -148,113 +134,81 @@ def load_image(image_path):
148
  print(f" 详细信息: {e}")
149
  return None
150
 
151
- # --- 主推理逻辑 ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
 
153
  def main(args):
154
  """执行模型推理的主函数"""
155
  output_dir = "output"
156
  os.makedirs(output_dir, exist_ok=True)
157
-
158
  dtype = torch.bfloat16
159
  device = "cuda" if torch.cuda.is_available() else "cpu"
160
  print(f"使用设备: {device}")
161
-
162
  print("正在加载 Qwen-Image-Edit 模型...")
163
  try:
164
  pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=dtype).to(device)
165
  print("模型加载完成。")
166
  except Exception as e:
167
- print(f"❌ 错误:模型加载失败。请检查网络连接和依赖项。")
168
  print(f" 详细信息: {e}")
169
  return
170
-
171
  print(f"正在从 '{args.input_image}' 加载输入图片...")
172
  input_image = load_image(args.input_image)
173
  if input_image is None:
174
  return
175
-
176
- # 设置随机种子
177
  seed = random.randint(0, np.iinfo(np.int32).max) if args.random_seed else args.seed
178
  generator = torch.Generator(device=device).manual_seed(seed)
179
-
180
- # 如果不禁用重写功能,则调用 polish_prompt
181
  prompt_to_use = polish_prompt(args.prompt, input_image) if not args.no_rewrite else args.prompt
182
-
183
  if not args.no_rewrite:
184
  print(f"重写后的提示词: '{prompt_to_use}'")
185
-
186
  print("-" * 30)
187
  print("🚀 开始推理...")
188
  print(f" - 提示词: '{prompt_to_use}'")
189
  print(f" - 随机种子: {seed}")
190
  print(f" - 推理步数: {args.steps}")
191
- print(f" -引导系数 (Guidance Scale): {args.guidance_scale}")
192
  print("-" * 30)
193
-
194
  try:
195
- images = pipe(
196
- image=input_image,
197
- prompt=prompt_to_use,
198
- negative_prompt=" ", # 固定负向提示词
199
- num_inference_steps=args.steps,
200
- generator=generator,
201
- true_cfg_scale=args.guidance_scale,
202
- num_images_per_prompt=1
203
- ).images
204
-
205
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
206
  output_path = os.path.join(output_dir, f"output_{timestamp}_{seed}.png")
207
  images[0].save(output_path)
208
  print(f"✅ 推理成功!图片已保存至: {output_path}")
209
-
210
  except Exception as e:
211
  print(f"❌ 推理过程中发生错误: {e}")
212
 
213
  # --- 命令行接口 ---
214
-
215
  if __name__ == "__main__":
216
- parser = argparse.ArgumentParser(description="Qwen 图像编辑命令行工具")
217
-
218
- parser.add_argument(
219
- "--prompt",
220
- type=str,
221
- required=True,
222
- help="必须:用于编辑图像的指令。"
223
- )
224
- parser.add_argument(
225
- "--input_image",
226
- type=str,
227
- required=True,
228
- help="必须:输入图片的本地路径或URL链接。"
229
- )
230
- parser.add_argument(
231
- "--seed",
232
- type=int,
233
- default=42,
234
- help="用于复现结果的随机种子,默认为 42。"
235
- )
236
- parser.add_argument(
237
- "--random_seed",
238
- action="store_true",
239
- help="如果设置此项,则使用一个随机种子。"
240
- )
241
- parser.add_argument(
242
- "--steps",
243
- type=int,
244
- default=50,
245
- help="推理步数,默认为 50。"
246
- )
247
- parser.add_argument(
248
- "--guidance_scale",
249
- type=float,
250
- default=4.0,
251
- help="引导系数 (CFG scale),默认为 4.0。"
252
- )
253
- parser.add_argument(
254
- "--no_rewrite",
255
- action="store_true",
256
- help="如果设置此项,则禁用提示词重写功能。"
257
- )
258
-
259
  args = parser.parse_args()
260
  main(args)
 
1
  # inference.py
2
  import os
3
+ import sys # 导入 sys 模块
4
  import argparse
5
  import random
6
  import json
 
15
  from diffusers import QwenImageEditPipeline
16
 
17
  # --- 从原脚本保留的辅助函数 ---
18
+ # SYSTEM_PROMPT, polish_prompt, encode_image, api 函数保持不变...
19
  SYSTEM_PROMPT = '''
20
  # Edit Instruction Rewriter
21
  You are a professional edit instruction rewriter. Your task is to generate a precise, concise, and visually achievable professional-level edit instruction based on the user-provided instruction and the image to be edited.  
 
81
  if not os.environ.get('DASH_API_KEY'):
82
  print("[警告] 环境变量 DASH_API_KEY 未设置,将跳过提示词重写。")
83
  return prompt
 
84
  full_prompt = f"{SYSTEM_PROMPT}\n\nUser Input: {prompt}\n\nRewritten Prompt:"
85
+ for attempt in range(3):
86
  try:
87
  result = api(full_prompt, [img])
88
  if isinstance(result, str):
 
90
  result_data = json.loads(result_json_str)
91
  else:
92
  result_data = json.loads(result)
 
93
  polished = result_data['Rewritten']
94
  return polished.strip().replace("\n", " ")
95
  except Exception as e:
96
  print(f"[警告] API调用失败 (尝试 {attempt + 1}): {e}")
 
97
  print("[错误] 多次尝试后提示词重写失败,将使用原始提示词。")
98
  return prompt
99
 
 
109
  api_key = os.environ.get('DASH_API_KEY')
110
  if not api_key:
111
  raise EnvironmentError("DASH_API_KEY is not set")
112
+ messages = [{"role": "system", "content": "you are a helpful assistant, you should provide useful answers to users."},{"role": "user", "content": []}]
 
 
 
 
113
  for img in img_list:
114
  messages[1]["content"].append({"image": f"data:image/png;base64,{encode_image(img)}"})
115
  messages[1]["content"].append({"text": f"{prompt}"})
116
+ response = dashscope.MultiModalConversation.call(api_key=api_key,model=model,messages=messages,result_format='message',response_format=kwargs.get('response_format', None),)
 
 
 
 
 
 
 
 
117
  if response.status_code == 200:
118
  return response.output.choices[0].message.content[0]['text']
119
  else:
 
134
  print(f" 详细信息: {e}")
135
  return None
136
 
137
+ def prepare_model():
138
+ """仅下载并缓存模型,不执行推理"""
139
+ print("正在准备模型... 如果是首次运行,将开始下载模型文件(约10GB)。")
140
+ print("请耐心等待,下载速度取决于您的网络状况。")
141
+ dtype = torch.bfloat16
142
+ try:
143
+ QwenImageEditPipeline.from_pretrained(
144
+ "Qwen/Qwen-Image-Edit",
145
+ torch_dtype=dtype,
146
+ low_cpu_mem_usage=True # 优化内存使用
147
+ )
148
+ print("\n✅ 模型文件已成功准备(下载/加载)到本地缓存。")
149
+ return True
150
+ except Exception as e:
151
+ print(f"\n❌ 错误:模型下载或加载失败。请检查网络连接或磁盘空间。")
152
+ print(f" 详细信息: {e}")
153
+ return False
154
 
155
+ # --- 主推理逻辑 ---
156
  def main(args):
157
  """执行模型推理的主函数"""
158
  output_dir = "output"
159
  os.makedirs(output_dir, exist_ok=True)
 
160
  dtype = torch.bfloat16
161
  device = "cuda" if torch.cuda.is_available() else "cpu"
162
  print(f"使用设备: {device}")
 
163
  print("正在加载 Qwen-Image-Edit 模型...")
164
  try:
165
  pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=dtype).to(device)
166
  print("模型加载完成。")
167
  except Exception as e:
168
+ print(f"❌ 错误:模型加载失败。")
169
  print(f" 详细信息: {e}")
170
  return
 
171
  print(f"正在从 '{args.input_image}' 加载输入图片...")
172
  input_image = load_image(args.input_image)
173
  if input_image is None:
174
  return
 
 
175
  seed = random.randint(0, np.iinfo(np.int32).max) if args.random_seed else args.seed
176
  generator = torch.Generator(device=device).manual_seed(seed)
 
 
177
  prompt_to_use = polish_prompt(args.prompt, input_image) if not args.no_rewrite else args.prompt
 
178
  if not args.no_rewrite:
179
  print(f"重写后的提示词: '{prompt_to_use}'")
 
180
  print("-" * 30)
181
  print("🚀 开始推理...")
182
  print(f" - 提示词: '{prompt_to_use}'")
183
  print(f" - 随机种子: {seed}")
184
  print(f" - 推理步数: {args.steps}")
185
+ print(f" - 引导系数 (Guidance Scale): {args.guidance_scale}")
186
  print("-" * 30)
 
187
  try:
188
+ images = pipe(image=input_image,prompt=prompt_to_use,negative_prompt=" ",num_inference_steps=args.steps,generator=generator,true_cfg_scale=args.guidance_scale,num_images_per_prompt=1).images
 
 
 
 
 
 
 
 
 
189
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
190
  output_path = os.path.join(output_dir, f"output_{timestamp}_{seed}.png")
191
  images[0].save(output_path)
192
  print(f"✅ 推理成功!图片已保存至: {output_path}")
 
193
  except Exception as e:
194
  print(f"❌ 推理过程中发生错误: {e}")
195
 
196
  # --- 命令行接口 ---
 
197
  if __name__ == "__main__":
198
+ # 新增逻辑:检查是否只运行脚本而不带任何参数
199
+ if len(sys.argv) == 1:
200
+ prepare_model()
201
+ print("任务完成,脚本退出。")
202
+ sys.exit(0) # 正常退出
203
+
204
+ # 如果带有参数,则执行原有的推理流程
205
+ parser = argparse.ArgumentParser(description="Qwen 图像编辑命令行工具", epilog="如果不提供任何参数,脚本将只下载模型然后退出。")
206
+ parser.add_argument("--prompt",type=str,required=True,help="必须:用于编辑图像的指令。")
207
+ parser.add_argument("--input_image",type=str,required=True,help="必须:输入图片的本地路径或URL链接。")
208
+ parser.add_argument("--seed",type=int,default=42,help="用于复现结果的随机种子,默认为 42。")
209
+ parser.add_argument("--random_seed",action="store_true",help="如果设置此项,则使用一个随机种子。")
210
+ parser.add_argument("--steps",type=int,default=50,help="推理步数,默认为 50。")
211
+ parser.add_argument("--guidance_scale",type=float,default=4.0,help="引导系数 (CFG scale),默认为 4.0。")
212
+ parser.add_argument("--no_rewrite",action="store_true",help="如果设置此项,则禁用提示词重写功能。")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  args = parser.parse_args()
214
  main(args)