File size: 6,177 Bytes
d94c1ca
 
 
 
 
 
 
 
 
 
 
ddf987f
 
 
 
 
 
 
 
 
 
d94c1ca
 
ddf987f
d94c1ca
ddf987f
 
 
d94c1ca
 
 
 
 
 
 
 
 
 
 
ddf987f
 
 
d94c1ca
 
 
 
 
ddf987f
 
 
d94c1ca
 
 
 
ddf987f
 
d94c1ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b610a5a
 
d94c1ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
import torch
from mattergpt_wrapper import MatterGPTWrapper, SimpleTokenizer
import os
from slices.core import SLICES
from pymatgen.core.structure import Structure
from pymatgen.io.cif import CifWriter
from pymatgen.io.ase import AseAtomsAdaptor
from ase.io import write as ase_write
import tempfile
import time
# 设置PyTorch使用的线程数
torch.set_num_threads(2)
def load_quantized_model(model_path):
    model = MatterGPTWrapper.from_pretrained(model_path)
    model.to('cpu')
    model.eval()
    quantized_model = torch.quantization.quantize_dynamic(
        model, {torch.nn.Linear}, dtype=torch.qint8
    )
    return quantized_model


# Load and quantize the model
model_path = "./"
quantized_model = load_quantized_model(model_path)
quantized_model.to("cpu")
quantized_model.eval()
# Load the tokenizer
tokenizer_path = "Voc_prior"
tokenizer = SimpleTokenizer(tokenizer_path)

# Initialize SLICES backend
try:
    backend = SLICES(relax_model="chgnet",fmax=0.4,steps=25)

except Exception as e:
    backend = SLICES(relax_model=None)



def generate_slices_quantized(quantized_model, tokenizer, formation_energy, band_gap, max_length, temperature, do_sample, top_k, top_p):
    condition = torch.tensor([[float(formation_energy), float(band_gap)]], dtype=torch.float32)
    context = '>'
    x = torch.tensor([[tokenizer.stoi[context]]], dtype=torch.long)
    
    with torch.no_grad():
        generated = quantized_model.generate(x, prop=condition, max_length=max_length, 
                                             temperature=temperature, do_sample=do_sample, 
                                             top_k=top_k, top_p=top_p)
    
    return tokenizer.decode(generated[0].tolist())

def generate_slices(formation_energy, band_gap):
    return generate_slices_quantized(quantized_model, tokenizer, formation_energy, band_gap, 
                                     quantized_model.config.block_size, 1.2, True, 0, 0.9)
def wrap_structure(structure):
    """Wrap all atoms back into the unit cell."""
    for i, site in enumerate(structure):
        frac_coords = site.frac_coords % 1.0
        structure.replace(i, species=site.species, coords=frac_coords, coords_are_cartesian=False)
    return structure

def convert_and_visualize(slices_string):
    try:
        structure, energy = backend.SLICES2structure(slices_string)
        
        # Wrap atoms back into the unit cell
        structure = wrap_structure(structure)
        
        # Generate CIF and save to temporary file
        cif_file = tempfile.NamedTemporaryFile(mode='w', suffix='.cif', delete=False)
        cif_writer = CifWriter(structure)
        cif_writer.write_file(cif_file.name)
        
        # Generate structure summary
        summary = f"Formula: {structure.composition.reduced_formula}\n"
        summary += f"Number of sites: {len(structure)}\n"
        summary += f"Lattice parameters: a={structure.lattice.a:.3f}, b={structure.lattice.b:.3f}, c={structure.lattice.c:.3f}\n"
        summary += f"Angles: alpha={structure.lattice.alpha:.2f}, beta={structure.lattice.beta:.2f}, gamma={structure.lattice.gamma:.2f}\n"
        summary += f"Volume: {structure.volume:.3f} ų\n"
        summary += f"Density: {structure.density:.3f} g/cm³"
        
        # Generate structure image using ASE and save to temporary file
        atoms = AseAtomsAdaptor.get_atoms(structure)
        image_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
        ase_write(image_file.name, atoms, format='png', rotation='10x,10y,10z')
        

        return cif_file.name, image_file.name, summary, f"Conversion successful. Energy: {energy:.4f} eV/atom", True
    except Exception as e:

        return "", "", "", f"Conversion failed. Error: {str(e)}", False

def generate_and_convert(formation_energy, band_gap):
    max_attempts = 5
    start_time = time.time()
    max_time = 300  # 5 minutes maximum execution time

    for attempt in range(max_attempts):
        if time.time() - start_time > max_time:
            return "Exceeded maximum execution time", "", "", "", "Generation and conversion failed due to timeout"

        slices_string = generate_slices(formation_energy, band_gap)
        cif_file, image_file, structure_summary, status, success = convert_and_visualize(slices_string)
        
        if success:
            return slices_string, cif_file, image_file, structure_summary, f"Successful on attempt {attempt + 1}: {status}"
        
        
        if attempt == max_attempts - 1:
            return slices_string, "", "", "", f"Failed after {max_attempts} attempts: {status}"

    return "Failed to generate valid SLICES string", "", "", "", "Generation failed"

# Create the Gradio interface
with gr.Blocks() as iface:
    gr.Markdown("# Crystal Inverse Designer: From Properties to Structures")
    
    with gr.Row():
        with gr.Column():
            gr.Image("Figure1.png", label="De novo crystal generation by MatterGPT targeting desired Eg, Ef", width=1000, height=300)
            gr.Markdown("**Enter desired properties to inversely design materials (encoded in SLICES), then decode it into crystal structure.**")
            gr.Markdown("**Allow 1-2 minutes for completion using 2 CPUs.**")
    
    with gr.Row():
        with gr.Column(scale=2):
            band_gap = gr.Number(label="Band Gap (eV)", value=2.0)
            formation_energy = gr.Number(label="Formation Energy (eV/atom)", value=-1.0)
            generate_button = gr.Button("Generate")
        
        with gr.Column(scale=3):
            slices_output = gr.Textbox(label="Generated SLICES String")
            cif_output = gr.File(label="Download CIF", file_types=[".cif"])
            structure_image = gr.Image(label="Structure Visualization")
            structure_summary = gr.Textbox(label="Structure Summary", lines=6)
            conversion_status = gr.Textbox(label="Conversion Status")
    
    generate_button.click(
        generate_and_convert,
        inputs=[formation_energy, band_gap],
        outputs=[slices_output, cif_output, structure_image, structure_summary, conversion_status]
    )

iface.launch(share=True)