Vera-ZWY's picture
Update linePlot.py
7709ffd verified
raw
history blame
3.72 kB
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, Tuple, Optional
# from datetime import datetime
def plot_stacked_time_series(time_series_df: pd.DataFrame, title: str) -> plt.Figure:
"""
Create stacked time series plot with shared x-axis (cleaned up version)
"""
n_series = len(time_series_df.columns)
# Create figure with shared x-axis
fig, axes = plt.subplots(n_series, 1, figsize=(12, 2*n_series), sharex=True)
fig.suptitle(title, y=1.02, fontsize=14)
# Calculate y-axis limits for consistent scaling
y_min = time_series_df.min().min()
y_max = time_series_df.max().max()
y_range = y_max - y_min
y_buffer = y_range * 0.1
for idx, column in enumerate(time_series_df.columns):
ax = axes[idx] if n_series > 1 else axes
series = time_series_df[column]
# Plot the line with smaller markers
ax.plot(range(len(series)), series, marker='o', markersize=3, linewidth=1.5)
# Set title and labels
ax.set_title(column, pad=5, fontsize=10)
ax.set_ylabel('Popularity', fontsize=9)
# Set consistent y-axis limits with buffer
ax.set_ylim(y_min - y_buffer, y_max + y_buffer)
# Add grid with lower alpha
# ax.grid(True, linestyle='--', alpha=0.4)
# Format x-axis (only for bottom subplot)
if idx == n_series-1:
ax.set_xlabel('Time Period', fontsize=10)
ax.set_xticks(range(len(series)))
ax.set_xticklabels(time_series_df.index, rotation=45, ha='right')
else:
ax.set_xticks(range(len(series)))
ax.set_xticklabels([])
# Adjust spacing between subplots
plt.subplots_adjust(hspace=0.3)
return fig
def plot_emotion_topic_grid(pair_series: Dict[Tuple[str, str], pd.Series],
top_n: int) -> plt.Figure:
"""
Create grid of time series plots for emotion-topic pairs with improved formatting
"""
fig, axes = plt.subplots(top_n, top_n, figsize=(15, 15))
fig.suptitle('Emotion-Topic Pair Frequencies Over Time', y=1.02, fontsize=14)
# Get unique emotions and topics
emotions = sorted(set(e for e, _ in pair_series.keys()))
topics = sorted(set(t for _, t in pair_series.keys()))
# Calculate global y-axis limits for consistent scaling
all_values = [series.values for series in pair_series.values()]
y_min = min(min(values) for values in all_values)
y_max = max(max(values) for values in all_values)
y_range = y_max - y_min
y_buffer = y_range * 0.1
for i, emotion in enumerate(emotions):
for j, topic in enumerate(topics):
ax = axes[i, j]
series = pair_series[(emotion, topic)]
# Plot the line
ax.plot(range(len(series)), series, marker='o', markersize=3)
# Set titles and labels
if i == 0:
ax.set_title(f'{topic}', fontsize=10)
if j == 0:
ax.set_ylabel(f'{emotion}', fontsize=10)
# Set consistent y-axis limits
ax.set_ylim(y_min - y_buffer, y_max + y_buffer)
# Add grid
ax.grid(True, linestyle='--', alpha=0.4)
# Format ticks
if i == top_n-1: # Only bottom row shows x-labels
ax.set_xticklabels(series.index, rotation=45, fontsize=8)
else:
ax.set_xticklabels([])
ax.tick_params(axis='both', which='major', labelsize=8)
plt.tight_layout()
return fig