xiaohang07 commited on
Commit
ddf987f
·
verified ·
1 Parent(s): e9b3218

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -12
app.py CHANGED
@@ -9,13 +9,23 @@ from pymatgen.io.ase import AseAtomsAdaptor
9
  from ase.io import write as ase_write
10
  import tempfile
11
  import time
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
- # Load the model
15
  model_path = "./"
16
- model = MatterGPTWrapper.from_pretrained(model_path)
17
- model.to("cpu")
18
- model.eval()
19
  # Load the tokenizer
20
  tokenizer_path = "Voc_prior"
21
  tokenizer = SimpleTokenizer(tokenizer_path)
@@ -27,22 +37,23 @@ try:
27
  except Exception as e:
28
  backend = SLICES(relax_model=None)
29
 
30
- def generate_slices_optimized(model, tokenizer, formation_energy, band_gap, max_length, temperature, do_sample, top_k, top_p):
 
 
31
  condition = torch.tensor([[float(formation_energy), float(band_gap)]], dtype=torch.float32)
32
  context = '>'
33
  x = torch.tensor([[tokenizer.stoi[context]]], dtype=torch.long)
34
 
35
  with torch.no_grad():
36
- generated = model.generate(x, prop=condition, max_length=max_length,
37
- temperature=temperature, do_sample=do_sample,
38
- top_k=top_k, top_p=top_p)
39
 
40
  return tokenizer.decode(generated[0].tolist())
41
 
42
  def generate_slices(formation_energy, band_gap):
43
- return generate_slices_optimized(model, tokenizer, formation_energy, band_gap,
44
- model.config.block_size, 1.2, True, 0, 0.9)
45
-
46
  def wrap_structure(structure):
47
  """Wrap all atoms back into the unit cell."""
48
  for i, site in enumerate(structure):
@@ -109,7 +120,7 @@ with gr.Blocks() as iface:
109
  with gr.Row():
110
  with gr.Column():
111
  gr.Image("Figure1.png", label="De novo crystal generation by MatterGPT targeting desired Eg, Ef", width=1000, height=300)
112
- gr.Markdown("**Enter desired properties to inversely design materials (encoded in SLICES), then decode it into crystal structure.**")
113
 
114
  with gr.Row():
115
  with gr.Column(scale=2):
 
9
  from ase.io import write as ase_write
10
  import tempfile
11
  import time
12
+ # 设置PyTorch使用的线程数
13
+ torch.set_num_threads(2)
14
+ def load_quantized_model(model_path):
15
+ model = MatterGPTWrapper.from_pretrained(model_path)
16
+ model.to('cpu')
17
+ model.eval()
18
+ quantized_model = torch.quantization.quantize_dynamic(
19
+ model, {torch.nn.Linear}, dtype=torch.qint8
20
+ )
21
+ return quantized_model
22
 
23
 
24
+ # Load and quantize the model
25
  model_path = "./"
26
+ quantized_model = load_quantized_model(model_path)
27
+ quantized_model.to("cpu")
28
+ quantized_model.eval()
29
  # Load the tokenizer
30
  tokenizer_path = "Voc_prior"
31
  tokenizer = SimpleTokenizer(tokenizer_path)
 
37
  except Exception as e:
38
  backend = SLICES(relax_model=None)
39
 
40
+
41
+
42
+ def generate_slices_quantized(quantized_model, tokenizer, formation_energy, band_gap, max_length, temperature, do_sample, top_k, top_p):
43
  condition = torch.tensor([[float(formation_energy), float(band_gap)]], dtype=torch.float32)
44
  context = '>'
45
  x = torch.tensor([[tokenizer.stoi[context]]], dtype=torch.long)
46
 
47
  with torch.no_grad():
48
+ generated = quantized_model.generate(x, prop=condition, max_length=max_length,
49
+ temperature=temperature, do_sample=do_sample,
50
+ top_k=top_k, top_p=top_p)
51
 
52
  return tokenizer.decode(generated[0].tolist())
53
 
54
  def generate_slices(formation_energy, band_gap):
55
+ return generate_slices_quantized(quantized_model, tokenizer, formation_energy, band_gap,
56
+ quantized_model.config.block_size, 1.2, True, 0, 0.9)
 
57
  def wrap_structure(structure):
58
  """Wrap all atoms back into the unit cell."""
59
  for i, site in enumerate(structure):
 
120
  with gr.Row():
121
  with gr.Column():
122
  gr.Image("Figure1.png", label="De novo crystal generation by MatterGPT targeting desired Eg, Ef", width=1000, height=300)
123
+ gr.Markdown("**Enter desired properties to inversely design materials (encoded in SLICES), then decode it into crystal structure. Take 1-2 mins to finish with 2 cpus**")
124
 
125
  with gr.Row():
126
  with gr.Column(scale=2):