OjciecTadeusz commited on
Commit
defc45e
·
verified ·
1 Parent(s): e5928ae

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +36 -159
main.py CHANGED
@@ -41,13 +41,46 @@ def generate(item: Item):
41
  )
42
 
43
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
44
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=item.stream, details=item.details, return_full_text=item.return_full_text)
45
- output = ""
 
 
 
 
 
46
 
 
47
  for response in stream:
48
- output += response.token.text
 
 
 
 
49
  return output
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  @app.post("/generate/")
52
  async def generate_text(item: Item):
53
  try:
@@ -56,162 +89,6 @@ async def generate_text(item: Item):
56
  except Exception as e:
57
  raise HTTPException(status_code=500, detail=str(e))
58
 
59
- # from fastapi import FastAPI, HTTPException, Depends
60
- # from fastapi.security.api_key import APIKeyHeader
61
- # from pydantic import BaseModel
62
- # from huggingface_hub import InferenceClient, HfApi
63
- # from typing import List, Optional
64
- # import os
65
- # from dotenv import load_dotenv
66
-
67
- # # Load environment variables
68
- # load_dotenv()
69
-
70
- # # Initialize FastAPI app
71
- # app = FastAPI()
72
-
73
- # # Get HuggingFace token from environment variable
74
- # HF_TOKEN = os.getenv("HF_TOKEN")
75
- # if not HF_TOKEN:
76
- # raise ValueError("HF_TOKEN environment variable is not set")
77
-
78
- # # Setup API key authorization
79
- # API_KEY_NAME = "Authorization"
80
- # api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=True)
81
-
82
- # # Initialize HuggingFace client
83
- # try:
84
- # client = InferenceClient(
85
- # "mistralai/Mixtral-8x7B-Instruct-v0.1",
86
- # token=HF_TOKEN
87
- # )
88
- # # Verify token is valid
89
- # hf_api = HfApi(token=HF_TOKEN)
90
- # hf_api.whoami()
91
- # except Exception as e:
92
- # raise ValueError(f"Failed to initialize HuggingFace client: {str(e)}")
93
-
94
- # class ChatMessage(BaseModel):
95
- # role: str
96
- # content: str
97
-
98
- # class GenerationRequest(BaseModel):
99
- # prompt: str
100
- # message: Optional[str] = None
101
- # system_message: Optional[str] = None
102
- # history: Optional[List[ChatMessage]] = None
103
- # temperature: Optional[float] = 0.7
104
- # top_p: Optional[float] = 0.95
105
-
106
- # def format_prompt(message: str, history: List[ChatMessage] = None, system_message: str = None) -> str:
107
- # prompt = ""
108
-
109
- # if system_message:
110
- # prompt += f"<s>[INST] {system_message} [/INST]</s>"
111
-
112
- # if history:
113
- # for msg in history:
114
- # if msg.role == "user":
115
- # prompt += f"<s>[INST] {msg.content} [/INST]"
116
- # else:
117
- # prompt += f" {msg.content}</s>"
118
-
119
- # prompt += f"<s>[INST] {message} [/INST]"
120
- # return prompt
121
-
122
- # async def verify_token(api_key_header: str = Depends(api_key_header)):
123
- # if not api_key_header.startswith("Bearer "):
124
- # raise HTTPException(
125
- # status_code=401,
126
- # detail="Bearer token missing"
127
- # )
128
- # token = api_key_header.replace("Bearer ", "")
129
- # if token != HF_TOKEN:
130
- # raise HTTPException(
131
- # status_code=401,
132
- # detail="Invalid authentication credentials"
133
- # )
134
- # return token
135
-
136
- # @app.post("/generate/")
137
- # async def generate_text(
138
- # request: GenerationRequest,
139
- # token: str = Depends(verify_token)
140
- # ):
141
- # try:
142
- # message = request.prompt if request.prompt else request.message
143
- # if not message:
144
- # return [
145
- # {
146
- # "msg": "MSG!"
147
- # }
148
- # ]
149
-
150
- # formatted_prompt = format_prompt(
151
- # message=message,
152
- # history=request.history,
153
- # system_message=request.system_message
154
- # )
155
-
156
- # response = client.text_generation(
157
- # formatted_prompt,
158
- # temperature=max(request.temperature, 0.01),
159
- # top_p=request.top_p,
160
- # max_new_tokens=1048,
161
- # do_sample=True,
162
- # return_full_text=False
163
- # )
164
-
165
- # if not response:
166
- # return [
167
- # {
168
- # "detail": [
169
- # {
170
- # # "type": "server_error",
171
- # "loc": ["server"],
172
- # "msg": "No response received from model",
173
- # "input": None
174
- # }
175
- # ]
176
- # }
177
- # ]
178
-
179
- # # Construct the custom JSON response
180
- # return [
181
- # {
182
- # "msg": response
183
- # # "msg": [
184
- # # {
185
- # # # "type": "success",
186
- # # # "loc":[
187
- # # # "body",
188
- # # # "prompt"
189
- # # # ],
190
- # # # "loc": ["body"],
191
- # # # "msg": [
192
- # # # response,
193
- # # # formatted_prompt
194
- # # # ],
195
-
196
- # # }
197
- # # ]
198
- # }
199
- # ]
200
-
201
- # except Exception as e:
202
- # return [
203
- # {
204
- # "detail": [
205
- # {
206
- # "type": "server_error",
207
- # "loc": ["server"],
208
- # "msg": f"Error generating response: {str(e)}",
209
- # "input": None
210
- # }
211
- # ]
212
- # }
213
- # ]
214
-
215
  # @app.get("/health")
216
  # async def health_check():
217
  # return {
 
41
  )
42
 
43
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
44
+ stream = client.text_generation(
45
+ formatted_prompt,
46
+ **generate_kwargs,
47
+ stream=item.stream,
48
+ details=item.details,
49
+ return_full_text=item.return_full_text
50
+ )
51
 
52
+ output = ""
53
  for response in stream:
54
+ # Check if response has the attribute 'token'
55
+ if hasattr(response, 'token'):
56
+ output += response.token.text
57
+ else:
58
+ output += response # If not, treat it as a string
59
  return output
60
 
61
+ # def generate(item: Item):
62
+ # temperature = float(item.temperature)
63
+ # if temperature < 1e-2:
64
+ # temperature = 1e-2
65
+ # top_p = float(item.top_p)
66
+
67
+ # generate_kwargs = dict(
68
+ # temperature=temperature,
69
+ # max_new_tokens=1048,
70
+ # top_p=top_p,
71
+ # repetition_penalty=1.0,
72
+ # do_sample=True,
73
+ # seed=42,
74
+ # )
75
+
76
+ # formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
77
+ # stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=item.stream, details=item.details, return_full_text=item.return_full_text)
78
+ # output = ""
79
+
80
+ # for response in stream:
81
+ # output += response.token.text
82
+ # return output
83
+
84
  @app.post("/generate/")
85
  async def generate_text(item: Item):
86
  try:
 
89
  except Exception as e:
90
  raise HTTPException(status_code=500, detail=str(e))
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  # @app.get("/health")
93
  # async def health_check():
94
  # return {