YALCINKAYA commited on
Commit
231d2b5
·
verified ·
1 Parent(s): 5c97f96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -54
app.py CHANGED
@@ -6,7 +6,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
6
  from accelerate import Accelerator
7
  import re
8
  import traceback
9
-
 
 
10
  # Set the HF_HOME environment variable to a writable directory
11
  os.environ["HF_HOME"] = "/workspace/huggingface_cache"
12
 
@@ -14,79 +16,167 @@ app = Flask(__name__)
14
 
15
  # Enable CORS for specific origins
16
  CORS(app, resources={r"/send_message": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}})
 
 
 
17
 
 
 
 
18
  # Global variables for model and tokenizer
19
  model = None
20
  tokenizer = None
21
- accelerator = Accelerator()
22
-
 
 
23
  def get_model_and_tokenizer(model_id: str):
24
- global model, tokenizer
25
- if model is None or tokenizer is None:
 
 
 
26
  try:
27
- print(f"Loading tokenizer for model_id: {model_id}")
28
  tokenizer = AutoTokenizer.from_pretrained(model_id)
29
- tokenizer.pad_token = tokenizer.eos_token
30
-
31
- print(f"Loading model for model_id: {model_id}")
32
-
33
- bnb_config = BitsAndBytesConfig(
34
- load_in_4bit=True, bnb_4bit_quant_type="nf4",
35
- bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True
36
- )
37
-
38
- model = AutoModelForCausalLM.from_pretrained(
39
- model_id, quantization_config=bnb_config, device_map="auto"
40
- )
41
-
42
- model.config.use_cache = False
43
- model.config.pretraining_tp = 1
44
- model.config.pad_token_id = tokenizer.eos_token_id # Fix padding issue
45
-
46
- # Use accelerator.prepare() to handle device assignment (no need to move model manually)
47
  model = accelerator.prepare(model)
48
-
49
  except Exception as e:
50
  print("Error loading model:")
51
  print(traceback.format_exc()) # Logs the full error traceback
52
  raise e # Reraise the exception to stop execution
53
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def generate_response(user_input, model_id):
55
  try:
56
- get_model_and_tokenizer(model_id)
57
- prompt = formatted_prompt(user_input)
58
- #prompt = user_input
59
- device = accelerator.device # Automatically uses GPU or CPU based on accelerator setup
 
 
 
 
 
 
 
 
 
 
 
60
 
 
 
 
61
  generation_config = GenerationConfig(
62
- do_sample=False, # Disable sampling for deterministic output
63
- top_p=0.0, # Prevents sampling lower probability tokens
64
- top_k=1, # Forces picking the most likely token at each step
65
- temperature=0.0, # No randomness in token selection
66
- repetition_penalty=1.3, # Helps prevent hallucinations
67
- max_new_tokens=50, # Adjust based on dataset response length
68
- pad_token_id=tokenizer.eos_token_id # Ensures proper token padding
69
- )
70
-
71
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
72
- # No need to move model here, as it's already dispatched to the correct device
73
-
74
- outputs = model.generate(**inputs, generation_config=generation_config)
75
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
76
-
77
- # Clean up response
78
- #cleaned_response = re.sub(r"(User:|Assistant:)", "", response).strip()
79
- #return cleaned_response.split("\n")[0]
80
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  except Exception as e:
83
  print("Error in generate_response:")
84
  print(traceback.format_exc()) # Logs the full traceback
85
  raise e
86
-
87
- def formatted_prompt(question)-> str:
88
- return f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant:"
89
-
90
  @app.route("/send_message", methods=["POST"])
91
  def handle_post_request():
92
  try:
@@ -95,7 +185,7 @@ def handle_post_request():
95
  return jsonify({"error": "No JSON data provided"}), 400
96
 
97
  message = data.get("inputs", "No message provided.")
98
- model_id = data.get("model_id", "YALCINKAYA/opsgenius_ultra")
99
 
100
  print(f"Processing request with model_id: {model_id}")
101
  model_response = generate_response(message, model_id)
@@ -105,6 +195,7 @@ def handle_post_request():
105
  "model_id": model_id,
106
  "status": "POST request successful!"
107
  })
 
108
  except Exception as e:
109
  print("Error handling POST request:")
110
  print(traceback.format_exc()) # Logs the full traceback
 
6
  from accelerate import Accelerator
7
  import re
8
  import traceback
