import json import logging from dataclasses import dataclass, field from typing import List, Dict, Any, Optional from src.uuid_util.is_valid_uuid import is_valid_uuid logger = logging.getLogger(__name__) @dataclass class PromptItem: """Dataclass to hold a single prompt with tags, UUID, and any extra fields.""" id: str prompt: str tags: List[str] = field(default_factory=list) extras: Dict[str, Any] = field(default_factory=dict) class PromptCatalog: """ A catalog of PromptItem objects, keyed by UUID. Supports loading from one or more JSONL files, each containing one JSON object per line. """ def __init__(self): self._catalog: Dict[str, PromptItem] = {} def load(self, filepath: str) -> None: """ Load prompts from a JSONL file. Each line is expected to have fields like 'id', 'prompt', 'tags', etc. Logs an error if 'id' or 'prompt' is missing/empty, then skips that row. """ with open(filepath, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f, start=1): line = line.strip() if not line: continue try: data = json.loads(line) except json.JSONDecodeError as e: logger.error(f"JSON decode error in {filepath} at line {line_num}: {e}") continue pid = data.get('id') prompt_text = data.get('prompt') if not pid: logger.error(f"Missing 'id' field in {filepath} at line {line_num}. Skipping row.") continue if not prompt_text: logger.error(f"Missing or empty 'prompt' for ID '{pid}' in {filepath} at line {line_num}. Skipping row.") continue if not is_valid_uuid(pid): logger.error(f"Invalid UUID in {filepath} at line {line_num}: '{pid}'. Skipping row.") continue tags = data.get('tags', []) extras = {k: v for k, v in data.items() if k not in ('id', 'prompt', 'tags')} if self._catalog.get(pid): logger.error(f"Duplicate UUID found in {filepath} at line {line_num}: {pid}. Skipping row.") continue item = PromptItem(id=pid, prompt=prompt_text, tags=tags, extras=extras) self._catalog[pid] = item def find(self, prompt_id: str) -> Optional[PromptItem]: """Retrieve a PromptItem by its ID (UUID). Returns None if not found.""" if not is_valid_uuid(prompt_id): raise ValueError(f"Invalid UUID: {prompt_id}") return self._catalog.get(prompt_id) def find_by_tag(self, tag: str) -> List[PromptItem]: """ Return a list of all PromptItems that contain the given tag (case-sensitive match). """ return [item for item in self._catalog.values() if tag in item.tags] def all(self) -> List[PromptItem]: """Return a list of all PromptItems in the order they were inserted.""" return list(self._catalog.values())