Reyad-Ahmmed commited on
Commit
77847f4
·
verified ·
1 Parent(s): ed419cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -0
app.py CHANGED
@@ -308,3 +308,64 @@ else:
308
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_save_path)
309
 
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_save_path)
309
 
310
 
311
+ #Define the label mappings (this must match the mapping used during training)
312
+ label_mapping = model.config.label_mapping
313
+ label_mapping_reverse = {value: key for key, value in label_mapping.items()}
314
+
315
+
316
+ #Function to classify user input
317
+ def classify_user_input():
318
+ while True:
319
+ user_input = input("Enter a command (or type 'q' to quit): ")
320
+ if user_input.lower() == 'q':
321
+ print("Exiting...")
322
+ break
323
+
324
+ # Tokenize and predict
325
+ input_encoding = tokenizer(user_input, padding=True, truncation=True, return_tensors="pt").to('cuda')
326
+
327
+ with torch.no_grad():
328
+ #attention_mask = input_encoding['attention_mask'].clone()
329
+
330
+ # Modify the attention mask to emphasize certain key tokens
331
+ for idx, token_id in enumerate(input_encoding['input_ids'][0]):
332
+ word = tokenizer.decode([token_id])
333
+ print(word)
334
+ #if word.strip() in ["point", "summarize", "oil", "maintenance"]: # Target key tokens
335
+ #attention_mask[0, idx] = 2 # Increase attention weight for these words
336
+ # else:
337
+ # attention_mask[0, idx] = 0
338
+ #print (attention_mask)
339
+ #input_encoding['attention_mask'] = attention_mask
340
+ output = model(**input_encoding, output_hidden_states=True)
341
+ # print('start-logits')
342
+ # print(output.logits)
343
+ # print('end-logits')
344
+ #print(output)
345
+ attention = output.attentions # Get attention scores
346
+ #print('atten')
347
+ #print(attention)
348
+ # Apply softmax to get the probabilities (confidence scores)
349
+ probabilities = F.softmax(output.logits, dim=-1)
350
+
351
+ # tokens = tokenizer.convert_ids_to_tokens(input_encoding['input_ids'][0].cpu().numpy())
352
+ # # Display the attention visualization
353
+ # input_text = tokenizer.convert_ids_to_tokens(input_encoding['input_ids'][0])
354
+
355
+ prediction = torch.argmax(output.logits, dim=1).cpu().numpy()
356
+
357
+ # Map prediction back to label
358
+ print(prediction)
359
+ predicted_label = label_mapping_reverse[prediction[0]]
360
+
361
+
362
+ print(f"Predicted intent: {predicted_label}\n")
363
+ # Print the confidence for each label
364
+ print("\nLabel Confidence Scores:")
365
+ for i, label in label_mapping_reverse.items():
366
+ confidence = probabilities[0][i].item() # Get confidence score for each label
367
+ print(f"{label}: {confidence:.4f}")
368
+ print("\n")
369
+
370
+ #Run the function
371
+ classify_user_input()