File size: 4,316 Bytes
bde3dc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import Chroma
from langchain_ollama import embeddings
from langchain_ollama import ChatOllama
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain.output_parsers import PydanticOutputParser
from langchain.text_splitter import CharacterTextSplitter
from sentence_transformers import SentenceTransformer
from aift.multimodal import textqa
from aift import setting
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import CharacterTextSplitter
import streamlit as st

class CustomEmbeddings:
    def __init__(self, model_name="mrp/simcse-model-m-bert-thai-cased"):
        """

        Initialize the embedding model using SentenceTransformer.

        :param model_name: Name of the pre-trained model

        """
        self.model = SentenceTransformer(model_name)

    def embed_query(self, text):
        """

        Generate embeddings for a single query.

        :param text: Input text to embed

        :return: Embedding vector as a Python list

        """
        embedding = self.model.encode([text])
        return embedding[0].tolist()  # Convert NumPy array to list

    async def aembed_query(self, text):
        """

        Asynchronous version of `embed_query`.

        :param text: Input text to embed

        :return: Embedding vector as a Python list

        """
        return self.embed_query(text)

    def embed_documents(self, texts):
        """

        Generate embeddings for multiple documents.

        :param texts: List of input texts to embed

        :return: List of embedding vectors as Python lists

        """
        embeddings = self.model.encode(texts)
        return [embedding.tolist() for embedding in embeddings]

    async def aembed_documents(self, texts):
        """

        Asynchronous version of `embed_documents`.

        :param texts: List of input texts to embed

        :return: List of embedding vectors as Python lists

        """
        return self.embed_documents(texts)

# Set Pathumma API Key
setting.set_api_key('T69FqnYgOdreO5G0nZaM8gHcjo1sifyU')

# Define a simple wrapper for Pathumma
class PathummaModel:
    def __init__(self):
        pass

    def generate(self, instruction: str, return_json: bool = False):
        response = textqa.generate(instruction=instruction, return_json=return_json)
        if return_json:
            return response.get("content", "")
        return response

    def __call__(self, input: str):
        return self.generate(input, return_json=False)

# Initialize Pathumma Model
model_local = PathummaModel()

# Load the document, split it into chunks, embed each chunk and load it into the vector store.
raw_documents = TextLoader('./mainn.txt').load()
text_splitter = CharacterTextSplitter(chunk_size=7500, chunk_overlap=0)
documents = text_splitter.split_documents(raw_documents)

# 2. Convert documents to Embeddings and store them
vectorstore = Chroma.from_documents(
    documents=documents,
    collection_name="rag-chroma",
    embedding=CustomEmbeddings(model_name="mrp/simcse-model-m-bert-thai-cased"),
)
retriever = vectorstore.as_retriever()

after_rag_template = """ตอบคำถามโดยพิจารณาจากบริบทต่อไปนี้เท่านั้น:

{context}

คำถาม: {question}

"""
after_rag_prompt = ChatPromptTemplate.from_template(after_rag_template)

# Query retriever for context and pass to Pathumma
def system_call(text_input):
    question = text_input
    retrieved_context = retriever.invoke(question)
    context = "\n".join([doc.page_content for doc in retrieved_context])

    after_rag_chain = after_rag_prompt.invoke({
        "context": context,
        "question": question,
    })
    response = model_local(after_rag_chain)
    st.write("response")
    st.write(response)
system_call("ผมชื่ออะไรเหรอ")