Muhammad Eri Setyawan commited on
Commit
cb997de
·
1 Parent(s): 3c04e1d

Finalized model and inference script

Browse files
Files changed (1) hide show
  1. inference.py +54 -0
inference.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
+ from datetime import datetime
5
+
6
+ # Load model and tokenizer once when the script is initialized
7
+ MODEL_PATH = "." # Adjust this to match the path in your HF repo
8
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
9
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
10
+ model.eval()
11
+
12
+ # Mapping for label interpretation
13
+ label_mapping = {0: "Negative", 1: "Positive"}
14
+
15
+ def predict(inputs):
16
+ """
17
+ Function to handle prediction.
18
+ :param inputs: Dictionary with the text to be analyzed, e.g., {'text': 'I love this movie'}
19
+ :return: Dictionary with label and confidence score
20
+ """
21
+ try:
22
+ # Extract input text from the dictionary
23
+ input_text = inputs.get("text")
24
+ if not input_text:
25
+ return {"error": "Invalid input, 'text' key is required"}, 400
26
+
27
+ # Tokenize the input text
28
+ tokenized_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512)
29
+
30
+ # Perform prediction with the model
31
+ with torch.no_grad():
32
+ outputs = model(**tokenized_input)
33
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
34
+ confidence, label_idx = torch.max(probabilities, dim=1)
35
+ confidence = confidence.item() * 100 # Convert to percentage
36
+ label = label_mapping[label_idx.item()]
37
+
38
+ # Structure the response as a dictionary
39
+ response = {
40
+ "data": {
41
+ "confidence": f"{confidence:.2f}%",
42
+ "input_text": input_text,
43
+ "label": label
44
+ },
45
+ "model_version": "1.0.0",
46
+ "status": "success",
47
+ "timestamp": datetime.now().isoformat()
48
+ }
49
+
50
+ return response
51
+
52
+ except Exception as e:
53
+ # Handle errors gracefully
54
+ return {"error": str(e)}, 500