Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,747 Bytes
82635c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import requests
from datetime import datetime,timedelta
import re
attn_maps = {}
def hook_fn(name):
def forward_hook(module, input, output):
if hasattr(module.processor, "attn_map"):
attn_maps[name] = module.processor.attn_map
del module.processor.attn_map
return forward_hook
def register_cross_attention_hook(unet):
for name, module in unet.named_modules():
if name.split('.')[-1].startswith('attn2'):
module.register_forward_hook(hook_fn(name))
return unet
def upscale(attn_map, target_size):
attn_map = torch.mean(attn_map, dim=0)
attn_map = attn_map.permute(1,0)
temp_size = None
for i in range(0,5):
scale = 2 ** i
if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
break
assert temp_size is not None, "temp_size cannot is None"
attn_map = attn_map.view(attn_map.shape[0], *temp_size)
attn_map = F.interpolate(
attn_map.unsqueeze(0).to(dtype=torch.float32),
size=target_size,
mode='bilinear',
align_corners=False
)[0]
attn_map = torch.softmax(attn_map, dim=0)
return attn_map
def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
idx = 0 if instance_or_negative else 1
net_attn_maps = []
for name, attn_map in attn_maps.items():
attn_map = attn_map.cpu() if detach else attn_map
attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
attn_map = upscale(attn_map, image_size)
net_attn_maps.append(attn_map)
net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
return net_attn_maps
def attnmaps2images(net_attn_maps):
#total_attn_scores = 0
images = []
for attn_map in net_attn_maps:
attn_map = attn_map.cpu().numpy()
#total_attn_scores += attn_map.mean().item()
normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
normalized_attn_map = normalized_attn_map.astype(np.uint8)
#print("norm: ", normalized_attn_map.shape)
image = Image.fromarray(normalized_attn_map)
#image = fix_save_attn_map(attn_map)
images.append(image)
#print(total_attn_scores)
return images
def is_torch2_available():
return hasattr(F, "scaled_dot_product_attention")
class RemoteJson:
def __init__(self, url, refresh_gap_seconds=3600, processor=None):
"""
Initialize the RemoteJsonManager.
:param url: The URL of the remote JSON file.
:param refresh_gap_seconds: Time in seconds after which the JSON should be refreshed.
:param processor: Optional callback function to process the JSON after it's loaded successfully.
"""
self.url = url
self.refresh_gap_seconds = refresh_gap_seconds
self.processor = processor
self.json_data = None
self.last_updated = None
def _load_json(self):
"""
Load JSON from the remote URL. If loading fails, return None.
"""
try:
response = requests.get(self.url)
response.raise_for_status()
return response.json()
except requests.RequestException as e:
print(f"Failed to fetch JSON: {e}")
return None
def _should_refresh(self):
"""
Check whether the JSON should be refreshed based on the time gap.
"""
if not self.last_updated:
return True # If no last update, always refresh
return datetime.now() - self.last_updated > timedelta(seconds=self.refresh_gap_seconds)
def _update_json(self):
"""
Fetch and load the JSON from the remote URL. If it fails, keep the previous data.
"""
new_json = self._load_json()
if new_json:
self.json_data = new_json
self.last_updated = datetime.now()
print("JSON updated successfully.")
if self.processor:
self.json_data = self.processor(self.json_data)
else:
print("Failed to update JSON. Keeping the previous version.")
def get(self):
"""
Get the JSON, checking whether it needs to be refreshed.
If refresh is required, it fetches the new data and applies the processor.
"""
if self._should_refresh():
print("Refreshing JSON...")
self._update_json()
else:
print("Using cached JSON.")
return self.json_data
def extract_key_value_pairs(input_string):
# Define the regular expression to match [xxx:yyy] where yyy can have special characters
pattern = r"\[([^\]]+):([^\]]+)\]"
# Find all matches in the input string with the original matching string
matches = re.finditer(pattern, input_string)
# Convert matches to a list of dictionaries including the raw matching string
result = [{"key": match.group(1), "value": match.group(2), "raw": match.group(0)} for match in matches]
return result
def extract_characters(prefix, input_string):
# Define the regular expression to match placeholders starting with "@" and ending with space or comma
pattern = rf"{prefix}([^\s,$]+)(?=\s|,|$)"
# Find all matches in the input string
matches = re.findall(pattern, input_string)
# Return a list of dictionaries with the extracted placeholders
result = [{"raw": f"{prefix}{match}", "key": match} for match in matches]
return result |