gufett0 commited on
Commit
0b966fb
·
1 Parent(s): 44018bc

switched back to langchain

Browse files
Files changed (5) hide show
  1. app.py +225 -18
  2. appLlama.py +29 -0
  3. backend.py +0 -2
  4. backend2.py +95 -0
  5. requirements.txt +6 -7
app.py CHANGED
@@ -1,29 +1,236 @@
1
- from backend import handle_query
 
 
 
2
  import gradio as gr
3
-
 
 
4
 
5
  DESCRIPTION = """\
6
- # <div style="text-align: center;">Odi, l'assistente ricercatore degli Osservatori</div>
 
 
 
 
 
7
 
8
 
9
- 👉 Retrieval-Augmented Generation - Ask me anything about the research carried out at the Osservatori.
10
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
- chat_interface =gr.ChatInterface(
14
- fn=handle_query,
15
- chatbot=gr.Chatbot(height=500),
16
- textbox=gr.Textbox(placeholder="Chiedimi qualasiasi cosa relativa agli Osservatori", container=False, scale=7),
17
- #examples=[["Ciao, in cosa puoi aiutarmi?"],["Dimmi i risultati e le modalità di conduzione del censimento per favore"]]
18
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- with gr.Blocks(css=".gradio-container {background-color: #B9D9EB}") as demo:
22
- gr.Markdown(DESCRIPTION)
23
- #gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
24
- chat_interface.render()
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  if __name__ == "__main__":
27
- #progress = gr.Progress(track_tqdm=True)
28
- demo.launch()
29
-
 
1
+ import os
2
+ from threading import Thread
3
+ from typing import Iterator
4
+ from backend2 import load_documents, prepare_documents, get_context_sources
5
  import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer
9
 
10
  DESCRIPTION = """\