9
+ from transformers import pipeline
10
+ from sentence_transformers import SentenceTransformer, util
11
+
12
  # Set the HF_HOME environment variable to a writable directory
13
  os.environ["HF_HOME"] = "/workspace/huggingface_cache"
14
 
 
16
 
17
  # Enable CORS for specific origins
18
  CORS(app, resources={r"/send_message": {"origins": ["http://localhost:3000", "https://main.dbn2ikif9ou3g.amplifyapp.com"]}})
19
+
20
+ # Load zero-shot classification pipeline
21
+ classifier = pipeline("zero-shot-classification")
22
 
23
+ # Load Sentence-BERT model
24
+ bertmodel = SentenceTransformer('all-MiniLM-L6-v2') # Lightweight, efficient model; choose larger if needed
25
+
26
  # Global variables for model and tokenizer
27
  model = None
28
  tokenizer = None
29
+ accelerator = Accelerator()
30
+ highest_label = None
31
+ loaded_models = {}
32
+
33
  def get_model_and_tokenizer(model_id: str):
34
+ """
35
+ Load and cache the model and tokenizer for the given model_id.
36
+ """
37
+ global model, tokenizer # Declare global variables to modify them within the function
38
+ if model_id not in loaded_models:
39
  try:
 
40
  tokenizer = AutoTokenizer.from_pretrained(model_id)
