File size: 19,962 Bytes
bb8e493
 
30dde8d
 
 
 
bb8e493
 
 
ddf8ec6
bb8e493
e35273c
231d2b5
 
 
bb8e493
 
 
 
 
 
 
231d2b5
 
e831671
bb8e493
231d2b5
 
e831671
d724509
 
 
 
 
 
 
e831671
 
 
 
 
 
 
 
 
 
231d2b5
d3914c1
a47c058
30dde8d
d3914c1
 
 
 
 
 
 
 
 
 
 
30dde8d
 
db44cd2
 
 
3f146d0
db44cd2
3f146d0
 
 
 
 
 
 
 
 
 
db44cd2
55b887a
 
db44cd2
5ded1b4
55b887a
db44cd2
55b887a
3fdeaa6
 
 
 
 
4c2a6e4
db44cd2
 
55b887a
 
 
d8840e8
 
 
55b887a
 
 
 
 
 
 
db44cd2
30dde8d
55b887a
 
 
38899cb
faaba49
8434101
d8840e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38899cb
d8840e8
 
 
d855789
 
 
 
 
 
 
 
3b42922
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8840e8
30dde8d
 
e7602b9
147b714
 
 
 
 
67d7b50
147b714
 
 
e7602b9
147b714
 
 
 
 
 
f8f92cb
147b714
 
 
 
 
 
 
 
67d7b50
 
147b714
 
 
67d7b50
 
 
147b714
 
 
 
 
 
f8f92cb
147b714
 
 
f8f92cb
147b714
e7602b9
147b714
 
30dde8d
cdb2e51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e013ea6
cdb2e51
30dde8d
 
e7602b9
30dde8d
 
 
 
e7602b9
 
30dde8d
 
e7602b9
 
 
 
 
 
e013ea6
e7602b9
 
 
 
4c01c10
744513e
f8f92cb
e7602b9
e013ea6
e7602b9
 
 
3b42922
 
e7602b9
e013ea6
e7602b9
 
 
30dde8d
284c0f7
231d2b5
 
 
 
 
4721a1c
 
231d2b5
ddf8ec6
231d2b5
4721a1c
8cf19de
 
 
231d2b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c01c10
4721a1c
8cf19de
231d2b5
 
 
 
 
d8840e8
30dde8d
 
38899cb
 
 
 
 
 
 
30dde8d
 
 
 
c51db35
 
 
30dde8d
158c8e2
231d2b5
158c8e2
 
 
d632349
158c8e2
231d2b5
 
158c8e2
d632349
158c8e2
 
9f05250
231d2b5
931f180
231d2b5
8cf19de
231d2b5
 
 
 
 
 
 
 
 
 
 
 
 
 
158c8e2
 
 
231d2b5
931f180
158c8e2
931f180
158c8e2
231d2b5
 
158c8e2
 
231d2b5
 
158c8e2
231d2b5
 
 
158c8e2
 
231d2b5
 
 
158c8e2
231d2b5
158c8e2
231d2b5
158c8e2
231d2b5
 
158c8e2
 
231d2b5
158c8e2
1cd29ab
8cf19de
 
 
 
231d2b5
e384a9f
 
8cf19de
 
 
 
e384a9f
8cf19de
ac91e2e
afe2e22
ac91e2e
afe2e22
8cf19de
d469f0d
8cf19de
d469f0d
284c0f7
 
d469f0d
 
231d2b5
d469f0d
8cf19de
 
 
e384a9f
bb8e493
ddf8ec6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
import os
import torch
import uuid
import shutil
import numpy as np
import faiss
from flask import Flask, jsonify, request
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
from accelerate import Accelerator
import re
import traceback
from transformers import pipeline 
from sentence_transformers import SentenceTransformer, util
  
# Set the HF_HOME environment variable to a writable directory
os.environ["HF_HOME"] = "/workspace/huggingface_cache"

app = Flask(__name__)

# Enable CORS for specific origins
CORS(app, resources={r"/send_message": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}})
  
# Load zero-shot classification pipeline
#classifier = pipeline("zero-shot-classification")

 # Load Sentence-BERT model
bertmodel = SentenceTransformer('all-MiniLM-L6-v2')  # Lightweight, efficient model; choose larger if needed

