File size: 3,213 Bytes
6369972
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8628c58
6369972
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import json
import logging
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional
from src.uuid_util.is_valid_uuid import is_valid_uuid

logger = logging.getLogger(__name__)

@dataclass
class PromptItem:
    """Dataclass to hold a single prompt with tags, UUID, and any extra fields."""
    id: str
    prompt: str
    tags: List[str] = field(default_factory=list)
    extras: Dict[str, Any] = field(default_factory=dict)

class PromptCatalog:
    """
    A catalog of PromptItem objects, keyed by UUID.
    Supports loading from one or more JSONL files, each containing
    one JSON object per line.
    """
    def __init__(self):
        self._catalog: Dict[str, PromptItem] = {}

    def load(self, filepath: str) -> None:
        """
        Load prompts from a JSONL file. Each line is expected to have
        fields like 'id', 'prompt', 'tags', etc.
        Logs an error if 'id' or 'prompt' is missing/empty, then skips that row.
        """
        with open(filepath, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, start=1):
                line = line.strip()
                if not line:
                    continue

                try:
                    data = json.loads(line)
                except json.JSONDecodeError as e:
                    logger.error(f"JSON decode error in {filepath} at line {line_num}: {e}")
                    continue

                pid = data.get('id')
                prompt_text = data.get('prompt')

                if not pid:
                    logger.error(f"Missing 'id' field in {filepath} at line {line_num}. Skipping row.")
                    continue
                if not prompt_text:
                    logger.error(f"Missing or empty 'prompt' for ID '{pid}' in {filepath} at line {line_num}. Skipping row.")
                    continue

                if not is_valid_uuid(pid):
                    logger.error(f"Invalid UUID in {filepath} at line {line_num}: '{pid}'. Skipping row.")
                    continue

                tags = data.get('tags', [])
                extras = {k: v for k, v in data.items() if k not in ('id', 'prompt', 'tags')}

                if self._catalog.get(pid):
                    logger.error(f"Duplicate UUID found in {filepath} at line {line_num}: {pid}. Skipping row.")
                    continue

                item = PromptItem(id=pid, prompt=prompt_text, tags=tags, extras=extras)
                self._catalog[pid] = item

    def find(self, prompt_id: str) -> Optional[PromptItem]:
        """Retrieve a PromptItem by its ID (UUID). Returns None if not found."""
        if not is_valid_uuid(prompt_id):
            raise ValueError(f"Invalid UUID: {prompt_id}")
        return self._catalog.get(prompt_id)

    def find_by_tag(self, tag: str) -> List[PromptItem]:
        """
        Return a list of all PromptItems that contain the given tag
        (case-sensitive match).
        """
        return [item for item in self._catalog.values() if tag in item.tags]

    def all(self) -> List[PromptItem]:
        """Return a list of all PromptItems in the order they were inserted."""
        return list(self._catalog.values())