|
|
|
from fastapi import FastAPI, Depends, HTTPException, Request, Form, status, Header |
|
from fastapi.responses import RedirectResponse, HTMLResponse |
|
from fastapi.responses import JSONResponse |
|
from fastapi.templating import Jinja2Templates |
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm |
|
from pydantic import BaseModel |
|
from sqlalchemy.orm import Session |
|
from database import get_db, get_user_by_email |
|
from models import User |
|
from passlib.context import CryptContext |
|
from datetime import datetime, timedelta |
|
import jwt |
|
from emailx import send_verification_email, generate_verification_token |
|
|
|
from fastapi.staticfiles import StaticFiles |
|
|
|
from typing import Optional |
|
import httpx |
|
import os |
|
from starlette.middleware.sessions import SessionMiddleware |
|
from authlib.integrations.starlette_client import OAuth |
|
|
|
|
|
|
|
|
|
GOOGLE_CLIENT_ID = os.getenv('GOOGLE_CLIENT_ID') |
|
GOOGLE_CLIENT_SECRET = os.getenv('GOOGLE_CLIENT_SECRET') |
|
SECRET_KEY = os.getenv('SecretKey', 'default_secret') |
|
ALGORITHM = "HS256" |
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30 |
|
|
|
|
|
app = FastAPI() |
|
app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY) |
|
oauth = OAuth() |
|
|
|
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
|
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
|
|
|
class TokenData(BaseModel): |
|
token: str |
|
|
|
class UserCreate(BaseModel): |
|
username: str |
|
email: str |
|
password: str |
|
|
|
|
|
oauth.register( |
|
name='google', |
|
client_id=os.environ['GOOGLE_CLIENT_ID'], |
|
client_secret=os.environ['GOOGLE_CLIENT_SECRET'], |
|
access_token_url='https://accounts.google.com/o/oauth2/token', |
|
authorize_url='https://accounts.google.com/o/oauth2/auth', |
|
authorize_params=None, |
|
api_base_url='https://www.googleapis.com/oauth2/v1/', |
|
client_kwargs={'scope': 'openid email profile'} |
|
) |
|
|
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
templates = Jinja2Templates(directory="templates") |
|
|
|
|
|
@app.get("/login/oauth") |
|
async def login_oauth(request: Request): |
|
redirect_uri = request.url_for('auth_callback') |
|
return await oauth.google.authorize_redirect(request, redirect_uri) |
|
|
|
@app.get("/auth/callback") |
|
async def auth_callback(request: Request, db: Session = Depends(get_db)): |
|
token = await oauth.google.authorize_access_token(request) |
|
user_info = await oauth.google.parse_id_token(request, token) |
|
request.session["user_info"] = user_info |
|
|
|
db_user = db.query(User).filter(User.email == user_info['email']).first() |
|
if not db_user: |
|
db_user = User(email=user_info['email'], username=user_info['name'], is_verified=True) |
|
db.add(db_user) |
|
db.commit() |
|
db.refresh(db_user) |
|
|
|
access_token = create_access_token(data={"sub": db_user.email}, expires_delta=timedelta(minutes=30)) |
|
response = RedirectResponse(url="/protected") |
|
response.set_cookie(key="access_token", value=f"Bearer {access_token}", httponly=True) |
|
return response |
|
|
|
@app.post("/login") |
|
async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db), recaptcha_token: str = Form(...)): |
|
|
|
|
|
recaptcha_secret = '6LeSJgwpAAAAAJrLrvlQYhRsOjf2wKXee_Jc4Z-k' |
|
recaptcha_url = 'https://www.google.com/recaptcha/api/siteverify' |
|
recaptcha_data = { |
|
'secret': recaptcha_secret, |
|
'response': recaptcha_token |
|
} |
|
|
|
async with httpx.AsyncClient() as client: |
|
recaptcha_response = await client.post(recaptcha_url, data=recaptcha_data) |
|
|
|
recaptcha_result = recaptcha_response.json() |
|
print(recaptcha_result) |
|
if not recaptcha_result.get('success', False): |
|
raise HTTPException(status_code=400, detail="reCAPTCHA validation failed.") |
|
if not form_data.username or not form_data.password: |
|
raise HTTPException(status_code=400, detail="Invalid email or password") |
|
|
|
user = authenticate_user(db, form_data.username, form_data.password) |
|
if user and user.is_verified: |
|
access_token = create_access_token( |
|
data={"sub": user.email}, |
|
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) |
|
) |
|
|
|
|
|
url = app.url_path_for("get_protected") |
|
|
|
|
|
|
|
response = RedirectResponse(f"{url}?token={access_token}", status_code=status.HTTP_303_SEE_OTHER) |
|
response.set_cookie(key="access_token", value=f"Bearer {access_token}", httponly=True) |
|
|
|
return response |
|
elif user and not user.is_verified: |
|
raise HTTPException( |
|
status_code=400, |
|
detail="You must verify your email before accessing this resource." |
|
) |
|
else: |
|
|
|
return templates.TemplateResponse( |
|
"login.html", |
|
{"request": request, "error_message": "Invalid email or password"} |
|
) |
|
|
|
@app.get("/login", response_class=HTMLResponse) |
|
async def login(request: Request, db: Session = Depends(get_db)): |
|
access_token = request.cookies.get("access_token") |
|
|
|
if access_token: |
|
try: |
|
user_email = verify_token(access_token.split("Bearer ")[1]) |
|
if user_email: |
|
|
|
db_user = db.query(User).filter(User.email == user_email).first() |
|
if not db_user: |
|
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="User not found") |
|
|
|
|
|
if not db_user.is_verified: |
|
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="User is not verified") |
|
|
|
|
|
new_access_token = create_access_token( |
|
data={"sub": db_user.email}, |
|
expires_delta=timedelta(minutes=auth_views.ACCESS_TOKEN_EXPIRE_MINUTES) |
|
) |
|
|
|
|
|
url = app.url_path_for("get_protected") |
|
response = RedirectResponse(url) |
|
response.set_cookie(key="access_token", value=f"Bearer {new_access_token}", httponly=True) |
|
return response |
|
except ExpiredSignatureError: |
|
|
|
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Token expired") |
|
except InvalidTokenError: |
|
|
|
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="Invalid token") |
|
except Exception as e: |
|
|
|
|
|
raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="An error occurred") |
|
|
|
|
|
google_oauth_url = request.url_for("login_oauth") |
|
return templates.TemplateResponse("login.html", {"request": request, "google_oauth_url": google_oauth_url}) |
|
|
|
@app.get("/register/google") |
|
async def register_google(request: Request): |
|
|
|
redirect_uri = request.url_for('auth_callback') |
|
return await oauth.google.authorize_redirect(request, redirect_uri) |
|
|
|
@app.get("/auth/callback") |
|
async def auth_callback(request: Request, db: Session = Depends(get_db)): |
|
|
|
token = await oauth.google.authorize_access_token(request) |
|
user_info = await oauth.google.parse_id_token(request, token) |
|
|
|
|
|
existing_user = db.query(User).filter(User.email == user_info['email']).first() |
|
if existing_user: |
|
|
|
|
|
pass |
|
else: |
|
|
|
new_user = User( |
|
email=user_info['email'], |
|
username=user_info.get('name'), |
|
is_verified=True |
|
) |
|
db.add(new_user) |
|
db.commit() |
|
db.refresh(new_user) |
|
|
|
request.session["user_info"] = {"username": new_user.username, "email": new_user.email} |
|
|
|
|
|
|
|
return RedirectResponse(url="/registration_successful") |
|
|
|
@app.post("/registration_successful", response_class=HTMLResponse) |
|
async def registration_successful(request: Request, db: Session = Depends(get_db)): |
|
|
|
user_info = request.session.get("user_info") |
|
|
|
if not user_info: |
|
raise HTTPException(status_code=401, detail="User not authenticated") |
|
|
|
email = user_info["email"] |
|
db_user = db.query(User).filter(User.email == email).first() |
|
if not db_user: |
|
raise HTTPException(status_code=404, detail="User not found") |
|
|
|
|
|
access_token = create_access_token( |
|
data={"sub": db_user.email}, |
|
expires_delta=timedelta(minutes=auth_views.ACCESS_TOKEN_EXPIRE_MINUTES) |
|
) |
|
|
|
|
|
response = RedirectResponse(url="/login") |
|
response.set_cookie(key="access_token", value=f"Bearer {access_token}", httponly=True) |
|
return response |
|
|
|
async def verify_recaptcha(recaptcha_token: str) -> bool: |
|
recaptcha_secret = '6LeSJgwpAAAAAJrLrvlQYhRsOjf2wKXee_Jc4Z-k' |
|
recaptcha_url = 'https://www.google.com/recaptcha/api/siteverify' |
|
recaptcha_data = { |
|
'secret': recaptcha_secret, |
|
'response': recaptcha_token |
|
} |
|
|
|
async with httpx.AsyncClient() as client: |
|
recaptcha_response = await client.post(recaptcha_url, data=recaptcha_data) |
|
|
|
recaptcha_result = recaptcha_response.json() |
|
print(recaptcha_result) |
|
|
|
return recaptcha_result.get('success', False) |
|
|
|
@app.get("/verify", response_class=HTMLResponse) |
|
async def verify_email(token: str, db: Session = Depends(get_db)): |
|
user = get_user_by_verification_token(db, token) |
|
if not user: |
|
raise HTTPException(status_code=400, detail="Invalid verification token") |
|
|
|
if user.is_verified: |
|
raise HTTPException(status_code=400, detail="Email already verified") |
|
|
|
user.is_verified = True |
|
user.email_verification_token = None |
|
db.commit() |
|
|
|
|
|
access_token = create_access_token(data={"sub": user.email}, expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)) |
|
|
|
|
|
response = RedirectResponse(url="/protected") |
|
response.set_cookie(key="access_token", value=f"Bearer {access_token}", httponly=True, secure=True, samesite='Lax') |
|
return response |
|
|
|
def is_username_available(db: Session, username: str) -> bool: |
|
return db.query(User).filter(User.username == username).first() is None |
|
|
|
@app.get("/register", response_class=HTMLResponse) |
|
async def register_get(request: Request): |
|
return templates.TemplateResponse("register.html", {"request": request, "google_oauth_url": request.url_for("login_oauth")}) |
|
|
|
@app.post("/register") |
|
async def register_post( |
|
request: Request, |
|
username: str = Form(...), |
|
email: str = Form(...), |
|
password: str = Form(...), |
|
confirm_password: str = Form(...), |
|
recaptcha_token: str = Form(...), |
|
db: Session = Depends(get_db) |
|
): |
|
if not await verify_recaptcha(recaptcha_token): |
|
return templates.TemplateResponse("register.html", {"request": request, "error_message": "reCAPTCHA validation failed."}) |
|
|
|
if password != confirm_password: |
|
return templates.TemplateResponse("register.html", {"request": request, "error_message": "Passwords do not match."}) |
|
|
|
user_data = UserCreate(username=username, email=email, password=password) |
|
if not is_username_available(db, user_data.username): |
|
raise HTTPException(status_code=400, detail="Username already taken") |
|
try: |
|
registered_user = register_user(user_data, db) |
|
|
|
|
|
access_token = create_access_token(data={"sub": registered_user.email}, |
|
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)) |
|
|
|
|
|
url = app.url_path_for("get_protected") |
|
|
|
|
|
|
|
|
|
|
|
return JSONResponse(content={ |
|
"access_token": access_token, |
|
"redirect_url": url |
|
}) |
|
|
|
except HTTPException as e: |
|
return templates.TemplateResponse("register.html", {"request": request, "error_message": e.detail}) |
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
async def landing(request: Request): |
|
return templates.TemplateResponse("landing.html", {"request": request}) |
|
|
|
def verify_password(plain_password, hashed_password): |
|
return pwd_context.verify(plain_password, hashed_password) |
|
|
|
def get_password_hash(password): |
|
return pwd_context.hash(password) |
|
|
|
def authenticate_user(db: Session, email: str, password: str): |
|
user = db.query(User).filter(User.email == username).first() |
|
if not user or not verify_password(password, user.hashed_password): |
|
return False |
|
return user |
|
|
|
|
|
def create_access_token(data: dict, expires_delta: timedelta = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)): |
|
to_encode = data.copy() |
|
expire = datetime.utcnow() + expires_delta |
|
to_encode.update({"exp": expire}) |
|
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) |
|
return encoded_jwt |
|
|
|
def verify_token(token: str = Depends(oauth2_scheme)): |
|
if token.startswith("Bearer "): |
|
token = token.split(" ")[1] |
|
|
|
try: |
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) |
|
user_email = payload.get("sub") |
|
if user_email is None: |
|
raise HTTPException(status_code=401, detail="Invalid authentication credentials") |
|
return user_email |
|
except jwt.ExpiredSignatureError: |
|
raise HTTPException(status_code=401, detail="Token has expired") |
|
except jwt.PyJWTError as e: |
|
|
|
print(f"JWT decoding error: {e}") |
|
raise HTTPException(status_code=401, detail="Could not validate credentials") |
|
|
|
def validate_token(token: str): |
|
try: |
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) |
|
username: str = payload.get("sub") |
|
if username is None: |
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") |
|
return TokenData(username=username) |
|
except JWTError: |
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") |
|
|
|
@app.get("/token/validate") |
|
async def token_validate(token: str = Depends(oauth2_scheme)): |
|
return validate_token(token) |
|
|
|
@app.post("/login") |
|
async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)): |
|
return await login_for_access_token(form_data.username, form_data.password, db) |
|
|
|
async def login_for_access_token(username: str, password: str, db: Session): |
|
user = authenticate_user(db,form_data.username, form_data.password) |
|
if not user: |
|
raise HTTPException( |
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
detail="Incorrect username or password", |
|
headers={"WWW-Authenticate": "Bearer"}, |
|
) |
|
access_token = create_access_token(data={"sub": user.email}) |
|
return {"access_token": access_token, "token_type": "bearer"} |
|
|
|
def authenticate_user(db: Session, username: str, password: str): |
|
user = get_user_by_email(db, username) |
|
if not user or not pwd_context.verify(password, user.hashed_password): |
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or password") |
|
return user |
|
|
|
def register_user(user_data: UserCreate, db: Session): |
|
if get_user_by_email(db, user_data.email): |
|
raise HTTPException(status_code=400, detail="Email already registered") |
|
|
|
hashed_password = pwd_context.hash(user_data.password) |
|
verification_token = generate_verification_token(user_data.email) |
|
reset_link = f"http://gregniuki-loginauth.hf.space/verify?token={verification_token}" |
|
send_verification_email(user_data.email, reset_link) |
|
|
|
new_user = User( |
|
email=user_data.email, |
|
username=user_data.username, |
|
hashed_password=hashed_password, |
|
email_verification_token=verification_token |
|
) |
|
db.add(new_user) |
|
db.commit() |
|
db.refresh(new_user) |
|
return new_user |
|
|
|
@app.get("/protected", response_class=HTMLResponse) |
|
async def get_protected( |
|
request: Request, |
|
db: Session = Depends(get_db), |
|
authorization: Optional[str] = Header(None), |
|
token: Optional[str] = None |
|
): |
|
|
|
if authorization: |
|
scheme, _, token = authorization.partition(' ') |
|
if scheme.lower() != 'bearer': |
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication scheme") |
|
else: |
|
|
|
token = request.cookies.get("access_token") |
|
|
|
if not token: |
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") |
|
|
|
|
|
try: |
|
user_email = verify_token(token) |
|
db_user = get_user_by_email(db, user_email) |
|
if db_user is None or not db_user.is_verified: |
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or not verified in the database") |
|
|
|
|
|
return templates.TemplateResponse("protected.html", {"request": request, "user": db_user.username}) |
|
except Exception as e: |
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e)) |
|
|
|
def verify_email(verification_token: str, db: Session = Depends(get_db)): |
|
|
|
user = get_user_by_verification_token(db, verification_token) |
|
|
|
if not user: |
|
raise HTTPException(status_code=400, detail="Invalid verification token") |
|
|
|
if user.is_verified: |
|
raise HTTPException(status_code=400, detail="Email already verified") |
|
|
|
|
|
user.is_verified = True |
|
user.email_verification_token = None |
|
db.commit() |
|
return {"message": "Email verification successful"} |
|
|
|
def get_user_by_verification_token(db: Session, verification_token: str): |
|
return db.query(User).filter(User.email_verification_token == verification_token).first() |
|
|
|
def reset_password(user: User, db: Session): |
|
verification_token = generate_verification_token(user.email) |
|
reset_link = f"http://gregniuki-loginauth.hf.space/reset-password?token={verification_token}" |
|
send_verification_email(user.email, reset_link) |
|
|
|
user.email_verification_token = verification_token |
|
db.commit() |
|
|
|
def get_current_user(token: str = Depends(verify_token)): |
|
return token |