Lora commited on
Commit
9cfeab8
·
1 Parent(s): c289bbc

add requirements, sense vecs, lm head

Browse files
requirements.txt ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.1.0
2
+ aiohttp==3.8.4
3
+ aiosignal==1.3.1
4
+ altair==4.2.2
5
+ anyio==3.6.2
6
+ async-timeout==4.0.2
7
+ attrs==22.2.0
8
+ certifi @ file:///Users/cbousseau/work/recipes/ci_py311/certifi_1677903144932/work/certifi
9
+ charset-normalizer==3.1.0
10
+ click==8.1.3
11
+ contourpy==1.0.7
12
+ cycler==0.11.0
13
+ entrypoints==0.4
14
+ fastapi==0.95.0
15
+ ffmpy==0.3.0
16
+ filelock==3.10.7
17
+ fonttools==4.39.3
18
+ frozenlist==1.3.3
19
+ fsspec==2023.3.0
20
+ gradio==3.24.1
21
+ gradio_client==0.0.7
22
+ h11==0.14.0
23
+ httpcore==0.16.3
24
+ httpx==0.23.3
25
+ huggingface-hub==0.13.3
26
+ idna==3.4
27
+ Jinja2==3.1.2
28
+ jsonschema==4.17.3
29
+ kiwisolver==1.4.4
30
+ linkify-it-py==2.0.0
31
+ markdown-it-py==2.2.0
32
+ MarkupSafe==2.1.2
33
+ matplotlib==3.7.1
34
+ mdit-py-plugins==0.3.3
35
+ mdurl==0.1.2
36
+ mpmath==1.3.0
37
+ multidict==6.0.4
38
+ networkx==3.1
39
+ numpy==1.24.2
40
+ orjson==3.8.9
41
+ packaging==23.0
42
+ pandas==2.0.0
43
+ Pillow==9.5.0
44
+ pydantic==1.10.7
45
+ pydub==0.25.1
46
+ pyparsing==3.0.9
47
+ pyrsistent==0.19.3
48
+ python-dateutil==2.8.2
49
+ python-multipart==0.0.6
50
+ pytz==2023.3
51
+ PyYAML==6.0
52
+ regex==2023.3.23
53
+ requests==2.28.2
54
+ rfc3986==1.5.0
55
+ semantic-version==2.10.0
56
+ six==1.16.0
57
+ sniffio==1.3.0
58
+ starlette==0.26.1
59
+ sympy==1.11.1
60
+ tokenizers==0.13.3
61
+ toolz==0.12.0
62
+ torch==2.0.0
63
+ tqdm==4.65.0
64
+ transformers==4.27.4
65
+ typing_extensions==4.5.0
66
+ tzdata==2023.3
67
+ uc-micro-py==1.0.1
68
+ urllib3==1.26.15
69
+ uvicorn==0.21.1
70
+ websockets==11.0
71
+ yarl==1.8.2
senses/all_vecs_mtx.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f0c9de5688dd793470c40ebc3b49c29be6ddbf9a38804bca64512940671e129
3
+ size 2470232826
senses/lm_head.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f94054e64b4d1a07e18443769df4d3b9e346c00b02ffe4e9579e8313034dac24
3
+ size 154411755
senses/use_senses.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Visualize some sense vectors"""
2
+
3
+ import torch
4
+ import argparse
5
+
6
+ import transformers
7
+
8
+ def visualize_word(word, tokenizer, vecs, lm_head, count=20, contents=None):
9
+ """
10
+ Prints out the top-scoring words (and lowest-scoring words) for each sense.
11
+
12
+ """
13
+ if contents is None:
14
+ print(word)
15
+ token_id = tokenizer(word)['input_ids'][0]
16
+ contents = vecs[token_id] # torch.Size([16, 768])
17
+
18
+ for i in range(contents.shape[0]):
19
+ print('~~~~~~~~~~~~~~~~~~~~~~~{}~~~~~~~~~~~~~~~~~~~~~~~~'.format(i))
20
+ logits = contents[i,:] @ lm_head.t() # (vocab,) [768] @ [768, 50257] -> [50257]
21
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
22
+ print('~~~Positive~~~')
23
+ for j in range(count):
24
+ print(tokenizer.decode(sorted_indices[j]), '\t','{:.2f}'.format(sorted_logits[j].item()))
25
+ print('~~~Negative~~~')
26
+ for j in range(count):
27
+ print(tokenizer.decode(sorted_indices[-j-1]), '\t','{:.2f}'.format(sorted_logits[-j-1].item()))
28
+ return contents
29
+ print()
30
+ print()
31
+ print()
32
+
33
+ argp = argparse.ArgumentParser()
34
+ argp.add_argument('vecs_path')
35
+ argp.add_argument('lm_head_path')
36
+ args = argp.parse_args()
37
+
38
+ # Load tokenizer and parameters
39
+ tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
40
+ vecs = torch.load(args.vecs_path)
41
+ lm_head = torch.load(args.lm_head_path)
42
+
43
+ visualize_word(input('Enter a word:'), tokenizer, vecs, lm_head, count=5)
44
+