import asyncio import functools import json import logging import os import ssl import time from enum import Enum from typing import Dict, Optional, Literal, Union import msgpack import pika from pika.exceptions import AMQPConnectionError from pydantic import BaseModel, field_validator logger = logging.getLogger(__name__) def sync(f): @functools.wraps(f) def wrapper(*args, **kwargs): return asyncio.get_event_loop().run_until_complete(f(*args, **kwargs)) return wrapper class Priority(Enum): LOW = 1 NORMAL = 5 HIGH = 10 class Headers(BaseModel): job_id: str priority: Priority task_type: Optional[str] = None @field_validator('priority', mode='before') @classmethod def _convert_priority(cls, value): if isinstance(value, Priority): return value return Priority[value] class RabbitMQConfig(BaseModel): host: str port: int username: str password: str protocol: str class BasicPikaClient: def __init__(self): self.username = os.environ.get('RABBITMQ_USER', '') self.password = os.environ.get('RABBITMQ_PASSWD', '') self.host = os.environ.get('RABBITMQ_HOST', 'localhost') self.port = os.environ.get('RABBITMQ_PORT', 5672) self.protocol = "amqp" self._init_connection_parameters() self._connect() def _connect(self): tries = 0 while True: try: self.connection = pika.BlockingConnection(self.parameters) self.channel = self.connection.channel() if self.connection.is_open: break except (AMQPConnectionError, Exception) as e: time.sleep(5) tries += 1 if tries == 20: raise AMQPConnectionError(e) def _init_connection_parameters(self): if any([self.username, self.password]): self.credentials = pika.PlainCredentials(self.username, self.password) self.parameters = pika.ConnectionParameters( host=self.host, port=int(self.port), virtual_host="/", credentials=self.credentials, ) else: self.parameters = pika.ConnectionParameters( self.host, int(self.port), "/", ) if self.protocol == "amqps": # SSL Context for TLS configuration of Amazon MQ for RabbitMQ ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) ssl_context.set_ciphers("ECDHE+AESGCM:!ECDSA") self.parameters.ssl_options = pika.SSLOptions(context=ssl_context) def check_connection(self): if not self.connection or self.connection.is_closed: self._connect() def close(self): self.channel.close() self.connection.close() def declare_queue( self, queue_name, exclusive: bool = False, max_priority: int = 10 ): self.check_connection() logger.debug(f"Trying to declare queue({queue_name})...") self.channel.queue_declare( queue=queue_name, exclusive=exclusive, durable=True, arguments={"x-max-priority": max_priority} if max_priority > 0 else None ) def declare_exchange(self, exchange_name: str, exchange_type: str = "direct"): self.check_connection() self.channel.exchange_declare( exchange=exchange_name, exchange_type=exchange_type ) def bind_queue(self, exchange_name: str, queue_name: str, routing_key: str): self.check_connection() self.channel.queue_bind( exchange=exchange_name, queue=queue_name, routing_key=routing_key ) def unbind_queue(self, exchange_name: str, queue_name: str, routing_key: str): self.channel.queue_unbind( queue=queue_name, exchange=exchange_name, routing_key=routing_key ) class BasicMessageSender(BasicPikaClient): message_encoding_type: Literal['bytes', 'json'] = 'json' def encode_message(self, body: Union[Dict, str], encoding_type: str = "bytes"): if encoding_type == "bytes": return msgpack.packb(body) elif encoding_type == "json": if isinstance(body, dict): return json.dumps(body) return body else: raise NotImplementedError def send_message( self, exchange_name: str, routing_key: str, body: Union[Dict, str], headers: Optional[Headers], ): body = self.encode_message(body=body, encoding_type=self.message_encoding_type) properties = pika.BasicProperties(delivery_mode=pika.spec.PERSISTENT_DELIVERY_MODE, priority=headers.priority.value if headers else None, headers=headers.model_dump() if headers else None) self.check_connection() self.channel.basic_publish( exchange=exchange_name, routing_key=routing_key, body=body, properties=properties, ) logger.debug( f"Sent message. Exchange: {exchange_name}, Routing Key: {routing_key}, Body: {body[:128]}" ) class BasicMessageReceiver(BasicPikaClient): def __init__(self): super().__init__() self.channel_tag = None def decode_message(self, body): if type(body) == bytes: return json.loads(body) else: raise NotImplementedError def get_message(self, queue_name: str, auto_ack: bool = False): method_frame, header_frame, body = self.channel.basic_get( queue=queue_name, auto_ack=auto_ack ) if method_frame: logger.debug(f"{method_frame}, {header_frame}, {body}") return method_frame, header_frame, body else: logger.debug("No message returned") return None def consume_messages(self, queue, callback): self.check_connection() self.channel_tag = self.channel.basic_consume( queue=queue, on_message_callback=callback, auto_ack=True ) logger.debug(" [*] Waiting for messages. To exit press CTRL+C") self.channel.start_consuming() def cancel_consumer(self): if self.channel_tag is not None: self.channel.basic_cancel(self.channel_tag) self.channel_tag = None else: logger.error("Do not cancel a non-existing job")