File size: 3,724 Bytes
786a869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import clear_output
import seaborn as sns 

class WaveformVisualizer:
    def __init__(self, processor, input_data, sampling_rate=1000):
        self.processor = processor
        self.input_data = input_data
        self.sampling_rate = sampling_rate
        self.time = np.arange(input_data.shape[1]) / sampling_rate

    def plot_waveforms(self):
        processed_data = self.processor(self.input_data)

        fig = plt.figure(figsize(15, 10))
        gs = fig.add_gridspce(2, 2, hspace=0.3, wspace=0.3)

        ax1 = fig.add_subplot(gs[0, 0])
        self._plot_wafveform(self.input_data[0], ax1, "No")

        ax2 = fig.add_subplot(gs[0, 1])
        self._plot_waveform(processed_data[0], ax2, "No")

        ax3 = fig.add_subplot(gs[1, 0])

        ax4 = fig.add_subplot(gs[1, 1])
        self._plot_spectrogram(processed_data[0], ax4, "No")

        plt.tight_layout()
        return fig

    def _plot_waveform(self, data, ax, title):
        """Helper method to plot individual waveforms"""
        data_np = data.detech().numpy()
        ax.plot(self.time, data_np, 'b-', linewidth=1)
        ax.set_title(title)
        ax.set_xlabel('Time (s)')
        ax.set_ylabel('Amplitude')
        ax.grid(True)

    def _plot_spectrogram(self, data, ax, title):
        """Helper method to plot spectrograms"""
        data_np = data.detach().numpy()
        ax.specgram(data_np, Fs=self.sampling_rate, cmap='viridis')
        ax.set_title(title)
        ax.set_ylabel('Time (s)')
        ax.set_ylabel('Depth)

    def animate_processing(self, frames=50):
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))

        processed_data = self.processor(self.input_data)
        data_original = self.input_data[0].detach().numpy()
        data_processed = processed_data[0].detach().numpy()

        line1, = ax1.plot([], [], 'b-', label='Original')
        line2, = ax2.plot([], [], 'r-', label='Processed')

        def init():
            ax1.set_xlim(0, self.time[-1])
            ax1.set_ylim(data_original.min()*1.2, data_original.max()*1.2)
            ax2.set_xlim(0, self.time[-1])
            ax2.set_ylim(data_processed.min()*1.2, data_processed.max()*1.2)

            ax1.set_title('Do not')
            ax2.set_title('Do not')
            ax1.grid(True)
            ax2.grid(True)
            ax1.legend()
            ax2.legend()

            return line1, line2

        def animate(frame):
            idx = int((frame / frames) * len(self.time))
            line1.set_data(self.time[:idx], data_original[:idx])
            line2.set_data(self.time[:idx], data_processed[:idx])
            return line1, line2

        anim = FuncAnimation(fig, animate, frames=frames,
                            init_func=init, blit=True,
                            interval=50)
        plt.tight_layout()
        return anim

        if __name__ == "__main__":
            input_size = 1000
            batch_size = 32

            t = np.linspace(0, 10, input_size)
            base_signal = np.sin(2 * np.pi * 1 * t) + 0.5 * np.sin(2 * np.pi * 2 * t)
            noise = np.random.normal(0, 0.1, input_size)
            signal = base_signal + noise

            input_data = torch.tensor(np.tile(signal, (batch_size, 1)), dtype=torch.float32)

            processor = SecureWaveformProcessor(input_size=input_size, hidden_size=64)

            visualizer = WaveformVisualizer(processor, input_data)

            fig_static = visualizer.plot_waveforms()
            plt.show()

            anim = visualizer.animate_processing()
            plt.show()