Spaces:
Running
Running
File size: 4,526 Bytes
121f84a f24a24a 121f84a f24a24a 121f84a f24a24a 121f84a f24a24a 121f84a f24a24a 121f84a f24a24a 121f84a f24a24a 121f84a f24a24a 121f84a f24a24a 121f84a f24a24a 121f84a f24a24a 121f84a f24a24a 121f84a f24a24a 121f84a f24a24a 121f84a f24a24a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import json
import logging
import requests
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# Setup logging
logging.basicConfig(level=logging.INFO)
def check_server_health(cloud_gateway_api: str, header: dict) -> bool:
"""
Use the appropriate API endpoint to check the server health.
Args:
cloud_gateway_api: API endpoint to probe.
header: Header for Authorization.
Returns:
True if server is active, false otherwise.
"""
try:
response = requests.get(
cloud_gateway_api + "model/info",
headers=header,
verify=False,
)
response.raise_for_status()
return True
except requests.RequestException as e:
logging.error(f"Failed to check server health: {e}")
return False
def request_generation(
header: dict,
message: str,
system_prompt: str,
cloud_gateway_api: str,
model_name: str,
max_new_tokens: int = 1024,
temperature: float = 0.3,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
):
"""
Request streaming generation from the cloud gateway API. Uses the simple requests module with stream=True to utilize
token-by-token generation from LLM.
Args:
header: authorization header for the API.
message: prompt from the user.
system_prompt: system prompt to append.
cloud_gateway_api (str): API endpoint to send the request.
max_new_tokens: maximum number of tokens to generate, ignoring the number of tokens in the prompt.
temperature: the value used to module the next token probabilities.
top_p: if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p
or higher are kept for generation.
repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty.
Returns:
"""
payload = {
"model": model_name,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": message},
],
"max_tokens": max_new_tokens,
"temperature": temperature,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"stream": True, # Enable streaming
"serving_runtime": "vllm",
}
try:
response = requests.post(
cloud_gateway_api + "chat/conversation",
headers=header,
json=payload,
verify=False,
)
response.raise_for_status()
# Append the conversation ID with the key X-Conversation-ID to the header
header["X-Conversation-ID"] = response.json()["conversationId"]
with requests.get(
cloud_gateway_api + f"conversation/stream",
headers=header,
verify=False,
stream=True,
) as response:
for chunk in response.iter_lines():
if chunk:
# Convert the chunk from bytes to a string and then parse it as json
chunk_str = chunk.decode("utf-8")
# Remove the `data: ` prefix from the chunk if it exists
for _ in range(2):
if chunk_str.startswith("data: "):
chunk_str = chunk_str[len("data: ") :]
# Skip empty chunks
if chunk_str.strip() == "[DONE]":
break
# Parse the chunk into a JSON object
try:
chunk_json = json.loads(chunk_str)
# Extract the "content" field from the choices
if "choices" in chunk_json and chunk_json["choices"]:
content = chunk_json["choices"][0]["delta"].get(
"content", ""
)
else:
content = ""
# Print the generated content as it's streamed
if content:
yield content
except json.JSONDecodeError:
# Handle any potential errors in decoding
continue
except requests.RequestException as e:
logging.error(f"Failed to generate response: {e}")
yield "Server not responding. Please try again later."
|