3v324v23 commited on
Commit
af44f8f
·
1 Parent(s): 6168b0d

rate limiting changes

Browse files
Files changed (1) hide show
  1. app.py +55 -6
app.py CHANGED
@@ -32,13 +32,61 @@ from gradio_client import Client
32
  import json
33
  import threading
34
  import os
 
35
 
36
  API_TOKEN=os.getenv("API_TOKEN")
37
 
38
  lock = threading.Lock()
39
- client = Client("pi19404/ai-worker",hf_token=API_TOKEN)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- def my_inference_function(input_data, output_data,mode, max_length, max_new_tokens, model_size):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  """
43
  The main inference function to process input data and return results.
44
 
@@ -78,8 +126,9 @@ def my_inference_function(input_data, output_data,mode, max_length, max_new_toke
78
 
79
  with gr.Blocks() as demo:
80
  gr.Markdown("## LLM Safety Evaluation")
81
-
82
  with gr.Tab("ShieldGemma2"):
 
83
  input_text = gr.Textbox(label="Input Text")
84
  output_text = gr.Textbox(
85
  label="Response Text",
@@ -100,7 +149,7 @@ with gr.Blocks() as demo:
100
  elem_classes=["wrap-text"]
101
  )
102
  text_button = gr.Button("Submit")
103
- text_button.click(fn=my_inference_function, inputs=[input_text, output_text, mode_input, max_length_input, max_new_tokens_input, model_size_input], outputs=response_text)
104
 
105
  # with gr.Tab("API Input"):
106
  # api_input = gr.JSON(label="Input JSON")
@@ -112,7 +161,7 @@ with gr.Blocks() as demo:
112
  # api_button = gr.Button("Submit")
113
  # api_button.click(fn=my_inference_function, inputs=[api_input, api_output,mode_input_api, max_length_input_api, max_new_tokens_input_api, model_size_input_api], outputs=api_output)
114
 
115
- demo.launch(share=True)
116
-
117
 
 
118
 
 
32
  import json
33
  import threading
34
  import os
35
+ from collections import OrderedDict
36
 
37
  API_TOKEN=os.getenv("API_TOKEN")
38
 
39
  lock = threading.Lock()
40
+ #client = Client("pi19404/ai-worker",hf_token=API_TOKEN)
41
+ # Create an OrderedDict to store clients, limited to 15 entries
42
+ client_cache = OrderedDict()
43
+ MAX_CACHE_SIZE = 15
44
+ default_client=Client("pi19404/ai-worker", hf_token=API_TOKEN)
45
+ def get_client_for_ip(ip_address,x_ip_token):
46
+ if x_ip_token is None:
47
+ x_ip_token=ip_address
48
+
49
+ #print("ipaddress is ",x_ip_token)
50
+ if x_ip_token is None:
51
+ new_client=default_client
52
+ else:
53
+
54
+ if x_ip_token in client_cache:
55
+ # Move the accessed item to the end (most recently used)
56
+ client_cache.move_to_end(x_ip_token)
57
+ return client_cache[x_ip_token]
58
+ # Create a new client
59
+ new_client = Client("pi19404/ai-worker", hf_token=API_TOKEN, headers={"X-IP-Token": x_ip_token})
60
+ # Add to cache, removing oldest if necessary
61
+ if len(client_cache) >= MAX_CACHE_SIZE:
62
+ client_cache.popitem(last=False)
63
+ client_cache[x_ip_token] = new_client
64
 
65
+
66
+ return new_client
67
+
68
+ def set_client_for_session(request: gr.Request):
69
+ # Collect all headers in a dictionary
70
+ all_headers = {header: value for header, value in request.headers.items()}
71
+
72
+ # Print headers to console
73
+ print("All request headers:")
74
+ print(json.dumps(all_headers, indent=2))
75
+
76
+ x_ip_token = request.headers.get('x-ip-token',None)
77
+ ip_address = request.client.host
78
+ print("ip address is ",ip_address)
79
+
80
+ client = get_client_for_ip(ip_address,x_ip_token)
81
+
82
+ # Return both the client and the headers
83
+ return client, json.dumps(all_headers, indent=2)
84
+
85
+
86
+ # The "gradio/text-to-image" space is a ZeroGPU space
87
+
88
+
89
+ def my_inference_function(client,input_data, output_data,mode, max_length, max_new_tokens, model_size):
90
  """
91
  The main inference function to process input data and return results.
92
 
 
126
 
127
  with gr.Blocks() as demo:
128
  gr.Markdown("## LLM Safety Evaluation")
129
+ client = gr.State()
130
  with gr.Tab("ShieldGemma2"):
131
+
132
  input_text = gr.Textbox(label="Input Text")
133
  output_text = gr.Textbox(
134
  label="Response Text",
 
149
  elem_classes=["wrap-text"]
150
  )
151
  text_button = gr.Button("Submit")
152
+ text_button.click(fn=my_inference_function, inputs=[client,input_text, output_text, mode_input, max_length_input, max_new_tokens_input, model_size_input], outputs=response_text)
153
 
154
  # with gr.Tab("API Input"):
155
  # api_input = gr.JSON(label="Input JSON")
 
161
  # api_button = gr.Button("Submit")
162
  # api_button.click(fn=my_inference_function, inputs=[api_input, api_output,mode_input_api, max_length_input_api, max_new_tokens_input_api, model_size_input_api], outputs=api_output)
163
 
164
+ demo.load(set_client_for_session,None,client)
 
165
 
166
+ demo.launch(share=True)
167