Spaces:
Runtime error
Runtime error
from typing import Optional, List, Dict, Any | |
from sqlalchemy import select, update, delete, func | |
from sqlalchemy.ext.asyncio import AsyncSession | |
from app.database.models import * | |
from app.database.base import get_session | |
import json | |
from app.utils.exceptions import DatabaseError, ValidationError | |
async def set_user(tg_id: int) -> Optional[User]: | |
"""Create a new user if not exists or return existing user.""" | |
async with get_session() as session: | |
try: | |
query = select(User).where(User.tg_id == tg_id) | |
user = await session.scalar(query) | |
if not user: | |
user = User(tg_id=tg_id) | |
session.add(user) | |
print("User added") | |
return user | |
except Exception as e: | |
raise DatabaseError(f"Error setting user: {str(e)}") | |
async def user_register( | |
tg_id: int, | |
name: str, | |
login: str, | |
contact: str, | |
subscribe: bool | |
) -> None: | |
"""Update user registration information.""" | |
async with get_session() as session: | |
try: | |
query = update(User).where(User.tg_id == tg_id).values( | |
name=name, | |
login=login, | |
contact=contact, | |
subscription_status="active" if subscribe else "inactive" | |
) | |
await session.execute(query) | |
except Exception as e: | |
raise DatabaseError(f"Error registering user: {str(e)}") | |
async def check_login_unique(login: str) -> bool: | |
"""Check if login is available""" | |
async with get_session() as session: | |
user = await session.scalar( | |
select(User).where(User.login == login) | |
) | |
return user is None | |
async def get_catalog() -> Optional[List[str]]: | |
"""Get list of all service names.""" | |
async with get_session() as session: | |
try: | |
query = select(Service).where(Service.is_active == True) | |
result = await session.execute(query) | |
services = result.scalars().all() | |
return services if services else None | |
except Exception as e: | |
raise DatabaseError(f"Error getting catalog: {str(e)}") | |
async def get_service_info(service_idx: str) -> Optional[Service]: | |
"""Get detailed information about a specific service.""" | |
async with get_session() as session: | |
try: | |
query = select(Service).where( | |
Service.id == service_idx, | |
Service.is_active == True | |
) | |
service = await session.scalar(query) | |
return service if service else None | |
except Exception as e: | |
raise DatabaseError(f"Error getting service info: {str(e)}") | |
async def add_service(name: str, desc: str, price: int, active=bool) -> None: | |
"""Add a new service to the catalog.""" | |
async with get_session() as session: | |
try: | |
service = Service( | |
service_name=name, | |
service_description=desc, | |
service_price=price, | |
is_active=active | |
) | |
session.add(service) | |
except Exception as e: | |
raise DatabaseError(f"Error adding service: {str(e)}") | |
async def edit_service(serv_id: int, param: str, change: Any, active: bool) -> None: | |
"""Edit an existing service.""" | |
param_mapping = { | |
'name': 'service_name', | |
'desc': 'service_description', | |
'price': 'service_price' | |
} | |
if param not in param_mapping: | |
raise ValueError(f"Invalid parameter: {param}") | |
async with get_session() as session: | |
try: | |
query = update(Service).where( | |
Service.id == serv_id | |
).values({param_mapping[param]: change}) | |
await session.execute(query) | |
except Exception as e: | |
raise DatabaseError(f"Error editing service: {str(e)}") | |
async def delete_service(serv_id: int) -> bool: | |
"""Delete a service from the catalog.""" | |
async with get_session() as session: | |
try: | |
query = select(Service).where(Service.id == serv_id) | |
service = await session.scalar(query) | |
if not service: | |
return False | |
feedback_query = select(Feedback).where(Feedback.service_id == service.id) | |
has_feedback = await session.scalar(feedback_query) | |
if has_feedback: | |
update_query = ( | |
update(Service) | |
.where(Service.id == serv_id) | |
.values(is_active=False) | |
) | |
await session.execute(update_query) | |
else: | |
await session.delete(service) | |
return True | |
except Exception as e: | |
raise DatabaseError(f"Error deleting service: {str(e)}") | |
async def get_leadmagnets() -> Optional[List[str]]: | |
"""Get list of all active lead magnets.""" | |
async with get_session() as session: | |
try: | |
query = select(LeadMagnet.trigger).where(LeadMagnet.is_active == True) | |
result = await session.execute(query) | |
magnets = result.scalars().all() | |
return magnets if magnets else None | |
except Exception as e: | |
raise DatabaseError(f"Error getting lead magnets: {str(e)}") | |
async def get_leadmagnet_info(trigger: str) -> Optional[LeadMagnet]: | |
"""Get detailed information about a specific lead magnet.""" | |
async with get_session() as session: | |
try: | |
query = select(LeadMagnet).where( | |
LeadMagnet.trigger == trigger, | |
LeadMagnet.is_active == True | |
) | |
magnet = await session.scalar(query) | |
return magnet if magnet else None | |
except Exception as e: | |
raise DatabaseError(f"Error getting lead magnet info: {str(e)}") | |
async def add_leadmagnet(trigger: str, content: str, active: bool) -> None: | |
"""Add a new lead magnet.""" | |
async with get_session() as session: | |
try: | |
magnet = LeadMagnet( | |
trigger=trigger, | |
content=content, | |
is_active=active | |
) | |
session.add(magnet) | |
except Exception as e: | |
raise DatabaseError(f"Error adding lead magnet: {str(e)}") | |
async def edit_leadmagnet(name, param, change): | |
async with get_session() as session: | |
replace_dict = {'trigger': 'trigger', | |
'content': 'content', | |
'status': 'is_active'} | |
query = select(LeadMagnet).where(LeadMagnet.trigger == name) | |
result = await session.execute(query) | |
lead = result.scalars().first() | |
if lead: | |
update_query = ( | |
update(LeadMagnet) | |
.where(LeadMagnet.trigger == name) | |
.values({replace_dict[param]: change}) | |
.execution_options(synchronize_session="fetch") | |
) | |
await session.execute(update_query) | |
await session.commit() | |
async def delete_leadmagnet(name: str) -> None: | |
"""Delete a lead magnet.""" | |
async with get_session() as session: | |
try: | |
query = delete(LeadMagnet).where(LeadMagnet.trigger == name) | |
await session.execute(query) | |
except Exception as e: | |
raise DatabaseError(f"Error deleting lead magnet: {str(e)}") | |
async def get_tests() -> Optional[List[str]]: | |
"""Get list of all active tests.""" | |
async with get_session() as session: | |
try: | |
query = select(Test).where(Test.is_active == True) | |
result = await session.execute(query) | |
tests = result.scalars().all() | |
return tests if tests else None | |
except Exception as e: | |
raise DatabaseError(f"Error getting tests: {str(e)}") | |
async def add_test_wo_points( | |
name: str, | |
test_type: str, | |
desc: str, | |
status: bool, | |
completion_message: str | |
) -> None: | |
"""Add a new test without points system.""" | |
async with get_session() as session: | |
try: | |
test = Test( | |
test_name=name, | |
test_type=test_type, | |
test_description=desc, | |
is_active=status, | |
completion_message=completion_message | |
) | |
session.add(test) | |
except Exception as e: | |
raise DatabaseError(f"Error adding test: {str(e)}") | |
async def add_question_vars_wo_points(test_name: str, text: str) -> None: | |
"""Add questions and variants to a test without points system.""" | |
async with get_session() as session: | |
try: | |
# Get test ID | |
test = await session.scalar( | |
select(Test).where(Test.test_name == test_name) | |
) | |
if not test: | |
raise ValidationError(f"Test {test_name} not found") | |
# Split text into question and variants | |
parts = text.split('***') | |
if len(parts) != 2: | |
raise ValidationError("Invalid question format") | |
question = TestQuestion( | |
test_id=test.id, | |
question_content=parts[0].strip(), | |
question_variants=parts[1].strip(), | |
question_points="{}" # Empty JSON for non-pointed questions | |
) | |
session.add(question) | |
except Exception as e: | |
raise DatabaseError(f"Error adding question: {str(e)}") | |
async def add_test_result_w_points(test_name: str, text: str) -> None: | |
"""Add test results with point ranges.""" | |
async with get_session() as session: | |
try: | |
test = await session.scalar( | |
select(Test).where(Test.test_name == test_name) | |
) | |
if not test: | |
raise ValidationError(f"Test {test_name} not found") | |
parts = text.split('\n') | |
if len(parts) != 2: | |
raise ValidationError("Invalid result format") | |
point_range = parts[0].strip() | |
min_points, max_points = map(int, point_range.split('-')) | |
result = TestResult( | |
test_id=test.id, | |
min_points=min_points, | |
max_points=max_points, | |
result_text=parts[1].strip() | |
) | |
session.add(result) | |
except Exception as e: | |
raise DatabaseError(f"Error adding test result: {str(e)}") | |
async def delete_test(t_id: int) -> None: | |
"""Delete a test and all related questions and results.""" | |
async with get_session() as session: | |
try: | |
test = await session.scalar( | |
select(Test).where(Test.id == t_id) | |
) | |
if test: | |
await session.delete(test) # Cascade will handle related records | |
except Exception as e: | |
raise DatabaseError(f"Error deleting test: {str(e)}") | |
async def get_test(t_id: int) -> Optional[Dict[str, Any]]: | |
"""Get complete test information including questions and results.""" | |
async with get_session() as session: | |
try: | |
test_query = select(Test).where( | |
Test.id == t_id, | |
Test.is_active == True | |
) | |
test = await session.scalar(test_query) | |
if not test: | |
return None | |
questions_query = select(TestQuestion).where( | |
TestQuestion.test_id == test.id | |
) | |
results_query = select(TestResult).where( | |
TestResult.test_id == test.id | |
) | |
questions = (await session.execute(questions_query)).scalars().all() | |
results = (await session.execute(results_query)).scalars().all() | |
return { | |
"id": t_id, | |
"test": test, | |
"questions": questions, | |
"results": results | |
} | |
except Exception as e: | |
raise DatabaseError(f"Error getting test: {str(e)}") | |
async def change_test_status(t_id: int, status: bool) -> None: | |
"""Change test active status.""" | |
async with get_session() as session: | |
try: | |
query = update(Test).where( | |
Test.id == t_id | |
).values(is_active=True if status == "Да" else False) | |
await session.execute(query) | |
except Exception as e: | |
raise DatabaseError(f"Error changing test status: {str(e)}") | |
async def add_feedback( | |
user_id: int, | |
service_name: str, | |
rating: int, | |
review: str | |
) -> None: | |
"""Add new feedback for a service.""" | |
async with get_session() as session: | |
try: | |
service = await session.scalar( | |
select(Service).where(Service.service_name == service_name) | |
) | |
if not service: | |
raise ValidationError(f"Service {service_name} not found") | |
feedback = Feedback( | |
user_id=user_id, | |
service_id=service.id, | |
rating=rating, | |
review=review, | |
is_new=True | |
) | |
session.add(feedback) | |
except Exception as e: | |
raise DatabaseError(f"Error adding feedback: {str(e)}") | |
async def get_new_feedback() -> Optional[List[Feedback]]: | |
"""Get all new feedback entries.""" | |
async with get_session() as session: | |
try: | |
query = select(Feedback).where(Feedback.is_new == True) | |
result = await session.execute(query) | |
feedback = result.scalars().all() | |
return feedback if feedback else None | |
except Exception as e: | |
raise DatabaseError(f"Error getting new feedback: {str(e)}") | |
async def mark_feedback_as_read(feedback_id: int) -> None: | |
"""Mark feedback as read.""" | |
async with get_session() as session: | |
try: | |
query = update(Feedback).where( | |
Feedback.id == feedback_id | |
).values(is_new=False) | |
await session.execute(query) | |
except Exception as e: | |
raise DatabaseError(f"Error marking feedback as read: {str(e)}") | |
async def get_user_info(tg_id: int) -> Optional[User]: | |
"""Get user information by Telegram ID""" | |
async with get_session() as session: | |
try: | |
query = select(User).where(User.tg_id == tg_id) | |
user = await session.scalar(query) | |
return user | |
except Exception as e: | |
raise DatabaseError(f"Error getting user info: {str(e)}") | |
async def start_test_attempt(user_id: int, test_id: str) -> Optional[Dict[str, Any]]: | |
"""Create new test attempt and return first question""" | |
async with get_session() as session: | |
try: | |
test = await session.scalar( | |
select(Test).where( | |
Test.id == test_id, | |
Test.is_active == True | |
) | |
) | |
if not test: | |
return None | |
user = await session.scalar( | |
select(User).where(User.tg_id == user_id) | |
) | |
if not user: | |
return None | |
# Create test attempt | |
attempt = TestAttempt( | |
user_id=user_id, | |
test_id=test.id | |
) | |
session.add(attempt) | |
await session.flush() # Get attempt ID | |
# Get first question | |
question = await session.scalar( | |
select(TestQuestion) | |
.where(TestQuestion.test_id == test.id) | |
.order_by(TestQuestion.id) | |
) | |
await session.commit() | |
return { | |
"attempt_id": attempt.id, | |
"question": question, | |
"total_questions": await session.scalar( | |
select(func.count()).select_from(TestQuestion) | |
.where(TestQuestion.test_id == test.id) | |
) | |
} | |
except Exception as e: | |
raise DatabaseError(f"Error starting test: {str(e)}") | |
async def record_answer(attempt_id: int, question_id: int, answer: str) -> Optional[Dict[str, Any]]: | |
"""Record user's answer and return next question or result""" | |
async with get_session() as session: | |
try: | |
# Get the test attempt first | |
attempt = await session.scalar( | |
select(TestAttempt).where(TestAttempt.id == attempt_id) | |
) | |
if not attempt: | |
raise DatabaseError("Test attempt not found") | |
# Get question and test | |
question = await session.scalar( | |
select(TestQuestion).where(TestQuestion.id == question_id) | |
) | |
if not question: | |
raise DatabaseError("Question not found") | |
test = await session.scalar( | |
select(Test).where(Test.id == question.test_id) | |
) | |
# Calculate points | |
points = 0 | |
if test.test_type == "С баллами": | |
variants_raw = question.question_variants.split('\n') | |
for variant in variants_raw: | |
if variant.strip(): | |
try: | |
variant_parts = variant.strip().split('...') | |
if len(variant_parts) == 2: | |
variant_text, points_str = variant_parts | |
if variant_text.strip() == answer.split("...")[0].strip(): | |
points = int(points_str.strip()) | |
break | |
except ValueError: | |
continue | |
# Create and save answer record | |
answer_record = TestAnswer( | |
attempt_id=attempt_id, | |
question_id=question_id, | |
answer_given=answer, | |
points_earned=points | |
) | |
session.add(answer_record) | |
await session.flush() | |
# Get next question | |
next_question = await session.scalar( | |
select(TestQuestion) | |
.where(TestQuestion.test_id == test.id) | |
.where(TestQuestion.id > question_id) | |
.order_by(TestQuestion.id) | |
) | |
if next_question: | |
await session.commit() | |
return {"next_question": next_question} | |
# If no next question, test is complete | |
# Calculate total score | |
answers = await session.scalars( | |
select(TestAnswer) | |
.where(TestAnswer.attempt_id == attempt_id) | |
) | |
total_score = sum(ans.points_earned for ans in answers.all()) | |
# Update attempt with final score | |
attempt.score = total_score | |
if test.test_type == "С баллами": | |
# Get appropriate result | |
result = await session.scalar( | |
select(TestResult) | |
.where(TestResult.test_id == test.id) | |
.where(TestResult.min_points <= total_score) | |
.where(TestResult.max_points >= total_score) | |
) | |
attempt.result = result.result_text if result else None | |
result_dict = { | |
"completed": True, | |
"total_points": total_score, | |
"result": result.result_text if result else None | |
} | |
else: | |
result_dict = { | |
"completed": True, | |
"result": test.completion_message | |
} | |
attempt.result = test.completion_message | |
await session.commit() | |
return result_dict | |
except Exception as e: | |
await session.rollback() | |
raise DatabaseError(f"Error recording answer: {str(e)}") | |
async def check_user_registered(user_id: int) -> bool: | |
"""Check if user has completed registration""" | |
async with get_session() as session: | |
try: | |
user = await session.scalar( | |
select(User) | |
.where(User.tg_id == user_id) | |
) | |
print(f"User found: {user}") # Debug print | |
return bool(user.name) | |
except Exception as e: | |
raise DatabaseError(f"Error checking user registration: {str(e)}") | |
async def get_user_test_results(user_login: str) -> List[Dict[str, Any]]: | |
"""Get all test results for a user""" | |
async with get_session() as session: | |
try: | |
user = await session.scalar( | |
select(User).where(User.login == user_login) | |
) | |
if not user: | |
return "Пользователь не найден" | |
attempts = await session.execute( | |
select(TestAttempt, Test) | |
.join(Test) | |
.where(TestAttempt.user_id == user.tg_id) | |
.order_by(TestAttempt.completed_at.desc()) | |
) | |
if attempts: | |
return ([ | |
{ | |
"test_name": test.test_name, | |
"completed_at": attempt.completed_at, | |
"score": attempt.score, | |
"result": attempt.result | |
} | |
for attempt, test in attempts | |
]) | |
except Exception as e: | |
raise DatabaseError(f"Error getting test results: {str(e)}") | |
async def get_user_registration_info(user_id: int) -> str: | |
"""Get formatted user registration information""" | |
async with get_session() as session: | |
try: | |
user = await session.scalar( | |
select(User).where(User.tg_id == user_id) | |
) | |
if not user: | |
return "Информация о пользователе не найдена" | |
return ( | |
"📋 Ваша регистрационная информация:\n" | |
f"ID: {user.tg_id}\n" | |
f"Имя: {user.name or 'Не указано'}\n" | |
f"Логин: {user.login or 'Не указано'}\n" | |
f"Контакт: {user.contact or 'Не указано'}\n" | |
f"Статус подписки: {'Активна' if user.subscription_status == 'active' else 'Неактивна'}" | |
) | |
except Exception as e: | |
raise DatabaseError(f"Error getting user info: {str(e)}") | |
async def get_all_test_answers() -> List[Dict[str, Any]]: | |
"""Fetch all test answers with related information""" | |
async with get_session() as session: | |
try: | |
result = await session.execute( | |
select(TestAnswer, TestAttempt, User, Test, TestQuestion) | |
.join(TestAttempt, TestAttempt.id == TestAnswer.attempt_id) | |
.join(User, User.id == TestAttempt.user_id) | |
.join(Test, Test.id == TestAttempt.test_id) | |
.join(TestQuestion, TestQuestion.id == TestAnswer.question_id) | |
.order_by(TestAttempt.completed_at.desc()) | |
) | |
answers = result.fetchall() | |
print(answers) # Debug print | |
return [ | |
{ | |
"answer_id": answer.id, | |
"user_name": user.name, | |
"test_name": test.test_name, | |
"question": question.question_content, | |
"answer_given": answer.answer_given, | |
"points_earned": answer.points_earned, | |
"completed_at": attempt.completed_at.strftime("%d.%m.%Y %H:%M") | |
} | |
for answer, attempt, user, test, question in answers | |
] | |
except Exception as e: | |
raise DatabaseError(f"Error fetching test answers: {str(e)}") | |
async def own_login_check(user_id: int, login: str) -> bool: | |
"""Check if the provided login matches the user's login""" | |
async with get_session() as session: | |
try: | |
user = await session.scalar( | |
select(User).where(User.tg_id == user_id) | |
) | |
if not user: | |
return False | |
return user.login == login | |
except Exception as e: | |
raise DatabaseError(f"Error checking login: {str(e)}") | |
async def update_user_data(user_id: int, param: str, change: Any) -> None: | |
async with get_session() as session: | |
replace_dict = {'Имя': 'name', | |
'Логин': 'login', | |
'Контакт': 'contact', | |
'Статус подписки на рассылку': 'subscription_status'} | |
query = select(User).where(User.tg_id == user_id) | |
result = await session.execute(query) | |
user = result.scalars().first() | |
if user: | |
update_query = ( | |
update(User) | |
.where(User.tg_id == user_id) | |
.values({replace_dict[param]: change}) | |
.execution_options(synchronize_session="fetch") | |
) | |
await session.execute(update_query) | |
await session.commit() | |
async def get_broadcast_users() -> List[int]: | |
"""Fetch all users for broadcasting""" | |
async with get_session() as session: | |
try: | |
result = await session.scalars( | |
select(User.tg_id) | |
.where(User.subscription_status == 'active') | |
) | |
return result.fetchall() | |
except Exception as e: | |
raise DatabaseError(f"Error fetching broadcast users: {str(e)}") |