Upload 4 files
Browse files- rag.py +96 -0
- readme.md +91 -0
- requirements.txt +12 -0
- shl_scraper.py +183 -0
rag.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
3 |
+
from sentence_transformers import SentenceTransformer
|
4 |
+
import chromadb
|
5 |
+
import uuid
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
# === STEP 1: Preprocessing CSV & Chunking ===
|
9 |
+
def pre_processing_csv(csv_path):
|
10 |
+
df = pd.read_csv(csv_path)
|
11 |
+
df.fillna("", inplace=True)
|
12 |
+
|
13 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
14 |
+
chunk_size=300,
|
15 |
+
chunk_overlap=50
|
16 |
+
)
|
17 |
+
|
18 |
+
documents = []
|
19 |
+
metadatas = []
|
20 |
+
|
21 |
+
for idx, row in df.iterrows():
|
22 |
+
# Combine multiple fields for better context
|
23 |
+
combined_text = f"""
|
24 |
+
Test Name: {row.get('Test Name', '')}
|
25 |
+
Description: {row.get('Description', '')}
|
26 |
+
Remote Testing: {row.get('Remote Testing', '')}
|
27 |
+
Adaptive/IRT: {row.get('Adaptive/IRT', '')}
|
28 |
+
Test Type: {row.get('Test Type', '')}
|
29 |
+
"""
|
30 |
+
|
31 |
+
chunks = text_splitter.split_text(combined_text)
|
32 |
+
|
33 |
+
for chunk in chunks:
|
34 |
+
documents.append(chunk)
|
35 |
+
metadatas.append({
|
36 |
+
"Test Name": row.get('Test Name', ''),
|
37 |
+
"Test Link": row.get('Test Link', ''),
|
38 |
+
"Remote Testing": row.get('Remote Testing', ''),
|
39 |
+
"Adaptive/IRT": row.get('Adaptive/IRT', ''),
|
40 |
+
"Test Type": row.get('Test Type', ''),
|
41 |
+
"row_id": idx
|
42 |
+
})
|
43 |
+
|
44 |
+
return documents, metadatas
|
45 |
+
|
46 |
+
# === STEP 2: Embed and Store in ChromaDB ===
|
47 |
+
def build_chroma_store(documents, metadatas,client=None):
|
48 |
+
if client is None:
|
49 |
+
client = chromadb.Client()
|
50 |
+
collection = client.create_collection(name="shl_test_catalog")
|
51 |
+
print("🔍 Embedding documents...")
|
52 |
+
model = SentenceTransformer("all-MiniLM-L6-v2")
|
53 |
+
embeddings = model.encode(documents, show_progress_bar=True)
|
54 |
+
|
55 |
+
print("📥 Adding to ChromaDB...")
|
56 |
+
collection.add(
|
57 |
+
documents=documents,
|
58 |
+
embeddings=[e.tolist() for e in embeddings],
|
59 |
+
ids=[str(uuid.uuid4()) for _ in range(len(documents))],
|
60 |
+
metadatas=metadatas
|
61 |
+
)
|
62 |
+
|
63 |
+
return collection, model
|
64 |
+
|
65 |
+
# === STEP 3: Query the RAG Model ===
|
66 |
+
def ask_query(query, model, collection, k=10):
|
67 |
+
print(f"\n💬 Query: {query}")
|
68 |
+
query_embedding = model.encode(query)
|
69 |
+
|
70 |
+
# Get more results than needed for diversity
|
71 |
+
results = collection.query(
|
72 |
+
query_embeddings=[query_embedding.tolist()],
|
73 |
+
n_results=k*2 # Get more results for diversity
|
74 |
+
)
|
75 |
+
|
76 |
+
# Process results to ensure diversity
|
77 |
+
seen_tests = set()
|
78 |
+
final_results = []
|
79 |
+
|
80 |
+
for i in range(len(results['documents'][0])):
|
81 |
+
doc = results['documents'][0][i]
|
82 |
+
meta = results['metadatas'][0][i]
|
83 |
+
test_name = meta['Test Name']
|
84 |
+
|
85 |
+
# Skip if we've already seen this test
|
86 |
+
if test_name in seen_tests:
|
87 |
+
continue
|
88 |
+
|
89 |
+
seen_tests.add(test_name)
|
90 |
+
final_results.append((doc, meta))
|
91 |
+
|
92 |
+
# Stop if we have enough diverse results
|
93 |
+
if len(final_results) >= k:
|
94 |
+
break
|
95 |
+
|
96 |
+
return final_results
|
readme.md
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SHL Assessment Retrieval System
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
|
5 |
+
The **SHL Assessment Retrieval System** is a web application designed to query and retrieve relevant assessments from the SHL product catalog. It utilizes a Retrieval-Augmented Generation (RAG) model to provide users with accurate and contextually relevant test assessments based on their queries. The application is built using Streamlit for the frontend and integrates with ChromaDB for efficient data storage and retrieval.
|
6 |
+
|
7 |
+
## Features
|
8 |
+
|
9 |
+
- **Data Scraping**: Automatically scrapes assessment data from the SHL product catalog.
|
10 |
+
- **Data Processing**: Preprocesses and chunks the scraped data for efficient querying.
|
11 |
+
- **Embedding Model**: Utilizes the `SentenceTransformer` model for embedding queries and documents.
|
12 |
+
- **Diverse Query Results**: Returns diverse and relevant results based on user queries.
|
13 |
+
- **User-Friendly Interface**: Built with Streamlit for an interactive user experience.
|
14 |
+
|
15 |
+
## Technologies Used
|
16 |
+
|
17 |
+
- Python
|
18 |
+
- Streamlit
|
19 |
+
- Pandas
|
20 |
+
- Sentence Transformers
|
21 |
+
- ChromaDB
|
22 |
+
- BeautifulSoup (for web scraping)
|
23 |
+
- Requests
|
24 |
+
|
25 |
+
## Installation
|
26 |
+
|
27 |
+
### Prerequisites
|
28 |
+
|
29 |
+
Make sure you have Python 3.7 or higher installed on your machine. You can download it from [python.org](https://www.python.org/downloads/).
|
30 |
+
|
31 |
+
### Clone the Repository
|
32 |
+
|
33 |
+
```bash
|
34 |
+
git clone https://github.com/yourusername/shl-assessment-retrieval.git
|
35 |
+
cd shl-assessment-retrieval
|
36 |
+
```
|
37 |
+
|
38 |
+
### Install Dependencies
|
39 |
+
|
40 |
+
You can install the required packages using pip. It is recommended to create a virtual environment first.
|
41 |
+
|
42 |
+
```bash
|
43 |
+
# Create a virtual environment (optional)
|
44 |
+
python -m venv venv
|
45 |
+
source venv/bin/activate # On Windows use `venv\Scripts\activate`
|
46 |
+
|
47 |
+
# Install dependencies
|
48 |
+
pip install -r requirements.txt
|
49 |
+
```
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
## Usage
|
54 |
+
|
55 |
+
### Scraping Data
|
56 |
+
|
57 |
+
Before querying the assessments, you need to scrape the data from the SHL product catalog. You can do this by running the `shl_scraper.py` script:
|
58 |
+
|
59 |
+
```bash
|
60 |
+
python shl_scraper.py
|
61 |
+
```
|
62 |
+
|
63 |
+
This will create a CSV file named `shl_products.csv` containing the scraped assessment data.
|
64 |
+
|
65 |
+
### Running the Streamlit App
|
66 |
+
|
67 |
+
Once the data is scraped, you can run the Streamlit app:
|
68 |
+
|
69 |
+
```bash
|
70 |
+
streamlit run app.py
|
71 |
+
```
|
72 |
+
|
73 |
+
Open your web browser and navigate to `http://localhost:8501` to access the application.
|
74 |
+
|
75 |
+
### Querying Assessments
|
76 |
+
|
77 |
+
- Enter your query in the input box and click the "Submit" button.
|
78 |
+
- The application will display relevant assessments based on your query.
|
79 |
+
|
80 |
+
## Code Structure
|
81 |
+
|
82 |
+
```
|
83 |
+
shl-assessment-retrieval/
|
84 |
+
│
|
85 |
+
├── app.py # Streamlit application for querying assessments
|
86 |
+
├── rag.py # RAG model implementation for data processing and querying
|
87 |
+
├── shl_scraper.py # Web scraper for fetching assessment data
|
88 |
+
├── evaluate.py # Evaluation script for assessing model performance
|
89 |
+
├── requirements.txt # List of dependencies
|
90 |
+
└── README.md # Project documentation
|
91 |
+
```
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
requests==2.31.0
|
2 |
+
beautifulsoup4==4.12.2
|
3 |
+
pandas==2.1.4
|
4 |
+
sentence-transformers==2.2.2
|
5 |
+
transformers==4.36.2
|
6 |
+
torch==2.2.0
|
7 |
+
huggingface-hub==0.19.4
|
8 |
+
streamlit
|
9 |
+
chromadb
|
10 |
+
langchain
|
11 |
+
setuptools
|
12 |
+
|
shl_scraper.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from bs4 import BeautifulSoup
|
3 |
+
import pandas as pd
|
4 |
+
import time
|
5 |
+
from typing import List, Dict
|
6 |
+
import logging
|
7 |
+
import urllib.parse
|
8 |
+
from sentence_transformers import SentenceTransformer
|
9 |
+
import torch
|
10 |
+
|
11 |
+
# Set up logging
|
12 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
13 |
+
|
14 |
+
class SHLScraper:
|
15 |
+
def __init__(self):
|
16 |
+
self.base_url = "https://www.shl.com/solutions/products/product-catalog/"
|
17 |
+
self.headers = {
|
18 |
+
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
19 |
+
}
|
20 |
+
# Initialize the embedding model
|
21 |
+
try:
|
22 |
+
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
|
23 |
+
except Exception as e:
|
24 |
+
logging.error(f"Error initializing embedding model: {e}")
|
25 |
+
self.embedding_model = None
|
26 |
+
|
27 |
+
def get_page_content(self, start: int, type_num: int) -> str:
|
28 |
+
"""Fetch page content with given start and type parameters."""
|
29 |
+
params = {
|
30 |
+
'start': start,
|
31 |
+
'type': type_num
|
32 |
+
}
|
33 |
+
try:
|
34 |
+
response = requests.get(self.base_url, params=params, headers=self.headers)
|
35 |
+
response.raise_for_status()
|
36 |
+
return response.text
|
37 |
+
except requests.RequestException as e:
|
38 |
+
logging.error(f"Error fetching page: {e}")
|
39 |
+
return ""
|
40 |
+
|
41 |
+
def check_yes_no(self, cell) -> str:
|
42 |
+
"""Check if a cell contains a yes or no indicator based on CSS classes."""
|
43 |
+
yes_span = cell.find('span', class_='catalogue__circle -yes')
|
44 |
+
no_span = cell.find('span', class_='catalogue__circle -no')
|
45 |
+
|
46 |
+
if yes_span:
|
47 |
+
return "Yes"
|
48 |
+
elif no_span:
|
49 |
+
return "No"
|
50 |
+
return ""
|
51 |
+
|
52 |
+
def get_test_link(self, cell) -> str:
|
53 |
+
"""Extract the href link from the test name cell."""
|
54 |
+
link = cell.find('a')
|
55 |
+
if link and 'href' in link.attrs:
|
56 |
+
return link['href']
|
57 |
+
return ""
|
58 |
+
|
59 |
+
def get_test_description(self, test_link: str) -> str:
|
60 |
+
"""Fetch and extract the description from a test's detail page."""
|
61 |
+
if not test_link:
|
62 |
+
return ""
|
63 |
+
|
64 |
+
# Construct full URL if it's a relative path
|
65 |
+
if test_link.startswith('/'):
|
66 |
+
test_link = urllib.parse.urljoin("https://www.shl.com", test_link)
|
67 |
+
|
68 |
+
try:
|
69 |
+
logging.info(f"Fetching description for: {test_link}")
|
70 |
+
response = requests.get(test_link, headers=self.headers)
|
71 |
+
response.raise_for_status()
|
72 |
+
|
73 |
+
soup = BeautifulSoup(response.text, 'html.parser')
|
74 |
+
|
75 |
+
# Initialize description parts
|
76 |
+
description_parts = []
|
77 |
+
|
78 |
+
# Try to find main description
|
79 |
+
desc_div = soup.find('div', class_='product-description')
|
80 |
+
if desc_div:
|
81 |
+
description_parts.append(desc_div.get_text(strip=True))
|
82 |
+
|
83 |
+
# Try to find additional details
|
84 |
+
details_div = soup.find('div', class_='product-details')
|
85 |
+
if details_div:
|
86 |
+
description_parts.append(details_div.get_text(strip=True))
|
87 |
+
|
88 |
+
# Try to find features
|
89 |
+
features_div = soup.find('div', class_='product-features')
|
90 |
+
if features_div:
|
91 |
+
description_parts.append(features_div.get_text(strip=True))
|
92 |
+
|
93 |
+
# Try to find benefits
|
94 |
+
benefits_div = soup.find('div', class_='product-benefits')
|
95 |
+
if benefits_div:
|
96 |
+
description_parts.append(benefits_div.get_text(strip=True))
|
97 |
+
|
98 |
+
# Try to find meta description as fallback
|
99 |
+
if not description_parts:
|
100 |
+
meta_desc = soup.find('meta', {'name': 'description'})
|
101 |
+
if meta_desc and 'content' in meta_desc.attrs:
|
102 |
+
description_parts.append(meta_desc['content'])
|
103 |
+
|
104 |
+
# Combine all parts with proper spacing
|
105 |
+
full_description = " | ".join(filter(None, description_parts))
|
106 |
+
|
107 |
+
time.sleep(1) # Be respectful with requests
|
108 |
+
return full_description
|
109 |
+
|
110 |
+
except requests.RequestException as e:
|
111 |
+
logging.error(f"Error fetching description from {test_link}: {e}")
|
112 |
+
return ""
|
113 |
+
|
114 |
+
def extract_table_data(self, html_content: str) -> List[Dict]:
|
115 |
+
"""Extract table data from HTML content."""
|
116 |
+
if not html_content:
|
117 |
+
return []
|
118 |
+
|
119 |
+
soup = BeautifulSoup(html_content, 'html.parser')
|
120 |
+
tables = soup.find_all('table')
|
121 |
+
|
122 |
+
all_data = []
|
123 |
+
for table in tables:
|
124 |
+
rows = table.find_all('tr')
|
125 |
+
for row in rows[1:]: # Skip header row
|
126 |
+
cols = row.find_all('td')
|
127 |
+
if len(cols) >= 4: # Ensure we have all columns
|
128 |
+
test_link = self.get_test_link(cols[0])
|
129 |
+
data = {
|
130 |
+
'Test Name': cols[0].get_text(strip=True),
|
131 |
+
'Test Link': test_link,
|
132 |
+
'Description': self.get_test_description(test_link),
|
133 |
+
'Remote Testing': self.check_yes_no(cols[1]),
|
134 |
+
'Adaptive/IRT': self.check_yes_no(cols[2]),
|
135 |
+
'Test Type': cols[3].get_text(strip=True)
|
136 |
+
}
|
137 |
+
all_data.append(data)
|
138 |
+
return all_data
|
139 |
+
|
140 |
+
def scrape_all_tables(self, max_pages: int = 10):
|
141 |
+
"""Scrape tables from multiple pages."""
|
142 |
+
all_data = []
|
143 |
+
|
144 |
+
for start in range(0, max_pages * 12, 12): # Each page has 12 items
|
145 |
+
for type_num in range(1, 9): # Types 1-8
|
146 |
+
logging.info(f"Scraping page with start={start}, type={type_num}")
|
147 |
+
|
148 |
+
html_content = self.get_page_content(start, type_num)
|
149 |
+
if not html_content:
|
150 |
+
continue
|
151 |
+
|
152 |
+
page_data = self.extract_table_data(html_content)
|
153 |
+
if page_data:
|
154 |
+
all_data.extend(page_data)
|
155 |
+
logging.info(f"Found {len(page_data)} items on this page")
|
156 |
+
|
157 |
+
# Add delay to be respectful to the server
|
158 |
+
time.sleep(1)
|
159 |
+
|
160 |
+
return all_data
|
161 |
+
|
162 |
+
def save_to_csv(self, data: List[Dict], filename: str = 'shl_products.csv'):
|
163 |
+
"""Save scraped data to CSV file."""
|
164 |
+
if not data:
|
165 |
+
logging.warning("No data to save")
|
166 |
+
return
|
167 |
+
|
168 |
+
df = pd.DataFrame(data)
|
169 |
+
df.to_csv(filename, index=False)
|
170 |
+
logging.info(f"Saved {len(data)} records to {filename}")
|
171 |
+
|
172 |
+
def main():
|
173 |
+
scraper = SHLScraper()
|
174 |
+
logging.info("Starting SHL product catalog scraping...")
|
175 |
+
|
176 |
+
data = scraper.scrape_all_tables()
|
177 |
+
logging.info(f"Total records scraped: {len(data)}")
|
178 |
+
|
179 |
+
scraper.save_to_csv(data)
|
180 |
+
logging.info("Scraping completed!")
|
181 |
+
|
182 |
+
if __name__ == "__main__":
|
183 |
+
main()
|