41
+ model = AutoModelForCausalLM.from_pretrained(model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  model = accelerator.prepare(model)
43
+ loaded_models[model_id] = (model, tokenizer)
44
  except Exception as e:
45
  print("Error loading model:")
46
  print(traceback.format_exc()) # Logs the full error traceback
47
  raise e # Reraise the exception to stop execution
48
+ return loaded_models[model_id]
49
+
50
+
51
+ # Extract the core sentence needing grammar correction
52
+ def extract_core_sentence(user_input):
53
+ """
54
+ Extract the core sentence needing grammar correction from the user input.
55
+ """
56
+ match = re.search(r"(?<=sentence[: ]).+", user_input, re.IGNORECASE)
57
+ if match:
58
+ return match.group(0).strip()
59
+ return user_input
60
+
61
+ def classify_intent(user_input):
62
+ """
63
+ Classify the intent of the user input using zero-shot classification.
64
+ """
65
+ candidate_labels = [
66
+ "grammar correction", "information request", "task completion",
67
+ "dialog continuation", "personal opinion", "product inquiry",
68
+ "feedback request", "recommendation request", "clarification request",
69
+ "affirmation or agreement", "real-time data request", "current information"
70
+ ]
71
+ result = classifier(user_input, candidate_labels)
72
+ highest_score_index = result['scores'].index(max(result['scores']))
73
+ highest_label = result['labels'][highest_score_index]
74
+ return highest_label
75
+
76
+
77
+ # Reformulate the prompt based on intent
78
+ # Function to generate reformulated prompts
79
+ def reformulate_prompt(user_input, intent_label):
80
+ """
81
+ Reformulate the prompt based on the classified intent.
82
+ """
83
+ core_sentence = extract_core_sentence(user_input)
84
+ prompt_templates = {
85
+ "grammar correction": f"Fix the grammar in this sentence: {core_sentence}",
86
+ "information request": f"Provide information about: {core_sentence}",
87
+ "dialog continuation": f"Continue the conversation based on the previous dialog:\n{core_sentence}\n",
88
+ "personal opinion": f"What is your personal opinion on: {core_sentence}?",
89
+ "product inquiry": f"Provide details about the product: {core_sentence}",
90
+ "feedback request": f"Please provide feedback on: {core_sentence}",
91
+ "recommendation request": f"Recommend something related to: {core_sentence}",
92
+ "clarification request": f"Clarify the following: {core_sentence}",
93
+ "affirmation or agreement": f"Affirm or agree with the statement: {core_sentence}",
94
+ }
95
+ return prompt_templates.get(intent_label, "Input does not require a defined action.")
96
+
97
+ chat_history = [
98
+ ("Hi there, how are you?", "I am fine. How are you?"),
99
+ ("Tell me a joke!", "The capital of France is Paris."),
100
+ ("Can you tell me another joke?", "Why don't scientists trust atoms? Because they make up everything!"),
101
+ ]
102
+
103
+
104
  def generate_response(user_input, model_id):
105
  try:
106
+ model, tokenizer = get_model_and_tokenizer(model_id)
107
+ device = accelerator.device # Get the device from the accelerator
108
+
109
+ # Append chat history
110
+ func_caller = []
111
+
112
+ for msg in chat_history:
113
+ func_caller.append({"role": "user", "content": f"{str(msg[0])}"})
114
+ func_caller.append({"role": "assistant", "content": f"{str(msg[1])}"})
115
+
116
+ # Reformulated prompt based on intent classification
117
+ reformulated_prompt = reformulate_prompt(user_input, highest_label)
118
+
119
+ func_caller.append({"role": "user", "content": f'{reformulated_prompt}'})
120
+ formatted_prompt = "\n".join([f"{m['role']}: {m['content']}" for m in func_caller])
121
 
122
+ #prompt = user_input
123
+ #device = accelerator.device # Automatically uses GPU or CPU based on accelerator setup
124
+
125
  generation_config = GenerationConfig(
126
+ do_sample=(highest_label == "dialog continuation" or highest_label == "recommendation request"), # True if dialog continuation, else False
127
+ temperature=0.7 if highest_label == "dialog continuation" else (0.2 if highest_label == "recommendation request" else None), # Set temperature for specific intents
128
+ top_k = 5 if highest_label == "recommendation request" else None,
129
+ #attention_mask=attention_mask,
130
+ max_length=150,
131
+ repetition_penalty=1.2,
132
+ length_penalty=1.0,
133
+ no_repeat_ngram_size=2,
134
+ num_return_sequences=1,
135
+ pad_token_id=tokenizer.eos_token_id,
136
+ #stop_sequences=["User:", "Assistant:", "\n"],
137
+ )
138
+
139
+ # Generate response
140
+ gpt_inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
141
+ gpt_output = model.generate(gpt_inputs["input_ids"], max_new_tokens=50, generation_config=generation_config)
142
+ final_response = tokenizer.decode(gpt_output[0], skip_special_tokens=True)
143
+ # Extract AI's response only (omit the prompt)
144
+ #ai_response2 = final_response.replace(reformulated_prompt, "").strip()
145
+ ai_response = re.sub(re.escape(formatted_prompt), "", final_response, flags=re.IGNORECASE).strip()
146
+ #ai_response = re.split(r'(?<=\w[.!?]) +', ai_response)
147
+ ai_response = [s.strip() for s in re.split(r'(?<=\w[.!?]) +', ai_response) if s]
148
+
149
+ # Encode the prompt and candidates
150
+ prompt_embedding = bertmodel.encode(formatted_prompt, convert_to_tensor=True)
151
+ candidate_embeddings = bertmodel.encode(ai_response, convert_to_tensor=True)
152
+
153
+ # Compute similarity scores between prompt and each candidate
154
+ similarities = util.pytorch_cos_sim(prompt_embedding, candidate_embeddings)[0]
155
+
156
+ # Find the candidate with the highest similarity score
157
+
158
+ best_index = similarities.argmax()
159
+ best_response = ai_response[best_index]
160
+
161
+ # Assuming best_response is already defined and contains the generated response
162
+
163
+ if highest_label == "dialog continuation":
164
+ # Split the response into sentences
165
+ sentences = best_response.split('. ')
166
+ # Take the first three sentences and join them back together
167
+ best_response = '. '.join(sentences[:3]) if len(sentences) > 3 else best_response
168
+
169
+ # Append the user's message to the chat history
170
+ chat_history.append({'role': 'user', 'content': user_input})
171
+ chat_history.append({'role': 'assistant', 'content': best_response})
172
+
173
+ return best_response
174
 
175
  except Exception as e:
176
  print("Error in generate_response:")
177
  print(traceback.format_exc()) # Logs the full traceback
178
  raise e
179
+
 
 
 
180
  @app.route("/send_message", methods=["POST"])
181
  def handle_post_request():
182
  try:
 
185
  return jsonify({"error": "No JSON data provided"}), 400
186
 
187
  message = data.get("inputs", "No message provided.")
188
+ model_id = data.get("model_id", "meta-llama/Llama-3.1-8B-Instruct")
189
 
190
  print(f"Processing request with model_id: {model_id}")
191
  model_response = generate_response(message, model_id)
 
195
  "model_id": model_id,
196
  "status": "POST request successful!"
197
  })
198
+
199
  except Exception as e:
200
  print("Error handling POST request:")
201
  print(traceback.format_exc()) # Logs the full traceback