jaynopponep's picture
Update model.py
71170a2 verified
raw
history blame
513 Bytes
import torch
from transformers import BertTokenizer, BertForSequenceClassification
def get_model():
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
return model
# Predicting Function
def predict(model, text, tokenizer):
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
outputs = model(**inputs)
predictions = torch.argmax(outputs.logits, dim=-1)
return "AI-generated" if predictions.item() == 1 else "Human-written"