11
+ # La Chatbot degli Osservatori
12
+ """
13
+ MAX_MAX_NEW_TOKENS = 2048
14
+ DEFAULT_MAX_NEW_TOKENS = 1024
15
+ os.environ["MAX_INPUT_TOKEN_LENGTH"] = "8192"
16
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH"))
17
 
18
 
19
+ # Force usage of CPU
20
+ #device = torch.device("cpu")
21
+
22
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
+
24
+ model_id = "google/gemma-2-2b-it"
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ model_id,
27
+ device_map="auto",
28
+ torch_dtype= torch.float16 if torch.cuda.is_available() else torch.float32,
29
+ )
30
+ tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
31
+ #tokenizer = AutoTokenizer.from_pretrained(model_id)
32
+ tokenizer.use_default_system_prompt = False
33
+ model.config.sliding_window = 4096
34
+ #model = model.to(device)
35
+ model.eval()
36
+
37
+ ###------####
38
+ # rag
39
+ documents_paths = {
40
+ 'blockchain': 'documents/blockchain',
41
+ 'metaverse': 'documents/metaverso',
42
+ 'payment': 'documents/payment'
43
+ }
44
+
45
+ session_state = {"documents_loaded": False,
46
+ "document_db": None,
47
+ "original_message": None,
48
+ "clarification": False}
49
+
50
+ INSTRUCTION_1 = 'In italiano, chiedi molto brevemente se la domanda si riferisce agli "Osservatori Blockchain", "Osservatori Payment" oppure "Osservatori Metaverse".'
51
+ INSTRUCTION_2 = 'Sei un assistente che risponde in italiano alle domande basandosi solo sulle informazioni fornite nel contesto. Se non trovi informazioni, rispondi "Puoi chiedere maggiori informazioni all\'ufficio di riferimento.". Se invece la domanda è completamente fuori contesto, non rispondere e rammenta il topic del contesto'
52
+
53
+ default_error_response = (
54
+ 'Non sono sicuro che tu voglia indirizzare la tua ricerca su una di queste opzioni: '
55
+ '"Blockchain", "Metaverse", "Payment". '
56
+ 'Per favore utilizza il nome corretto.'
57
+ )
58
 
59
 
60
+ @spaces.GPU(duration=90)
61
+ def generate(
62
+ message: str,
63
+ chat_history: list[tuple[str, str]],
64
+ max_new_tokens: int = 1024,
65
+ temperature: float = 0.6,
66
+ top_p: float = 0.9,
67
+ top_k: int = 50,
68
+ repetition_penalty: float = 1.2,
69
+ ) -> Iterator[str]:
70
+
71
+ global context, sources, conversation, session_state
72
+
73
+ if not (session_state["documents_loaded"]) and not (session_state["clarification"]):
74
+
75
+ conversation = []
76
+ for user, assistant in chat_history:
77
+ conversation.extend(
78
+ [
79
+ {"role": "user", "content": user},
80
+ {"role": "assistant", "content": assistant},
81
+ ]
82
+ )
83
+ conversation.append({"role": "user", "content": f"Domanda: {message} . Comando: {INSTRUCTION_1}" })
84
+ conversation.append({"role": "assistant", "content": "Ok."})
85
+ print("debug - CONV1", conversation)
86
+
87
+ session_state["original_message"] = message
88
+ session_state["clarification"] = True
89
+
90
+
91
+ elif session_state["clarification"]:
92
+
93
+ message = message.lower()
94
+ matched_path = None
95
+
96
+ for key, path in documents_paths.items():
97
+ if key in message:
98
+ matched_path = path
99
+ break
100
+
101
+ if matched_path:
102
+ yield "Fammi cercare tra i miei documenti..."
103
+ documents = load_documents(matched_path)
104
+ session_state["document_db"] = prepare_documents(documents)
105
+ session_state["documents_loaded"] = True
106
+ yield f"Ecco, ho raccolto informazioni dagli Osservatori {key.capitalize()}. Ora sto elaborando una risposta per te..."
107
+ context, sources = get_context_sources(session_state["original_message"], session_state["document_db"])
108
+
109
+ #conversation = []
110
+ conversation.append({"role": "user", "content": f"{INSTRUCTION_2}"})
111
+ for user, assistant in chat_history:
112
+ conversation.extend(
113
+ [
114
+ #{"role": "user", "content": user },
115
+ {"role": "assistant", "content": assistant},
116
+ ]
117
+ )
118
+ conversation.append({"role": "user", "content": f"Contesto: {context}\n\n Domanda iniziale: {session_state["original_message"]} . Rispondi solo in italiano."})
119
+ session_state["clarification"] = False
120
+ print("debug - CONV2", conversation)
121
+
122
+ else:
123
+ print(default_error_response)
124
+
125
+ else:
126
+ conversation = []
127
+ conversation.append({"role": "user", "content": f"Comandi: {INSTRUCTION_2}"})
128
+ conversation.append({"role": "assistant", "content": "Va bene."})
129
+ for user, assistant in chat_history:
130
+ conversation.extend(
131
+ [
132
+ {"role": "user", "content": user},
133
+ {"role": "assistant", "content": assistant},
134
+ ]
135
+ )
136
+ conversation.append({"role": "user", "content": f"Contesto: {context}\n\n Nuova domanda: {message} . Rispondi in italiano e seguendo i comandi che ti ho dato prima"})
137
+ print("debug - CONV3", conversation)
138
+
139
+ """ retriever = db.as_retriever()
140
+ qa = RetrievalQA.from_chain_type(llm=model, chain_type="refine", retriever=retriever, return_source_documents=False)
141
+ question = "Cosa sono i RWA?"
142
+ result = qa.run({"query": question})
143
+ print(result["result"]) """
144
+
145
+ # Iterate model output
146
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
147
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
148
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
149
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
150
+ input_ids = input_ids.to(model.device)
151
 
152
+ streamer = TextIteratorStreamer(tokenizer, timeout=None, skip_prompt=True, skip_special_tokens=True)
153
+ generate_kwargs = dict(
154
+ {"input_ids": input_ids},
155
+ streamer=streamer,
156
+ max_new_tokens=max_new_tokens,
157
+ do_sample=True,
158
+ top_p=top_p,
159
+ top_k=top_k,
160
+ temperature=temperature,
161
+ num_beams=1,
162
+ repetition_penalty=repetition_penalty,
163
+ )
164
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
165
+ t.start()
166
 
167
+
168
+ outputs = []
169
+ for text in streamer:
170
+ outputs.append(text)
171
+ yield "".join(outputs)
172
 
173
+ if session_state["documents_loaded"]:
174
+ outputs.append(f"Fonti utilizzate: {sources}")
175
+ yield "".join(outputs)
176
+
177
+ sources = []
178
+ print("debug - CHATHISTORY", chat_history)
179
+
180
+ chat_interface = gr.ChatInterface(
181
+ fn=generate,
182
+ additional_inputs=[
183
+ gr.Slider(
184
+ label="Max new tokens",
185
+ minimum=1,
186
+ maximum=MAX_MAX_NEW_TOKENS,
187
+ step=1,
188
+ value=DEFAULT_MAX_NEW_TOKENS,
189
+ ),
190
+ gr.Slider(
191
+ label="Temperature",
192
+ minimum=0.1,
193
+ maximum=4.0,
194
+ step=0.1,
195
+ value=0.6,
196
+ ),
197
+ gr.Slider(
198
+ label="Top-p (nucleus sampling)",
199
+ minimum=0.05,
200
+ maximum=1.0,
201
+ step=0.05,
202
+ value=0.9,
203
+ ),
204
+ gr.Slider(
205
+ label="Top-k",
206
+ minimum=1,
207
+ maximum=1000,
208
+ step=1,
209
+ value=50,
210
+ ),
211
+ gr.Slider(
212
+ label="Repetition penalty",
213
+ minimum=1.0,
214
+ maximum=2.0,
215
+ step=0.05,
216
+ value=1.2,
217
+ ),
218
+ ],
219
+ stop_btn=None,
220
+ examples=[
221
+ ["Ciao, in cosa puoi aiutarmi?"],
222
+ ["Ciao, in cosa consiste un piatto di spaghetti?"],
223
+ ["Ciao, quali sono le aziende che hanno iniziato ad integrare le stablecoins? Fammi un breve sommario."],
224
+ ["Spiegami la differenza tra mondi virtuali pubblici o privati"],
225
+ ["Trovami un esempio di progetto B2B"],
226
+ ["Quali sono le regole europee sui bonifici istantanei?"],
227
+ ],
228
+ cache_examples=False,
229
+ )
230
+
231
+ with gr.Blocks(css="style.css", fill_height=True) as demo:
232
+ gr.Markdown(DESCRIPTION, elem_classes="centered")
233
+ chat_interface.render()
234
+
235
  if __name__ == "__main__":
236
+ demo.queue(max_size=20).launch()
 
 
appLlama.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from backend import handle_query
2
+ import gradio as gr
3
+
4
+
5
+ DESCRIPTION = """\
6
+ # <div style="text-align: center;">Odi, l'assistente ricercatore degli Osservatori</div>
7
+
8
+
9
+ 👉 Retrieval-Augmented Generation - Ask me anything about the research carried out at the Osservatori.
10
+ """
11
+
12
+
13
+ chat_interface =gr.ChatInterface(
14
+ fn=handle_query,
15
+ chatbot=gr.Chatbot(height=500),
16
+ textbox=gr.Textbox(placeholder="Chiedimi qualasiasi cosa relativa agli Osservatori", container=False, scale=7),
17
+ #examples=[["Ciao, in cosa puoi aiutarmi?"],["Dimmi i risultati e le modalità di conduzione del censimento per favore"]]
18
+ )
19
+
20
+
21
+ with gr.Blocks(css=".gradio-container {background-color: #B9D9EB}") as demo:
22
+ gr.Markdown(DESCRIPTION)
23
+ #gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
24
+ chat_interface.render()
25
+
26
+ if __name__ == "__main__":
27
+ #progress = gr.Progress(track_tqdm=True)
28
+ demo.launch()
29
+
backend.py CHANGED
@@ -76,7 +76,6 @@ def build_index(path: str):
76
  nodes = parser.get_nodes_from_documents(documents)
77
  # Build the vector store index from the nodes
78
  index = VectorStoreIndex(nodes)
79
-
80
  #storage_context = StorageContext.from_defaults()
81
  #index.storage_context.persist(persist_dir=PERSIST_DIR)
82
 
@@ -106,7 +105,6 @@ def handle_query(query_str: str,
106
  ]
107
  )
108
 
109
-
110
  try:
111
 
112
  memory = ChatMemoryBuffer.from_defaults(token_limit=1500)
 
76
  nodes = parser.get_nodes_from_documents(documents)
77
  # Build the vector store index from the nodes
78
  index = VectorStoreIndex(nodes)
 
79
  #storage_context = StorageContext.from_defaults()
80
  #index.storage_context.persist(persist_dir=PERSIST_DIR)
81
 
 
105
  ]
106
  )
107
 
 
108
  try:
109
 
110
  memory = ChatMemoryBuffer.from_defaults(token_limit=1500)
backend2.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from pypdf import PdfReader
5
+ from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
6
+ from langchain_community.vectorstores import FAISS
7
+ #from langchain_community.embeddings import HuggingFaceEmbeddings
8
+ from langchain_huggingface import HuggingFaceEmbeddings
9
+ import time
10
+ import torch
11
+ from dotenv import load_dotenv
12
+
13
+ logging.basicConfig(
14
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
15
+ level=logging.DEBUG
16
+ )
17
+ logger = logging.getLogger(__name__)
18
+ logging.getLogger('matplotlib').setLevel(logging.WARNING) # Suppress Matplotlib debug messages
19
+
20
+ load_dotenv()
21
+
22
+ logger.debug("Environment variables loaded.")
23
+
24
+ def load_single_document(filepath):
25
+ if filepath.endswith('.pdf'):
26
+ with open(filepath, 'rb') as file:
27
+ pdf_reader = PdfReader(file)
28
+ text = " ".join([page.extract_text() for page in pdf_reader.pages])
29
+ elif filepath.endswith('.txt'):
30
+ with open(filepath, 'r', encoding='utf-8') as file:
31
+ text = file.read()
32
+ else:
33
+ logger.warning("Unsupported file type: %s", filepath)
34
+ return {"content": "", "source": filepath}
35
+
36
+ return {"content": text, "source": filepath}
37
+
38
+ def load_documents(directory):
39
+ logger.debug("Loading documents from directory: %s", directory)
40
+ start_time = time.time()
41
+ filepaths = [os.path.join(directory, filename) for filename in os.listdir(directory) if filename.endswith('.pdf') or filename.endswith('.txt')]
42
+
43
+ if not filepaths:
44
+ logger.error("No documents found in the directory.")
45
+ else:
46
+ logger.debug("Found %d documents", len(filepaths))
47
+
48
+ documents = []
49
+ with ThreadPoolExecutor() as executor:
50
+ documents = list(executor.map(load_single_document, filepaths))
51
+
52
+ end_time = time.time()
53
+ logger.debug("Loaded %d documents in %.2f seconds.", len(documents), end_time - start_time)
54
+ return documents
55
+
56
+
57
+ def prepare_documents(documents):
58
+ logger.debug("Preparing documents for embedding.")
59
+ start_time = time.time()
60
+
61
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
62
+ # It splits text into chunks of 1000 characters each with a 150-character overlap.
63
+ #text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
64
+ texts = text_splitter.create_documents([doc["content"] for doc in documents], metadatas=[{"source": os.path.basename(doc["source"])} for doc in documents])
65
+ if not texts:
66
+ logger.error("No texts to embed.")
67
+ return None
68
+
69
+ modelPath = "sentence-transformers/all-MiniLM-l6-v2"
70
+ model_kwargs = {'device':'mps'}
71
+ encode_kwargs = {'normalize_embeddings': False}
72
+ embeddings = HuggingFaceEmbeddings(model_name=modelPath, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs )
73
+
74
+ try:
75
+ db = FAISS.from_documents(texts, embeddings)
76
+ except Exception as e:
77
+ logger.error("Error creating FAISS index: %s", e)
78
+ return None
79
+
80
+ end_time = time.time()
81
+ logger.debug("Documents prepared in %.2f seconds.", end_time - start_time)
82
+ return db
83
+
84
+
85
+ def get_context_sources(question, db):
86
+ start_time = time.time()
87
+
88
+ docs = db.similarity_search(question, k=3)
89
+ context = " ".join([doc.page_content for doc in docs])
90
+ sources = ", ".join(set([doc.metadata['source'] for doc in docs]))
91
+
92
+ end_time = time.time()
93
+ logger.debug("Similarity search done in %.2f seconds.", end_time - start_time)
94
+
95
+ return context, sources
requirements.txt CHANGED
@@ -3,7 +3,7 @@ llama-index
3
  llama-index-embeddings-huggingface
4
  llama-index-llms-huggingface
5
  llama-index-embeddings-instructor
6
- sentence-transformers==2.2.2
7
  llama-index-readers-web
8
  llama-index-readers-file
9
  gradio
@@ -13,9 +13,8 @@ setuptools
13
  spaces
14
  pydantic
15
  ipython
16
- keras
17
- keras-nlp
18
- tensorflow
19
- #langchain
20
- #langchain-community
21
- #langchain_huggingface
 
3
  llama-index-embeddings-huggingface
4
  llama-index-llms-huggingface
5
  llama-index-embeddings-instructor
6
+ sentence-transformers #==2.2.2
7
  llama-index-readers-web
8
  llama-index-readers-file
9
  gradio
 
13
  spaces
14
  pydantic
15
  ipython
16
+ #keras
17
+ #keras-nlp
18
+ #tensorflow
19
+ langchain-community
20
+ langchain-huggingface