File size: 5,747 Bytes
82635c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import requests
from datetime import datetime,timedelta
import re

attn_maps = {}
def hook_fn(name):
    def forward_hook(module, input, output):
        if hasattr(module.processor, "attn_map"):
            attn_maps[name] = module.processor.attn_map
            del module.processor.attn_map

    return forward_hook

def register_cross_attention_hook(unet):
    for name, module in unet.named_modules():
        if name.split('.')[-1].startswith('attn2'):
            module.register_forward_hook(hook_fn(name))

    return unet

def upscale(attn_map, target_size):
    attn_map = torch.mean(attn_map, dim=0)
    attn_map = attn_map.permute(1,0)
    temp_size = None

    for i in range(0,5):
        scale = 2 ** i
        if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
            temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
            break

    assert temp_size is not None, "temp_size cannot is None"

    attn_map = attn_map.view(attn_map.shape[0], *temp_size)

    attn_map = F.interpolate(
        attn_map.unsqueeze(0).to(dtype=torch.float32),
        size=target_size,
        mode='bilinear',
        align_corners=False
    )[0]

    attn_map = torch.softmax(attn_map, dim=0)
    return attn_map
def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):

    idx = 0 if instance_or_negative else 1
    net_attn_maps = []

    for name, attn_map in attn_maps.items():
        attn_map = attn_map.cpu() if detach else attn_map
        attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
        attn_map = upscale(attn_map, image_size) 
        net_attn_maps.append(attn_map) 

    net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)

    return net_attn_maps

def attnmaps2images(net_attn_maps):

    #total_attn_scores = 0
    images = []

    for attn_map in net_attn_maps:
        attn_map = attn_map.cpu().numpy()
        #total_attn_scores += attn_map.mean().item()

        normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
        normalized_attn_map = normalized_attn_map.astype(np.uint8)
        #print("norm: ", normalized_attn_map.shape)
        image = Image.fromarray(normalized_attn_map)

        #image = fix_save_attn_map(attn_map)
        images.append(image)

    #print(total_attn_scores)
    return images
def is_torch2_available():
    return hasattr(F, "scaled_dot_product_attention")


class RemoteJson:
    def __init__(self, url, refresh_gap_seconds=3600, processor=None):
        """
        Initialize the RemoteJsonManager.
        :param url: The URL of the remote JSON file.
        :param refresh_gap_seconds: Time in seconds after which the JSON should be refreshed.
        :param processor: Optional callback function to process the JSON after it's loaded successfully.
        """
        self.url = url
        self.refresh_gap_seconds = refresh_gap_seconds
        self.processor = processor
        self.json_data = None
        self.last_updated = None

    def _load_json(self):
        """
        Load JSON from the remote URL. If loading fails, return None.
        """
        try:
            response = requests.get(self.url)
            response.raise_for_status()
            return response.json()
        except requests.RequestException as e:
            print(f"Failed to fetch JSON: {e}")
            return None

    def _should_refresh(self):
        """
        Check whether the JSON should be refreshed based on the time gap.
        """
        if not self.last_updated:
            return True  # If no last update, always refresh
        return datetime.now() - self.last_updated > timedelta(seconds=self.refresh_gap_seconds)

    def _update_json(self):
        """
        Fetch and load the JSON from the remote URL. If it fails, keep the previous data.
        """
        new_json = self._load_json()
        if new_json:
            self.json_data = new_json
            self.last_updated = datetime.now()
            print("JSON updated successfully.")
            if self.processor:
                self.json_data = self.processor(self.json_data)
        else:
            print("Failed to update JSON. Keeping the previous version.")

    def get(self):
        """
        Get the JSON, checking whether it needs to be refreshed.
        If refresh is required, it fetches the new data and applies the processor.
        """
        if self._should_refresh():
            print("Refreshing JSON...")
            self._update_json()
        else:
            print("Using cached JSON.")

        return self.json_data

def extract_key_value_pairs(input_string):
    # Define the regular expression to match [xxx:yyy] where yyy can have special characters
    pattern = r"\[([^\]]+):([^\]]+)\]"
    
    # Find all matches in the input string with the original matching string
    matches = re.finditer(pattern, input_string)
    
    # Convert matches to a list of dictionaries including the raw matching string
    result = [{"key": match.group(1), "value": match.group(2), "raw": match.group(0)} for match in matches]
    
    return result

def extract_characters(prefix, input_string):
    # Define the regular expression to match placeholders starting with "@" and ending with space or comma
    pattern = rf"{prefix}([^\s,$]+)(?=\s|,|$)"
    
    # Find all matches in the input string
    matches = re.findall(pattern, input_string)
    
    # Return a list of dictionaries with the extracted placeholders
    result = [{"raw": f"{prefix}{match}", "key": match} for match in matches]
    
    return result