File size: 3,213 Bytes
6369972 8628c58 6369972 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import json
import logging
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional
from src.uuid_util.is_valid_uuid import is_valid_uuid
logger = logging.getLogger(__name__)
@dataclass
class PromptItem:
"""Dataclass to hold a single prompt with tags, UUID, and any extra fields."""
id: str
prompt: str
tags: List[str] = field(default_factory=list)
extras: Dict[str, Any] = field(default_factory=dict)
class PromptCatalog:
"""
A catalog of PromptItem objects, keyed by UUID.
Supports loading from one or more JSONL files, each containing
one JSON object per line.
"""
def __init__(self):
self._catalog: Dict[str, PromptItem] = {}
def load(self, filepath: str) -> None:
"""
Load prompts from a JSONL file. Each line is expected to have
fields like 'id', 'prompt', 'tags', etc.
Logs an error if 'id' or 'prompt' is missing/empty, then skips that row.
"""
with open(filepath, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, start=1):
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError as e:
logger.error(f"JSON decode error in {filepath} at line {line_num}: {e}")
continue
pid = data.get('id')
prompt_text = data.get('prompt')
if not pid:
logger.error(f"Missing 'id' field in {filepath} at line {line_num}. Skipping row.")
continue
if not prompt_text:
logger.error(f"Missing or empty 'prompt' for ID '{pid}' in {filepath} at line {line_num}. Skipping row.")
continue
if not is_valid_uuid(pid):
logger.error(f"Invalid UUID in {filepath} at line {line_num}: '{pid}'. Skipping row.")
continue
tags = data.get('tags', [])
extras = {k: v for k, v in data.items() if k not in ('id', 'prompt', 'tags')}
if self._catalog.get(pid):
logger.error(f"Duplicate UUID found in {filepath} at line {line_num}: {pid}. Skipping row.")
continue
item = PromptItem(id=pid, prompt=prompt_text, tags=tags, extras=extras)
self._catalog[pid] = item
def find(self, prompt_id: str) -> Optional[PromptItem]:
"""Retrieve a PromptItem by its ID (UUID). Returns None if not found."""
if not is_valid_uuid(prompt_id):
raise ValueError(f"Invalid UUID: {prompt_id}")
return self._catalog.get(prompt_id)
def find_by_tag(self, tag: str) -> List[PromptItem]:
"""
Return a list of all PromptItems that contain the given tag
(case-sensitive match).
"""
return [item for item in self._catalog.values() if tag in item.tags]
def all(self) -> List[PromptItem]:
"""Return a list of all PromptItems in the order they were inserted."""
return list(self._catalog.values())
|