import json import logging import requests import urllib3 urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) # Setup logging logging.basicConfig(level=logging.INFO) def check_server_health(cloud_gateway_api: str, header: dict) -> bool: """ Use the appropriate API endpoint to check the server health. Args: cloud_gateway_api: API endpoint to probe. header: Header for Authorization. Returns: True if server is active, false otherwise. """ try: response = requests.get( cloud_gateway_api + "model/info", headers=header, verify=False, ) response.raise_for_status() return True except requests.RequestException as e: logging.error(f"Failed to check server health: {e}") return False def request_generation( header: dict, message: str, system_prompt: str, cloud_gateway_api: str, model_name: str, max_new_tokens: int = 1024, temperature: float = 0.3, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, ): """ Request streaming generation from the cloud gateway API. Uses the simple requests module with stream=True to utilize token-by-token generation from LLM. Args: header: authorization header for the API. message: prompt from the user. system_prompt: system prompt to append. cloud_gateway_api (str): API endpoint to send the request. max_new_tokens: maximum number of tokens to generate, ignoring the number of tokens in the prompt. temperature: the value used to module the next token probabilities. top_p: if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty. Returns: """ payload = { "model": model_name, "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": message}, ], "max_tokens": max_new_tokens, "temperature": temperature, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "stream": True, # Enable streaming "serving_runtime": "vllm", } try: response = requests.post( cloud_gateway_api + "chat/conversation", headers=header, json=payload, verify=False, ) response.raise_for_status() # Append the conversation ID with the key X-Conversation-ID to the header header["X-Conversation-ID"] = response.json()["conversationId"] with requests.get( cloud_gateway_api + f"conversation/stream", headers=header, verify=False, stream=True, ) as response: for chunk in response.iter_lines(): if chunk: # Convert the chunk from bytes to a string and then parse it as json chunk_str = chunk.decode("utf-8") # Remove the `data: ` prefix from the chunk if it exists for _ in range(2): if chunk_str.startswith("data: "): chunk_str = chunk_str[len("data: ") :] # Skip empty chunks if chunk_str.strip() == "[DONE]": break # Parse the chunk into a JSON object try: chunk_json = json.loads(chunk_str) # Extract the "content" field from the choices if "choices" in chunk_json and chunk_json["choices"]: content = chunk_json["choices"][0]["delta"].get( "content", "" ) else: content = "" # Print the generated content as it's streamed if content: yield content except json.JSONDecodeError: # Handle any potential errors in decoding continue except requests.RequestException as e: logging.error(f"Failed to generate response: {e}") yield "Server not responding. Please try again later."