|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import re |
|
|
|
import ipywidgets as widgets |
|
|
|
|
|
class DnaRnaProteinEntityWidget(widgets.VBox): |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
self.layout.border = "1px solid black" |
|
self.layout.padding = "1px" |
|
self.layout.margin = "1px" |
|
self.molecule_type_dropdown = widgets.Dropdown( |
|
options=["Protein", "DNA", "RNA"], |
|
value="Protein", |
|
disabled=False, |
|
) |
|
self.recode_name = { |
|
"Protein": "proteinChain", |
|
"DNA": "dnaSequence", |
|
"RNA": "rnaSequence", |
|
} |
|
|
|
self.copies_int_text = widgets.BoundedIntText( |
|
value=1, min=1, max=100, step=1, description="Copies:", disabled=False |
|
) |
|
|
|
self.sequence_text = widgets.Textarea( |
|
value="", |
|
placeholder="Paste sequence or fasta", |
|
description="Sequence:", |
|
disabled=False, |
|
) |
|
|
|
def handle_submit(sender): |
|
item = sender["owner"] |
|
cleaned_value = re.sub(r"[^a-zA-Z]", "", item.value) |
|
item.value = cleaned_value.upper() |
|
|
|
self.sequence_text.observe(handle_submit, names="value") |
|
|
|
self.add_modification = widgets.Button( |
|
description="modification", |
|
disabled=False, |
|
tooltip="add modification", |
|
icon="plus", |
|
) |
|
|
|
self.modifications = list() |
|
self.modifications_id = list() |
|
self.modifications_count = 0 |
|
|
|
self.add_modification.on_click(self.add_modification_callback) |
|
|
|
self.update() |
|
|
|
def add_modification_callback(self, b): |
|
self.modifications_count += 1 |
|
index = self.modifications_count |
|
modificationType = widgets.Text( |
|
value="CCD_XXX", |
|
placeholder="CCD_XXX", |
|
description="modificationType", |
|
max_length=6, |
|
disabled=False, |
|
) |
|
|
|
def handle_submit(sender): |
|
item = sender["owner"] |
|
item.value = item.value.upper() |
|
if item.value[0:4] != "CCD_": |
|
item.value = "CCD_" |
|
item.value = item.value[0:10] |
|
if len(item.value) < 4: |
|
item.value = "CCD_" |
|
item.value = re.sub(r"[^A-Z0-9_]", "", item.value) |
|
|
|
modificationType.observe(handle_submit, names="value") |
|
|
|
ptmPosition = widgets.BoundedIntText( |
|
value=1, min=1, description="Position:", disabled=False |
|
) |
|
delete_button = widgets.Button( |
|
description="modification", |
|
disabled=False, |
|
button_style="danger", |
|
tooltip="Click to delete this row", |
|
icon="trash", |
|
) |
|
delete_button.on_click(lambda b: self.delete_ptm_row(index)) |
|
self.modifications.append( |
|
widgets.HBox([modificationType, ptmPosition, delete_button]) |
|
) |
|
self.modifications_id.append(index) |
|
self.update() |
|
|
|
def update(self): |
|
self.children = ( |
|
[ |
|
widgets.HBox( |
|
[ |
|
self.molecule_type_dropdown, |
|
self.copies_int_text, |
|
self.sequence_text, |
|
] |
|
) |
|
] |
|
+ [i for i in self.modifications] |
|
+ [self.add_modification] |
|
) |
|
|
|
def delete_ptm_row(self, index): |
|
for idx in range(len(self.modifications_id)): |
|
if self.modifications_id[idx] == index: |
|
self.modifications.pop(idx) |
|
self.modifications_id.pop(idx) |
|
break |
|
self.update() |
|
|
|
def get_result(self): |
|
result = dict() |
|
sequence_type = self.molecule_type_dropdown.value |
|
recode_name = self.recode_name[sequence_type] |
|
sequence = self.sequence_text.value |
|
result[recode_name] = dict() |
|
result[recode_name]["count"] = self.copies_int_text.value |
|
result[recode_name]["sequence"] = sequence |
|
assert len(result[recode_name]["sequence"]) > 0, "sequence length must > 0" |
|
result[recode_name]["modifications"] = [ |
|
{ |
|
"modificationType": item.children[0].value, |
|
"Position": item.children[1].value, |
|
} |
|
for item in self.modifications |
|
] |
|
for item in result[recode_name]["modifications"]: |
|
pos = item["Position"] |
|
assert pos <= len(result[recode_name]["sequence"]), "position out of range" |
|
|
|
if sequence_type == "Protein": |
|
valid_chars = set("ARNDCQEGHILKMFPSTWYVX") |
|
for char in sequence: |
|
assert ( |
|
char in valid_chars |
|
), f"Invalid character '{char}' in sequence. valid set is {valid_chars}" |
|
elif sequence_type == "DNA": |
|
valid_chars = set("AGCTXINU") |
|
for char in sequence: |
|
assert ( |
|
char in valid_chars |
|
), f"Invalid character '{char}' in sequence. valid set is {valid_chars}" |
|
elif sequence_type == "RNA": |
|
valid_chars = set("AGCUXIN") |
|
for char in sequence: |
|
assert ( |
|
char in valid_chars |
|
), f"Invalid character '{char}' in sequence. valid set is {valid_chars}" |
|
return result |
|
|
|
|
|
class LigandSmilesEntityWidget(widgets.VBox): |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
self.layout.border = "1px solid black" |
|
self.layout.padding = "1px" |
|
self.layout.margin = "1px" |
|
|
|
self.molecule_type_dropdown = widgets.Dropdown( |
|
options=["Ligand"], |
|
value="Ligand", |
|
disabled=False, |
|
) |
|
|
|
self.copies_int_text = widgets.BoundedIntText( |
|
value=1, min=1, max=100, step=1, description="Copies:", disabled=False |
|
) |
|
|
|
self.smiles_text = widgets.Textarea( |
|
value="", |
|
placeholder="SMILES", |
|
description="SMILES:", |
|
disabled=False, |
|
) |
|
|
|
self.update() |
|
|
|
def update(self): |
|
self.children = [ |
|
widgets.HBox( |
|
[self.molecule_type_dropdown, self.copies_int_text, self.smiles_text] |
|
) |
|
] |
|
|
|
def get_result(self): |
|
result = dict() |
|
name = "ligand" |
|
result[name] = dict() |
|
result[name][name] = self.smiles_text.value |
|
assert len(result[name][name]) > 0, "SMILES length must > 0" |
|
result[name]["count"] = self.copies_int_text.value |
|
return result |
|
|
|
|
|
class LigandIonCCDEntityWidget(widgets.VBox): |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
self.layout.border = "1px solid black" |
|
self.layout.padding = "1px" |
|
self.layout.margin = "1px" |
|
|
|
self.molecule_type_dropdown = widgets.Dropdown( |
|
options=["Ligand", "Ion"], |
|
value="Ligand", |
|
disabled=False, |
|
) |
|
|
|
self.copies_int_text = widgets.BoundedIntText( |
|
value=1, min=1, max=100, step=1, description="Copies:", disabled=False |
|
) |
|
|
|
self.sequence_text = widgets.Text( |
|
value="CCD_XXX", |
|
placeholder="Paste sequence or fasta", |
|
description="CCD:", |
|
disabled=False, |
|
) |
|
|
|
def handle_submit(sender): |
|
item = sender["owner"] |
|
item.value = item.value.upper() |
|
if item.value[0:4] != "CCD_": |
|
item.value = "CCD_" |
|
if len(item.value) < 4: |
|
item.value = "CCD_" |
|
|
|
self.sequence_text.observe(handle_submit, names="value") |
|
|
|
self.update() |
|
|
|
def update(self): |
|
self.children = [ |
|
widgets.HBox( |
|
[self.molecule_type_dropdown, self.copies_int_text, self.sequence_text] |
|
) |
|
] |
|
|
|
def get_result(self): |
|
result = dict() |
|
if self.molecule_type_dropdown.value == "Ion": |
|
name = "ion" |
|
else: |
|
name = "ligand" |
|
result[name] = dict() |
|
result[name][name] = ( |
|
self.sequence_text.value |
|
if name == "ligand" |
|
else self.sequence_text.value[4::] |
|
) |
|
result[name]["count"] = self.copies_int_text.value |
|
return result |
|
|
|
|
|
class CovalentBondsWidget(widgets.VBox): |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
self.layout.border = "1px solid black" |
|
self.layout.padding = "1px" |
|
self.layout.margin = "1px" |
|
|
|
self.convalent_bonds = list() |
|
self.convalent_bonds_id = list() |
|
self.convalent_bonds_count = 0 |
|
|
|
self.add_convalent_bond_buttom = widgets.Button( |
|
description="convalent_bond", |
|
disabled=False, |
|
tooltip="add convalent_bond", |
|
icon="plus", |
|
) |
|
|
|
self.add_convalent_bond_buttom.on_click(self.on_add_convalent_bond) |
|
|
|
self.update() |
|
|
|
def on_add_convalent_bond(self, b): |
|
self.convalent_bonds_count += 1 |
|
index = self.convalent_bonds_count |
|
|
|
left_entity = widgets.BoundedIntText( |
|
value=1, min=1, description="left entity", disabled=False |
|
) |
|
left_position = widgets.BoundedIntText( |
|
value=1, min=1, description="left pos", disabled=False |
|
) |
|
left_atom = widgets.Text( |
|
value="", placeholder="left atom", description="left atom", disabled=False |
|
) |
|
right_entity = widgets.BoundedIntText( |
|
value=1, min=1, description="right entity", disabled=False |
|
) |
|
right_position = widgets.BoundedIntText( |
|
value=1, min=1, description="right pos", disabled=False |
|
) |
|
right_atom = widgets.Text( |
|
value="", placeholder="right atom", description="right atom", disabled=False |
|
) |
|
delete_button = widgets.Button( |
|
description="convalent bond", |
|
disabled=False, |
|
button_style="danger", |
|
tooltip="Click to delete this row", |
|
icon="trash", |
|
) |
|
delete_button.on_click(lambda b: self.delete_row(index)) |
|
self.convalent_bonds.append( |
|
widgets.HBox( |
|
[ |
|
left_entity, |
|
left_position, |
|
left_atom, |
|
right_entity, |
|
right_position, |
|
right_atom, |
|
delete_button, |
|
] |
|
) |
|
) |
|
self.convalent_bonds_id.append(index) |
|
self.update() |
|
|
|
def delete_row(self, index): |
|
for idx in range(len(self.convalent_bonds_id)): |
|
if self.convalent_bonds_id[idx] == index: |
|
self.convalent_bonds.pop(idx) |
|
self.convalent_bonds_id.pop(idx) |
|
break |
|
self.update() |
|
|
|
def update(self): |
|
self.children = [item for item in self.convalent_bonds] + [ |
|
self.add_convalent_bond_buttom |
|
] |
|
|
|
def get_result(self): |
|
result = [ |
|
{ |
|
"left_entity": item.children[0].value, |
|
"left_position": item.children[1].value, |
|
"left_atom": item.children[2].value, |
|
"right_entity": item.children[3].value, |
|
"right_position": item.children[4].value, |
|
"right_atom": item.children[5].value, |
|
} |
|
for item in self.convalent_bonds |
|
] |
|
return result |
|
|
|
|
|
class ModelWidget(widgets.VBox): |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
self.layout.border = "1px solid black" |
|
self.layout.padding = "1px" |
|
self.layout.margin = "1px" |
|
|
|
self.sample = widgets.BoundedIntText( |
|
value=5, |
|
description="N_sample", |
|
tooltip="diffusion model samples", |
|
disabled=False, |
|
) |
|
|
|
self.step = widgets.BoundedIntText( |
|
value=200, |
|
max=500, |
|
description="N_step", |
|
tooltip="diffusion model step", |
|
disabled=False, |
|
) |
|
|
|
self.cycle = widgets.BoundedIntText( |
|
value=10, description="N_cycle", tooltip="n cycle of model", disabled=False |
|
) |
|
|
|
self.seeds = widgets.Text( |
|
value="1,2,3,4,5", |
|
description="seeds", |
|
tooltip="random seeds", |
|
disabled=False, |
|
) |
|
|
|
self.version = widgets.Dropdown( |
|
options=["v1", "v2", "v3", "v4", "v5"], |
|
value="v1", |
|
description="version", |
|
tooltip="model version", |
|
) |
|
|
|
self.update() |
|
|
|
def update(self): |
|
self.children = [ |
|
widgets.HBox([self.sample, self.step, self.cycle, self.seeds, self.version]) |
|
] |
|
|
|
def get_result(self): |
|
results = dict() |
|
results["model_seeds"] = [int(i) for i in self.seeds.value.split(",")] |
|
results["model_version"] = self.version.value |
|
results["N_sample"] = self.sample.value |
|
results["N_step"] = self.step.value |
|
results["N_cycle"] = self.cycle.value |
|
return results |
|
|
|
|
|
class ProtenixInputViewer(widgets.VBox): |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
self.layout.border = "1px solid black" |
|
self.layout.padding = "1px" |
|
self.layout.margin = "1px" |
|
|
|
self.name = widgets.Text( |
|
value="", placeholder="name", description="name", disabled=False |
|
) |
|
|
|
def handle_submit(sender): |
|
item = sender["owner"] |
|
cleaned_value = re.sub(r"[^a-zA-Z0-9_]", "", item.value) |
|
item.value = cleaned_value |
|
|
|
self.name.observe(handle_submit, names="value") |
|
|
|
self.use_msa = widgets.Checkbox( |
|
value=True, description="use_msa", disabled=False |
|
) |
|
self.atom_confidence = widgets.Checkbox( |
|
value=False, description="atom_confidence", disabled=False |
|
) |
|
|
|
self.ligand_smiles_entities = list() |
|
self.ligand_smiles_entities_id = list() |
|
self.ligand_smiles_entities_count = 0 |
|
|
|
self.dna_rna_protein_entities = list() |
|
self.dna_rna_protein_entities_id = list() |
|
self.dna_rna_protein_entities_count = 0 |
|
|
|
self.covalent_bonds = CovalentBondsWidget() |
|
|
|
self.add_ligand_smiles = widgets.Button( |
|
description="Ligand(SMILES)", |
|
disabled=False, |
|
tooltip="add Ligand(SMILES) Entity", |
|
icon="plus", |
|
) |
|
|
|
self.add_ligand_ion_ccd = widgets.Button( |
|
description="Ligand/Ion CCD", |
|
disabled=False, |
|
tooltip="add Ligand/Ion Entity", |
|
icon="plus", |
|
) |
|
|
|
self.add_dna_rna_protein = widgets.Button( |
|
description="Dna/Rna/Protein", |
|
disabled=False, |
|
tooltip="add Dna/Rna/Protein Entity", |
|
icon="plus", |
|
) |
|
|
|
self.model_parameter = ModelWidget() |
|
|
|
self.add_ligand_smiles.on_click(self.add_ligand_smiles_callback) |
|
self.add_ligand_ion_ccd.on_click(self.add_ligand_ion_ccd_callback) |
|
self.add_dna_rna_protein.on_click(self.add_dna_rna_protein_callback) |
|
|
|
self.update() |
|
|
|
def add_ligand_smiles_callback(self, b): |
|
self.ligand_smiles_entities_count += 1 |
|
index = self.ligand_smiles_entities_count |
|
item = LigandSmilesEntityWidget() |
|
delete_button = widgets.Button( |
|
description="entity", |
|
disabled=False, |
|
button_style="danger", |
|
tooltip="Click to delete this row", |
|
icon="trash", |
|
) |
|
delete_button.on_click(lambda b: self.delete_ligand_ion_ccd_row(index)) |
|
self.ligand_smiles_entities.append(widgets.HBox([item, delete_button])) |
|
self.ligand_smiles_entities_id.append(index) |
|
self.update() |
|
|
|
def add_ligand_ion_ccd_callback(self, b): |
|
self.ligand_smiles_entities_count += 1 |
|
index = self.ligand_smiles_entities_count |
|
item = LigandIonCCDEntityWidget() |
|
delete_button = widgets.Button( |
|
description="entity", |
|
disabled=False, |
|
button_style="danger", |
|
tooltip="Click to delete this row", |
|
icon="trash", |
|
) |
|
delete_button.on_click(lambda b: self.delete_ligand_ion_ccd_row(index)) |
|
self.ligand_smiles_entities.append(widgets.HBox([item, delete_button])) |
|
self.ligand_smiles_entities_id.append(index) |
|
self.update() |
|
|
|
def delete_ligand_ion_ccd_row(self, index): |
|
for idx in range(len(self.ligand_smiles_entities_id)): |
|
if self.ligand_smiles_entities_id[idx] == index: |
|
self.ligand_smiles_entities.pop(idx) |
|
self.ligand_smiles_entities_id.pop(idx) |
|
break |
|
self.update() |
|
|
|
def add_dna_rna_protein_callback(self, b): |
|
self.dna_rna_protein_entities_count += 1 |
|
index = self.dna_rna_protein_entities_count |
|
item = DnaRnaProteinEntityWidget() |
|
delete_button = widgets.Button( |
|
description="entity", |
|
disabled=False, |
|
button_style="danger", |
|
tooltip="Click to delete this row", |
|
icon="trash", |
|
) |
|
delete_button.on_click(lambda b: self.delete_dna_rna_protein_row(index)) |
|
self.dna_rna_protein_entities.append(widgets.HBox([item, delete_button])) |
|
self.dna_rna_protein_entities_id.append(index) |
|
self.update() |
|
|
|
def delete_dna_rna_protein_row(self, index): |
|
for idx in range(len(self.dna_rna_protein_entities_id)): |
|
if self.dna_rna_protein_entities_id[idx] == index: |
|
self.dna_rna_protein_entities.pop(idx) |
|
self.dna_rna_protein_entities_id.pop(idx) |
|
break |
|
self.update() |
|
|
|
def update(self): |
|
self.children = ( |
|
[widgets.HBox([self.name, self.use_msa, self.atom_confidence])] |
|
+ [item for item in self.dna_rna_protein_entities] |
|
+ [item for item in self.ligand_smiles_entities] |
|
+ [self.covalent_bonds] |
|
+ [ |
|
widgets.HBox( |
|
[ |
|
self.add_dna_rna_protein, |
|
self.add_ligand_smiles, |
|
self.add_ligand_ion_ccd, |
|
] |
|
) |
|
] |
|
+ [self.model_parameter] |
|
) |
|
|
|
def get_result(self): |
|
result = dict() |
|
result["name"] = self.name.value |
|
assert len(result["name"]) != 0, "name is empty" |
|
result["use_msa"] = self.use_msa.value |
|
result["atom_confidence"] = self.atom_confidence.value |
|
result["sequences"] = [ |
|
item.children[0].get_result() for item in self.dna_rna_protein_entities |
|
] + [item.children[0].get_result() for item in self.ligand_smiles_entities] |
|
result["covalent_bonds"] = self.covalent_bonds.get_result() |
|
model_paramter = self.model_parameter.get_result() |
|
for key in model_paramter: |
|
result[key] = model_paramter[key] |
|
|
|
entity_num = len(result["sequences"]) |
|
for covalent_bond in result["covalent_bonds"]: |
|
assert ( |
|
covalent_bond["left_entity"] <= entity_num |
|
and covalent_bond["left_entity"] > 0 |
|
), "covalent bond index out of range" |
|
assert ( |
|
covalent_bond["right_entity"] <= entity_num |
|
and covalent_bond["right_entity"] > 0 |
|
), "covalent bond index out of range" |
|
|
|
assert len(covalent_bond["left_atom"]) != 0, "left atom is empty" |
|
assert len(covalent_bond["right_atom"]) != 0, "right_atom atom is empty" |
|
|
|
left_entiy = result["sequences"][covalent_bond["left_entity"] - 1] |
|
left_entiy_seq = left_entiy[[item for item in left_entiy.keys()][0]] |
|
right_entity = result["sequences"][covalent_bond["right_entity"] - 1] |
|
right_entity_seq = right_entity[[item for item in right_entity.keys()][0]] |
|
assert ( |
|
covalent_bond["left_position"] <= len(left_entiy_seq) |
|
and covalent_bond["left_position"] >= 1 |
|
), "left position out of range" |
|
assert ( |
|
covalent_bond["right_position"] <= len(right_entity_seq) |
|
and covalent_bond["right_position"] >= 1 |
|
), "right position out of range" |
|
return result |
|
|