RizwanSajad commited on
Commit
8e804f7
·
verified ·
1 Parent(s): fb2104b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -41
app.py CHANGED
@@ -3,9 +3,7 @@ import gradio as gr
3
  import torch
4
  import pandas as pd
5
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
6
- from datasets import Dataset
7
  import yfinance as yf
8
- import numpy as np
9
 
10
  # Function to fetch and preprocess ICICI Bank data
11
  def fetch_and_preprocess_data():
@@ -26,51 +24,20 @@ def fetch_and_preprocess_data():
26
  print(f"Error fetching data: {e}")
27
  return pd.DataFrame() # Return an empty DataFrame if fetching fails
28
 
29
- # Function to create and save a custom index for the retriever
30
- def create_custom_index():
31
- # Fetch and preprocess data
32
- data = fetch_and_preprocess_data()
33
-
34
- if data.empty:
35
- raise ValueError("No data available to create the index.")
36
-
37
- # Create a dataset for the retriever
38
- dataset = Dataset.from_dict({
39
- "id": [str(i) for i in range(len(data))],
40
- "text": data.apply(lambda row: f"Date: {row.name}, Close: {row['Close']:.2f}, MA_50: {row['MA_50']:.2f}, MA_200: {row['MA_200']:.2f}", axis=1).tolist(),
41
- "title": [f"ICICI Bank Data {i}" for i in range(len(data))]
42
- })
43
-
44
- # Save the dataset and index
45
- dataset_path = "icici_bank_dataset"
46
- index_path = "icici_bank_index"
47
- dataset.save_to_disk(dataset_path)
48
- print(f"Dataset saved to {dataset_path}")
49
-
50
- # Add FAISS index
51
- dataset.add_faiss_index("text")
52
- dataset.get_index("text").save(index_path)
53
- print(f"FAISS index saved to {index_path}")
54
-
55
- return dataset_path, index_path
56
-
57
  # Load the fine-tuned RAG model and tokenizer
58
- tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base")
59
- print("Tokenizer loaded successfully.")
60
-
61
  try:
62
- # Create and save the custom index
63
- dataset_path, index_path = create_custom_index()
64
-
65
- # Load the retriever with the custom index
66
  retriever = RagRetriever.from_pretrained(
67
  "facebook/rag-sequence-base",
68
- index_name="custom",
69
- passages_path=dataset_path,
70
- index_path=index_path
71
  )
72
  print("Retriever loaded successfully.")
73
-
74
  # Load the RAG model
75
  model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-base", retriever=retriever)
76
  print("Model loaded successfully.")
 
3
  import torch
4
  import pandas as pd
5
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
 
6
  import yfinance as yf
 
7
 
8
  # Function to fetch and preprocess ICICI Bank data
9
  def fetch_and_preprocess_data():
 
24
  print(f"Error fetching data: {e}")
25
  return pd.DataFrame() # Return an empty DataFrame if fetching fails
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # Load the fine-tuned RAG model and tokenizer
 
 
 
28
  try:
29
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base")
30
+ print("Tokenizer loaded successfully.")
31
+
32
+ # Use a pre-built index (e.g., wiki_dpr) instead of creating a custom index
33
  retriever = RagRetriever.from_pretrained(
34
  "facebook/rag-sequence-base",
35
+ index_name="wiki_dpr",
36
+ passages_path=None,
37
+ index_path=None
38
  )
39
  print("Retriever loaded successfully.")
40
+
41
  # Load the RAG model
42
  model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-base", retriever=retriever)
43
  print("Model loaded successfully.")