import re import random import gc import comfy.model_management as mm from nodes import ConditioningConcat, ConditioningZeroOut, ConditioningSetTimestepRange, ConditioningCombine def chatglm3_text_encode(chatglm3_model, prompt, clean_gpu=False): device = mm.get_torch_device() offload_device = mm.unet_offload_device() if clean_gpu: mm.unload_all_models() mm.soft_empty_cache() # Function to randomly select an option from the brackets def choose_random_option(match): options = match.group(1).split('|') return random.choice(options) prompt = re.sub(r'\{([^{}]*)\}', choose_random_option, prompt) if "|" in prompt: prompt = prompt.split("|") if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) # Define tokenizers and text encoders tokenizer = chatglm3_model['tokenizer'] text_encoder = chatglm3_model['text_encoder'] text_encoder.to(device) text_inputs = tokenizer( prompt, padding="max_length", max_length=256, truncation=True, return_tensors="pt", ).to(device) output = text_encoder( input_ids=text_inputs['input_ids'], attention_mask=text_inputs['attention_mask'], position_ids=text_inputs['position_ids'], output_hidden_states=True) # [batch_size, 77, 4096] prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() text_proj = output.hidden_states[-1][-1, :, :].clone() # [batch_size, 4096] bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, 1, 1) prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) bs_embed = text_proj.shape[0] text_proj = text_proj.repeat(1, 1).view(bs_embed, -1) text_encoder.to(offload_device) if clean_gpu: mm.soft_empty_cache() gc.collect() return [[prompt_embeds, {"pooled_output": text_proj},]] def chatglm3_adv_text_encode(chatglm3_model, text, clean_gpu=False): time_start = 0 time_end = 1 match = re.search(r'TIMESTEP.*$', text) if match: timestep = match.group() timestep = timestep.split(' ') timestep = timestep[0] text = text.replace(timestep, '') value = timestep.split(':') if len(value) >= 3: time_start = float(value[1]) time_end = float(value[2]) elif len(value) == 2: time_start = float(value[1]) time_end = 1 elif len(value) == 1: time_start = 0.1 time_end = 1 pass3 = [x.strip() for x in text.split("BREAK")] pass3 = [x for x in pass3 if x != ''] if len(pass3) == 0: pass3 = [''] conditioning = None for text in pass3: cond = chatglm3_text_encode(chatglm3_model, text, clean_gpu) if conditioning is not None: conditioning = ConditioningConcat().concat(conditioning, cond)[0] else: conditioning = cond # setTimeStepRange if time_start > 0 or time_end < 1: conditioning_2, = ConditioningSetTimestepRange().set_range(conditioning, 0, time_start) conditioning_1, = ConditioningZeroOut().zero_out(conditioning) conditioning_1, = ConditioningSetTimestepRange().set_range(conditioning_1, time_start, time_end) conditioning, = ConditioningCombine().combine(conditioning_1, conditioning_2) return conditioning