OjciecTadeusz commited on
Commit
8345d88
·
verified ·
1 Parent(s): 491cd52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +199 -143
app.py CHANGED
@@ -1,153 +1,209 @@
1
- import gradio as gr
2
- from fastapi import FastAPI, Request, HTTPException
3
- from fastapi.responses import JSONResponse
4
- import datetime
5
- import requests
6
- import os
7
- import logging
8
- import toml
9
-
10
- # Initialize FastAPI
11
  app = FastAPI()
12
 
13
- # Configure logging
14
- logging.basicConfig(level=logging.INFO)
15
- logger = logging.getLogger(__name__)
16
-
17
- # Load config
18
- with open("config.toml") as f:
19
- config = toml.load(f)
20
-
21
- #API_URL = os.getenv('API_URL')
22
- #API_TOKEN = os.getenv('API_TOKEN')
23
- # API_URL = 'https://ojciectadeusz-fastapi-inference-qwen2-5-coder-32-a0ab504.hf.space/v1/chat/completions'
24
- API_URL = 'https://ojciectadeusz-fastapi-inference-qwen2.5-coder-32b-instruct.hf.space/v1/chat/completions'
25
- headers = {
26
- "Authorization": f"Bearer {os.getenv('HF_API_TOKEN')}",
27
- "Content-Type": "application/json"
28
- }
29
-
30
- def format_chat_response(response_text, prompt_tokens=0, completion_tokens=0):
31
- return {
32
- "id": f"chatcmpl-{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}",
33
- "object": "chat.completion",
34
- "created": int(datetime.datetime.now().timestamp()),
35
- "model": "Qwen/Qwen2.5-Coder-32B",
36
- "choices": [{
37
- "index": 0,
38
- "message": {
39
- "role": "assistant",
40
- "content": response_text
41
- },
42
- "finish_reason": "stop"
43
- }],
44
- "usage": {
45
- "prompt_tokens": prompt_tokens,
46
- "completion_tokens": completion_tokens,
47
- "total_tokens": prompt_tokens + completion_tokens
48
- }
49
- }
50
-
51
- async def query_model(payload):
52
- try:
53
- response = requests.post(API_URL, headers=headers, json=payload)
54
- response.raise_for_status()
55
- return response.json()
56
- except requests.exceptions.RequestException as e:
57
- logger.error(f"Request failed: {e}")
58
- raise HTTPException(status_code=500, detail=str(e))
59
-
60
- @app.get("/status")
61
- async def status():
62
- try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- response_text = os.getenv('HF_API_TOKEN') + "it's working"
65
- return JSONResponse(content=format_chat_response(response_text))
66
- except Exception as e:
67
- logger.error(f"Status check failed: {e}")
68
- raise HTTPException(status_code=500, detail=str(e))
69
-
70
- @app.post("/v1/chat/completions")
71
- async def chat_completion(request: Request):
72
- try:
73
- data = await request.json()
74
- messages = data.get("messages", [])
75
- if not messages:
76
- raise HTTPException(status_code=400, detail="Messages are required")
77
-
78
- payload = {
79
- "inputs": {
80
- "messages": messages
81
- },
82
- "parameters": {
83
- "max_new_tokens": data.get("max_tokens", 2048),
84
- "temperature": data.get("temperature", 0.7),
85
- "top_p": data.get("top_p", 0.95),
86
- "do_sample": True
87
- }
88
- }
89
 
90
- response = await query_model(payload)
91
 
92
- if isinstance(response, dict) and "error" in response:
93
- raise HTTPException(status_code=500, detail=response["error"])
94
 
95
- response_text = response[0]["generated_text"]
96
 
97
- return JSONResponse(content=format_chat_response(response_text))
98
- except HTTPException as e:
99
- logger.error(f"Chat completion failed: {e.detail}")
100
- raise e
101
- except Exception as e:
102
- logger.error(f"Unexpected error: {e}")
103
- raise HTTPException(status_code=500, detail=str(e))
104
-
105
- def generate_response(messages):
106
- payload = {
107
- "inputs": {
108
- "messages": messages
109
- },
110
- "parameters": {
111
- "max_new_tokens": 2048,
112
- "temperature": 0.7,
113
- "top_p": 0.95,
114
- "do_sample": True
115
- }
116
- }
117
 
118
- try:
119
- response = requests.post(API_URL, headers=headers, json=payload)
120
- response.raise_for_status()
121
- result = response.json()
122
 
