ai-audio-books / src /text_modification_chain.py
Aliaksandr
merge dev into main (#13)
f655f69 unverified
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from pydantic import BaseModel
from src.prompts import ModifyTextPrompt
from src.utils import GPTModels, get_chat_llm
class ModifiedTextOutput(BaseModel):
text_raw: str
text_modified: str
def modify_text_chain(llm_model: GPTModels):
llm = get_chat_llm(llm_model=llm_model, temperature=0.0)
prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(ModifyTextPrompt.SYSTEM),
HumanMessagePromptTemplate.from_template(ModifyTextPrompt.USER),
]
)
chain = RunnablePassthrough.assign(text_modified=prompt | llm | StrOutputParser()) | (
lambda inputs: ModifiedTextOutput(
text_raw=inputs["text"], text_modified=inputs["text_modified"]
)
)
return chain