xiaohang07 commited on
Commit
d94c1ca
·
verified ·
1 Parent(s): fe7e2d0

Upload 9 files

Browse files
Files changed (9) hide show
  1. Figure1.png +0 -0
  2. README.md +3 -3
  3. Voc_prior +130 -0
  4. app.py +144 -0
  5. config.json +29 -0
  6. mattergpt_wrapper.py +70 -0
  7. model.py +312 -0
  8. pytorch_model.pt +3 -0
  9. requirements.txt +4 -0
Figure1.png ADDED
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: MatterGPT CPU
3
- emoji: 🏃
4
  colorFrom: purple
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 4.41.0
8
  app_file: app.py
@@ -10,4 +10,4 @@ pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: MatterGPT CPU
3
+ emoji: 🖼
4
  colorFrom: purple
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.41.0
8
  app_file: app.py
 
10
  license: mit
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
Voc_prior ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ S
2
+ o-o
3
+ He
4
+ Dy
5
+ -o-
6
+ Ne
7
+ +-o
8
+ Re
9
+ Bi
10
+ Cu
11
+ oo+
12
+ 16
13
+ Sc
14
+ --o
15
+ Nd
16
+ Lu
17
+ -+o
18
+ Te
19
+ Si
20
+ o+o
21
+ Er
22
+ 1
23
+ Sr
24
+ Hg
25
+ 3
26
+ oo-
27
+ 8
28
+ Ru
29
+ H
30
+ Mo
31
+ Tc
32
+ 12
33
+ 11
34
+ +oo
35
+ Pb
36
+ 6
37
+ In
38
+ La
39
+ --+
40
+ C
41
+ Sn
42
+ Se
43
+ B
44
+ Ar
45
+ o--
46
+ -o+
47
+ Ga
48
+ ++o
49
+ Rh
50
+ Sm
51
+ Ir
52
+ Li
53
+ Tl
54
+ 18
55
+ I
56
+ Cl
57
+ Ag
58
+ Ba
59
+ Ta
60
+ Ho
61
+ Tb
62
+ As
63
+ -+-
64
+ Gd
65
+ Os
66
+ O
67
+ 15
68
+ ---
69
+ W
70
+ F
71
+ 13
72
+ Pm
73
+ K
74
+ Na
75
+ 9
76
+ Eu
77
+ Ce
78
+ 14
79
+ -++
80
+ 5
81
+ Ge
82
+ Yb
83
+ Al
84
+ Rb
85
+ Pd
86
+ Ni
87
+ Cd
88
+ Hf
89
+ P
90
+ Zn
91
+ Ti
92
+ Nb
93
+ 0
94
+ Pr
95
+ 7
96
+ Mg
97
+ Y
98
+ +-+
99
+ ooo
100
+ Pt
101
+ +--
102
+ 19
103
+ Cs
104
+ N
105
+ -oo
106
+ +o-
107
+ o-+
108
+ Xe
109
+ 4
110
+ o+-
111
+ Tm
112
+ 2
113
+ Cr
114
+ Fe
115
+ +o+
116
+ Zr
117
+ ++-
118
+ Kr
119
+ 10
120
+ +++
121
+ Co
122
+ o++
123
+ Be
124
+ Br
125
+ Mn
126
+ Ca
127
+ Au
128
+ V
129
+ Sb
130
+ 17
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from mattergpt_wrapper import MatterGPTWrapper, SimpleTokenizer
4
+ import os
5
+ from slices.core import SLICES
6
+ from pymatgen.core.structure import Structure
7
+ from pymatgen.io.cif import CifWriter
8
+ from pymatgen.io.ase import AseAtomsAdaptor
9
+ from ase.io import write as ase_write
10
+ import tempfile
11
+ import time
12
+ # 设置PyTorch使用的线程数
13
+ torch.set_num_threads(2)
14
+ def load_quantized_model(model_path):
15
+ model = MatterGPTWrapper.from_pretrained(model_path)
16
+ model.to('cpu')
17
+ model.eval()
18
+ quantized_model = torch.quantization.quantize_dynamic(
19
+ model, {torch.nn.Linear}, dtype=torch.qint8
20
+ )
21
+ return quantized_model
22
+
23
+
24
+ # Load and quantize the model
25
+ model_path = "./"
26
+ quantized_model = load_quantized_model(model_path)
27
+ quantized_model.to("cpu")
28
+ quantized_model.eval()
29
+ # Load the tokenizer
30
+ tokenizer_path = "Voc_prior"
31
+ tokenizer = SimpleTokenizer(tokenizer_path)
32
+
33
+ # Initialize SLICES backend
34
+ try:
35
+ backend = SLICES(relax_model="chgnet",fmax=0.4,steps=25)
36
+
37
+ except Exception as e:
38
+ backend = SLICES(relax_model=None)
39
+
40
+
41
+
42
+ def generate_slices_quantized(quantized_model, tokenizer, formation_energy, band_gap, max_length, temperature, do_sample, top_k, top_p):
43
+ condition = torch.tensor([[float(formation_energy), float(band_gap)]], dtype=torch.float32)
44
+ context = '>'
45
+ x = torch.tensor([[tokenizer.stoi[context]]], dtype=torch.long)
46
+
47
+ with torch.no_grad():
48
+ generated = quantized_model.generate(x, prop=condition, max_length=max_length,
49
+ temperature=temperature, do_sample=do_sample,
50
+ top_k=top_k, top_p=top_p)
51
+
52
+ return tokenizer.decode(generated[0].tolist())
53
+
54
+ def generate_slices(formation_energy, band_gap):
55
+ return generate_slices_quantized(quantized_model, tokenizer, formation_energy, band_gap,
56
+ quantized_model.config.block_size, 1.2, True, 0, 0.9)
57
+ def wrap_structure(structure):
58
+ """Wrap all atoms back into the unit cell."""
59
+ for i, site in enumerate(structure):
60
+ frac_coords = site.frac_coords % 1.0
61
+ structure.replace(i, species=site.species, coords=frac_coords, coords_are_cartesian=False)
62
+ return structure
63
+
64
+ def convert_and_visualize(slices_string):
65
+ try:
66
+ structure, energy = backend.SLICES2structure(slices_string)
67
+
68
+ # Wrap atoms back into the unit cell
69
+ structure = wrap_structure(structure)
70
+
71
+ # Generate CIF and save to temporary file
72
+ cif_file = tempfile.NamedTemporaryFile(mode='w', suffix='.cif', delete=False)
73
+ cif_writer = CifWriter(structure)
74
+ cif_writer.write_file(cif_file.name)
75
+
76
+ # Generate structure summary
77
+ summary = f"Formula: {structure.composition.reduced_formula}\n"
78
+ summary += f"Number of sites: {len(structure)}\n"
79
+ summary += f"Lattice parameters: a={structure.lattice.a:.3f}, b={structure.lattice.b:.3f}, c={structure.lattice.c:.3f}\n"
80
+ summary += f"Angles: alpha={structure.lattice.alpha:.2f}, beta={structure.lattice.beta:.2f}, gamma={structure.lattice.gamma:.2f}\n"
81
+ summary += f"Volume: {structure.volume:.3f} ų\n"
82
+ summary += f"Density: {structure.density:.3f} g/cm³"
83
+
84
+ # Generate structure image using ASE and save to temporary file
85
+ atoms = AseAtomsAdaptor.get_atoms(structure)
86
+ image_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
87
+ ase_write(image_file.name, atoms, format='png', rotation='10x,10y,10z')
88
+
89
+
90
+ return cif_file.name, image_file.name, summary, f"Conversion successful. Energy: {energy:.4f} eV/atom", True
91
+ except Exception as e:
92
+
93
+ return "", "", "", f"Conversion failed. Error: {str(e)}", False
94
+
95
+ def generate_and_convert(formation_energy, band_gap):
96
+ max_attempts = 5
97
+ start_time = time.time()
98
+ max_time = 300 # 5 minutes maximum execution time
99
+
100
+ for attempt in range(max_attempts):
101
+ if time.time() - start_time > max_time:
102
+ return "Exceeded maximum execution time", "", "", "", "Generation and conversion failed due to timeout"
103
+
104
+ slices_string = generate_slices(formation_energy, band_gap)
105
+ cif_file, image_file, structure_summary, status, success = convert_and_visualize(slices_string)
106
+
107
+ if success:
108
+ return slices_string, cif_file, image_file, structure_summary, f"Successful on attempt {attempt + 1}: {status}"
109
+
110
+
111
+ if attempt == max_attempts - 1:
112
+ return slices_string, "", "", "", f"Failed after {max_attempts} attempts: {status}"
113
+
114
+ return "Failed to generate valid SLICES string", "", "", "", "Generation failed"
115
+
116
+ # Create the Gradio interface
117
+ with gr.Blocks() as iface:
118
+ gr.Markdown("# Crystal Inverse Designer: From Properties to Structures")
119
+
120
+ with gr.Row():
121
+ with gr.Column():
122
+ gr.Image("Figure1.png", label="De novo crystal generation by MatterGPT targeting desired Eg, Ef", width=1000, height=300)
123
+ gr.Markdown("**Enter desired properties to inversely design materials (encoded in SLICES), then decode it into crystal structure.**")
124
+
125
+ with gr.Row():
126
+ with gr.Column(scale=2):
127
+ band_gap = gr.Number(label="Band Gap (eV)", value=2.0)
128
+ formation_energy = gr.Number(label="Formation Energy (eV/atom)", value=-1.0)
129
+ generate_button = gr.Button("Generate")
130
+
131
+ with gr.Column(scale=3):
132
+ slices_output = gr.Textbox(label="Generated SLICES String")
133
+ cif_output = gr.File(label="Download CIF", file_types=[".cif"])
134
+ structure_image = gr.Image(label="Structure Visualization")
135
+ structure_summary = gr.Textbox(label="Structure Summary", lines=6)
136
+ conversion_status = gr.Textbox(label="Conversion Status")
137
+
138
+ generate_button.click(
139
+ generate_and_convert,
140
+ inputs=[formation_energy, band_gap],
141
+ outputs=[slices_output, cif_output, structure_image, structure_summary, conversion_status]
142
+ )
143
+
144
+ iface.launch(share=True)
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "gpt",
3
+ "architectures": [
4
+ "GPT"
5
+ ],
6
+ "vocab_size": 132,
7
+ "block_size": 397,
8
+ "n_layer": 12,
9
+ "n_head": 12,
10
+ "n_embd": 768,
11
+ "num_props": 2,
12
+ "activation_function": "gelu_new",
13
+ "resid_pdrop": 0.1,
14
+ "embd_pdrop": 0.1,
15
+ "attn_pdrop": 0.1,
16
+ "layer_norm_epsilon": 1e-5,
17
+ "initializer_range": 0.02,
18
+ "summary_type": "cls_index",
19
+ "summary_use_proj": true,
20
+ "summary_activation": null,
21
+ "summary_proj_to_labels": true,
22
+ "summary_first_dropout": 0.1,
23
+ "scale_attn_weights": true,
24
+ "use_cache": true,
25
+ "bos_token_id": 130,
26
+ "eos_token_id": 131,
27
+ "lstm": false,
28
+ "lstm_layers": 0
29
+ }
mattergpt_wrapper.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import PreTrainedModel, PretrainedConfig
4
+ from model import GPT, GPTConfig # Import your original model and config classes
5
+ import json
6
+
7
+ class CustomGPTConfig(PretrainedConfig):
8
+ model_type = "gpt"
9
+
10
+ def __init__(self, **kwargs):
11
+ super().__init__(**kwargs)
12
+ for key, value in kwargs.items():
13
+ setattr(self, key, value)
14
+
15
+ class MatterGPTWrapper(PreTrainedModel):
16
+ config_class = CustomGPTConfig
17
+ base_model_prefix = "gpt"
18
+
19
+ def __init__(self, config):
20
+ super().__init__(config)
21
+ self.model = GPT(GPTConfig(**config.__dict__))
22
+
23
+ def forward(self, input_ids, attention_mask=None, labels=None, prop=None):
24
+ return self.model(input_ids, targets=labels, prop=prop)
25
+
26
+ def generate(self, input_ids, prop, max_length, num_return_sequences=1, **kwargs):
27
+ steps = max_length - input_ids.shape[1]
28
+ return self.model.sample(input_ids, steps, prop=prop, **kwargs)
29
+
30
+ @classmethod
31
+ def from_pretrained(cls, pretrained_model_path, *model_args, **kwargs):
32
+ config_file = f"{pretrained_model_path}/config.json"
33
+ with open(config_file, 'r') as f:
34
+ config_dict = json.load(f)
35
+
36
+ config = CustomGPTConfig(**config_dict)
37
+
38
+ model = cls(config)
39
+
40
+
41
+ # 加载模型权重
42
+ state_dict = torch.load(f"{pretrained_model_path}/pytorch_model.pt", map_location="cpu")
43
+ model.model.load_state_dict(state_dict)
44
+
45
+ return model
46
+
47
+ def save_pretrained(self, save_directory):
48
+ self.config.save_pretrained(save_directory)
49
+ torch.save(self.model.state_dict(), f"{save_directory}/pytorch_model.pt")
50
+
51
+ class SimpleTokenizer:
52
+ def __init__(self, vocab_file):
53
+ with open(vocab_file, 'r') as f:
54
+ self.vocab = f.read().splitlines()
55
+ self.vocab = sorted(set(self.vocab + ['<', '>']))
56
+ self.stoi = {ch: i for i, ch in enumerate(self.vocab)}
57
+ self.itos = {i: ch for i, ch in enumerate(self.vocab)}
58
+
59
+ def encode(self, text):
60
+ return [self.stoi[token] for token in text.split()]
61
+
62
+ def decode(self, ids):
63
+ return " ".join([self.itos[int(i)] for i in ids if i in self.itos]).replace("<", "").strip()
64
+
65
+ def __call__(self, text, return_tensors=None):
66
+ encoded = self.encode(text)
67
+ if return_tensors == 'pt':
68
+ import torch
69
+ return {'input_ids': torch.tensor([encoded])}
70
+ return {'input_ids': [encoded]}
model.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Yan Chen 2023.10
3
4
+ """
5
+ GPT model:
6
+ - the initial stem consists of a combination of token encoding and a positional encoding
7
+ - the meat of it is a uniform sequence of Transformer blocks
8
+ - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
9
+ - all blocks feed into a central residual pathway similar to resnets
10
+ - the final decoder is a linear projection into a vanilla Softmax classifier
11
+ """
12
+
13
+ import math,json
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+
18
+ class GPTConfig:
19
+ """ base GPT config, params common to all GPT versions """
20
+ embd_pdrop = 0.1
21
+ resid_pdrop = 0.1
22
+ attn_pdrop = 0.1
23
+
24
+ def __init__(self, vocab_size, block_size, **kwargs):
25
+ self.vocab_size = vocab_size
26
+ self.block_size = block_size
27
+ for k,v in kwargs.items():
28
+ setattr(self, k, v)
29
+
30
+ class GPT1Config(GPTConfig):
31
+ """ GPT-1 like network roughly 125M params """
32
+ n_layer = 12
33
+ n_head = 12
34
+ n_embd = 768
35
+
36
+ class CausalSelfAttention(nn.Module):
37
+ """
38
+ A vanilla multi-head masked self-attention layer with a projection at the end.
39
+ It is possible to use torch.nn.MultiheadAttention here but I am including an
40
+ explicit implementation here to show that there is nothing too scary here.
41
+ """
42
+
43
+ def __init__(self, config):
44
+ super().__init__()
45
+ assert config.n_embd % config.n_head == 0
46
+ # key, query, value projections for all heads
47
+ self.key = nn.Linear(config.n_embd, config.n_embd)
48
+ self.query = nn.Linear(config.n_embd, config.n_embd)
49
+ self.value = nn.Linear(config.n_embd, config.n_embd)
50
+ # regularization
51
+ self.attn_drop = nn.Dropout(config.attn_pdrop)
52
+ self.resid_drop = nn.Dropout(config.resid_pdrop)
53
+ # output projection
54
+ self.proj = nn.Linear(config.n_embd, config.n_embd)
55
+ # causal mask to ensure that attention is only applied to the left in the input sequence
56
+ num = int(bool(config.num_props))
57
+ # num = 1
58
+ self.register_buffer("mask", torch.tril(torch.ones(config.block_size + num, config.block_size + num))
59
+ .view(1, 1, config.block_size + num, config.block_size + num))
60
+
61
+ self.n_head = config.n_head
62
+
63
+ def forward(self, x, layer_past=None):
64
+ B, T, C = x.size()
65
+
66
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
67
+ k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
68
+ q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
69
+ v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
70
+
71
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
72
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
73
+ att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
74
+ att = F.softmax(att, dim=-1)
75
+ attn_save = att
76
+ att = self.attn_drop(att)
77
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
78
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
79
+
80
+ # output projection
81
+ y = self.resid_drop(self.proj(y))
82
+ return y, attn_save
83
+
84
+ class Block(nn.Module):
85
+ """ an unassuming Transformer block """
86
+
87
+ def __init__(self, config):
88
+ super().__init__()
89
+ self.ln1 = nn.LayerNorm(config.n_embd)
90
+ self.ln2 = nn.LayerNorm(config.n_embd)
91
+ self.attn = CausalSelfAttention(config)
92
+ self.mlp = nn.Sequential(
93
+ nn.Linear(config.n_embd, 4 * config.n_embd),
94
+ nn.GELU(),
95
+ nn.Linear(4 * config.n_embd, config.n_embd),
96
+ nn.Dropout(config.resid_pdrop),
97
+ )
98
+
99
+ def forward(self, x):
100
+ y, attn = self.attn(self.ln1(x))
101
+ x = x + y
102
+ x = x + self.mlp(self.ln2(x))
103
+ return x, attn
104
+
105
+ class GPT(nn.Module):
106
+ """ the full GPT language model, with a context size of block_size """
107
+
108
+ def __init__(self, config):
109
+ super().__init__()
110
+ #print(json.dumps(config.__dict__, indent=2))
111
+ # input embedding stem
112
+ self.config = config
113
+ self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
114
+ self.type_emb = nn.Embedding(2, config.n_embd)
115
+ if config.num_props:
116
+ self.prop_nn = nn.Linear(config.num_props, config.n_embd)
117
+
118
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
119
+ self.drop = nn.Dropout(config.embd_pdrop)
120
+ # transformer
121
+ self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
122
+ # decoder head
123
+ self.ln_f = nn.LayerNorm(config.n_embd)
124
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
125
+
126
+ self.block_size = config.block_size
127
+
128
+ if config.lstm:
129
+ self.lstm = nn.LSTM(input_size = config.n_embd, hidden_size = config.n_embd, num_layers = config.lstm_layers, dropout = 0.3, bidirectional = False)
130
+ self.apply(self._init_weights)
131
+
132
+ #logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
133
+
134
+ def get_block_size(self):
135
+ return self.block_size
136
+
137
+ def _init_weights(self, module):
138
+ if isinstance(module, (nn.Linear, nn.Embedding)):
139
+ module.weight.data.normal_(mean=0.0, std=0.02)
140
+ if isinstance(module, nn.Linear) and module.bias is not None:
141
+ module.bias.data.zero_()
142
+ elif isinstance(module, nn.LayerNorm):
143
+ module.bias.data.zero_()
144
+ module.weight.data.fill_(1.0)
145
+
146
+ def configure_optimizers(self, train_config):
147
+ """
148
+ This long function is unfortunately doing something very simple and is being very defensive:
149
+ We are separating out all parameters of the model into two buckets: those that will experience
150
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
151
+ We are then returning the PyTorch optimizer object.
152
+ """
153
+
154
+ # separate out all parameters to those that will and won't experience regularizing weight decay
155
+ decay = set()
156
+ no_decay = set()
157
+ whitelist_weight_modules = (torch.nn.Linear, torch.nn.LSTM)
158
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
159
+ for mn, m in self.named_modules():
160
+ for pn, p in m.named_parameters():
161
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
162
+
163
+ if pn.endswith('bias') or ('bias' in pn):
164
+ # all biases will not be decayed
165
+ no_decay.add(fpn)
166
+ elif (pn.endswith('weight') or ('weight' in pn)) and isinstance(m, whitelist_weight_modules):
167
+ # weights of whitelist modules will be weight decayed
168
+ decay.add(fpn)
169
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
170
+ # weights of blacklist modules will NOT be weight decayed
171
+ no_decay.add(fpn)
172
+
173
+ # special case the position embedding parameter in the root GPT module as not decayed
174
+ no_decay.add('pos_emb')
175
+
176
+ # validate that we considered every parameter
177
+ param_dict = {pn: p for pn, p in self.named_parameters()}
178
+ inter_params = decay & no_decay
179
+ union_params = decay | no_decay
180
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
181
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
182
+ % (str(param_dict.keys() - union_params), )
183
+
184
+ # create the pytorch optimizer object
185
+ optim_groups = [
186
+ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
187
+ {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
188
+ ]
189
+ optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
190
+ return optimizer
191
+
192
+ def forward(self, idx, targets=None, prop = None):
193
+ b, t = idx.size()
194
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
195
+
196
+ if self.config.num_props:
197
+ assert prop.size(-1) == self.config.num_props, "Num_props should be equal to last dim of property vector"
198
+
199
+ # forward the GPT model
200
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
201
+ position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
202
+ type_embeddings = self.type_emb(torch.ones((b,t), dtype = torch.long, device = idx.device))
203
+ x = self.drop(token_embeddings + position_embeddings + type_embeddings)
204
+
205
+ embed = x
206
+
207
+ if self.config.num_props:
208
+ type_embd = self.type_emb(torch.zeros((b, 1), dtype = torch.long, device = idx.device))
209
+ if prop.ndim == 2:
210
+ p = self.prop_nn(prop.unsqueeze(1)) # for single property
211
+ else:
212
+ p = self.prop_nn(prop) # for multiproperty
213
+ p += type_embd
214
+ x = torch.cat([p, x], 1)
215
+
216
+ # x = self.blocks(x)
217
+ attn_maps = []
218
+
219
+ for layer in self.blocks:
220
+ x, attn = layer(x)
221
+ attn_maps.append(attn)
222
+
223
+ x = self.ln_f(x)
224
+ logits = self.head(x)
225
+
226
+ if self.config.num_props:
227
+ num = int(bool(self.config.num_props))
228
+ else:
229
+ num = 0
230
+
231
+ logits = logits[:, num:, :]
232
+
233
+ # if we are given some desired targets also calculate the loss
234
+ loss = None
235
+ if targets is not None:
236
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.view(-1))
237
+
238
+ return logits, loss, attn_maps, embed # (num_layers, batch_size, num_heads, max_seq_len, max_seq_len)
239
+
240
+
241
+ @torch.no_grad()
242
+ def sample(self, x, steps, temperature=1.0, do_sample=False, top_k=None, top_p=None, prop=None):
243
+ """
244
+ Take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
245
+ the sequence, feeding the predictions back into the model each time. Clearly the sampling
246
+ has quadratic complexity unlike an RNN that is only linear, and has a finite context window
247
+ of block_size, unlike an RNN that has an infinite context window.
248
+
249
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
250
+ """
251
+ #model.eval()
252
+
253
+ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
254
+ """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
255
+ Args:
256
+ logits: logits distribution shape (batch size x vocabulary size)
257
+ top_k > 0: keep only top k tokens with highest probability (top-k filtering).
258
+ top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
259
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
260
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
261
+ """
262
+ top_k = min(top_k, logits.size(-1)) # Safety check
263
+ if top_k > 0:
264
+ # Remove all tokens with a probability less than the last token of the top-k
265
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
266
+ logits[indices_to_remove] = filter_value
267
+
268
+ if top_p > 0.0:
269
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
270
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
271
+
272
+ # Remove tokens with cumulative probability above the threshold
273
+ sorted_indices_to_remove = cumulative_probs > top_p
274
+ # Shift the indices to the right to keep also the first token above the threshold
275
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
276
+ sorted_indices_to_remove[..., 0] = 0
277
+
278
+ # scatter sorted tensors to original indexing
279
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
280
+ logits[indices_to_remove] = filter_value
281
+ return logits
282
+
283
+
284
+ for k in range(steps):
285
+ x_cond = x if x.size(1) <= self.block_size else x[:, -self.block_size:] # crop context if needed
286
+
287
+ # forward the model to get the logits for the index in the sequence
288
+ logits, _, _, _ = self(x_cond, prop = prop) # for sampling, no target
289
+
290
+ # pluck the logits at the final step and scale by desired temperature
291
+ logits = logits[:, -1, :] / temperature
292
+
293
+ # optionally crop the logits to only the top k options OR using nucleus (top-p) filtering
294
+ #if top_k is not None:
295
+ # v, _ = torch.topk(logits, top_k)
296
+ # logits[logits < v[:, [-1]]] = -float('Inf')
297
+ logits = top_k_top_p_filtering(logits, top_p=top_p, top_k=top_k)
298
+
299
+
300
+ # apply softmax to convert logits to (normalized) probabilities
301
+ probs = F.softmax(logits, dim=-1)
302
+
303
+ # sample from the distribution or take the most likely
304
+ if do_sample:
305
+ x_next = torch.multinomial(probs, num_samples=1)
306
+ else:
307
+ _, x_next = torch.topk(probs, k=1, dim=-1)
308
+
309
+ # append sampled index to the running sequence and continue
310
+ x = torch.cat((x, x_next), dim=1)
311
+
312
+ return x[:, 1:]
pytorch_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d0071fa7c05449273bfaf60357e2c4f9525c8b8f47e6d313856160312a72b21
3
+ size 349946009
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ spaces
4
+ slices==2.0.4