Zaixi's picture
Add large file
89c0b51
# Copyright 2024 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import ipywidgets as widgets
class DnaRnaProteinEntityWidget(widgets.VBox):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.layout.border = "1px solid black"
self.layout.padding = "1px"
self.layout.margin = "1px"
self.molecule_type_dropdown = widgets.Dropdown(
options=["Protein", "DNA", "RNA"],
value="Protein",
disabled=False,
)
self.recode_name = {
"Protein": "proteinChain",
"DNA": "dnaSequence",
"RNA": "rnaSequence",
}
self.copies_int_text = widgets.BoundedIntText(
value=1, min=1, max=100, step=1, description="Copies:", disabled=False
)
self.sequence_text = widgets.Textarea(
value="",
placeholder="Paste sequence or fasta",
description="Sequence:",
disabled=False,
)
def handle_submit(sender):
item = sender["owner"]
cleaned_value = re.sub(r"[^a-zA-Z]", "", item.value)
item.value = cleaned_value.upper()
self.sequence_text.observe(handle_submit, names="value")
self.add_modification = widgets.Button(
description="modification",
disabled=False,
tooltip="add modification",
icon="plus",
)
self.modifications = list()
self.modifications_id = list()
self.modifications_count = 0
self.add_modification.on_click(self.add_modification_callback)
self.update()
def add_modification_callback(self, b):
self.modifications_count += 1
index = self.modifications_count
modificationType = widgets.Text(
value="CCD_XXX",
placeholder="CCD_XXX",
description="modificationType",
max_length=6,
disabled=False,
)
def handle_submit(sender):
item = sender["owner"]
item.value = item.value.upper()
if item.value[0:4] != "CCD_":
item.value = "CCD_"
item.value = item.value[0:10]
if len(item.value) < 4:
item.value = "CCD_"
item.value = re.sub(r"[^A-Z0-9_]", "", item.value)
modificationType.observe(handle_submit, names="value")
ptmPosition = widgets.BoundedIntText(
value=1, min=1, description="Position:", disabled=False
)
delete_button = widgets.Button(
description="modification",
disabled=False,
button_style="danger",
tooltip="Click to delete this row",
icon="trash",
)
delete_button.on_click(lambda b: self.delete_ptm_row(index))
self.modifications.append(
widgets.HBox([modificationType, ptmPosition, delete_button])
)
self.modifications_id.append(index)
self.update()
def update(self):
self.children = (
[
widgets.HBox(
[
self.molecule_type_dropdown,
self.copies_int_text,
self.sequence_text,
]
)
]
+ [i for i in self.modifications]
+ [self.add_modification]
)
def delete_ptm_row(self, index):
for idx in range(len(self.modifications_id)):
if self.modifications_id[idx] == index:
self.modifications.pop(idx)
self.modifications_id.pop(idx)
break
self.update()
def get_result(self):
result = dict()
sequence_type = self.molecule_type_dropdown.value
recode_name = self.recode_name[sequence_type]
sequence = self.sequence_text.value
result[recode_name] = dict()
result[recode_name]["count"] = self.copies_int_text.value
result[recode_name]["sequence"] = sequence
assert len(result[recode_name]["sequence"]) > 0, "sequence length must > 0"
result[recode_name]["modifications"] = [
{
"modificationType": item.children[0].value,
"Position": item.children[1].value,
}
for item in self.modifications
]
for item in result[recode_name]["modifications"]:
pos = item["Position"]
assert pos <= len(result[recode_name]["sequence"]), "position out of range"
if sequence_type == "Protein":
valid_chars = set("ARNDCQEGHILKMFPSTWYVX")
for char in sequence:
assert (
char in valid_chars
), f"Invalid character '{char}' in sequence. valid set is {valid_chars}"
elif sequence_type == "DNA":
valid_chars = set("AGCTXINU")
for char in sequence:
assert (
char in valid_chars
), f"Invalid character '{char}' in sequence. valid set is {valid_chars}"
elif sequence_type == "RNA":
valid_chars = set("AGCUXIN")
for char in sequence:
assert (
char in valid_chars
), f"Invalid character '{char}' in sequence. valid set is {valid_chars}"
return result
class LigandSmilesEntityWidget(widgets.VBox):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.layout.border = "1px solid black"
self.layout.padding = "1px"
self.layout.margin = "1px"
self.molecule_type_dropdown = widgets.Dropdown(
options=["Ligand"],
value="Ligand",
disabled=False,
)
self.copies_int_text = widgets.BoundedIntText(
value=1, min=1, max=100, step=1, description="Copies:", disabled=False
)
self.smiles_text = widgets.Textarea(
value="",
placeholder="SMILES",
description="SMILES:",
disabled=False,
)
self.update()
def update(self):
self.children = [
widgets.HBox(
[self.molecule_type_dropdown, self.copies_int_text, self.smiles_text]
)
]
def get_result(self):
result = dict()
name = "ligand"
result[name] = dict()
result[name][name] = self.smiles_text.value
assert len(result[name][name]) > 0, "SMILES length must > 0"
result[name]["count"] = self.copies_int_text.value
return result
class LigandIonCCDEntityWidget(widgets.VBox):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.layout.border = "1px solid black"
self.layout.padding = "1px"
self.layout.margin = "1px"
self.molecule_type_dropdown = widgets.Dropdown(
options=["Ligand", "Ion"],
value="Ligand",
disabled=False,
)
self.copies_int_text = widgets.BoundedIntText(
value=1, min=1, max=100, step=1, description="Copies:", disabled=False
)
self.sequence_text = widgets.Text(
value="CCD_XXX",
placeholder="Paste sequence or fasta",
description="CCD:",
disabled=False,
)
def handle_submit(sender):
item = sender["owner"]
item.value = item.value.upper()
if item.value[0:4] != "CCD_":
item.value = "CCD_"
if len(item.value) < 4:
item.value = "CCD_"
self.sequence_text.observe(handle_submit, names="value")
self.update()
def update(self):
self.children = [
widgets.HBox(
[self.molecule_type_dropdown, self.copies_int_text, self.sequence_text]
)
]
def get_result(self):
result = dict()
if self.molecule_type_dropdown.value == "Ion":
name = "ion"
else:
name = "ligand"
result[name] = dict()
result[name][name] = (
self.sequence_text.value
if name == "ligand"
else self.sequence_text.value[4::]
)
result[name]["count"] = self.copies_int_text.value
return result
class CovalentBondsWidget(widgets.VBox):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.layout.border = "1px solid black"
self.layout.padding = "1px"
self.layout.margin = "1px"
self.convalent_bonds = list()
self.convalent_bonds_id = list()
self.convalent_bonds_count = 0
self.add_convalent_bond_buttom = widgets.Button(
description="convalent_bond",
disabled=False,
tooltip="add convalent_bond",
icon="plus",
)
self.add_convalent_bond_buttom.on_click(self.on_add_convalent_bond)
self.update()
def on_add_convalent_bond(self, b):
self.convalent_bonds_count += 1
index = self.convalent_bonds_count
left_entity = widgets.BoundedIntText(
value=1, min=1, description="left entity", disabled=False
)
left_position = widgets.BoundedIntText(
value=1, min=1, description="left pos", disabled=False
)
left_atom = widgets.Text(
value="", placeholder="left atom", description="left atom", disabled=False
)
right_entity = widgets.BoundedIntText(
value=1, min=1, description="right entity", disabled=False
)
right_position = widgets.BoundedIntText(
value=1, min=1, description="right pos", disabled=False
)
right_atom = widgets.Text(
value="", placeholder="right atom", description="right atom", disabled=False
)
delete_button = widgets.Button(
description="convalent bond",
disabled=False,
button_style="danger",
tooltip="Click to delete this row",
icon="trash",
)
delete_button.on_click(lambda b: self.delete_row(index))
self.convalent_bonds.append(
widgets.HBox(
[
left_entity,
left_position,
left_atom,
right_entity,
right_position,
right_atom,
delete_button,
]
)
)
self.convalent_bonds_id.append(index)
self.update()
def delete_row(self, index):
for idx in range(len(self.convalent_bonds_id)):
if self.convalent_bonds_id[idx] == index:
self.convalent_bonds.pop(idx)
self.convalent_bonds_id.pop(idx)
break
self.update()
def update(self):
self.children = [item for item in self.convalent_bonds] + [
self.add_convalent_bond_buttom
]
def get_result(self):
result = [
{
"left_entity": item.children[0].value,
"left_position": item.children[1].value,
"left_atom": item.children[2].value,
"right_entity": item.children[3].value,
"right_position": item.children[4].value,
"right_atom": item.children[5].value,
}
for item in self.convalent_bonds
]
return result
class ModelWidget(widgets.VBox):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.layout.border = "1px solid black"
self.layout.padding = "1px"
self.layout.margin = "1px"
self.sample = widgets.BoundedIntText(
value=5,
description="N_sample",
tooltip="diffusion model samples",
disabled=False,
)
self.step = widgets.BoundedIntText(
value=200,
max=500,
description="N_step",
tooltip="diffusion model step",
disabled=False,
)
self.cycle = widgets.BoundedIntText(
value=10, description="N_cycle", tooltip="n cycle of model", disabled=False
)
self.seeds = widgets.Text(
value="1,2,3,4,5",
description="seeds",
tooltip="random seeds",
disabled=False,
)
self.version = widgets.Dropdown(
options=["v1", "v2", "v3", "v4", "v5"],
value="v1",
description="version",
tooltip="model version",
)
self.update()
def update(self):
self.children = [
widgets.HBox([self.sample, self.step, self.cycle, self.seeds, self.version])
]
def get_result(self):
results = dict()
results["model_seeds"] = [int(i) for i in self.seeds.value.split(",")]
results["model_version"] = self.version.value
results["N_sample"] = self.sample.value
results["N_step"] = self.step.value
results["N_cycle"] = self.cycle.value
return results
class ProtenixInputViewer(widgets.VBox):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.layout.border = "1px solid black"
self.layout.padding = "1px"
self.layout.margin = "1px"
self.name = widgets.Text(
value="", placeholder="name", description="name", disabled=False
)
def handle_submit(sender):
item = sender["owner"]
cleaned_value = re.sub(r"[^a-zA-Z0-9_]", "", item.value)
item.value = cleaned_value
self.name.observe(handle_submit, names="value")
self.use_msa = widgets.Checkbox(
value=True, description="use_msa", disabled=False
)
self.atom_confidence = widgets.Checkbox(
value=False, description="atom_confidence", disabled=False
)
self.ligand_smiles_entities = list()
self.ligand_smiles_entities_id = list()
self.ligand_smiles_entities_count = 0
self.dna_rna_protein_entities = list()
self.dna_rna_protein_entities_id = list()
self.dna_rna_protein_entities_count = 0
self.covalent_bonds = CovalentBondsWidget()
self.add_ligand_smiles = widgets.Button(
description="Ligand(SMILES)",
disabled=False,
tooltip="add Ligand(SMILES) Entity",
icon="plus",
)
self.add_ligand_ion_ccd = widgets.Button(
description="Ligand/Ion CCD",
disabled=False,
tooltip="add Ligand/Ion Entity",
icon="plus",
)
self.add_dna_rna_protein = widgets.Button(
description="Dna/Rna/Protein",
disabled=False,
tooltip="add Dna/Rna/Protein Entity",
icon="plus",
)
self.model_parameter = ModelWidget()
self.add_ligand_smiles.on_click(self.add_ligand_smiles_callback)
self.add_ligand_ion_ccd.on_click(self.add_ligand_ion_ccd_callback)
self.add_dna_rna_protein.on_click(self.add_dna_rna_protein_callback)
self.update()
def add_ligand_smiles_callback(self, b):
self.ligand_smiles_entities_count += 1
index = self.ligand_smiles_entities_count
item = LigandSmilesEntityWidget()
delete_button = widgets.Button(
description="entity",
disabled=False,
button_style="danger",
tooltip="Click to delete this row",
icon="trash",
)
delete_button.on_click(lambda b: self.delete_ligand_ion_ccd_row(index))
self.ligand_smiles_entities.append(widgets.HBox([item, delete_button]))
self.ligand_smiles_entities_id.append(index)
self.update()
def add_ligand_ion_ccd_callback(self, b):
self.ligand_smiles_entities_count += 1
index = self.ligand_smiles_entities_count
item = LigandIonCCDEntityWidget()
delete_button = widgets.Button(
description="entity",
disabled=False,
button_style="danger",
tooltip="Click to delete this row",
icon="trash",
)
delete_button.on_click(lambda b: self.delete_ligand_ion_ccd_row(index))
self.ligand_smiles_entities.append(widgets.HBox([item, delete_button]))
self.ligand_smiles_entities_id.append(index)
self.update()
def delete_ligand_ion_ccd_row(self, index):
for idx in range(len(self.ligand_smiles_entities_id)):
if self.ligand_smiles_entities_id[idx] == index:
self.ligand_smiles_entities.pop(idx)
self.ligand_smiles_entities_id.pop(idx)
break
self.update()
def add_dna_rna_protein_callback(self, b):
self.dna_rna_protein_entities_count += 1
index = self.dna_rna_protein_entities_count
item = DnaRnaProteinEntityWidget()
delete_button = widgets.Button(
description="entity",
disabled=False,
button_style="danger",
tooltip="Click to delete this row",
icon="trash",
)
delete_button.on_click(lambda b: self.delete_dna_rna_protein_row(index))
self.dna_rna_protein_entities.append(widgets.HBox([item, delete_button]))
self.dna_rna_protein_entities_id.append(index)
self.update()
def delete_dna_rna_protein_row(self, index):
for idx in range(len(self.dna_rna_protein_entities_id)):
if self.dna_rna_protein_entities_id[idx] == index:
self.dna_rna_protein_entities.pop(idx)
self.dna_rna_protein_entities_id.pop(idx)
break
self.update()
def update(self):
self.children = (
[widgets.HBox([self.name, self.use_msa, self.atom_confidence])]
+ [item for item in self.dna_rna_protein_entities]
+ [item for item in self.ligand_smiles_entities]
+ [self.covalent_bonds]
+ [
widgets.HBox(
[
self.add_dna_rna_protein,
self.add_ligand_smiles,
self.add_ligand_ion_ccd,
]
)
]
+ [self.model_parameter]
)
def get_result(self):
result = dict()
result["name"] = self.name.value
assert len(result["name"]) != 0, "name is empty"
result["use_msa"] = self.use_msa.value
result["atom_confidence"] = self.atom_confidence.value
result["sequences"] = [
item.children[0].get_result() for item in self.dna_rna_protein_entities
] + [item.children[0].get_result() for item in self.ligand_smiles_entities]
result["covalent_bonds"] = self.covalent_bonds.get_result()
model_paramter = self.model_parameter.get_result()
for key in model_paramter:
result[key] = model_paramter[key]
entity_num = len(result["sequences"])
for covalent_bond in result["covalent_bonds"]:
assert (
covalent_bond["left_entity"] <= entity_num
and covalent_bond["left_entity"] > 0
), "covalent bond index out of range"
assert (
covalent_bond["right_entity"] <= entity_num
and covalent_bond["right_entity"] > 0
), "covalent bond index out of range"
assert len(covalent_bond["left_atom"]) != 0, "left atom is empty"
assert len(covalent_bond["right_atom"]) != 0, "right_atom atom is empty"
left_entiy = result["sequences"][covalent_bond["left_entity"] - 1]
left_entiy_seq = left_entiy[[item for item in left_entiy.keys()][0]]
right_entity = result["sequences"][covalent_bond["right_entity"] - 1]
right_entity_seq = right_entity[[item for item in right_entity.keys()][0]]
assert (
covalent_bond["left_position"] <= len(left_entiy_seq)
and covalent_bond["left_position"] >= 1
), "left position out of range"
assert (
covalent_bond["right_position"] <= len(right_entity_seq)
and covalent_bond["right_position"] >= 1
), "right position out of range"
return result