|
import os
|
|
import threading
|
|
import traceback
|
|
|
|
from aiohttp import web
|
|
|
|
import impact
|
|
import folder_paths
|
|
|
|
import torchvision
|
|
|
|
import impact.core as core
|
|
import impact.impact_pack as impact_pack
|
|
from impact.utils import to_tensor
|
|
from segment_anything import SamPredictor, sam_model_registry
|
|
import numpy as np
|
|
import nodes
|
|
from PIL import Image
|
|
import io
|
|
import impact.wildcards as wildcards
|
|
import comfy
|
|
from io import BytesIO
|
|
import random
|
|
from server import PromptServer
|
|
|
|
|
|
@PromptServer.instance.routes.post("/upload/temp")
|
|
async def upload_image(request):
|
|
upload_dir = folder_paths.get_temp_directory()
|
|
|
|
if not os.path.exists(upload_dir):
|
|
os.makedirs(upload_dir)
|
|
|
|
post = await request.post()
|
|
image = post.get("image")
|
|
|
|
if image and image.file:
|
|
filename = image.filename
|
|
if not filename:
|
|
return web.Response(status=400)
|
|
|
|
split = os.path.splitext(filename)
|
|
i = 1
|
|
while os.path.exists(os.path.join(upload_dir, filename)):
|
|
filename = f"{split[0]} ({i}){split[1]}"
|
|
i += 1
|
|
|
|
filepath = os.path.join(upload_dir, filename)
|
|
|
|
with open(filepath, "wb") as f:
|
|
f.write(image.file.read())
|
|
|
|
return web.json_response({"name": filename})
|
|
else:
|
|
return web.Response(status=400)
|
|
|
|
|
|
sam_predictor = None
|
|
default_sam_model_name = os.path.join(impact_pack.model_path, "sams", "sam_vit_b_01ec64.pth")
|
|
|
|
sam_lock = threading.Condition()
|
|
|
|
last_prepare_data = None
|
|
|
|
|
|
def async_prepare_sam(image_dir, model_name, filename):
|
|
with sam_lock:
|
|
global sam_predictor
|
|
|
|
if 'vit_h' in model_name:
|
|
model_kind = 'vit_h'
|
|
elif 'vit_l' in model_name:
|
|
model_kind = 'vit_l'
|
|
else:
|
|
model_kind = 'vit_b'
|
|
|
|
sam_model = sam_model_registry[model_kind](checkpoint=model_name)
|
|
sam_predictor = SamPredictor(sam_model)
|
|
|
|
image_path = os.path.join(image_dir, filename)
|
|
image = nodes.LoadImage().load_image(image_path)[0]
|
|
image = np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
|
|
|
|
if impact.config.get_config()['sam_editor_cpu']:
|
|
device = 'cpu'
|
|
else:
|
|
device = comfy.model_management.get_torch_device()
|
|
|
|
sam_predictor.model.to(device=device)
|
|
sam_predictor.set_image(image, "RGB")
|
|
sam_predictor.model.cpu()
|
|
|
|
|
|
@PromptServer.instance.routes.post("/sam/prepare")
|
|
async def sam_prepare(request):
|
|
global sam_predictor
|
|
global last_prepare_data
|
|
data = await request.json()
|
|
|
|
with sam_lock:
|
|
if last_prepare_data is not None and last_prepare_data == data:
|
|
|
|
return web.Response(status=200)
|
|
|
|
last_prepare_data = data
|
|
|
|
model_name = 'sam_vit_b_01ec64.pth'
|
|
if data['sam_model_name'] == 'auto':
|
|
model_name = impact.config.get_config()['sam_editor_model']
|
|
|
|
model_name = os.path.join(impact_pack.model_path, "sams", model_name)
|
|
|
|
print(f"[INFO] ComfyUI-Impact-Pack: Loading SAM model '{impact_pack.model_path}'")
|
|
|
|
filename, image_dir = folder_paths.annotated_filepath(data["filename"])
|
|
|
|
if image_dir is None:
|
|
typ = data['type'] if data['type'] != '' else 'output'
|
|
image_dir = folder_paths.get_directory_by_type(typ)
|
|
if data['subfolder'] is not None and data['subfolder'] != '':
|
|
image_dir += f"/{data['subfolder']}"
|
|
|
|
if image_dir is None:
|
|
return web.Response(status=400)
|
|
|
|
thread = threading.Thread(target=async_prepare_sam, args=(image_dir, model_name, filename,))
|
|
thread.start()
|
|
|
|
print(f"[INFO] ComfyUI-Impact-Pack: SAM model loaded. ")
|
|
return web.Response(status=200)
|
|
|
|
|
|
@PromptServer.instance.routes.post("/sam/release")
|
|
async def release_sam(request):
|
|
global sam_predictor
|
|
|
|
with sam_lock:
|
|
del sam_predictor
|
|
sam_predictor = None
|
|
|
|
print(f"[INFO] ComfyUI-Impact-Pack: unloading SAM model")
|
|
|
|
|
|
@PromptServer.instance.routes.post("/sam/detect")
|
|
async def sam_detect(request):
|
|
global sam_predictor
|
|
with sam_lock:
|
|
if sam_predictor is not None:
|
|
if impact.config.get_config()['sam_editor_cpu']:
|
|
device = 'cpu'
|
|
else:
|
|
device = comfy.model_management.get_torch_device()
|
|
|
|
sam_predictor.model.to(device=device)
|
|
try:
|
|
data = await request.json()
|
|
|
|
positive_points = data['positive_points']
|
|
negative_points = data['negative_points']
|
|
threshold = data['threshold']
|
|
|
|
points = []
|
|
plabs = []
|
|
|
|
for p in positive_points:
|
|
points.append(p)
|
|
plabs.append(1)
|
|
|
|
for p in negative_points:
|
|
points.append(p)
|
|
plabs.append(0)
|
|
|
|
detected_masks = core.sam_predict(sam_predictor, points, plabs, None, threshold)
|
|
mask = core.combine_masks2(detected_masks)
|
|
|
|
if mask is None:
|
|
return web.Response(status=400)
|
|
|
|
image = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
|
i = 255. * image.cpu().numpy()
|
|
|
|
img = Image.fromarray(np.clip(i[0], 0, 255).astype(np.uint8))
|
|
|
|
img_buffer = io.BytesIO()
|
|
img.save(img_buffer, format='png')
|
|
|
|
headers = {'Content-Type': 'image/png'}
|
|
finally:
|
|
sam_predictor.model.to(device="cpu")
|
|
|
|
return web.Response(body=img_buffer.getvalue(), headers=headers)
|
|
|
|
else:
|
|
return web.Response(status=400)
|
|
|
|
|
|
@PromptServer.instance.routes.get("/impact/wildcards/refresh")
|
|
async def wildcards_refresh(request):
|
|
impact.wildcards.wildcard_load()
|
|
return web.Response(status=200)
|
|
|
|
|
|
@PromptServer.instance.routes.get("/impact/wildcards/list")
|
|
async def wildcards_list(request):
|
|
data = {'data': impact.wildcards.get_wildcard_list()}
|
|
return web.json_response(data)
|
|
|
|
|
|
@PromptServer.instance.routes.post("/impact/wildcards")
|
|
async def populate_wildcards(request):
|
|
data = await request.json()
|
|
populated = wildcards.process(data['text'], data.get('seed', None))
|
|
return web.json_response({"text": populated})
|
|
|
|
|
|
segs_picker_map = {}
|
|
|
|
@PromptServer.instance.routes.get("/impact/segs/picker/count")
|
|
async def segs_picker_count(request):
|
|
node_id = request.rel_url.query.get('id', '')
|
|
|
|
if node_id in segs_picker_map:
|
|
res = len(segs_picker_map[node_id])
|
|
return web.Response(status=200, text=str(res))
|
|
|
|
return web.Response(status=400)
|
|
|
|
|
|
@PromptServer.instance.routes.get("/impact/segs/picker/view")
|
|
async def segs_picker(request):
|
|
node_id = request.rel_url.query.get('id', '')
|
|
idx = int(request.rel_url.query.get('idx', ''))
|
|
|
|
if node_id in segs_picker_map and idx < len(segs_picker_map[node_id]):
|
|
img = to_tensor(segs_picker_map[node_id][idx]).permute(0, 3, 1, 2).squeeze(0)
|
|
pil = torchvision.transforms.ToPILImage('RGB')(img)
|
|
|
|
image_bytes = BytesIO()
|
|
pil.save(image_bytes, format="PNG")
|
|
image_bytes.seek(0)
|
|
return web.Response(status=200, body=image_bytes, content_type='image/png', headers={"Content-Disposition": f"filename={node_id}{idx}.png"})
|
|
|
|
return web.Response(status=400)
|
|
|
|
|
|
@PromptServer.instance.routes.get("/view/validate")
|
|
async def view_validate(request):
|
|
if "filename" in request.rel_url.query:
|
|
filename = request.rel_url.query["filename"]
|
|
subfolder = request.rel_url.query["subfolder"]
|
|
filename, base_dir = folder_paths.annotated_filepath(filename)
|
|
|
|
if filename == '' or filename[0] == '/' or '..' in filename:
|
|
return web.Response(status=400)
|
|
|
|
if base_dir is None:
|
|
base_dir = folder_paths.get_input_directory()
|
|
|
|
file = os.path.join(base_dir, subfolder, filename)
|
|
|
|
if os.path.isfile(file):
|
|
return web.Response(status=200)
|
|
|
|
return web.Response(status=400)
|
|
|
|
|
|
@PromptServer.instance.routes.get("/impact/validate/pb_id_image")
|
|
async def view_validate(request):
|
|
if "id" in request.rel_url.query:
|
|
pb_id = request.rel_url.query["id"]
|
|
|
|
if pb_id not in core.preview_bridge_image_id_map:
|
|
return web.Response(status=400)
|
|
|
|
file = core.preview_bridge_image_id_map[pb_id]
|
|
if os.path.isfile(file):
|
|
return web.Response(status=200)
|
|
|
|
return web.Response(status=400)
|
|
|
|
|
|
@PromptServer.instance.routes.get("/impact/set/pb_id_image")
|
|
async def set_previewbridge_image(request):
|
|
try:
|
|
if "filename" in request.rel_url.query:
|
|
node_id = request.rel_url.query["node_id"]
|
|
filename = request.rel_url.query["filename"]
|
|
path_type = request.rel_url.query["type"]
|
|
subfolder = request.rel_url.query["subfolder"]
|
|
filename, output_dir = folder_paths.annotated_filepath(filename)
|
|
|
|
if filename == '' or filename[0] == '/' or '..' in filename:
|
|
return web.Response(status=400)
|
|
|
|
if output_dir is None:
|
|
if path_type == 'input':
|
|
output_dir = folder_paths.get_input_directory()
|
|
elif path_type == 'output':
|
|
output_dir = folder_paths.get_output_directory()
|
|
else:
|
|
output_dir = folder_paths.get_temp_directory()
|
|
|
|
file = os.path.join(output_dir, subfolder, filename)
|
|
item = {
|
|
'filename': filename,
|
|
'type': path_type,
|
|
'subfolder': subfolder,
|
|
}
|
|
pb_id = core.set_previewbridge_image(node_id, file, item)
|
|
|
|
return web.Response(status=200, text=pb_id)
|
|
except Exception:
|
|
traceback.print_exc()
|
|
|
|
return web.Response(status=400)
|
|
|
|
|
|
@PromptServer.instance.routes.get("/impact/get/pb_id_image")
|
|
async def get_previewbridge_image(request):
|
|
if "id" in request.rel_url.query:
|
|
pb_id = request.rel_url.query["id"]
|
|
|
|
if pb_id in core.preview_bridge_image_id_map:
|
|
_, path_item = core.preview_bridge_image_id_map[pb_id]
|
|
return web.json_response(path_item)
|
|
|
|
return web.Response(status=400)
|
|
|
|
|
|
@PromptServer.instance.routes.get("/impact/view/pb_id_image")
|
|
async def view_previewbridge_image(request):
|
|
if "id" in request.rel_url.query:
|
|
pb_id = request.rel_url.query["id"]
|
|
|
|
if pb_id in core.preview_bridge_image_id_map:
|
|
file = core.preview_bridge_image_id_map[pb_id]
|
|
|
|
with Image.open(file) as img:
|
|
filename = os.path.basename(file)
|
|
return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""})
|
|
|
|
return web.Response(status=400)
|
|
|
|
|
|
def onprompt_for_switch(json_data):
|
|
inversed_switch_info = {}
|
|
onprompt_switch_info = {}
|
|
onprompt_cond_branch_info = {}
|
|
|
|
for k, v in json_data['prompt'].items():
|
|
if 'class_type' not in v:
|
|
continue
|
|
|
|
cls = v['class_type']
|
|
if cls == 'ImpactInversedSwitch':
|
|
select_input = v['inputs']['select']
|
|
if isinstance(select_input, list) and len(select_input) == 2:
|
|
input_node = json_data['prompt'][select_input[0]]
|
|
if input_node['class_type'] == 'ImpactInt' and 'inputs' in input_node and 'value' in input_node['inputs']:
|
|
inversed_switch_info[k] = input_node['inputs']['value']
|
|
else:
|
|
inversed_switch_info[k] = select_input
|
|
|
|
elif cls in ['ImpactSwitch', 'LatentSwitch', 'SEGSSwitch', 'ImpactMakeImageList']:
|
|
if 'sel_mode' in v['inputs'] and v['inputs']['sel_mode'] and 'select' in v['inputs']:
|
|
select_input = v['inputs']['select']
|
|
if isinstance(select_input, list) and len(select_input) == 2:
|
|
input_node = json_data['prompt'][select_input[0]]
|
|
if input_node['class_type'] == 'ImpactInt' and 'inputs' in input_node and 'value' in input_node['inputs']:
|
|
onprompt_switch_info[k] = input_node['inputs']['value']
|
|
if input_node['class_type'] == 'ImpactSwitch' and 'inputs' in input_node and 'select' in input_node['inputs']:
|
|
if isinstance(input_node['inputs']['select'], int):
|
|
onprompt_switch_info[k] = input_node['inputs']['select']
|
|
else:
|
|
print(f"\n##### ##### #####\n[WARN] {cls}: For the 'select' operation, only 'select_index' of the 'ImpactSwitch', which is not an input, or 'ImpactInt' and 'Primitive' are allowed as inputs.\n##### ##### #####\n")
|
|
else:
|
|
onprompt_switch_info[k] = select_input
|
|
|
|
elif cls == 'ImpactConditionalBranchSelMode':
|
|
if 'sel_mode' in v['inputs'] and v['inputs']['sel_mode'] and 'cond' in v['inputs']:
|
|
cond_input = v['inputs']['cond']
|
|
if isinstance(cond_input, list) and len(cond_input) == 2:
|
|
input_node = json_data['prompt'][cond_input[0]]
|
|
if (input_node['class_type'] == 'ImpactValueReceiver' and 'inputs' in input_node
|
|
and 'value' in input_node['inputs'] and 'typ' in input_node['inputs']):
|
|
if 'BOOLEAN' == input_node['inputs']['typ']:
|
|
try:
|
|
onprompt_cond_branch_info[k] = input_node['inputs']['value'].lower() == "true"
|
|
except:
|
|
pass
|
|
else:
|
|
onprompt_cond_branch_info[k] = cond_input
|
|
|
|
for k, v in json_data['prompt'].items():
|
|
disable_targets = set()
|
|
|
|
for kk, vv in v['inputs'].items():
|
|
if isinstance(vv, list) and len(vv) == 2:
|
|
if vv[0] in inversed_switch_info:
|
|
if vv[1] + 1 != inversed_switch_info[vv[0]]:
|
|
disable_targets.add(kk)
|
|
|
|
if k in onprompt_switch_info:
|
|
selected_slot_name = f"input{onprompt_switch_info[k]}"
|
|
for kk, vv in v['inputs'].items():
|
|
if kk != selected_slot_name and kk.startswith('input'):
|
|
disable_targets.add(kk)
|
|
|
|
if k in onprompt_cond_branch_info:
|
|
selected_slot_name = "tt_value" if onprompt_cond_branch_info[k] else "ff_value"
|
|
for kk, vv in v['inputs'].items():
|
|
if kk in ['tt_value', 'ff_value'] and kk != selected_slot_name:
|
|
disable_targets.add(kk)
|
|
|
|
for kk in disable_targets:
|
|
del v['inputs'][kk]
|
|
|
|
def onprompt_for_pickers(json_data):
|
|
detected_pickers = set()
|
|
|
|
for k, v in json_data['prompt'].items():
|
|
if 'class_type' not in v:
|
|
continue
|
|
|
|
cls = v['class_type']
|
|
if cls == 'ImpactSEGSPicker':
|
|
detected_pickers.add(k)
|
|
|
|
|
|
keys_to_remove = [key for key in segs_picker_map if key not in detected_pickers]
|
|
for key in keys_to_remove:
|
|
del segs_picker_map[key]
|
|
|
|
|
|
def gc_preview_bridge_cache(json_data):
|
|
prompt_keys = json_data['prompt'].keys()
|
|
|
|
for key in list(core.preview_bridge_cache.keys()):
|
|
if key not in prompt_keys:
|
|
print(f"key deleted: {key}")
|
|
del core.preview_bridge_cache[key]
|
|
|
|
|
|
def workflow_imagereceiver_update(json_data):
|
|
prompt = json_data['prompt']
|
|
|
|
for v in prompt.values():
|
|
if 'class_type' in v and v['class_type'] == 'ImageReceiver':
|
|
if v['inputs']['save_to_workflow']:
|
|
v['inputs']['image'] = "#DATA"
|
|
|
|
|
|
def regional_sampler_seed_update(json_data):
|
|
prompt = json_data['prompt']
|
|
|
|
for k, v in prompt.items():
|
|
if 'class_type' in v and v['class_type'] == 'RegionalSampler':
|
|
seed_2nd_mode = v['inputs']['seed_2nd_mode']
|
|
|
|
new_seed = None
|
|
if seed_2nd_mode == 'increment':
|
|
new_seed = v['inputs']['seed_2nd']+1
|
|
if new_seed > 1125899906842624:
|
|
new_seed = 0
|
|
elif seed_2nd_mode == 'decrement':
|
|
new_seed = v['inputs']['seed_2nd']-1
|
|
if new_seed < 0:
|
|
new_seed = 1125899906842624
|
|
elif seed_2nd_mode == 'randomize':
|
|
new_seed = random.randint(0, 1125899906842624)
|
|
|
|
if new_seed is not None:
|
|
PromptServer.instance.send_sync("impact-node-feedback", {"node_id": k, "widget_name": "seed_2nd", "type": "INT", "value": new_seed})
|
|
|
|
|
|
def onprompt_populate_wildcards(json_data):
|
|
prompt = json_data['prompt']
|
|
|
|
updated_widget_values = {}
|
|
for k, v in prompt.items():
|
|
if 'class_type' in v and (v['class_type'] == 'ImpactWildcardEncode' or v['class_type'] == 'ImpactWildcardProcessor'):
|
|
inputs = v['inputs']
|
|
if inputs['mode'] and isinstance(inputs['populated_text'], str):
|
|
if isinstance(inputs['seed'], list):
|
|
try:
|
|
input_node = prompt[inputs['seed'][0]]
|
|
if input_node['class_type'] == 'ImpactInt':
|
|
input_seed = int(input_node['inputs']['value'])
|
|
if not isinstance(input_seed, int):
|
|
continue
|
|
if input_node['class_type'] == 'Seed (rgthree)':
|
|
input_seed = int(input_node['inputs']['seed'])
|
|
if not isinstance(input_seed, int):
|
|
continue
|
|
else:
|
|
print(f"[Impact Pack] Only `ImpactInt`, `Seed (rgthree)` and `Primitive` Node are allowed as the seed for '{v['class_type']}'. It will be ignored. ")
|
|
continue
|
|
except:
|
|
continue
|
|
else:
|
|
input_seed = int(inputs['seed'])
|
|
|
|
inputs['populated_text'] = wildcards.process(inputs['wildcard_text'], input_seed)
|
|
inputs['mode'] = False
|
|
|
|
PromptServer.instance.send_sync("impact-node-feedback", {"node_id": k, "widget_name": "populated_text", "type": "STRING", "value": inputs['populated_text']})
|
|
updated_widget_values[k] = inputs['populated_text']
|
|
|
|
if 'extra_data' in json_data and 'extra_pnginfo' in json_data['extra_data']:
|
|
for node in json_data['extra_data']['extra_pnginfo']['workflow']['nodes']:
|
|
key = str(node['id'])
|
|
if key in updated_widget_values:
|
|
node['widgets_values'][1] = updated_widget_values[key]
|
|
node['widgets_values'][2] = False
|
|
|
|
|
|
def onprompt_for_remote(json_data):
|
|
prompt = json_data['prompt']
|
|
|
|
for v in prompt.values():
|
|
if 'class_type' in v:
|
|
cls = v['class_type']
|
|
if cls == 'ImpactRemoteBoolean' or cls == 'ImpactRemoteInt':
|
|
inputs = v['inputs']
|
|
node_id = str(inputs['node_id'])
|
|
|
|
if node_id not in prompt:
|
|
continue
|
|
|
|
target_inputs = prompt[node_id]['inputs']
|
|
|
|
widget_name = inputs['widget_name']
|
|
if widget_name in target_inputs:
|
|
widget_type = None
|
|
if cls == 'ImpactRemoteBoolean' and isinstance(target_inputs[widget_name], bool):
|
|
widget_type = 'BOOLEAN'
|
|
|
|
elif cls == 'ImpactRemoteInt' and (isinstance(target_inputs[widget_name], int) or isinstance(target_inputs[widget_name], float)):
|
|
widget_type = 'INT'
|
|
|
|
if widget_type is None:
|
|
break
|
|
|
|
target_inputs[widget_name] = inputs['value']
|
|
PromptServer.instance.send_sync("impact-node-feedback", {"node_id": node_id, "widget_name": widget_name, "type": widget_type, "value": inputs['value']})
|
|
|
|
|
|
def onprompt(json_data):
|
|
try:
|
|
onprompt_for_remote(json_data)
|
|
onprompt_for_switch(json_data)
|
|
onprompt_for_pickers(json_data)
|
|
onprompt_populate_wildcards(json_data)
|
|
gc_preview_bridge_cache(json_data)
|
|
workflow_imagereceiver_update(json_data)
|
|
regional_sampler_seed_update(json_data)
|
|
core.current_prompt = json_data
|
|
except Exception as e:
|
|
print(f"[WARN] ComfyUI-Impact-Pack: Error on prompt - several features will not work.\n{e}")
|
|
|
|
return json_data
|
|
|
|
|
|
PromptServer.instance.add_on_prompt_handler(onprompt)
|
|
|