123
- if isinstance(result, dict) and "error" in result:
124
- return f"Error: {result['error']}"
125
 
126
- return result[0]["generated_text"]
127
- except requests.exceptions.RequestException as e:
128
- logger.error(f"Request failed: {e}")
129
- return f"Error: {e}"
130
-
131
- def chat_interface(messages):
132
- chat_history = []
133
- for message in messages:
134
- try:
135
- response = generate_response([{"role": "user", "content": message}])
136
- chat_history.append({"role": "user", "content": message})
137
- chat_history.append({"role": "assistant", "content": response})
138
- except Exception as e:
139
- chat_history.append({"role": "user", "content": message})
140
- chat_history.append({"role": "assistant", "content": f"Error: {str(e)}"})
141
- return chat_history
142
-
143
- # Create Gradio interface
144
- def gradio_app():
145
- return gr.ChatInterface(chat_interface, type="messages")
146
-
147
- # Mount both FastAPI and Gradio
148
- app = gr.mount_gradio_app(app, gradio_app(), path="/")
149
-
150
- # For running with uvicorn directly
151
- if __name__ == "__main__":
152
- import uvicorn
153
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from huggingface_hub import InferenceClient
4
+ import uvicorn
5
+
6
+
 
 
 
 
7
  app = FastAPI()
8
 
9
+ client = InferenceClient("Qwen/Qwen2.5-Coder-32B-Instruct")
10
+
11
+ class Item(BaseModel):
12
+ prompt: str
13
+ history: list
14
+ system_prompt: str
15
+ temperature: float = 0.0
16
+ max_new_tokens: int = 1048
17
+ top_p: float = 0.15
18
+ repetition_penalty: float = 1.0
19
+
20
+ def format_prompt(message, history):
21
+ prompt = "<s>"
22
+ for user_prompt, bot_response in history:
23
+ prompt += f"[INST] {user_prompt} [/INST]"
24
+ prompt += f" {bot_response}</s> "
25
+ prompt += f"[INST] {message} [/INST]"
26
+ return prompt
27
+
28
+ def generate(item: Item):
29
+ temperature = float(item.temperature)
30
+ if temperature < 1e-2:
31
+ temperature = 1e-2
32
+ top_p = float(item.top_p)
33
+
34
+ generate_kwargs = dict(
35
+ temperature=temperature,
36
+ max_new_tokens=item.max_new_tokens,
37
+ top_p=top_p,
38
+ repetition_penalty=item.repetition_penalty,
39
+ do_sample=True,
40
+ seed=42,
41
+ )
42
+
43
+ formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
44
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
45
+ output = ""
46
+
47
+ for response in stream:
48
+ output += response.token.text
49
+ return output
50
+
51
+ @app.post("/generate/")
52
+ async def generate_text(item: Item):
53
+ return {"response": generate(item)}
54
+
55
+
56
+
57
+ # import gradio as gr
58
+ # from fastapi import FastAPI, Request, HTTPException
59
+ # from fastapi.responses import JSONResponse
60
+ # import datetime
61
+ # import requests
62
+ # import os
63
+ # import logging
64
+ # import toml
65
+
66
+ # # Initialize FastAPI
67
+ # app = FastAPI()
68
+
69
+ # # Configure logging
70
+ # logging.basicConfig(level=logging.INFO)
71
+ # logger = logging.getLogger(__name__)
72
+
73
+ # # Load config
74
+ # with open("config.toml") as f:
75
+ # config = toml.load(f)
76
+
77
+ # #API_URL = os.getenv('API_URL')
78
+ # #API_TOKEN = os.getenv('API_TOKEN')
79
+ # # API_URL = 'https://ojciectadeusz-fastapi-inference-qwen2-5-coder-32-a0ab504.hf.space/v1/chat/completions'
80
+ # API_URL = 'https://ojciectadeusz-fastapi-inference-qwen2.5-coder-32b-instruct.hf.space/v1/chat/completions'
81
+ # headers = {
82
+ # "Authorization": f"Bearer {os.getenv('HF_API_TOKEN')}",
83
+ # "Content-Type": "application/json"
84
+ # }
85
+
86
+ # def format_chat_response(response_text, prompt_tokens=0, completion_tokens=0):
87
+ # return {
88
+ # "id": f"chatcmpl-{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}",
89
+ # "object": "chat.completion",
90
+ # "created": int(datetime.datetime.now().timestamp()),
91
+ # "model": "Qwen/Qwen2.5-Coder-32B",
92
+ # "choices": [{
93
+ # "index": 0,
94
+ # "message": {
95
+ # "role": "assistant",
96
+ # "content": response_text
97
+ # },
98
+ # "finish_reason": "stop"
99
+ # }],
100
+ # "usage": {
101
+ # "prompt_tokens": prompt_tokens,
102
+ # "completion_tokens": completion_tokens,
103
+ # "total_tokens": prompt_tokens + completion_tokens
104
+ # }
105
+ # }
106
+
107
+ # async def query_model(payload):
108
+ # try:
109
+ # response = requests.post(API_URL, headers=headers, json=payload)
110
+ # response.raise_for_status()
111
+ # return response.json()
112
+ # except requests.exceptions.RequestException as e:
113
+ # logger.error(f"Request failed: {e}")
114
+ # raise HTTPException(status_code=500, detail=str(e))
115
+
116
+ # @app.get("/status")
117
+ # async def status():
118
+ # try:
119
 
