from flask import Flask, request, jsonify, Response from functools import wraps import uuid import json from typing import List, Optional from pydantic import BaseModel, ValidationError import logging from API_provider import API_Inference from core_logic import ( check_api_key_validity, update_request_count, get_rate_limit_status, get_subscription_status, get_available_models, get_model_info, ) app = Flask(__name__) logging.basicConfig(level=logging.DEBUG) class Message(BaseModel): role: str content: str class ChatCompletionRequest(BaseModel): model: str messages: List[Message] stream: Optional[bool] = False max_tokens: Optional[int] = 4000 temperature: Optional[float] = 0.5 top_p: Optional[float] = 0.95 def get_api_key(): auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer '): return None return auth_header.split(' ')[1] def requires_api_key(func): @wraps(func) def decorated(*args, **kwargs): api_key = get_api_key() if not api_key: return jsonify({'detail': 'Not authenticated'}), 401 kwargs['api_key'] = api_key return func(*args, **kwargs) return decorated @app.route('/') def index(): return 'Hello, World!' @app.route('/chat/completions', methods=['POST', 'GET']) @requires_api_key def chat_completions(api_key): return jsonify({'detail': "YOUUUUUUUUUU"}), 500 # logging.info("Received request for chat completions") # print("requess received") # try: # logging.info("Received request for chat completions") # # Parse and validate request data # try: # data = request.get_json() # chat_request = ChatCompletionRequest(**data) # except ValidationError as e: # return jsonify({'detail': e.errors()}), 400 # # Check API key validity and rate limit # is_valid, error_message = check_api_key_validity(api_key) # if not is_valid: # return jsonify({'detail': error_message}), 401 # messages = [{"role": msg.role, "content": msg.content} for msg in chat_request.messages] # # Get model info # model_info = get_model_info(chat_request.model) # if not model_info: # return jsonify({'detail': 'Invalid model specified'}), 400 # # Model mapping # model_mapping = { # "meta-llama-405b-turbo": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", # "claude-3.5-sonnet": "claude-3-sonnet-20240229", # } # model_name = model_mapping.get(chat_request.model, chat_request.model) # credits_reduction = { # "gpt-4o": 1, # "claude-3-sonnet-20240229": 1, # "gemini-1.5-pro": 1, # "gemini-1-5-flash": 1, # "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": 1, # "o1-mini": 2, # "o1-preview": 3, # }.get(model_name, 0) # if chat_request.stream: # def generate(): # try: # for chunk in API_Inference(messages, model=model_name, stream=True, # max_tokens=chat_request.max_tokens, # temperature=chat_request.temperature, # top_p=chat_request.top_p): # data = json.dumps({'choices': [{'delta': {'content': chunk}}]}) # yield f"data: {data}\n\n" # yield f"data: [DONE]\n\nCredits used: {credits_reduction}\n\n" # update_request_count(api_key, credits_reduction) # except Exception as e: # yield f"data: [ERROR] {str(e)}\n\n" # return Response(generate(), mimetype='text/event-stream') # else: # response = API_Inference(messages, model=model_name, stream=False, # max_tokens=chat_request.max_tokens, # temperature=chat_request.temperature, # top_p=chat_request.top_p) # update_request_count(api_key, credits_reduction) # prompt_tokens = sum(len(msg['content'].split()) for msg in messages) # completion_tokens = len(response.split()) # total_tokens = prompt_tokens + completion_tokens # return jsonify({ # "id": f"chatcmpl-{str(uuid.uuid4())}", # "object": "chat.completion", # "created": int(uuid.uuid1().time // 1e7), # "model": model_name, # "choices": [ # { # "index": 0, # "message": { # "role": "assistant", # "content": response # }, # "finish_reason": "stop" # } # ], # "usage": { # "prompt_tokens": prompt_tokens, # "completion_tokens": completion_tokens, # "total_tokens": total_tokens # }, # "credits_used": credits_reduction # }) # except Exception as e: # return jsonify({'detail': str(e)}), 500 @app.route('/rate_limit/status', methods=['GET']) @requires_api_key def get_rate_limit_status_endpoint(api_key): is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False) if not is_valid: return jsonify({'detail': error_message}), 401 return jsonify(get_rate_limit_status(api_key)) @app.route('/subscription/status', methods=['GET']) @requires_api_key def get_subscription_status_endpoint(api_key): is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False) if not is_valid: return jsonify({'detail': error_message}), 401 return jsonify(get_subscription_status(api_key)) @app.route('/models', methods=['GET']) @requires_api_key def get_available_models_endpoint(api_key): is_valid, error_message = check_api_key_validity(api_key, check_rate_limit=False) if not is_valid: return jsonify({'detail': error_message}), 401 return jsonify({"data": [{"id": model} for model in get_available_models().values()]}) if __name__ == "__main__": app.run(host="0.0.0.0", port=8000)