Spaces:
Runtime error
Runtime error
import torch | |
import yfinance as yf | |
import matplotlib.pyplot as plt | |
import mplfinance as mpf | |
from PIL import Image, ImageDraw, ImageFont | |
import gradio as gr | |
import datetime | |
import logging | |
from transformers import AutoProcessor, AutoModelForPreTraining | |
import tempfile | |
import os | |
import spaces | |
import pandas as pd | |
# Configure logging | |
logging.basicConfig(filename='debug.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') | |
# Load the chart_analysis model and processor | |
processor = AutoProcessor.from_pretrained("mobenta/chart_analysis") | |
model = AutoModelForPreTraining.from_pretrained("mobenta/chart_analysis") | |
def predict(image, input_text): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
image = image.convert("RGB") | |
inputs = processor(text=input_text, images=image, return_tensors="pt") | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
prompt_length = inputs['input_ids'].shape[1] | |
generate_ids = model.generate(**inputs, max_new_tokens=512) | |
output_text = processor.batch_decode(generate_ids[:, prompt_length:], skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
return output_text | |
def fetch_stock_data(ticker='TSLA', start='2010-01-01', end=None, interval='1d'): | |
if end is None: | |
end = datetime.datetime.now().strftime('%Y-%m-%d') | |
try: | |
logging.debug(f"Fetching data for {ticker} from {start} to {end} with interval {interval}") | |
stock = yf.Ticker(ticker) | |
data = stock.history(start=start, end=end, interval=interval) | |
if data.empty: | |
logging.warning(f"No data fetched for {ticker} in the range {start} to {end}") | |
raise ValueError(f"No data available for {ticker} in the range {start} to {end}") | |
logging.debug(f"Fetched data with {len(data)} rows") | |
return data | |
except Exception as e: | |
logging.error(f"Error fetching data: {e}") | |
raise | |
def create_stock_chart(data, ticker, filename='chart.png', timeframe='1d', indicators=None): | |
try: | |
logging.debug(f"Creating chart for {ticker} with timeframe {timeframe} and saving to {filename}") | |
title = f"{ticker.upper()} Price Data (Timeframe: {timeframe})" | |
plt.rcParams["axes.titlesize"] = 10 | |
my_style = mpf.make_mpf_style(base_mpf_style='charles') | |
# Calculate indicators if selected | |
addplot = [] | |
if indicators: | |
if 'RSI' in indicators: | |
delta = data['Close'].diff(1) | |
gain = delta.where(delta > 0, 0) | |
loss = -delta.where(delta < 0, 0) | |
avg_gain = gain.rolling(window=14).mean() | |
avg_loss = loss.rolling(window=14).mean() | |
rs = avg_gain / avg_loss | |
rsi = 100 - (100 / (1 + rs)) | |
addplot.append(mpf.make_addplot(rsi, panel=2, color='orange', ylabel='RSI')) | |
if 'SMA21' in indicators: | |
logging.debug("Calculating SMA 21") | |
sma_21 = data['Close'].rolling(window=21).mean() | |
addplot.append(mpf.make_addplot(sma_21, color='purple', linestyle='dashed')) | |
if 'SMA50' in indicators: | |
logging.debug("Calculating SMA 50") | |
sma_50 = data['Close'].rolling(window=50).mean() | |
addplot.append(mpf.make_addplot(sma_50, color='orange', linestyle='dashed')) | |
if 'SMA200' in indicators: | |
logging.debug("Calculating SMA 200") | |
sma_200 = data['Close'].rolling(window=200).mean() | |
addplot.append(mpf.make_addplot(sma_200, color='brown', linestyle='dashed')) | |
if 'VWAP' in indicators: | |
logging.debug("Calculating VWAP") | |
vwap = (data['Volume'] * (data['High'] + data['Low'] + data['Close']) / 3).cumsum() / data['Volume'].cumsum() | |
addplot.append(mpf.make_addplot(vwap, color='blue', linestyle='dashed')) | |
if 'Bollinger Bands' in indicators: | |
logging.debug("Calculating Bollinger Bands") | |
sma = data['Close'].rolling(window=20).mean() | |
std = data['Close'].rolling(window=20).std() | |
upper_band = sma + (std * 2) | |
lower_band = sma - (std * 2) | |
addplot.append(mpf.make_addplot(upper_band, color='green', linestyle='dashed')) | |
addplot.append(mpf.make_addplot(lower_band, color='green', linestyle='dashed')) | |
fig, axlist = mpf.plot(data, type='candle', style=my_style, volume=True, addplot=addplot, returnfig=True) | |
fig.suptitle(title, y=0.98) | |
# Save chart image | |
fig.savefig(filename, dpi=300) | |
plt.close(fig) | |
# Open and add financial data to the image | |
image = Image.open(filename) | |
draw = ImageDraw.Draw(image) | |
font = ImageFont.load_default() # Use default font, you can also use custom fonts if available | |
# Financial metrics to add | |
metrics = { | |
"Ticker": ticker, | |
"Latest Close": f"${data['Close'].iloc[-1]:,.2f}", | |
"Volume": f"{data['Volume'].iloc[-1]:,.0f}" | |
} | |
# Add additional metrics if indicators are present | |
if 'SMA21' in indicators: | |
metrics["SMA 21"] = f"${data['Close'].rolling(window=21).mean().iloc[-1]:,.2f}" | |
if 'SMA50' in indicators: | |
metrics["SMA 50"] = f"${data['Close'].rolling(window=50).mean().iloc[-1]:,.2f}" | |
if 'SMA200' in indicators: | |
metrics["SMA 200"] = f"${data['Close'].rolling(window=200).mean().iloc[-1]:,.2f}" | |
# Draw metrics on the image | |
y_text = image.height - 50 # Starting y position for text | |
for key, value in metrics.items(): | |
text = f"{key}: {value}" | |
draw.text((10, y_text), text, font=font, fill=(255, 255, 255)) # White color text | |
y_text += 20 | |
# Resize image | |
new_size = (image.width * 3, image.height * 3) | |
resized_image = image.resize(new_size, Image.LANCZOS) | |
resized_image.save(filename) | |
logging.debug(f"Resized image with timeframe {timeframe} and ticker {ticker} saved to {filename}") | |
except Exception as e: | |
logging.error(f"Error creating or resizing chart: {e}") | |
raise | |
def combine_images(image_paths, output_path='combined_chart.png'): | |
try: | |
logging.debug(f"Combining images {image_paths} into {output_path}") | |
images = [Image.open(path) for path in image_paths] | |
# Calculate total width and max height for combined image | |
total_width = sum(img.width for img in images) | |
max_height = max(img.height for img in images) | |
combined_image = Image.new('RGB', (total_width, max_height)) | |
x_offset = 0 | |
for img in images: | |
combined_image.paste(img, (x_offset, 0)) | |
x_offset += img.width | |
combined_image.save(output_path) | |
logging.debug(f"Combined image saved to {output_path}") | |
return output_path | |
except Exception as e: | |
logging.error(f"Error combining images: {e}") | |
raise | |
def gradio_interface(ticker1, ticker2, ticker3, ticker4, start_date, end_date, query, analysis_type, interval, indicators): | |
try: | |
logging.debug(f"Starting gradio_interface with tickers: {ticker1}, {ticker2}, {ticker3}, {ticker4}, start_date: {start_date}, end_date: {end_date}, query: {query}, analysis_type: {analysis_type}, interval: {interval}") | |
tickers = [ticker1, ticker2, ticker3, ticker4] | |
chart_paths = [] | |
for i, ticker in enumerate(tickers): | |
if ticker: | |
data = fetch_stock_data(ticker, start=start_date, end=end_date, interval=interval) | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_chart: | |
chart_path = temp_chart.name | |
create_stock_chart(data, ticker, chart_path, timeframe=interval, indicators=indicators) | |
chart_paths.append(chart_path) | |
if analysis_type == 'Comparative Analysis' and len(chart_paths) > 1: | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_combined_chart: | |
combined_chart_path = temp_combined_chart.name | |
combine_images(chart_paths, combined_chart_path) | |
insights = predict(Image.open(combined_chart_path), query) | |
return insights, combined_chart_path | |
# No comparative analysis, just return the single chart | |
if chart_paths: | |
insights = predict(Image.open(chart_paths[0]), query) | |
return insights, chart_paths[0] | |
else: | |
return "No tickers provided.", None | |
except Exception as e: | |
logging.error(f"Error in Gradio interface: {e}") | |
return f"Error processing image or query: {e}", None | |
def gradio_app(): | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
## 📈Stock Analysis Dashboard | |
This application provides a comprehensive stock analysis tool that allows users to input up to four stock tickers, specify date ranges, and select various financial indicators. The core functionalities include: | |
1. **Data Fetching and Chart Creation**: Historical stock data is fetched from Yahoo Finance, and candlestick charts are generated with optional financial indicators like RSI, SMA, VWAP, and Bollinger Bands. | |
2. **Text Analysis and Insights Generation**: The application uses a pre-trained model based on the **Paligema** architecture to analyze the input chart and text query, generating insightful analysis based on the provided financial data and context. | |
3. **User Interface**: Users can interactively select stocks, date ranges, intervals, and indicators. The app also supports the analysis of single tickers or comparative analysis across multiple tickers. | |
4. **Logging and Debugging**: Detailed logging helps in debugging and tracking the application's processes. | |
5. **Enhanced Image Processing**: The app adds financial metrics and annotations to the generated charts, ensuring clear presentation of data. | |
This tool leverages the Paligema model to provide detailed insights into stock market trends, offering an interactive and educational experience for users. | |
""") | |
with gr.Row(): | |
ticker1 = gr.Textbox(label="Primary Ticker", value="GC=F") | |
ticker2 = gr.Textbox(label="Secondary Ticker", value="CL=F") | |
ticker3 = gr.Textbox(label="Third Ticker", value="SPY") | |
ticker4 = gr.Textbox(label="Fourth Ticker", value="EURUSD=X") | |
with gr.Row(): | |
start_date = gr.Textbox(label="Start Date", value="2022-01-01") | |
end_date = gr.Textbox(label="End Date", value=datetime.datetime.now().strftime('%Y-%m-%d')) | |
interval = gr.Dropdown(label="Interval", choices=['1d', '1wk', '1mo'], value='1d') | |
with gr.Row(): | |
indicators = gr.CheckboxGroup(label="Indicators", choices=['RSI', 'SMA21', 'SMA50', 'SMA200', 'VWAP', 'Bollinger Bands'], value=['SMA21', 'SMA50']) | |
analysis_type = gr.Radio(label="Analysis Type", choices=['Single Ticker', 'Comparative Analysis'], value='Single Ticker') | |
query = gr.Textbox(label="Analysis Query", value="Analyze the price trends.") | |
analyze_button = gr.Button("Analyze") | |
output_image = gr.Image(label="Stock Chart") | |
output_text = gr.Textbox(label="Generated Insights", lines=5) | |
analyze_button.click( | |
fn=gradio_interface, | |
inputs=[ticker1, ticker2, ticker3, ticker4, start_date, end_date, query, analysis_type, interval, indicators], | |
outputs=[output_text, output_image] | |
) | |
demo.launch() | |
if __name__ == "__main__": | |
gradio_app() | |