ducdatit2002 commited on
Commit
0033618
ยท
verified ยท
1 Parent(s): 82df4a4

Delete api.py

Browse files
Files changed (1) hide show
  1. api.py +0 -201
api.py DELETED
@@ -1,201 +0,0 @@
1
- # demo_phobert_api.py
2
- # -*- coding: utf-8 -*-
3
-
4
- from fastapi import FastAPI
5
- from pydantic import BaseModel
6
- import torch
7
- import re
8
- import json
9
- import emoji
10
- from underthesea import word_tokenize
11
- from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
12
-
13
- # KhแปŸi tแบกo FastAPI app
14
- app = FastAPI(
15
- title="PhoBERT Emotion Classification API",
16
- description="API dแปฑ ฤ‘oรกn cแบฃm xรบc cแปงa cรขu tiแบฟng Viแป‡t sแปญ dแปฅng PhoBERT.",
17
- version="1.0"
18
- )
19
-
20
- ###############################################################################
21
- # TแบขI MAPPING EMOJI - COPY Y NGUYรŠN Tแปช FILE TRAIN
22
- ###############################################################################
23
- emoji_mapping = {
24
- "๐Ÿ˜€": "[joy]", "๐Ÿ˜ƒ": "[joy]", "๐Ÿ˜„": "[joy]", "๐Ÿ˜": "[joy]", "๐Ÿ˜†": "[joy]", "๐Ÿ˜…": "[joy]", "๐Ÿ˜‚": "[joy]", "๐Ÿคฃ": "[joy]",
25
- "๐Ÿ™‚": "[love]", "๐Ÿ™ƒ": "[love]", "๐Ÿ˜‰": "[love]", "๐Ÿ˜Š": "[love]", "๐Ÿ˜‡": "[love]", "๐Ÿฅฐ": "[love]", "๐Ÿ˜": "[love]",
26
- "๐Ÿคฉ": "[love]", "๐Ÿ˜˜": "[love]", "๐Ÿ˜—": "[love]", "โ˜บ": "[love]", "๐Ÿ˜š": "[love]", "๐Ÿ˜™": "[love]",
27
- "๐Ÿ˜‹": "[satisfaction]", "๐Ÿ˜›": "[satisfaction]", "๐Ÿ˜œ": "[satisfaction]", "๐Ÿคช": "[satisfaction]", "๐Ÿ˜": "[satisfaction]",
28
- "๐Ÿค‘": "[satisfaction]",
29
- "๐Ÿค": "[neutral]", "๐Ÿคจ": "[neutral]", "๐Ÿ˜": "[neutral]", "๐Ÿ˜‘": "[neutral]", "๐Ÿ˜ถ": "[neutral]",
30
- "๐Ÿ˜": "[sarcasm]",
31
- "๐Ÿ˜’": "[disappointment]", "๐Ÿ™„": "[disappointment]", "๐Ÿ˜ฌ": "[disappointment]",
32
- "๐Ÿ˜”": "[sadness]", "๐Ÿ˜ช": "[sadness]", "๐Ÿ˜ข": "[sadness]", "๐Ÿ˜ญ": "[sadness]", "๐Ÿ˜ฅ": "[sadness]", "๐Ÿ˜“": "[sadness]",
33
- "๐Ÿ˜ฉ": "[tiredness]", "๐Ÿ˜ซ": "[tiredness]", "๐Ÿฅฑ": "[tiredness]",
34
- "๐Ÿคค": "[discomfort]", "๐Ÿคข": "[discomfort]", "๐Ÿคฎ": "[discomfort]", "๐Ÿคง": "[discomfort]", "๐Ÿฅต": "[discomfort]",
35
- "๐Ÿฅถ": "[discomfort]", "๐Ÿฅด": "[discomfort]", "๐Ÿ˜ต": "[discomfort]", "๐Ÿคฏ": "[discomfort]",
36
- "๐Ÿ˜•": "[confused]", "๐Ÿ˜Ÿ": "[confused]", "๐Ÿ™": "[confused]", "โ˜น": "[confused]",
37
- "๐Ÿ˜ฎ": "[surprise]", "๐Ÿ˜ฏ": "[surprise]", "๐Ÿ˜ฒ": "[surprise]", "๐Ÿ˜ณ": "[surprise]", "๐Ÿฅบ": "[pleading]",
38
- "๐Ÿ˜ฆ": "[fear]", "๐Ÿ˜ง": "[fear]", "๐Ÿ˜จ": "[fear]", "๐Ÿ˜ฐ": "[fear]", "๐Ÿ˜ฑ": "[fear]",
39
- "๐Ÿ˜–": "[confusion]", "๐Ÿ˜ฃ": "[confusion]", "๐Ÿ˜ž": "[confusion]",
40
- "๐Ÿ˜ค": "[anger]", "๐Ÿ˜ก": "[anger]", "๐Ÿ˜ ": "[anger]", "๐Ÿคฌ": "[anger]", "๐Ÿ˜ˆ": "[mischievous]", "๐Ÿ‘ฟ": "[mischievous]"
41
- }
42
-
43
- ###############################################################################
44
- # Hร€M Xแปฌ Lร (COPY Tแปช FILE TRAIN)
45
- ###############################################################################
46
- def replace_emojis(sentence, emoji_mapping):
47
- processed_sentence = []
48
- for char in sentence:
49
- if char in emoji_mapping:
50
- processed_sentence.append(emoji_mapping[char])
51
- elif not emoji.is_emoji(char):
52
- processed_sentence.append(char)
53
- return ''.join(processed_sentence)
54
-
55
- def remove_profanity(sentence):
56
- profane_words = ["loz", "vloz", "vl", "dm", "ฤ‘m", "clgt", "dmm", "cc", "vc", "ฤ‘รน mรฉ", "vรฃi"]
57
- words = sentence.split()
58
- filtered = [w for w in words if w.lower() not in profane_words]
59
- return ' '.join(filtered)
60
-
61
- def remove_special_characters(sentence):
62
- return re.sub(r"[\^\*@#&$%<>~{}|\\]", "", sentence)
63
-
64
- def normalize_whitespace(sentence):
65
- return ' '.join(sentence.split())
66
-
67
- def remove_repeated_characters(sentence):
68
- return re.sub(r"(.)\1{2,}", r"\1", sentence)
69
-
70
- def replace_numbers(sentence):
71
- return re.sub(r"\d+", "[number]", sentence)
72
-
73
- def tokenize_underthesea(sentence):
74
- tokens = word_tokenize(sentence)
75
- return " ".join(tokens)
76
-
77
- # Nแบฟu cรณ abbreviations.json, load nรณ. Nแบฟu khรดng thรฌ ฤ‘แปƒ rแป—ng.
78
- try:
79
- with open("abbreviations.json", "r", encoding="utf-8") as f:
80
- abbreviations = json.load(f)
81
- except Exception as e:
82
- abbreviations = {}
83
-
84
- def preprocess_sentence(sentence):
85
- sentence = sentence.lower()
86
- sentence = replace_emojis(sentence, emoji_mapping)
87
- sentence = remove_profanity(sentence)
88
- sentence = remove_special_characters(sentence)
89
- sentence = normalize_whitespace(sentence)
90
- # Thay thแบฟ tแปซ viแบฟt tแบฏt nแบฟu cรณ trong abbreviations
91
- words = sentence.split()
92
- replaced = []
93
- for w in words:
94
- if w in abbreviations:
95
- replaced.append(" ".join(abbreviations[w]))
96
- else:
97
- replaced.append(w)
98
- sentence = " ".join(replaced)
99
- sentence = remove_repeated_characters(sentence)
100
- sentence = replace_numbers(sentence)
101
- sentence = tokenize_underthesea(sentence)
102
- return sentence
103
-
104
- ###############################################################################
105
- # LOAD CHECKPOINT
106
- ###############################################################################
107
- checkpoint_dir = "./checkpoint" # ฤฦฐแปng dแบซn ฤ‘แบฟn folder checkpoint
108
- device = "cuda" if torch.cuda.is_available() else "cpu"
109
-
110
- print("Loading config...")
111
- config = AutoConfig.from_pretrained(checkpoint_dir)
112
-
113
- # Mapping id to label theo thแปฉ tแปฑ bแบกn cung cแบฅp
114
- custom_id2label = {
115
- 0: 'Anger',
116
- 1: 'Disgust',
117
- 2: 'Enjoyment',
118
- 3: 'Fear',
119
- 4: 'Other',
120
- 5: 'Sadness',
121
- 6: 'Surprise'
122
- }
123
-
124
- if hasattr(config, "id2label") and config.id2label:
125
- if all(label.startswith("LABEL_") for label in config.id2label.values()):
126
- id2label = custom_id2label
127
- else:
128
- id2label = {int(k): v for k, v in config.id2label.items()}
129
- else:
130
- id2label = custom_id2label
131
-
132
- print("id2label loaded:", id2label)
133
-
134
- print("Loading tokenizer...")
135
- tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir)
136
-
137
- print("Loading model...")
138
- model = AutoModelForSequenceClassification.from_pretrained(checkpoint_dir, config=config)
139
- model.to(device)
140
- model.eval()
141
-
142
- ###############################################################################
143
- # Hร€M PREDICT
144
- ###############################################################################
145
- label2message = {
146
- 'Anger': 'Hรฃy bรฌnh tฤฉnh vร  giแบฃi quyแบฟt vแบฅn ฤ‘แป mแป™t cรกch bรฌnh thแบฃn.',
147
- 'Disgust': 'Hรฃy trรกnh xa nhแปฏng thแปฉ khiแบฟn bแบกn khรดng thรญch.',
148
- 'Enjoyment': 'Chรบc mแปซng bแบกn cรณ mแป™t ngร y tuyแป‡t vแปi!',
149
- 'Fear': 'Hรฃy ฤ‘แป‘i mแบทt vแป›i nแป—i sแปฃ ฤ‘แปƒ vฦฐแปฃt qua chรบng.',
150
- 'Other': 'Cแบฃm xรบc cแปงa bแบกn hiแป‡n tแบกi khรดng ฤ‘ฦฐแปฃc phรขn loแบกi rรต rร ng.',
151
- 'Sadness': 'Hรฃy tรฌm kiแบฟm sแปฑ hแป— trแปฃ khi cแบงn thiแบฟt.',
152
- 'Surprise': 'Thแบญt bแบฅt ngแป! Hรฃy tแบญn hฦฐแปŸng khoแบฃnh khแบฏc nร y.'
153
- }
154
-
155
- def predict_text(text: str) -> str:
156
- text_proc = preprocess_sentence(text)
157
- inputs = tokenizer(
158
- [text_proc],
159
- padding=True,
160
- truncation=True,
161
- max_length=256,
162
- return_tensors="pt"
163
- ).to(device)
164
-
165
- with torch.no_grad():
166
- outputs = model(**inputs)
167
- pred_id = outputs.logits.argmax(dim=-1).item()
168
-
169
- if pred_id in id2label:
170
- label = id2label[pred_id]
171
- message = label2message.get(label, "")
172
- if message:
173
- return f"Dแปฑ ฤ‘oรกn cแบฃm xรบc: {label}. {message}"
174
- else:
175
- return f"Dแปฑ ฤ‘oรกn cแบฃm xรบc: {label}."
176
- else:
177
- return f"Nhรฃn khรดng xรกc ฤ‘แป‹nh (id={pred_id})"
178
-
179
- ###############################################################################
180
- # ฤแปŠNH NGHฤจA MODEL INPUT
181
- ###############################################################################
182
- class InputText(BaseModel):
183
- text: str
184
-
185
- ###############################################################################
186
- # API ENDPOINT
187
- ###############################################################################
188
- @app.post("/predict")
189
- def predict(input_text: InputText):
190
- """
191
- Nhแบญn mแป™t cรขu tiแบฟng Viแป‡t vร  trแบฃ vแป dแปฑ ฤ‘oรกn cแบฃm xรบc.
192
- """
193
- result = predict_text(input_text.text)
194
- return {"result": result}
195
-
196
- ###############################################################################
197
- # CHแบ Y API SERVER
198
- ###############################################################################
199
- if __name__ == "__main__":
200
- import uvicorn
201
- uvicorn.run(app, host="0.0.0.0", port=8000)