File size: 3,721 Bytes
d74eb02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7709ffd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, Tuple, Optional
# from datetime import datetime

def plot_stacked_time_series(time_series_df: pd.DataFrame, title: str) -> plt.Figure:
    """
    Create stacked time series plot with shared x-axis (cleaned up version)
    """
    n_series = len(time_series_df.columns)
    
    # Create figure with shared x-axis
    fig, axes = plt.subplots(n_series, 1, figsize=(12, 2*n_series), sharex=True)
    fig.suptitle(title, y=1.02, fontsize=14)
    
    # Calculate y-axis limits for consistent scaling
    y_min = time_series_df.min().min()
    y_max = time_series_df.max().max()
    y_range = y_max - y_min
    y_buffer = y_range * 0.1
    
    for idx, column in enumerate(time_series_df.columns):
        ax = axes[idx] if n_series > 1 else axes
        series = time_series_df[column]
        
        # Plot the line with smaller markers
        ax.plot(range(len(series)), series, marker='o', markersize=3, linewidth=1.5)
        
        # Set title and labels
        ax.set_title(column, pad=5, fontsize=10)
        ax.set_ylabel('Popularity', fontsize=9)
        
        # Set consistent y-axis limits with buffer
        ax.set_ylim(y_min - y_buffer, y_max + y_buffer)
        
        # Add grid with lower alpha
        # ax.grid(True, linestyle='--', alpha=0.4)
        
        # Format x-axis (only for bottom subplot)
        if idx == n_series-1:
            ax.set_xlabel('Time Period', fontsize=10)
            ax.set_xticks(range(len(series)))
            ax.set_xticklabels(time_series_df.index, rotation=45, ha='right')
        else:
            ax.set_xticks(range(len(series)))
            ax.set_xticklabels([])
    
    # Adjust spacing between subplots
    plt.subplots_adjust(hspace=0.3)
    
    return fig


def plot_emotion_topic_grid(pair_series: Dict[Tuple[str, str], pd.Series], 
                          top_n: int) -> plt.Figure:
    """
    Create grid of time series plots for emotion-topic pairs with improved formatting
    """
    fig, axes = plt.subplots(top_n, top_n, figsize=(15, 15))
    fig.suptitle('Emotion-Topic Pair Frequencies Over Time', y=1.02, fontsize=14)
    
    # Get unique emotions and topics
    emotions = sorted(set(e for e, _ in pair_series.keys()))
    topics = sorted(set(t for _, t in pair_series.keys()))
    
    # Calculate global y-axis limits for consistent scaling
    all_values = [series.values for series in pair_series.values()]
    y_min = min(min(values) for values in all_values)
    y_max = max(max(values) for values in all_values)
    y_range = y_max - y_min
    y_buffer = y_range * 0.1
    
    for i, emotion in enumerate(emotions):
        for j, topic in enumerate(topics):
            ax = axes[i, j]
            series = pair_series[(emotion, topic)]
            
            # Plot the line
            ax.plot(range(len(series)), series, marker='o', markersize=3)
            
            # Set titles and labels
            if i == 0:
                ax.set_title(f'{topic}', fontsize=10)
            if j == 0:
                ax.set_ylabel(f'{emotion}', fontsize=10)
            
            # Set consistent y-axis limits
            ax.set_ylim(y_min - y_buffer, y_max + y_buffer)
            
            # Add grid
            ax.grid(True, linestyle='--', alpha=0.4)
            
            # Format ticks
            if i == top_n-1:  # Only bottom row shows x-labels
                ax.set_xticklabels(series.index, rotation=45, fontsize=8)
            else:
                ax.set_xticklabels([])
            
            ax.tick_params(axis='both', which='major', labelsize=8)
    
    plt.tight_layout()
    return fig