|
from enums import split_google |
|
from utils import sanitize_filename |
|
|
|
|
|
def setup_app(name_login='google_login', name_app='h2ogpt', verbose=False): |
|
from authlib.integrations.starlette_client import OAuth, OAuthError |
|
from fastapi import FastAPI, Depends, Request |
|
from starlette.config import Config |
|
from starlette.responses import RedirectResponse |
|
from starlette.middleware.sessions import SessionMiddleware |
|
import os |
|
import gradio as gr |
|
|
|
assert os.environ['GOOGLE_CLIENT_ID'], "Set env GOOGLE_CLIENT_ID" |
|
GOOGLE_CLIENT_ID = os.environ['GOOGLE_CLIENT_ID'] |
|
assert os.environ['GOOGLE_CLIENT_SECRET'], "Set env GOOGLE_CLIENT_SECRET" |
|
GOOGLE_CLIENT_SECRET = os.environ['GOOGLE_CLIENT_SECRET'] |
|
assert os.environ['SECRET_KEY'], "Set env SECRET_KEY" |
|
SECRET_KEY = os.environ['SECRET_KEY'] |
|
|
|
app = FastAPI() |
|
config = Config() |
|
oauth = OAuth(config) |
|
|
|
|
|
config_data = {'GOOGLE_CLIENT_ID': GOOGLE_CLIENT_ID, 'GOOGLE_CLIENT_SECRET': GOOGLE_CLIENT_SECRET} |
|
starlette_config = Config(environ=config_data) |
|
oauth = OAuth(starlette_config) |
|
oauth.register( |
|
name='google', |
|
server_metadata_url='https://accounts.google.com/.well-known/openid-configuration', |
|
client_kwargs={'scope': 'openid email profile'}, |
|
) |
|
app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY) |
|
|
|
|
|
def get_user(request: Request): |
|
if verbose: |
|
print_request(request, which='get_user') |
|
user = request.session.get('user') |
|
if user: |
|
assert user['email'], "No email" |
|
assert user['email_verified'], "Email not verified: %s" % user['email'] |
|
picture = user.get('picture', '') or 'None' |
|
return user['name'] + split_google + user['email'] + split_google + picture |
|
return None |
|
|
|
@app.get('/') |
|
def public(request: Request, user=Depends(get_user)): |
|
if verbose: |
|
print_request(request, which='public') |
|
root_url = gr.route_utils.get_root_url(request, "/", None) |
|
if user: |
|
return RedirectResponse(url=f'{root_url}/{name_app}/') |
|
else: |
|
return RedirectResponse(url=f'{root_url}/{name_login}/') |
|
|
|
@app.route('/logout') |
|
async def logout(request: Request): |
|
if verbose: |
|
print_request(request, which='logout') |
|
request.session.pop('user', None) |
|
return RedirectResponse(url='/') |
|
|
|
@app.route('/login') |
|
async def login(request: Request): |
|
if verbose: |
|
print_request(request, which='login0') |
|
root_url = gr.route_utils.get_root_url(request, "/login", None) |
|
redirect_uri = f"{root_url}/auth" |
|
print("Redirecting to", redirect_uri) |
|
return await oauth.google.authorize_redirect(request, redirect_uri) |
|
|
|
@app.route('/auth') |
|
async def auth(request: Request): |
|
if verbose: |
|
print_request(request, which='auth') |
|
try: |
|
access_token = await oauth.google.authorize_access_token(request) |
|
except OAuthError: |
|
print("Error getting access token", str(OAuthError)) |
|
return RedirectResponse(url='/') |
|
request.session['user'] = dict(access_token)["userinfo"] |
|
print(f"Redirecting to /{name_app}") |
|
return RedirectResponse(url=f'/{name_app}') |
|
|
|
from urllib.parse import urlparse, urlunparse |
|
|
|
|
|
@app.route('/login') |
|
async def login(request: Request): |
|
if verbose: |
|
print_request(request, which='login') |
|
parsed_url = urlparse(str(request.url_for('auth'))) |
|
modified_url = parsed_url._replace(scheme='https') |
|
redirect_uri = urlunparse(modified_url) |
|
return await oauth.google.authorize_redirect(request, redirect_uri) |
|
|
|
def print_request(request: Request, which='unknown'): |
|
|
|
print("%s Method:" % which, request.method) |
|
|
|
|
|
print("%s URL:" % which, str(request.url)) |
|
|
|
|
|
print("%s Headers:" % which) |
|
for key, value in request.headers.items(): |
|
print(f" {key}: {value}") |
|
|
|
|
|
print("%s Query Parameters:" % which) |
|
for key, value in request.query_params.items(): |
|
print(f" {key}: {value}") |
|
|
|
print("%s session:" % which, request.session) |
|
|
|
return app, get_user |
|
|
|
|
|
def login_gradio(**kwargs): |
|
import gradio as gr |
|
login_demo = gr.Blocks() |
|
with login_demo: |
|
if kwargs['visible_h2ogpt_logo']: |
|
gr.Markdown(kwargs['markdown_logo']) |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
pass |
|
with gr.Column(scale=1): |
|
btn = gr.Button("%s Google Auth Login" % kwargs['page_title']) |
|
with gr.Column(scale=1): |
|
pass |
|
_js_redirect = """ |
|
() => { |
|
url = '/login' + window.location.search; |
|
window.open(url, '_blank'); |
|
} |
|
""" |
|
btn.click(None, js=_js_redirect) |
|
return login_demo |
|
|
|
|
|
def get_app(demo, app_kwargs={}, **login_kwargs): |
|
name_login = 'google_login' |
|
name_app = sanitize_filename(login_kwargs['page_title']).replace('/', '').lower() |
|
app, get_user = setup_app(name_login=name_login, |
|
name_app=name_app, |
|
verbose=False, |
|
) |
|
import gradio as gr |
|
login_app = gr.mount_gradio_app(app, login_gradio(**login_kwargs), f"/{name_login}") |
|
main_app = gr.mount_gradio_app(login_app, demo, path=f"/{name_app}", |
|
auth_dependency=get_user, |
|
app_kwargs=app_kwargs) |
|
return main_app |
|
|