File size: 2,267 Bytes
cd0e571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# utils/plotting.py

import matplotlib.pyplot as plt
import pandas as pd

def plot_stock_data(data, buy_signals=None, sell_signals=None, title="Stock Data with Indicators"):
    """
    Plots stock data with SMAs, Bollinger Bands, and buy/sell signals.

    Parameters:
    - data: DataFrame containing the stock data with 'Close', 'SMA_21', 'SMA_50', 'BB_Upper', and 'BB_Lower' columns.
    - buy_signals: DataFrame or Series with buy signals. Must contain a 'Date' or similar index for plotting.
    - sell_signals: DataFrame or Series with sell signals. Must contain a 'Date' or similar index for plotting.
    - title: Title of the plot.
    """
    # Create a new figure and set the size
    plt.figure(figsize=(14, 7))

    # Plot the closing price
    plt.plot(data.index, data['Close'], label='Close Price', color='skyblue', linewidth=2)

    # Plot SMAs
    plt.plot(data.index, data['SMA_21'], label='21-period SMA', color='orange', linewidth=1.5)
    plt.plot(data.index, data['SMA_50'], label='50-period SMA', color='green', linewidth=1.5)

    # Plot Bollinger Bands
    plt.plot(data.index, data['BB_Upper'], label='Upper Bollinger Band', color='grey', linestyle='--', linewidth=1)
    plt.plot(data.index, data['BB_Lower'], label='Lower Bollinger Band', color='grey', linestyle='--', linewidth=1)

    # Highlight buy signals
    if buy_signals is not None:
        plt.scatter(buy_signals.index, data.loc[buy_signals.index]['Close'], marker='^', color='green', label='Buy Signal', alpha=1)

    # Highlight sell signals
    if sell_signals is not None:
        plt.scatter(sell_signals.index, data.loc[sell_signals.index]['Close'], marker='v', color='red', label='Sell Signal', alpha=1)

    # Customize the plot
    plt.title(title)
    plt.xlabel('Date')
    plt.ylabel('Price')
    plt.legend()
    plt.grid(True)

    # Show the plot
    plt.show()

# Example usage:
# This assumes `data` DataFrame is already loaded with the required columns including 'Close', 'SMA_21', 'SMA_50', 'BB_Upper', 'BB_Lower'.
# `buy_signals` and `sell_signals` DataFrames/Series should have the dates of signals.
# Due to the nature of this example, actual data is not provided here. To test this function, ensure you have a DataFrame with the appropriate structure.