|
|
|
from fastapi import FastAPI, Depends, HTTPException, Request, Form, status |
|
from fastapi.responses import RedirectResponse, HTMLResponse |
|
from fastapi.templating import Jinja2Templates |
|
from fastapi.security import OAuth2PasswordBearer |
|
from pydantic import BaseModel |
|
from sqlalchemy.orm import Session |
|
from database import get_db, get_user_by_email |
|
from models import User |
|
from passlib.context import CryptContext |
|
from datetime import datetime, timedelta |
|
import jwt |
|
from emailx import send_verification_email, generate_verification_token |
|
|
|
from fastapi.staticfiles import StaticFiles |
|
|
|
from typing import Optional |
|
import httpx |
|
import os |
|
from starlette.middleware.sessions import SessionMiddleware |
|
from authlib.integrations.starlette_client import OAuth |
|
|
|
|
|
|
|
|
|
GOOGLE_CLIENT_ID = os.getenv('GOOGLE_CLIENT_ID') |
|
GOOGLE_CLIENT_SECRET = os.getenv('GOOGLE_CLIENT_SECRET') |
|
SECRET_KEY = os.getenv('SecretKey', 'default_secret') |
|
ALGORITHM = "HS256" |
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30 |
|
|
|
|
|
app = FastAPI() |
|
app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY) |
|
oauth = OAuth(app) |
|
|
|
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
|
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
|
|
|
class TokenData(BaseModel): |
|
token: str |
|
|
|
class UserCreate(BaseModel): |
|
username: str |
|
email: str |
|
password: str |
|
confirm_password: str |
|
|
|
oauth.register( |
|
name='google', |
|
client_id=os.environ['GOOGLE_CLIENT_ID'], |
|
client_secret=os.environ['GOOGLE_CLIENT_SECRET'], |
|
access_token_url='https://accounts.google.com/o/oauth2/token', |
|
authorize_url='https://accounts.google.com/o/oauth2/auth', |
|
authorize_params=None, |
|
api_base_url='https://www.googleapis.com/oauth2/v1/', |
|
client_kwargs={'scope': 'openid email profile'} |
|
) |
|
|
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
templates = Jinja2Templates(directory="templates") |
|
|
|
|
|
@app.get("/login/oauth") |
|
async def login_oauth(request: Request): |
|
redirect_uri = request.url_for('auth_callback') |
|
return await oauth.google.authorize_redirect(request, redirect_uri) |
|
|
|
@app.get("/auth/callback") |
|
async def auth_callback(request: Request, db: Session = Depends(get_db)): |
|
token = await oauth.google.authorize_access_token(request) |
|
user_info = await oauth.google.parse_id_token(request, token) |
|
request.session["user_info"] = user_info |
|
|
|
db_user = db.query(User).filter(User.email == user_info['email']).first() |
|
if not db_user: |
|
db_user = User(email=user_info['email'], username=user_info['name'], is_verified=True) |
|
db.add(db_user) |
|
db.commit() |
|
db.refresh(db_user) |
|
|
|
access_token = create_access_token(data={"sub": db_user.email}, expires_delta=timedelta(minutes=30)) |
|
response = RedirectResponse(url="/protected") |
|
response.set_cookie(key="access_token", value=f"Bearer {access_token}", httponly=True) |
|
return response |
|
|
|
|
|
def create_access_token(data: dict, expires_delta: timedelta = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)): |
|
to_encode = data.copy() |
|
expire = datetime.utcnow() + expires_delta |
|
to_encode.update({"exp": expire}) |
|
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) |
|
return encoded_jwt |
|
|
|
def verify_token(token: str = Depends(oauth2_scheme)): |
|
try: |
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) |
|
return payload.get("sub") |
|
except jwt.ExpiredSignatureError: |
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired") |
|
except jwt.PyJWTError: |
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials") |
|
|
|
def authenticate_user(db: Session, email: str, password: str): |
|
user = get_user_by_email(db, email) |
|
if not user or not pwd_context.verify(password, user.hashed_password): |
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or password") |
|
return user |
|
|
|
def register_user(user_data: UserCreate, db: Session): |
|
if get_user_by_email(db, user_data.email): |
|
raise HTTPException(status_code=400, detail="Email already registered") |
|
|
|
hashed_password = pwd_context.hash(user_data.password) |
|
verification_token = generate_verification_token(user_data.email) |
|
reset_link = f"http://gregniuki-loginauth.hf.space/verify?token={verification_token}" |
|
send_verification_email(user_data.email, reset_link) |
|
|
|
new_user = User( |
|
email=user_data.email, |
|
username=user_data.username, |
|
hashed_password=hashed_password, |
|
email_verification_token=verification_token |
|
) |
|
db.add(new_user) |
|
db.commit() |
|
db.refresh(new_user) |
|
return new_user |
|
|
|
def verify_email(verification_token: str, db: Session = Depends(get_db)): |
|
|
|
user = get_user_by_verification_token(db, verification_token) |
|
|
|
if not user: |
|
raise HTTPException(status_code=400, detail="Invalid verification token") |
|
|
|
if user.is_verified: |
|
raise HTTPException(status_code=400, detail="Email already verified") |
|
|
|
|
|
user.is_verified = True |
|
user.email_verification_token = None |
|
db.commit() |
|
return {"message": "Email verification successful"} |
|
|
|
def get_user_by_verification_token(db: Session, verification_token: str): |
|
return db.query(User).filter(User.email_verification_token == verification_token).first() |
|
|
|
def reset_password(user: User, db: Session): |
|
verification_token = generate_verification_token(user.email) |
|
reset_link = f"http://gregniuki-loginauth.hf.space/reset-password?token={verification_token}" |
|
send_verification_email(user.email, reset_link) |
|
|
|
user.email_verification_token = verification_token |
|
db.commit() |
|
|
|
def get_current_user(token: str = Depends(verify_token)): |
|
return token |