Vera-ZWY commited on
Commit
d74eb02
1 Parent(s): a708fda

Create linePlot.py

Browse files
Files changed (1) hide show
  1. linePlot.py +103 -0
linePlot.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from typing import Dict, Tuple, Optional
5
+ # from datetime import datetime
6
+
7
+ def plot_stacked_time_series(time_series_df: pd.DataFrame, title: str) -> plt.Figure:
8
+ """
9
+ Create stacked time series plot with shared x-axis (cleaned up version)
10
+ """
11
+ n_series = len(time_series_df.columns)
12
+
13
+ # Create figure with shared x-axis
14
+ fig, axes = plt.subplots(n_series, 1, figsize=(12, 2*n_series), sharex=True)
15
+ fig.suptitle(title, y=1.02, fontsize=14)
16
+
17
+ # Calculate y-axis limits for consistent scaling
18
+ y_min = time_series_df.min().min()
19
+ y_max = time_series_df.max().max()
20
+ y_range = y_max - y_min
21
+ y_buffer = y_range * 0.1
22
+
23
+ for idx, column in enumerate(time_series_df.columns):
24
+ ax = axes[idx] if n_series > 1 else axes
25
+ series = time_series_df[column]
26
+
27
+ # Plot the line with smaller markers
28
+ ax.plot(range(len(series)), series, marker='o', markersize=3, linewidth=1.5)
29
+
30
+ # Set title and labels
31
+ ax.set_title(column, pad=5, fontsize=10)
32
+ ax.set_ylabel('Popularity', fontsize=9)
33
+
34
+ # Set consistent y-axis limits with buffer
35
+ ax.set_ylim(y_min - y_buffer, y_max + y_buffer)
36
+
37
+ # Add grid with lower alpha
38
+ # ax.grid(True, linestyle='--', alpha=0.4)
39
+
40
+ # Format x-axis (only for bottom subplot)
41
+ if idx == n_series-1:
42
+ ax.set_xlabel('Time Period', fontsize=10)
43
+ ax.set_xticks(range(len(series)))
44
+ ax.set_xticklabels(time_series_df.index, rotation=45, ha='right')
45
+ else:
46
+ ax.set_xticks(range(len(series)))
47
+ ax.set_xticklabels([])
48
+
49
+ # Adjust spacing between subplots
50
+ plt.subplots_adjust(hspace=0.3)
51
+
52
+ return fig
53
+
54
+
55
+ def plot_emotion_topic_grid(pair_series: Dict[Tuple[str, str], pd.Series],
56
+ top_n: int) -> plt.Figure:
57
+ """
58
+ Create grid of time series plots for emotion-topic pairs with improved formatting
59
+ """
60
+ fig, axes = plt.subplots(top_n, top_n, figsize=(15, 15))
61
+ fig.suptitle('Emotion-Topic Pair Frequencies Over Time', y=1.02, fontsize=14)
62
+
63
+ # Get unique emotions and topics
64
+ emotions = sorted(set(e for e, _ in pair_series.keys()))
65
+ topics = sorted(set(t for _, t in pair_series.keys()))
66
+
67
+ # Calculate global y-axis limits for consistent scaling
68
+ all_values = [series.values for series in pair_series.values()]
69
+ y_min = min(min(values) for values in all_values)
70
+ y_max = max(max(values) for values in all_values)
71
+ y_range = y_max - y_min
72
+ y_buffer = y_range * 0.1
73
+
74
+ for i, emotion in enumerate(emotions):
75
+ for j, topic in enumerate(topics):
76
+ ax = axes[i, j]
77
+ series = pair_series[(emotion, topic)]
78
+
79
+ # Plot the line
80
+ ax.plot(range(len(series)), series, marker='o', markersize=3)
81
+
82
+ # Set titles and labels
83
+ if i == 0:
84
+ ax.set_title(f'{topic}', fontsize=10)
85
+ if j == 0:
86
+ ax.set_ylabel(f'{emotion}', fontsize=10)
87
+
88
+ # Set consistent y-axis limits
89
+ ax.set_ylim(y_min - y_buffer, y_max + y_buffer)
90
+
91
+ # Add grid
92
+ ax.grid(True, linestyle='--', alpha=0.4)
93
+
94
+ # Format ticks
95
+ if i == top_n-1: # Only bottom row shows x-labels
96
+ ax.set_xticklabels(series.index, rotation=45, fontsize=8)
97
+ else:
98
+ ax.set_xticklabels([])
99
+
100
+ ax.tick_params(axis='both', which='major', labelsize=8)
101
+
102
+ plt.tight_layout()
103
+ return fig