Santipab commited on
Commit
bde3dc5
·
verified ·
1 Parent(s): dc4adb2

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.document_loaders import PyPDFLoader
2
+ from langchain_community.document_loaders import WebBaseLoader
3
+ from langchain_community.document_loaders import PyPDFLoader
4
+ from langchain_community.vectorstores import Chroma
5
+ from langchain_ollama import embeddings
6
+ from langchain_ollama import ChatOllama
7
+ from langchain_core.runnables import RunnablePassthrough
8
+ from langchain_core.output_parsers import StrOutputParser
9
+ from langchain_core.prompts import ChatPromptTemplate
10
+ from langchain.output_parsers import PydanticOutputParser
11
+ from langchain.text_splitter import CharacterTextSplitter
12
+ from sentence_transformers import SentenceTransformer
13
+ from aift.multimodal import textqa
14
+ from aift import setting
15
+ from langchain_community.document_loaders import TextLoader
16
+ from langchain_text_splitters import CharacterTextSplitter
17
+ import streamlit as st
18
+
19
+ class CustomEmbeddings:
20
+ def __init__(self, model_name="mrp/simcse-model-m-bert-thai-cased"):
21
+ """
22
+ Initialize the embedding model using SentenceTransformer.
23
+ :param model_name: Name of the pre-trained model
24
+ """
25
+ self.model = SentenceTransformer(model_name)
26
+
27
+ def embed_query(self, text):
28
+ """
29
+ Generate embeddings for a single query.
30
+ :param text: Input text to embed
31
+ :return: Embedding vector as a Python list
32
+ """
33
+ embedding = self.model.encode([text])
34
+ return embedding[0].tolist() # Convert NumPy array to list
35
+
36
+ async def aembed_query(self, text):
37
+ """
38
+ Asynchronous version of `embed_query`.
39
+ :param text: Input text to embed
40
+ :return: Embedding vector as a Python list
41
+ """
42
+ return self.embed_query(text)
43
+
44
+ def embed_documents(self, texts):
45
+ """
46
+ Generate embeddings for multiple documents.
47
+ :param texts: List of input texts to embed
48
+ :return: List of embedding vectors as Python lists
49
+ """
50
+ embeddings = self.model.encode(texts)
51
+ return [embedding.tolist() for embedding in embeddings]
52
+
53
+ async def aembed_documents(self, texts):
54
+ """
55
+ Asynchronous version of `embed_documents`.
56
+ :param texts: List of input texts to embed
57
+ :return: List of embedding vectors as Python lists
58
+ """
59
+ return self.embed_documents(texts)
60
+
61
+ # Set Pathumma API Key
62
+ setting.set_api_key('T69FqnYgOdreO5G0nZaM8gHcjo1sifyU')
63
+
64
+ # Define a simple wrapper for Pathumma
65
+ class PathummaModel:
66
+ def __init__(self):
67
+ pass
68
+
69
+ def generate(self, instruction: str, return_json: bool = False):
70
+ response = textqa.generate(instruction=instruction, return_json=return_json)
71
+ if return_json:
72
+ return response.get("content", "")
73
+ return response
74
+
75
+ def __call__(self, input: str):
76
+ return self.generate(input, return_json=False)
77
+
78
+ # Initialize Pathumma Model
79
+ model_local = PathummaModel()
80
+
81
+ # Load the document, split it into chunks, embed each chunk and load it into the vector store.
82
+ raw_documents = TextLoader('./mainn.txt').load()
83
+ text_splitter = CharacterTextSplitter(chunk_size=7500, chunk_overlap=0)
84
+ documents = text_splitter.split_documents(raw_documents)
85
+
86
+ # 2. Convert documents to Embeddings and store them
87
+ vectorstore = Chroma.from_documents(
88
+ documents=documents,
89
+ collection_name="rag-chroma",
90
+ embedding=CustomEmbeddings(model_name="mrp/simcse-model-m-bert-thai-cased"),
91
+ )
92
+ retriever = vectorstore.as_retriever()
93
+
94
+ after_rag_template = """ตอบคำถามโดยพิจารณาจากบริบทต่อไปนี้เท่านั้น:
95
+ {context}
96
+ คำถาม: {question}
97
+ """
98
+ after_rag_prompt = ChatPromptTemplate.from_template(after_rag_template)
99
+
100
+ # Query retriever for context and pass to Pathumma
101
+ def system_call(text_input):
102
+ question = text_input
103
+ retrieved_context = retriever.invoke(question)
104
+ context = "\n".join([doc.page_content for doc in retrieved_context])
105
+
106
+ after_rag_chain = after_rag_prompt.invoke({
107
+ "context": context,
108
+ "question": question,
109
+ })
110
+ response = model_local(after_rag_chain)
111
+ st.write("response")
112
+ st.write(response)
113
+ system_call("ผมชื่ออะไรเหรอ")