File size: 1,446 Bytes
69fc6bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import streamlit as st
import torch
from transformers import BertTokenizer, BertForSequenceClassification

def load_model():
    tokenizer = BertTokenizer.from_pretrained("BERT_GED")
    model = BertForSequenceClassification.from_pretrained("BERT_GED")
    return model, tokenizer

def predict(model, tokenizer, sentence):
    # Tokenize sentence
    encoded_dict = tokenizer.encode_plus(
        sentence, 
        add_special_tokens=True,
        max_length=64, 
        padding="max_length",
        truncation=True,
        return_attention_mask=True, 
        return_tensors='pt', 
    )
    input_ids = encoded_dict['input_ids']
    attention_mask = encoded_dict['attention_mask']

    # Model inference
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
    
    logits = outputs.logits
    index = torch.argmax(logits, -1).item()  # Get the predicted class (0 or 1)

    if index == 1:
        return "perfect"
    else:
        return "not right!!"
    
def main():
    st.title("Grammatical Correctness Predictor")
    sentence = st.text_area("Sentence to analyze:")

    if st.button("Analyze"):
        if sentence:
            model, tokenizer = load_model()
            prediction = predict(model, tokenizer, sentence)
            st.write(f'"{sentence}" is grammatically {prediction}')
        else:
            st.warning("Please enter a sentence.")

if __name__ == "__main__":
    main()