taozi555 commited on
Commit
462525d
·
verified ·
1 Parent(s): 62c1330

Update moderations.py

Browse files
Files changed (1) hide show
  1. moderations.py +75 -11
moderations.py CHANGED
@@ -1,7 +1,4 @@
1
  from gevent import pywsgi
2
- import dotenv
3
- dotenv.load_dotenv(override=True)
4
-
5
  import sys
6
  import time
7
  import argparse
@@ -10,9 +7,14 @@ from typing import Union
10
  from pydantic import BaseModel
11
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
12
  import torch
 
13
  import openedai
14
  import numpy as np
15
  import asyncio
 
 
 
 
16
 
17
 
18
  app = openedai.OpenAIStub()
@@ -35,7 +37,7 @@ labels = ['hate',
35
 
36
  label2id = {l:i for i, l in enumerate(labels)}
37
  id2label = {i:l for i, l in enumerate(labels)}
38
- model_name = "/root/autodl-tmp/duanyu027/moderation_0628"
39
  tokenizer = AutoTokenizer.from_pretrained(model_name)
40
  model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(labels),id2label=id2label, label2id=label2id, problem_type = "multi_label_classification")
41
  model.to(device)
@@ -44,6 +46,38 @@ model.eval()
44
  # model, {torch.nn.Linear}, dtype=torch.qint8
45
  #)
46
  torch.set_num_threads(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  class ModerationsRequest(BaseModel):
48
  model: str = "text-moderation-latest" # or "text-moderation-stable"
49
  input: Union[str, list[str]]
@@ -121,11 +155,41 @@ async def predict(text, model, tokenizer):
121
  # Main
122
  if __name__ == "__main__":
123
 
124
- args = parse_args(sys.argv[1:])
125
- # start API
126
- print(f'Starting moderations[{device}] API on {args.host}:{args.port}', file=sys.stderr)
127
- app.register_model('text-moderations-latest', 'text-moderations-stable')
128
- app.register_model('text-moderations-005', 'text-moderations-ifmain')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
- if not args.test_load:
131
- uvicorn.run(app, host=args.host, port=args.port)
 
 
 
 
 
1
  from gevent import pywsgi
 
 
 
2
  import sys
3
  import time
4
  import argparse
 
7
  from pydantic import BaseModel
8
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
  import torch
10
+ import os
11
  import openedai
12
  import numpy as np
13
  import asyncio
14
+ from urllib.parse import urlparse
15
+ import nacos
16
+ import configparser
17
+
18
 
19
 
20
  app = openedai.OpenAIStub()
 
37
 
38
  label2id = {l:i for i, l in enumerate(labels)}
39
  id2label = {i:l for i, l in enumerate(labels)}
40
+ model_name = "duanyu027/moderation_0628"
41
  tokenizer = AutoTokenizer.from_pretrained(model_name)
42
  model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(labels),id2label=id2label, label2id=label2id, problem_type = "multi_label_classification")
43
  model.to(device)
 
46
  # model, {torch.nn.Linear}, dtype=torch.qint8
47
  #)
48
  torch.set_num_threads(1)
49
+ def register_service(client,service_name,service_ip,service_port,cluster_name,health_check_interval,weight,http_proxy,domain,protocol,direct_domain):
50
+ try:
51
+ # 初始化 metadata
52
+ metadata = {}
53
+
54
+ # 如果 http_proxy 为 True,添加额外的 metadata 键值对
55
+ if http_proxy:
56
+ metadata["http_proxy"] = True
57
+ if direct_domain:
58
+ metadata["domain"] = f"{protocol}://{service_ip}:{service_port}"
59
+ else:
60
+ metadata["domain"] = f"{domain}/port/{service_port}"
61
+ else:
62
+ metadata["http_proxy"] = False
63
+ metadata["domain"] = f"{protocol}://{service_ip}:{service_port}"
64
+ response = client.add_naming_instance(
65
+ service_name,
66
+ service_ip,
67
+ service_port,
68
+ cluster_name,
69
+ weight,
70
+ metadata,
71
+ enable=True,
72
+ healthy=True,
73
+ ephemeral=True,
74
+ heartbeat_interval=health_check_interval
75
+ )
76
+ return response
77
+ except Exception as e:
78
+ print(f"Error registering service to Nacos: {e}")
79
+ return True
80
+
81
  class ModerationsRequest(BaseModel):
82
  model: str = "text-moderation-latest" # or "text-moderation-stable"
83
  input: Union[str, list[str]]
 
155
  # Main
156
  if __name__ == "__main__":
157
 
158
+ # 创建配置解析器
159
+ config = configparser.ConfigParser()
160
+ # 读取配置文件
161
+ if not config.read('config.ini'):
162
+ raise RuntimeError("配置文件不存在")
163
+ # Nacos server and other configurations
164
+ NACOS_SERVER = config['nacos']['nacos_server']
165
+ NAMESPACE = config['nacos']['namespace']
166
+ CLUSTER_NAME = config['nacos']['cluster_name']
167
+ client = nacos.NacosClient(NACOS_SERVER, namespace=NAMESPACE, username=config['nacos']['username'], password=config['nacos']['password'])
168
+ SERVICE_NAME = config['nacos']['service_name']
169
+ HEALTH_CHECK_INTERVAL = int(config['nacos']['health_check_interval'])
170
+ if config.has_option('nacos', 'weight'):
171
+ WEIGHT = int(config.get('nacos', 'weight'))
172
+ else:
173
+ WEIGHT = 1
174
+ HTTP_PROXY = config.getboolean('server', 'http_proxy')
175
+ DOMAIN = config['server']['domain']
176
+ PROTOCOL = config['server']['protocol']
177
+ DIRECT_DOMAIN = config.getboolean('server', 'direct_domain')
178
+ # Parse AutoDLServiceURL
179
+ autodl_url = os.environ.get('AutoDLServiceURL')
180
+
181
+ if not autodl_url:
182
+ raise RuntimeError("Error: AutoDLServiceURL environment variable is not set.")
183
+
184
+ parsed_url = urlparse(autodl_url)
185
+ SERVICE_IP = parsed_url.hostname
186
+ SERVICE_PORT = parsed_url.port
187
+ if not SERVICE_IP or not SERVICE_PORT:
188
+ raise RuntimeError("Error: Invalid AutoDLServiceURL format.")
189
 
190
+ print(f"Service will be registered with IP: {SERVICE_IP} and Port: {SERVICE_PORT}")
191
+ if not register_service(client,SERVICE_NAME,SERVICE_IP,SERVICE_PORT,CLUSTER_NAME,HEALTH_CHECK_INTERVAL,WEIGHT,HTTP_PROXY,DOMAIN,PROTOCOL,DIRECT_DOMAIN):
192
+ raise RuntimeError("Service is healthy but failed to register.")
193
+ app.register_model('text-moderations-latest', 'text-moderations-stable')
194
+ app.register_model('text-moderations-005', 'text-moderations-ifmain')
195
+ uvicorn.run(app, host="0.0.0.0", port=6006)