samithva commited on
Commit
18630d4
·
verified ·
1 Parent(s): e9233a1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from khmernltk import word_tokenize
4
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
+
6
+ # Load your model and tokenizer
7
+ model = AutoModelForSequenceClassification.from_pretrained(
8
+ "./final_model",
9
+ # load_in_8bit=True, # Use if you want to load in 8-bit quantized format
10
+ # torch_dtype=torch.float16, # Use appropriate dtype based on your GPU
11
+ # device_map="cuda:0" # Automatically map model to available devices
12
+ )
13
+ tokenizer = AutoTokenizer.from_pretrained("./final_model")
14
+
15
+ # Ensure the model is in evaluation mode
16
+ model.eval()
17
+
18
+ class_labels = {
19
+ 0: "non-accident",
20
+ 1: "accident"
21
+ # Add more labels if you have more classes
22
+ }
23
+
24
+ # Define the inference function
25
+ def classify(text):
26
+
27
+ words = word_tokenize(text)
28
+ sent = ' '.join(words)
29
+ print(f'sent : {sent}')
30
+ encoded_dict = tokenizer.encode_plus(
31
+ sent, # Sentence to encode.
32
+ add_special_tokens = True, # Add '[CLS]' and '[SEP]'
33
+ max_length = 512, # 64 Pad & truncate all sentences.
34
+ pad_to_max_length = True,
35
+ return_attention_mask = True, # Construct attn. masks.
36
+ return_tensors = 'pt', # Return pytorch tensors.
37
+ )
38
+ input_ids = encoded_dict['input_ids']
39
+ attention_masks = encoded_dict['attention_mask']
40
+ with torch.no_grad(): # Disable gradient calculation
41
+ outputs = model(input_ids, attention_masks)
42
+ logits = outputs.logits
43
+ predictions = torch.argmax(logits, dim=-1)
44
+ return class_labels[predictions.item()]
45
+
46
+ # Set up Gradio interface
47
+ interface = gr.Interface(fn=classify,
48
+ inputs="text",
49
+ outputs="text",
50
+ title="Accident Classification",
51
+ description="Enter a text to classify it.")
52
+
53
+ # Launch the interface
54
+ interface.launch()