# Global variables for model and tokenizer
model = None
tokenizer = None
accelerator = Accelerator() 
highest_label = None 
loaded_models = {}

# Load model with accelerator
classifier = pipeline(
    "zero-shot-classification",
    model="facebook/bart-large-mnli",
    revision="d7645e1",
    device=accelerator.device  # Ensures correct device placement
)

# Move model to correct device
classifier.model = accelerator.prepare(classifier.model)
 
# Define upload directory and FAISS index file
UPLOAD_DIR = "/app/uploads"
faiss_index_file = os.path.join(UPLOAD_DIR, "faiss_index.bin")

# Ensure upload directory exists and has write permissions
try:
    os.makedirs(UPLOAD_DIR, exist_ok=True)
    if not os.access(UPLOAD_DIR, os.W_OK):
        print(f"Fixing permissions for {UPLOAD_DIR}...")
        os.chmod(UPLOAD_DIR, 0o777)
    print(f"Uploads directory is ready: {UPLOAD_DIR}")
except PermissionError as e:
    print(f"PermissionError: {e}. Try adjusting directory ownership or running with elevated permissions.")
    
document_store = {}

def initialize_faiss():
    if os.path.exists(faiss_index_file):
        try:
            print(f"FAISS index file {faiss_index_file} exists, attempting to load it.") 
            index = faiss.read_index(faiss_index_file)
            if index.ntotal > 0:
                print(f"FAISS index loaded with {index.ntotal} vectors.")
                # If the index has non-zero entries, reset it
                index.reset()  # Resetting the index if non-zero entries
                print("Index reset. Reinitializing index.")
                index = faiss.IndexIDMap(faiss.IndexFlatL2(384))  # Reinitialize the index
            else:
                print("Loaded index has zero vectors, reinitializing index.")
                index = faiss.IndexIDMap(faiss.IndexFlatL2(384))  # Initialize with flat L2 distance
            
        except Exception as e:
            print(f"Error loading FAISS index: {e}, reinitializing a new index.")
            index = faiss.IndexIDMap(faiss.IndexFlatL2(384))  
    else:
        print(f"FAISS index file {faiss_index_file} does not exist, initializing a new index.")
        index = faiss.IndexIDMap(faiss.IndexFlatL2(384))

    # Move to GPU if available
    # if torch.cuda.is_available():
    # print("CUDA is available, moving FAISS index to GPU.")
    # index = faiss.index_cpu_to_all_gpus(index)  
        
    # print("FAISS index is now on GPU.")

    return index

def save_faiss_index(index):
    try:
        if torch.cuda.is_available():
            print("Moving FAISS index back to CPU before saving.") 
            res = faiss.StandardGpuResources()  # Allocate GPU resources
            index = faiss.index_cpu_to_gpu(res, 0, index)  # Move to GPU 0
        print(f"Saving FAISS index to {faiss_index_file}.")
        faiss.write_index(index, faiss_index_file)
        print(f"FAISS index successfully saved to {faiss_index_file}.")
    except Exception as e:
        print(f"Error saving FAISS index: {e}")

# Initialize FAISS index
index = initialize_faiss()

# Save FAISS index after modifications
save_faiss_index(index)

# Load document store and populate FAISS index
knowledgebase_file = os.path.join(UPLOAD_DIR, "knowledgebase1.txt")  # Ensure this path is correct
   
def load_document_store():
    """Loads knowledgebase.txt into a dictionary where FAISS IDs map to text and embeddings"""
    global document_store
    document_store = {}  # Reset document store
    all_texts = []
    
    if os.path.exists(knowledgebase_file):
        with open(knowledgebase_file, "r", encoding="utf-8") as f:
            lines = f.readlines()
        
        for i, line in enumerate(lines):
            text = line.strip()
            if text:
                document_store[i] = {"text": text}  # Store text mapped to FAISS ID
                all_texts.append(text)  # Collect all texts for embedding

        print(f"Loaded {len(document_store)} documents into document_store.")
    else:
        print("Error: knowledgebase.txt not found!")
        
    # Generate embeddings for all documents
    embeddings = bertmodel.encode(all_texts)
    embeddings = embeddings.astype("float32")
    
    # Add embeddings to FAISS index
    index.add_with_ids(embeddings, np.array(list(document_store.keys()), dtype=np.int64))
    print(f"Added {len(all_texts)} document embeddings to FAISS index.") 

