Spaces:
Sleeping
Sleeping
File size: 6,137 Bytes
370a5dd 6168b0d 370a5dd af44f8f 370a5dd af44f8f 370a5dd af44f8f 370a5dd af44f8f 370a5dd af44f8f 370a5dd af44f8f 370a5dd af44f8f 370a5dd af44f8f 370a5dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
"""
License:
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
In no event shall the authors or copyright holders be liable
for any claim, damages or other liability, whether in an action of contract,otherwise,
arising from, out of or in connection with the software or the use or
other dealings in the software.
Copyright (c) 2024 pi19404. All rights reserved.
Authors:
pi19404 <[email protected]>
"""
"""
Gradio Interface for Shield Gemma LLM Evaluator
This module provides a Gradio interface to interact with the Shield Gemma LLM Evaluator.
It allows users to input JSON data and select various options to evaluate the content
for policy violations.
Functions:
my_inference_function: The main inference function to process input data and return results.
"""
import gradio as gr
from gradio_client import Client
import json
import threading
import os
from collections import OrderedDict
API_TOKEN=os.getenv("API_TOKEN")
lock = threading.Lock()
#client = Client("pi19404/ai-worker",hf_token=API_TOKEN)
# Create an OrderedDict to store clients, limited to 15 entries
client_cache = OrderedDict()
MAX_CACHE_SIZE = 15
default_client=Client("pi19404/ai-worker", hf_token=API_TOKEN)
def get_client_for_ip(ip_address,x_ip_token):
if x_ip_token is None:
x_ip_token=ip_address
#print("ipaddress is ",x_ip_token)
if x_ip_token is None:
new_client=default_client
else:
if x_ip_token in client_cache:
# Move the accessed item to the end (most recently used)
client_cache.move_to_end(x_ip_token)
return client_cache[x_ip_token]
# Create a new client
new_client = Client("pi19404/ai-worker", hf_token=API_TOKEN, headers={"X-IP-Token": x_ip_token})
# Add to cache, removing oldest if necessary
if len(client_cache) >= MAX_CACHE_SIZE:
client_cache.popitem(last=False)
client_cache[x_ip_token] = new_client
return new_client
def set_client_for_session(request: gr.Request):
# Collect all headers in a dictionary
all_headers = {header: value for header, value in request.headers.items()}
# Print headers to console
print("All request headers:")
print(json.dumps(all_headers, indent=2))
x_ip_token = request.headers.get('x-ip-token',None)
ip_address = request.client.host
print("ip address is ",ip_address)
client = get_client_for_ip(ip_address,x_ip_token)
# Return both the client and the headers
return client, json.dumps(all_headers, indent=2)
# The "gradio/text-to-image" space is a ZeroGPU space
def my_inference_function(client,input_data, output_data,mode, max_length, max_new_tokens, model_size):
"""
The main inference function to process input data and return results.
Args:
input_data (str or dict): The input data in JSON format.
mode (str): The mode of operation ("scoring" or "generative").
max_length (int): The maximum length of the input prompt.
max_new_tokens (int): The maximum number of new tokens to generate.
model_size (str): The size of the model to be used.
Returns:
str: The output data in JSON format.
"""
with lock:
try:
result = client.predict(
input_data=input_data,
output_data=output_data,
mode=mode,
max_length=max_length,
max_new_tokens=max_new_tokens,
model_size=model_size,
api_name="/my_inference_function"
)
print(result)
print("entering return",result)
return result # Pretty-print the JSON
except json.JSONDecodeError:
return json.dumps({"error": "Invalid JSON input"})
except KeyError:
return json.dumps({"error": "Missing 'input' key in JSON"})
except ValueError as e:
return json.dumps({"error": str(e)})
with gr.Blocks() as demo:
gr.Markdown("## LLM Safety Evaluation")
client = gr.State()
with gr.Tab("ShieldGemma2"):
input_text = gr.Textbox(label="Input Text")
output_text = gr.Textbox(
label="Response Text",
lines=5,
max_lines=10,
show_copy_button=True,
elem_classes=["wrap-text"]
)
mode_input = gr.Dropdown(choices=["scoring", "generative"], label="Prediction Mode")
max_length_input = gr.Number(label="Max Length", value=150)
max_new_tokens_input = gr.Number(label="Max New Tokens", value=1024)
model_size_input = gr.Dropdown(choices=["2B", "9B", "27B"], label="Model Size")
response_text = gr.Textbox(
label="Output Text",
lines=10,
max_lines=20,
show_copy_button=True,
elem_classes=["wrap-text"]
)
text_button = gr.Button("Submit")
text_button.click(fn=my_inference_function, inputs=[client,input_text, output_text, mode_input, max_length_input, max_new_tokens_input, model_size_input], outputs=response_text)
# with gr.Tab("API Input"):
# api_input = gr.JSON(label="Input JSON")
# mode_input_api = gr.Dropdown(choices=["scoring", "generative"], label="Mode")
# max_length_input_api = gr.Number(label="Max Length", value=150)
# max_new_tokens_input_api = gr.Number(label="Max New Tokens", value=None)
# model_size_input_api = gr.Dropdown(choices=["2B", "9B", "27B"], label="Model Size")
# api_output = gr.JSON(label="Output JSON")
# api_button = gr.Button("Submit")
# api_button.click(fn=my_inference_function, inputs=[api_input, api_output,mode_input_api, max_length_input_api, max_new_tokens_input_api, model_size_input_api], outputs=api_output)
demo.load(set_client_for_session,None,client)
demo.launch(share=True)
|