vlff李飞飞 commited on
Commit
47208d1
·
1 Parent(s): 3cac10f
Files changed (1) hide show
  1. qwen_server/database_server.py +5 -23
qwen_server/database_server.py CHANGED
@@ -76,26 +76,13 @@ def change_checkbox_state(text, cache_file):
76
  return {'result': 'changed'}
77
 
78
 
79
- class BasicAuthMiddleware(BaseHTTPMiddleware):
80
- def __init__(self, app, token: str):
81
- super().__init__(app)
82
- self.required_credentials = token
83
-
84
- async def dispatch(self, request: Request, call_next):
85
- authorization: str = request.headers.get("Authorization")
86
- if authorization:
87
- try:
88
- schema, credentials = authorization.split()
89
- if credentials == self.required_credentials:
90
- return await call_next(request)
91
- except ValueError:
92
- pass
93
-
94
  headers = {'WWW-Authenticate': 'Basic'}
95
  return Response(status_code=401, headers=headers)
96
-
97
-
98
- app.add_middleware(BasicAuthMiddleware, token=os.getenv("ACCESS_TOKEN"))
99
 
100
 
101
  @app.get('/healthz')
@@ -105,8 +92,6 @@ async def healthz(request: Request):
105
 
106
  @app.get('/cachedata/{file_name}')
107
  async def cache_data(file_name: str, access_token: str):
108
- if not access_token or os.getenv("ACCESS_TOKEN") != access_token:
109
- raise HTTPException(401, "the access token is not valid")
110
  cache_file = os.path.join(server_config.path.cache_root, file_name)
111
  lines = []
112
  for line in jsonlines.open(cache_file):
@@ -117,9 +102,6 @@ async def cache_data(file_name: str, access_token: str):
117
  @app.post('/endpoint')
118
  async def web_listening(request: Request):
119
  data = await request.json()
120
- access_token: str = request.headers.get("Authorization")
121
- if not access_token or os.getenv("ACCESS_TOKEN") != access_token:
122
- raise HTTPException(401, "the access token is not valid")
123
  msg_type = data['task']
124
 
125
  cache_file_popup_url = os.path.join(server_config.path.cache_root, 'popup_url.jsonl')
 
76
  return {'result': 'changed'}
77
 
78
 
79
+ @app.middleware("http")
80
+ async def validate_token(request: Request, call_next):
81
+ access_token: str = request.headers.get("Authorization") or request.query_params.get("access_token")
82
+ if not access_token or os.getenv("ACCESS_TOKEN") != access_token:
 
 
 
 
 
 
 
 
 
 
 
83
  headers = {'WWW-Authenticate': 'Basic'}
84
  return Response(status_code=401, headers=headers)
85
+ return await call_next(request)
 
 
86
 
87
 
88
  @app.get('/healthz')
 
92
 
93
  @app.get('/cachedata/{file_name}')
94
  async def cache_data(file_name: str, access_token: str):
 
 
95
  cache_file = os.path.join(server_config.path.cache_root, file_name)
96
  lines = []
97
  for line in jsonlines.open(cache_file):
 
102
  @app.post('/endpoint')
103
  async def web_listening(request: Request):
104
  data = await request.json()
 
 
 
105
  msg_type = data['task']
106
 
107
  cache_file_popup_url = os.path.join(server_config.path.cache_root, 'popup_url.jsonl')