def load_document_store_once(file_path):
    """Loads knowledgebase.txt into a dictionary where FAISS IDs map to text and embeddings"""
    global document_store
    document_store = {}  # Reset document store
    all_texts = []
    file_location = os.path.join(UPLOAD_DIR, os.path.basename(file_path))
    if os.path.exists(file_location):
        with open(file_location, "r", encoding="utf-8") as f:
            lines = f.readlines()
        
        for i, line in enumerate(lines):
            text = line.strip()
            if text:
                document_store[i] = {"text": text}  # Store text mapped to FAISS ID
                all_texts.append(text)  # Collect all texts for embedding

        print(f"Loaded {len(document_store)} documents into document_store.")
    else:
        print("Error: knowledgebase.txt not found!")
    
    # Generate embeddings for all documents
    embeddings = bertmodel.encode(all_texts)
    embeddings = embeddings.astype("float32")
    
    # Add embeddings to FAISS index
    index.add_with_ids(embeddings, np.array(list(document_store.keys()), dtype=np.int64))
    print(f"Added {len(all_texts)} document embeddings to FAISS index.")

 
# Function to upload document
def upload_document(file_path, embed_model):
    try:
        # Generate unique document ID
        doc_id = uuid.uuid4().int % (2**63 - 1)
        
        # Ensure the file is saved to the correct directory with secure handling
        file_location = os.path.join(UPLOAD_DIR, os.path.basename(file_path))
        print(f"Saving file to: {file_location}")  # Log the location
        
        # Safely copy the file to the upload directory
        shutil.copy(file_path, file_location)
        
        # Read the content of the uploaded file
        try:
            with open(file_location, "r", encoding="utf-8") as f:
                text = f.read()
        except Exception as e:
            print(f"Error reading file {file_location}: {e}")
            return {"error": f"Error reading file: {e}"}, 507  # Error while reading file
        
        # Embed the text and add it to the FAISS index
        try:
            # Ensure the embedding model is valid
            if embed_model is None:
                raise ValueError("Embedding model is not initialized properly.")
            
            vector = embed_model.encode(text).astype("float32")
            print(f"Generated vector for document {doc_id}: {vector}")  # Log vector
            
            index.add_with_ids(np.array([vector]), np.array([doc_id], dtype=np.int64))
            document_store[doc_id] = {"path": file_location, "text": text}
            
            # Log FAISS index file path
            print(f"Saving FAISS index to: {faiss_index_file}")  # Log the file path
            
            # Save the FAISS index after adding the document
            try:
                faiss.write_index(index, faiss_index_file)
                print(f"Document uploaded with doc_id: {doc_id}")
            except Exception as e:
                print(f"Error saving FAISS index: {e}")
                return {"error": f"Error saving FAISS index: {e}"}, 508  # Error while saving FAISS index
            
        except Exception as e:
            print(f"Error during document upload: {e}")
            return {"error": f"Error during document upload: {e}"}, 509  # Error during embedding or FAISS processing
    
    except Exception as e:
        print(f"Unexpected error: {e}")
        return {"error": f"Unexpected error: {e}"}, 500  # General error

@app.route("/list_uploads", methods=["GET"])
def list_uploaded_files():
    try:
        # Ensure the upload directory exists
        if not os.path.exists(UPLOAD_DIR):
            return jsonify({"error": "Upload directory does not exist"}), 400

        # List all files in the upload directory
        files = os.listdir(UPLOAD_DIR)
        
        if not files:
            return jsonify({"message": "No files found in the upload directory"}), 200
        
        return jsonify({"files": files}), 200
    except Exception as e:
        return jsonify({"error": f"Error listing files: {e}"}), 504

