alex-remade commited on
Commit
8f58f9b
·
1 Parent(s): 0f05fa7

feat: implement rate limiting for video generation requests

Browse files
Files changed (1) hide show
  1. app.py +89 -1
app.py CHANGED
@@ -16,6 +16,8 @@ from openai import OpenAI
16
  import base64
17
  from google.cloud import vision
18
  from google.oauth2 import service_account
 
 
19
 
20
  dotenv.load_dotenv()
21
 
@@ -25,6 +27,13 @@ SCRIPT_DIR = Path(__file__).parent
25
  MODAL_ENDPOINT = os.getenv('FAL_MODAL_ENDPOINT')
26
  MODAL_AUTH_TOKEN = os.getenv('MODAL_AUTH_TOKEN')
27
 
 
 
 
 
 
 
 
28
  loras = [
29
  {
30
  "image": "https://huggingface.co/Remade-AI/Crash-zoom-out/resolve/main/example_videos/1.gif",
@@ -576,11 +585,30 @@ def update_selection(evt: gr.SelectData):
576
  sentence = f"Selected LoRA: {selected_lora['title']}"
577
  return selected_lora['id'], sentence
578
 
579
- async def handle_generation(image_input, subject, selected_index, progress=gr.Progress(track_tqdm=True)):
580
  try:
581
  if selected_index is None:
582
  raise gr.Error("You must select a LoRA before proceeding.")
583
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
584
  # First, moderate the prompt
585
  prompt_moderation = await moderate_prompt(subject)
586
  print(f"Prompt moderation result: {prompt_moderation}")
@@ -823,6 +851,30 @@ css = '''
823
  }
824
  '''
825
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
826
  with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate", text_size="lg")) as demo:
827
  selected_index = gr.State(None)
828
  current_generation_id = gr.State(None)
@@ -892,6 +944,9 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="
892
 
893
  subject = gr.Textbox(label="Describe your subject", placeholder="Cat toy")
894
 
 
 
 
895
  with gr.Row():
896
  button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
897
  audio_button = gr.Button("Add Audio 🔒", interactive=False)
@@ -997,6 +1052,13 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="
997
  inputs=None,
998
  outputs=None
999
  )
 
 
 
 
 
 
 
1000
 
1001
  def get_modal_auth_headers():
1002
  """Get authentication headers for Modal API requests"""
@@ -1008,6 +1070,32 @@ def get_modal_auth_headers():
1008
  'Content-Type': 'application/json'
1009
  }
1010
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1011
  if __name__ == "__main__":
1012
  demo.queue(default_concurrency_limit=20)
1013
  demo.launch(ssr_mode=False, share=True)
 
16
  import base64
17
  from google.cloud import vision
18
  from google.oauth2 import service_account
19
+ import time
20
+ from collections import defaultdict, deque
21
 
22
  dotenv.load_dotenv()
23
 
 
27
  MODAL_ENDPOINT = os.getenv('FAL_MODAL_ENDPOINT')
28
  MODAL_AUTH_TOKEN = os.getenv('MODAL_AUTH_TOKEN')
29
 
