vives commited on
Commit
2168bad
·
1 Parent(s): 6a4d8b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -7
app.py CHANGED
@@ -40,17 +40,17 @@ def get_transcript(file):
40
  transcript = data['results'].values[1][0]['transcript']
41
  transcript = transcript.lower()
42
  return transcript
43
- def concat_tokens(sentences):
44
- tokens = {'input_ids': [], 'attention_mask': [], 'KPS': {}}
45
- for sentence, values in sentences.items():
46
- weight = values['weight']
47
  # encode each sentence and append to dictionary
48
  new_tokens = tokenizer.encode_plus(sentence, max_length=64,
49
  truncation=True, padding='max_length',
50
  return_tensors='pt')
51
  tokens['input_ids'].append(new_tokens['input_ids'][0])
52
  tokens['attention_mask'].append(new_tokens['attention_mask'][0])
53
- tokens['KPS'][sentence] = weight
54
  # reformat list of tensors into single tensor
55
  tokens['input_ids'] = torch.stack(tokens['input_ids'])
56
  tokens['attention_mask'] = torch.stack(tokens['attention_mask'])
@@ -59,7 +59,7 @@ def concat_tokens(sentences):
59
  """preprocess tags"""
60
  if tags:
61
  tags = [x.lower().strip() for x in tags.split(",")]
62
- tags_tokens = concat_tokens(tags)
63
  tags_tokens.pop("KPS")
64
  with torch.no_grad():
65
  outputs_tags = model(**tags_tokens)
@@ -70,7 +70,22 @@ if tags:
70
 
71
  """Code related with processing text, extracting KPs, and doing distance to tag"""
72
 
73
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  def calculate_weighted_embed_dist(out, tokens, weight, text,kp_dict, idx, exclude_text=False,exclude_words=False):
75
  sim_dict = {}
76
  pools = pool_embeddings_count(out, tokens, idx).detach().numpy()
 
40
  transcript = data['results'].values[1][0]['transcript']
41
  transcript = transcript.lower()
42
  return transcript
43
+
44
+ def concat_tokens_tags(sentences):
45
+ tokens = {'input_ids': [], 'attention_mask': [], 'KPS': []}
46
+ for sentence in sentences:
47
  # encode each sentence and append to dictionary
48
  new_tokens = tokenizer.encode_plus(sentence, max_length=64,
49
  truncation=True, padding='max_length',
50
  return_tensors='pt')
51
  tokens['input_ids'].append(new_tokens['input_ids'][0])
52
  tokens['attention_mask'].append(new_tokens['attention_mask'][0])
53
+ tokens['KPS'].append(sentence)
54
  # reformat list of tensors into single tensor
55
  tokens['input_ids'] = torch.stack(tokens['input_ids'])
56
  tokens['attention_mask'] = torch.stack(tokens['attention_mask'])
 
59
  """preprocess tags"""
60
  if tags:
61
  tags = [x.lower().strip() for x in tags.split(",")]
62
+ tags_tokens = concat_tokens_tags(tags)
63
  tags_tokens.pop("KPS")
64
  with torch.no_grad():
65
  outputs_tags = model(**tags_tokens)
 
70
 
71
  """Code related with processing text, extracting KPs, and doing distance to tag"""
72
 
73
+ def concat_tokens(sentences):
74
+ tokens = {'input_ids': [], 'attention_mask': [], 'KPS': {}}
75
+ for sentence, values in sentences.items():
76
+ weight = values['weight']
77
+ # encode each sentence and append to dictionary
78
+ new_tokens = tokenizer.encode_plus(sentence, max_length=64,
79
+ truncation=True, padding='max_length',
80
+ return_tensors='pt')
81
+ tokens['input_ids'].append(new_tokens['input_ids'][0])
82
+ tokens['attention_mask'].append(new_tokens['attention_mask'][0])
83
+ tokens['KPS'][sentence] = weight
84
+ # reformat list of tensors into single tensor
85
+ tokens['input_ids'] = torch.stack(tokens['input_ids'])
86
+ tokens['attention_mask'] = torch.stack(tokens['attention_mask'])
87
+ return tokens
88
+
89
  def calculate_weighted_embed_dist(out, tokens, weight, text,kp_dict, idx, exclude_text=False,exclude_words=False):
90
  sim_dict = {}
91
  pools = pool_embeddings_count(out, tokens, idx).detach().numpy()