120
+ # response_text = os.getenv('HF_API_TOKEN') + "it's working"
121
+ # return JSONResponse(content=format_chat_response(response_text))
122
+ # except Exception as e:
123
+ # logger.error(f"Status check failed: {e}")
124
+ # raise HTTPException(status_code=500, detail=str(e))
125
+
126
+ # @app.post("/v1/chat/completions")
127
+ # async def chat_completion(request: Request):
128
+ # try:
129
+ # data = await request.json()
130
+ # messages = data.get("messages", [])
131
+ # if not messages:
132
+ # raise HTTPException(status_code=400, detail="Messages are required")
133
+
134
+ # payload = {
135
+ # "inputs": {
136
+ # "messages": messages
137
+ # },
138
+ # "parameters": {
139
+ # "max_new_tokens": data.get("max_tokens", 2048),
140
+ # "temperature": data.get("temperature", 0.7),
141
+ # "top_p": data.get("top_p", 0.95),
142
+ # "do_sample": True
143
+ # }
144
+ # }
145
 
146
+ # response = await query_model(payload)
147
 
148
+ # if isinstance(response, dict) and "error" in response:
149
+ # raise HTTPException(status_code=500, detail=response["error"])
150
 
151
+ # response_text = response[0]["generated_text"]
152
 
153
+ # return JSONResponse(content=format_chat_response(response_text))
154
+ # except HTTPException as e:
155
+ # logger.error(f"Chat completion failed: {e.detail}")
156
+ # raise e
157
+ # except Exception as e:
158
+ # logger.error(f"Unexpected error: {e}")
159
+ # raise HTTPException(status_code=500, detail=str(e))
160
+
161
+ # def generate_response(messages):
162
+ # payload = {
163
+ # "inputs": {
164
+ # "messages": messages
165
+ # },
166
+ # "parameters": {
167
+ # "max_new_tokens": 2048,
168
+ # "temperature": 0.7,
169
+ # "top_p": 0.95,
170
+ # "do_sample": True
171
+ # }
172
+ # }
173
 
174
+ # try:
175
+ # response = requests.post(API_URL, headers=headers, json=payload)
176
+ # response.raise_for_status()
177
+ # result = response.json()
178
 
179
+ # if isinstance(result, dict) and "error" in result:
180
+ # return f"Error: {result['error']}"
181
 
182
+ # return result[0]["generated_text"]
183
+ # except requests.exceptions.RequestException as e:
184
+ # logger.error(f"Request failed: {e}")
185
+ # return f"Error: {e}"
186
+
187
+ # def chat_interface(messages):
188
+ # chat_history = []
189
+ # for message in messages:
190
+ # try:
191
+ # response = generate_response([{"role": "user", "content": message}])
192
+ # chat_history.append({"role": "user", "content": message})
193
+ # chat_history.append({"role": "assistant", "content": response})
194
+ # except Exception as e:
195
+ # chat_history.append({"role": "user", "content": message})
196
+ # chat_history.append({"role": "assistant", "content": f"Error: {str(e)}"})
197
+ # return chat_history
198
+
199
+ # # Create Gradio interface
200
+ # def gradio_app():
201
+ # return gr.ChatInterface(chat_interface, type="messages")
202
+
203
+ # # Mount both FastAPI and Gradio
204
+ # app = gr.mount_gradio_app(app, gradio_app(), path="/")
205
+
206
+ # # For running with uvicorn directly
207
+ # if __name__ == "__main__":
208
+ # import uvicorn
209
+ # uvicorn.run(app, host="0.0.0.0", port=7860)