File size: 5,651 Bytes
82ea528 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import torch
import numpy as np
from PIL import Image
class JanusImageGeneration:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("JANUS_MODEL",),
"processor": ("JANUS_PROCESSOR",),
"prompt": ("STRING", {
"multiline": True,
"default": "A beautiful photo of"
}),
"seed": ("INT", {
"default": 666666666666666,
"min": 0,
"max": 0xffffffffffffffff
}),
"batch_size": ("INT", {
"default": 1,
"min": 1,
"max": 16
}),
"cfg_weight": ("FLOAT", {
"default": 5.0,
"min": 1.0,
"max": 10.0,
"step": 0.5
}),
"temperature": ("FLOAT", {
"default": 1.0,
"min": 0.1,
"max": 2.0,
"step": 0.1
}),
"top_p": ("FLOAT", {
"default": 0.95,
"min": 0.0,
"max": 1.0
}),
},
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "generate_images"
CATEGORY = "Janus-Pro"
def generate_images(self, model, processor, prompt, seed, batch_size=1, temperature=1.0, cfg_weight=5.0, top_p=0.95):
try:
from janus.models import MultiModalityCausalLM
except ImportError:
raise ImportError("Please install Janus using 'pip install -r requirements.txt'")
# 设置随机种子
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# 图像参数设置
image_token_num = 576 # 24x24 patches
img_size = 384 # 输出图像大小
patch_size = 16 # 每个patch的大小
parallel_size = batch_size
# 准备对话格式
conversation = [
{
"role": "<|User|>",
"content": prompt,
},
{"role": "<|Assistant|>", "content": ""},
]
# 准备输入
sft_format = processor.apply_sft_template_for_multi_turn_prompts(
conversations=conversation,
sft_format=processor.sft_format,
system_prompt="",
)
prompt = sft_format + processor.image_start_tag
# 编码输入文本
input_ids = processor.tokenizer.encode(prompt)
input_ids = torch.LongTensor(input_ids)
# 准备条件和无条件输入
tokens = torch.zeros((parallel_size*2, len(input_ids)), dtype=torch.int).cuda()
for i in range(parallel_size*2):
tokens[i, :] = input_ids
if i % 2 != 0: # 无条件输入
tokens[i, 1:-1] = processor.pad_id
# 获取文本嵌入
inputs_embeds = model.language_model.get_input_embeddings()(tokens)
# 生成图像tokens
generated_tokens = torch.zeros((parallel_size, image_token_num), dtype=torch.int).cuda()
outputs = None
# 自回归生成
for i in range(image_token_num):
outputs = model.language_model.model(
inputs_embeds=inputs_embeds,
use_cache=True,
past_key_values=outputs.past_key_values if i != 0 else None
)
hidden_states = outputs.last_hidden_state
# 获取logits并应用CFG
logits = model.gen_head(hidden_states[:, -1, :])
logit_cond = logits[0::2, :]
logit_uncond = logits[1::2, :]
logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
probs = torch.softmax(logits / temperature, dim=-1)
# 采样下一个token
next_token = torch.multinomial(probs, num_samples=1)
generated_tokens[:, i] = next_token.squeeze(dim=-1)
# 准备下一步的输入
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
img_embeds = model.prepare_gen_img_embeds(next_token)
inputs_embeds = img_embeds.unsqueeze(dim=1)
# 解码生成的tokens为图像
dec = model.gen_vision_model.decode_code(
generated_tokens.to(dtype=torch.int),
shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size]
)
# 转换为numpy进行处理
dec = dec.to(torch.float32).cpu().numpy()
# 确保是BCHW格式
if dec.shape[1] != 3:
dec = np.repeat(dec, 3, axis=1)
# 从[-1,1]转换到[0,1]
dec = (dec + 1) / 2
# 确保值范围在[0,1]之间
dec = np.clip(dec, 0, 1)
# 转换为ComfyUI需要的格式 [B,C,H,W] -> [B,H,W,C]
dec = np.transpose(dec, (0, 2, 3, 1))
# 转换为tensor
images = torch.from_numpy(dec).float()
# 打印详细的形状信息
# print(f"Initial dec shape: {dec.shape}")
# print(f"Final tensor: shape={images.shape}, dtype={images.dtype}, range=[{images.min():.3f}, {images.max():.3f}]")
# 确保格式正确
assert images.ndim == 4 and images.shape[-1] == 3, f"Unexpected shape: {images.shape}"
return (images,)
@classmethod
def IS_CHANGED(cls, seed, **kwargs):
return seed |