File size: 3,191 Bytes
46bc116
663d47c
46bc116
 
96b2c8f
 
 
 
 
46bc116
663d47c
 
c221934
96b2c8f
c221934
 
 
 
 
96b2c8f
c221934
 
 
 
 
 
96b2c8f
46bc116
96b2c8f
c221934
8e804f7
 
 
c221934
 
8e804f7
 
 
c221934
4a05d26
8e804f7
96b2c8f
4a05d26
c221934
96b2c8f
c221934
 
a2bec5b
96b2c8f
663d47c
c221934
 
96b2c8f
663d47c
 
96b2c8f
c221934
 
96b2c8f
 
663d47c
 
 
 
 
 
96b2c8f
663d47c
 
96b2c8f
 
 
 
 
663d47c
46bc116
96b2c8f
663d47c
96b2c8f
663d47c
46bc116
663d47c
46bc116
663d47c
46bc116
 
663d47c
 
 
 
 
 
 
46bc116
 
663d47c
96b2c8f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
import torch
import pandas as pd
import yfinance as yf
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration

# Check if GPU is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Function to fetch and preprocess ICICI Bank data
def fetch_and_preprocess_data():
    try:
        ticker = "ICICIBANK.BO"  # ICICI Bank BSE Symbol
        data = yf.download(ticker, start="2020-01-01", end="2023-01-01")
        
        if data.empty:
            raise ValueError("No data found for the given symbol.")
        
        # Calculate Moving Averages
        data['MA_50'] = data['Close'].rolling(window=50).mean()
        data['MA_200'] = data['Close'].rolling(window=200).mean()
        
        return data
    except Exception as e:
        print(f"Error fetching data: {e}")
        return pd.DataFrame()  # Return empty DataFrame if fetching fails

# Load the RAG model and tokenizer with error handling
try:
    tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base")
    print("Tokenizer loaded successfully.")

    retriever = RagRetriever.from_pretrained(
        "facebook/rag-sequence-base",
        index_name="wiki_dpr",
        passages_path=None,
        index_path=None
    )
    print("Retriever loaded successfully.")

    model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-base", retriever=retriever).to(device)
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error loading model or retriever: {e}")
    retriever = None
    model = None

# Function to analyze trading data
def analyze_trading_data(question):
    if model is None or retriever is None:
        return "Error: Model or retriever is not initialized. Please check the logs."

    # Fetch and preprocess data
    data = fetch_and_preprocess_data()

    if data.empty:
        return "Error: No data available for analysis."

    # Prepare context for RAG model
    context = (
        f"ICICI Bank stock data:\n"
        f"Latest Close Price: {data['Close'].iloc[-1]:.2f}\n"
        f"50-Day Moving Average: {data['MA_50'].iloc[-1]:.2f}\n"
        f"200-Day Moving Average: {data['MA_200'].iloc[-1]:.2f}\n"
    )

    # Combine question and context
    input_text = f"Question: {question}\nContext: {context}"

    # Tokenize input
    inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).to(device)

    # Generate answer using the model
    outputs = model.generate(inputs['input_ids'])
    
    # Decode output
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)

    return answer

# Gradio interface
iface = gr.Interface(
    fn=analyze_trading_data,
    inputs="text",
    outputs="text",
    title="ICICI Bank Trading Analysis",
    description="Ask any question about ICICI Bank's trading data and get a detailed analysis.",
    examples=[
        "What is the current trend of ICICI Bank stock?",
        "Is the 50-day moving average above the 200-day moving average?",
        "What is the latest closing price of ICICI Bank?"
    ]
)

# Launch the app
if __name__ == "__main__":
    iface.launch()