Spaces:
Running
Running
import json | |
import logging | |
import requests | |
import urllib3 | |
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
def check_server_health(cloud_gateway_api: str, header: dict) -> bool: | |
""" | |
Use the appropriate API endpoint to check the server health. | |
Args: | |
cloud_gateway_api: API endpoint to probe. | |
header: Header for Authorization. | |
Returns: | |
True if server is active, false otherwise. | |
""" | |
try: | |
response = requests.get( | |
cloud_gateway_api + "model/info", | |
headers=header, | |
verify=False, | |
) | |
response.raise_for_status() | |
return True | |
except requests.RequestException as e: | |
logging.error(f"Failed to check server health: {e}") | |
return False | |
def request_generation( | |
header: dict, | |
message: str, | |
system_prompt: str, | |
cloud_gateway_api: str, | |
model_name: str, | |
max_new_tokens: int = 1024, | |
temperature: float = 0.3, | |
frequency_penalty: float = 0.0, | |
presence_penalty: float = 0.0, | |
): | |
""" | |
Request streaming generation from the cloud gateway API. Uses the simple requests module with stream=True to utilize | |
token-by-token generation from LLM. | |
Args: | |
header: authorization header for the API. | |
message: prompt from the user. | |
system_prompt: system prompt to append. | |
cloud_gateway_api (str): API endpoint to send the request. | |
max_new_tokens: maximum number of tokens to generate, ignoring the number of tokens in the prompt. | |
temperature: the value used to module the next token probabilities. | |
top_p: if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p | |
or higher are kept for generation. | |
repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty. | |
Returns: | |
""" | |
payload = { | |
"model": model_name, | |
"messages": [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": message}, | |
], | |
"max_tokens": max_new_tokens, | |
"temperature": temperature, | |
"frequency_penalty": frequency_penalty, | |
"presence_penalty": presence_penalty, | |
"stream": True, # Enable streaming | |
"serving_runtime": "vllm", | |
} | |
try: | |
response = requests.post( | |
cloud_gateway_api + "chat/conversation", | |
headers=header, | |
json=payload, | |
verify=False, | |
) | |
response.raise_for_status() | |
# Append the conversation ID with the key X-Conversation-ID to the header | |
header["X-Conversation-ID"] = response.json()["conversationId"] | |
with requests.get( | |
cloud_gateway_api + f"conversation/stream", | |
headers=header, | |
verify=False, | |
stream=True, | |
) as response: | |
for chunk in response.iter_lines(): | |
if chunk: | |
# Convert the chunk from bytes to a string and then parse it as json | |
chunk_str = chunk.decode("utf-8") | |
# Remove the `data: ` prefix from the chunk if it exists | |
for _ in range(2): | |
if chunk_str.startswith("data: "): | |
chunk_str = chunk_str[len("data: ") :] | |
# Skip empty chunks | |
if chunk_str.strip() == "[DONE]": | |
break | |
# Parse the chunk into a JSON object | |
try: | |
chunk_json = json.loads(chunk_str) | |
# Extract the "content" field from the choices | |
if "choices" in chunk_json and chunk_json["choices"]: | |
content = chunk_json["choices"][0]["delta"].get( | |
"content", "" | |
) | |
else: | |
content = "" | |
# Print the generated content as it's streamed | |
if content: | |
yield content | |
except json.JSONDecodeError: | |
# Handle any potential errors in decoding | |
continue | |
except requests.RequestException as e: | |
logging.error(f"Failed to generate response: {e}") | |
yield "Server not responding. Please try again later." | |