Spaces:
Runtime error
Runtime error
import json | |
import random | |
import string | |
import uuid | |
import time | |
import jwt | |
import datetime | |
import requests | |
import os | |
from flask import Flask, request, jsonify, Request, Response | |
from redis import Redis | |
from utils import nowtime | |
import pay_package | |
from ApiResponse import ApiResponse | |
from flask_cors import CORS | |
from mail import MyEmail | |
SECERT_KEY = "8U2LL1" | |
MY_OPENAI_API_KEY = os.environ.get('MY_OPENAI_API_KEY') | |
app = Flask(__name__) | |
cors = CORS(app) | |
redis = Redis(host='192.168.3.229', port=6379, password='lizhen-redis') | |
# redis = Redis(host='10.254.13.87', port=6379) | |
# redis = Redis(host='localhost', port=6379) | |
# 生成验证码 | |
def generate_verification_code(): | |
code = ''.join(random.choices(string.digits, k=6)) | |
return code | |
# 发送验证码到用户邮箱(这里只是模拟发送过程) | |
def send_verification_code(email, code): | |
my = MyEmail() | |
my.user = "[email protected]" | |
my.passwd = "todo" | |
my.to_list = [email] | |
my.tag = "Chat注册验证码" | |
my.txt = f"【{code}】Chat邮箱注册验证码,您正在注册Chat账号,请勿泄露。" | |
my.send() | |
# 用户请求发送验证码 | |
def send_verification_code_endpoint(): | |
# 从请求中获取邮箱地址 | |
email = request.json.get('email') | |
# 生成验证码 | |
verification_code = generate_verification_code() | |
# 发送验证码到用户邮箱 | |
send_verification_code(email, verification_code) | |
# 保存验证码到Redis,并设置过期时间(例如,5分钟) | |
redis.setex(email, 300, verification_code) | |
return jsonify({'code': 0, 'message': 'Verification code sent'}) | |
# 用户注册 | |
def register(): | |
# 从请求中获取注册信息 | |
email = request.json.get('email') | |
username = request.json.get('username') | |
password = request.json.get('password') | |
verification_code = request.json.get('verification_code') | |
if is_email_registered(email): | |
return jsonify({'code': 400, 'message': '邮箱已被注册'}) | |
# 检查验证码是否匹配 | |
stored_code = redis.get(email) | |
if stored_code is None or verification_code != stored_code.decode('utf-8'): | |
return jsonify({'code': 400, 'message': 'Invalid verification code'}) | |
# 检查用户名是否已被注册 | |
if redis.hexists('users', username): | |
return jsonify({'code': 400, 'message': 'Username already exists'}) | |
# 生成唯一的用户ID | |
user_id = str(uuid.uuid4()) | |
# 保存用户信息到Redis | |
user_data = { | |
'user_id': user_id, | |
'username': username, | |
'email': email, | |
'password': password | |
} | |
redis.hset('users', username, json.dumps(user_data)) | |
# 清除验证码 | |
redis.delete(email) | |
return jsonify({ | |
'code': 0, | |
'message': 'Registration successful' | |
}) | |
# 用户登录 | |
def login(): | |
# 从请求中获取登录信息 | |
username = request.json.get('username') | |
password = request.json.get('password') | |
# 检查用户名和密码是否匹配 | |
user_data = redis.hget('users', username) | |
if not user_data: | |
return jsonify({'code': 400, 'message': 'Invalid username'}) | |
user_data = user_data.decode('utf-8') | |
if password != eval(user_data)['password']: | |
return jsonify({'code': 400, 'message': 'Invalid password'}) | |
# 生成令牌 | |
token = generate_token(eval(user_data)['user_id'], username) | |
return jsonify({ | |
'code': 0, | |
'message': 'Login successful', | |
'data': { | |
'token': token | |
} | |
}) | |
# 需要验证登录状态的接口 | |
def protected(): | |
token = parse_token(request) | |
# 验证令牌 | |
if not validate_token(token): | |
return jsonify({'code': 401, 'message': 'Invalid token'}), 200 | |
# 如果用户未登录,则返回未授权的响应 | |
return jsonify({'code': 401, 'message': 'Unauthorized'}) | |
def reset_password(): | |
email = request.json.get('email') | |
verification_code = request.json.get('verification_code') | |
new_password = request.json.get('new_password') | |
# 检查验证码是否匹配 | |
stored_code = redis.get(email) | |
if stored_code is None or verification_code != stored_code.decode('utf-8'): | |
return jsonify({'code': 400, 'message': 'Invalid verification code'}) | |
if not is_email_registered(email): | |
return jsonify({'code': 400, 'message': '邮箱未注册'}) | |
update_password(email, new_password) | |
redis.delete(email) | |
return jsonify({'code': 0, 'message': '密码已更新'}) | |
# 用户注销 | |
def logout(): | |
token = parse_token(request) | |
# 验证令牌 | |
if not validate_token(token): | |
# 将令牌添加到 Redis 黑名单 | |
redis.set(token, 'revoked') | |
return jsonify({'code': 0, 'message': 'Logout successful'}) | |
# 购买支付套餐 | |
def purchase(): | |
package_id = request.json.get('package_id') | |
token = parse_token(request) | |
# 验证令牌 | |
if not validate_token(token): | |
return jsonify({'code': 401, 'message': 'Invalid token'}) | |
# 根据套餐ID获取套餐信息 | |
package = pay_package.get_package_by_id(package_id) | |
if not package: | |
return jsonify({'code': 400, 'message': 'Invalid package ID'}) | |
user_id = get_user_id_from_token(token) | |
if not user_id: | |
return jsonify({'code': 400, 'message': 'User not found'}) | |
# 检查用户是否已经支付过当前套餐 | |
if not is_package_expired(user_id) and has_purchased_package(user_id, package_id): | |
return jsonify({'code': 400, 'message': 'Package already purchased'}) | |
# 检查如果用户已经支付了高级套餐,则不能支付比高级套餐更低级的基础套餐 | |
if not is_package_expired(user_id) and has_purchased_advanced_package(user_id) and package_id == '1': | |
return jsonify({'code': 400, 'message': 'Cannot purchase lower level package'}) | |
# 存储用户套餐信息到Redis | |
store_user_package(user_id, package) | |
return jsonify({'code': 0, 'message': 'Purchase successful'}) | |
# 验证用户聊天次数 | |
def validate(): | |
token = parse_token(request) | |
model = request.json.get('model') | |
# 验证令牌 | |
if not validate_token(token): | |
return jsonify({'code': 401, 'message': 'Invalid token'}) | |
user_id = get_user_id_from_token(token) | |
if not user_id: | |
return jsonify({'code': 400, 'message': 'User not found'}) | |
# 获取用户套餐信息 | |
package = get_user_package(user_id) | |
if not package: | |
return jsonify({'code': 400, 'message': 'User has not purchased any package'}) | |
# 检查用户聊天次数是否超过限制 | |
if exceeded_chat_limit(user_id, package, model): | |
return jsonify({'code': 400, 'message': 'Chat limit exceeded'}) | |
return jsonify({'code': 0, 'message': 'Chat limit not exceeded'}) | |
def proxy_chat_completions(): | |
token = parse_token(request) | |
model = request.json.get('model') | |
# 验证令牌 | |
if not validate_token(token): | |
return jsonify({'code': 401, 'message': 'Invalid token'}) | |
user_id = get_user_id_from_token(token) | |
if not user_id: | |
return jsonify({'code': 400, 'message': 'User not found'}) | |
# 获取用户套餐信息 | |
package = get_user_package(user_id) | |
if not package: | |
return jsonify({'code': 400, 'message': 'User has not purchased any package'}) | |
# 检查用户聊天次数是否超过限制 | |
if exceeded_chat_limit(user_id, package, model): | |
if model == 'gpt-3.5-turbo': | |
return jsonify({'code': 400, 'message': 'model3.5基础访问次数已用完'}) | |
if model == 'gpt-4': | |
return jsonify({'code': 400, 'message': 'model4高级访问次数已用完'}) | |
# 获取请求数据 | |
data = request.get_json() | |
stream = request.json.get('stream') | |
# 设置请求头部信息 | |
headers = { | |
'Authorization': f'Bearer {MY_OPENAI_API_KEY}', | |
'Content-Type': 'application/json' | |
} | |
if stream: | |
# 将请求转发到 OpenAI API | |
response = requests.post( | |
'https://api.openai.com/v1/chat/completions', json=data, headers=headers, stream=True, verify=False) | |
# 获取 OpenAI API 的响应数据 | |
result = response.iter_content(chunk_size=8192) | |
if get_free_count(user_id) > 0 and model == 'gpt-3.5-turbo': | |
redis.hincrby(f'user:{user_id}:free', 'basic_chat_count', -1) | |
else: | |
user_package_key = f'user:{user_id}:package' | |
redis.hincrby(user_package_key, 'basic_chat_limit', -1) | |
# 返回 OpenAI API 的响应给客户端 | |
return Response(result, content_type=response.headers['content-type']) | |
else: | |
# 将请求转发到 OpenAI API | |
response = requests.post( | |
'https://api.openai.com/v1/chat/completions', json=data, headers=headers) | |
# 获取 OpenAI API 的响应数据 | |
result = response.json() | |
if get_free_count(user_id) > 0 and model == 'gpt-3.5-turbo': | |
redis.hincrby(f'user:{user_id}:free', 'basic_chat_count', -1) | |
else: | |
user_package_key = f'user:{user_id}:package' | |
redis.hincrby(user_package_key, 'basic_chat_limit', -1) | |
# 返回 OpenAI API 的响应给客户端 | |
return result, response.status_code | |
# 每天领取免费次数 | |
def handle_pick_up_free_chat_count(): | |
token = parse_token(request) | |
# 验证令牌 | |
if not validate_token(token): | |
return jsonify({'code': 401, 'message': 'Invalid token'}) | |
user_id = get_user_id_from_token(token) | |
if not user_id: | |
return jsonify({'code': 400, 'message': 'User not found'}) | |
# 领取赠送的聊天次数 | |
success = pick_up_free_chat_count(user_id) | |
if success: | |
return {'message': '领取成功'} | |
else: | |
return {'message': '您今天已经领取了'} | |
def packageOnSales(): | |
token = parse_token(request) | |
# 验证令牌 | |
if not validate_token(token): | |
return jsonify({'code': 401, 'message': 'Invalid token'}) | |
user_id = get_user_id_from_token(token) | |
if not user_id: | |
return jsonify({'code': 400, 'message': 'User not found'}) | |
response = ApiResponse( | |
code=0, | |
message='Login successful', | |
data=pay_package.packages) | |
return jsonify(response.to_json()) | |
def parse_token(request: Request): | |
token_with_bearer = request.headers.get('Authorization') | |
if token_with_bearer is not None and token_with_bearer.startswith('Bearer '): | |
token = token_with_bearer.split(' ')[1] | |
else: | |
# 处理未包含 "Bearer" 前缀的情况 | |
token = token_with_bearer | |
return token | |
# 生成令牌 | |
def generate_token(user_id, username): | |
# 构造包含用户信息的负载 | |
# 如果要添加过期时间,加'exp': nowtime() + datetime.timedelta(days=30) | |
payload = { | |
'user_id': user_id, | |
'username': username, | |
} | |
# 在这里,您可以使用您的密钥(secret key)来签署令牌 | |
# 选择适当的签名算法,并设置适当的过期时间等参数 | |
# 仅使用 HS256 算法和过期时间为1小时 | |
token = jwt.encode(payload, SECERT_KEY, algorithm='HS256') | |
return token | |
# 验证令牌 | |
def validate_token(token): | |
try: | |
print("validate_token", "token:", token) | |
# 使用密钥进行解码 | |
payload = jwt.decode(token, SECERT_KEY, algorithms=['HS256']) | |
print("validate_token", "payload:", payload) | |
# 检查令牌的过期时间 | |
if 'exp' in payload and nowtime() > datetime.datetime.fromtimestamp(payload['exp']): | |
return False | |
if 'user_id' not in payload or 'username' not in payload: | |
return False | |
return True | |
except (jwt.DecodeError, jwt.InvalidTokenError): | |
return False | |
def get_user_id_from_token(token): | |
try: | |
decoded_token = jwt.decode( | |
token, SECERT_KEY, algorithms=['HS256']) | |
user_id = decoded_token.get('user_id') | |
return user_id | |
except jwt.ExpiredSignatureError: | |
# 处理过期的令牌 | |
return None | |
except (jwt.DecodeError, jwt.InvalidTokenError): | |
# 处理解码或无效的令牌 | |
return None | |
# 获取用户ID通过用户名 | |
def get_user_id_by_username(username): | |
user_data = redis.hget('users', username) | |
if user_data: | |
user_data = json.loads(user_data.decode('utf-8')) | |
user_id = user_data.get('user_id') | |
return user_id | |
return None | |
# 存储用户套餐信息到Redis | |
def store_user_package(user_id, package): | |
current_package = get_user_package(user_id) | |
basic_chat_limit = package['basic_chat_limit'] | |
advanced_chat_limit = package['advanced_chat_limit'] | |
if current_package: | |
basic_chat_limit += int(current_package.get(b'basic_chat_limit', 0)) | |
advanced_chat_limit += int(current_package.get(b'advanced_chat_limit', 0)) | |
user_package_key = f'user:{user_id}:package' | |
redis.hset(user_package_key, 'id', package['id']) | |
redis.hset(user_package_key, 'title', package['title']) | |
redis.hset(user_package_key, 'basic_chat_limit', basic_chat_limit) | |
redis.hset(user_package_key, 'advanced_chat_limit', advanced_chat_limit) | |
# 设置套餐过期时间 | |
# expiration = int(time.time()) + package['expiration'] | |
# redis.expireat(user_package_key, expiration) | |
# 获取用户套餐信息 | |
def get_user_package(user_id): | |
user_package_key = f'user:{user_id}:package' | |
package = redis.hgetall(user_package_key) | |
return package | |
# 检查用户是否已经支付过指定套餐 | |
def has_purchased_package(user_id, package_id): | |
user_package_key = f'user:{user_id}:package' | |
purchased_package_id = redis.hget(user_package_key, 'id') | |
return purchased_package_id.decode('utf-8') == str(package_id) | |
# 检查用户是否已经支付了高级套餐 | |
def has_purchased_advanced_package(user_id): | |
user_package_key = f'user:{user_id}:package' | |
purchased_package_id = redis.hget(user_package_key, 'id') | |
return purchased_package_id.decode('utf-8') == '2' | |
# 检查套餐是否过期 | |
def is_package_expired(user_id): | |
user_package_key = f'user:{user_id}:package' | |
expiration = redis.ttl(user_package_key) | |
return expiration <= 0 | |
# 获取套餐有效期 | |
def get_package_expiration(user_id): | |
user_package_key = f'user:{user_id}:package' | |
expiration = redis.ttl(user_package_key) | |
return expiration | |
# 检查用户聊天次数是否超过限制 | |
def exceeded_chat_limit(user_id, package, model): | |
if model == 'gpt-3.5-turbo': | |
basic_chat_limit = int(package.get(b'basic_chat_limit', 0)) | |
print('basic_chat_limit:', basic_chat_limit) | |
if get_free_count(user_id) > 0: | |
return False | |
if basic_chat_limit <= 0: | |
return True | |
if model == 'gpt-4': | |
advanced_chat_limit = int(package.get(b'advanced_chat_limit', 0)) | |
if advanced_chat_limit <= 0: | |
return True | |
return False | |
# 检查邮箱是否已注册 | |
def is_email_registered(email): | |
# 遍历所有用户数据 | |
for username in redis.hkeys('users'): | |
user_data = json.loads(redis.hget('users', username).decode('utf-8')) | |
if user_data['email'] == email: | |
return True | |
return False | |
# 更新用户密码 | |
def update_password(email, new_password): | |
# 遍历所有用户数据 | |
for username in redis.hkeys('users'): | |
user_data = json.loads(redis.hget('users', username).decode('utf-8')) | |
if user_data['email'] == email: | |
user_data['password'] = new_password | |
redis.hset('users', username, json.dumps(user_data)) | |
return True | |
return False | |
def get_user_free_data(user_id): | |
user_free_key = f'user:{user_id}:free' | |
free_data = redis.hgetall(user_free_key) | |
return free_data | |
def initialize_user_free_data(user_id): | |
free_data = { | |
b'basic_chat_count': 0, | |
b'last_gift_time': str(datetime.date.today()), | |
b'has_pick_up_free': b'false' # 用户是否领取了赠送的聊天次数 | |
} | |
user_free_name = f'user:{user_id}:free' | |
redis.hset(user_free_name, | |
b'basic_chat_count', | |
free_data.get(b'basic_chat_count', 0)) | |
redis.hset(user_free_name, | |
b'last_gift_time', | |
free_data.get(b'last_gift_time')) | |
redis.hset(user_free_name, | |
b'has_pick_up_free', | |
free_data.get(b'has_pick_up_free', b'false')) | |
return free_data | |
# 每天领取免费次数 | |
def pick_up_free_chat_count(user_id): | |
free_data = get_user_free_data(user_id) | |
if not free_data: | |
free_data = initialize_user_free_data(user_id) | |
# 获取用户的聊天次数和最后赠送时间 | |
basic_chat_count = free_data.get(b'basic_chat_count', 0) | |
last_gift_time = free_data.get( | |
b'last_gift_time', | |
str(datetime.date.today())) | |
has_pick_up_free = ( | |
free_data.get(b'has_pick_up_free', b'false') | |
.decode() | |
.lower() == 'true') | |
print('free_data3', basic_chat_count, last_gift_time, has_pick_up_free) | |
# 检查用户是否已经领取过赠送的聊天次数今天 | |
if has_pick_up_free and last_gift_time.decode() == str(datetime.date.today()): | |
return False | |
basic_free_count = 5 | |
# 更新用户聊天数据和领取状态 | |
user_free_name = f'user:{user_id}:free' | |
redis.hset(user_free_name, b'basic_chat_count', basic_free_count) | |
redis.hset(user_free_name, b'last_gift_time', str(datetime.date.today())) | |
redis.hset(user_free_name, b'has_pick_up_free', 'true') | |
return True | |
def get_free_count(user_id): | |
free_data = get_user_free_data(user_id) | |
basic_chat_count = free_data.get(b'basic_chat_count', 0) | |
return int(basic_chat_count) | |
if __name__ == '__main__': | |
app.run(debug=True) | |