DGSpitzer commited on
Commit
cc3b35f
·
1 Parent(s): edb5a8c

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +50 -0
utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import httpx
4
+
5
+ from constants import MUBERT_TAGS, MUBERT_LICENSE, MUBERT_MODE, MUBERT_TOKEN
6
+
7
+
8
+ def get_mubert_tags_embeddings(w2v_model):
9
+ return w2v_model.encode(MUBERT_TAGS)
10
+
11
+
12
+ def get_pat(email: str):
13
+ r = httpx.post('https://api-b2b.mubert.com/v2/GetServiceAccess',
14
+ json={
15
+ "method": "GetServiceAccess",
16
+ "params": {
17
+ "email": email,
18
+ "license": MUBERT_LICENSE,
19
+ "token": MUBERT_TOKEN,
20
+ "mode": MUBERT_MODE,
21
+ }
22
+ })
23
+
24
+ rdata = json.loads(r.text)
25
+ assert rdata['status'] == 1, "probably incorrect e-mail"
26
+ pat = rdata['data']['pat']
27
+ return pat
28
+
29
+
30
+ def find_similar(em, embeddings, method='cosine'):
31
+ scores = []
32
+ for ref in embeddings:
33
+ if method == 'cosine':
34
+ scores.append(1 - np.dot(ref, em) / (np.linalg.norm(ref) * np.linalg.norm(em)))
35
+ if method == 'norm':
36
+ scores.append(np.linalg.norm(ref - em))
37
+ return np.array(scores), np.argsort(scores)
38
+
39
+
40
+ def get_tags_for_prompts(w2v_model, mubert_tags_embeddings, prompts, top_n=3, debug=False):
41
+ prompts_embeddings = w2v_model.encode(prompts)
42
+ ret = []
43
+ for i, pe in enumerate(prompts_embeddings):
44
+ scores, idxs = find_similar(pe, mubert_tags_embeddings)
45
+ top_tags = MUBERT_TAGS[idxs[:top_n]]
46
+ top_prob = 1 - scores[idxs[:top_n]]
47
+ if debug:
48
+ print(f"Prompt: {prompts[i]}\nTags: {', '.join(top_tags)}\nScores: {top_prob}\n\n\n")
49
+ ret.append((prompts[i], list(top_tags)))
50
+ return ret