import streamlit as st
import pandas as pd
import numpy as np
import torch
import nltk
import os
import tempfile
import base64
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, CrossEncoder
from nltk.tokenize import word_tokenize
import pdfplumber
import PyPDF2
from docx import Document
import csv
from datasets import load_dataset
import gc
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import faiss
import re
# Download NLTK resources
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')
# Set page configuration
st.set_page_config(
page_title="AI Resume Screener",
page_icon="๐ฏ",
layout="wide",
initial_sidebar_state="expanded"
)
# --- Global Device and Model Loading Section ---
# Initialize session state keys for all models, their loading status/errors, and app data
keys_to_initialize = {
'embedding_model': None, 'embedding_model_error': None,
'cross_encoder': None, 'cross_encoder_error': None,
'qwen3_1_7b_tokenizer': None, 'qwen3_1_7b_tokenizer_error': None,
'qwen3_1_7b_model': None, 'qwen3_1_7b_model_error': None,
'results': [], 'resume_texts': [], 'file_names': [], 'current_job_description': ""
# Add any other app-specific session state keys here if needed
}
for key, default_value in keys_to_initialize.items():
if key not in st.session_state:
st.session_state[key] = default_value
# Load Embedding Model (BAAI/bge-large-en-v1.5)
if st.session_state.embedding_model is None and st.session_state.embedding_model_error is None:
print("[Global Init] Attempting to load Embedding Model (BAAI/bge-large-en-v1.5) with device_map='auto'...")
try:
st.session_state.embedding_model = SentenceTransformer(
'BAAI/bge-large-en-v1.5',
device_map="auto"
)
print(f"[Global Init] Embedding Model (BAAI/bge-large-en-v1.5) LOADED with device_map='auto'.")
except Exception as e:
if "device_map" in str(e).lower() and "unexpected keyword argument" in str(e).lower():
print("โ ๏ธ [Global Init] device_map='auto' not supported for SentenceTransformer. Falling back to default device handling.")
try:
st.session_state.embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5')
print(f"[Global Init] Embedding Model (BAAI/bge-large-en-v1.5) LOADED (fallback device handling).")
except Exception as e_fallback:
error_msg = f"Failed to load Embedding Model (fallback): {str(e_fallback)}"
print(f"โ [Global Init] {error_msg}")
st.session_state.embedding_model_error = error_msg
else:
error_msg = f"Failed to load Embedding Model: {str(e)}"
print(f"โ [Global Init] {error_msg}")
st.session_state.embedding_model_error = error_msg
# Load Cross-Encoder Model (ms-marco-MiniLM-L6-v2)
if st.session_state.cross_encoder is None and st.session_state.cross_encoder_error is None:
print("[Global Init] Attempting to load Cross-Encoder Model (ms-marco-MiniLM-L6-v2) with device_map='auto'...")
try:
st.session_state.cross_encoder = CrossEncoder(
'cross-encoder/ms-marco-MiniLM-L6-v2',
device_map="auto"
)
print(f"[Global Init] Cross-Encoder Model (ms-marco-MiniLM-L6-v2) LOADED with device_map='auto'.")
except Exception as e:
if "device_map" in str(e).lower() and "unexpected keyword argument" in str(e).lower():
print("โ ๏ธ [Global Init] device_map='auto' not supported for CrossEncoder. Falling back to default device handling.")
try:
st.session_state.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L6-v2')
print(f"[Global Init] Cross-Encoder Model (ms-marco-MiniLM-L6-v2) LOADED (fallback device handling).")
except Exception as e_fallback:
error_msg = f"Failed to load Cross-Encoder Model (fallback): {str(e_fallback)}"
print(f"โ [Global Init] {error_msg}")
st.session_state.cross_encoder_error = error_msg
else:
error_msg = f"Failed to load Cross-Encoder Model: {str(e)}"
print(f"โ [Global Init] {error_msg}")
st.session_state.cross_encoder_error = error_msg
# Load Qwen3-1.7B Tokenizer
if st.session_state.qwen3_1_7b_tokenizer is None and st.session_state.qwen3_1_7b_tokenizer_error is None:
print("[Global Init] Loading Qwen3-1.7B Tokenizer...")
try:
st.session_state.qwen3_1_7b_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-1.7B")
print("[Global Init] Qwen3-1.7B Tokenizer Loaded.")
except Exception as e:
error_msg = f"Failed to load Qwen3-1.7B Tokenizer: {str(e)}"
print(f"โ [Global Init] {error_msg}")
st.session_state.qwen3_1_7b_tokenizer_error = error_msg
# Load Qwen3-1.7B Model
if st.session_state.qwen3_1_7b_model is None and st.session_state.qwen3_1_7b_model_error is None:
print("[Global Init] Loading Qwen3-1.7B Model (attempting with device_map='auto')...")
try:
st.session_state.qwen3_1_7b_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-1.7B",
torch_dtype="auto",
device_map="auto",
trust_remote_code=True # if required by this specific model
)
print("[Global Init] Qwen3-1.7B Model Loaded with device_map='auto'.")
except Exception as e_dev_map:
print(f"โ ๏ธ [Global Init] Failed to load Qwen3-1.7B with device_map='auto': {str(e_dev_map)}")
print("[Global Init] Retrying Qwen3-1.7B load without device_map (will use default single device)...")
try:
st.session_state.qwen3_1_7b_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-1.7B",
torch_dtype="auto",
# No device_map here, let Hugging Face decide or use CUDA if available
trust_remote_code=True # if required
)
print("[Global Init] Qwen3-1.7B Model Loaded (fallback device handling).")
except Exception as e_fallback:
error_msg = f"Failed to load Qwen3-1.7B Model (fallback): {str(e_fallback)}"
print(f"โ [Global Init] {error_msg}")
st.session_state.qwen3_1_7b_model_error = error_msg
# --- End of Global Model Loading Section ---
# --- Class Definitions and Helper Functions ---
def generate_qwen3_response(prompt, tokenizer, model, max_new_tokens=200):
# ... (implementation of generate_qwen3_response)
messages = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True # As per Qwen3-1.7B docs for thinking mode
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=max_new_tokens
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
response = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
return response
class ResumeScreener: # Ensure this class definition is BEFORE it's instantiated
def __init__(self):
# ... (init logic as before, referencing st.session_state for models)
print("[ResumeScreener] Initializing with references to globally loaded models...")
self.embedding_model = st.session_state.get('embedding_model')
self.cross_encoder = st.session_state.get('cross_encoder')
if self.embedding_model:
print("[ResumeScreener] Embedding model reference set.")
else:
print("[ResumeScreener] Embedding model not available (check loading errors).")
if self.cross_encoder:
print("[ResumeScreener] Cross-encoder model reference set.")
else:
print("[ResumeScreener] Cross-encoder model not available (check loading errors).")
print("[ResumeScreener] Initialization complete.")
# ... (all other methods of ResumeScreener: extract_text_from_file, get_embedding,
# calculate_bm25_scores, advanced_pipeline_ranking, faiss_recall, cross_encoder_rerank,
# add_bm25_scores, add_intent_scores, analyze_intent, calculate_final_scores, extract_skills)
# Make sure all methods are correctly indented within the class
def extract_text_from_file(self, file_path, file_type):
# ... (implementation)
try:
if file_type == "pdf":
with open(file_path, 'rb') as file:
with pdfplumber.open(file) as pdf:
text = ""
for page in pdf.pages:
text += page.extract_text() or ""
if not text.strip():
file.seek(0)
reader = PyPDF2.PdfReader(file)
text = ""
for page_num in range(len(reader.pages)):
text += reader.pages[page_num].extract_text() or ""
return text
elif file_type == "docx":
doc = Document(file_path)
return " ".join([paragraph.text for paragraph in doc.paragraphs])
elif file_type == "txt":
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
elif file_type == "csv":
with open(file_path, 'r', encoding='utf-8') as file:
csv_reader = csv.reader(file)
return " ".join([" ".join(row) for row in csv_reader])
except Exception as e:
st.error(f"Error extracting text from {file_path}: {str(e)}")
return ""
def get_embedding(self, text):
if self.embedding_model is None:
st.error("Embedding model is not available!")
return np.zeros(1024)
try:
if len(text) < 500:
text = "Represent this sentence for searching relevant passages: " + text
text = text[:8192] if text else ""
embedding = self.embedding_model.encode(text, convert_to_numpy=True, normalize_embeddings=True)
return embedding
except Exception as e:
st.error(f"Error generating embedding: {str(e)}")
return np.zeros(1024)
def calculate_bm25_scores(self, resume_texts, job_description):
try:
job_tokens = word_tokenize(job_description.lower())
corpus = [word_tokenize(text.lower()) for text in resume_texts if text and text.strip()]
if not corpus:
return [0.0] * len(resume_texts)
bm25 = BM25Okapi(corpus)
scores = bm25.get_scores(job_tokens)
return scores.tolist()
except Exception as e:
st.error(f"Error calculating BM25 scores: {str(e)}")
return [0.0] * len(resume_texts)
def advanced_pipeline_ranking(self, resume_texts, job_description):
print("[Pipeline] Advanced Pipeline Ranking started.")
if not resume_texts:
return []
st.info("๐ Stage 1: FAISS Recall - Finding top candidates...")
top_50_indices = self.faiss_recall(resume_texts, job_description, top_k=50)
st.info("๐ฏ Stage 2: Cross-Encoder Re-ranking - Selecting top candidates...")
top_20_results = self.cross_encoder_rerank(resume_texts, job_description, top_50_indices, top_k=20)
st.info("๐ค Stage 3: BM25 Keyword Matching...")
top_20_with_bm25 = self.add_bm25_scores(resume_texts, job_description, top_20_results)
st.info("๐ค Stage 4: LLM Intent Analysis (Qwen3-1.7B)...")
top_20_with_intent = self.add_intent_scores(resume_texts, job_description, top_20_with_bm25)
st.info("๐ Stage 5: Final Combined Ranking...")
final_results = self.calculate_final_scores(top_20_with_intent)
print("[Pipeline] Advanced Pipeline Ranking finished.")
return final_results[:st.session_state.get('top_k', 5)]
def faiss_recall(self, resume_texts, job_description, top_k=50):
print("[faiss_recall] Method started.")
st.text("FAISS Recall: Embedding job description...")
job_embedding = self.get_embedding(job_description)
st.text(f"FAISS Recall: Embedding {len(resume_texts)} resumes...")
resume_embeddings = []
progress_bar = st.progress(0)
for i, text in enumerate(resume_texts):
if text:
embedding = self.embedding_model.encode(text[:8192], convert_to_numpy=True, normalize_embeddings=True)
resume_embeddings.append(embedding)
else:
resume_embeddings.append(np.zeros(1024))
progress_bar.progress((i + 1) / len(resume_texts))
progress_bar.empty()
resume_embeddings_np = np.array(resume_embeddings).astype('float32') # Renamed variable
if resume_embeddings_np.ndim == 1: # Handle case of single resume
resume_embeddings_np = resume_embeddings_np.reshape(1, -1)
if resume_embeddings_np.size == 0:
print("[faiss_recall] No resume embeddings to add to FAISS index.")
return [] # Or handle error appropriately
dimension = resume_embeddings_np.shape[1]
index = faiss.IndexFlatIP(dimension)
index.add(resume_embeddings_np)
job_embedding_np = job_embedding.reshape(1, -1).astype('float32') # Renamed variable
scores, indices = index.search(job_embedding_np, min(top_k, len(resume_texts)))
return indices[0].tolist()
def cross_encoder_rerank(self, resume_texts, job_description, top_50_indices, top_k=20):
print("[cross_encoder_rerank] Method started.")
if not self.cross_encoder:
st.error("Cross-encoder model is not available!")
return [(idx, 0.0) for idx in top_50_indices[:top_k]]
pairs = []
valid_indices = []
for idx in top_50_indices:
if idx < len(resume_texts) and resume_texts[idx]:
job_snippet = job_description[:512]
resume_snippet = resume_texts[idx][:512]
pairs.append([job_snippet, resume_snippet])
valid_indices.append(idx)
if not pairs:
return [(idx, 0.0) for idx in top_50_indices[:top_k]]
st.text(f"Cross-Encoder: Preparing {len(pairs)} pairs for re-ranking...")
scores = []
batch_size = 8
progress_bar = st.progress(0)
for i in range(0, len(pairs), batch_size):
batch = pairs[i:i+batch_size]
batch_scores = self.cross_encoder.predict(batch)
scores.extend(batch_scores)
progress_bar.progress(min(1.0, (i + batch_size) / len(pairs)))
progress_bar.empty()
indexed_scores = list(zip(valid_indices, scores))
indexed_scores.sort(key=lambda x: x[1], reverse=True)
return indexed_scores[:top_k]
def add_bm25_scores(self, resume_texts, job_description, top_20_results):
st.text("BM25: Calculating keyword scores...")
top_20_texts = [resume_texts[idx] for idx, _ in top_20_results]
bm25_scores_raw = self.calculate_bm25_scores(top_20_texts, job_description)
if bm25_scores_raw and max(bm25_scores_raw) > 0:
max_bm25, min_bm25 = max(bm25_scores_raw), min(bm25_scores_raw)
if max_bm25 > min_bm25:
normalized_bm25 = [0.1 + 0.1 * (s - min_bm25) / (max_bm25 - min_bm25) for s in bm25_scores_raw]
else:
normalized_bm25 = [0.15] * len(bm25_scores_raw)
else:
normalized_bm25 = [0.15] * len(top_20_results)
results_with_bm25 = []
for i, (idx, cross_score) in enumerate(top_20_results):
results_with_bm25.append((idx, cross_score, normalized_bm25[i] if i < len(normalized_bm25) else 0.15))
return results_with_bm25
def add_intent_scores(self, resume_texts, job_description, top_20_with_bm25):
st.text(f"LLM Intent: Analyzing intent for {len(top_20_with_bm25)} candidates (Qwen3-1.7B)...")
results_with_intent = []
progress_bar = st.progress(0)
for i, (idx, cross_score, bm25_score) in enumerate(top_20_with_bm25):
intent_score = self.analyze_intent(resume_texts[idx], job_description)
results_with_intent.append((idx, cross_score, bm25_score, intent_score))
progress_bar.progress((i + 1) / len(top_20_with_bm25))
progress_bar.empty()
return results_with_intent
def analyze_intent(self, resume_text, job_description):
print(f"[analyze_intent] Analyzing intent for one resume (Qwen3-1.7B)...")
st.text("LLM Intent: Analyzing intent (Qwen3-1.7B)...")
try:
resume_snippet = resume_text[:15000]
job_snippet = job_description[:5000]
prompt = f\"\"\"You are given a job description and a candidate's resume... (rest of prompt)\"\"\" # Ensure f-string is correct
# ... (rest of analyze_intent, using st.session_state.qwen3_1_7b_tokenizer and _model)
response_text = generate_qwen3_response(
prompt,
st.session_state.qwen3_1_7b_tokenizer,
st.session_state.qwen3_1_7b_model,
max_new_tokens=20000
)
# ... (parsing logic for response_text) ...
thinking_content = "No detailed thought process extracted."
intent_decision_part = response_text
think_start_tag = "