Update moderations.py
Browse files- moderations.py +75 -11
moderations.py
CHANGED
@@ -1,7 +1,4 @@
|
|
1 |
from gevent import pywsgi
|
2 |
-
import dotenv
|
3 |
-
dotenv.load_dotenv(override=True)
|
4 |
-
|
5 |
import sys
|
6 |
import time
|
7 |
import argparse
|
@@ -10,9 +7,14 @@ from typing import Union
|
|
10 |
from pydantic import BaseModel
|
11 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
12 |
import torch
|
|
|
13 |
import openedai
|
14 |
import numpy as np
|
15 |
import asyncio
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
app = openedai.OpenAIStub()
|
@@ -35,7 +37,7 @@ labels = ['hate',
|
|
35 |
|
36 |
label2id = {l:i for i, l in enumerate(labels)}
|
37 |
id2label = {i:l for i, l in enumerate(labels)}
|
38 |
-
model_name = "
|
39 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
40 |
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(labels),id2label=id2label, label2id=label2id, problem_type = "multi_label_classification")
|
41 |
model.to(device)
|
@@ -44,6 +46,38 @@ model.eval()
|
|
44 |
# model, {torch.nn.Linear}, dtype=torch.qint8
|
45 |
#)
|
46 |
torch.set_num_threads(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
class ModerationsRequest(BaseModel):
|
48 |
model: str = "text-moderation-latest" # or "text-moderation-stable"
|
49 |
input: Union[str, list[str]]
|
@@ -121,11 +155,41 @@ async def predict(text, model, tokenizer):
|
|
121 |
# Main
|
122 |
if __name__ == "__main__":
|
123 |
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
|
|
|
1 |
from gevent import pywsgi
|
|
|
|
|
|
|
2 |
import sys
|
3 |
import time
|
4 |
import argparse
|
|
|
7 |
from pydantic import BaseModel
|
8 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
9 |
import torch
|
10 |
+
import os
|
11 |
import openedai
|
12 |
import numpy as np
|
13 |
import asyncio
|
14 |
+
from urllib.parse import urlparse
|
15 |
+
import nacos
|
16 |
+
import configparser
|
17 |
+
|
18 |
|
19 |
|
20 |
app = openedai.OpenAIStub()
|
|
|
37 |
|
38 |
label2id = {l:i for i, l in enumerate(labels)}
|
39 |
id2label = {i:l for i, l in enumerate(labels)}
|
40 |
+
model_name = "duanyu027/moderation_0628"
|
41 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
42 |
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(labels),id2label=id2label, label2id=label2id, problem_type = "multi_label_classification")
|
43 |
model.to(device)
|
|
|
46 |
# model, {torch.nn.Linear}, dtype=torch.qint8
|
47 |
#)
|
48 |
torch.set_num_threads(1)
|
49 |
+
def register_service(client,service_name,service_ip,service_port,cluster_name,health_check_interval,weight,http_proxy,domain,protocol,direct_domain):
|
50 |
+
try:
|
51 |
+
# 初始化 metadata
|
52 |
+
metadata = {}
|
53 |
+
|
54 |
+
# 如果 http_proxy 为 True,添加额外的 metadata 键值对
|
55 |
+
if http_proxy:
|
56 |
+
metadata["http_proxy"] = True
|
57 |
+
if direct_domain:
|
58 |
+
metadata["domain"] = f"{protocol}://{service_ip}:{service_port}"
|
59 |
+
else:
|
60 |
+
metadata["domain"] = f"{domain}/port/{service_port}"
|
61 |
+
else:
|
62 |
+
metadata["http_proxy"] = False
|
63 |
+
metadata["domain"] = f"{protocol}://{service_ip}:{service_port}"
|
64 |
+
response = client.add_naming_instance(
|
65 |
+
service_name,
|
66 |
+
service_ip,
|
67 |
+
service_port,
|
68 |
+
cluster_name,
|
69 |
+
weight,
|
70 |
+
metadata,
|
71 |
+
enable=True,
|
72 |
+
healthy=True,
|
73 |
+
ephemeral=True,
|
74 |
+
heartbeat_interval=health_check_interval
|
75 |
+
)
|
76 |
+
return response
|
77 |
+
except Exception as e:
|
78 |
+
print(f"Error registering service to Nacos: {e}")
|
79 |
+
return True
|
80 |
+
|
81 |
class ModerationsRequest(BaseModel):
|
82 |
model: str = "text-moderation-latest" # or "text-moderation-stable"
|
83 |
input: Union[str, list[str]]
|
|
|
155 |
# Main
|
156 |
if __name__ == "__main__":
|
157 |
|
158 |
+
# 创建配置解析器
|
159 |
+
config = configparser.ConfigParser()
|
160 |
+
# 读取配置文件
|
161 |
+
if not config.read('config.ini'):
|
162 |
+
raise RuntimeError("配置文件不存在")
|
163 |
+
# Nacos server and other configurations
|
164 |
+
NACOS_SERVER = config['nacos']['nacos_server']
|
165 |
+
NAMESPACE = config['nacos']['namespace']
|
166 |
+
CLUSTER_NAME = config['nacos']['cluster_name']
|
167 |
+
client = nacos.NacosClient(NACOS_SERVER, namespace=NAMESPACE, username=config['nacos']['username'], password=config['nacos']['password'])
|
168 |
+
SERVICE_NAME = config['nacos']['service_name']
|
169 |
+
HEALTH_CHECK_INTERVAL = int(config['nacos']['health_check_interval'])
|
170 |
+
if config.has_option('nacos', 'weight'):
|
171 |
+
WEIGHT = int(config.get('nacos', 'weight'))
|
172 |
+
else:
|
173 |
+
WEIGHT = 1
|
174 |
+
HTTP_PROXY = config.getboolean('server', 'http_proxy')
|
175 |
+
DOMAIN = config['server']['domain']
|
176 |
+
PROTOCOL = config['server']['protocol']
|
177 |
+
DIRECT_DOMAIN = config.getboolean('server', 'direct_domain')
|
178 |
+
# Parse AutoDLServiceURL
|
179 |
+
autodl_url = os.environ.get('AutoDLServiceURL')
|
180 |
+
|
181 |
+
if not autodl_url:
|
182 |
+
raise RuntimeError("Error: AutoDLServiceURL environment variable is not set.")
|
183 |
+
|
184 |
+
parsed_url = urlparse(autodl_url)
|
185 |
+
SERVICE_IP = parsed_url.hostname
|
186 |
+
SERVICE_PORT = parsed_url.port
|
187 |
+
if not SERVICE_IP or not SERVICE_PORT:
|
188 |
+
raise RuntimeError("Error: Invalid AutoDLServiceURL format.")
|
189 |
|
190 |
+
print(f"Service will be registered with IP: {SERVICE_IP} and Port: {SERVICE_PORT}")
|
191 |
+
if not register_service(client,SERVICE_NAME,SERVICE_IP,SERVICE_PORT,CLUSTER_NAME,HEALTH_CHECK_INTERVAL,WEIGHT,HTTP_PROXY,DOMAIN,PROTOCOL,DIRECT_DOMAIN):
|
192 |
+
raise RuntimeError("Service is healthy but failed to register.")
|
193 |
+
app.register_model('text-moderations-latest', 'text-moderations-stable')
|
194 |
+
app.register_model('text-moderations-005', 'text-moderations-ifmain')
|
195 |
+
uvicorn.run(app, host="0.0.0.0", port=6006)
|