Rohil Bansal commited on
Commit
ee7ea09
·
1 Parent(s): dd0bddd

dataframe improved

Browse files
course_search/app/gradio_app.py CHANGED
@@ -21,7 +21,7 @@ class CourseSearchApp:
21
  """Initialize RAG system and load data"""
22
  try:
23
  # Construct path to data file
24
- data_path = Path(__file__).parent.parent.parent / 'data' / 'courses_with_embeddings.pkl'
25
 
26
  if not data_path.exists():
27
  raise FileNotFoundError(f"Data file not found at: {data_path}")
@@ -30,9 +30,23 @@ class CourseSearchApp:
30
  df = pd.read_pickle(str(data_path))
31
  logger.info(f"Loaded {len(df)} courses from {data_path}")
32
 
 
 
 
 
 
 
 
 
 
33
  # Initialize RAG system
34
  self.rag_system = RAGSystem()
35
- self.rag_system.load_and_process_data(df)
 
 
 
 
 
36
  logger.info("Components loaded successfully")
37
 
38
  except Exception as e:
 
21
  """Initialize RAG system and load data"""
22
  try:
23
  # Construct path to data file
24
+ data_path = Path(__file__).parent.parent.parent / 'data' / 'courses.pkl'
25
 
26
  if not data_path.exists():
27
  raise FileNotFoundError(f"Data file not found at: {data_path}")
 
30
  df = pd.read_pickle(str(data_path))
31
  logger.info(f"Loaded {len(df)} courses from {data_path}")
32
 
33
+ # Validate DataFrame
34
+ if len(df) == 0:
35
+ raise ValueError("Empty DataFrame loaded")
36
+
37
+ required_columns = ['title', 'description', 'curriculum', 'url']
38
+ missing_columns = [col for col in required_columns if col not in df.columns]
39
+ if missing_columns:
40
+ raise ValueError(f"Missing required columns: {missing_columns}")
41
+
42
  # Initialize RAG system
43
  self.rag_system = RAGSystem()
44
+
45
+ # Create cache directory
46
+ cache_dir = data_path.parent / 'cache'
47
+ cache_dir.mkdir(exist_ok=True)
48
+
49
+ self.rag_system.load_and_process_data(df, cache_dir=cache_dir)
50
  logger.info("Components loaded successfully")
51
 
52
  except Exception as e:
course_search/search_system/rag_system.py CHANGED
@@ -9,71 +9,90 @@ import logging
9
  from typing import List, Dict
10
  import os
11
  from dotenv import load_dotenv
 
 
 
 
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
  class RAGSystem:
16
  def __init__(self):
17
- """Initialize the RAG system with LangChain components"""
18
- load_dotenv()
19
-
20
- # Initialize embedding model
21
- self.embeddings = HuggingFaceEmbeddings(
22
- model_name="sentence-transformers/all-MiniLM-L6-v2"
23
- )
24
-
25
- # Initialize text splitter for chunking
26
- self.text_splitter = RecursiveCharacterTextSplitter(
27
- chunk_size=500,
28
- chunk_overlap=50
29
- )
30
-
31
- self.vector_store = None
32
- self.qa_chain = None
33
 
34
- def load_and_process_data(self, df: pd.DataFrame) -> None:
35
- """
36
- Load course data and create vector store
37
- """
38
  try:
39
- # Prepare documents from DataFrame
40
- loader = DataFrameLoader(
41
- data_frame=df,
42
- page_content_column="description"
43
- )
44
- documents = loader.load()
45
- for doc, row in zip(documents, df.itertuples()):
46
- doc.metadata = {
47
- "title": row.title,
48
- "url": row.url,
49
- # Add other metadata fields as needed
50
- }
51
-
52
- # Split documents into chunks
53
- splits = self.text_splitter.split_documents(documents)
54
-
55
- # Create vector store
56
- self.vector_store = FAISS.from_documents(
57
- splits,
58
- self.embeddings
59
- )
60
 
61
- # Initialize QA chain
62
- llm = HuggingFaceHub(
63
- repo_id="google/flan-t5-small",
64
- huggingfacehub_api_token=os.getenv('HUGGINGFACE_API_TOKEN')
65
- )
66
 
