Spaces:
Running
Running
File size: 1,252 Bytes
69c0ff1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
from enum import StrEnum
from httpx import Timeout
from langchain_core.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain_openai import ChatOpenAI
from pydantic import BaseModel
from prompts import SplitTextPrompt
class GPTModels(StrEnum):
GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09"
GPT_4o_MINI = "gpt-4o-mini"
class TextPart(BaseModel):
character: str
text: str
class SplitTextOutput(BaseModel):
characters: list[str]
parts: list[TextPart]
def to_pretty_text(self):
lines = []
lines.append(f"characters: {self.characters}")
lines.extend(f"[{part.character}] {part.text}" for part in self.parts)
res = "\n".join(lines)
return res
def create_split_text_chain(llm_model: GPTModels):
llm = ChatOpenAI(model=llm_model, temperature=0.0, timeout=Timeout(60, connect=4))
llm = llm.with_structured_output(SplitTextOutput)
prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(SplitTextPrompt.SYSTEM),
HumanMessagePromptTemplate.from_template(SplitTextPrompt.USER),
]
)
chain = prompt | llm
return chain
|