Create linePlot.py
Browse files- linePlot.py +103 -0
linePlot.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from typing import Dict, Tuple, Optional
|
5 |
+
# from datetime import datetime
|
6 |
+
|
7 |
+
def plot_stacked_time_series(time_series_df: pd.DataFrame, title: str) -> plt.Figure:
|
8 |
+
"""
|
9 |
+
Create stacked time series plot with shared x-axis (cleaned up version)
|
10 |
+
"""
|
11 |
+
n_series = len(time_series_df.columns)
|
12 |
+
|
13 |
+
# Create figure with shared x-axis
|
14 |
+
fig, axes = plt.subplots(n_series, 1, figsize=(12, 2*n_series), sharex=True)
|
15 |
+
fig.suptitle(title, y=1.02, fontsize=14)
|
16 |
+
|
17 |
+
# Calculate y-axis limits for consistent scaling
|
18 |
+
y_min = time_series_df.min().min()
|
19 |
+
y_max = time_series_df.max().max()
|
20 |
+
y_range = y_max - y_min
|
21 |
+
y_buffer = y_range * 0.1
|
22 |
+
|
23 |
+
for idx, column in enumerate(time_series_df.columns):
|
24 |
+
ax = axes[idx] if n_series > 1 else axes
|
25 |
+
series = time_series_df[column]
|
26 |
+
|
27 |
+
# Plot the line with smaller markers
|
28 |
+
ax.plot(range(len(series)), series, marker='o', markersize=3, linewidth=1.5)
|
29 |
+
|
30 |
+
# Set title and labels
|
31 |
+
ax.set_title(column, pad=5, fontsize=10)
|
32 |
+
ax.set_ylabel('Popularity', fontsize=9)
|
33 |
+
|
34 |
+
# Set consistent y-axis limits with buffer
|
35 |
+
ax.set_ylim(y_min - y_buffer, y_max + y_buffer)
|
36 |
+
|
37 |
+
# Add grid with lower alpha
|
38 |
+
# ax.grid(True, linestyle='--', alpha=0.4)
|
39 |
+
|
40 |
+
# Format x-axis (only for bottom subplot)
|
41 |
+
if idx == n_series-1:
|
42 |
+
ax.set_xlabel('Time Period', fontsize=10)
|
43 |
+
ax.set_xticks(range(len(series)))
|
44 |
+
ax.set_xticklabels(time_series_df.index, rotation=45, ha='right')
|
45 |
+
else:
|
46 |
+
ax.set_xticks(range(len(series)))
|
47 |
+
ax.set_xticklabels([])
|
48 |
+
|
49 |
+
# Adjust spacing between subplots
|
50 |
+
plt.subplots_adjust(hspace=0.3)
|
51 |
+
|
52 |
+
return fig
|
53 |
+
|
54 |
+
|
55 |
+
def plot_emotion_topic_grid(pair_series: Dict[Tuple[str, str], pd.Series],
|
56 |
+
top_n: int) -> plt.Figure:
|
57 |
+
"""
|
58 |
+
Create grid of time series plots for emotion-topic pairs with improved formatting
|
59 |
+
"""
|
60 |
+
fig, axes = plt.subplots(top_n, top_n, figsize=(15, 15))
|
61 |
+
fig.suptitle('Emotion-Topic Pair Frequencies Over Time', y=1.02, fontsize=14)
|
62 |
+
|
63 |
+
# Get unique emotions and topics
|
64 |
+
emotions = sorted(set(e for e, _ in pair_series.keys()))
|
65 |
+
topics = sorted(set(t for _, t in pair_series.keys()))
|
66 |
+
|
67 |
+
# Calculate global y-axis limits for consistent scaling
|
68 |
+
all_values = [series.values for series in pair_series.values()]
|
69 |
+
y_min = min(min(values) for values in all_values)
|
70 |
+
y_max = max(max(values) for values in all_values)
|
71 |
+
y_range = y_max - y_min
|
72 |
+
y_buffer = y_range * 0.1
|
73 |
+
|
74 |
+
for i, emotion in enumerate(emotions):
|
75 |
+
for j, topic in enumerate(topics):
|
76 |
+
ax = axes[i, j]
|
77 |
+
series = pair_series[(emotion, topic)]
|
78 |
+
|
79 |
+
# Plot the line
|
80 |
+
ax.plot(range(len(series)), series, marker='o', markersize=3)
|
81 |
+
|
82 |
+
# Set titles and labels
|
83 |
+
if i == 0:
|
84 |
+
ax.set_title(f'{topic}', fontsize=10)
|
85 |
+
if j == 0:
|
86 |
+
ax.set_ylabel(f'{emotion}', fontsize=10)
|
87 |
+
|
88 |
+
# Set consistent y-axis limits
|
89 |
+
ax.set_ylim(y_min - y_buffer, y_max + y_buffer)
|
90 |
+
|
91 |
+
# Add grid
|
92 |
+
ax.grid(True, linestyle='--', alpha=0.4)
|
93 |
+
|
94 |
+
# Format ticks
|
95 |
+
if i == top_n-1: # Only bottom row shows x-labels
|
96 |
+
ax.set_xticklabels(series.index, rotation=45, fontsize=8)
|
97 |
+
else:
|
98 |
+
ax.set_xticklabels([])
|
99 |
+
|
100 |
+
ax.tick_params(axis='both', which='major', labelsize=8)
|
101 |
+
|
102 |
+
plt.tight_layout()
|
103 |
+
return fig
|