30
+ # Rate limiting configuration
31
+ RATE_LIMIT_GENERATIONS = int(os.getenv('RATE_LIMIT_GENERATIONS', '5')) # Default 5 generations per hour
32
+ RATE_LIMIT_WINDOW = int(os.getenv('RATE_LIMIT_WINDOW', '3600')) # Default 1 hour in seconds
33
+
34
+ # In-memory rate limiting storage (for production, consider Redis)
35
+ user_generations = defaultdict(deque)
36
+
37
  loras = [
38
  {
39
  "image": "https://huggingface.co/Remade-AI/Crash-zoom-out/resolve/main/example_videos/1.gif",
 
585
  sentence = f"Selected LoRA: {selected_lora['title']}"
586
  return selected_lora['id'], sentence
587
 
588
+ async def handle_generation(image_input, subject, selected_index, request: gr.Request, progress=gr.Progress(track_tqdm=True)):
589
  try:
590
  if selected_index is None:
591
  raise gr.Error("You must select a LoRA before proceeding.")
592
 
593
+ # Check rate limit first
594
+ user_identifier = get_user_identifier(request)
595
+ is_allowed, remaining, reset_time = check_rate_limit(user_identifier)
596
+
597
+ if not is_allowed:
598
+ minutes = reset_time // 60
599
+ seconds = reset_time % 60
600
+ time_str = f"{minutes}m {seconds}s" if minutes > 0 else f"{seconds}s"
601
+ # Re-enable button on rate limit
602
+ yield None, None, gr.update(visible=False), gr.update(value="Generate", interactive=True)
603
+ raise gr.Error(f"Rate limit exceeded. Go to https://app.remade.ai for more generations and effects. Otherwise, you can generate {RATE_LIMIT_GENERATIONS} videos per hour. Try again in {time_str}.")
604
+
605
+ # Record this generation attempt
606
+ record_generation(user_identifier)
607
+
608
+ # Show remaining generations to user
609
+ if remaining > 0:
610
+ print(f"User {user_identifier} has {remaining} generations remaining this hour")
611
+
612
  # First, moderate the prompt
613
  prompt_moderation = await moderate_prompt(subject)
614
  print(f"Prompt moderation result: {prompt_moderation}")
 
851
  }
852
  '''
853
 
854
+ def get_user_identifier(request: gr.Request) -> str:
855
+ """Get user identifier from request (IP address)"""
856
+ if request and hasattr(request, 'client') and hasattr(request.client, 'host'):
857
+ return request.client.host
858
+ return "unknown"
859
+
860
+ def get_rate_limit_status(request: gr.Request) -> str:
861
+ """Get current rate limit status for display to user"""
862
+ try:
863
+ user_identifier = get_user_identifier(request)
864
+ is_allowed, remaining, reset_time = check_rate_limit(user_identifier)
865
+
866
+ if remaining == 0 and reset_time > 0:
867
+ minutes = reset_time // 60
868
+ seconds = reset_time % 60
869
+ time_str = f"{minutes}m {seconds}s" if minutes > 0 else f"{seconds}s"
870
+ return f"⚠️ Rate limit reached. Try again in {time_str}"
871
+ elif remaining <= 2:
872
+ return f"⚡ {remaining} generations remaining this hour"
873
+ else:
874
+ return f"✅ {remaining} generations remaining this hour"
875
+ except:
876
+ return "✅ Ready to generate"
877
+
878
  with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate", text_size="lg")) as demo:
879
  selected_index = gr.State(None)
880
  current_generation_id = gr.State(None)
 
944
 
945
  subject = gr.Textbox(label="Describe your subject", placeholder="Cat toy")
946
 
947
+ # Rate limit status display
948
+ rate_limit_status = gr.Markdown("✅ Ready to generate", elem_id="rate_limit_status")
949
+
950
  with gr.Row():
951
  button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
952
  audio_button = gr.Button("Add Audio 🔒", interactive=False)
 
1052
  inputs=None,
1053
  outputs=None
1054
  )
1055
+
1056
+ # Update rate limit status on page load
1057
+ demo.load(
1058
+ fn=get_rate_limit_status,
1059
+ inputs=None,
1060
+ outputs=[rate_limit_status]
1061
+ )
1062
 
1063
  def get_modal_auth_headers():
1064
  """Get authentication headers for Modal API requests"""
 
1070
  'Content-Type': 'application/json'
1071
  }
1072
 
1073
+ def check_rate_limit(user_identifier: str) -> tuple[bool, int, int]:
1074
+ """
1075
+ Check if user has exceeded rate limit
1076
+ Returns: (is_allowed, remaining_generations, reset_time_seconds)
1077
+ """
1078
+ current_time = time.time()
1079
+ user_queue = user_generations[user_identifier]
1080
+
1081
+ # Remove old entries outside the time window
1082
+ while user_queue and current_time - user_queue[0] > RATE_LIMIT_WINDOW:
1083
+ user_queue.popleft()
1084
+
1085
+ # Check if user has exceeded limit
1086
+ if len(user_queue) >= RATE_LIMIT_GENERATIONS:
1087
+ # Calculate when the oldest entry will expire
1088
+ reset_time = int(user_queue[0] + RATE_LIMIT_WINDOW - current_time)
1089
+ return False, 0, reset_time
1090
+
1091
+ remaining = RATE_LIMIT_GENERATIONS - len(user_queue)
1092
+ return True, remaining, 0
1093
+
1094
+ def record_generation(user_identifier: str):
1095
+ """Record a new generation for the user"""
1096
+ current_time = time.time()
1097
+ user_generations[user_identifier].append(current_time)
1098
+
1099
  if __name__ == "__main__":
1100
  demo.queue(default_concurrency_limit=20)
1101
  demo.launch(ssr_mode=False, share=True)