import os import logging import gradio as gr from typing import Iterator from dialog import get_dialog_box from gateway import check_server_health, request_generation # Setup logging logging.basicConfig(level=logging.INFO) # CONSTANTS # Get max new tokens from environment variable, if it is not set, default to 2048 MAX_NEW_TOKENS: int = os.getenv("MAX_NEW_TOKENS", 2048) # Validate environment variables CLOUD_GATEWAY_API = os.getenv("API_ENDPOINT") if not CLOUD_GATEWAY_API: raise EnvironmentError("API_ENDPOINT is not set.") MODEL_NAME: str = os.getenv("MODEL_NAME") if not MODEL_NAME: raise EnvironmentError("MODEL_NAME is not set.") # Get API Key API_KEY = os.getenv("API_KEY") if not API_KEY: # simple check to validate API Key raise Exception("API Key not valid.") # Create a header, avoid declaring multiple times HEADER = {"x-api-key": f"{API_KEY}"} def toggle_ui(): """ Function to toggle the visibility of the UI based on the server health Returns: hide/show main ui/dialog """ health = check_server_health(cloud_gateway_api=CLOUD_GATEWAY_API, header=HEADER) if health: return gr.update(visible=True), gr.update( visible=False ) # Show main UI, hide dialog else: return gr.update(visible=False), gr.update( visible=True ) # Hide main UI, show dialog def generate( message: str, chat_history: list, system_prompt: str, max_new_tokens: int = 1024, temperature: float = 0.6, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, ) -> Iterator[str]: """Send a request to backend, fetch the streaming responses and emit to the UI. Args: message (str): input message from the user chat_history (list[tuple[str, str]]): entire chat history of the session system_prompt (str): system prompt max_new_tokens (int, optional): maximum number of tokens to generate, ignoring the number of tokens in the prompt. Defaults to 1024. temperature (float, optional): the value used to module the next token probabilities. Defaults to 0.6. top_p (float, optional): if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to 0.9. top_k (int, optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to 50. repetition_penalty (float, optional): the parameter for repetition penalty. 1.0 means no penalty. Defaults to 1.2. Yields: Iterator[str]: Streaming responses to the UI """ # sample method to yield responses from the llm model outputs = [] for text in request_generation( header=HEADER, message=message, system_prompt=system_prompt, max_new_tokens=max_new_tokens, temperature=temperature, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, cloud_gateway_api=CLOUD_GATEWAY_API, model_name=MODEL_NAME, ): outputs.append(text) yield "".join(outputs) chat_interface = gr.ChatInterface( fn=generate, additional_inputs=[ gr.Textbox( label="System prompt", value="You are a highly capable AI assistant. Provide accurate, concise, and fact-based responses that are directly relevant to the user's query. Avoid speculation, ensure logical consistency, and maintain clarity in longer outputs. Keep answers well-structured and under 1200 tokens unless explicitly requested otherwise.", lines=3, ), gr.Slider( label="Max New Tokens", minimum=1, maximum=MAX_NEW_TOKENS, step=1, value=2048, ), gr.Slider( label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.3, ), gr.Slider( label="Frequency penalty", minimum=-2.0, maximum=2.0, step=0.1, value=0.0, ), gr.Slider( label="Presence penalty", minimum=-2.0, maximum=2.0, step=0.1, value=0.0, ), ], stop_btn=None, examples=[ ["Plan a three-day trip to Washington DC for Cherry Blossom Festival."], [ "Compose a short, joyful musical piece for kids celebrating spring sunshine and blossom." ], ["Can you explain briefly to me what is the Python programming language?"], ["Explain the plot of Cinderella in a sentence."], ["How many hours does it take a man to eat a Helicopter?"], ["Write a 100-word article on 'Benefits of Open-Source in AI research'."], ], cache_examples=False, ) with gr.Blocks(css="style.css", fill_height=True) as demo: # Get the server status before displaying UI visibility = check_server_health(CLOUD_GATEWAY_API, header=HEADER) # Container for the main interface with gr.Column(visible=visibility, elem_id="main_ui") as main_ui: gr.Markdown( f""" # Gemma 3 27b Instruct This Space is an Alpha release that demonstrates [Gemma-3-27B-It](https://huggingface.co/google/gemma-3-27b-it) model running on AMD MI300 infrastructure. The space is built with Google Gemma 3 [License](https://ai.google.dev/gemma/terms). Feel free to play with it! """ ) chat_interface.render() # Dialog box using Markdown for the error message with gr.Row(visible=(not visibility), elem_id="dialog_box") as dialog_box: # Add spinner and message get_dialog_box() # Timer to check server health every 5 seconds and update UI timer = gr.Timer(value=10) timer.tick(fn=toggle_ui, outputs=[main_ui, dialog_box]) if __name__ == "__main__": demo.queue( max_size=int(os.getenv("QUEUE")), default_concurrency_limit=int(os.getenv("CONCURRENCY_LIMIT")), ).launch()