Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer | |
from transformers import BertTokenizer, BertModel | |
from typing import Tuple, Dict, List | |
import pandas as pd | |
from tqdm import tqdm | |
class FeatureExtractor: | |
def __init__(self, bert_model_name: str = "bert-base-uncased"): | |
self.bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name) | |
self.bert_model = BertModel.from_pretrained(bert_model_name) | |
self.tfidf_vectorizer = TfidfVectorizer( | |
max_features=5000, | |
ngram_range=(1, 2), | |
stop_words='english' | |
) | |
self.count_vectorizer = CountVectorizer( | |
max_features=5000, | |
ngram_range=(1, 2), | |
stop_words='english' | |
) | |
def get_bert_embeddings(self, texts: List[str], | |
batch_size: int = 32, | |
max_length: int = 512) -> np.ndarray: | |
"""Extract BERT embeddings for a list of texts.""" | |
self.bert_model.eval() | |
embeddings = [] | |
with torch.no_grad(): | |
for i in tqdm(range(0, len(texts), batch_size)): | |
batch_texts = texts[i:i + batch_size] | |
# Tokenize and prepare input | |
encoded = self.bert_tokenizer( | |
batch_texts, | |
padding=True, | |
truncation=True, | |
max_length=max_length, | |
return_tensors='pt' | |
) | |
# Get BERT embeddings | |
outputs = self.bert_model(**encoded) | |
# Use [CLS] token embeddings as sentence representation | |
batch_embeddings = outputs.last_hidden_state[:, 0, :].numpy() | |
embeddings.append(batch_embeddings) | |
return np.vstack(embeddings) | |
def get_tfidf_features(self, texts: List[str]) -> np.ndarray: | |
"""Extract TF-IDF features from texts.""" | |
return self.tfidf_vectorizer.fit_transform(texts).toarray() | |
def get_count_features(self, texts: List[str]) -> np.ndarray: | |
"""Extract Count Vectorizer features from texts.""" | |
return self.count_vectorizer.fit_transform(texts).toarray() | |
def extract_all_features(self, texts: List[str], | |
use_bert: bool = True, | |
use_tfidf: bool = True, | |
use_count: bool = True) -> Dict[str, np.ndarray]: | |
"""Extract all features from texts.""" | |
features = {} | |
if use_bert: | |
features['bert'] = self.get_bert_embeddings(texts) | |
if use_tfidf: | |
features['tfidf'] = self.get_tfidf_features(texts) | |
if use_count: | |
features['count'] = self.get_count_features(texts) | |
return features | |
def extract_features_from_dataframe(self, | |
df: pd.DataFrame, | |
text_column: str, | |
**kwargs) -> Dict[str, np.ndarray]: | |
"""Extract features from a dataframe's text column.""" | |
texts = df[text_column].tolist() | |
return self.extract_all_features(texts, **kwargs) |