File size: 6,904 Bytes
9965bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import argparse
import os

import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F

from .pr_datasets_all import FUNC_DICT
import matplotlib.pyplot as plt

plt.rcParams["figure.figsize"] = (20, 3)
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300


def model_fn(x, t, y=None, rule=None,
             model=nn.Identity(), num_classes=3, class_cond=True, cfg=False, w=0.):
    # y has to be composer, rule is a dummy input
    y_null = th.tensor([num_classes] * x.shape[0], device=x.device)
    if class_cond:
        if cfg:
            return (1 + w) * model(x, t, y) - w * model(x, t, y_null)
        else:
            return model(x, t, y)
    else:
        return model(x, t, y_null)


def dc_model_fn(x, t, y=None, rule=None,
                model=nn.Identity(), num_classes=3, class_cond=True, cfg=False, w=0.):
    # diffcollage score function takes in 4 x pitch x time
    x = x.permute(0, 1, 3, 2)
    y_null = th.tensor([num_classes] * x.shape[0], device=x.device)
    if class_cond:
        if cfg:
            eps = (1 + w) * model(x, t, y) - w * model(x, t, y_null)
            return eps.permute(0, 1, 3, 2)  # need to return 4 x time x pitch
        else:
            return model(x, t, y).permute(0, 1, 3, 2)
    else:
        return model(x, t, y_null).permute(0, 1, 3, 2)


# y is a dummy input for cond_fn, rule is the real input
def grad_nn_zt_xentropy(x, y=None, rule=None, classifier=nn.Identity()):
    # Xentropy cond_fn
    assert rule is not None
    t = th.zeros(x.shape[0], device=x.device)
    with th.enable_grad():
        x_in = x.detach().requires_grad_(True)
        logits = classifier(x_in, t)
        log_probs = F.log_softmax(logits, dim=-1)
        selected = log_probs[range(len(logits)), rule.view(-1)]
        return th.autograd.grad(selected.sum(), x_in)[0]


def grad_nn_zt_mse(x, t, y=None, rule=None, classifier_scale=10., classifier=nn.Identity()):
    assert rule is not None
    with th.enable_grad():
        x_in = x.detach().requires_grad_(True)
        logits = classifier(x_in, t)
        log_probs = - F.mse_loss(logits, rule, reduction="none").sum(dim=-1)
        return th.autograd.grad(log_probs.sum(), x_in)[0] * classifier_scale


def grad_nn_zt_chord(x, t, y=None, rule=None, classifier_scale=10., classifier=nn.Identity(), both=False):
    assert rule is not None
    with th.enable_grad():
        x_in = x.detach().requires_grad_(True)
        key_logits, chord_logits = classifier(x_in, t)
        if both:
            rule_key = rule[:, :1]
            rule_chord = rule[:, 1:]
            rule_chord = rule_chord.reshape(-1)
            chord_logits = chord_logits.reshape(-1, chord_logits.shape[-1])
            key_log_probs = - F.cross_entropy(key_logits, rule_key, reduction="none")
            chord_log_probs = - F.cross_entropy(chord_logits, rule_chord, reduction="none")
            chord_log_probs = chord_log_probs.reshape(x_in.shape[0], -1).mean(dim=-1)
            log_probs = key_log_probs + chord_log_probs
        else:
            rule = rule.reshape(-1)
            chord_logits = chord_logits.reshape(-1, chord_logits.shape[-1])
            log_probs = - F.cross_entropy(chord_logits, rule, reduction="none")
        return th.autograd.grad(log_probs.sum(), x_in)[0] * classifier_scale


