Spaces:
Runtime error
Runtime error
Update inference.py
Browse files- inference.py +41 -87
inference.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
# inference.py
|
2 |
import os
|
|
|
3 |
import argparse
|
4 |
import random
|
5 |
import json
|
@@ -14,7 +15,7 @@ from PIL import Image
|
|
14 |
from diffusers import QwenImageEditPipeline
|
15 |
|
16 |
# --- 从原脚本保留的辅助函数 ---
|
17 |
-
|
18 |
SYSTEM_PROMPT = '''
|
19 |
# Edit Instruction Rewriter
|
20 |
You are a professional edit instruction rewriter. Your task is to generate a precise, concise, and visually achievable professional-level edit instruction based on the user-provided instruction and the image to be edited.
|
@@ -80,9 +81,8 @@ def polish_prompt(prompt, img):
|
|
80 |
if not os.environ.get('DASH_API_KEY'):
|
81 |
print("[警告] 环境变量 DASH_API_KEY 未设置,将跳过提示词重写。")
|
82 |
return prompt
|
83 |
-
|
84 |
full_prompt = f"{SYSTEM_PROMPT}\n\nUser Input: {prompt}\n\nRewritten Prompt:"
|
85 |
-
for attempt in range(3):
|
86 |
try:
|
87 |
result = api(full_prompt, [img])
|
88 |
if isinstance(result, str):
|
@@ -90,12 +90,10 @@ def polish_prompt(prompt, img):
|
|
90 |
result_data = json.loads(result_json_str)
|
91 |
else:
|
92 |
result_data = json.loads(result)
|
93 |
-
|
94 |
polished = result_data['Rewritten']
|
95 |
return polished.strip().replace("\n", " ")
|
96 |
except Exception as e:
|
97 |
print(f"[警告] API调用失败 (尝试 {attempt + 1}): {e}")
|
98 |
-
|
99 |
print("[错误] 多次尝试后提示词重写失败,将使用原始提示词。")
|
100 |
return prompt
|
101 |
|
@@ -111,23 +109,11 @@ def api(prompt, img_list, model="qwen-vl-max-latest", kwargs={}):
|
|
111 |
api_key = os.environ.get('DASH_API_KEY')
|
112 |
if not api_key:
|
113 |
raise EnvironmentError("DASH_API_KEY is not set")
|
114 |
-
|
115 |
-
messages = [
|
116 |
-
{"role": "system", "content": "you are a helpful assistant, you should provide useful answers to users."},
|
117 |
-
{"role": "user", "content": []}
|
118 |
-
]
|
119 |
for img in img_list:
|
120 |
messages[1]["content"].append({"image": f"data:image/png;base64,{encode_image(img)}"})
|
121 |
messages[1]["content"].append({"text": f"{prompt}"})
|
122 |
-
|
123 |
-
response = dashscope.MultiModalConversation.call(
|
124 |
-
api_key=api_key,
|
125 |
-
model=model,
|
126 |
-
messages=messages,
|
127 |
-
result_format='message',
|
128 |
-
response_format=kwargs.get('response_format', None),
|
129 |
-
)
|
130 |
-
|
131 |
if response.status_code == 200:
|
132 |
return response.output.choices[0].message.content[0]['text']
|
133 |
else:
|
@@ -148,113 +134,81 @@ def load_image(image_path):
|
|
148 |
print(f" 详细信息: {e}")
|
149 |
return None
|
150 |
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
|
|
153 |
def main(args):
|
154 |
"""执行模型推理的主函数"""
|
155 |
output_dir = "output"
|
156 |
os.makedirs(output_dir, exist_ok=True)
|
157 |
-
|
158 |
dtype = torch.bfloat16
|
159 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
160 |
print(f"使用设备: {device}")
|
161 |
-
|
162 |
print("正在加载 Qwen-Image-Edit 模型...")
|
163 |
try:
|
164 |
pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=dtype).to(device)
|
165 |
print("模型加载完成。")
|
166 |
except Exception as e:
|
167 |
-
print(f"❌
|
168 |
print(f" 详细信息: {e}")
|
169 |
return
|
170 |
-
|
171 |
print(f"正在从 '{args.input_image}' 加载输入图片...")
|
172 |
input_image = load_image(args.input_image)
|
173 |
if input_image is None:
|
174 |
return
|
175 |
-
|
176 |
-
# 设置随机种子
|
177 |
seed = random.randint(0, np.iinfo(np.int32).max) if args.random_seed else args.seed
|
178 |
generator = torch.Generator(device=device).manual_seed(seed)
|
179 |
-
|
180 |
-
# 如果不禁用重写功能,则调用 polish_prompt
|
181 |
prompt_to_use = polish_prompt(args.prompt, input_image) if not args.no_rewrite else args.prompt
|
182 |
-
|
183 |
if not args.no_rewrite:
|
184 |
print(f"重写后的提示词: '{prompt_to_use}'")
|
185 |
-
|
186 |
print("-" * 30)
|
187 |
print("🚀 开始推理...")
|
188 |
print(f" - 提示词: '{prompt_to_use}'")
|
189 |
print(f" - 随机种子: {seed}")
|
190 |
print(f" - 推理步数: {args.steps}")
|
191 |
-
print(f"
|
192 |
print("-" * 30)
|
193 |
-
|
194 |
try:
|
195 |
-
images = pipe(
|
196 |
-
image=input_image,
|
197 |
-
prompt=prompt_to_use,
|
198 |
-
negative_prompt=" ", # 固定负向提示词
|
199 |
-
num_inference_steps=args.steps,
|
200 |
-
generator=generator,
|
201 |
-
true_cfg_scale=args.guidance_scale,
|
202 |
-
num_images_per_prompt=1
|
203 |
-
).images
|
204 |
-
|
205 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
206 |
output_path = os.path.join(output_dir, f"output_{timestamp}_{seed}.png")
|
207 |
images[0].save(output_path)
|
208 |
print(f"✅ 推理成功!图片已保存至: {output_path}")
|
209 |
-
|
210 |
except Exception as e:
|
211 |
print(f"❌ 推理过程中发生错误: {e}")
|
212 |
|
213 |
# --- 命令行接口 ---
|
214 |
-
|
215 |
if __name__ == "__main__":
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
"
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
)
|
224 |
-
parser.add_argument(
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
)
|
230 |
-
parser.add_argument(
|
231 |
-
"--seed",
|
232 |
-
type=int,
|
233 |
-
default=42,
|
234 |
-
help="用于复现结果的随机种子,默认为 42。"
|
235 |
-
)
|
236 |
-
parser.add_argument(
|
237 |
-
"--random_seed",
|
238 |
-
action="store_true",
|
239 |
-
help="如果设置此项,则使用一个随机种子。"
|
240 |
-
)
|
241 |
-
parser.add_argument(
|
242 |
-
"--steps",
|
243 |
-
type=int,
|
244 |
-
default=50,
|
245 |
-
help="推理步数,默认为 50。"
|
246 |
-
)
|
247 |
-
parser.add_argument(
|
248 |
-
"--guidance_scale",
|
249 |
-
type=float,
|
250 |
-
default=4.0,
|
251 |
-
help="引导系数 (CFG scale),默认为 4.0。"
|
252 |
-
)
|
253 |
-
parser.add_argument(
|
254 |
-
"--no_rewrite",
|
255 |
-
action="store_true",
|
256 |
-
help="如果设置此项,则禁用提示词重写功能。"
|
257 |
-
)
|
258 |
-
|
259 |
args = parser.parse_args()
|
260 |
main(args)
|
|
|
1 |
# inference.py
|
2 |
import os
|
3 |
+
import sys # 导入 sys 模块
|
4 |
import argparse
|
5 |
import random
|
6 |
import json
|
|
|
15 |
from diffusers import QwenImageEditPipeline
|
16 |
|
17 |
# --- 从原脚本保留的辅助函数 ---
|
18 |
+
# SYSTEM_PROMPT, polish_prompt, encode_image, api 函数保持不变...
|
19 |
SYSTEM_PROMPT = '''
|
20 |
# Edit Instruction Rewriter
|
21 |
You are a professional edit instruction rewriter. Your task is to generate a precise, concise, and visually achievable professional-level edit instruction based on the user-provided instruction and the image to be edited.
|
|
|
81 |
if not os.environ.get('DASH_API_KEY'):
|
82 |
print("[警告] 环境变量 DASH_API_KEY 未设置,将跳过提示词重写。")
|
83 |
return prompt
|
|
|
84 |
full_prompt = f"{SYSTEM_PROMPT}\n\nUser Input: {prompt}\n\nRewritten Prompt:"
|
85 |
+
for attempt in range(3):
|
86 |
try:
|
87 |
result = api(full_prompt, [img])
|
88 |
if isinstance(result, str):
|
|
|
90 |
result_data = json.loads(result_json_str)
|
91 |
else:
|
92 |
result_data = json.loads(result)
|
|
|
93 |
polished = result_data['Rewritten']
|
94 |
return polished.strip().replace("\n", " ")
|
95 |
except Exception as e:
|
96 |
print(f"[警告] API调用失败 (尝试 {attempt + 1}): {e}")
|
|
|
97 |
print("[错误] 多次尝试后提示词重写失败,将使用原始提示词。")
|
98 |
return prompt
|
99 |
|
|
|
109 |
api_key = os.environ.get('DASH_API_KEY')
|
110 |
if not api_key:
|
111 |
raise EnvironmentError("DASH_API_KEY is not set")
|
112 |
+
messages = [{"role": "system", "content": "you are a helpful assistant, you should provide useful answers to users."},{"role": "user", "content": []}]
|
|
|
|
|
|
|
|
|
113 |
for img in img_list:
|
114 |
messages[1]["content"].append({"image": f"data:image/png;base64,{encode_image(img)}"})
|
115 |
messages[1]["content"].append({"text": f"{prompt}"})
|
116 |
+
response = dashscope.MultiModalConversation.call(api_key=api_key,model=model,messages=messages,result_format='message',response_format=kwargs.get('response_format', None),)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
if response.status_code == 200:
|
118 |
return response.output.choices[0].message.content[0]['text']
|
119 |
else:
|
|
|
134 |
print(f" 详细信息: {e}")
|
135 |
return None
|
136 |
|
137 |
+
def prepare_model():
|
138 |
+
"""仅下载并缓存模型,不执行推理"""
|
139 |
+
print("正在准备模型... 如果是首次运行,将开始下载模型文件(约10GB)。")
|
140 |
+
print("请耐心等待,下载速度取决于您的网络状况。")
|
141 |
+
dtype = torch.bfloat16
|
142 |
+
try:
|
143 |
+
QwenImageEditPipeline.from_pretrained(
|
144 |
+
"Qwen/Qwen-Image-Edit",
|
145 |
+
torch_dtype=dtype,
|
146 |
+
low_cpu_mem_usage=True # 优化内存使用
|
147 |
+
)
|
148 |
+
print("\n✅ 模型文件已成功准备(下载/加载)到本地缓存。")
|
149 |
+
return True
|
150 |
+
except Exception as e:
|
151 |
+
print(f"\n❌ 错误:模型下载或加载失败。请检查网络连接或磁盘空间。")
|
152 |
+
print(f" 详细信息: {e}")
|
153 |
+
return False
|
154 |
|
155 |
+
# --- 主推理逻辑 ---
|
156 |
def main(args):
|
157 |
"""执行模型推理的主函数"""
|
158 |
output_dir = "output"
|
159 |
os.makedirs(output_dir, exist_ok=True)
|
|
|
160 |
dtype = torch.bfloat16
|
161 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
162 |
print(f"使用设备: {device}")
|
|
|
163 |
print("正在加载 Qwen-Image-Edit 模型...")
|
164 |
try:
|
165 |
pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=dtype).to(device)
|
166 |
print("模型加载完成。")
|
167 |
except Exception as e:
|
168 |
+
print(f"❌ 错误:模型加载失败。")
|
169 |
print(f" 详细信息: {e}")
|
170 |
return
|
|
|
171 |
print(f"正在从 '{args.input_image}' 加载输入图片...")
|
172 |
input_image = load_image(args.input_image)
|
173 |
if input_image is None:
|
174 |
return
|
|
|
|
|
175 |
seed = random.randint(0, np.iinfo(np.int32).max) if args.random_seed else args.seed
|
176 |
generator = torch.Generator(device=device).manual_seed(seed)
|
|
|
|
|
177 |
prompt_to_use = polish_prompt(args.prompt, input_image) if not args.no_rewrite else args.prompt
|
|
|
178 |
if not args.no_rewrite:
|
179 |
print(f"重写后的提示词: '{prompt_to_use}'")
|
|
|
180 |
print("-" * 30)
|
181 |
print("🚀 开始推理...")
|
182 |
print(f" - 提示词: '{prompt_to_use}'")
|
183 |
print(f" - 随机种子: {seed}")
|
184 |
print(f" - 推理步数: {args.steps}")
|
185 |
+
print(f" - 引导系数 (Guidance Scale): {args.guidance_scale}")
|
186 |
print("-" * 30)
|
|
|
187 |
try:
|
188 |
+
images = pipe(image=input_image,prompt=prompt_to_use,negative_prompt=" ",num_inference_steps=args.steps,generator=generator,true_cfg_scale=args.guidance_scale,num_images_per_prompt=1).images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
190 |
output_path = os.path.join(output_dir, f"output_{timestamp}_{seed}.png")
|
191 |
images[0].save(output_path)
|
192 |
print(f"✅ 推理成功!图片已保存至: {output_path}")
|
|
|
193 |
except Exception as e:
|
194 |
print(f"❌ 推理过程中发生错误: {e}")
|
195 |
|
196 |
# --- 命令行接口 ---
|
|
|
197 |
if __name__ == "__main__":
|
198 |
+
# 新增逻辑:检查是否只运行脚本而不带任何参数
|
199 |
+
if len(sys.argv) == 1:
|
200 |
+
prepare_model()
|
201 |
+
print("任务完成,脚本退出。")
|
202 |
+
sys.exit(0) # 正常退出
|
203 |
+
|
204 |
+
# 如果带有参数,则执行原有的推理流程
|
205 |
+
parser = argparse.ArgumentParser(description="Qwen 图像编辑命令行工具", epilog="如果不提供任何参数,脚本将只下载模型然后退出。")
|
206 |
+
parser.add_argument("--prompt",type=str,required=True,help="必须:用于编辑图像的指令。")
|
207 |
+
parser.add_argument("--input_image",type=str,required=True,help="必须:输入图片的本地路径或URL链接。")
|
208 |
+
parser.add_argument("--seed",type=int,default=42,help="用于复现结果的随机种子,默认为 42。")
|
209 |
+
parser.add_argument("--random_seed",action="store_true",help="如果设置此项,则使用一个随机种子。")
|
210 |
+
parser.add_argument("--steps",type=int,default=50,help="推理步数,默认为 50。")
|
211 |
+
parser.add_argument("--guidance_scale",type=float,default=4.0,help="引导系数 (CFG scale),默认为 4.0。")
|
212 |
+
parser.add_argument("--no_rewrite",action="store_true",help="如果设置此项,则禁用提示词重写功能。")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
args = parser.parse_args()
|
214 |
main(args)
|