File size: 4,266 Bytes
9d615c0
 
352827d
 
9d615c0
 
 
 
aae060c
 
 
 
 
 
 
 
 
71b8ea3
 
9d615c0
95a84bb
9d615c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8cabdd
9d615c0
 
 
 
 
 
 
 
 
ce65deb
 
 
 
 
 
befaba8
ea855b6
ad2034b
 
 
 
 
ce65deb
9d615c0
 
 
 
 
ea855b6
7afd0f0
f15f01c
 
 
98a0d76
 
f15f01c
98a0d76
f15f01c
98a0d76
2afcd7a
f15f01c
 
 
7afd0f0
 
 
ea855b6
 
9d615c0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""This module contains functions for loading a ConversationalRetrievalChain"""

## May 24 move to langchain_community because everything deprecated

import logging

import wandb
from langchain.chains import ConversationalRetrievalChain
##from langchain.chat_models import ChatOpenAI

from langchain_community.chat_models import ChatOpenAI

##from langchain.embeddings import OpenAIEmbeddings

from langchain_community.embeddings import OpenAIEmbeddings


## deprectated from langchain.vectorstores import Chroma
from langchain_community.vectorstores import Chroma
from prompts import load_chat_prompt
import pathlib


logger = logging.getLogger(__name__)


def load_vector_store(wandb_run: wandb.run, openai_api_key: str) -> Chroma:
    """Load a vector store from a Weights & Biases artifact
    Args:
        run (wandb.run): An active Weights & Biases run
        openai_api_key (str): The OpenAI API key to use for embedding
    Returns:
        Chroma: A chroma vector store object
    """
    # load vector store artifact
    vector_store_artifact_dir = wandb_run.use_artifact(
        wandb_run.config.vector_store_artifact, type="search_index"
    ).download()
    embedding_fn = OpenAIEmbeddings(openai_api_key=openai_api_key)
    # load vector store
    vector_store = Chroma(
        embedding_function=embedding_fn, persist_directory=vector_store_artifact_dir
    )

    return vector_store


def load_chain(wandb_run: wandb.run, vector_store: Chroma, openai_api_key: str):
    """Load a ConversationalQA chain from a config and a vector store
    Args:
        wandb_run (wandb.run): An active Weights & Biases run
        vector_store (Chroma): A Chroma vector store object
        openai_api_key (str): The OpenAI API key to use for embedding
    Returns:
        ConversationalRetrievalChain: A ConversationalRetrievalChain object
    """
    retriever = vector_store.as_retriever()
    llm = ChatOpenAI(
        openai_api_key=openai_api_key,
        model_name=wandb_run.config.model_name,
        temperature=wandb_run.config.chat_temperature,
        max_retries=wandb_run.config.max_fallback_retries,
    )
    chat_prompt_dir = wandb_run.use_artifact(
        wandb_run.config.chat_prompt_artifact, type="prompt"
    ).download()
    qa_prompt = load_chat_prompt(f"{chat_prompt_dir}/chat_prompt_massa.json")
    
    print ( '\\n===================\\nqa_prompt = ', qa_prompt)
    
    qa_chain = ConversationalRetrievalChain.from_llm(
        llm=llm,
        chain_type="stuff",
        retriever=retriever,
        combine_docs_chain_kwargs={"prompt": qa_prompt},
        return_source_documents=True,
    )
    return qa_chain


def get_answer(
    chain: ConversationalRetrievalChain,
    question: str,
    chat_history: list[tuple[str, str]],
    wandb_run: wandb.run
):
    """Get an answer from a ConversationalRetrievalChain
    Args:
        chain (ConversationalRetrievalChain): A ConversationalRetrievalChain object
        question (str): The question to ask
        chat_history (list[tuple[str, str]]): A list of tuples of (question, answer)
    Returns:
        str: The answer to the question
    """
    # Define logging configuration
    logging.basicConfig(filename='user_input.log', level=logging.INFO,
                    format='%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
    
    # Log user question
    logging.info(f"User question: {question}")

    

    wandb.log({"question": question })
        
    # Log training progress
    
    
    result = chain(
        inputs={"question": question, "chat_history": chat_history},
        return_only_outputs=True,
    )
    response = f"Answer:\t{result['answer']}"

    print( "file name"+ wandb_run.config.log_file)
    
    f_name = wandb_run.config.log_file
    
    #if isinstance(f_name, str) and f_name:
    #    f_name = pathlib.Path(f_name)

    #    with open(f_name, "w") as file1:
    # Writing data to a file
    #        file1.write("Hello \n")
            
    
    #if f_name and f_name.is_file():
    #    ret = f_name.write("r"))

          # if f_name and f_name.is_file():
        ##template = json.load(f_name.open("r"))
    print("File writing complete."+"quest = "+question+" answer : "+ result['answer'])
    
    return response