jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
import re
import random
import gc
import comfy.model_management as mm
from nodes import ConditioningConcat, ConditioningZeroOut, ConditioningSetTimestepRange, ConditioningCombine
def chatglm3_text_encode(chatglm3_model, prompt, clean_gpu=False):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
if clean_gpu:
mm.unload_all_models()
mm.soft_empty_cache()
# Function to randomly select an option from the brackets
def choose_random_option(match):
options = match.group(1).split('|')
return random.choice(options)
prompt = re.sub(r'\{([^{}]*)\}', choose_random_option, prompt)
if "|" in prompt:
prompt = prompt.split("|")
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
# Define tokenizers and text encoders
tokenizer = chatglm3_model['tokenizer']
text_encoder = chatglm3_model['text_encoder']
text_encoder.to(device)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=256,
truncation=True,
return_tensors="pt",
).to(device)
output = text_encoder(
input_ids=text_inputs['input_ids'],
attention_mask=text_inputs['attention_mask'],
position_ids=text_inputs['position_ids'],
output_hidden_states=True)
# [batch_size, 77, 4096]
prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
text_proj = output.hidden_states[-1][-1, :, :].clone() # [batch_size, 4096]
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, 1, 1)
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
bs_embed = text_proj.shape[0]
text_proj = text_proj.repeat(1, 1).view(bs_embed, -1)
text_encoder.to(offload_device)
if clean_gpu:
mm.soft_empty_cache()
gc.collect()
return [[prompt_embeds, {"pooled_output": text_proj},]]
def chatglm3_adv_text_encode(chatglm3_model, text, clean_gpu=False):
time_start = 0
time_end = 1
match = re.search(r'TIMESTEP.*$', text)
if match:
timestep = match.group()
timestep = timestep.split(' ')
timestep = timestep[0]
text = text.replace(timestep, '')
value = timestep.split(':')
if len(value) >= 3:
time_start = float(value[1])
time_end = float(value[2])
elif len(value) == 2:
time_start = float(value[1])
time_end = 1
elif len(value) == 1:
time_start = 0.1
time_end = 1
pass3 = [x.strip() for x in text.split("BREAK")]
pass3 = [x for x in pass3 if x != '']
if len(pass3) == 0:
pass3 = ['']
conditioning = None
for text in pass3:
cond = chatglm3_text_encode(chatglm3_model, text, clean_gpu)
if conditioning is not None:
conditioning = ConditioningConcat().concat(conditioning, cond)[0]
else:
conditioning = cond
# setTimeStepRange
if time_start > 0 or time_end < 1:
conditioning_2, = ConditioningSetTimestepRange().set_range(conditioning, 0, time_start)
conditioning_1, = ConditioningZeroOut().zero_out(conditioning)
conditioning_1, = ConditioningSetTimestepRange().set_range(conditioning_1, time_start, time_end)
conditioning, = ConditioningCombine().combine(conditioning_1, conditioning_2)
return conditioning