File size: 5,896 Bytes
3943768 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
from enums import split_google
from utils import sanitize_filename
def setup_app(name_login='google_login', name_app='h2ogpt', verbose=False):
from authlib.integrations.starlette_client import OAuth, OAuthError
from fastapi import FastAPI, Depends, Request
from starlette.config import Config
from starlette.responses import RedirectResponse
from starlette.middleware.sessions import SessionMiddleware
import os
import gradio as gr
assert os.environ['GOOGLE_CLIENT_ID'], "Set env GOOGLE_CLIENT_ID"
GOOGLE_CLIENT_ID = os.environ['GOOGLE_CLIENT_ID']
assert os.environ['GOOGLE_CLIENT_SECRET'], "Set env GOOGLE_CLIENT_SECRET"
GOOGLE_CLIENT_SECRET = os.environ['GOOGLE_CLIENT_SECRET']
assert os.environ['SECRET_KEY'], "Set env SECRET_KEY"
SECRET_KEY = os.environ['SECRET_KEY']
app = FastAPI()
config = Config()
oauth = OAuth(config)
# Set up OAuth
config_data = {'GOOGLE_CLIENT_ID': GOOGLE_CLIENT_ID, 'GOOGLE_CLIENT_SECRET': GOOGLE_CLIENT_SECRET}
starlette_config = Config(environ=config_data)
oauth = OAuth(starlette_config)
oauth.register(
name='google',
server_metadata_url='https://accounts.google.com/.well-known/openid-configuration',
client_kwargs={'scope': 'openid email profile'},
)
app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY)
# Dependency to get the current user
def get_user(request: Request):
if verbose:
print_request(request, which='get_user')
user = request.session.get('user')
if user:
assert user['email'], "No email"
assert user['email_verified'], "Email not verified: %s" % user['email']
picture = user.get('picture', '') or 'None'
return user['name'] + split_google + user['email'] + split_google + picture
return None
@app.get('/')
def public(request: Request, user=Depends(get_user)):
if verbose:
print_request(request, which='public')
root_url = gr.route_utils.get_root_url(request, "/", None)
if user:
return RedirectResponse(url=f'{root_url}/{name_app}/')
else:
return RedirectResponse(url=f'{root_url}/{name_login}/')
@app.route('/logout')
async def logout(request: Request):
if verbose:
print_request(request, which='logout')
request.session.pop('user', None)
return RedirectResponse(url='/')
@app.route('/login')
async def login(request: Request):
if verbose:
print_request(request, which='login0')
root_url = gr.route_utils.get_root_url(request, "/login", None)
redirect_uri = f"{root_url}/auth"
print("Redirecting to", redirect_uri)
return await oauth.google.authorize_redirect(request, redirect_uri)
@app.route('/auth')
async def auth(request: Request):
if verbose:
print_request(request, which='auth')
try:
access_token = await oauth.google.authorize_access_token(request)
except OAuthError:
print("Error getting access token", str(OAuthError))
return RedirectResponse(url='/')
request.session['user'] = dict(access_token)["userinfo"]
print(f"Redirecting to /{name_app}")
return RedirectResponse(url=f'/{name_app}')
from urllib.parse import urlparse, urlunparse
# Comment out below if using http instead of https
@app.route('/login')
async def login(request: Request):
if verbose:
print_request(request, which='login')
parsed_url = urlparse(str(request.url_for('auth')))
modified_url = parsed_url._replace(scheme='https')
redirect_uri = urlunparse(modified_url)
return await oauth.google.authorize_redirect(request, redirect_uri)
def print_request(request: Request, which='unknown'):
# Print request method (GET, POST, etc.)
print("%s Method:" % which, request.method)
# Print full URL
print("%s URL:" % which, str(request.url))
# Print headers
print("%s Headers:" % which)
for key, value in request.headers.items():
print(f" {key}: {value}")
# Print query parameters
print("%s Query Parameters:" % which)
for key, value in request.query_params.items():
print(f" {key}: {value}")
print("%s session:" % which, request.session)
return app, get_user
def login_gradio(**kwargs):
import gradio as gr
login_demo = gr.Blocks()
with login_demo:
if kwargs['visible_h2ogpt_logo']:
gr.Markdown(kwargs['markdown_logo'])
with gr.Row():
with gr.Column(scale=1):
pass
with gr.Column(scale=1):
btn = gr.Button("%s Google Auth Login" % kwargs['page_title'])
with gr.Column(scale=1):
pass
_js_redirect = """
() => {
url = '/login' + window.location.search;
window.open(url, '_blank');
}
"""
btn.click(None, js=_js_redirect)
return login_demo
def get_app(demo, app_kwargs={}, **login_kwargs):
name_login = 'google_login'
name_app = sanitize_filename(login_kwargs['page_title']).replace('/', '').lower()
app, get_user = setup_app(name_login=name_login,
name_app=name_app,
verbose=False, # can set to True to debug
)
import gradio as gr
login_app = gr.mount_gradio_app(app, login_gradio(**login_kwargs), f"/{name_login}")
main_app = gr.mount_gradio_app(login_app, demo, path=f"/{name_app}",
auth_dependency=get_user,
app_kwargs=app_kwargs)
return main_app
|