def nn_z0_chord_dummy(x, t, y=None, rule=None, classifier_scale=0.1, classifier=nn.Identity(), both=False):
    # classifier_scale is equivalent to step_size
    t = th.zeros(x.shape[0], device=x.device)
    key_logits, chord_logits = classifier(x, t)
    if both:
        rule_key = rule[:, :1]
        rule_chord = rule[:, 1:]
        rule_chord = rule_chord.reshape(-1)
        chord_logits = chord_logits.reshape(-1, chord_logits.shape[-1])
        key_log_probs = - F.cross_entropy(key_logits, rule_key, reduction="none")
        chord_log_probs = - F.cross_entropy(chord_logits, rule_chord, reduction="none")
        chord_log_probs = chord_log_probs.reshape(x.shape[0], -1).mean(dim=-1)
        log_probs = key_log_probs + chord_log_probs
    else:
        rule = rule.reshape(-1)
        chord_logits = chord_logits.reshape(-1, chord_logits.shape[-1])
        log_probs = - F.cross_entropy(chord_logits, rule, reduction="none")
        log_probs = log_probs.reshape(x.shape[0], -1).mean(dim=-1)
    return log_probs * classifier_scale


def nn_z0_mse_dummy(x, t, y=None, rule=None, classifier_scale=0.1, classifier=nn.Identity()):
    # mse cond_fn, t is a dummy variable b/c wrap_model in respace
    assert rule is not None
    t = th.zeros(x.shape[0], device=x.device)
    logits = classifier(x, t)
    log_probs = - F.mse_loss(logits, rule, reduction="none").sum(dim=-1)
    return log_probs * classifier_scale


def nn_z0_mse(x, rule=None, classifier=nn.Identity()):
    # mse cond_fn, t is a dummy variable b/c wrap_model in respace
    t = th.zeros(x.shape[0], device=x.device)
    logits = classifier(x, t)
    log_probs = - F.mse_loss(logits, rule, reduction="none").sum(dim=-1)
    return log_probs


def rule_x0_mse_dummy(x, t, y=None, rule=None, rule_name='pitch_hist'):
    # use differentiable rule to differentiate through rule(x_0), t is a dummy variable b/c wrap_model in respace
    logits = FUNC_DICT[rule_name](x)
    log_probs = - F.mse_loss(logits, rule, reduction="none").sum(dim=-1)
    return log_probs


def rule_x0_mse(x, rule=None, rule_name='pitch_hist', soft=False):
    # soften non-differentiable rule to differentiate through rule(x_0)
    # soften doesn't seem to work so didn't actually take in soft as input, always set to False
    logits = FUNC_DICT[rule_name](x, soft=soft)
    log_probs = - F.mse_loss(logits, rule, reduction="none").sum(dim=-1)
    return log_probs


class _WrappedFn:
    def __init__(self, fn):
        self.fn = fn

    def __call__(self, x, t, y=None, rule=None):
        return self.fn(x, t, y, rule)


function_map = {
    "grad_nn_zt_xentropy": grad_nn_zt_xentropy,
    "grad_nn_zt_mse": grad_nn_zt_mse,
    "grad_nn_zt_chord": grad_nn_zt_chord,
    "nn_z0_chord_dummy": nn_z0_chord_dummy,
    "nn_z0_mse_dummy": nn_z0_mse_dummy,
    "nn_z0_mse": nn_z0_mse,
    "rule_x0_mse_dummy": rule_x0_mse_dummy,
    "rule_x0_mse": rule_x0_mse
}


def composite_nn_zt(x, t, y=None, rule=None, fns=None, classifier_scales=None, classifiers=None, rule_names=None):
    num_classifiers = len(classifiers)
    out = 0
    for i in range(num_classifiers):
        out += function_map[fns[i]](x, t, y=y, rule=rule[rule_names[i]],
                                    classifier_scale=classifier_scales[i], classifier=classifiers[i])
    return out


def composite_rule(x, t, y=None, rule=None, fns=None, classifier_scales=None, rule_names=None):
    out = 0
    for i in range(len(fns)):
        out += function_map[fns[i]](x, t, y=y, rule=rule[rule_names[i]], rule_name=rule_names[i]) * classifier_scales[i]
    return out