Spaces:
Running
Running
Commit
·
8f58f9b
1
Parent(s):
0f05fa7
feat: implement rate limiting for video generation requests
Browse files
app.py
CHANGED
@@ -16,6 +16,8 @@ from openai import OpenAI
|
|
16 |
import base64
|
17 |
from google.cloud import vision
|
18 |
from google.oauth2 import service_account
|
|
|
|
|
19 |
|
20 |
dotenv.load_dotenv()
|
21 |
|
@@ -25,6 +27,13 @@ SCRIPT_DIR = Path(__file__).parent
|
|
25 |
MODAL_ENDPOINT = os.getenv('FAL_MODAL_ENDPOINT')
|
26 |
MODAL_AUTH_TOKEN = os.getenv('MODAL_AUTH_TOKEN')
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
loras = [
|
29 |
{
|
30 |
"image": "https://huggingface.co/Remade-AI/Crash-zoom-out/resolve/main/example_videos/1.gif",
|
@@ -576,11 +585,30 @@ def update_selection(evt: gr.SelectData):
|
|
576 |
sentence = f"Selected LoRA: {selected_lora['title']}"
|
577 |
return selected_lora['id'], sentence
|
578 |
|
579 |
-
async def handle_generation(image_input, subject, selected_index, progress=gr.Progress(track_tqdm=True)):
|
580 |
try:
|
581 |
if selected_index is None:
|
582 |
raise gr.Error("You must select a LoRA before proceeding.")
|
583 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
584 |
# First, moderate the prompt
|
585 |
prompt_moderation = await moderate_prompt(subject)
|
586 |
print(f"Prompt moderation result: {prompt_moderation}")
|
@@ -823,6 +851,30 @@ css = '''
|
|
823 |
}
|
824 |
'''
|
825 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
826 |
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate", text_size="lg")) as demo:
|
827 |
selected_index = gr.State(None)
|
828 |
current_generation_id = gr.State(None)
|
@@ -892,6 +944,9 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="
|
|
892 |
|
893 |
subject = gr.Textbox(label="Describe your subject", placeholder="Cat toy")
|
894 |
|
|
|
|
|
|
|
895 |
with gr.Row():
|
896 |
button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
|
897 |
audio_button = gr.Button("Add Audio 🔒", interactive=False)
|
@@ -997,6 +1052,13 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="
|
|
997 |
inputs=None,
|
998 |
outputs=None
|
999 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1000 |
|
1001 |
def get_modal_auth_headers():
|
1002 |
"""Get authentication headers for Modal API requests"""
|
@@ -1008,6 +1070,32 @@ def get_modal_auth_headers():
|
|
1008 |
'Content-Type': 'application/json'
|
1009 |
}
|
1010 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1011 |
if __name__ == "__main__":
|
1012 |
demo.queue(default_concurrency_limit=20)
|
1013 |
demo.launch(ssr_mode=False, share=True)
|
|
|
16 |
import base64
|
17 |
from google.cloud import vision
|
18 |
from google.oauth2 import service_account
|
19 |
+
import time
|
20 |
+
from collections import defaultdict, deque
|
21 |
|
22 |
dotenv.load_dotenv()
|
23 |
|
|
|
27 |
MODAL_ENDPOINT = os.getenv('FAL_MODAL_ENDPOINT')
|
28 |
MODAL_AUTH_TOKEN = os.getenv('MODAL_AUTH_TOKEN')
|
29 |
|
30 |
+
# Rate limiting configuration
|
31 |
+
RATE_LIMIT_GENERATIONS = int(os.getenv('RATE_LIMIT_GENERATIONS', '5')) # Default 5 generations per hour
|
32 |
+
RATE_LIMIT_WINDOW = int(os.getenv('RATE_LIMIT_WINDOW', '3600')) # Default 1 hour in seconds
|
33 |
+
|
34 |
+
# In-memory rate limiting storage (for production, consider Redis)
|
35 |
+
user_generations = defaultdict(deque)
|
36 |
+
|
37 |
loras = [
|
38 |
{
|
39 |
"image": "https://huggingface.co/Remade-AI/Crash-zoom-out/resolve/main/example_videos/1.gif",
|
|
|
585 |
sentence = f"Selected LoRA: {selected_lora['title']}"
|
586 |
return selected_lora['id'], sentence
|
587 |
|
588 |
+
async def handle_generation(image_input, subject, selected_index, request: gr.Request, progress=gr.Progress(track_tqdm=True)):
|
589 |
try:
|
590 |
if selected_index is None:
|
591 |
raise gr.Error("You must select a LoRA before proceeding.")
|
592 |
|
593 |
+
# Check rate limit first
|
594 |
+
user_identifier = get_user_identifier(request)
|
595 |
+
is_allowed, remaining, reset_time = check_rate_limit(user_identifier)
|
596 |
+
|
597 |
+
if not is_allowed:
|
598 |
+
minutes = reset_time // 60
|
599 |
+
seconds = reset_time % 60
|
600 |
+
time_str = f"{minutes}m {seconds}s" if minutes > 0 else f"{seconds}s"
|
601 |
+
# Re-enable button on rate limit
|
602 |
+
yield None, None, gr.update(visible=False), gr.update(value="Generate", interactive=True)
|
603 |
+
raise gr.Error(f"Rate limit exceeded. Go to https://app.remade.ai for more generations and effects. Otherwise, you can generate {RATE_LIMIT_GENERATIONS} videos per hour. Try again in {time_str}.")
|
604 |
+
|
605 |
+
# Record this generation attempt
|
606 |
+
record_generation(user_identifier)
|
607 |
+
|
608 |
+
# Show remaining generations to user
|
609 |
+
if remaining > 0:
|
610 |
+
print(f"User {user_identifier} has {remaining} generations remaining this hour")
|
611 |
+
|
612 |
# First, moderate the prompt
|
613 |
prompt_moderation = await moderate_prompt(subject)
|
614 |
print(f"Prompt moderation result: {prompt_moderation}")
|
|
|
851 |
}
|
852 |
'''
|
853 |
|
854 |
+
def get_user_identifier(request: gr.Request) -> str:
|
855 |
+
"""Get user identifier from request (IP address)"""
|
856 |
+
if request and hasattr(request, 'client') and hasattr(request.client, 'host'):
|
857 |
+
return request.client.host
|
858 |
+
return "unknown"
|
859 |
+
|
860 |
+
def get_rate_limit_status(request: gr.Request) -> str:
|
861 |
+
"""Get current rate limit status for display to user"""
|
862 |
+
try:
|
863 |
+
user_identifier = get_user_identifier(request)
|
864 |
+
is_allowed, remaining, reset_time = check_rate_limit(user_identifier)
|
865 |
+
|
866 |
+
if remaining == 0 and reset_time > 0:
|
867 |
+
minutes = reset_time // 60
|
868 |
+
seconds = reset_time % 60
|
869 |
+
time_str = f"{minutes}m {seconds}s" if minutes > 0 else f"{seconds}s"
|
870 |
+
return f"⚠️ Rate limit reached. Try again in {time_str}"
|
871 |
+
elif remaining <= 2:
|
872 |
+
return f"⚡ {remaining} generations remaining this hour"
|
873 |
+
else:
|
874 |
+
return f"✅ {remaining} generations remaining this hour"
|
875 |
+
except:
|
876 |
+
return "✅ Ready to generate"
|
877 |
+
|
878 |
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="indigo", neutral_hue="slate", text_size="lg")) as demo:
|
879 |
selected_index = gr.State(None)
|
880 |
current_generation_id = gr.State(None)
|
|
|
944 |
|
945 |
subject = gr.Textbox(label="Describe your subject", placeholder="Cat toy")
|
946 |
|
947 |
+
# Rate limit status display
|
948 |
+
rate_limit_status = gr.Markdown("✅ Ready to generate", elem_id="rate_limit_status")
|
949 |
+
|
950 |
with gr.Row():
|
951 |
button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
|
952 |
audio_button = gr.Button("Add Audio 🔒", interactive=False)
|
|
|
1052 |
inputs=None,
|
1053 |
outputs=None
|
1054 |
)
|
1055 |
+
|
1056 |
+
# Update rate limit status on page load
|
1057 |
+
demo.load(
|
1058 |
+
fn=get_rate_limit_status,
|
1059 |
+
inputs=None,
|
1060 |
+
outputs=[rate_limit_status]
|
1061 |
+
)
|
1062 |
|
1063 |
def get_modal_auth_headers():
|
1064 |
"""Get authentication headers for Modal API requests"""
|
|
|
1070 |
'Content-Type': 'application/json'
|
1071 |
}
|
1072 |
|
1073 |
+
def check_rate_limit(user_identifier: str) -> tuple[bool, int, int]:
|
1074 |
+
"""
|
1075 |
+
Check if user has exceeded rate limit
|
1076 |
+
Returns: (is_allowed, remaining_generations, reset_time_seconds)
|
1077 |
+
"""
|
1078 |
+
current_time = time.time()
|
1079 |
+
user_queue = user_generations[user_identifier]
|
1080 |
+
|
1081 |
+
# Remove old entries outside the time window
|
1082 |
+
while user_queue and current_time - user_queue[0] > RATE_LIMIT_WINDOW:
|
1083 |
+
user_queue.popleft()
|
1084 |
+
|
1085 |
+
# Check if user has exceeded limit
|
1086 |
+
if len(user_queue) >= RATE_LIMIT_GENERATIONS:
|
1087 |
+
# Calculate when the oldest entry will expire
|
1088 |
+
reset_time = int(user_queue[0] + RATE_LIMIT_WINDOW - current_time)
|
1089 |
+
return False, 0, reset_time
|
1090 |
+
|
1091 |
+
remaining = RATE_LIMIT_GENERATIONS - len(user_queue)
|
1092 |
+
return True, remaining, 0
|
1093 |
+
|
1094 |
+
def record_generation(user_identifier: str):
|
1095 |
+
"""Record a new generation for the user"""
|
1096 |
+
current_time = time.time()
|
1097 |
+
user_generations[user_identifier].append(current_time)
|
1098 |
+
|
1099 |
if __name__ == "__main__":
|
1100 |
demo.queue(default_concurrency_limit=20)
|
1101 |
demo.launch(ssr_mode=False, share=True)
|