dapa / src /database.py
raushan-in's picture
file added
66c0d0c
raw
history blame
2.28 kB
import re
from contextlib import asynccontextmanager
from datetime import datetime
from fastapi import Depends
from pydantic import validator
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlmodel import Field, SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
from scams import scam_categories
from settings import settings
database_url = settings.DATABASE_URL.get_secret_value()
engine = create_async_engine(database_url, echo=settings.is_dev())
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async def create_db_and_tables():
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
@asynccontextmanager
async def get_session() -> AsyncSession:
"""
Dependency function to provide an async database session.
Ensures proper cleanup after use.
"""
async with async_session() as session:
try:
yield session
finally:
await session.close()
class Scammer(SQLModel, table=True):
"""Scammer ORM Model."""
id: int = Field(default=None, primary_key=True)
scammer_mobile: str = Field(index=True, description="Scammer mobile number")
scam_id: int = Field(description="Scam ID of the scam type")
reporter_ordeal: str = Field(description="Summary of the scam")
reporter_mobile: str = Field(description="Reporter mobile number")
created_at: datetime = Field(
default_factory=datetime.utcnow, description="Timestamp of report creation"
)
@validator("scammer_mobile", "reporter_mobile", pre=True)
def validate_mobile_number(cls, value: str) -> str:
"""Validate mobile numbers using a regex."""
pattern = r"^\+\d{1,3}-?\d{6,14}$" # E.164 format
if not re.match(pattern, value):
raise ValueError(f"Invalid mobile number: {value}")
return value
@validator("scam_id")
def validate_scam_id(cls, value: int) -> int:
"""Validate if scam_id exists in scam_categories."""
if value not in scam_categories.keys():
raise ValueError(
f"Invalid scam_id: {value}. Must be one of {list(scam_categories.keys())}."
)
return value