mkoot007 commited on
Commit
6b717c4
·
1 Parent(s): 679bc5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -7
app.py CHANGED
@@ -3,36 +3,70 @@ import streamlit as st
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import re
5
  import torch
 
 
6
  tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
7
  model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")
8
 
9
- def analyze_text(text):
 
10
  text = re.sub(r"[^\w\s]", "", text)
11
  text = text.lower()
 
 
12
  encoded_text = tokenizer(text, truncation=True, padding=True, return_tensors='pt')
13
 
 
14
  with torch.no_grad():
15
  output = model(**encoded_text)
16
- predictions = output.logits.argmax(-1).item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- if predictions == 0:
19
- return "Job Interview Related"
20
- else:
21
- return "Not Job Interview Related"
22
  st.title("Job Interview Message Analyzer")
23
 
24
  uploaded_file = st.file_uploader("Upload CSV file")
25
  user_input = st.text_input("Enter text")
26
 
27
  if uploaded_file:
 
28
  data = pd.read_csv(uploaded_file)
 
 
29
  results = []
30
  for message in data["message"]:
31
  result = analyze_text(message)
32
  results.append(result)
33
- data["Job_Interview_Related"] = results
 
 
 
 
 
34
  st.dataframe(data)
 
 
35
  elif user_input:
 
36
  result = analyze_text(user_input)
37
  st.write(f"Message Classification: {result}")
38
  else:
 
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import re
5
  import torch
6
+
7
+ # Load the pre-trained model and tokenizer
8
  tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
9
  model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")
10
 
11
+ def analyze_text(text, confidence_threshold=0.6):
12
+ # Preprocess the text
13
  text = re.sub(r"[^\w\s]", "", text)
14
  text = text.lower()
15
+
16
+ # Encode the text
17
  encoded_text = tokenizer(text, truncation=True, padding=True, return_tensors='pt')
18
 
19
+ # Classify the text
20
  with torch.no_grad():
21
  output = model(**encoded_text)
22
+ logits = output.logits
23
+ predictions = logits.argmax(-1).item()
24
+ confidence = torch.softmax(logits, dim=1)[0][predictions].item()
25
+
26
+ if confidence > confidence_threshold:
27
+ if predictions == 0:
28
+ return "Job Interview Related"
29
+ return "Not Job Interview Related"
30
+
31
+ def count_job_related_messages(data):
32
+ job_related_count = 0
33
+ not_job_related_count = 0
34
+
35
+ for message in data["message"]:
36
+ result = analyze_text(message)
37
+ if result == "Job Interview Related":
38
+ job_related_count += 1
39
+ else:
40
+ not_job_related_count += 1
41
+
42
+ return job_related_count, not_job_related_count
43
 
44
+ # Streamlit application
 
 
 
45
  st.title("Job Interview Message Analyzer")
46
 
47
  uploaded_file = st.file_uploader("Upload CSV file")
48
  user_input = st.text_input("Enter text")
49
 
50
  if uploaded_file:
51
+ # Read the CSV file
52
  data = pd.read_csv(uploaded_file)
53
+
54
+ # Analyze messages
55
  results = []
56
  for message in data["message"]:
57
  result = analyze_text(message)
58
  results.append(result)
59
+
60
+ data["Job Interview Related"] = results
61
+
62
+ # Count job-related messages
63
+ job_related_count, not_job_related_count = count_job_related_messages(data)
64
+
65
  st.dataframe(data)
66
+ st.write(f"Job Interview Related Messages: {job_related_count}")
67
+ st.write(f"Not Job Interview Related Messages: {not_job_related_count}")
68
  elif user_input:
69
+ # Analyze user-input text
70
  result = analyze_text(user_input)
71
  st.write(f"Message Classification: {result}")
72
  else: