File size: 9,017 Bytes
bb8e493
 
 
 
 
ddf8ec6
bb8e493
e35273c
231d2b5
 
 
bb8e493
 
 
 
 
 
 
231d2b5
 
 
bb8e493
231d2b5
 
 
bb8e493
 
 
231d2b5
 
 
 
284c0f7
231d2b5
 
 
 
 
4721a1c
 
231d2b5
ddf8ec6
231d2b5
4721a1c
8cf19de
 
 
231d2b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4721a1c
8cf19de
231d2b5
 
 
 
 
 
 
 
 
d632349
 
231d2b5
 
d632349
 
231d2b5
 
9f05250
231d2b5
 
 
8cf19de
231d2b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cd29ab
8cf19de
 
 
 
231d2b5
e384a9f
 
8cf19de
 
 
 
e384a9f
8cf19de
d632349
233b98c
8cf19de
d469f0d
8cf19de
d469f0d
284c0f7
 
d469f0d
 
231d2b5
d469f0d
8cf19de
 
 
e384a9f
bb8e493
ddf8ec6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import os
import torch
from flask import Flask, jsonify, request
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
from accelerate import Accelerator
import re
import traceback
from transformers import pipeline 
from sentence_transformers import SentenceTransformer, util
  
# Set the HF_HOME environment variable to a writable directory
os.environ["HF_HOME"] = "/workspace/huggingface_cache"

app = Flask(__name__)

# Enable CORS for specific origins
CORS(app, resources={r"/send_message": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}})
  
# Load zero-shot classification pipeline
classifier = pipeline("zero-shot-classification")

 # Load Sentence-BERT model
bertmodel = SentenceTransformer('all-MiniLM-L6-v2')  # Lightweight, efficient model; choose larger if needed
        
# Global variables for model and tokenizer
model = None
tokenizer = None
accelerator = Accelerator() 
highest_label = None 
loaded_models = {}
 
def get_model_and_tokenizer(model_id: str):
    """
    Load and cache the model and tokenizer for the given model_id.
    """
    global model, tokenizer  # Declare global variables to modify them within the function
    if model_id not in loaded_models:
        try:
            tokenizer = AutoTokenizer.from_pretrained(model_id)
            model = AutoModelForCausalLM.from_pretrained(model_id)
            model = accelerator.prepare(model)
            loaded_models[model_id] = (model, tokenizer)
        except Exception as e:
            print("Error loading model:")
            print(traceback.format_exc())  # Logs the full error traceback
            raise e  # Reraise the exception to stop execution
    return loaded_models[model_id]
        
        
# Extract the core sentence needing grammar correction
def extract_core_sentence(user_input):
    """
    Extract the core sentence needing grammar correction from the user input.
    """
    match = re.search(r"(?<=sentence[: ]).+", user_input, re.IGNORECASE)
    if match:
        return match.group(0).strip()
    return user_input

def classify_intent(user_input):
    """
    Classify the intent of the user input using zero-shot classification.
    """
    candidate_labels = [
        "grammar correction", "information request", "task completion", 
        "dialog continuation", "personal opinion", "product inquiry",
        "feedback request", "recommendation request", "clarification request", 
        "affirmation or agreement", "real-time data request", "current information"
    ]
    result = classifier(user_input, candidate_labels)
    highest_score_index = result['scores'].index(max(result['scores']))
    highest_label = result['labels'][highest_score_index]
    return highest_label


# Reformulate the prompt based on intent
# Function to generate reformulated prompts
def reformulate_prompt(user_input, intent_label):
    """
    Reformulate the prompt based on the classified intent.
    """
    core_sentence = extract_core_sentence(user_input)
    prompt_templates = {
        "grammar correction": f"Fix the grammar in this sentence: {core_sentence}",
        "information request": f"Provide information about: {core_sentence}",
        "dialog continuation": f"Continue the conversation based on the previous dialog:\n{core_sentence}\n",
        "personal opinion": f"What is your personal opinion on: {core_sentence}?",
        "product inquiry": f"Provide details about the product: {core_sentence}",
        "feedback request": f"Please provide feedback on: {core_sentence}",
        "recommendation request": f"Recommend something related to: {core_sentence}",
        "clarification request": f"Clarify the following: {core_sentence}",
        "affirmation or agreement": f"Affirm or agree with the statement: {core_sentence}",
    }
    return prompt_templates.get(intent_label, "Input does not require a defined action.")

