pytorch / pages /21_NLP_Transformer.py
eaglelandsonce's picture
Create 21_NLP_Transformer.py
cd3cab0 verified
raw
history blame
3.14 kB
import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, AdamW, get_scheduler
from datasets import load_dataset
from tqdm.auto import tqdm
import streamlit as st
import matplotlib.pyplot as plt
# Load and preprocess the dataset
dataset = load_dataset("imdb")
train_dataset = dataset["train"]
test_dataset = dataset["test"]
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
def preprocess_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)
encoded_train_dataset = train_dataset.map(preprocess_function, batched=True)
encoded_test_dataset = test_dataset.map(preprocess_function, batched=True)
train_dataloader = DataLoader(encoded_train_dataset, shuffle=True, batch_size=8)
test_dataloader = DataLoader(encoded_test_dataset, batch_size=8)
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
optimizer = AdamW(model.parameters(), lr=5e-5)
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
# Training Loop with loss tracking
loss_values = []
model.train()
for epoch in range(num_epochs):
for batch in train_dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
loss_values.append(loss.item())
# Define evaluation function
def evaluate(model, dataloader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
correct += (predictions == batch["labels"]).sum().item()
total += batch["labels"].size(0)
return correct / total
# Evaluate the model on the test set
accuracy = evaluate(model, test_dataloader)
# Streamlit Interface
st.title("Sentiment Analysis with BERT")
st.write(f"Test Accuracy: {accuracy * 100:.2f}%")
# Plot loss values
st.write("### Training Loss")
plt.figure(figsize=(10, 6))
plt.plot(loss_values, label="Training Loss")
plt.xlabel("Training Steps")
plt.ylabel("Loss")
plt.legend()
st.pyplot(plt)
# Text input for prediction
st.write("### Predict Sentiment")
user_input = st.text_area("Enter text:", "I loved this movie!")
if user_input:
inputs = tokenizer(user_input, padding="max_length", truncation=True, max_length=512, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
model.eval()
with torch.no_grad():
outputs = model(**inputs)
prediction = outputs.logits.argmax(dim=-1).item()
sentiment = "Positive" if prediction == 1 else "Negative"
st.write(f"Sentiment: **{sentiment}**")