|
|
|
|
|
from fastapi import FastAPI, Form, Depends, HTTPException, status |
|
from fastapi.requests import Request |
|
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse |
|
from fastapi.templating import Jinja2Templates |
|
from sqlalchemy.orm import Session |
|
from auth import verify_token, oauth2_scheme, auth_views, register, UserCreate, TokenData, authenticate_user, get_user_by_verification_token, resetpassword |
|
from database import get_db, get_user_by_email |
|
from datetime import timedelta |
|
from typing import Optional |
|
import httpx |
|
|
|
|
|
import os |
|
|
|
|
|
from authlib.integrations.starlette_client import OAuth |
|
from starlette.middleware.sessions import SessionMiddleware |
|
|
|
my_secret_key = os.environ['my_secret_key'] |
|
SECRET_KEY = os.getenv('SecretKey', 'default_secret') |
|
from fastapi.staticfiles import StaticFiles |
|
from authlib.integrations.starlette_client import OAuth |
|
|
|
app = FastAPI() |
|
|
|
oauth = OAuth(app) |
|
app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY) |
|
|
|
oauth.register( |
|
name='google', |
|
client_id=os.environ['GOOGLE_CLIENT_ID'], |
|
client_secret=os.environ['GOOGLE_CLIENT_SECRET'], |
|
access_token_url='https://accounts.google.com/o/oauth2/token', |
|
authorize_url='https://accounts.google.com/o/oauth2/auth', |
|
authorize_params=None, |
|
api_base_url='https://www.googleapis.com/oauth2/v1/', |
|
client_kwargs={'scope': 'openid email profile'} |
|
) |
|
@app.get("/login/oauth") |
|
async def login_oauth(request: Request): |
|
|
|
redirect_uri = request.url_for('auth_callback') |
|
return await oauth.google.authorize_redirect(request, redirect_uri) |
|
|
|
@app.get("/auth/callback") |
|
async def auth_callback(request: Request, db: Session = Depends(get_db)): |
|
|
|
token = await oauth.google.authorize_access_token(request) |
|
|
|
|
|
user_info = await oauth.google.parse_id_token(request, token) |
|
|
|
|
|
request.session["user_info"] = user_info |
|
|
|
|
|
|
|
db_user = db.query(User).filter(User.email == user_info['email']).first() |
|
if not db_user: |
|
db_user = User(email=user_info['email'], username=user_info['name'], is_verified=True) |
|
db.add(db_user) |
|
db.commit() |
|
db.refresh(db_user) |
|
|
|
|
|
access_token = auth_views.create_access_token( |
|
data={"sub": db_user.email}, |
|
expires_delta=timedelta(minutes=auth_views.ACCESS_TOKEN_EXPIRE_MINUTES) |
|
) |
|
|
|
|
|
url = app.url_path_for("get_protected") |
|
response = RedirectResponse(url) |
|
response.set_cookie(key="access_token", value=f"Bearer {access_token}", httponly=True) |
|
return response |
|
|
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
|
templates = Jinja2Templates(directory="templates") |
|
|
|
|
|
|
|
""" |
|
def create_assessment( |
|
project_id: str, recaptcha_key: str, token: str, recaptcha_action: str |
|
) -> Assessment: |
|
|
|
|
|
client = recaptchaenterprise_v1.RecaptchaEnterpriseServiceClient() |
|
|
|
# Set the properties of the event to be tracked. |
|
event = recaptchaenterprise_v1.Event() |
|
event.site_key = recaptcha_key |
|
event.token = token |
|
|
|
assessment = recaptchaenterprise_v1.Assessment() |
|
assessment.event = event |
|
|
|
project_name = f"projects/{project_id}" |
|
|
|
# Build the assessment request. |
|
request = recaptchaenterprise_v1.CreateAssessmentRequest() |
|
request.assessment = assessment |
|
request.parent = project_name |
|
|
|
response = client.create_assessment(request) |
|
|
|
# Check if the token is valid. |
|
if not response.token_properties.valid: |
|
print( |
|
"The CreateAssessment call failed because the token was " |
|
+ "invalid for the following reasons: " |
|
+ str(response.token_properties.invalid_reason) |
|
) |
|
return |
|
|
|
# Check if the expected action was executed. |
|
if response.token_properties.action != recaptcha_action: |
|
print( |
|
"The action attribute in your reCAPTCHA tag does" |
|
+ "not match the action you are expecting to score" |
|
) |
|
return |
|
else: |
|
# Get the risk score and the reason(s). |
|
# For more information on interpreting the assessment, see: |
|
# https://cloud.google.com/recaptcha-enterprise/docs/interpret-assessment |
|
for reason in response.risk_analysis.reasons: |
|
print(reason) |
|
print( |
|
"The reCAPTCHA score for this token is: " |
|
+ str(response.risk_analysis.score) |
|
) |
|
# Get the assessment name (ID). Use this to annotate the assessment. |
|
assessment_name = client.parse_assessment_path(response.name).get("assessment") |
|
print(f"Assessment name: {assessment_name}") |
|
return response |
|
""" |
|
|
|
@app.post("/verify-google-token") |
|
async def verify_google_token(token_data: TokenData, db: Session = Depends(get_db)): |
|
|
|
response = requests.get(f'https://www.googleapis.com/oauth2/v3/tokeninfo?id_token={token_data.token}') |
|
if response.status_code != 200: |
|
raise HTTPException(status_code=400, detail="Invalid Google token") |
|
|
|
google_user_info = response.json() |
|
email = google_user_info.get('email') |
|
|
|
|
|
db_user = db.query(User).filter(User.email == email).first() |
|
if not db_user: |
|
|
|
db_user = User(email=email, is_verified=True, username=google_user_info.get('name')) |
|
db.add(db_user) |
|
db.commit() |
|
db.refresh(db_user) |
|
elif not db_user.is_verified: |
|
|
|
db_user.is_verified = True |
|
db.commit() |
|
|
|
|
|
access_token = auth_views.create_access_token( |
|
data={"sub": db_user.email}, |
|
expires_delta=timedelta(minutes=auth_views.ACCESS_TOKEN_EXPIRE_MINUTES) |
|
) |
|
|
|
|
|
|
|
|
|
url = app.url_path_for("get_protected") |
|
|
|
response = RedirectResponse(f"{url}?token={access_token}", status_code=status.HTTP_303_SEE_OTHER) |
|
response.set_cookie(key="access_token", value=f"Bearer {access_token}", httponly=True) |
|
return response |
|
|
|
def get_current_user(token: str = Depends(verify_token)): |
|
if not token: |
|
raise HTTPException(status_code=401, detail="Token not valid") |
|
return token |
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
async def landing(request: Request): |
|
return templates.TemplateResponse("landing.html", {"request": request}) |
|
|
|
|
|
|
|
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_400_BAD_REQUEST |
|
from jwt import ExpiredSignatureError, InvalidTokenError |
|
|
|
@app.get("/login", response_class=HTMLResponse) |
|
async def login(request: Request, db: Session = Depends(get_db)): |
|
access_token = request.cookies.get("access_token") |
|
|
|
if access_token: |
|
try: |
|
user_email = verify_token(access_token.split("Bearer ")[1]) |
|
if user_email: |
|
|
|
db_user = db.query(User).filter(User.email == user_email).first() |
|
if not db_user: |
|
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="User not found") |
|
|
|
|
|
if not db_user.is_verified: |
|
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="User is not verified") |
|
|
|
|
|
new_access_token = auth_views.create_access_token( |
|
data={"sub": db_user.email}, |
|
expires_delta=timedelta(minutes=auth_views.ACCESS_TOKEN_EXPIRE_MINUTES) |
|
) |
|
|
|
|
|
url = app.url_path_for("get_protected") |
|
response = RedirectResponse(url) |
|
response.set_cookie(key="access_token", value=f"Bearer {new_access_token}", httponly=True) |
|
return response |
|
except ExpiredSignatureError: |
|
|
|
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Token expired") |
|
except InvalidTokenError: |
|
|
|
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="Invalid token") |
|
except Exception as e: |
|
|
|
|
|
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="An error occurred") |
|
|
|
|
|
return templates.TemplateResponse("login.html", {"request": request}) |
|
|
|
|
|
|
|
@app.post("/login") |
|
async def login_post( |
|
request: Request, |
|
email: str = Form(...), |
|
password: str = Form(...), |
|
recaptcha_token: str = Form(...), |
|
db: Session = Depends(get_db) |
|
): |
|
|
|
|
|
recaptcha_secret = '6LeSJgwpAAAAAJrLrvlQYhRsOjf2wKXee_Jc4Z-k' |
|
recaptcha_url = 'https://www.google.com/recaptcha/api/siteverify' |
|
recaptcha_data = { |
|
'secret': recaptcha_secret, |
|
'response': recaptcha_token |
|
} |
|
|
|
async with httpx.AsyncClient() as client: |
|
recaptcha_response = await client.post(recaptcha_url, data=recaptcha_data) |
|
|
|
recaptcha_result = recaptcha_response.json() |
|
print(recaptcha_result) |
|
if not recaptcha_result.get('success', False): |
|
raise HTTPException(status_code=400, detail="reCAPTCHA validation failed.") |
|
if not email or not password: |
|
raise HTTPException(status_code=400, detail="Invalid email or password") |
|
|
|
user = authenticate_user(db, email, password) |
|
if user and user.is_verified: |
|
access_token = auth_views.create_access_token( |
|
data={"sub": user.email}, |
|
expires_delta=timedelta(minutes=auth_views.ACCESS_TOKEN_EXPIRE_MINUTES) |
|
) |
|
|
|
|
|
url = app.url_path_for("get_protected") |
|
|
|
|
|
|
|
response = RedirectResponse(f"{url}?token={access_token}", status_code=status.HTTP_303_SEE_OTHER) |
|
response.set_cookie(key="access_token", value=f"Bearer {access_token}", httponly=True) |
|
|
|
return response |
|
elif user and not user.is_verified: |
|
raise HTTPException( |
|
status_code=400, |
|
detail="You must verify your email before accessing this resource." |
|
) |
|
else: |
|
|
|
return templates.TemplateResponse( |
|
"login.html", |
|
{"request": request, "error_message": "Invalid email or password"} |
|
) |
|
@app.get("/register", response_class=HTMLResponse) |
|
async def register_get(request: Request): |
|
return templates.TemplateResponse("register.html", {"request": request}) |
|
|
|
|
|
@app.post("/register", response_class=HTMLResponse) |
|
async def register_post( |
|
request: Request, |
|
username: str = Form(...), |
|
email: str = Form(...), |
|
password: str = Form(...), |
|
confirm_password: str = Form(...), |
|
recaptcha_token: str = Form(...), |
|
db: Session = Depends(get_db) |
|
): |
|
|
|
|
|
recaptcha_secret = '6LeSJgwpAAAAAJrLrvlQYhRsOjf2wKXee_Jc4Z-k' |
|
recaptcha_url = 'https://www.google.com/recaptcha/api/siteverify' |
|
recaptcha_data = { |
|
'secret': recaptcha_secret, |
|
'response': recaptcha_token |
|
} |
|
|
|
async with httpx.AsyncClient() as client: |
|
recaptcha_response = await client.post(recaptcha_url, data=recaptcha_data) |
|
|
|
recaptcha_result = recaptcha_response.json() |
|
print(recaptcha_result) |
|
if not recaptcha_result.get('success', False): |
|
raise HTTPException(status_code=400, detail="reCAPTCHA validation failed.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if password != confirm_password: |
|
|
|
return templates.TemplateResponse("register.html", { |
|
"request": request, |
|
"error_message": "Passwords do not match." |
|
}) |
|
|
|
try: |
|
user = UserCreate(username=username, email=email, password=password, confirm_password=confirm_password) |
|
registered_user = register(user, db) |
|
|
|
request.session["user_info"] = {"username": registered_user.username, "email": registered_user.email} |
|
except HTTPException as e: |
|
|
|
return templates.TemplateResponse("register.html", { |
|
"request": request, |
|
"error_message": e.detail |
|
}) |
|
|
|
|
|
response = RedirectResponse("/registration_successful", status_code=status.HTTP_302_FOUND) |
|
return response |
|
|
|
|
|
@app.get("/registration_successful", response_class=HTMLResponse) |
|
async def registration_successful(request: Request, db: Session = Depends(get_db)): |
|
|
|
user_info = request.session.get("user_info") |
|
|
|
if not user_info: |
|
raise HTTPException(status_code=401, detail="User not authenticated") |
|
|
|
email = user_info["email"] |
|
db_user = db.query(User).filter(User.email == email).first() |
|
if not db_user: |
|
raise HTTPException(status_code=404, detail="User not found") |
|
|
|
|
|
access_token = auth_views.create_access_token( |
|
data={"sub": db_user.email}, |
|
expires_delta=timedelta(minutes=auth_views.ACCESS_TOKEN_EXPIRE_MINUTES) |
|
) |
|
|
|
|
|
response = RedirectResponse(url="/protected") |
|
response.set_cookie(key="access_token", value=f"Bearer {access_token}", httponly=True) |
|
return response |
|
|
|
|
|
@app.get("/verify", response_class=HTMLResponse) |
|
async def verify_email(token: str, db: Session = Depends(get_db)): |
|
user = get_user_by_verification_token(db, token) |
|
if not user: |
|
raise HTTPException(status_code=400, detail="Invalid verification token") |
|
|
|
if user.is_verified: |
|
raise HTTPException(status_code=400, detail="Email already verified") |
|
|
|
user.is_verified = True |
|
user.email_verification_token = None |
|
db.commit() |
|
|
|
|
|
access_token = auth_views.create_access_token(data={"sub": user.email}, expires_delta=timedelta(minutes=auth_views.ACCESS_TOKEN_EXPIRE_MINUTES)) |
|
|
|
return RedirectResponse(url=f"/protected?token={access_token}") |
|
|
|
|
|
|
|
@app.get("/protected", response_class=HTMLResponse) |
|
async def get_protected( |
|
request: Request, |
|
db: Session = Depends(get_db), |
|
token: Optional[str] = None |
|
): |
|
|
|
token = token or request.cookies.get("access_token") |
|
if not token: |
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") |
|
|
|
|
|
|
|
user_email = verify_token(token) |
|
|
|
|
|
db_user = get_user_by_email(db, user_email) |
|
if db_user is None or not db_user.is_verified: |
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or not verified in the database") |
|
|
|
|
|
return templates.TemplateResponse("protected.html", {"request": request, "user": db_user.username}) |
|
|
|
@app.get("/reset-password") |
|
async def reset_password_get(request: Request, token: str): |
|
return templates.TemplateResponse("reset-password.html", {"request": request, "token": token}) |
|
|
|
@app.post("/password-reset-request") |
|
async def password_reset_request(email: str = Form(...), db: Session = Depends(get_db)): |
|
user = get_user_by_email(db, email) |
|
|
|
|
|
if user: |
|
resetpassword(user,db) |
|
|
|
return {"message": "Password reset link sent if the email is registered with us."} |
|
@app.get("/password-reset-request", response_class=HTMLResponse) |
|
async def password_reset_form(request: Request): |
|
return templates.TemplateResponse("password_reset_request.html", {"request": request}) |
|
|
|
from fastapi import Form |
|
|
|
@app.post("/reset-password") |
|
async def reset_password(token: str = Form(...), new_password: str = Form(...), db: Session = Depends(get_db)): |
|
user = get_user_by_verification_token(db, token) |
|
if not user: |
|
raise HTTPException(status_code=400, detail="Invalid or expired token") |
|
|
|
|
|
hashed_password = auth_views.pwd_context.hash(new_password) |
|
|
|
|
|
user.hashed_password = hashed_password |
|
user.email_verification_token = None |
|
db.commit() |
|
|
|
return {"message": "Password successfully reset."} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|