@app.route("/upload", methods=["POST"])
def handle_upload():
    # Check if the request contains the file
    if "file" not in request.files:
        return jsonify({"error": "No file provided"}), 400
    
    file = request.files["file"]
    
    # Ensure the filename is safe and construct the full file path
    file_path = os.path.join(UPLOAD_DIR, file.filename)
    
    # Ensure the upload directory exists and has correct permissions
    try:
        os.makedirs(UPLOAD_DIR, exist_ok=True)  # Ensure the directory exists
        if not os.access(UPLOAD_DIR, os.W_OK):  # Check write permissions
            os.chmod(UPLOAD_DIR, 0o777)
    except PermissionError as e:
        return jsonify({"error": f"Permission error with upload directory: {e}"}), 501

    try:
        # Save the file to the upload directory
        file.save(file_path)
        load_document_store()  # Reload FAISS index
        # Now that the document is uploaded, call load_document_store()
        print(f"File uploaded successfully. Calling load_document_store()...") 
    except Exception as e:
        return jsonify({"error": f"Error saving file: {e}"}), 502
    
    # Process the document using the upload_document function
    try:
        load_document_store_once(file_path)
    #    upload_document(file_path, bertmodel)  # Assuming 'bertmodel' is defined elsewhere
    except Exception as e:
        return jsonify({"error": f"Error processing file: {e}"}), 503
    
    # Return success response
    return jsonify({"message": "File uploaded and processed successfully"}), 200

def get_model_and_tokenizer(model_id: str):
    """
    Load and cache the model and tokenizer for the given model_id.
    """
    global model, tokenizer  # Declare global variables to modify them within the function
    if model_id not in loaded_models:
        try:
            tokenizer = AutoTokenizer.from_pretrained(model_id)
            model = AutoModelForCausalLM.from_pretrained(model_id)
            model = accelerator.prepare(model)
            loaded_models[model_id] = (model, tokenizer)
        except Exception as e:
            print("Error loading model:")
            print(traceback.format_exc())  # Logs the full error traceback
            raise e  # Reraise the exception to stop execution
    return loaded_models[model_id]
        
        
# Extract the core sentence needing grammar correction
def extract_core_sentence(user_input):
    """
    Extract the core sentence needing grammar correction from the user input.
    """
    match = re.search(r"(?<=sentence[: ]).+", user_input, re.IGNORECASE)
    if match:
        return match.group(0).strip()
    return user_input

def classify_intent(user_input):
    """
    Classify the intent of the user input using zero-shot classification.
    """
    candidate_labels = [
        "grammar correction", "information request", "task completion", 
        "dialog continuation", "personal opinion", "product inquiry",
        "feedback request", "recommendation request", "clarification request", 
        "affirmation or agreement", "real-time data request", "current information"
    ]
    result = classifier(user_input, candidate_labels)
    highest_score_index = result['scores'].index(max(result['scores']))
    highest_label = result['labels'][highest_score_index]
    return highest_label


# Reformulate the prompt based on intent
# Function to generate reformulated prompts
def reformulate_prompt(user_input, intent_label):
    """
    Reformulate the prompt based on the classified intent.
    """
    core_sentence = extract_core_sentence(user_input)
    prompt_templates = {
        "grammar correction": f"Fix the grammar in this sentence: {core_sentence}",
        "information request": f"Provide information about: {core_sentence}",
        "dialog continuation": f"Continue the conversation based on the previous dialog:\n{core_sentence}\n",
        "personal opinion": f"What is your personal opinion on: {core_sentence}?",
        "product inquiry": f"Provide details about the product: {core_sentence}",
        "feedback request": f"Please provide feedback on: {core_sentence}",
        "recommendation request": f"Recommend something related to: {core_sentence}",
        "clarification request": f"Clarify the following: {core_sentence}",
        "affirmation or agreement": f"Affirm or agree with the statement: {core_sentence}",
    }
    return prompt_templates.get(intent_label, "Input does not require a defined action.")

chat_history = [
            ("Hi there, how are you?", "I am fine. How are you?"),
            ("Tell me a joke!", "The capital of France is Paris."),
            ("Can you tell me another joke?", "Why don't scientists trust atoms? Because they make up everything!"),
            ]
    
 
