Reyad-Ahmmed commited on
Commit
605744d
·
verified ·
1 Parent(s): ab2af46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -0
app.py CHANGED
@@ -29,6 +29,7 @@ import os
29
  from flask import Flask, jsonify, request
30
  import requests
31
  from fetch_data import fetch_and_update_training_data
 
32
 
33
  # Load configuration file
34
  with open('config.json', 'r') as config_file:
@@ -308,3 +309,56 @@ else:
308
  model = AutoModelForSequenceClassification.from_pretrained(model_save_path).to('cpu')
309
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_save_path)
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  from flask import Flask, jsonify, request
30
  import requests
31
  from fetch_data import fetch_and_update_training_data
32
+ import gradio as gr
33
 
34
  # Load configuration file
35
  with open('config.json', 'r') as config_file:
 
309
  model = AutoModelForSequenceClassification.from_pretrained(model_save_path).to('cpu')
310
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_save_path)
311
 
312
+ #Function to classify user input
313
+ def classify_user_input(user_input):
314
+ while True:
315
+
316
+ # Tokenize and predict
317
+ input_encoding = tokenizer(user_input, padding=True, truncation=True, return_tensors="pt").to('cuda')
318
+
319
+ with torch.no_grad():
320
+ #attention_mask = input_encoding['attention_mask'].clone()
321
+
322
+ # Modify the attention mask to emphasize certain key tokens
323
+ for idx, token_id in enumerate(input_encoding['input_ids'][0]):
324
+ word = tokenizer.decode([token_id])
325
+ print(word)
326
+ #if word.strip() in ["point", "summarize", "oil", "maintenance"]: # Target key tokens
327
+ #attention_mask[0, idx] = 2 # Increase attention weight for these words
328
+ # else:
329
+ # attention_mask[0, idx] = 0
330
+ #print (attention_mask)
331
+ #input_encoding['attention_mask'] = attention_mask
332
+ output = model(**input_encoding, output_hidden_states=True)
333
+ # print('start-logits')
334
+ # print(output.logits)
335
+ # print('end-logits')
336
+ #print(output)
337
+ attention = output.attentions # Get attention scores
338
+ #print('atten')
339
+ #print(attention)
340
+ # Apply softmax to get the probabilities (confidence scores)
341
+ probabilities = F.softmax(output.logits, dim=-1)
342
+
343
+ # tokens = tokenizer.convert_ids_to_tokens(input_encoding['input_ids'][0].cpu().numpy())
344
+ # # Display the attention visualization
345
+ # input_text = tokenizer.convert_ids_to_tokens(input_encoding['input_ids'][0])
346
+
347
+ prediction = torch.argmax(output.logits, dim=1).cpu().numpy()
348
+
349
+ # Map prediction back to label
350
+ print(prediction)
351
+ predicted_label = label_mapping_reverse[prediction[0]]
352
+
353
+
354
+ print(f"Predicted intent: {predicted_label}\n")
355
+ # Print the confidence for each label
356
+ print("\nLabel Confidence Scores:")
357
+ for i, label in label_mapping_reverse.items():
358
+ confidence = probabilities[0][i].item() # Get confidence score for each label
359
+ print(f"{label}: {confidence:.4f}")
360
+ print("\n")
361
+
362
+
363
+ iface = gr.Interface(fn=classify_user_input, inputs="text", outputs="text")
364
+ iface.launch(share=True)