Songwei Ge commited on
Commit
d0745b6
·
1 Parent(s): 51be712
app.py CHANGED
@@ -44,7 +44,7 @@ def main():
44
  # parse json to span attributes
45
  base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\
46
  color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance = parse_json(
47
- json.loads(text_input))
48
 
49
  # create control input for region diffusion
50
  region_text_prompts, region_target_token_ids, base_tokens = get_region_diffusion_input(
 
44
  # parse json to span attributes
45
  base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\
46
  color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance = parse_json(
47
+ json.loads(text_input), device)
48
 
49
  # create control input for region diffusion
50
  region_text_prompts, region_target_token_ids, base_tokens = get_region_diffusion_input(
models/__pycache__/region_diffusion.cpython-38.pyc CHANGED
Binary files a/models/__pycache__/region_diffusion.cpython-38.pyc and b/models/__pycache__/region_diffusion.cpython-38.pyc differ
 
utils/attention_utils.py CHANGED
@@ -184,5 +184,5 @@ def get_token_maps(attention_maps, save_dir, width, height, obj_tokens, seed=0,
184
  token_maps_vis = plot_attention_maps([attention_maps_averaged, attention_maps_averaged_normalized],
185
  obj_tokens, save_dir, seed, tokens_vis)
186
  attention_maps_averaged_normalized = [attn_mask.unsqueeze(1).repeat(
187
- [1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_normalized]
188
  return attention_maps_averaged_normalized, token_maps_vis
 
184
  token_maps_vis = plot_attention_maps([attention_maps_averaged, attention_maps_averaged_normalized],
185
  obj_tokens, save_dir, seed, tokens_vis)
186
  attention_maps_averaged_normalized = [attn_mask.unsqueeze(1).repeat(
187
+ [1, 4, 1, 1]).to(attention_maps_averaged_sum.device) for attn_mask in attention_maps_averaged_normalized]
188
  return attention_maps_averaged_normalized, token_maps_vis
utils/richtext_utils.py CHANGED
@@ -27,7 +27,7 @@ def seed_everything(seed):
27
  torch.cuda.manual_seed(seed)
28
 
29
 
30
- def hex_to_rgb(hex_string, return_nearest_color=False):
31
  r"""
32
  Covert Hex triplet to RGB triplet.
33
  """
@@ -40,8 +40,8 @@ def hex_to_rgb(hex_string, return_nearest_color=False):
40
  rgb = torch.FloatTensor((red, green, blue))[None, :, None, None]/255.
41
  if return_nearest_color:
42
  nearest_color = find_nearest_color(rgb)
43
- return rgb.cuda(), nearest_color
44
- return rgb.cuda()
45
 
46
 
47
  def find_nearest_color(rgb):
@@ -56,7 +56,7 @@ def find_nearest_color(rgb):
56
  return nearest_color
57
 
58
 
59
- def font2style(font):
60
  r"""
61
  Convert the font name to the style name.
62
  """
@@ -71,7 +71,7 @@ def font2style(font):
71
  'Akronim': 'Abstract Cubism, Pablo Picasso', }[font]
72
 
73
 
74
- def parse_json(json_str):
75
  r"""
76
  Convert the JSON string to attributes.
77
  """
@@ -121,7 +121,7 @@ def parse_json(json_str):
121
  if 'color' in span['attributes']:
122
  use_grad_guidance = True
123
  color_rgb, nearest_color = hex_to_rgb(
124
- span['attributes']['color'], True)
125
  if prev_color_rgb == color_rgb:
126
  prev_text_prompt = color_text_prompts[-1]
127
  color_text_prompts[-1] = prev_text_prompt + \
@@ -197,8 +197,8 @@ def get_attention_control_input(model, base_tokens, size_text_prompts_and_sizes)
197
  word_pos.append(base_tokens.index(size_token)+1)
198
  font_sizes.append(font_size)
199
  if len(word_pos) > 0:
200
- word_pos = torch.LongTensor(word_pos).cuda()
201
- font_sizes = torch.FloatTensor(font_sizes).cuda()
202
  else:
203
  word_pos = None
204
  font_sizes = None
 
27
  torch.cuda.manual_seed(seed)
28
 
29
 
30
+ def hex_to_rgb(hex_string, return_nearest_color=False, device='cuda'):
31
  r"""
32
  Covert Hex triplet to RGB triplet.
33
  """
 
40
  rgb = torch.FloatTensor((red, green, blue))[None, :, None, None]/255.
41
  if return_nearest_color:
42
  nearest_color = find_nearest_color(rgb)
43
+ return rgb.to(device), nearest_color
44
+ return rgb.to(device)
45
 
46
 
47
  def find_nearest_color(rgb):
 
56
  return nearest_color
57
 
58
 
59
+ def font2style(font, device='cuda'):
60
  r"""
61
  Convert the font name to the style name.
62
  """
 
71
  'Akronim': 'Abstract Cubism, Pablo Picasso', }[font]
72
 
73
 
74
+ def parse_json(json_str, device):
75
  r"""
76
  Convert the JSON string to attributes.
77
  """
 
121
  if 'color' in span['attributes']:
122
  use_grad_guidance = True
123
  color_rgb, nearest_color = hex_to_rgb(
124
+ span['attributes']['color'], True, device=device)
125
  if prev_color_rgb == color_rgb:
126
  prev_text_prompt = color_text_prompts[-1]
127
  color_text_prompts[-1] = prev_text_prompt + \
 
197
  word_pos.append(base_tokens.index(size_token)+1)
198
  font_sizes.append(font_size)
199
  if len(word_pos) > 0:
200
+ word_pos = torch.LongTensor(word_pos).to(model.device)
201
+ font_sizes = torch.FloatTensor(font_sizes).to(model.device)
202
  else:
203
  word_pos = None
204
  font_sizes = None