Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -9,13 +9,23 @@ from pymatgen.io.ase import AseAtomsAdaptor
|
|
9 |
from ase.io import write as ase_write
|
10 |
import tempfile
|
11 |
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
-
# Load the model
|
15 |
model_path = "./"
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
# Load the tokenizer
|
20 |
tokenizer_path = "Voc_prior"
|
21 |
tokenizer = SimpleTokenizer(tokenizer_path)
|
@@ -27,22 +37,23 @@ try:
|
|
27 |
except Exception as e:
|
28 |
backend = SLICES(relax_model=None)
|
29 |
|
30 |
-
|
|
|
|
|
31 |
condition = torch.tensor([[float(formation_energy), float(band_gap)]], dtype=torch.float32)
|
32 |
context = '>'
|
33 |
x = torch.tensor([[tokenizer.stoi[context]]], dtype=torch.long)
|
34 |
|
35 |
with torch.no_grad():
|
36 |
-
generated =
|
37 |
-
|
38 |
-
|
39 |
|
40 |
return tokenizer.decode(generated[0].tolist())
|
41 |
|
42 |
def generate_slices(formation_energy, band_gap):
|
43 |
-
return
|
44 |
-
|
45 |
-
|
46 |
def wrap_structure(structure):
|
47 |
"""Wrap all atoms back into the unit cell."""
|
48 |
for i, site in enumerate(structure):
|
@@ -109,7 +120,7 @@ with gr.Blocks() as iface:
|
|
109 |
with gr.Row():
|
110 |
with gr.Column():
|
111 |
gr.Image("Figure1.png", label="De novo crystal generation by MatterGPT targeting desired Eg, Ef", width=1000, height=300)
|
112 |
-
gr.Markdown("**Enter desired properties to inversely design materials (encoded in SLICES), then decode it into crystal structure
|
113 |
|
114 |
with gr.Row():
|
115 |
with gr.Column(scale=2):
|
|
|
9 |
from ase.io import write as ase_write
|
10 |
import tempfile
|
11 |
import time
|
12 |
+
# 设置PyTorch使用的线程数
|
13 |
+
torch.set_num_threads(2)
|
14 |
+
def load_quantized_model(model_path):
|
15 |
+
model = MatterGPTWrapper.from_pretrained(model_path)
|
16 |
+
model.to('cpu')
|
17 |
+
model.eval()
|
18 |
+
quantized_model = torch.quantization.quantize_dynamic(
|
19 |
+
model, {torch.nn.Linear}, dtype=torch.qint8
|
20 |
+
)
|
21 |
+
return quantized_model
|
22 |
|
23 |
|
24 |
+
# Load and quantize the model
|
25 |
model_path = "./"
|
26 |
+
quantized_model = load_quantized_model(model_path)
|
27 |
+
quantized_model.to("cpu")
|
28 |
+
quantized_model.eval()
|
29 |
# Load the tokenizer
|
30 |
tokenizer_path = "Voc_prior"
|
31 |
tokenizer = SimpleTokenizer(tokenizer_path)
|
|
|
37 |
except Exception as e:
|
38 |
backend = SLICES(relax_model=None)
|
39 |
|
40 |
+
|
41 |
+
|
42 |
+
def generate_slices_quantized(quantized_model, tokenizer, formation_energy, band_gap, max_length, temperature, do_sample, top_k, top_p):
|
43 |
condition = torch.tensor([[float(formation_energy), float(band_gap)]], dtype=torch.float32)
|
44 |
context = '>'
|
45 |
x = torch.tensor([[tokenizer.stoi[context]]], dtype=torch.long)
|
46 |
|
47 |
with torch.no_grad():
|
48 |
+
generated = quantized_model.generate(x, prop=condition, max_length=max_length,
|
49 |
+
temperature=temperature, do_sample=do_sample,
|
50 |
+
top_k=top_k, top_p=top_p)
|
51 |
|
52 |
return tokenizer.decode(generated[0].tolist())
|
53 |
|
54 |
def generate_slices(formation_energy, band_gap):
|
55 |
+
return generate_slices_quantized(quantized_model, tokenizer, formation_energy, band_gap,
|
56 |
+
quantized_model.config.block_size, 1.2, True, 0, 0.9)
|
|
|
57 |
def wrap_structure(structure):
|
58 |
"""Wrap all atoms back into the unit cell."""
|
59 |
for i, site in enumerate(structure):
|
|
|
120 |
with gr.Row():
|
121 |
with gr.Column():
|
122 |
gr.Image("Figure1.png", label="De novo crystal generation by MatterGPT targeting desired Eg, Ef", width=1000, height=300)
|
123 |
+
gr.Markdown("**Enter desired properties to inversely design materials (encoded in SLICES), then decode it into crystal structure. Take 1-2 mins to finish with 2 cpus**")
|
124 |
|
125 |
with gr.Row():
|
126 |
with gr.Column(scale=2):
|