Update model.py
Browse files
model.py
CHANGED
@@ -1,20 +1,15 @@
|
|
1 |
import torch
|
2 |
-
import
|
3 |
-
from transformers import BertTokenizer, BertForSequenceClassification, TFAutoModel
|
4 |
|
5 |
|
6 |
def get_model():
|
7 |
-
model =
|
8 |
return model
|
9 |
|
10 |
|
11 |
# Predicting Function
|
12 |
def predict(model, text, tokenizer):
|
13 |
-
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="
|
14 |
outputs = model(**inputs)
|
15 |
-
predictions =
|
16 |
-
return "AI-generated" if predictions.
|
17 |
-
|
18 |
-
# Example Usage (commented out as it's not needed for web deployment)
|
19 |
-
# user_input = input("Enter the text you want to classify: ")
|
20 |
-
# print("Classified as:", predict(user_input))
|
|
|
1 |
import torch
|
2 |
+
from transformers import BertTokenizer, BertForSequenceClassification
|
|
|
3 |
|
4 |
|
5 |
def get_model():
|
6 |
+
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
|
7 |
return model
|
8 |
|
9 |
|
10 |
# Predicting Function
|
11 |
def predict(model, text, tokenizer):
|
12 |
+
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
|
13 |
outputs = model(**inputs)
|
14 |
+
predictions = torch.argmax(outputs.logits, dim=-1)
|
15 |
+
return "AI-generated" if predictions.item() == 1 else "Human-written"
|
|
|
|
|
|
|
|