KameronB's picture
Update README.md
b37695a verified
|
raw
history blame
1.62 kB
metadata
license: mit
language:
  - en
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import RobertaTokenizer, RobertaForSequenceClassification, AdamW
from sklearn.model_selection import train_test_split
import pandas as pd

# Load the tokenizer
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

# Load RoBERTa pre-trained model
model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=2)
model = model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))




def predict_description(model, tokenizer, text, max_length=512):
    model.eval()  # Set the model to evaluation mode
    
    # Ensure model is on the correct device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # Encode the input text
    inputs = tokenizer.encode_plus(
        text,
        None,
        add_special_tokens=True,
        max_length=max_length,
        padding='max_length',
        return_token_type_ids=False,
        return_tensors='pt',
        truncation=True
    )

    # Move tensors to the correct device
    inputs = {key: value.to(device) for key, value in inputs.items()}

    # Make prediction
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probabilities = torch.softmax(logits, dim=-1)
        predicted_class_id = torch.argmax(probabilities, dim=-1).item()

    return predicted_class_id


(['INCIDENT', 'REQUEST'])[predict_description(model, tokenizer, """My ID card is not being detected.""")]