Spaces:
Running
Running
Upload 9 files
Browse files- Figure1.png +0 -0
- README.md +3 -3
- Voc_prior +130 -0
- app.py +144 -0
- config.json +29 -0
- mattergpt_wrapper.py +70 -0
- model.py +312 -0
- pytorch_model.pt +3 -0
- requirements.txt +4 -0
Figure1.png
ADDED
![]() |
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
title: MatterGPT CPU
|
3 |
-
emoji:
|
4 |
colorFrom: purple
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.41.0
|
8 |
app_file: app.py
|
@@ -10,4 +10,4 @@ pinned: false
|
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: MatterGPT CPU
|
3 |
+
emoji: 🖼
|
4 |
colorFrom: purple
|
5 |
+
colorTo: red
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.41.0
|
8 |
app_file: app.py
|
|
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
Voc_prior
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
S
|
2 |
+
o-o
|
3 |
+
He
|
4 |
+
Dy
|
5 |
+
-o-
|
6 |
+
Ne
|
7 |
+
+-o
|
8 |
+
Re
|
9 |
+
Bi
|
10 |
+
Cu
|
11 |
+
oo+
|
12 |
+
16
|
13 |
+
Sc
|
14 |
+
--o
|
15 |
+
Nd
|
16 |
+
Lu
|
17 |
+
-+o
|
18 |
+
Te
|
19 |
+
Si
|
20 |
+
o+o
|
21 |
+
Er
|
22 |
+
1
|
23 |
+
Sr
|
24 |
+
Hg
|
25 |
+
3
|
26 |
+
oo-
|
27 |
+
8
|
28 |
+
Ru
|
29 |
+
H
|
30 |
+
Mo
|
31 |
+
Tc
|
32 |
+
12
|
33 |
+
11
|
34 |
+
+oo
|
35 |
+
Pb
|
36 |
+
6
|
37 |
+
In
|
38 |
+
La
|
39 |
+
--+
|
40 |
+
C
|
41 |
+
Sn
|
42 |
+
Se
|
43 |
+
B
|
44 |
+
Ar
|
45 |
+
o--
|
46 |
+
-o+
|
47 |
+
Ga
|
48 |
+
++o
|
49 |
+
Rh
|
50 |
+
Sm
|
51 |
+
Ir
|
52 |
+
Li
|
53 |
+
Tl
|
54 |
+
18
|
55 |
+
I
|
56 |
+
Cl
|
57 |
+
Ag
|
58 |
+
Ba
|
59 |
+
Ta
|
60 |
+
Ho
|
61 |
+
Tb
|
62 |
+
As
|
63 |
+
-+-
|
64 |
+
Gd
|
65 |
+
Os
|
66 |
+
O
|
67 |
+
15
|
68 |
+
---
|
69 |
+
W
|
70 |
+
F
|
71 |
+
13
|
72 |
+
Pm
|
73 |
+
K
|
74 |
+
Na
|
75 |
+
9
|
76 |
+
Eu
|
77 |
+
Ce
|
78 |
+
14
|
79 |
+
-++
|
80 |
+
5
|
81 |
+
Ge
|
82 |
+
Yb
|
83 |
+
Al
|
84 |
+
Rb
|
85 |
+
Pd
|
86 |
+
Ni
|
87 |
+
Cd
|
88 |
+
Hf
|
89 |
+
P
|
90 |
+
Zn
|
91 |
+
Ti
|
92 |
+
Nb
|
93 |
+
0
|
94 |
+
Pr
|
95 |
+
7
|
96 |
+
Mg
|
97 |
+
Y
|
98 |
+
+-+
|
99 |
+
ooo
|
100 |
+
Pt
|
101 |
+
+--
|
102 |
+
19
|
103 |
+
Cs
|
104 |
+
N
|
105 |
+
-oo
|
106 |
+
+o-
|
107 |
+
o-+
|
108 |
+
Xe
|
109 |
+
4
|
110 |
+
o+-
|
111 |
+
Tm
|
112 |
+
2
|
113 |
+
Cr
|
114 |
+
Fe
|
115 |
+
+o+
|
116 |
+
Zr
|
117 |
+
++-
|
118 |
+
Kr
|
119 |
+
10
|
120 |
+
+++
|
121 |
+
Co
|
122 |
+
o++
|
123 |
+
Be
|
124 |
+
Br
|
125 |
+
Mn
|
126 |
+
Ca
|
127 |
+
Au
|
128 |
+
V
|
129 |
+
Sb
|
130 |
+
17
|
app.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from mattergpt_wrapper import MatterGPTWrapper, SimpleTokenizer
|
4 |
+
import os
|
5 |
+
from slices.core import SLICES
|
6 |
+
from pymatgen.core.structure import Structure
|
7 |
+
from pymatgen.io.cif import CifWriter
|
8 |
+
from pymatgen.io.ase import AseAtomsAdaptor
|
9 |
+
from ase.io import write as ase_write
|
10 |
+
import tempfile
|
11 |
+
import time
|
12 |
+
# 设置PyTorch使用的线程数
|
13 |
+
torch.set_num_threads(2)
|
14 |
+
def load_quantized_model(model_path):
|
15 |
+
model = MatterGPTWrapper.from_pretrained(model_path)
|
16 |
+
model.to('cpu')
|
17 |
+
model.eval()
|
18 |
+
quantized_model = torch.quantization.quantize_dynamic(
|
19 |
+
model, {torch.nn.Linear}, dtype=torch.qint8
|
20 |
+
)
|
21 |
+
return quantized_model
|
22 |
+
|
23 |
+
|
24 |
+
# Load and quantize the model
|
25 |
+
model_path = "./"
|
26 |
+
quantized_model = load_quantized_model(model_path)
|
27 |
+
quantized_model.to("cpu")
|
28 |
+
quantized_model.eval()
|
29 |
+
# Load the tokenizer
|
30 |
+
tokenizer_path = "Voc_prior"
|
31 |
+
tokenizer = SimpleTokenizer(tokenizer_path)
|
32 |
+
|
33 |
+
# Initialize SLICES backend
|
34 |
+
try:
|
35 |
+
backend = SLICES(relax_model="chgnet",fmax=0.4,steps=25)
|
36 |
+
|
37 |
+
except Exception as e:
|
38 |
+
backend = SLICES(relax_model=None)
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
def generate_slices_quantized(quantized_model, tokenizer, formation_energy, band_gap, max_length, temperature, do_sample, top_k, top_p):
|
43 |
+
condition = torch.tensor([[float(formation_energy), float(band_gap)]], dtype=torch.float32)
|
44 |
+
context = '>'
|
45 |
+
x = torch.tensor([[tokenizer.stoi[context]]], dtype=torch.long)
|
46 |
+
|
47 |
+
with torch.no_grad():
|
48 |
+
generated = quantized_model.generate(x, prop=condition, max_length=max_length,
|
49 |
+
temperature=temperature, do_sample=do_sample,
|
50 |
+
top_k=top_k, top_p=top_p)
|
51 |
+
|
52 |
+
return tokenizer.decode(generated[0].tolist())
|
53 |
+
|
54 |
+
def generate_slices(formation_energy, band_gap):
|
55 |
+
return generate_slices_quantized(quantized_model, tokenizer, formation_energy, band_gap,
|
56 |
+
quantized_model.config.block_size, 1.2, True, 0, 0.9)
|
57 |
+
def wrap_structure(structure):
|
58 |
+
"""Wrap all atoms back into the unit cell."""
|
59 |
+
for i, site in enumerate(structure):
|
60 |
+
frac_coords = site.frac_coords % 1.0
|
61 |
+
structure.replace(i, species=site.species, coords=frac_coords, coords_are_cartesian=False)
|
62 |
+
return structure
|
63 |
+
|
64 |
+
def convert_and_visualize(slices_string):
|
65 |
+
try:
|
66 |
+
structure, energy = backend.SLICES2structure(slices_string)
|
67 |
+
|
68 |
+
# Wrap atoms back into the unit cell
|
69 |
+
structure = wrap_structure(structure)
|
70 |
+
|
71 |
+
# Generate CIF and save to temporary file
|
72 |
+
cif_file = tempfile.NamedTemporaryFile(mode='w', suffix='.cif', delete=False)
|
73 |
+
cif_writer = CifWriter(structure)
|
74 |
+
cif_writer.write_file(cif_file.name)
|
75 |
+
|
76 |
+
# Generate structure summary
|
77 |
+
summary = f"Formula: {structure.composition.reduced_formula}\n"
|
78 |
+
summary += f"Number of sites: {len(structure)}\n"
|
79 |
+
summary += f"Lattice parameters: a={structure.lattice.a:.3f}, b={structure.lattice.b:.3f}, c={structure.lattice.c:.3f}\n"
|
80 |
+
summary += f"Angles: alpha={structure.lattice.alpha:.2f}, beta={structure.lattice.beta:.2f}, gamma={structure.lattice.gamma:.2f}\n"
|
81 |
+
summary += f"Volume: {structure.volume:.3f} ų\n"
|
82 |
+
summary += f"Density: {structure.density:.3f} g/cm³"
|
83 |
+
|
84 |
+
# Generate structure image using ASE and save to temporary file
|
85 |
+
atoms = AseAtomsAdaptor.get_atoms(structure)
|
86 |
+
image_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
87 |
+
ase_write(image_file.name, atoms, format='png', rotation='10x,10y,10z')
|
88 |
+
|
89 |
+
|
90 |
+
return cif_file.name, image_file.name, summary, f"Conversion successful. Energy: {energy:.4f} eV/atom", True
|
91 |
+
except Exception as e:
|
92 |
+
|
93 |
+
return "", "", "", f"Conversion failed. Error: {str(e)}", False
|
94 |
+
|
95 |
+
def generate_and_convert(formation_energy, band_gap):
|
96 |
+
max_attempts = 5
|
97 |
+
start_time = time.time()
|
98 |
+
max_time = 300 # 5 minutes maximum execution time
|
99 |
+
|
100 |
+
for attempt in range(max_attempts):
|
101 |
+
if time.time() - start_time > max_time:
|
102 |
+
return "Exceeded maximum execution time", "", "", "", "Generation and conversion failed due to timeout"
|
103 |
+
|
104 |
+
slices_string = generate_slices(formation_energy, band_gap)
|
105 |
+
cif_file, image_file, structure_summary, status, success = convert_and_visualize(slices_string)
|
106 |
+
|
107 |
+
if success:
|
108 |
+
return slices_string, cif_file, image_file, structure_summary, f"Successful on attempt {attempt + 1}: {status}"
|
109 |
+
|
110 |
+
|
111 |
+
if attempt == max_attempts - 1:
|
112 |
+
return slices_string, "", "", "", f"Failed after {max_attempts} attempts: {status}"
|
113 |
+
|
114 |
+
return "Failed to generate valid SLICES string", "", "", "", "Generation failed"
|
115 |
+
|
116 |
+
# Create the Gradio interface
|
117 |
+
with gr.Blocks() as iface:
|
118 |
+
gr.Markdown("# Crystal Inverse Designer: From Properties to Structures")
|
119 |
+
|
120 |
+
with gr.Row():
|
121 |
+
with gr.Column():
|
122 |
+
gr.Image("Figure1.png", label="De novo crystal generation by MatterGPT targeting desired Eg, Ef", width=1000, height=300)
|
123 |
+
gr.Markdown("**Enter desired properties to inversely design materials (encoded in SLICES), then decode it into crystal structure.**")
|
124 |
+
|
125 |
+
with gr.Row():
|
126 |
+
with gr.Column(scale=2):
|
127 |
+
band_gap = gr.Number(label="Band Gap (eV)", value=2.0)
|
128 |
+
formation_energy = gr.Number(label="Formation Energy (eV/atom)", value=-1.0)
|
129 |
+
generate_button = gr.Button("Generate")
|
130 |
+
|
131 |
+
with gr.Column(scale=3):
|
132 |
+
slices_output = gr.Textbox(label="Generated SLICES String")
|
133 |
+
cif_output = gr.File(label="Download CIF", file_types=[".cif"])
|
134 |
+
structure_image = gr.Image(label="Structure Visualization")
|
135 |
+
structure_summary = gr.Textbox(label="Structure Summary", lines=6)
|
136 |
+
conversion_status = gr.Textbox(label="Conversion Status")
|
137 |
+
|
138 |
+
generate_button.click(
|
139 |
+
generate_and_convert,
|
140 |
+
inputs=[formation_energy, band_gap],
|
141 |
+
outputs=[slices_output, cif_output, structure_image, structure_summary, conversion_status]
|
142 |
+
)
|
143 |
+
|
144 |
+
iface.launch(share=True)
|
config.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_type": "gpt",
|
3 |
+
"architectures": [
|
4 |
+
"GPT"
|
5 |
+
],
|
6 |
+
"vocab_size": 132,
|
7 |
+
"block_size": 397,
|
8 |
+
"n_layer": 12,
|
9 |
+
"n_head": 12,
|
10 |
+
"n_embd": 768,
|
11 |
+
"num_props": 2,
|
12 |
+
"activation_function": "gelu_new",
|
13 |
+
"resid_pdrop": 0.1,
|
14 |
+
"embd_pdrop": 0.1,
|
15 |
+
"attn_pdrop": 0.1,
|
16 |
+
"layer_norm_epsilon": 1e-5,
|
17 |
+
"initializer_range": 0.02,
|
18 |
+
"summary_type": "cls_index",
|
19 |
+
"summary_use_proj": true,
|
20 |
+
"summary_activation": null,
|
21 |
+
"summary_proj_to_labels": true,
|
22 |
+
"summary_first_dropout": 0.1,
|
23 |
+
"scale_attn_weights": true,
|
24 |
+
"use_cache": true,
|
25 |
+
"bos_token_id": 130,
|
26 |
+
"eos_token_id": 131,
|
27 |
+
"lstm": false,
|
28 |
+
"lstm_layers": 0
|
29 |
+
}
|
mattergpt_wrapper.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
4 |
+
from model import GPT, GPTConfig # Import your original model and config classes
|
5 |
+
import json
|
6 |
+
|
7 |
+
class CustomGPTConfig(PretrainedConfig):
|
8 |
+
model_type = "gpt"
|
9 |
+
|
10 |
+
def __init__(self, **kwargs):
|
11 |
+
super().__init__(**kwargs)
|
12 |
+
for key, value in kwargs.items():
|
13 |
+
setattr(self, key, value)
|
14 |
+
|
15 |
+
class MatterGPTWrapper(PreTrainedModel):
|
16 |
+
config_class = CustomGPTConfig
|
17 |
+
base_model_prefix = "gpt"
|
18 |
+
|
19 |
+
def __init__(self, config):
|
20 |
+
super().__init__(config)
|
21 |
+
self.model = GPT(GPTConfig(**config.__dict__))
|
22 |
+
|
23 |
+
def forward(self, input_ids, attention_mask=None, labels=None, prop=None):
|
24 |
+
return self.model(input_ids, targets=labels, prop=prop)
|
25 |
+
|
26 |
+
def generate(self, input_ids, prop, max_length, num_return_sequences=1, **kwargs):
|
27 |
+
steps = max_length - input_ids.shape[1]
|
28 |
+
return self.model.sample(input_ids, steps, prop=prop, **kwargs)
|
29 |
+
|
30 |
+
@classmethod
|
31 |
+
def from_pretrained(cls, pretrained_model_path, *model_args, **kwargs):
|
32 |
+
config_file = f"{pretrained_model_path}/config.json"
|
33 |
+
with open(config_file, 'r') as f:
|
34 |
+
config_dict = json.load(f)
|
35 |
+
|
36 |
+
config = CustomGPTConfig(**config_dict)
|
37 |
+
|
38 |
+
model = cls(config)
|
39 |
+
|
40 |
+
|
41 |
+
# 加载模型权重
|
42 |
+
state_dict = torch.load(f"{pretrained_model_path}/pytorch_model.pt", map_location="cpu")
|
43 |
+
model.model.load_state_dict(state_dict)
|
44 |
+
|
45 |
+
return model
|
46 |
+
|
47 |
+
def save_pretrained(self, save_directory):
|
48 |
+
self.config.save_pretrained(save_directory)
|
49 |
+
torch.save(self.model.state_dict(), f"{save_directory}/pytorch_model.pt")
|
50 |
+
|
51 |
+
class SimpleTokenizer:
|
52 |
+
def __init__(self, vocab_file):
|
53 |
+
with open(vocab_file, 'r') as f:
|
54 |
+
self.vocab = f.read().splitlines()
|
55 |
+
self.vocab = sorted(set(self.vocab + ['<', '>']))
|
56 |
+
self.stoi = {ch: i for i, ch in enumerate(self.vocab)}
|
57 |
+
self.itos = {i: ch for i, ch in enumerate(self.vocab)}
|
58 |
+
|
59 |
+
def encode(self, text):
|
60 |
+
return [self.stoi[token] for token in text.split()]
|
61 |
+
|
62 |
+
def decode(self, ids):
|
63 |
+
return " ".join([self.itos[int(i)] for i in ids if i in self.itos]).replace("<", "").strip()
|
64 |
+
|
65 |
+
def __call__(self, text, return_tensors=None):
|
66 |
+
encoded = self.encode(text)
|
67 |
+
if return_tensors == 'pt':
|
68 |
+
import torch
|
69 |
+
return {'input_ids': torch.tensor([encoded])}
|
70 |
+
return {'input_ids': [encoded]}
|
model.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Yan Chen 2023.10
|
3 | |
4 |
+
"""
|
5 |
+
GPT model:
|
6 |
+
- the initial stem consists of a combination of token encoding and a positional encoding
|
7 |
+
- the meat of it is a uniform sequence of Transformer blocks
|
8 |
+
- each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
|
9 |
+
- all blocks feed into a central residual pathway similar to resnets
|
10 |
+
- the final decoder is a linear projection into a vanilla Softmax classifier
|
11 |
+
"""
|
12 |
+
|
13 |
+
import math,json
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from torch.nn import functional as F
|
17 |
+
|
18 |
+
class GPTConfig:
|
19 |
+
""" base GPT config, params common to all GPT versions """
|
20 |
+
embd_pdrop = 0.1
|
21 |
+
resid_pdrop = 0.1
|
22 |
+
attn_pdrop = 0.1
|
23 |
+
|
24 |
+
def __init__(self, vocab_size, block_size, **kwargs):
|
25 |
+
self.vocab_size = vocab_size
|
26 |
+
self.block_size = block_size
|
27 |
+
for k,v in kwargs.items():
|
28 |
+
setattr(self, k, v)
|
29 |
+
|
30 |
+
class GPT1Config(GPTConfig):
|
31 |
+
""" GPT-1 like network roughly 125M params """
|
32 |
+
n_layer = 12
|
33 |
+
n_head = 12
|
34 |
+
n_embd = 768
|
35 |
+
|
36 |
+
class CausalSelfAttention(nn.Module):
|
37 |
+
"""
|
38 |
+
A vanilla multi-head masked self-attention layer with a projection at the end.
|
39 |
+
It is possible to use torch.nn.MultiheadAttention here but I am including an
|
40 |
+
explicit implementation here to show that there is nothing too scary here.
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(self, config):
|
44 |
+
super().__init__()
|
45 |
+
assert config.n_embd % config.n_head == 0
|
46 |
+
# key, query, value projections for all heads
|
47 |
+
self.key = nn.Linear(config.n_embd, config.n_embd)
|
48 |
+
self.query = nn.Linear(config.n_embd, config.n_embd)
|
49 |
+
self.value = nn.Linear(config.n_embd, config.n_embd)
|
50 |
+
# regularization
|
51 |
+
self.attn_drop = nn.Dropout(config.attn_pdrop)
|
52 |
+
self.resid_drop = nn.Dropout(config.resid_pdrop)
|
53 |
+
# output projection
|
54 |
+
self.proj = nn.Linear(config.n_embd, config.n_embd)
|
55 |
+
# causal mask to ensure that attention is only applied to the left in the input sequence
|
56 |
+
num = int(bool(config.num_props))
|
57 |
+
# num = 1
|
58 |
+
self.register_buffer("mask", torch.tril(torch.ones(config.block_size + num, config.block_size + num))
|
59 |
+
.view(1, 1, config.block_size + num, config.block_size + num))
|
60 |
+
|
61 |
+
self.n_head = config.n_head
|
62 |
+
|
63 |
+
def forward(self, x, layer_past=None):
|
64 |
+
B, T, C = x.size()
|
65 |
+
|
66 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
67 |
+
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
68 |
+
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
69 |
+
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
70 |
+
|
71 |
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
72 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
73 |
+
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
|
74 |
+
att = F.softmax(att, dim=-1)
|
75 |
+
attn_save = att
|
76 |
+
att = self.attn_drop(att)
|
77 |
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
78 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
79 |
+
|
80 |
+
# output projection
|
81 |
+
y = self.resid_drop(self.proj(y))
|
82 |
+
return y, attn_save
|
83 |
+
|
84 |
+
class Block(nn.Module):
|
85 |
+
""" an unassuming Transformer block """
|
86 |
+
|
87 |
+
def __init__(self, config):
|
88 |
+
super().__init__()
|
89 |
+
self.ln1 = nn.LayerNorm(config.n_embd)
|
90 |
+
self.ln2 = nn.LayerNorm(config.n_embd)
|
91 |
+
self.attn = CausalSelfAttention(config)
|
92 |
+
self.mlp = nn.Sequential(
|
93 |
+
nn.Linear(config.n_embd, 4 * config.n_embd),
|
94 |
+
nn.GELU(),
|
95 |
+
nn.Linear(4 * config.n_embd, config.n_embd),
|
96 |
+
nn.Dropout(config.resid_pdrop),
|
97 |
+
)
|
98 |
+
|
99 |
+
def forward(self, x):
|
100 |
+
y, attn = self.attn(self.ln1(x))
|
101 |
+
x = x + y
|
102 |
+
x = x + self.mlp(self.ln2(x))
|
103 |
+
return x, attn
|
104 |
+
|
105 |
+
class GPT(nn.Module):
|
106 |
+
""" the full GPT language model, with a context size of block_size """
|
107 |
+
|
108 |
+
def __init__(self, config):
|
109 |
+
super().__init__()
|
110 |
+
#print(json.dumps(config.__dict__, indent=2))
|
111 |
+
# input embedding stem
|
112 |
+
self.config = config
|
113 |
+
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
|
114 |
+
self.type_emb = nn.Embedding(2, config.n_embd)
|
115 |
+
if config.num_props:
|
116 |
+
self.prop_nn = nn.Linear(config.num_props, config.n_embd)
|
117 |
+
|
118 |
+
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
|
119 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
120 |
+
# transformer
|
121 |
+
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
|
122 |
+
# decoder head
|
123 |
+
self.ln_f = nn.LayerNorm(config.n_embd)
|
124 |
+
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
125 |
+
|
126 |
+
self.block_size = config.block_size
|
127 |
+
|
128 |
+
if config.lstm:
|
129 |
+
self.lstm = nn.LSTM(input_size = config.n_embd, hidden_size = config.n_embd, num_layers = config.lstm_layers, dropout = 0.3, bidirectional = False)
|
130 |
+
self.apply(self._init_weights)
|
131 |
+
|
132 |
+
#logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
133 |
+
|
134 |
+
def get_block_size(self):
|
135 |
+
return self.block_size
|
136 |
+
|
137 |
+
def _init_weights(self, module):
|
138 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
139 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
140 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
141 |
+
module.bias.data.zero_()
|
142 |
+
elif isinstance(module, nn.LayerNorm):
|
143 |
+
module.bias.data.zero_()
|
144 |
+
module.weight.data.fill_(1.0)
|
145 |
+
|
146 |
+
def configure_optimizers(self, train_config):
|
147 |
+
"""
|
148 |
+
This long function is unfortunately doing something very simple and is being very defensive:
|
149 |
+
We are separating out all parameters of the model into two buckets: those that will experience
|
150 |
+
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
|
151 |
+
We are then returning the PyTorch optimizer object.
|
152 |
+
"""
|
153 |
+
|
154 |
+
# separate out all parameters to those that will and won't experience regularizing weight decay
|
155 |
+
decay = set()
|
156 |
+
no_decay = set()
|
157 |
+
whitelist_weight_modules = (torch.nn.Linear, torch.nn.LSTM)
|
158 |
+
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
159 |
+
for mn, m in self.named_modules():
|
160 |
+
for pn, p in m.named_parameters():
|
161 |
+
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
162 |
+
|
163 |
+
if pn.endswith('bias') or ('bias' in pn):
|
164 |
+
# all biases will not be decayed
|
165 |
+
no_decay.add(fpn)
|
166 |
+
elif (pn.endswith('weight') or ('weight' in pn)) and isinstance(m, whitelist_weight_modules):
|
167 |
+
# weights of whitelist modules will be weight decayed
|
168 |
+
decay.add(fpn)
|
169 |
+
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
170 |
+
# weights of blacklist modules will NOT be weight decayed
|
171 |
+
no_decay.add(fpn)
|
172 |
+
|
173 |
+
# special case the position embedding parameter in the root GPT module as not decayed
|
174 |
+
no_decay.add('pos_emb')
|
175 |
+
|
176 |
+
# validate that we considered every parameter
|
177 |
+
param_dict = {pn: p for pn, p in self.named_parameters()}
|
178 |
+
inter_params = decay & no_decay
|
179 |
+
union_params = decay | no_decay
|
180 |
+
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
|
181 |
+
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
|
182 |
+
% (str(param_dict.keys() - union_params), )
|
183 |
+
|
184 |
+
# create the pytorch optimizer object
|
185 |
+
optim_groups = [
|
186 |
+
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
|
187 |
+
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
188 |
+
]
|
189 |
+
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
|
190 |
+
return optimizer
|
191 |
+
|
192 |
+
def forward(self, idx, targets=None, prop = None):
|
193 |
+
b, t = idx.size()
|
194 |
+
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
|
195 |
+
|
196 |
+
if self.config.num_props:
|
197 |
+
assert prop.size(-1) == self.config.num_props, "Num_props should be equal to last dim of property vector"
|
198 |
+
|
199 |
+
# forward the GPT model
|
200 |
+
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
|
201 |
+
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
|
202 |
+
type_embeddings = self.type_emb(torch.ones((b,t), dtype = torch.long, device = idx.device))
|
203 |
+
x = self.drop(token_embeddings + position_embeddings + type_embeddings)
|
204 |
+
|
205 |
+
embed = x
|
206 |
+
|
207 |
+
if self.config.num_props:
|
208 |
+
type_embd = self.type_emb(torch.zeros((b, 1), dtype = torch.long, device = idx.device))
|
209 |
+
if prop.ndim == 2:
|
210 |
+
p = self.prop_nn(prop.unsqueeze(1)) # for single property
|
211 |
+
else:
|
212 |
+
p = self.prop_nn(prop) # for multiproperty
|
213 |
+
p += type_embd
|
214 |
+
x = torch.cat([p, x], 1)
|
215 |
+
|
216 |
+
# x = self.blocks(x)
|
217 |
+
attn_maps = []
|
218 |
+
|
219 |
+
for layer in self.blocks:
|
220 |
+
x, attn = layer(x)
|
221 |
+
attn_maps.append(attn)
|
222 |
+
|
223 |
+
x = self.ln_f(x)
|
224 |
+
logits = self.head(x)
|
225 |
+
|
226 |
+
if self.config.num_props:
|
227 |
+
num = int(bool(self.config.num_props))
|
228 |
+
else:
|
229 |
+
num = 0
|
230 |
+
|
231 |
+
logits = logits[:, num:, :]
|
232 |
+
|
233 |
+
# if we are given some desired targets also calculate the loss
|
234 |
+
loss = None
|
235 |
+
if targets is not None:
|
236 |
+
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.view(-1))
|
237 |
+
|
238 |
+
return logits, loss, attn_maps, embed # (num_layers, batch_size, num_heads, max_seq_len, max_seq_len)
|
239 |
+
|
240 |
+
|
241 |
+
@torch.no_grad()
|
242 |
+
def sample(self, x, steps, temperature=1.0, do_sample=False, top_k=None, top_p=None, prop=None):
|
243 |
+
"""
|
244 |
+
Take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
|
245 |
+
the sequence, feeding the predictions back into the model each time. Clearly the sampling
|
246 |
+
has quadratic complexity unlike an RNN that is only linear, and has a finite context window
|
247 |
+
of block_size, unlike an RNN that has an infinite context window.
|
248 |
+
|
249 |
+
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
|
250 |
+
"""
|
251 |
+
#model.eval()
|
252 |
+
|
253 |
+
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
254 |
+
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
255 |
+
Args:
|
256 |
+
logits: logits distribution shape (batch size x vocabulary size)
|
257 |
+
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
258 |
+
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
259 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
260 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
261 |
+
"""
|
262 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
263 |
+
if top_k > 0:
|
264 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
265 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
266 |
+
logits[indices_to_remove] = filter_value
|
267 |
+
|
268 |
+
if top_p > 0.0:
|
269 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
270 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
271 |
+
|
272 |
+
# Remove tokens with cumulative probability above the threshold
|
273 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
274 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
275 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
276 |
+
sorted_indices_to_remove[..., 0] = 0
|
277 |
+
|
278 |
+
# scatter sorted tensors to original indexing
|
279 |
+
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
280 |
+
logits[indices_to_remove] = filter_value
|
281 |
+
return logits
|
282 |
+
|
283 |
+
|
284 |
+
for k in range(steps):
|
285 |
+
x_cond = x if x.size(1) <= self.block_size else x[:, -self.block_size:] # crop context if needed
|
286 |
+
|
287 |
+
# forward the model to get the logits for the index in the sequence
|
288 |
+
logits, _, _, _ = self(x_cond, prop = prop) # for sampling, no target
|
289 |
+
|
290 |
+
# pluck the logits at the final step and scale by desired temperature
|
291 |
+
logits = logits[:, -1, :] / temperature
|
292 |
+
|
293 |
+
# optionally crop the logits to only the top k options OR using nucleus (top-p) filtering
|
294 |
+
#if top_k is not None:
|
295 |
+
# v, _ = torch.topk(logits, top_k)
|
296 |
+
# logits[logits < v[:, [-1]]] = -float('Inf')
|
297 |
+
logits = top_k_top_p_filtering(logits, top_p=top_p, top_k=top_k)
|
298 |
+
|
299 |
+
|
300 |
+
# apply softmax to convert logits to (normalized) probabilities
|
301 |
+
probs = F.softmax(logits, dim=-1)
|
302 |
+
|
303 |
+
# sample from the distribution or take the most likely
|
304 |
+
if do_sample:
|
305 |
+
x_next = torch.multinomial(probs, num_samples=1)
|
306 |
+
else:
|
307 |
+
_, x_next = torch.topk(probs, k=1, dim=-1)
|
308 |
+
|
309 |
+
# append sampled index to the running sequence and continue
|
310 |
+
x = torch.cat((x, x_next), dim=1)
|
311 |
+
|
312 |
+
return x[:, 1:]
|
pytorch_model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8d0071fa7c05449273bfaf60357e2c4f9525c8b8f47e6d313856160312a72b21
|
3 |
+
size 349946009
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
spaces
|
4 |
+
slices==2.0.4
|