webhook / gradio_webhooks.py
Wauplin's picture
Wauplin HF staff
refacto + add example
d9c353a
raw
history blame
3.92 kB
import os
from pathlib import Path
from typing import Literal, Optional, Set, Union
import gradio as gr
from fastapi import HTTPException, Request
from pydantic import BaseModel
class GradioWebhookApp:
"""
```py
from gradio_webhooks import GradioWebhookApp
app = GradioWebhookApp()
@app.add_webhook("/test_webhook")
async def hello():
return {"in_gradio": True}
app.ready()
```
"""
def __init__(
self,
landing_path: Union[str, Path] = "README.md",
webhook_secret: Optional[str] = None,
) -> None:
# Use README.md as landing page or provide any markdown file
landing_path = Path(landing_path)
landing_content = landing_path.read_text()
if landing_path.name == "README.md":
landing_content = landing_content.split("---")[-1].strip()
# Simple gradio app with landing content
block = gr.Blocks()
with block:
gr.Markdown(landing_content)
# Launch gradio app:
# - as non-blocking so that webhooks can be added afterwards
# - as shared if launch locally (to receive webhooks)
app, _, _ = block.launch(prevent_thread_lock=True, share=not block.is_space)
self.gradio_app = block
self.fastapi_app = app
self.webhook_paths: Set[str] = set()
# Add auth middleware to check the "X-Webhook-Secret" header
self._webhook_secret = webhook_secret or os.getenv("WEBHOOK_SECRET")
if self._webhook_secret is None:
print(
"\nWebhook secret is not defined. This means your webhook endpoints will be open to everyone."
)
print(
"To add a secret, set `WEBHOOK_SECRET` as environment variable or pass it at initialization: "
"\n\t`app = GradioWebhookApp(webhook_secret='my_secret', ...)`"
)
print(
"For more details about Webhook secrets, please refer to https://huggingface.co/docs/hub/webhooks#webhook-secret."
)
app.middleware("http")(self._webhook_secret_middleware)
def add_webhook(self, path: str):
"""Decorator to add a webhook to the server app."""
self.webhook_paths.add(path)
return self.fastapi_app.get(path)
def ready(self) -> None:
"""Set the app as "ready" and block main thread to keep it running."""
url = (
self.gradio_app.share_url
if self.gradio_app.share_url is not None
else self.gradio_app.local_url
).strip("/")
print("\nWebhooks are correctly setup and ready to use:")
print("\n".join(f" - POST {url}{webhook}" for webhook in self.webhook_paths))
print(f"Checkout {url}/docs for more details.")
self.gradio_app.block_thread()
async def _webhook_secret_middleware(self, request: Request, call_next) -> None:
"""Middleware to check "X-Webhook-Secret" header on every webhook request."""
if request.url.path in self.webhook_paths:
if self._webhook_secret is not None:
request_secret = request.headers.get("x-webhook-secret")
if request_secret is None:
raise HTTPException(401)
if request_secret != self._webhook_secret:
raise HTTPException(403)
return await call_next(request)
class WebhookPayloadEvent(BaseModel):
action: Literal["create", "update", "delete"]
scope: str
class WebhookPayloadRepo(BaseModel):
type: Literal["dataset", "model", "space"]
name: str
private: bool
class WebhookPayloadDiscussion(BaseModel):
num: int
isPullRequest: bool
status: Literal["open", "closed", "merged"]
class WebhookPayload(BaseModel):
event: WebhookPayloadEvent
repo: WebhookPayloadRepo
discussion: Optional[WebhookPayloadDiscussion]