ducdatit2002 commited on
Commit
2d5c298
·
verified ·
1 Parent(s): d1e6bcf

Upload api.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. api.py +201 -0
api.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # demo_phobert_api.py
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from fastapi import FastAPI
5
+ from pydantic import BaseModel
6
+ import torch
7
+ import re
8
+ import json
9
+ import emoji
10
+ from underthesea import word_tokenize
11
+ from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
12
+
13
+ # Khởi tạo FastAPI app
14
+ app = FastAPI(
15
+ title="PhoBERT Emotion Classification API",
16
+ description="API dự đoán cảm xúc của câu tiếng Việt sử dụng PhoBERT.",
17
+ version="1.0"
18
+ )
19
+
20
+ ###############################################################################
21
+ # TẢI MAPPING EMOJI - COPY Y NGUYÊN TỪ FILE TRAIN
22
+ ###############################################################################
23
+ emoji_mapping = {
24
+ "😀": "[joy]", "😃": "[joy]", "😄": "[joy]", "😁": "[joy]", "😆": "[joy]", "😅": "[joy]", "😂": "[joy]", "🤣": "[joy]",
25
+ "🙂": "[love]", "🙃": "[love]", "😉": "[love]", "😊": "[love]", "😇": "[love]", "🥰": "[love]", "😍": "[love]",
26
+ "🤩": "[love]", "😘": "[love]", "😗": "[love]", "☺": "[love]", "😚": "[love]", "😙": "[love]",
27
+ "😋": "[satisfaction]", "😛": "[satisfaction]", "😜": "[satisfaction]", "🤪": "[satisfaction]", "😝": "[satisfaction]",
28
+ "🤑": "[satisfaction]",
29
+ "🤐": "[neutral]", "🤨": "[neutral]", "😐": "[neutral]", "😑": "[neutral]", "😶": "[neutral]",
30
+ "😏": "[sarcasm]",
31
+ "😒": "[disappointment]", "🙄": "[disappointment]", "😬": "[disappointment]",
32
+ "😔": "[sadness]", "😪": "[sadness]", "😢": "[sadness]", "😭": "[sadness]", "😥": "[sadness]", "😓": "[sadness]",
33
+ "😩": "[tiredness]", "😫": "[tiredness]", "🥱": "[tiredness]",
34
+ "🤤": "[discomfort]", "🤢": "[discomfort]", "🤮": "[discomfort]", "🤧": "[discomfort]", "🥵": "[discomfort]",
35
+ "🥶": "[discomfort]", "🥴": "[discomfort]", "😵": "[discomfort]", "🤯": "[discomfort]",
36
+ "😕": "[confused]", "😟": "[confused]", "🙁": "[confused]", "☹": "[confused]",
37
+ "😮": "[surprise]", "😯": "[surprise]", "😲": "[surprise]", "😳": "[surprise]", "🥺": "[pleading]",
38
+ "😦": "[fear]", "😧": "[fear]", "😨": "[fear]", "😰": "[fear]", "😱": "[fear]",
39
+ "😖": "[confusion]", "😣": "[confusion]", "😞": "[confusion]",
40
+ "😤": "[anger]", "😡": "[anger]", "😠": "[anger]", "🤬": "[anger]", "😈": "[mischievous]", "👿": "[mischievous]"
41
+ }
42
+
43
+ ###############################################################################
44
+ # HÀM XỬ LÝ (COPY TỪ FILE TRAIN)
45
+ ###############################################################################
46
+ def replace_emojis(sentence, emoji_mapping):
47
+ processed_sentence = []
48
+ for char in sentence:
49
+ if char in emoji_mapping:
50
+ processed_sentence.append(emoji_mapping[char])
51
+ elif not emoji.is_emoji(char):
52
+ processed_sentence.append(char)
53
+ return ''.join(processed_sentence)
54
+
55
+ def remove_profanity(sentence):
56
+ profane_words = ["loz", "vloz", "vl", "dm", "đm", "clgt", "dmm", "cc", "vc", "đù mé", "vãi"]
57
+ words = sentence.split()
58
+ filtered = [w for w in words if w.lower() not in profane_words]
59
+ return ' '.join(filtered)
60
+
61
+ def remove_special_characters(sentence):
62
+ return re.sub(r"[\^\*@#&$%<>~{}|\\]", "", sentence)
63
+
64
+ def normalize_whitespace(sentence):
65
+ return ' '.join(sentence.split())
66
+
67
+ def remove_repeated_characters(sentence):
68
+ return re.sub(r"(.)\1{2,}", r"\1", sentence)
69
+
70
+ def replace_numbers(sentence):
71
+ return re.sub(r"\d+", "[number]", sentence)
72
+
73
+ def tokenize_underthesea(sentence):
74
+ tokens = word_tokenize(sentence)
75
+ return " ".join(tokens)
76
+
77
+ # Nếu có abbreviations.json, load nó. Nếu không thì để rỗng.
78
+ try:
79
+ with open("abbreviations.json", "r", encoding="utf-8") as f:
80
+ abbreviations = json.load(f)
81
+ except Exception as e:
82
+ abbreviations = {}
83
+
84
+ def preprocess_sentence(sentence):
85
+ sentence = sentence.lower()
86
+ sentence = replace_emojis(sentence, emoji_mapping)
87
+ sentence = remove_profanity(sentence)
88
+ sentence = remove_special_characters(sentence)
89
+ sentence = normalize_whitespace(sentence)
90
+ # Thay thế từ viết tắt nếu có trong abbreviations
91
+ words = sentence.split()
92
+ replaced = []
93
+ for w in words:
94
+ if w in abbreviations:
95
+ replaced.append(" ".join(abbreviations[w]))
96
+ else:
97
+ replaced.append(w)
98
+ sentence = " ".join(replaced)
99
+ sentence = remove_repeated_characters(sentence)
100
+ sentence = replace_numbers(sentence)
101
+ sentence = tokenize_underthesea(sentence)
102
+ return sentence
103
+
104
+ ###############################################################################
105
+ # LOAD CHECKPOINT
106
+ ###############################################################################
107
+ checkpoint_dir = "./checkpoint" # Đường dẫn đến folder checkpoint
108
+ device = "cuda" if torch.cuda.is_available() else "cpu"
109
+
110
+ print("Loading config...")
111
+ config = AutoConfig.from_pretrained(checkpoint_dir)
112
+
113
+ # Mapping id to label theo thứ tự bạn cung cấp
114
+ custom_id2label = {
115
+ 0: 'Anger',
116
+ 1: 'Disgust',
117
+ 2: 'Enjoyment',
118
+ 3: 'Fear',
119
+ 4: 'Other',
120
+ 5: 'Sadness',
121
+ 6: 'Surprise'
122
+ }
123
+
124
+ if hasattr(config, "id2label") and config.id2label:
125
+ if all(label.startswith("LABEL_") for label in config.id2label.values()):
126
+ id2label = custom_id2label
127
+ else:
128
+ id2label = {int(k): v for k, v in config.id2label.items()}
129
+ else:
130
+ id2label = custom_id2label
131
+
132
+ print("id2label loaded:", id2label)
133
+
134
+ print("Loading tokenizer...")
135
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir)
136
+
137
+ print("Loading model...")
138
+ model = AutoModelForSequenceClassification.from_pretrained(checkpoint_dir, config=config)
139
+ model.to(device)
140
+ model.eval()
141
+
142
+ ###############################################################################
143
+ # HÀM PREDICT
144
+ ###############################################################################
145
+ label2message = {
146
+ 'Anger': 'Hãy bình tĩnh và giải quyết vấn đề một cách bình thản.',
147
+ 'Disgust': 'Hãy tránh xa những thứ khiến bạn không thích.',
148
+ 'Enjoyment': 'Chúc mừng bạn có một ngày tuyệt vời!',
149
+ 'Fear': 'Hãy đối mặt với nỗi sợ để vượt qua chúng.',
150
+ 'Other': 'Cảm xúc của bạn hiện tại không được phân loại rõ ràng.',
151
+ 'Sadness': 'Hãy tìm kiếm sự hỗ trợ khi cần thiết.',
152
+ 'Surprise': 'Thật bất ngờ! Hãy tận hưởng khoảnh khắc này.'
153
+ }
154
+
155
+ def predict_text(text: str) -> str:
156
+ text_proc = preprocess_sentence(text)
157
+ inputs = tokenizer(
158
+ [text_proc],
159
+ padding=True,
160
+ truncation=True,
161
+ max_length=256,
162
+ return_tensors="pt"
163
+ ).to(device)
164
+
165
+ with torch.no_grad():
166
+ outputs = model(**inputs)
167
+ pred_id = outputs.logits.argmax(dim=-1).item()
168
+
169
+ if pred_id in id2label:
170
+ label = id2label[pred_id]
171
+ message = label2message.get(label, "")
172
+ if message:
173
+ return f"Dự đoán cảm xúc: {label}. {message}"
174
+ else:
175
+ return f"Dự đoán cảm xúc: {label}."
176
+ else:
177
+ return f"Nhãn không xác định (id={pred_id})"
178
+
179
+ ###############################################################################
180
+ # ĐỊNH NGHĨA MODEL INPUT
181
+ ###############################################################################
182
+ class InputText(BaseModel):
183
+ text: str
184
+
185
+ ###############################################################################
186
+ # API ENDPOINT
187
+ ###############################################################################
188
+ @app.post("/predict")
189
+ def predict(input_text: InputText):
190
+ """
191
+ Nhận một câu tiếng Việt và trả về dự đoán cảm xúc.
192
+ """
193
+ result = predict_text(input_text.text)
194
+ return {"result": result}
195
+
196
+ ###############################################################################
197
+ # CHẠY API SERVER
198
+ ###############################################################################
199
+ if __name__ == "__main__":
200
+ import uvicorn
201
+ uvicorn.run(app, host="0.0.0.0", port=8000)