File size: 1,019 Bytes
f655f69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from langchain.prompts import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    SystemMessagePromptTemplate,
)
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from pydantic import BaseModel

from src.prompts import ModifyTextPrompt
from src.utils import GPTModels, get_chat_llm


class ModifiedTextOutput(BaseModel):
    text_raw: str
    text_modified: str


def modify_text_chain(llm_model: GPTModels):
    llm = get_chat_llm(llm_model=llm_model, temperature=0.0)

    prompt = ChatPromptTemplate.from_messages(
        [
            SystemMessagePromptTemplate.from_template(ModifyTextPrompt.SYSTEM),
            HumanMessagePromptTemplate.from_template(ModifyTextPrompt.USER),
        ]
    )

    chain = RunnablePassthrough.assign(text_modified=prompt | llm | StrOutputParser()) | (
        lambda inputs: ModifiedTextOutput(
            text_raw=inputs["text"], text_modified=inputs["text_modified"]
        )
    )
    return chain