from typing import Optional, List, Dict, Any from sqlalchemy import select, update, delete, func from sqlalchemy.ext.asyncio import AsyncSession from app.database.models import * from app.database.base import get_session import json from app.utils.exceptions import DatabaseError, ValidationError async def set_user(tg_id: int) -> Optional[User]: """Create a new user if not exists or return existing user.""" async with get_session() as session: try: query = select(User).where(User.tg_id == tg_id) user = await session.scalar(query) if not user: user = User(tg_id=tg_id) session.add(user) print("User added") return user except Exception as e: raise DatabaseError(f"Error setting user: {str(e)}") async def user_register( tg_id: int, name: str, login: str, contact: str, subscribe: bool ) -> None: """Update user registration information.""" async with get_session() as session: try: query = update(User).where(User.tg_id == tg_id).values( name=name, login=login, contact=contact, subscription_status="active" if subscribe else "inactive" ) await session.execute(query) except Exception as e: raise DatabaseError(f"Error registering user: {str(e)}") async def check_login_unique(login: str) -> bool: """Check if login is available""" async with get_session() as session: user = await session.scalar( select(User).where(User.login == login) ) return user is None async def get_catalog() -> Optional[List[str]]: """Get list of all service names.""" async with get_session() as session: try: query = select(Service).where(Service.is_active == True) result = await session.execute(query) services = result.scalars().all() return services if services else None except Exception as e: raise DatabaseError(f"Error getting catalog: {str(e)}") async def get_service_info(service_idx: str) -> Optional[Service]: """Get detailed information about a specific service.""" async with get_session() as session: try: query = select(Service).where( Service.id == service_idx, Service.is_active == True ) service = await session.scalar(query) return service if service else None except Exception as e: raise DatabaseError(f"Error getting service info: {str(e)}") async def add_service(name: str, desc: str, price: int, active=bool) -> None: """Add a new service to the catalog.""" async with get_session() as session: try: service = Service( service_name=name, service_description=desc, service_price=price, is_active=active ) session.add(service) except Exception as e: raise DatabaseError(f"Error adding service: {str(e)}") async def edit_service(serv_id: int, param: str, change: Any, active: bool) -> None: """Edit an existing service.""" param_mapping = { 'name': 'service_name', 'desc': 'service_description', 'price': 'service_price' } if param not in param_mapping: raise ValueError(f"Invalid parameter: {param}") async with get_session() as session: try: query = update(Service).where( Service.id == serv_id ).values({param_mapping[param]: change}) await session.execute(query) except Exception as e: raise DatabaseError(f"Error editing service: {str(e)}") async def delete_service(serv_id: int) -> bool: """Delete a service from the catalog.""" async with get_session() as session: try: query = select(Service).where(Service.id == serv_id) service = await session.scalar(query) if not service: return False feedback_query = select(Feedback).where(Feedback.service_id == service.id) has_feedback = await session.scalar(feedback_query) if has_feedback: update_query = ( update(Service) .where(Service.id == serv_id) .values(is_active=False) ) await session.execute(update_query) else: await session.delete(service) return True except Exception as e: raise DatabaseError(f"Error deleting service: {str(e)}") async def get_leadmagnets() -> Optional[List[str]]: """Get list of all active lead magnets.""" async with get_session() as session: try: query = select(LeadMagnet.trigger).where(LeadMagnet.is_active == True) result = await session.execute(query) magnets = result.scalars().all() return magnets if magnets else None except Exception as e: raise DatabaseError(f"Error getting lead magnets: {str(e)}") async def get_leadmagnet_info(trigger: str) -> Optional[LeadMagnet]: """Get detailed information about a specific lead magnet.""" async with get_session() as session: try: query = select(LeadMagnet).where( LeadMagnet.trigger == trigger, LeadMagnet.is_active == True ) magnet = await session.scalar(query) return magnet if magnet else None except Exception as e: raise DatabaseError(f"Error getting lead magnet info: {str(e)}") async def add_leadmagnet(trigger: str, content: str, active: bool) -> None: """Add a new lead magnet.""" async with get_session() as session: try: magnet = LeadMagnet( trigger=trigger, content=content, is_active=active ) session.add(magnet) except Exception as e: raise DatabaseError(f"Error adding lead magnet: {str(e)}") async def edit_leadmagnet(name, param, change): async with get_session() as session: replace_dict = {'trigger': 'trigger', 'content': 'content', 'status': 'is_active'} query = select(LeadMagnet).where(LeadMagnet.trigger == name) result = await session.execute(query) lead = result.scalars().first() if lead: update_query = ( update(LeadMagnet) .where(LeadMagnet.trigger == name) .values({replace_dict[param]: change}) .execution_options(synchronize_session="fetch") ) await session.execute(update_query) await session.commit() async def delete_leadmagnet(name: str) -> None: """Delete a lead magnet.""" async with get_session() as session: try: query = delete(LeadMagnet).where(LeadMagnet.trigger == name) await session.execute(query) except Exception as e: raise DatabaseError(f"Error deleting lead magnet: {str(e)}") async def get_tests() -> Optional[List[str]]: """Get list of all active tests.""" async with get_session() as session: try: query = select(Test).where(Test.is_active == True) result = await session.execute(query) tests = result.scalars().all() return tests if tests else None except Exception as e: raise DatabaseError(f"Error getting tests: {str(e)}") async def add_test_wo_points( name: str, test_type: str, desc: str, status: bool, completion_message: str ) -> None: """Add a new test without points system.""" async with get_session() as session: try: test = Test( test_name=name, test_type=test_type, test_description=desc, is_active=status, completion_message=completion_message ) session.add(test) except Exception as e: raise DatabaseError(f"Error adding test: {str(e)}") async def add_question_vars_wo_points(test_name: str, text: str) -> None: """Add questions and variants to a test without points system.""" async with get_session() as session: try: # Get test ID test = await session.scalar( select(Test).where(Test.test_name == test_name) ) if not test: raise ValidationError(f"Test {test_name} not found") # Split text into question and variants parts = text.split('***') if len(parts) != 2: raise ValidationError("Invalid question format") question = TestQuestion( test_id=test.id, question_content=parts[0].strip(), question_variants=parts[1].strip(), question_points="{}" # Empty JSON for non-pointed questions ) session.add(question) except Exception as e: raise DatabaseError(f"Error adding question: {str(e)}") async def add_test_result_w_points(test_name: str, text: str) -> None: """Add test results with point ranges.""" async with get_session() as session: try: test = await session.scalar( select(Test).where(Test.test_name == test_name) ) if not test: raise ValidationError(f"Test {test_name} not found") parts = text.split('\n') if len(parts) != 2: raise ValidationError("Invalid result format") point_range = parts[0].strip() min_points, max_points = map(int, point_range.split('-')) result = TestResult( test_id=test.id, min_points=min_points, max_points=max_points, result_text=parts[1].strip() ) session.add(result) except Exception as e: raise DatabaseError(f"Error adding test result: {str(e)}") async def delete_test(t_id: int) -> None: """Delete a test and all related questions and results.""" async with get_session() as session: try: test = await session.scalar( select(Test).where(Test.id == t_id) ) if test: await session.delete(test) # Cascade will handle related records except Exception as e: raise DatabaseError(f"Error deleting test: {str(e)}") async def get_test(t_id: int) -> Optional[Dict[str, Any]]: """Get complete test information including questions and results.""" async with get_session() as session: try: test_query = select(Test).where( Test.id == t_id, Test.is_active == True ) test = await session.scalar(test_query) if not test: return None questions_query = select(TestQuestion).where( TestQuestion.test_id == test.id ) results_query = select(TestResult).where( TestResult.test_id == test.id ) questions = (await session.execute(questions_query)).scalars().all() results = (await session.execute(results_query)).scalars().all() return { "id": t_id, "test": test, "questions": questions, "results": results } except Exception as e: raise DatabaseError(f"Error getting test: {str(e)}") async def change_test_status(t_id: int, status: bool) -> None: """Change test active status.""" async with get_session() as session: try: query = update(Test).where( Test.id == t_id ).values(is_active=True if status == "Да" else False) await session.execute(query) except Exception as e: raise DatabaseError(f"Error changing test status: {str(e)}") async def add_feedback( user_id: int, service_name: str, rating: int, review: str ) -> None: """Add new feedback for a service.""" async with get_session() as session: try: service = await session.scalar( select(Service).where(Service.service_name == service_name) ) if not service: raise ValidationError(f"Service {service_name} not found") feedback = Feedback( user_id=user_id, service_id=service.id, rating=rating, review=review, is_new=True ) session.add(feedback) except Exception as e: raise DatabaseError(f"Error adding feedback: {str(e)}") async def get_new_feedback() -> Optional[List[Feedback]]: """Get all new feedback entries.""" async with get_session() as session: try: query = select(Feedback).where(Feedback.is_new == True) result = await session.execute(query) feedback = result.scalars().all() return feedback if feedback else None except Exception as e: raise DatabaseError(f"Error getting new feedback: {str(e)}") async def mark_feedback_as_read(feedback_id: int) -> None: """Mark feedback as read.""" async with get_session() as session: try: query = update(Feedback).where( Feedback.id == feedback_id ).values(is_new=False) await session.execute(query) except Exception as e: raise DatabaseError(f"Error marking feedback as read: {str(e)}") async def get_user_info(tg_id: int) -> Optional[User]: """Get user information by Telegram ID""" async with get_session() as session: try: query = select(User).where(User.tg_id == tg_id) user = await session.scalar(query) return user except Exception as e: raise DatabaseError(f"Error getting user info: {str(e)}") async def start_test_attempt(user_id: int, test_id: str) -> Optional[Dict[str, Any]]: """Create new test attempt and return first question""" async with get_session() as session: try: test = await session.scalar( select(Test).where( Test.id == test_id, Test.is_active == True ) ) if not test: return None user = await session.scalar( select(User).where(User.tg_id == user_id) ) if not user: return None # Create test attempt attempt = TestAttempt( user_id=user_id, test_id=test.id ) session.add(attempt) await session.flush() # Get attempt ID # Get first question question = await session.scalar( select(TestQuestion) .where(TestQuestion.test_id == test.id) .order_by(TestQuestion.id) ) await session.commit() return { "attempt_id": attempt.id, "question": question, "total_questions": await session.scalar( select(func.count()).select_from(TestQuestion) .where(TestQuestion.test_id == test.id) ) } except Exception as e: raise DatabaseError(f"Error starting test: {str(e)}") async def record_answer(attempt_id: int, question_id: int, answer: str) -> Optional[Dict[str, Any]]: """Record user's answer and return next question or result""" async with get_session() as session: try: # Get the test attempt first attempt = await session.scalar( select(TestAttempt).where(TestAttempt.id == attempt_id) ) if not attempt: raise DatabaseError("Test attempt not found") # Get question and test question = await session.scalar( select(TestQuestion).where(TestQuestion.id == question_id) ) if not question: raise DatabaseError("Question not found") test = await session.scalar( select(Test).where(Test.id == question.test_id) ) # Calculate points points = 0 if test.test_type == "С баллами": variants_raw = question.question_variants.split('\n') for variant in variants_raw: if variant.strip(): try: variant_parts = variant.strip().split('...') if len(variant_parts) == 2: variant_text, points_str = variant_parts if variant_text.strip() == answer.split("...")[0].strip(): points = int(points_str.strip()) break except ValueError: continue # Create and save answer record answer_record = TestAnswer( attempt_id=attempt_id, question_id=question_id, answer_given=answer, points_earned=points ) session.add(answer_record) await session.flush() # Get next question next_question = await session.scalar( select(TestQuestion) .where(TestQuestion.test_id == test.id) .where(TestQuestion.id > question_id) .order_by(TestQuestion.id) ) if next_question: await session.commit() return {"next_question": next_question} # If no next question, test is complete # Calculate total score answers = await session.scalars( select(TestAnswer) .where(TestAnswer.attempt_id == attempt_id) ) total_score = sum(ans.points_earned for ans in answers.all()) # Update attempt with final score attempt.score = total_score if test.test_type == "С баллами": # Get appropriate result result = await session.scalar( select(TestResult) .where(TestResult.test_id == test.id) .where(TestResult.min_points <= total_score) .where(TestResult.max_points >= total_score) ) attempt.result = result.result_text if result else None result_dict = { "completed": True, "total_points": total_score, "result": result.result_text if result else None } else: result_dict = { "completed": True, "result": test.completion_message } attempt.result = test.completion_message await session.commit() return result_dict except Exception as e: await session.rollback() raise DatabaseError(f"Error recording answer: {str(e)}") async def check_user_registered(user_id: int) -> bool: """Check if user has completed registration""" async with get_session() as session: try: user = await session.scalar( select(User) .where(User.tg_id == user_id) ) print(f"User found: {user}") # Debug print return bool(user.name) except Exception as e: raise DatabaseError(f"Error checking user registration: {str(e)}") async def get_user_test_results(user_login: str) -> List[Dict[str, Any]]: """Get all test results for a user""" async with get_session() as session: try: user = await session.scalar( select(User).where(User.login == user_login) ) if not user: return "Пользователь не найден" attempts = await session.execute( select(TestAttempt, Test) .join(Test) .where(TestAttempt.user_id == user.tg_id) .order_by(TestAttempt.completed_at.desc()) ) if attempts: return ([ { "test_name": test.test_name, "completed_at": attempt.completed_at, "score": attempt.score, "result": attempt.result } for attempt, test in attempts ]) except Exception as e: raise DatabaseError(f"Error getting test results: {str(e)}") async def get_user_registration_info(user_id: int) -> str: """Get formatted user registration information""" async with get_session() as session: try: user = await session.scalar( select(User).where(User.tg_id == user_id) ) if not user: return "Информация о пользователе не найдена" return ( "📋 Ваша регистрационная информация:\n" f"ID: {user.tg_id}\n" f"Имя: {user.name or 'Не указано'}\n" f"Логин: {user.login or 'Не указано'}\n" f"Контакт: {user.contact or 'Не указано'}\n" f"Статус подписки: {'Активна' if user.subscription_status == 'active' else 'Неактивна'}" ) except Exception as e: raise DatabaseError(f"Error getting user info: {str(e)}") async def get_all_test_answers() -> List[Dict[str, Any]]: """Fetch all test answers with related information""" async with get_session() as session: try: result = await session.execute( select(TestAnswer, TestAttempt, User, Test, TestQuestion) .join(TestAttempt, TestAttempt.id == TestAnswer.attempt_id) .join(User, User.id == TestAttempt.user_id) .join(Test, Test.id == TestAttempt.test_id) .join(TestQuestion, TestQuestion.id == TestAnswer.question_id) .order_by(TestAttempt.completed_at.desc()) ) answers = result.fetchall() print(answers) # Debug print return [ { "answer_id": answer.id, "user_name": user.name, "test_name": test.test_name, "question": question.question_content, "answer_given": answer.answer_given, "points_earned": answer.points_earned, "completed_at": attempt.completed_at.strftime("%d.%m.%Y %H:%M") } for answer, attempt, user, test, question in answers ] except Exception as e: raise DatabaseError(f"Error fetching test answers: {str(e)}") async def own_login_check(user_id: int, login: str) -> bool: """Check if the provided login matches the user's login""" async with get_session() as session: try: user = await session.scalar( select(User).where(User.tg_id == user_id) ) if not user: return False return user.login == login except Exception as e: raise DatabaseError(f"Error checking login: {str(e)}") async def update_user_data(user_id: int, param: str, change: Any) -> None: async with get_session() as session: replace_dict = {'Имя': 'name', 'Логин': 'login', 'Контакт': 'contact', 'Статус подписки на рассылку': 'subscription_status'} query = select(User).where(User.tg_id == user_id) result = await session.execute(query) user = result.scalars().first() if user: update_query = ( update(User) .where(User.tg_id == user_id) .values({replace_dict[param]: change}) .execution_options(synchronize_session="fetch") ) await session.execute(update_query) await session.commit() async def get_broadcast_users() -> List[int]: """Fetch all users for broadcasting""" async with get_session() as session: try: result = await session.scalars( select(User.tg_id) .where(User.subscription_status == 'active') ) return result.fetchall() except Exception as e: raise DatabaseError(f"Error fetching broadcast users: {str(e)}")