Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -9,23 +9,13 @@ from pymatgen.io.ase import AseAtomsAdaptor
|
|
9 |
from ase.io import write as ase_write
|
10 |
import tempfile
|
11 |
import time
|
12 |
-
# 设置PyTorch使用的线程数
|
13 |
-
torch.set_num_threads(2)
|
14 |
-
def load_quantized_model(model_path):
|
15 |
-
model = MatterGPTWrapper.from_pretrained(model_path)
|
16 |
-
model.to('cpu')
|
17 |
-
model.eval()
|
18 |
-
quantized_model = torch.quantization.quantize_dynamic(
|
19 |
-
model, {torch.nn.Linear}, dtype=torch.qint8
|
20 |
-
)
|
21 |
-
return quantized_model
|
22 |
|
23 |
|
24 |
-
# Load
|
25 |
model_path = "./"
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
# Load the tokenizer
|
30 |
tokenizer_path = "Voc_prior"
|
31 |
tokenizer = SimpleTokenizer(tokenizer_path)
|
@@ -37,23 +27,22 @@ try:
|
|
37 |
except Exception as e:
|
38 |
backend = SLICES(relax_model=None)
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
def generate_slices_quantized(quantized_model, tokenizer, formation_energy, band_gap, max_length, temperature, do_sample, top_k, top_p):
|
43 |
condition = torch.tensor([[float(formation_energy), float(band_gap)]], dtype=torch.float32)
|
44 |
context = '>'
|
45 |
x = torch.tensor([[tokenizer.stoi[context]]], dtype=torch.long)
|
46 |
|
47 |
with torch.no_grad():
|
48 |
-
generated =
|
49 |
-
|
50 |
-
|
51 |
|
52 |
return tokenizer.decode(generated[0].tolist())
|
53 |
|
54 |
def generate_slices(formation_energy, band_gap):
|
55 |
-
return
|
56 |
-
|
|
|
57 |
def wrap_structure(structure):
|
58 |
"""Wrap all atoms back into the unit cell."""
|
59 |
for i, site in enumerate(structure):
|
|
|
9 |
from ase.io import write as ase_write
|
10 |
import tempfile
|
11 |
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
+
# Load the model
|
15 |
model_path = "./"
|
16 |
+
model = MatterGPTWrapper.from_pretrained(model_path)
|
17 |
+
model.to("cpu")
|
18 |
+
model.eval()
|
19 |
# Load the tokenizer
|
20 |
tokenizer_path = "Voc_prior"
|
21 |
tokenizer = SimpleTokenizer(tokenizer_path)
|
|
|
27 |
except Exception as e:
|
28 |
backend = SLICES(relax_model=None)
|
29 |
|
30 |
+
def generate_slices_optimized(model, tokenizer, formation_energy, band_gap, max_length, temperature, do_sample, top_k, top_p):
|
|
|
|
|
31 |
condition = torch.tensor([[float(formation_energy), float(band_gap)]], dtype=torch.float32)
|
32 |
context = '>'
|
33 |
x = torch.tensor([[tokenizer.stoi[context]]], dtype=torch.long)
|
34 |
|
35 |
with torch.no_grad():
|
36 |
+
generated = model.generate(x, prop=condition, max_length=max_length,
|
37 |
+
temperature=temperature, do_sample=do_sample,
|
38 |
+
top_k=top_k, top_p=top_p)
|
39 |
|
40 |
return tokenizer.decode(generated[0].tolist())
|
41 |
|
42 |
def generate_slices(formation_energy, band_gap):
|
43 |
+
return generate_slices_optimized(model, tokenizer, formation_energy, band_gap,
|
44 |
+
model.config.block_size, 1.2, True, 0, 0.9)
|
45 |
+
|
46 |
def wrap_structure(structure):
|
47 |
"""Wrap all atoms back into the unit cell."""
|
48 |
for i, site in enumerate(structure):
|