RizwanSajad commited on
Commit
c221934
·
verified ·
1 Parent(s): a2bec5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -23
app.py CHANGED
@@ -2,28 +2,38 @@
2
  import gradio as gr
3
  import torch
4
  import pandas as pd
5
- from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, RagConfig
6
  from datasets import Dataset
7
  import yfinance as yf
8
  import numpy as np
9
 
10
  # Function to fetch and preprocess ICICI Bank data
11
  def fetch_and_preprocess_data():
12
- # Fetch ICICI Bank data using yfinance
13
- ticker = "ICICIBANK.NS"
14
- data = yf.download(ticker, start="2020-01-01", end="2023-01-01")
15
-
16
- # Calculate technical indicators
17
- data['MA_50'] = data['Close'].rolling(window=50).mean()
18
- data['MA_200'] = data['Close'].rolling(window=200).mean()
19
-
20
- return data
 
 
 
 
 
 
 
21
 
22
  # Function to create and save a custom index for the retriever
23
  def create_custom_index():
24
  # Fetch and preprocess data
25
  data = fetch_and_preprocess_data()
26
 
 
 
 
27
  # Create a dataset for the retriever
28
  dataset = Dataset.from_dict({
29
  "id": [str(i) for i in range(len(data))],
@@ -43,25 +53,36 @@ def create_custom_index():
43
  # Load the fine-tuned RAG model and tokenizer
44
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base")
45
 
46
- # Create and save the custom index
47
- dataset_path, index_path = create_custom_index()
48
-
49
- # Load the retriever with the custom index
50
- retriever = RagRetriever.from_pretrained(
51
- "facebook/rag-sequence-base",
52
- index_name="custom",
53
- passages_path=dataset_path,
54
- index_path=index_path
55
- )
56
-
57
- # Load the RAG model
58
- model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-base", retriever=retriever)
 
 
 
 
 
59
 
60
  # Function to analyze trading data using the RAG model
61
  def analyze_trading_data(question):
 
 
 
62
  # Fetch and preprocess data
63
  data = fetch_and_preprocess_data()
64
 
 
 
 
65
  # Prepare context for the RAG model
66
  context = (
67
  f"ICICI Bank stock data:\n"
 
2
  import gradio as gr
3
  import torch
4
  import pandas as pd
5
+ from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
6
  from datasets import Dataset
7
  import yfinance as yf
8
  import numpy as np
9
 
10
  # Function to fetch and preprocess ICICI Bank data
11
  def fetch_and_preprocess_data():
12
+ try:
13
+ # Fetch ICICI Bank data using yfinance
14
+ ticker = "ICICIBANK.BO" # Use BSE symbol
15
+ data = yf.download(ticker, start="2020-01-01", end="2023-01-01")
16
+
17
+ if data.empty:
18
+ raise ValueError("No data found for the given symbol.")
19
+
20
+ # Calculate technical indicators
21
+ data['MA_50'] = data['Close'].rolling(window=50).mean()
22
+ data['MA_200'] = data['Close'].rolling(window=200).mean()
23
+
24
+ return data
25
+ except Exception as e:
26
+ print(f"Error fetching data: {e}")
27
+ return pd.DataFrame() # Return an empty DataFrame if fetching fails
28
 
29
  # Function to create and save a custom index for the retriever
30
  def create_custom_index():
31
  # Fetch and preprocess data
32
  data = fetch_and_preprocess_data()
33
 
34
+ if data.empty:
35
+ raise ValueError("No data available to create the index.")
36
+
37
  # Create a dataset for the retriever
38
  dataset = Dataset.from_dict({
39
  "id": [str(i) for i in range(len(data))],
 
53
  # Load the fine-tuned RAG model and tokenizer
54
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base")
55
 
56
+ try:
57
+ # Create and save the custom index
58
+ dataset_path, index_path = create_custom_index()
59
+
60
+ # Load the retriever with the custom index
61
+ retriever = RagRetriever.from_pretrained(
62
+ "facebook/rag-sequence-base",
63
+ index_name="custom",
64
+ passages_path=dataset_path,
65
+ index_path=index_path
66
+ )
67
+
68
+ # Load the RAG model
69
+ model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-base", retriever=retriever)
70
+ except Exception as e:
71
+ print(f"Error initializing model or retriever: {e}")
72
+ retriever = None
73
+ model = None
74
 
75
  # Function to analyze trading data using the RAG model
76
  def analyze_trading_data(question):
77
+ if model is None or retriever is None:
78
+ return "Error: Model or retriever is not initialized. Please check the logs."
79
+
80
  # Fetch and preprocess data
81
  data = fetch_and_preprocess_data()
82
 
83
+ if data.empty:
84
+ return "Error: No data available for analysis."
85
+
86
  # Prepare context for the RAG model
87
  context = (
88
  f"ICICI Bank stock data:\n"