chat_history = [
            ("Hi there, how are you?", "I am fine. How are you?"),
            ("Tell me a joke!", "The capital of France is Paris."),
            ("Can you tell me another joke?", "Why don't scientists trust atoms? Because they make up everything!"),
            ]
    
    
def generate_response(user_input, model_id):
    try:
        model, tokenizer = get_model_and_tokenizer(model_id)
        device = accelerator.device  # Get the device from the accelerator
       
        # Append chat history
        func_caller = []

        for msg in chat_history:
            func_caller.append({"role": "user", "content": f"{str(msg[0])}"})
            func_caller.append({"role": "assistant", "content": f"{str(msg[1])}"})
 
        highest_label_result = classify_intent(user_input) 

        # Reformulated prompt based on intent classification
        reformulated_prompt = reformulate_prompt(user_input, highest_label_result)
          
        func_caller.append({"role": "user", "content": f'{reformulated_prompt}'})
        formatted_prompt = "\n".join([f"{m['role']}: {m['content']}" for m in func_caller])

        #prompt = user_input
        #device = accelerator.device  # Automatically uses GPU or CPU based on accelerator setup
        
        generation_config = GenerationConfig(
            do_sample=(highest_label == "dialog continuation" or highest_label == "recommendation request"),  # True if dialog continuation, else False
            temperature=0.7 if highest_label == "dialog continuation" else (0.2 if highest_label == "recommendation request" else None),  # Set temperature for specific intents 
            top_k = 5 if highest_label == "recommendation request" else None, 
            #attention_mask=attention_mask,
            max_length=150, 
            repetition_penalty=1.2, 
            length_penalty=1.0, 
            no_repeat_ngram_size=2, 
            num_return_sequences=1,
            pad_token_id=tokenizer.eos_token_id,
            #stop_sequences=["User:", "Assistant:", "\n"],
            )
        
        # Generate response
        gpt_inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
        gpt_output = model.generate(gpt_inputs["input_ids"], max_new_tokens=50, generation_config=generation_config)
        final_response = tokenizer.decode(gpt_output[0], skip_special_tokens=True)
        # Extract AI's response only (omit the prompt)
        #ai_response2 = final_response.replace(reformulated_prompt, "").strip()
        ai_response = re.sub(re.escape(formatted_prompt), "", final_response, flags=re.IGNORECASE).strip()
        #ai_response = re.split(r'(?<=\w[.!?]) +', ai_response)
        ai_response = [s.strip() for s in re.split(r'(?<=\w[.!?]) +', ai_response) if s]
       
        # Encode the prompt and candidates
        prompt_embedding = bertmodel.encode(formatted_prompt, convert_to_tensor=True)
        candidate_embeddings = bertmodel.encode(ai_response, convert_to_tensor=True)
        
        # Compute similarity scores between prompt and each candidate
        similarities = util.pytorch_cos_sim(prompt_embedding, candidate_embeddings)[0]
        
        # Find the candidate with the highest similarity score
        
        best_index = similarities.argmax()
        best_response = ai_response[best_index]
        
        # Assuming best_response is already defined and contains the generated response
        
        if highest_label == "dialog continuation":
            # Split the response into sentences
            sentences = best_response.split('. ')
            # Take the first three sentences and join them back together
            best_response = '. '.join(sentences[:3]) if len(sentences) > 3 else best_response
 
        # Append the user's message to the chat history
        chat_history.append({'role': 'user', 'content': user_input})
        chat_history.append({'role': 'assistant', 'content': best_response})

        return best_response
        
    except Exception as e:
        print("Error in generate_response:")
        print(traceback.format_exc())  # Logs the full traceback
        raise e
 
@app.route("/send_message", methods=["POST"])
def handle_post_request():
    try:
        data = request.get_json()
        if data is None:
            return jsonify({"error": "No JSON data provided"}), 400

        message = data.get("inputs", "No message provided.")
        model_id = data.get("model_id", "openai-community/gpt2-large")

        print(f"Processing request with model_id: {model_id}")
        model_response = generate_response(message, model_id)

        return jsonify({
            "received_message": model_response,
            "model_id": model_id,
            "status": "POST request successful!"
        })
    
    except Exception as e:
        print("Error handling POST request:")
        print(traceback.format_exc())  # Logs the full traceback
        return jsonify({"error": str(e)}), 500

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)