gemma3-27b-mi-amd / gateway.py
Lohia, Aditya
update space
f24a24a
import json
import logging
import requests
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# Setup logging
logging.basicConfig(level=logging.INFO)
def check_server_health(cloud_gateway_api: str, header: dict) -> bool:
"""
Use the appropriate API endpoint to check the server health.
Args:
cloud_gateway_api: API endpoint to probe.
header: Header for Authorization.
Returns:
True if server is active, false otherwise.
"""
try:
response = requests.get(
cloud_gateway_api + "model/info",
headers=header,
verify=False,
)
response.raise_for_status()
return True
except requests.RequestException as e:
logging.error(f"Failed to check server health: {e}")
return False
def request_generation(
header: dict,
message: str,
system_prompt: str,
cloud_gateway_api: str,
model_name: str,
max_new_tokens: int = 1024,
temperature: float = 0.3,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
):
"""
Request streaming generation from the cloud gateway API. Uses the simple requests module with stream=True to utilize
token-by-token generation from LLM.
Args:
header: authorization header for the API.
message: prompt from the user.
system_prompt: system prompt to append.
cloud_gateway_api (str): API endpoint to send the request.
max_new_tokens: maximum number of tokens to generate, ignoring the number of tokens in the prompt.
temperature: the value used to module the next token probabilities.
top_p: if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p
or higher are kept for generation.
repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty.
Returns:
"""
payload = {
"model": model_name,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": message},
],
"max_tokens": max_new_tokens,
"temperature": temperature,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"stream": True, # Enable streaming
"serving_runtime": "vllm",
}
try:
response = requests.post(
cloud_gateway_api + "chat/conversation",
headers=header,
json=payload,
verify=False,
)
response.raise_for_status()
# Append the conversation ID with the key X-Conversation-ID to the header
header["X-Conversation-ID"] = response.json()["conversationId"]
with requests.get(
cloud_gateway_api + f"conversation/stream",
headers=header,
verify=False,
stream=True,
) as response:
for chunk in response.iter_lines():
if chunk:
# Convert the chunk from bytes to a string and then parse it as json
chunk_str = chunk.decode("utf-8")
# Remove the `data: ` prefix from the chunk if it exists
for _ in range(2):
if chunk_str.startswith("data: "):
chunk_str = chunk_str[len("data: ") :]
# Skip empty chunks
if chunk_str.strip() == "[DONE]":
break
# Parse the chunk into a JSON object
try:
chunk_json = json.loads(chunk_str)
# Extract the "content" field from the choices
if "choices" in chunk_json and chunk_json["choices"]:
content = chunk_json["choices"][0]["delta"].get(
"content", ""
)
else:
content = ""
# Print the generated content as it's streamed
if content:
yield content
except json.JSONDecodeError:
# Handle any potential errors in decoding
continue
except requests.RequestException as e:
logging.error(f"Failed to generate response: {e}")
yield "Server not responding. Please try again later."