|
from flask import Flask, render_template, request, jsonify |
|
from transformers import BertForSequenceClassification, BertTokenizer |
|
import torch |
|
|
|
app = Flask(__name__) |
|
|
|
|
|
model_state_dict = torch.load("bert_classifier_three_labeled.pth") |
|
|
|
|
|
model = BertForSequenceClassification.from_pretrained('bert-base-uncased') |
|
|
|
|
|
model.load_state_dict(model_state_dict) |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
|
def predict(prompt): |
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
outputs = model(**inputs) |
|
probs = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
return probs[0].tolist() |
|
|
|
@app.route('/', methods=['GET', 'POST']) |
|
def index(): |
|
result = None |
|
if request.method == 'POST': |
|
prompt = request.form['prompt'] |
|
result = predict(prompt) |
|
return render_template('index.html', result=result) |
|
|
|
if __name__ == '__main__': |
|
app.run(debug=True) |
|
|