File size: 4,551 Bytes
9016314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()


def visualize(s, batch, prefix):
    if len(s.shape) == 5:
        x, b, m = batch['x'], batch['b'], batch['m']
        im_visualize(s, x, b, m, prefix)
    elif len(s.shape) == 3:
        x, b, m = batch['x'], batch['b'], batch['m']
        pc_visualize(s, x, b, m, prefix)
    elif len(s.shape) == 4:
        xc, yc, xt, yt = batch['xc'], batch['yc'], batch['xt'], batch['yt']
        fn_visualize(s, xc, yc, xt, yt, prefix)
    else:
        raise ValueError()

def im_visualize(s, x, b, m, prefix):
    B,N,H,W,C = s.shape
    for i in range(B):
        ss, xx, bb, mm = s[i], x[i], b[i], m[i]
        if ss.shape[-1] == 2: # kspace
            C = 1
            ss = np.expand_dims(np.absolute(np.fft.ifft2(np.fft.ifftshift(ss[...,0] + ss[...,1] * 1j, axes=(-2,-1)))), axis=-1)
            ss = np.array(ss*255, dtype=np.uint8)
            xx = np.expand_dims(np.absolute(np.fft.ifft2(np.fft.ifftshift(xx[...,0] + xx[...,1] * 1j, axes=(-2,-1)))), axis=-1)
            xx = np.array(xx*255, dtype=np.uint8)
            bb = bb[...,0:1]
            mm = mm[...,0:1]
        ss = np.transpose(ss, [1,0,2,3]).reshape(H,W*N,C).squeeze()
        xx = np.transpose(xx, [1,0,2,3]).reshape(H,W*N,C).squeeze()
        bb = np.transpose(bb, [1,0,2,3]).reshape(H,W*N,C).squeeze()
        mm = np.transpose(mm, [1,0,2,3]).reshape(H,W*N,C).squeeze()
        xm = xx * mm + (1-mm) * 128
        xo = xx * bb + (1-bb) * 128
        img = np.concatenate([xm, xo, ss]).astype(np.uint8)

        plt.imsave(f'{prefix}_{i}.png', img)

def pc_visualize(s, x, b, m, prefix):
    B,N,C = s.shape
    for i in range(B):
        ss, xx, bb = s[i], x[i], b[i]
        o = np.where(bb[:,0]==1)[0]
        fig = plt.figure(figsize=(7.5, 2.5))
        ax = fig.add_subplot(131, projection='3d')
        ax.scatter(xx[:,0], xx[:,1], xx[:,2], c='g', s=5)
        ax.axis('off')
        ax.grid(False)
        ax = fig.add_subplot(132, projection='3d')
        ax.scatter(xx[o,0], xx[o,1], xx[o,2], c='g', s=5)
        ax.axis('off')
        ax.grid(False)
        ax = fig.add_subplot(133, projection='3d')
        ax.scatter(ss[:,0], ss[:,1], ss[:,2], c='g', s=5)
        ax.axis('off')
        ax.grid(False)
        plt.savefig(f'{prefix}_{i}.png')
        plt.close('all')

def fn_visualize(s, xc, yc, xt, yt, prefix):
    B,K,N,C = s.shape
    for i in range(B):
        ss, xxc, yyc, xxt, yyt = s[i], xc[i], yc[i], xt[i], yt[i]
        fig = plt.figure(figsize=(4.0, 2.5*K))
        for k in range(K):
            ax = fig.add_subplot(K,1,k+1)
            ax.plot(xxc[k], yyc[k], 'rx', markersize=8)
            ax.plot(xxt[k], yyt[k], 'ko', markersize=3)
            ax.plot(xxt[k], ss[k], 'bo', markersize=3)
        plt.savefig(f'{prefix}_{i}.png')
        plt.close('all')


def plot_functions(m, s, batch, prefix):
    B,K,N,C = m.shape
    xc, yc, xt, yt = batch['xc'], batch['yc'], batch['xt'], batch['yt']
    for i in range(B):
        mm, ss, xxc, yyc, xxt, yyt = m[i,:,:,0], s[i,:,:,0], xc[i,:,:,0], yc[i,:,:,0], xt[i,:,:,0], yt[i,:,:,0]
        fig = plt.figure(figsize=(4.0, 2.5*K))
        for k in range(K):
            idx = np.argsort(xxt[k])
            ax = fig.add_subplot(K,1,k+1)
            ax.plot(xxc[k], yyc[k], 'rx', markersize=8)
            ax.plot(xxt[k], yyt[k], 'ko', markersize=3)
            ax.plot(xxt[k,idx], mm[k,idx], 'b', linewidth=2)
            plt.fill_between(
                xxt[k,idx],
                mm[k,idx] - ss[k,idx],
                mm[k,idx] + ss[k,idx],
                alpha=0.2,
                facecolor='#65c9f7',
                interpolate=True)
        plt.savefig(f'{prefix}_{i}.png')
        plt.close('all')

def plot_img_functions(m, s, batch, prefix):
    B,K,N,C = m.shape
    idx, xc, yc, xt, yt = batch['idx'], batch['xc'], batch['yc'], batch['xt'], batch['yt']
    yo = np.ones_like(yt) * 128
    yo[:,:,idx] = (yc + 0.5) * 255.
    yt =  (yt + 0.5) * 255.
    m = (m + 0.5) * 255.
    for i in range(B):
        yoi, yti, mi = yo[i], yt[i], m[i]
        yoi = np.reshape(yoi, [K,28,28]).astype(np.uint8)
        yoi = np.reshape(np.transpose(yoi, [1,0,2]), [28, K*28])
        yti = np.reshape(yti, [K,28,28]).astype(np.uint8)
        yti = np.reshape(np.transpose(yti, [1,0,2]), [28, K*28])
        mi = np.reshape(mi, [K,28,28]).astype(np.uint8)
        mi = np.reshape(np.transpose(mi, [1,0,2]), [28, K*28])
        img = np.concatenate([yoi, mi, yti], axis=0)

        plt.imsave(f'{prefix}_{i}.png', img)