File size: 7,749 Bytes
2476a51 59ac5b7 2c0b3d9 841d578 2c0b3d9 1d62c79 9c6d234 1d62c79 9c6d234 2476a51 1d62c79 0052143 9c6d234 2c0b3d9 9c6d234 fce8ee7 2c0b3d9 9c6d234 81a2e90 67638d2 4775141 2c0b3d9 9c6d234 6b4e76c 9c6d234 6b4e76c 9c6d234 7821ebd 361449c 7821ebd 2c0b3d9 361449c 2c0b3d9 361449c 2c0b3d9 9c6d234 2c0b3d9 063b9b9 841d578 2c0b3d9 9c6d234 2c0b3d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
from fastapi import FastAPI, Depends, HTTPException, Request, Form, status
from fastapi.responses import RedirectResponse, HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel
from sqlalchemy.orm import Session
from database import get_db, get_user_by_email
from models import User
from passlib.context import CryptContext
from datetime import datetime, timedelta
import jwt
from emailx import send_verification_email, generate_verification_token
from fastapi.staticfiles import StaticFiles
from typing import Optional
import httpx
import os
from starlette.middleware.sessions import SessionMiddleware
from authlib.integrations.starlette_client import OAuth
# Environment variables
GOOGLE_CLIENT_ID = os.getenv('GOOGLE_CLIENT_ID')
GOOGLE_CLIENT_SECRET = os.getenv('GOOGLE_CLIENT_SECRET')
SECRET_KEY = os.getenv('SecretKey', 'default_secret')
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
# FastAPI and OAuth setup
app = FastAPI()
app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY)
oauth = OAuth()
# Password context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# OAuth2 scheme
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
class TokenData(BaseModel):
token: str
class UserCreate(BaseModel):
username: str
email: str
password: str
confirm_password: str
# OAuth Configuration
oauth.register(
name='google',
client_id=os.environ['GOOGLE_CLIENT_ID'],
client_secret=os.environ['GOOGLE_CLIENT_SECRET'],
access_token_url='https://accounts.google.com/o/oauth2/token',
authorize_url='https://accounts.google.com/o/oauth2/auth',
authorize_params=None,
api_base_url='https://www.googleapis.com/oauth2/v1/',
client_kwargs={'scope': 'openid email profile'}
)
# Static and template configurations
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
# OAuth routes
@app.get("/login/oauth")
async def login_oauth(request: Request):
redirect_uri = request.url_for('auth_callback')
return await oauth.google.authorize_redirect(request, redirect_uri)
@app.get("/auth/callback")
async def auth_callback(request: Request, db: Session = Depends(get_db)):
token = await oauth.google.authorize_access_token(request)
user_info = await oauth.google.parse_id_token(request, token)
request.session["user_info"] = user_info
db_user = db.query(User).filter(User.email == user_info['email']).first()
if not db_user:
db_user = User(email=user_info['email'], username=user_info['name'], is_verified=True)
db.add(db_user)
db.commit()
db.refresh(db_user)
access_token = create_access_token(data={"sub": db_user.email}, expires_delta=timedelta(minutes=30))
response = RedirectResponse(url="/protected")
response.set_cookie(key="access_token", value=f"Bearer {access_token}", httponly=True)
return response
@app.get("/", response_class=HTMLResponse)
async def landing(request: Request):
return templates.TemplateResponse("landing.html", {"request": request})
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password):
return pwd_context.hash(password)
def authenticate_user(db: Session, username: str, password: str):
user = db.query(User).filter(User.username == username).first()
if not user or not verify_password(password, user.hashed_password):
return False
return user
def create_access_token(data: dict, expires_delta: timedelta = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)):
to_encode = data.copy()
expire = datetime.utcnow() + expires_delta
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def verify_token(token: str = Depends(oauth2_scheme)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
return payload.get("sub")
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired")
def validate_token(token: str):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
return TokenData(username=username)
except JWTError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
@app.get("/token/validate")
async def token_validate(token: str = Depends(oauth2_scheme)):
return validate_token(token)
@app.post("/login")
async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
return await login_for_access_token(form_data.username, form_data.password, db)
async def login_for_access_token(username: str, password: str, db: Session):
user = authenticate_user(db, username, password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token = create_access_token(data={"sub": user.username})
return {"access_token": access_token, "token_type": "bearer"}
def authenticate_user(db: Session, email: str, password: str):
user = get_user_by_email(db, email)
if not user or not pwd_context.verify(password, user.hashed_password):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or password")
return user
def register_user(user_data: UserCreate, db: Session):
if get_user_by_email(db, user_data.email):
raise HTTPException(status_code=400, detail="Email already registered")
hashed_password = pwd_context.hash(user_data.password)
verification_token = generate_verification_token(user_data.email)
reset_link = f"http://gregniuki-loginauth.hf.space/verify?token={verification_token}"
send_verification_email(user_data.email, reset_link)
new_user = User(
email=user_data.email,
username=user_data.username,
hashed_password=hashed_password,
email_verification_token=verification_token
)
db.add(new_user)
db.commit()
db.refresh(new_user)
return new_user
def verify_email(verification_token: str, db: Session = Depends(get_db)):
# Verify the email using the token
user = get_user_by_verification_token(db, verification_token)
if not user:
raise HTTPException(status_code=400, detail="Invalid verification token")
if user.is_verified:
raise HTTPException(status_code=400, detail="Email already verified")
# Mark the email as verified
user.is_verified = True
user.email_verification_token = None # Optionally clear the verification token
db.commit()
return {"message": "Email verification successful"}
def get_user_by_verification_token(db: Session, verification_token: str):
return db.query(User).filter(User.email_verification_token == verification_token).first()
def reset_password(user: User, db: Session):
verification_token = generate_verification_token(user.email)
reset_link = f"http://gregniuki-loginauth.hf.space/reset-password?token={verification_token}"
send_verification_email(user.email, reset_link)
user.email_verification_token = verification_token
db.commit()
def get_current_user(token: str = Depends(verify_token)):
return token |