File size: 9,343 Bytes
370a5dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6168b0d
370a5dd
 
 
af44f8f
864f6e0
370a5dd
 
 
 
af44f8f
 
 
 
6860b2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af44f8f
7d96495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af44f8f
 
 
 
 
 
 
 
 
 
 
 
7d96495
af44f8f
 
 
 
 
 
370a5dd
af44f8f
 
 
 
7d96495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af44f8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370a5dd
 
 
7d96495
 
 
 
 
 
 
 
 
6860b2d
370a5dd
af44f8f
370a5dd
af44f8f
370a5dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af44f8f
370a5dd
 
 
 
 
 
 
 
 
 
 
af44f8f
370a5dd
af44f8f
370a5dd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
"""
License:
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
In no event shall the authors or copyright holders be liable
for any claim, damages or other liability, whether in an action of contract,otherwise,
arising from, out of or in connection with the software or the use or 
other dealings in the software.

Copyright (c) 2024 pi19404. All rights reserved.

Authors:
    pi19404 <pi19404@gmail.com>
"""


"""
Gradio Interface for Shield Gemma LLM Evaluator

This module provides a Gradio interface to interact with the Shield Gemma LLM Evaluator.
It allows users to input JSON data and select various options to evaluate the content
for policy violations.

Functions:
    my_inference_function: The main inference function to process input data and return results.
"""

import gradio as gr
from gradio_client import Client

import json
import threading
import os
from collections import OrderedDict
import httpx

API_TOKEN=os.getenv("API_TOKEN")

lock = threading.Lock()
#client = Client("pi19404/ai-worker",hf_token=API_TOKEN)
# Create an OrderedDict to store clients, limited to 15 entries
client_cache = OrderedDict()
MAX_CACHE_SIZE = 15


def my_inference_function(client,input_data, output_data,mode, max_length, max_new_tokens, model_size):
    """
    The main inference function to process input data and return results.
    
    Args:
        input_data (str or dict): The input data in JSON format.
        mode (str): The mode of operation ("scoring" or "generative").
        max_length (int): The maximum length of the input prompt.
        max_new_tokens (int): The maximum number of new tokens to generate.
        model_size (str): The size of the model to be used.
    
    Returns:
        str: The output data in JSON format.
    """
    with lock:
        try:
        

            
            result = client[0].predict(
                    input_data=input_data,
                    output_data=output_data,
                    mode=mode,
                    max_length=max_length,
                    max_new_tokens=max_new_tokens,
                    model_size=model_size,
                    api_name="/my_inference_function"
            )
            print(result)
            print("entering return",result)
            return result  # Pretty-print the JSON
        except json.JSONDecodeError:
            return json.dumps({"error": "Invalid JSON input"})
        except KeyError:
            return json.dumps({"error": "Missing 'input' key in JSON"})
        except ValueError as e:
            return json.dumps({"error": str(e)})


def wake_up_space_with_retries(space_url, token, retries=5, wait_time=10):
    """
    Attempt to wake up the Hugging Face Space with retries.
    Retries a number of times in case of a delay due to the Space waking up.
    
    :param space_url: The URL of the Hugging Face Space.
    :param token: The Hugging Face API token.
    :param retries: Number of retries if the Space is sleeping.
    :param wait_time: Time to wait between retries (in seconds).
    """
    for attempt in range(retries):
        try:
            print(f"Attempt {attempt + 1} to wake up the Space...")
            
            # Initialize the Gradio Client
            client = Client(space_url, hf_token=token, timeout=httpx.Timeout(30.0))  # 30-second timeout
            

            my_inference_function(client,"test input","",scoring,10,10,"2B")

            # Make a prediction or call to wake the Space
            #result = client.predict("<your_input>")  # Replace with actual inputs
            print("Space is awake and ready!")
            return client

        except httpx.ReadTimeout:
            print(f"Request timed out on attempt {attempt + 1}. Retrying in {wait_time} seconds...")
            time.sleep(wait_time)
        
        except Exception as e:
            print(f"An error occurred on attempt {attempt + 1}: {e}")
        
        # Wait before retrying
        if attempt < retries - 1:
            print(f"Waiting for {wait_time} seconds before retrying...")

    print("Space is still not active after multiple attempts.")
    return None


#default_client=Client("pi19404/ai-worker", hf_token=API_TOKEN)

default_client=wake_up_space_with_retries("pi19404/ai-worker",API_TOKEN)