def generate_response(user_input, model_id):
    try:
        model, tokenizer = get_model_and_tokenizer(model_id)
        device = accelerator.device  # Get the device from the accelerator
       
        # Append chat history
        func_caller = []
          
        query_vector = bertmodel.encode(user_input).reshape(1, -1).astype("float32")
        D, I = index.search(query_vector, 1)

        # Retrieve document
        retrieved_id = I[0][0]
        retrieved_knowledge = (
            document_store.get(retrieved_id, {}).get("text", "No relevant information found.")
            if retrieved_id != -1 else "No relevant information found."
        )
        
        # Construct the knowledge prompt
        prompt = f"Use the following knowledge:\n{retrieved_knowledge}"
        
        # Log the prompt (you can change this to a logging library if needed)
        print(f"Generated prompt: {prompt}")  # <-- Log the prompt here
        
        # Add the retrieved knowledge to the prompt
        func_caller.append({"role": "system", "content": prompt}) 

        for msg in chat_history:
            func_caller.append({"role": "user", "content": f"{str(msg[0])}"})
            func_caller.append({"role": "assistant", "content": f"{str(msg[1])}"})
 
        highest_label_result = classify_intent(user_input) 

        # Reformulated prompt based on intent classification
        reformulated_prompt = reformulate_prompt(user_input, highest_label_result)
          
        func_caller.append({"role": "user", "content": f'{reformulated_prompt}'})
        formatted_prompt = "\n".join([f"{m['role']}: {m['content']}" for m in func_caller])

        #prompt = user_input
        #device = accelerator.device  # Automatically uses GPU or CPU based on accelerator setup
        
        generation_config = GenerationConfig(
            do_sample=(highest_label == "dialog continuation" or highest_label == "recommendation request"),  # True if dialog continuation, else False
            temperature=0.7 if highest_label == "dialog continuation" else (0.2 if highest_label == "recommendation request" else None),  # Set temperature for specific intents 
            top_k = 5 if highest_label == "recommendation request" else None, 
            #attention_mask=attention_mask,
            max_length=150, 
            repetition_penalty=1.2, 
            length_penalty=1.0, 
            no_repeat_ngram_size=2, 
            num_return_sequences=1,
            pad_token_id=tokenizer.eos_token_id,
            #stop_sequences=["User:", "Assistant:", "\n"],
            )
        
        # Generate response
        gpt_inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
        gpt_output = model.generate(gpt_inputs["input_ids"], max_new_tokens=50, generation_config=generation_config)
        final_response = tokenizer.decode(gpt_output[0], skip_special_tokens=True)
        # Extract AI's response only (omit the prompt)
        #ai_response2 = final_response.replace(reformulated_prompt, "").strip()
        ai_response = re.sub(re.escape(formatted_prompt), "", final_response, flags=re.IGNORECASE).strip()
        #ai_response = re.split(r'(?<=\w[.!?]) +', ai_response)
        ai_response = [s.strip() for s in re.split(r'(?<=\w[.!?]) +', ai_response) if s]
       
        # Encode the prompt and candidates
        prompt_embedding = bertmodel.encode(formatted_prompt, convert_to_tensor=True)
        candidate_embeddings = bertmodel.encode(ai_response, convert_to_tensor=True)
        
        # Compute similarity scores between prompt and each candidate
        similarities = util.pytorch_cos_sim(prompt_embedding, candidate_embeddings)[0]
        
        # Find the candidate with the highest similarity score
        
        best_index = similarities.argmax()
        best_response = ai_response[best_index]
        
        # Assuming best_response is already defined and contains the generated response
        
        if highest_label == "dialog continuation":
            # Split the response into sentences
            sentences = best_response.split('. ')
            # Take the first three sentences and join them back together
            best_response = '. '.join(sentences[:3]) if len(sentences) > 3 else best_response
 
        # Append the user's message to the chat history
        chat_history.append({'role': 'user', 'content': user_input})
        chat_history.append({'role': 'assistant', 'content': best_response})

        return best_response
        
    except Exception as e:
        print("Error in generate_response:")
        print(traceback.format_exc())  # Logs the full traceback
        raise e
 
@app.route("/send_message", methods=["POST"])
def handle_post_request():
    try:
        data = request.get_json()
        if data is None:
            return jsonify({"error": "No JSON data provided"}), 400

        message = data.get("inputs", "No message provided.")
        model_id = data.get("model_id", "meta-llama/Llama-3.1-8B-Instruct")
        
        #model_id = data.get("model_id", "openai-community/gpt2-large")
        
        print(f"Processing request with model_id: {model_id}")
        model_response = generate_response(message, model_id)

        return jsonify({
            "received_message": model_response,
            "model_id": model_id,
            "status": "POST request successful!"
        })
    
    except Exception as e:
        print("Error handling POST request:")
        print(traceback.format_exc())  # Logs the full traceback
        return jsonify({"error": str(e)}), 500

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)