hmrizal commited on
Commit
e9a5be2
·
verified ·
1 Parent(s): a61644e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -34
app.py CHANGED
@@ -6,9 +6,8 @@ import pandas as pd
6
  from langchain.document_loaders.csv_loader import CSVLoader
7
  from langchain.embeddings import HuggingFaceEmbeddings
8
  from langchain.vectorstores import FAISS
9
- from langchain.llms import HuggingFacePipeline
10
  from langchain.chains import ConversationalRetrievalChain
11
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
12
 
13
  # Global model cache
14
  MODEL_CACHE = {
@@ -20,36 +19,19 @@ MODEL_CACHE = {
20
  os.makedirs("user_data", exist_ok=True)
21
 
22
  def initialize_model_once():
23
- """Initialize model once using pipeline API"""
24
  with MODEL_CACHE["init_lock"]:
25
  if MODEL_CACHE["model"] is None:
26
- # Load model from Hugging Face Hub
27
- model_id = "meta-llama/Llama-2-7b-chat-hf"
28
-
29
- # Tokenizer
30
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ.get("HF_TOKEN"))
31
-
32
- # Model with low precision
33
- model = AutoModelForCausalLM.from_pretrained(
34
- model_id,
35
- token=os.environ.get("HF_TOKEN"),
36
- device_map="auto",
37
- load_in_8bit=True # Quantize model to 8-bit precision
38
- )
39
-
40
- # Create pipeline
41
- pipe = pipeline(
42
- "text-generation",
43
- model=model,
44
- tokenizer=tokenizer,
45
  max_new_tokens=512,
46
  temperature=0.2,
47
  top_p=0.9,
48
  repetition_penalty=1.2
49
  )
50
-
51
- # Create LangChain wrapper
52
- MODEL_CACHE["model"] = HuggingFacePipeline(pipeline=pipe)
53
 
54
  return MODEL_CACHE["model"]
55
 
@@ -91,7 +73,7 @@ class ChatBot:
91
  db_path = f"{self.user_dir}/db_faiss"
92
  embeddings = HuggingFaceEmbeddings(
93
  model_name='sentence-transformers/all-MiniLM-L6-v2',
94
- model_kwargs={'device': 'auto'}
95
  )
96
 
97
  db = FAISS.from_documents(data, embeddings)
@@ -105,7 +87,8 @@ class ChatBot:
105
  llm = initialize_model_once()
106
  self.chain = ConversationalRetrievalChain.from_llm(
107
  llm=llm,
108
- retriever=db.as_retriever(search_kwargs={"k": 4})
 
109
  )
110
  print("Chain created successfully")
111
  except Exception as e:
@@ -115,7 +98,7 @@ class ChatBot:
115
  file_info = f"CSV berhasil dimuat dengan {df.shape[0]} baris dan {len(df.columns)} kolom. Kolom: {', '.join(df.columns.tolist())}"
116
  self.chat_history.append(("System", file_info))
117
 
118
- return "File CSV berhasil diproses! Anda dapat mulai chat dengan model Llama 2."
119
  except Exception as e:
120
  import traceback
121
  print(traceback.format_exc())
@@ -131,6 +114,15 @@ class ChatBot:
131
 
132
  # Update chat history
133
  answer = result["answer"]
 
 
 
 
 
 
 
 
 
134
  self.chat_history.append((message, answer))
135
 
136
  return answer
@@ -141,12 +133,12 @@ class ChatBot:
141
 
142
  # UI Code dan handler functions sama seperti sebelumnya
143
  def create_gradio_interface():
144
- with gr.Blocks(title="Chat with CSV using Llama2 🦙") as interface:
145
  session_id = gr.State(lambda: str(uuid.uuid4()))
146
  chatbot_state = gr.State(lambda: None)
147
 
148
- gr.HTML("<h1 style='text-align: center;'>Chat with CSV using Llama2 🦙</h1>")
149
- gr.HTML("<h3 style='text-align: center;'>Asisten analisis CSV yang powerfull</h3>")
150
 
151
  with gr.Row():
152
  with gr.Column(scale=1):
@@ -158,11 +150,11 @@ def create_gradio_interface():
158
 
159
  with gr.Accordion("Informasi Model", open=False):
160
  gr.Markdown("""
161
- **Model**: Llama-2-7b-chat-hf
162
 
163
  **Fitur**:
164
- - Dioptimalkan untuk analisis data dan percakapan
165
- - Menggunakan API Hugging Face untuk efisiensi
166
  - Manajemen sesi per pengguna
167
  """)
168
 
 
6
  from langchain.document_loaders.csv_loader import CSVLoader
7
  from langchain.embeddings import HuggingFaceEmbeddings
8
  from langchain.vectorstores import FAISS
9
+ from langchain.llms import CTransformers
10
  from langchain.chains import ConversationalRetrievalChain
 
11
 
12
  # Global model cache
13
  MODEL_CACHE = {
 
19
  os.makedirs("user_data", exist_ok=True)
20
 
21
  def initialize_model_once():
22
+ """Initialize model once using CTransformers API"""
23
  with MODEL_CACHE["init_lock"]:
24
  if MODEL_CACHE["model"] is None:
25
+ # Load Mistral-7B-Instruct-v0.2.Q4_K_M.gguf model
26
+ MODEL_CACHE["model"] = CTransformers(
27
+ model="TheBloke/Mistral-7B-Instruct-v0.2-GGUF",
28
+ model_file="mistral-7b-instruct-v0.2.Q4_K_M.gguf",
29
+ model_type="mistral",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  max_new_tokens=512,
31
  temperature=0.2,
32
  top_p=0.9,
33
  repetition_penalty=1.2
34
  )
 
 
 
35
 
36
  return MODEL_CACHE["model"]
37
 
 
73
  db_path = f"{self.user_dir}/db_faiss"
74
  embeddings = HuggingFaceEmbeddings(
75
  model_name='sentence-transformers/all-MiniLM-L6-v2',
76
+ model_kwargs={'device': 'cpu'} # Explicitly set to CPU
77
  )
78
 
79
  db = FAISS.from_documents(data, embeddings)
 
87
  llm = initialize_model_once()
88
  self.chain = ConversationalRetrievalChain.from_llm(
89
  llm=llm,
90
+ retriever=db.as_retriever(search_kwargs={"k": 4}),
91
+ return_source_documents=True
92
  )
93
  print("Chain created successfully")
94
  except Exception as e:
 
98
  file_info = f"CSV berhasil dimuat dengan {df.shape[0]} baris dan {len(df.columns)} kolom. Kolom: {', '.join(df.columns.tolist())}"
99
  self.chat_history.append(("System", file_info))
100
 
101
+ return "File CSV berhasil diproses! Anda dapat mulai chat dengan Mistral 7B."
102
  except Exception as e:
103
  import traceback
104
  print(traceback.format_exc())
 
114
 
115
  # Update chat history
116
  answer = result["answer"]
117
+
118
+ # Optional: Add source info to answer
119
+ sources = result.get("source_documents", [])
120
+ if sources:
121
+ source_text = "\n\nSumber:\n"
122
+ for i, doc in enumerate(sources[:2], 1): # Limit to top 2 sources
123
+ source_text += f"{i}. {doc.page_content[:100]}...\n"
124
+ answer += source_text
125
+
126
  self.chat_history.append((message, answer))
127
 
128
  return answer
 
133
 
134
  # UI Code dan handler functions sama seperti sebelumnya
135
  def create_gradio_interface():
136
+ with gr.Blocks(title="Chat with CSV using Mistral 7B") as interface:
137
  session_id = gr.State(lambda: str(uuid.uuid4()))
138
  chatbot_state = gr.State(lambda: None)
139
 
140
+ gr.HTML("<h1 style='text-align: center;'>Chat with CSV using Mistral 7B</h1>")
141
+ gr.HTML("<h3 style='text-align: center;'>Asisten analisis CSV yang powerful</h3>")
142
 
143
  with gr.Row():
144
  with gr.Column(scale=1):
 
150
 
151
  with gr.Accordion("Informasi Model", open=False):
152
  gr.Markdown("""
153
+ **Model**: Mistral-7B-Instruct-v0.2-GGUF
154
 
155
  **Fitur**:
156
+ - GGUF model yang dioptimalkan untuk CPU
157
+ - Efisien untuk analisis data dan percakapan
158
  - Manajemen sesi per pengguna
159
  """)
160