Trading_App / app.py
RizwanSajad's picture
Update app.py
96b2c8f verified
import gradio as gr
import torch
import pandas as pd
import yfinance as yf
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
# Check if GPU is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Function to fetch and preprocess ICICI Bank data
def fetch_and_preprocess_data():
try:
ticker = "ICICIBANK.BO" # ICICI Bank BSE Symbol
data = yf.download(ticker, start="2020-01-01", end="2023-01-01")
if data.empty:
raise ValueError("No data found for the given symbol.")
# Calculate Moving Averages
data['MA_50'] = data['Close'].rolling(window=50).mean()
data['MA_200'] = data['Close'].rolling(window=200).mean()
return data
except Exception as e:
print(f"Error fetching data: {e}")
return pd.DataFrame() # Return empty DataFrame if fetching fails
# Load the RAG model and tokenizer with error handling
try:
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base")
print("Tokenizer loaded successfully.")
retriever = RagRetriever.from_pretrained(
"facebook/rag-sequence-base",
index_name="wiki_dpr",
passages_path=None,
index_path=None
)
print("Retriever loaded successfully.")
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-base", retriever=retriever).to(device)
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model or retriever: {e}")
retriever = None
model = None
# Function to analyze trading data
def analyze_trading_data(question):
if model is None or retriever is None:
return "Error: Model or retriever is not initialized. Please check the logs."
# Fetch and preprocess data
data = fetch_and_preprocess_data()
if data.empty:
return "Error: No data available for analysis."
# Prepare context for RAG model
context = (
f"ICICI Bank stock data:\n"
f"Latest Close Price: {data['Close'].iloc[-1]:.2f}\n"
f"50-Day Moving Average: {data['MA_50'].iloc[-1]:.2f}\n"
f"200-Day Moving Average: {data['MA_200'].iloc[-1]:.2f}\n"
)
# Combine question and context
input_text = f"Question: {question}\nContext: {context}"
# Tokenize input
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).to(device)
# Generate answer using the model
outputs = model.generate(inputs['input_ids'])
# Decode output
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
return answer
# Gradio interface
iface = gr.Interface(
fn=analyze_trading_data,
inputs="text",
outputs="text",
title="ICICI Bank Trading Analysis",
description="Ask any question about ICICI Bank's trading data and get a detailed analysis.",
examples=[
"What is the current trend of ICICI Bank stock?",
"Is the 50-day moving average above the 200-day moving average?",
"What is the latest closing price of ICICI Bank?"
]
)
# Launch the app
if __name__ == "__main__":
iface.launch()