dangthr commited on
Commit
a578aa5
·
verified ·
1 Parent(s): 18e2bb4

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +19 -22
inference.py CHANGED
@@ -1,6 +1,6 @@
1
  # inference.py
2
  import os
3
- import sys # 导入 sys 模块
4
  import argparse
5
  import random
6
  import json
@@ -13,6 +13,7 @@ import numpy as np
13
  import requests
14
  from PIL import Image
15
  from diffusers import QwenImageEditPipeline
 
16
 
17
  # --- 从原脚本保留的辅助函数 ---
18
  # SYSTEM_PROMPT, polish_prompt, encode_image, api 函数保持不变...
@@ -77,7 +78,6 @@ Please strictly follow the rewriting rules below:
77
  '''
78
 
79
  def polish_prompt(prompt, img):
80
- """使用 DashScope API 重写和优化提示词"""
81
  if not os.environ.get('DASH_API_KEY'):
82
  print("[警告] 环境变量 DASH_API_KEY 未设置,将跳过提示词重写。")
83
  return prompt
@@ -98,13 +98,11 @@ def polish_prompt(prompt, img):
98
  return prompt
99
 
100
  def encode_image(pil_image):
101
- """将 PIL 图片编码为 base64 字符串"""
102
  buffered = BytesIO()
103
  pil_image.save(buffered, format="PNG")
104
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
105
 
106
  def api(prompt, img_list, model="qwen-vl-max-latest", kwargs={}):
107
- """调用 DashScope API"""
108
  import dashscope
109
  api_key = os.environ.get('DASH_API_KEY')
110
  if not api_key:
@@ -120,7 +118,6 @@ def api(prompt, img_list, model="qwen-vl-max-latest", kwargs={}):
120
  raise Exception(f'Failed to post: {response}')
121
 
122
  def load_image(image_path):
123
- """从本地路径或URL加载图片"""
124
  try:
125
  if image_path.startswith("http://") or image_path.startswith("https://"):
126
  response = requests.get(image_path)
@@ -134,25 +131,28 @@ def load_image(image_path):
134
  print(f" 详细信息: {e}")
135
  return None
136
 
 
137
  def prepare_model():
138
- """仅下载并缓存模型,不执行推理"""
139
- print("正在准备模型... 如果是首次运行,将开始下载模型文件(约10GB)。")
140
- print("请耐心等待,下载速度取决于您的网络状况。")
141
- dtype = torch.bfloat16
 
142
  try:
143
- QwenImageEditPipeline.from_pretrained(
144
- "Qwen/Qwen-Image-Edit",
145
- torch_dtype=dtype,
146
- low_cpu_mem_usage=True # 优化内存使用
 
147
  )
148
- print("\n✅ 模型文件已成功准备(下载/加载)到本地缓存。")
149
  return True
150
  except Exception as e:
151
- print(f"\n❌ 错误:模型下载或加载失败。请检查网络连接或磁盘空间。")
152
  print(f" 详细信息: {e}")
153
  return False
154
 
155
- # --- 主推理逻辑 ---
156
  def main(args):
157
  """执行模型推理的主函数"""
158
  output_dir = "output"
@@ -160,7 +160,7 @@ def main(args):
160
  dtype = torch.bfloat16
161
  device = "cuda" if torch.cuda.is_available() else "cpu"
162
  print(f"使用设备: {device}")
163
- print("正在加载 Qwen-Image-Edit 模型...")
164
  try:
165
  pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=dtype).to(device)
166
  print("模型加载完成。")
@@ -193,15 +193,12 @@ def main(args):
193
  except Exception as e:
194
  print(f"❌ 推理过程中发生错误: {e}")
195
 
196
- # --- 命令行接口 ---
197
  if __name__ == "__main__":
198
- # 新增逻辑:检查是否只运行脚本而不带任何参数
199
  if len(sys.argv) == 1:
200
  prepare_model()
201
  print("任务完成,脚本退出。")
202
- sys.exit(0) # 正常退出
203
-
204
- # 如果带有参数,则执行原有的推理流程
205
  parser = argparse.ArgumentParser(description="Qwen 图像编辑命令行工具", epilog="如果不提供任何参数,脚本将只下载模型然后退出。")
206
  parser.add_argument("--prompt",type=str,required=True,help="必须:用于编辑图像的指令。")
207
  parser.add_argument("--input_image",type=str,required=True,help="必须:输入图片的本地路径或URL链接。")
 
1
  # inference.py
2
  import os
3
+ import sys
4
  import argparse
5
  import random
6
  import json
 
13
  import requests
14
  from PIL import Image
15
  from diffusers import QwenImageEditPipeline
16
+ from huggingface_hub import snapshot_download # <--- 新增导入
17
 
18
  # --- 从原脚本保留的辅助函数 ---
19
  # SYSTEM_PROMPT, polish_prompt, encode_image, api 函数保持不变...
 
78
  '''
