File size: 4,758 Bytes
6369972 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
"""
Based on short description, make a longer description.
PROMPT> python -m src.fiction.fiction_writer
"""
import json
import time
import logging
from math import ceil
from typing import Optional
from dataclasses import dataclass
from pydantic import BaseModel, Field
from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.core.llms.llm import LLM
logger = logging.getLogger(__name__)
class BookDraft(BaseModel):
book_title: str = Field(description="Human readable title.")
overview: str = Field(description="What is this about?")
elaborate: str = Field(description="Details")
background_story: str = Field(description="What is the background story.")
blurb: str = Field(description="The back cover of the book. Immediately capture the readers attention.")
goal: str = Field(description="What is the goal.")
main_characters: list[str] = Field(description="List of characters in the story and their background story.")
character_flaws: list[str] = Field(description="Character flaws relevant to the story.")
plot_devices: list[str] = Field(description="Items that appear in the story.")
possible_plot_ideas: list[str] = Field(description="List of story directions.")
challenges: list[str] = Field(description="Things that could go wrong or be difficult.")
chapter_title_list: list[str] = Field(description="Name of each chapter.")
final_story: str = Field(description="Based on the above, what is the final story.")
@dataclass
class FictionWriter:
"""
Given a short text, elaborate on it.
"""
query: str
response: dict
metadata: dict
@classmethod
def execute(cls, llm: LLM, query: str, system_prompt: Optional[str]) -> 'FictionWriter':
"""
Invoke LLM to write a fiction based on the query.
"""
if not isinstance(llm, LLM):
raise ValueError("Invalid LLM instance.")
if not isinstance(query, str):
raise ValueError("Invalid query.")
chat_message_list = []
if system_prompt:
chat_message_list.append(
ChatMessage(
role=MessageRole.SYSTEM,
content=system_prompt,
)
)
chat_message_list.append(ChatMessage(
role=MessageRole.USER,
content=query
))
start_time = time.perf_counter()
sllm = llm.as_structured_llm(BookDraft)
try:
chat_response = sllm.chat(chat_message_list)
except Exception as e:
logger.error(f"FictionWriter failed to chat with LLM: {e}")
raise ValueError(f"Failed to chat with LLM: {e}")
json_response = json.loads(chat_response.message.content)
end_time = time.perf_counter()
duration = int(ceil(end_time - start_time))
metadata = dict(llm.metadata)
metadata["llm_classname"] = llm.class_name()
metadata["duration"] = duration
result = FictionWriter(
query=query,
response=json_response,
metadata=metadata
)
return result
def raw_response_dict(self, include_metadata=True, include_query=True) -> dict:
d = self.response.copy()
if include_metadata:
d['metadata'] = self.metadata
if include_query:
d['query'] = self.query
return d
if __name__ == "__main__":
from src.llm_factory import get_llm
from src.prompt.prompt_catalog import PromptCatalog
import os
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler()
]
)
system_prompt = "You are a fiction writer that has been given a short description to elaborate on."
system_prompt = "You are a non-fiction writer that has been given a short description to elaborate on."
prompt_catalog = PromptCatalog()
prompt_catalog.load(os.path.join(os.path.dirname(__file__), 'data', 'simple_fiction_prompts.jsonl'))
prompt_item = prompt_catalog.find("0e8e9b9d-95dd-4632-b47c-dcc4625a556d")
if not prompt_item:
raise ValueError("Prompt item not found.")
query = prompt_item.prompt
llm = get_llm("ollama-llama3.1") # works
# llm = get_llm("openrouter-paid-gemini-2.0-flash-001") # works
# llm = get_llm("ollama-qwen")
# llm = get_llm("ollama-phi")
# llm = get_llm("deepseek-chat")
print(f"System: {system_prompt}")
print(f"\n\nQuery: {query}")
result = FictionWriter.execute(llm, query, system_prompt)
print("\n\nResponse:")
print(json.dumps(result.raw_response_dict(include_query=False), indent=2))
|