File size: 5,896 Bytes
3943768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from enums import split_google
from utils import sanitize_filename


def setup_app(name_login='google_login', name_app='h2ogpt', verbose=False):
    from authlib.integrations.starlette_client import OAuth, OAuthError
    from fastapi import FastAPI, Depends, Request
    from starlette.config import Config
    from starlette.responses import RedirectResponse
    from starlette.middleware.sessions import SessionMiddleware
    import os
    import gradio as gr

    assert os.environ['GOOGLE_CLIENT_ID'], "Set env GOOGLE_CLIENT_ID"
    GOOGLE_CLIENT_ID = os.environ['GOOGLE_CLIENT_ID']
    assert os.environ['GOOGLE_CLIENT_SECRET'], "Set env GOOGLE_CLIENT_SECRET"
    GOOGLE_CLIENT_SECRET = os.environ['GOOGLE_CLIENT_SECRET']
    assert os.environ['SECRET_KEY'], "Set env SECRET_KEY"
    SECRET_KEY = os.environ['SECRET_KEY']

    app = FastAPI()
    config = Config()
    oauth = OAuth(config)

    # Set up OAuth
    config_data = {'GOOGLE_CLIENT_ID': GOOGLE_CLIENT_ID, 'GOOGLE_CLIENT_SECRET': GOOGLE_CLIENT_SECRET}
    starlette_config = Config(environ=config_data)
    oauth = OAuth(starlette_config)
    oauth.register(
        name='google',
        server_metadata_url='https://accounts.google.com/.well-known/openid-configuration',
        client_kwargs={'scope': 'openid email profile'},
    )
    app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY)

    # Dependency to get the current user
    def get_user(request: Request):
        if verbose:
            print_request(request, which='get_user')
        user = request.session.get('user')
        if user:
            assert user['email'], "No email"
            assert user['email_verified'], "Email not verified: %s" % user['email']
            picture = user.get('picture', '') or 'None'
            return user['name'] + split_google + user['email'] + split_google + picture
        return None

    @app.get('/')
    def public(request: Request, user=Depends(get_user)):
        if verbose:
            print_request(request, which='public')
        root_url = gr.route_utils.get_root_url(request, "/", None)
        if user:
            return RedirectResponse(url=f'{root_url}/{name_app}/')
        else:
            return RedirectResponse(url=f'{root_url}/{name_login}/')

    @app.route('/logout')
    async def logout(request: Request):
        if verbose:
            print_request(request, which='logout')
        request.session.pop('user', None)
        return RedirectResponse(url='/')

    @app.route('/login')
    async def login(request: Request):
        if verbose:
            print_request(request, which='login0')
        root_url = gr.route_utils.get_root_url(request, "/login", None)
        redirect_uri = f"{root_url}/auth"
        print("Redirecting to", redirect_uri)
        return await oauth.google.authorize_redirect(request, redirect_uri)

    @app.route('/auth')
    async def auth(request: Request):
        if verbose:
            print_request(request, which='auth')
        try:
            access_token = await oauth.google.authorize_access_token(request)
        except OAuthError:
            print("Error getting access token", str(OAuthError))
            return RedirectResponse(url='/')
        request.session['user'] = dict(access_token)["userinfo"]
        print(f"Redirecting to /{name_app}")
        return RedirectResponse(url=f'/{name_app}')

    from urllib.parse import urlparse, urlunparse

    # Comment out below if using http instead of https
    @app.route('/login')
    async def login(request: Request):
        if verbose:
            print_request(request, which='login')
        parsed_url = urlparse(str(request.url_for('auth')))
        modified_url = parsed_url._replace(scheme='https')
        redirect_uri = urlunparse(modified_url)
        return await oauth.google.authorize_redirect(request, redirect_uri)

    def print_request(request: Request, which='unknown'):
        # Print request method (GET, POST, etc.)
        print("%s Method:" % which, request.method)

        # Print full URL
        print("%s URL:" % which, str(request.url))

        # Print headers
        print("%s Headers:" % which)
        for key, value in request.headers.items():
            print(f"    {key}: {value}")

        # Print query parameters
        print("%s Query Parameters:" % which)
        for key, value in request.query_params.items():
            print(f"    {key}: {value}")

        print("%s session:" % which, request.session)

    return app, get_user


def login_gradio(**kwargs):
    import gradio as gr
    login_demo = gr.Blocks()
    with login_demo:
        if kwargs['visible_h2ogpt_logo']:
            gr.Markdown(kwargs['markdown_logo'])
        with gr.Row():
            with gr.Column(scale=1):
                pass
            with gr.Column(scale=1):
                btn = gr.Button("%s Google Auth Login" % kwargs['page_title'])
            with gr.Column(scale=1):
                pass
        _js_redirect = """
            () => {
                url = '/login' + window.location.search;
                window.open(url, '_blank');
            }
            """
        btn.click(None, js=_js_redirect)
    return login_demo


def get_app(demo, app_kwargs={}, **login_kwargs):
    name_login = 'google_login'
    name_app = sanitize_filename(login_kwargs['page_title']).replace('/', '').lower()
    app, get_user = setup_app(name_login=name_login,
                              name_app=name_app,
                              verbose=False,  # can set to True to debug
                              )
    import gradio as gr
    login_app = gr.mount_gradio_app(app, login_gradio(**login_kwargs), f"/{name_login}")
    main_app = gr.mount_gradio_app(login_app, demo, path=f"/{name_app}",
                                   auth_dependency=get_user,
                                   app_kwargs=app_kwargs)
    return main_app