67
- self.qa_chain = RetrievalQA.from_chain_type(
68
- llm=llm,
69
- chain_type="stuff",
70
- retriever=self.vector_store.as_retriever()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  )
72
 
73
- logger.info("RAG system initialized successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  except Exception as e:
76
- logger.error(f"Error initializing RAG system: {str(e)}")
77
  raise
78
 
79
  def search_courses(self, query: str, top_k: int = 5) -> List[Dict]:
 
9
  from typing import List, Dict
10
  import os
11
  from dotenv import load_dotenv
12
+ from pathlib import Path
13
+ import numpy as np
14
+ import faiss
15
+ from sentence_transformers import SentenceTransformer
16
 
17
  logger = logging.getLogger(__name__)
18
 
19
  class RAGSystem:
20
  def __init__(self):
21
+ """Initialize the RAG system"""
22
+ try:
23
+ self.model = SentenceTransformer('all-MiniLM-L6-v2')
24
+ self.embeddings = None
25
+ self.index = None
26
+ self.df = None
27
+ logger.info("RAG system initialized successfully")
28
+ except Exception as e:
29
+ logger.error(f"Error initializing RAG system: {str(e)}")
30
+ raise
 
 
 
 
 
 
31
 
32
+ def load_and_process_data(self, df: pd.DataFrame, cache_dir: Path = None):
33
+ """Load and process the course data with caching support"""
 
 
34
  try:
35
+ # Validate input
36
+ if df is None or len(df) == 0:
37
+ raise ValueError("Empty or None DataFrame provided")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ required_columns = ['title', 'description', 'curriculum', 'url']
40
+ missing_columns = [col for col in required_columns if col not in df.columns]
41
+ if missing_columns:
42
+ raise ValueError(f"Missing required columns: {missing_columns}")
 
43
 
44
+ self.df = df
45
+ vector_dimension = 384 # dimension for all-MiniLM-L6-v2
46
+
47
+ # Try loading from cache first
48
+ if cache_dir is not None:
49
+ cache_dir.mkdir(exist_ok=True)
50
+ embeddings_path = cache_dir / 'course_embeddings.npy'
51
+ index_path = cache_dir / 'faiss_index.bin'
52
+
53
+ if embeddings_path.exists() and index_path.exists():
54
+ logger.info("Loading cached embeddings and index...")
55
+ try:
56
+ self.embeddings = np.load(str(embeddings_path))
57
+ self.index = faiss.read_index(str(index_path))
58
+ logger.info("Successfully loaded cached data")
59
+ return
60
+ except Exception as e:
61
+ logger.warning(f"Failed to load cache: {e}. Computing new embeddings...")
62
+
63
+ # Compute new embeddings
64
+ logger.info("Computing course embeddings...")
65
+ texts = [
66
+ f"{row['title']}. {row['description']}"
67
+ for _, row in df.iterrows()
68
+ ]
69
+
70
+ if not texts:
71
+ raise ValueError("No texts to encode")
72
+
73
+ self.embeddings = self.model.encode(
74
+ texts,
75
+ show_progress_bar=True,
76
+ convert_to_numpy=True
77
  )
78
 
79
+ if self.embeddings.size == 0:
80
+ raise ValueError("Failed to generate embeddings")
81
+
82
+ # Create and populate FAISS index
83
+ self.index = faiss.IndexFlatL2(vector_dimension)
84
+ self.index.add(self.embeddings.astype('float32'))
85
+
86
+ # Save to cache if directory provided
87
+ if cache_dir is not None:
88
+ logger.info("Saving embeddings and index to cache...")
89
+ np.save(str(embeddings_path), self.embeddings)
90
+ faiss.write_index(self.index, str(index_path))
91
+
92
+ logger.info(f"Successfully processed {len(df)} courses")
93
 
94
  except Exception as e:
95
+ logger.error(f"Error processing data: {str(e)}")
96
  raise
97
 
98
  def search_courses(self, query: str, top_k: int = 5) -> List[Dict]: