winfred2027 commited on
Commit
8a7f4ef
·
verified ·
1 Parent(s): 39570c3

Upload generation.py

Browse files
Files changed (1) hide show
  1. demo_support/generation.py +204 -0
demo_support/generation.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch_redstone as rst
3
+ import transformers
4
+ import numpy as np
5
+ from torch import nn
6
+ from typing import Tuple, List, Union, Optional
7
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
8
+ from huggingface_hub import hf_hub_download
9
+ from diffusers import StableUnCLIPImg2ImgPipeline
10
+
11
+
12
+ class Wrapper(transformers.modeling_utils.PreTrainedModel):
13
+ def __init__(self) -> None:
14
+ super().__init__(transformers.configuration_utils.PretrainedConfig())
15
+ self.param = torch.nn.Parameter(torch.tensor(0.))
16
+
17
+ def forward(self, x):
18
+ return rst.ObjectProxy(image_embeds=x)
19
+
20
+ class MLP(nn.Module):
21
+
22
+ def forward(self, x: T) -> T:
23
+ return self.model(x)
24
+
25
+ def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
26
+ super(MLP, self).__init__()
27
+ layers = []
28
+ for i in range(len(sizes) -1):
29
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
30
+ if i < len(sizes) - 2:
31
+ layers.append(act())
32
+ self.model = nn.Sequential(*layers)
33
+
34
+ class ClipCaptionModel(nn.Module):
35
+
36
+ #@functools.lru_cache #FIXME
37
+ def get_dummy_token(self, batch_size: int, device: D) -> T:
38
+ return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
39
+
40
+ def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None):
41
+ embedding_text = self.gpt.transformer.wte(tokens)
42
+ prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
43
+ #print(embedding_text.size()) #torch.Size([5, 67, 768])
44
+ #print(prefix_projections.size()) #torch.Size([5, 1, 768])
45
+ embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
46
+ if labels is not None:
47
+ dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
48
+ labels = torch.cat((dummy_token, tokens), dim=1)
49
+ out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
50
+ return out
51
+
52
+ def __init__(self, prefix_length: int, prefix_size: int = 512):
53
+ super(ClipCaptionModel, self).__init__()
54
+ self.prefix_length = prefix_length
55
+ self.gpt = GPT2LMHeadModel(GPT2Config.from_pretrained('gpt2'))
56
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
57
+ if prefix_length > 10: # not enough memory
58
+ self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
59
+ else:
60
+ self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))
61
+
62
+ class ClipCaptionPrefix(ClipCaptionModel):
63
+
64
+ def parameters(self, recurse: bool = True):
65
+ return self.clip_project.parameters()
66
+
67
+ def train(self, mode: bool = True):
68
+ super(ClipCaptionPrefix, self).train(mode)
69
+ self.gpt.eval()
70
+ return self
71
+
72
+ def generate2(
73
+ model,
74
+ tokenizer,
75
+ tokens=None,
76
+ prompt=None,
77
+ embed=None,
78
+ entry_count=1,
79
+ entry_length=67, # maximum number of words
80
+ top_p=0.8,
81
+ temperature=1.,
82
+ stop_token: str = '.',
83
+ ):
84
+ model.eval()
85
+ generated_num = 0
86
+ generated_list = []
87
+ stop_token_index = tokenizer.encode(stop_token)[0]
88
+ filter_value = -float("Inf")
89
+ device = next(model.parameters()).device
90
+ score_col = []
91
+ with torch.no_grad():
92
+
93
+ for entry_idx in range(entry_count):
94
+ if embed is not None:
95
+ generated = embed
96
+ else:
97
+ if tokens is None:
98
+ tokens = torch.tensor(tokenizer.encode(prompt))
99
+ tokens = tokens.unsqueeze(0).to(device)
100
+
101
+ generated = model.gpt.transformer.wte(tokens)
102
+
103
+ for i in range(entry_length):
104
+
105
+ outputs = model.gpt(inputs_embeds=generated)
106
+ logits = outputs.logits
107
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
108
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
109
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
110
+ sorted_indices_to_remove = cumulative_probs > top_p
111
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
112
+ ..., :-1
113
+ ].clone()
114
+ sorted_indices_to_remove[..., 0] = 0
115
+
116
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
117
+ logits[:, indices_to_remove] = filter_value
118
+ next_token = torch.argmax(torch.softmax(logits, dim=-1), -1).reshape(1, 1)
119
+ score = torch.softmax(logits, dim=-1).reshape(-1)[next_token.item()].item()
120
+ score_col.append(score)
121
+ next_token_embed = model.gpt.transformer.wte(next_token)
122
+ if tokens is None:
123
+ tokens = next_token
124
+ else:
125
+ tokens = torch.cat((tokens, next_token), dim=1)
126
+ generated = torch.cat((generated, next_token_embed), dim=1)
127
+ if stop_token_index == next_token.item():
128
+ break
129
+
130
+ output_list = list(tokens.squeeze(0).cpu().numpy())
131
+ output_text = tokenizer.decode(output_list)
132
+ generated_list.append(output_text)
133
+ return generated_list[0]
134
+
135
+
136
+ @torch.no_grad()
137
+ def pc_to_text(pc_encoder: torch.nn.Module, pc, cond_scale):
138
+ ref_dev = next(pc_encoder.parameters()).device
139
+ prefix = pc_encoder(torch.tensor(pc.T[None], device=ref_dev))
140
+ prefix = prefix.float() * cond_scale
141
+ prefix = prefix.to(next(model.parameters()).device)
142
+ prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
143
+ text = generate2(model, tokenizer, embed=prefix_embed)
144
+ return text
145
+
146
+ @torch.no_grad()
147
+ def pc_to_image(pc_encoder: torch.nn.Module, pc, prompt, noise_level, width, height, cfg_scale, num_steps, callback):
148
+ ref_dev = next(pc_encoder.parameters()).device
149
+ enc = pc_encoder(torch.tensor(pc.T[None], device=ref_dev))
150
+ enc = torch.nn.functional.normalize(enc, dim=-1) * (768 ** 0.5) / 2
151
+ if torch.cuda.is_available():
152
+ enc = enc.to('cuda:' + str(torch.cuda.current_device()))
153
+ # enc = enc.type(half)
154
+ # with torch.autocast("cuda"):
155
+ return pipe(
156
+ prompt=', '.join(["best quality"] + ([prompt] if prompt else [])),
157
+ negative_prompt="cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
158
+ image=enc,
159
+ width=width, height=height,
160
+ guidance_scale=cfg_scale,
161
+ noise_level=noise_level,
162
+ callback=callback,
163
+ num_inference_steps=num_steps
164
+ ).images[0]
165
+
166
+
167
+ N = type(None)
168
+ V = np.array
169
+ ARRAY = np.ndarray
170
+ ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
171
+ VS = Union[Tuple[V, ...], List[V]]
172
+ VN = Union[V, N]
173
+ VNS = Union[VS, N]
174
+ T = torch.Tensor
175
+ TS = Union[Tuple[T, ...], List[T]]
176
+ TN = Optional[T]
177
+ TNS = Union[Tuple[TN, ...], List[TN]]
178
+ TSN = Optional[TS]
179
+ TA = Union[T, ARRAY]
180
+
181
+
182
+ D = torch.device
183
+
184
+ pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
185
+ "diffusers/stable-diffusion-2-1-unclip-i2i-l",
186
+ # variant="fp16",
187
+ image_encoder = Wrapper()
188
+ )
189
+ # pe = pipe.text_encoder.text_model.embeddings
190
+ # pe.position_ids = torch.arange(pe.position_ids.shape[-1]).expand((1, -1)).to(pe.position_ids) # workaround
191
+ if torch.cuda.is_available():
192
+ pipe = pipe.to('cuda:' + str(torch.cuda.current_device()))
193
+ pipe.enable_model_cpu_offload(torch.cuda.current_device())
194
+ pipe.enable_attention_slicing()
195
+ pipe.enable_vae_slicing()
196
+
197
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
198
+ prefix_length = 10
199
+ model = ClipCaptionModel(prefix_length)
200
+ # print(model.gpt_embedding_size)
201
+ model.load_state_dict(torch.load(hf_hub_download('OpenShape/clipcap-cc', 'conceptual_weights.pt'), map_location='cpu'))
202
+ model.eval()
203
+ if torch.cuda.is_available():
204
+ model = model.cuda()