def get_client_for_ip(ip_address,x_ip_token):
    """
    Retrieve or create a client for the given IP address.

    This function implements a caching mechanism to store up to MAX_CACHE_SIZE clients.
    If a client for the given IP exists in the cache, it's returned and moved to the end
    of the cache (marking it as most recently used). If not, a new client is created,
    added to the cache, and the least recently used client is removed if the cache is full.

    Args:
        ip_address (str): The IP address of the client.
        x_ip_token (str): The X-IP-Token header value for the client.

    Returns:
        Client: A Gradio client instance for the given IP address.
    """

    if x_ip_token is None:
        x_ip_token=ip_address

    #print("ipaddress is ",x_ip_token)
    if x_ip_token is None:
        new_client=default_client
    else:
       
        if x_ip_token in client_cache:
            # Move the accessed item to the end (most recently used)
            client_cache.move_to_end(x_ip_token)
            return client_cache[x_ip_token]

        # Create a new client
        new_client = Client("pi19404/ai-worker", hf_token=API_TOKEN, headers={"X-IP-Token": x_ip_token})
        # Add to cache, removing oldest if necessary
        if len(client_cache) >= MAX_CACHE_SIZE:
            client_cache.popitem(last=False)
        client_cache[x_ip_token] = new_client

    
    return new_client

def set_client_for_session(request: gr.Request):
    """
    Set up a client for the current session and collect request headers.

    This function is called when a new session is initiated. It retrieves or creates
    a client for the session's IP address and collects all request headers for debugging.

    Args:
        request (gr.Request): The Gradio request object for the current session.

    Returns:
        tuple: A tuple containing:
            - Client: The Gradio client instance for the session.
            - str: A JSON string of all request headers.
    """

    # Collect all headers in a dictionary
    all_headers = {header: value for header, value in request.headers.items()}
    
    # Print headers to console
    print("All request headers:")
    print(json.dumps(all_headers, indent=2))

    x_ip_token = request.headers.get('x-ip-token',None)
    ip_address = request.client.host
    print("ip address is ",ip_address)
    
    client = get_client_for_ip(ip_address,x_ip_token)
    
    # Return both the client and the headers
    return client, json.dumps(all_headers, indent=2)


    # The "gradio/text-to-image" space is a ZeroGPU space
    
    


with gr.Blocks() as demo:
    """
    Main Gradio interface setup.

    This block sets up the Gradio interface, including:
    - A State component to store the client for the session.
    - A JSON component to display request headers for debugging.
    - Other UI components (not shown in this snippet).
    - A load event that calls set_client_for_session when the interface is loaded.
    """

    gr.Markdown("## LLM Safety Evaluation")
    client = gr.State()
    with gr.Tab("ShieldGemma2"):
        
        input_text = gr.Textbox(label="Input Text")
        output_text = gr.Textbox(
            label="Response Text",
            lines=5,
            max_lines=10,
            show_copy_button=True,
            elem_classes=["wrap-text"]
        )
        mode_input = gr.Dropdown(choices=["scoring", "generative"], label="Prediction Mode")
        max_length_input = gr.Number(label="Max Length", value=150)
        max_new_tokens_input = gr.Number(label="Max New Tokens", value=1024)
        model_size_input = gr.Dropdown(choices=["2B", "9B", "27B"], label="Model Size")
        response_text = gr.Textbox(
            label="Output Text",
            lines=10,
            max_lines=20,
            show_copy_button=True,
            elem_classes=["wrap-text"]
        )
        text_button = gr.Button("Submit")
        text_button.click(fn=my_inference_function, inputs=[client,input_text, output_text, mode_input, max_length_input, max_new_tokens_input, model_size_input], outputs=response_text)
    
    # with gr.Tab("API Input"):
    #     api_input = gr.JSON(label="Input JSON")
    #     mode_input_api = gr.Dropdown(choices=["scoring", "generative"], label="Mode")
    #     max_length_input_api = gr.Number(label="Max Length", value=150)
    #     max_new_tokens_input_api = gr.Number(label="Max New Tokens", value=None)
    #     model_size_input_api = gr.Dropdown(choices=["2B", "9B", "27B"], label="Model Size")
    #     api_output = gr.JSON(label="Output JSON")
    #     api_button = gr.Button("Submit")
    #     api_button.click(fn=my_inference_function, inputs=[api_input, api_output,mode_input_api, max_length_input_api, max_new_tokens_input_api, model_size_input_api], outputs=api_output)

    demo.load(set_client_for_session,None,client)

demo.launch(share=True)