File size: 1,959 Bytes
18630d4
 
 
 
 
 
 
da03bb7
18630d4
 
 
 
da03bb7
18630d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0afe598
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import gradio as gr
import torch
from khmernltk import word_tokenize
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# Load your model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained(
    "./", 
    # load_in_8bit=True,  # Use if you want to load in 8-bit quantized format
    # torch_dtype=torch.float16,  # Use appropriate dtype based on your GPU
    # device_map="cuda:0"  # Automatically map model to available devices
)
tokenizer = AutoTokenizer.from_pretrained("./")

# Ensure the model is in evaluation mode
model.eval()

class_labels = {
    0: "non-accident",
    1: "accident"
    # Add more labels if you have more classes
}

# Define the inference function
def classify(text):
    
    words = word_tokenize(text)
    sent = ' '.join(words)
    print(f'sent : {sent}')
    encoded_dict = tokenizer.encode_plus(
                    sent,                      # Sentence to encode.
                    add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                    max_length = 512,      # 64 Pad & truncate all sentences.
                    pad_to_max_length = True,
                    return_attention_mask = True,   # Construct attn. masks.
                    return_tensors = 'pt',     # Return pytorch tensors.
                    )
    input_ids = encoded_dict['input_ids']
    attention_masks = encoded_dict['attention_mask']
    with torch.no_grad():  # Disable gradient calculation
        outputs = model(input_ids, attention_masks)
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    return class_labels[predictions.item()]

# Set up Gradio interface
interface = gr.Interface(fn=classify, 
                         inputs="text", 
                         outputs="text",
                         title="Accident Classification",
                         description="Enter a text to classify it.")

# Launch the interface
interface.launch(True)