pytorch / pages /21_NLP_Transformer.py
eaglelandsonce's picture
Update pages/21_NLP_Transformer.py
5a1cec5 verified
raw
history blame
3.35 kB
import torch
from transformers import BertTokenizer, BertForSequenceClassification, AdamW, get_scheduler
from datasets import load_dataset
from torch.utils.data import DataLoader
import streamlit as st
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
# Load pre-trained model and tokenizer from Hugging Face
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
# Streamlit interface
st.title("Sentiment Analysis with BERT")
# Training setup
st.sidebar.title("Training Setup")
num_epochs = st.sidebar.slider("Number of Epochs", 1, 5, 3)
batch_size = st.sidebar.slider("Batch Size", 4, 32, 8)
learning_rate = st.sidebar.slider("Learning Rate", 1e-6, 1e-3, 5e-5, format="%.6f")
# Define a custom hash function for AddedToken type
@st.cache_data(hash_funcs={tokenizer.__class__: id})
def load_and_preprocess_data():
dataset = load_dataset("imdb", split="train[:1%]")
def preprocess_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)
encoded_dataset = dataset.map(preprocess_function, batched=True)
encoded_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
encoded_dataset = encoded_dataset.rename_column("label", "labels") # Rename the column to 'labels'
return DataLoader(encoded_dataset, shuffle=True, batch_size=batch_size)
train_dataloader = load_and_preprocess_data()
# Initialize training status
training_completed = st.sidebar.empty()
# Training loop
if st.sidebar.button("Train"):
optimizer = AdamW(model.parameters(), lr=learning_rate)
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)
progress_bar = tqdm(range(num_training_steps))
loss_values = []
model.train()
for epoch in range(num_epochs):
for batch in train_dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
loss_values.append(loss.item())
training_completed.success("Training completed")
# Plot loss values
st.write("### Training Loss")
plt.figure(figsize=(10, 6))
plt.plot(loss_values, label="Training Loss")
plt.xlabel("Training Steps")
plt.ylabel("Loss")
plt.legend()
st.pyplot(plt)
# Text input for prediction
st.write("### Predict Sentiment")
user_input = st.text_area("Enter text:", "I loved this movie!")
if user_input:
inputs = tokenizer(user_input, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
model.eval()
with torch.no_grad():
outputs = model(**inputs)
prediction = outputs.logits.argmax(dim=-1).item()
sentiment = "Positive" if prediction == 1 else "Negative"
st.write(f"Sentiment: **{sentiment}**")