File size: 3,429 Bytes
730700d
 
 
 
 
 
df8d5b7
730700d
 
df8d5b7
730700d
df8d5b7
730700d
 
 
 
df8d5b7
730700d
 
 
 
 
 
 
 
 
df8d5b7
730700d
df8d5b7
730700d
 
 
 
 
 
df8d5b7
5aa55a4
 
eb00bfd
 
5aa55a4
 
 
 
 
 
 
f4d34d8
730700d
 
 
 
 
 
 
f4d34d8
 
 
730700d
 
 
 
 
 
 
d6dc1a5
730700d
 
3c47710
73b4b68
730700d
 
eb00bfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
730700d
2bbbe08
730700d
 
 
 
 
 
 
 
 
4042cfe
730700d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import uvicorn
from fastapi import FastAPI, Depends
from starlette.responses import RedirectResponse
from starlette.middleware.sessions import SessionMiddleware
from authlib.integrations.starlette_client import OAuth, OAuthError
from fastapi import Request
import os
from starlette.config import Config
import gradio as gr

app = FastAPI()

# OAuth settings
GOOGLE_CLIENT_ID = os.environ.get("GOOGLE_CLIENT_ID")
GOOGLE_CLIENT_SECRET = os.environ.get("GOOGLE_CLIENT_SECRET")
SECRET_KEY = os.environ.get("SECRET_KEY")

# Set up OAuth
config_data = {'GOOGLE_CLIENT_ID': GOOGLE_CLIENT_ID, 'GOOGLE_CLIENT_SECRET': GOOGLE_CLIENT_SECRET}
starlette_config = Config(environ=config_data)
oauth = OAuth(starlette_config)
oauth.register(
    name='google',
    server_metadata_url='https://accounts.google.com/.well-known/openid-configuration',
    client_kwargs={'scope': 'openid email profile'},
)

app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY)

# Dependency to get the current user
def get_user(request: Request):
    user = request.session.get('user')
    if user:
        return user['name']
    return None

@app.get('/')
def public(request: Request, user: dict = Depends(get_user)):
    root_url = gr.route_utils.get_root_url(request, "/", None)
    print("root_url", root_url)
    return RedirectResponse(url='https://abidlabs-gradio-oauth.hf.space/main/')
    # if user:
        # return RedirectResponse(url=f"{root_url}/gradio")
    # else:
        # redirect_url = f"{root_url}/login-demo"
        # print("Redirecting to", redirect_url)
        # return RedirectResponse(url=redirect_url)

@app.route('/logout')
async def logout(request: Request):
    request.session.pop('user', None)
    return RedirectResponse(url='/')

@app.route('/login')
async def login(request: Request):
    root_url = gr.route_utils.get_root_url(request, "/login", None)
    redirect_uri = f"{root_url}/auth"
    print("Redirecting to", redirect_uri)
    return await oauth.google.authorize_redirect(request, redirect_uri)

@app.route('/auth')
async def auth(request: Request):
    try:
        access_token = await oauth.google.authorize_access_token(request)
    except OAuthError:
        print("Error getting access token", str(OAuthError))
        return RedirectResponse(url='/')
    request.session['user'] = dict(access_token)["userinfo"]
    print("Redirecting to /gradio")
    return RedirectResponse(url='/gradio')

with gr.Blocks() as login_demo:
    btn = gr.Button("Login", link="/login")
    # _js_handle_redirect = """
    # (buttonValue) => {
    #     if (buttonValue === BUTTON_DEFAULT_VALUE) {
    #         url = '/login/huggingface' + window.location.search;
    #         if ( window !== window.parent ) {
    #             window.open(url, '_blank');
    #         } else {
    #             window.location.assign(url);
    #         }
    #     } else {
    #         url = '/logout' + window.location.search
    #         window.location.assign(url);
    #     }
    # }
    # """


app = gr.mount_gradio_app(app, login_demo, path="/main")

def greet(request: gr.Request):
    return f"Welcome to Gradio, {request.username}"

with gr.Blocks() as main_demo:
    m = gr.Markdown("Welcome to Gradio!")
    gr.Button("Logout", link="/logout")
    main_demo.load(greet, None, m)

app = gr.mount_gradio_app(app, main_demo, path="/gradio", auth_dependency=get_user)


if __name__ == '__main__':
    uvicorn.run(app)