79
 
80
  def polish_prompt(prompt, img):
 
81
  if not os.environ.get('DASH_API_KEY'):
82
  print("[警告] 环境变量 DASH_API_KEY 未设置,将跳过提示词重写。")
83
  return prompt
 
98
  return prompt
99
 
100
  def encode_image(pil_image):
 
101
  buffered = BytesIO()
102
  pil_image.save(buffered, format="PNG")
103
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
104
 
105
  def api(prompt, img_list, model="qwen-vl-max-latest", kwargs={}):
 
106
  import dashscope
107
  api_key = os.environ.get('DASH_API_KEY')
108
  if not api_key:
 
118
  raise Exception(f'Failed to post: {response}')
119
 
120
  def load_image(image_path):
 
121
  try:
122
  if image_path.startswith("http://") or image_path.startswith("https://"):
123
  response = requests.get(image_path)
 
131
  print(f" 详细信息: {e}")
132
  return None
133
 
134
+ # --- 函数修改处 ---
135
  def prepare_model():
136
+ """仅下载模型文件到本地缓存,不加载到内存。"""
137
+ repo_id = "Qwen/Qwen-Image-Edit"
138
+ print(f"正在准备从 Hugging Face Hub 下载模型 '{repo_id}'...")
139
+ print("本操作仅下载文件,不会将模型加载到内存或显存中。")
140
+ print("如果是首次运行,将开始下载模型文件(约7GB),请耐心等待。")
141
  try:
142
+ # 使用 snapshot_download 函数只下载文件,并返回其本地路径
143
+ snapshot_download(
144
+ repo_id=repo_id,
145
+ local_dir_use_symlinks=False, # 建议设置为False以提高兼容性
146
+ resume_download=True # 支持断点续传
147
  )
148
+ print(f"\n✅ 模型 '{repo_id}' 的文件已成功下载到本地缓存。")
149
  return True
150
  except Exception as e:
151
+ print(f"\n❌ 错误:模型文件下载失败。请检查您的网络连接或仓库名称 '{repo_id}' 是否正确。")
152
  print(f" 详细信息: {e}")
153
  return False
154
 
155
+ # --- 主推理逻辑 (保持不变) ---
156
  def main(args):
157
  """执行模型推理的主函数"""
158
  output_dir = "output"
 
160
  dtype = torch.bfloat16
161
  device = "cuda" if torch.cuda.is_available() else "cpu"
162
  print(f"使用设备: {device}")
163
+ print("正在加载 Qwen-Image-Edit 模型 (从本地缓存)...") # 更新提示
164
  try:
165
  pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=dtype).to(device)
166
  print("模型加载完成。")
 
193
  except Exception as e:
194
  print(f"❌ 推理过程中发生错误: {e}")
195
 
196
+ # --- 命令行接口 (保持不变) ---
197
  if __name__ == "__main__":
 
198
  if len(sys.argv) == 1:
199
  prepare_model()
200
  print("任务完成,脚本退出。")
201
+ sys.exit(0)
 
 
202
  parser = argparse.ArgumentParser(description="Qwen 图像编辑命令行工具", epilog="如果不提供任何参数,脚本将只下载模型然后退出。")
203
  parser.add_argument("--prompt",type=str,required=True,help="必须:用于编辑图像的指令。")
204
  parser.add_argument("--input_image",type=str,required=True,help="必须:输入图片的本地路径或URL链接。")