TruthCheck / src /data /feature_extractor.py
adnaan05's picture
Initial commit for Hugging Face Space
469c254
raw
history blame
3.3 kB
import numpy as np
import torch
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from transformers import BertTokenizer, BertModel
from typing import Tuple, Dict, List
import pandas as pd
from tqdm import tqdm
class FeatureExtractor:
def __init__(self, bert_model_name: str = "bert-base-uncased"):
self.bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
self.bert_model = BertModel.from_pretrained(bert_model_name)
self.tfidf_vectorizer = TfidfVectorizer(
max_features=5000,
ngram_range=(1, 2),
stop_words='english'
)
self.count_vectorizer = CountVectorizer(
max_features=5000,
ngram_range=(1, 2),
stop_words='english'
)
def get_bert_embeddings(self, texts: List[str],
batch_size: int = 32,
max_length: int = 512) -> np.ndarray:
"""Extract BERT embeddings for a list of texts."""
self.bert_model.eval()
embeddings = []
with torch.no_grad():
for i in tqdm(range(0, len(texts), batch_size)):
batch_texts = texts[i:i + batch_size]
# Tokenize and prepare input
encoded = self.bert_tokenizer(
batch_texts,
padding=True,
truncation=True,
max_length=max_length,
return_tensors='pt'
)
# Get BERT embeddings
outputs = self.bert_model(**encoded)
# Use [CLS] token embeddings as sentence representation
batch_embeddings = outputs.last_hidden_state[:, 0, :].numpy()
embeddings.append(batch_embeddings)
return np.vstack(embeddings)
def get_tfidf_features(self, texts: List[str]) -> np.ndarray:
"""Extract TF-IDF features from texts."""
return self.tfidf_vectorizer.fit_transform(texts).toarray()
def get_count_features(self, texts: List[str]) -> np.ndarray:
"""Extract Count Vectorizer features from texts."""
return self.count_vectorizer.fit_transform(texts).toarray()
def extract_all_features(self, texts: List[str],
use_bert: bool = True,
use_tfidf: bool = True,
use_count: bool = True) -> Dict[str, np.ndarray]:
"""Extract all features from texts."""
features = {}
if use_bert:
features['bert'] = self.get_bert_embeddings(texts)
if use_tfidf:
features['tfidf'] = self.get_tfidf_features(texts)
if use_count:
features['count'] = self.get_count_features(texts)
return features
def extract_features_from_dataframe(self,
df: pd.DataFrame,
text_column: str,
**kwargs) -> Dict[str, np.ndarray]:
"""Extract features from a dataframe's text column."""
texts = df[text_column].tolist()
return self.extract_all_features(texts, **kwargs)