Spaces:
Running
Running
File size: 6,177 Bytes
d94c1ca ddf987f d94c1ca ddf987f d94c1ca ddf987f d94c1ca ddf987f d94c1ca ddf987f d94c1ca ddf987f d94c1ca b610a5a d94c1ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import gradio as gr
import torch
from mattergpt_wrapper import MatterGPTWrapper, SimpleTokenizer
import os
from slices.core import SLICES
from pymatgen.core.structure import Structure
from pymatgen.io.cif import CifWriter
from pymatgen.io.ase import AseAtomsAdaptor
from ase.io import write as ase_write
import tempfile
import time
# 设置PyTorch使用的线程数
torch.set_num_threads(2)
def load_quantized_model(model_path):
model = MatterGPTWrapper.from_pretrained(model_path)
model.to('cpu')
model.eval()
quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
return quantized_model
# Load and quantize the model
model_path = "./"
quantized_model = load_quantized_model(model_path)
quantized_model.to("cpu")
quantized_model.eval()
# Load the tokenizer
tokenizer_path = "Voc_prior"
tokenizer = SimpleTokenizer(tokenizer_path)
# Initialize SLICES backend
try:
backend = SLICES(relax_model="chgnet",fmax=0.4,steps=25)
except Exception as e:
backend = SLICES(relax_model=None)
def generate_slices_quantized(quantized_model, tokenizer, formation_energy, band_gap, max_length, temperature, do_sample, top_k, top_p):
condition = torch.tensor([[float(formation_energy), float(band_gap)]], dtype=torch.float32)
context = '>'
x = torch.tensor([[tokenizer.stoi[context]]], dtype=torch.long)
with torch.no_grad():
generated = quantized_model.generate(x, prop=condition, max_length=max_length,
temperature=temperature, do_sample=do_sample,
top_k=top_k, top_p=top_p)
return tokenizer.decode(generated[0].tolist())
def generate_slices(formation_energy, band_gap):
return generate_slices_quantized(quantized_model, tokenizer, formation_energy, band_gap,
quantized_model.config.block_size, 1.2, True, 0, 0.9)
def wrap_structure(structure):
"""Wrap all atoms back into the unit cell."""
for i, site in enumerate(structure):
frac_coords = site.frac_coords % 1.0
structure.replace(i, species=site.species, coords=frac_coords, coords_are_cartesian=False)
return structure
def convert_and_visualize(slices_string):
try:
structure, energy = backend.SLICES2structure(slices_string)
# Wrap atoms back into the unit cell
structure = wrap_structure(structure)
# Generate CIF and save to temporary file
cif_file = tempfile.NamedTemporaryFile(mode='w', suffix='.cif', delete=False)
cif_writer = CifWriter(structure)
cif_writer.write_file(cif_file.name)
# Generate structure summary
summary = f"Formula: {structure.composition.reduced_formula}\n"
summary += f"Number of sites: {len(structure)}\n"
summary += f"Lattice parameters: a={structure.lattice.a:.3f}, b={structure.lattice.b:.3f}, c={structure.lattice.c:.3f}\n"
summary += f"Angles: alpha={structure.lattice.alpha:.2f}, beta={structure.lattice.beta:.2f}, gamma={structure.lattice.gamma:.2f}\n"
summary += f"Volume: {structure.volume:.3f} ų\n"
summary += f"Density: {structure.density:.3f} g/cm³"
# Generate structure image using ASE and save to temporary file
atoms = AseAtomsAdaptor.get_atoms(structure)
image_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
ase_write(image_file.name, atoms, format='png', rotation='10x,10y,10z')
return cif_file.name, image_file.name, summary, f"Conversion successful. Energy: {energy:.4f} eV/atom", True
except Exception as e:
return "", "", "", f"Conversion failed. Error: {str(e)}", False
def generate_and_convert(formation_energy, band_gap):
max_attempts = 5
start_time = time.time()
max_time = 300 # 5 minutes maximum execution time
for attempt in range(max_attempts):
if time.time() - start_time > max_time:
return "Exceeded maximum execution time", "", "", "", "Generation and conversion failed due to timeout"
slices_string = generate_slices(formation_energy, band_gap)
cif_file, image_file, structure_summary, status, success = convert_and_visualize(slices_string)
if success:
return slices_string, cif_file, image_file, structure_summary, f"Successful on attempt {attempt + 1}: {status}"
if attempt == max_attempts - 1:
return slices_string, "", "", "", f"Failed after {max_attempts} attempts: {status}"
return "Failed to generate valid SLICES string", "", "", "", "Generation failed"
# Create the Gradio interface
with gr.Blocks() as iface:
gr.Markdown("# Crystal Inverse Designer: From Properties to Structures")
with gr.Row():
with gr.Column():
gr.Image("Figure1.png", label="De novo crystal generation by MatterGPT targeting desired Eg, Ef", width=1000, height=300)
gr.Markdown("**Enter desired properties to inversely design materials (encoded in SLICES), then decode it into crystal structure.**")
gr.Markdown("**Allow 1-2 minutes for completion using 2 CPUs.**")
with gr.Row():
with gr.Column(scale=2):
band_gap = gr.Number(label="Band Gap (eV)", value=2.0)
formation_energy = gr.Number(label="Formation Energy (eV/atom)", value=-1.0)
generate_button = gr.Button("Generate")
with gr.Column(scale=3):
slices_output = gr.Textbox(label="Generated SLICES String")
cif_output = gr.File(label="Download CIF", file_types=[".cif"])
structure_image = gr.Image(label="Structure Visualization")
structure_summary = gr.Textbox(label="Structure Summary", lines=6)
conversion_status = gr.Textbox(label="Conversion Status")
generate_button.click(
generate_and_convert,
inputs=[formation_energy, band_gap],
outputs=[slices_output, cif_output, structure_image, structure_summary, conversion_status]
)
iface.launch(share=True) |