ethiotech4848 commited on
Commit
b1dd449
·
verified ·
1 Parent(s): 2c2b2b5

Create flask_app.py

Browse files
Files changed (1) hide show
  1. flask_app.py +131 -0
flask_app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, Response, jsonify, stream_with_context
2
+ from flask_cors import CORS
3
+ import json
4
+
5
+ from typegpt_api import generate, model_mapping, simplified_models
6
+ from api_info import developer_info, model_providers
7
+
8
+ app = Flask(__name__)
9
+
10
+ # Set up CORS middleware if needed
11
+ CORS(app, resources={
12
+ r"/*": {
13
+ "origins": "*",
14
+ "allow_credentials": True,
15
+ "methods": ["*"],
16
+ "headers": ["*"]
17
+ }
18
+ })
19
+
20
+ @app.route("/health_check", methods=['GET'])
21
+ def health_check():
22
+ return jsonify({"status": "OK"})
23
+
24
+ @app.route("/models", methods=['GET'])
25
+ def get_models():
26
+ try:
27
+ response = {
28
+ "object": "list",
29
+ "data": []
30
+ }
31
+ for provider, info in model_providers.items():
32
+ for model in info["models"]:
33
+ response["data"].append({
34
+ "id": model,
35
+ "object": "model",
36
+ "provider": provider,
37
+ "description": info["description"]
38
+ })
39
+
40
+ return jsonify(response)
41
+ except Exception as e:
42
+ return jsonify({"error": str(e)}), 500
43
+
44
+ @app.route("/chat/completions", methods=['POST'])
45
+ def chat_completions():
46
+ # Receive the JSON payload
47
+ try:
48
+ body = request.get_json()
49
+ except Exception as e:
50
+ return jsonify({"error": "Invalid JSON payload"}), 400
51
+
52
+ # Extract parameters
53
+ model = body.get("model")
54
+ messages = body.get("messages")
55
+ temperature = body.get("temperature", 0.7)
56
+ top_p = body.get("top_p", 1.0)
57
+ n = body.get("n", 1)
58
+ stream = body.get("stream", False)
59
+ stop = body.get("stop")
60
+ max_tokens = body.get("max_tokens")
61
+ presence_penalty = body.get("presence_penalty", 0.0)
62
+ frequency_penalty = body.get("frequency_penalty", 0.0)
63
+ logit_bias = body.get("logit_bias")
64
+ user = body.get("user")
65
+ timeout = 30 # or set based on your preference
66
+
67
+ # Validate required parameters
68
+ if not model:
69
+ return jsonify({"error": "The 'model' parameter is required."}), 400
70
+ if not messages:
71
+ return jsonify({"error": "The 'messages' parameter is required."}), 400
72
+
73
+ # Call the generate function
74
+ try:
75
+ if stream:
76
+ def generate_stream():
77
+ response = generate(
78
+ model=model,
79
+ messages=messages,
80
+ temperature=temperature,
81
+ top_p=top_p,
82
+ n=n,
83
+ stream=True,
84
+ stop=stop,
85
+ max_tokens=max_tokens,
86
+ presence_penalty=presence_penalty,
87
+ frequency_penalty=frequency_penalty,
88
+ logit_bias=logit_bias,
89
+ user=user,
90
+ timeout=timeout,
91
+ )
92
+
93
+ for chunk in response:
94
+ yield f"data: {json.dumps(chunk)}\n\n"
95
+ yield "data: [DONE]\n\n"
96
+
97
+ return Response(
98
+ stream_with_context(generate_stream()),
99
+ mimetype="text/event-stream",
100
+ headers={
101
+ "Cache-Control": "no-cache",
102
+ "Connection": "keep-alive",
103
+ "Transfer-Encoding": "chunked"
104
+ }
105
+ )
106
+ else:
107
+ response = generate(
108
+ model=model,
109
+ messages=messages,
110
+ temperature=temperature,
111
+ top_p=top_p,
112
+ n=n,
113
+ stream=False,
114
+ stop=stop,
115
+ max_tokens=max_tokens,
116
+ presence_penalty=presence_penalty,
117
+ frequency_penalty=frequency_penalty,
118
+ logit_bias=logit_bias,
119
+ user=user,
120
+ timeout=timeout,
121
+ )
122
+ return jsonify(response)
123
+ except Exception as e:
124
+ return jsonify({"error": str(e)}), 500
125
+
126
+ @app.route("/developer_info", methods=['GET'])
127
+ def get_developer_info():
128
+ return jsonify(developer_info)
129
+
130
+ if __name__ == "__main__":
131
+ app.run(host="0.0.0.0", port=8000)