# MODIFY AS REQUIRED import torch import pandas as pd import seaborn as sns import numpy as np import matplotlib.pyplot as plt import plotly.express as px import plotly.graph_objects as go from plotly.subplots import make_subplots from sklearn.model_selection import train_test_split from datasets import load_dataset from datasets import Dataset, DatasetDict from transformers import DataCollatorWithPadding from torch.utils.data import DataLoader from transformers import AutoTokenizer, AutoModelForSequenceClassification from torch.optim import AdamW from torch.nn import BCEWithLogitsLoss from transformers import get_scheduler from tqdm.auto import tqdm import evaluate from tqdm import tqdm import logging logging.basicConfig(level=logging.INFO) from text_preprocessing import clean_tweet, clear_reply_mentions, normalizeTweet ''' DATA_PATH = "../../data" PROCESSED_PATH = f"{DATA_PATH}/processed" PROCESSED_PATH_VIRAL = f'{DATA_PATH}/new/processed/viral' PROCESSED_PATH_COVID = f'{DATA_PATH}/new/processed/covid' ''' # Different models BERT_BASE_UNCASED = "bert-base-uncased" BERT_BASE_CASED = "bert-base-cased" ROBERTA_BASE = "roberta-base" BERT_TWEET = "vinai/bertweet-base" # BERT_TWEET_LARGE = "vinai/bertweet-large" DEBERTA_V3 = "microsoft/deberta-v3-base" # TODO: Don't forget to cite papers if you use some model BERT_TINY = "prajjwal1/bert-tiny" TWEET_MAX_LENGTH = 280 # TEST SPLIT RATIO + MODELS (ADD MORE MODELS FROM ABOVE) # MODELS = [BERT_TWEET, BERT_TINY, BERT_BASE_CASED, ROBERTA_BASE] MODELS = [DEBERTA_V3] TEST_RATIO = 0.2 def preprocess_data(dataset): # remove tweets with 0 retweets (to eliminate their effects) #dataset = dataset[dataset.retweet_count > 0] ## UPDATE: Get tweets tweeted by the same user, on the same day he tweeted a viral tweet # Get the date from datetime # normalize() sets all datetimes clock to midnight, which is equivalent as keeping only the date part dataset['date'] = dataset.created_at.dt.normalize() viral_tweets = dataset[dataset.viral] non_viral_tweets = dataset[~dataset.viral] temp = non_viral_tweets.merge(viral_tweets[['author_id', 'date', 'id', 'viral']], on=['author_id', 'date'], suffixes=(None, '_y')) same_day_viral_ids = temp.id_y.unique() same_day_viral_tweets = viral_tweets[viral_tweets.id.isin(same_day_viral_ids)].drop_duplicates(subset=['author_id', 'date']) same_day_non_viral_tweets = temp.drop_duplicates(subset=['author_id', 'date']) logging.info(f"Number of viral tweets tweeted on the same day {len(same_day_viral_tweets)}") logging.info(f"Number of non viral tweets tweeted on the same day {len(same_day_non_viral_tweets)}") dataset = pd.concat([same_day_viral_tweets, same_day_non_viral_tweets], axis=0) dataset = dataset[['id', 'text', 'viral']] # Balance classes to have as many viral as non viral ones #dataset = pd.concat([positives, negatives.sample(n=len(positives))]) #dataset = pd.concat([positives.iloc[:100], negatives.sample(n=len(positives)).iloc[:200]]) # Clean text to prepare for tokenization #dataset = dataset.dropna() dataset.loc[:, "viral"] = dataset.viral.astype(int) # TODO: COMMENT IF YOU WANT TO KEEP TEXT AS IS dataset["cleaned_text"] = dataset.text.apply(lambda x: clean_tweet(x, demojize_emojis=False)) dataset = dataset.dropna() dataset = dataset[['id', 'cleaned_text', 'viral']] return dataset def prepare_dataset(sample_data, balance=False): # Split the train and test data st each has a fixed proportion of viral tweets train_dataset, eval_dataset = train_test_split(sample_data, test_size=TEST_RATIO, random_state=42, stratify=sample_data.viral) # Balance test set if balance: eval_virals = eval_dataset[eval_dataset.viral == 1] eval_non_virals = eval_dataset[eval_dataset.viral == 0] eval_dataset = pd.concat([eval_virals, eval_non_virals.sample(n=len(eval_virals))]) logging.info('{:>5,} training samples with {:>5,} positives and {:>5,} negatives'.format( len(train_dataset), len(train_dataset[train_dataset.viral == 1]), len(train_dataset[train_dataset.viral == 0]))) logging.info('{:>5,} validation samples with {:>5,} positives and {:>5,} negatives'.format( len(eval_dataset), len(eval_dataset[eval_dataset.viral == 1]), len(eval_dataset[eval_dataset.viral == 0]))) train_dataset.to_parquet("train.parquet.gzip", compression='gzip') eval_dataset.to_parquet("test.parquet.gzip", compression='gzip') ds = load_dataset("parquet", data_files={'train': 'train.parquet.gzip', 'test': 'test.parquet.gzip'}) return ds def tokenize_function(example, tokenizer): # Truncate to max length. Note that a tweet's maximum length is 280 # TODO: check dynamic padding: https://huggingface.co/course/chapter3/2?fw=pt#dynamic-padding return tokenizer(example["cleaned_text"], truncation=True) def test_all_models(ds, models=MODELS): models_losses = {} device = torch.device("mps") if torch.mps.is_available() else torch.device("cpu") output = "" for checkpoint in models: torch.mps.empty_cache() tokenizer = AutoTokenizer.from_pretrained(checkpoint) model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2) model.to(device) tokenized_datasets = ds.map(lambda x: tokenize_function(x, tokenizer=tokenizer), batched=True) data_collator = DataCollatorWithPadding(tokenizer=tokenizer) tokenized_datasets = tokenized_datasets.remove_columns(["__index_level_0__", "cleaned_text", "id"]) tokenized_datasets = tokenized_datasets.rename_column("viral", "labels") tokenized_datasets.set_format("torch") batch_size = 32 train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, batch_size=batch_size, collate_fn=data_collator) eval_dataloader = DataLoader(tokenized_datasets["test"], batch_size=batch_size, collate_fn=data_collator) criterion = BCEWithLogitsLoss() optimizer = AdamW(model.parameters(), lr=5e-5) optimizer = AdamW(model.parameters(), lr=5e-5) num_epochs = 15 num_training_steps = num_epochs * len(train_dataloader) lr_scheduler = get_scheduler( name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps ) progress_bar = tqdm(range(num_training_steps)) exp_loss = None losses = [] model.train() for epoch in range(num_epochs): for batch in train_dataloader: batch = {k: v.to(device) for k, v in batch.items()} outputs = model(**batch) loss = outputs.loss losses.append(loss.item()) loss.backward() if exp_loss is None: exp_loss = loss.cpu().item() else: exp_loss = 0.9 * exp_loss + 0.1 * loss.cpu().item() optimizer.step() lr_scheduler.step() optimizer.zero_grad() progress_bar.update(1) progress_bar.set_postfix({"loss": exp_loss, "epoch": epoch}) torch.save(model.state_dict(), f"models/trained_{checkpoint.replace('/', '_')}.pt") models_losses[checkpoint] = losses metric = evaluate.combine(["accuracy", "recall", "precision", "f1"]) model.eval() for batch in eval_dataloader: batch = {k: v.to(device) for k, v in batch.items()} with torch.no_grad(): outputs = model(**batch) logits = outputs.logits predictions = torch.argmax(logits, dim=-1) metric.add_batch(predictions=predictions, references=batch["labels"]) output += f"checkpoint: {checkpoint}: {metric.compute()}\n" logging.info(output) with open("same_day_as_viral_with_features_train_test_balanced_accuracy.txt", "w") as text_file: text_file.write(output) return models_losses def main(): # DATA FILE SHOULD BE AT THE ROOT WITH THIS SCRIPT all_tweets_labeled = pd.read_parquet(f'final_dataset_since_october_2022.parquet.gzip') dataset = preprocess_data(all_tweets_labeled) ds = prepare_dataset(dataset, balance=False) test_all_models(ds) if __name__ == "__main__": main()