File size: 1,278 Bytes
b7f3942 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
"""k-diffusion transformer diffusion models, version 2.
Codes adopted from https://github.com/crowsonkb/k-diffusion
"""
from contextlib import contextmanager
import math
import threading
state = threading.local()
state.flop_counter = None
@contextmanager
def flop_counter(enable=True):
try:
old_flop_counter = state.flop_counter
state.flop_counter = FlopCounter() if enable else None
yield state.flop_counter
finally:
state.flop_counter = old_flop_counter
class FlopCounter:
def __init__(self):
self.ops = []
def op(self, op, *args, **kwargs):
self.ops.append((op, args, kwargs))
@property
def flops(self):
flops = 0
for op, args, kwargs in self.ops:
flops += op(*args, **kwargs)
return flops
def op(op, *args, **kwargs):
if getattr(state, "flop_counter", None):
state.flop_counter.op(op, *args, **kwargs)
def op_linear(x, weight):
return math.prod(x) * weight[0]
def op_attention(q, k, v):
*b, s_q, d_q = q
*b, s_k, d_k = k
*b, s_v, d_v = v
return math.prod(b) * s_q * s_k * (d_q + d_v)
def op_natten(q, k, v, kernel_size):
*q_rest, d_q = q
*_, d_v = v
return math.prod(q_rest) * (d_q + d_v) * kernel_size**2 |