ProteinGenesis / app.py
aiqcamp's picture
Update app.py
cebc63d verified
raw
history blame
36.4 kB
import os,sys
# install required packages
os.system('pip install plotly') # plotly ์„ค์น˜
os.system('pip install matplotlib') # matplotlib ์„ค์น˜
os.system('pip install dgl==1.0.2+cu116 -f https://data.dgl.ai/wheels/cu116/repo.html')
os.environ["DGLBACKEND"] = "pytorch"
print('Modules installed')
import plotly.graph_objects as go
import numpy as np
import gradio as gr
import py3Dmol
from io import StringIO
import json
import secrets
import copy
import matplotlib.pyplot as plt
from utils.sampler import HuggingFace_sampler
from utils.parsers_inference import parse_pdb
from model.util import writepdb
from utils.inpainting_util import *
# install environment goods
#os.system("pip -q install dgl -f https://data.dgl.ai/wheels/cu113/repo.html")
os.system('pip install dgl==1.0.2+cu116 -f https://data.dgl.ai/wheels/cu116/repo.html')
#os.system('pip install gradio')
os.environ["DGLBACKEND"] = "pytorch"
#os.system(f'pip install -r ./PROTEIN_GENERATOR/requirements.txt')
print('Modules installed')
#os.system('pip install --force gradio==3.36.1')
#os.system('pip install gradio_client==0.2.7')
#os.system('pip install \"numpy<2\"')
#os.system('pip install numpy --upgrade')
#os.system('pip install --force numpy==1.24.1')
if not os.path.exists('./SEQDIFF_230205_dssp_hotspots_25mask_EQtasks_mod30.pt'):
print('Downloading model weights 1')
os.system('wget http://files.ipd.uw.edu/pub/sequence_diffusion/checkpoints/SEQDIFF_230205_dssp_hotspots_25mask_EQtasks_mod30.pt')
print('Successfully Downloaded')
if not os.path.exists('./SEQDIFF_221219_equalTASKS_nostrSELFCOND_mod30.pt'):
print('Downloading model weights 2')
os.system('wget http://files.ipd.uw.edu/pub/sequence_diffusion/checkpoints/SEQDIFF_221219_equalTASKS_nostrSELFCOND_mod30.pt')
print('Successfully Downloaded')
import numpy as np
import gradio as gr
import py3Dmol
from io import StringIO
import json
import secrets
import copy
import matplotlib.pyplot as plt
from utils.sampler import HuggingFace_sampler
from utils.parsers_inference import parse_pdb
from model.util import writepdb
from utils.inpainting_util import *
plt.rcParams.update({'font.size': 13})
with open('./tmp/args.json','r') as f:
args = json.load(f)
# manually set checkpoint to load
args['checkpoint'] = None
args['dump_trb'] = False
args['dump_args'] = True
args['save_best_plddt'] = True
args['T'] = 25
args['strand_bias'] = 0.0
args['loop_bias'] = 0.0
args['helix_bias'] = 0.0
def protein_diffusion_model(sequence, seq_len, helix_bias, strand_bias, loop_bias,
secondary_structure, aa_bias, aa_bias_potential,
num_steps, noise, hydrophobic_target_score, hydrophobic_potential,
contigs, pssm, seq_mask, str_mask, rewrite_pdb):
dssp_checkpoint = './SEQDIFF_230205_dssp_hotspots_25mask_EQtasks_mod30.pt'
og_checkpoint = './SEQDIFF_221219_equalTASKS_nostrSELFCOND_mod30.pt'
model_args = copy.deepcopy(args)
# make sampler
S = HuggingFace_sampler(args=model_args)
# get random prefix
S.out_prefix = './tmp/'+secrets.token_hex(nbytes=10).upper()
# set args
S.args['checkpoint'] = None
S.args['dump_trb'] = False
S.args['dump_args'] = True
S.args['save_best_plddt'] = True
S.args['T'] = 20
S.args['strand_bias'] = 0.0
S.args['loop_bias'] = 0.0
S.args['helix_bias'] = 0.0
S.args['potentials'] = None
S.args['potential_scale'] = None
S.args['aa_composition'] = None
# get sequence if entered and make sure all chars are valid
alt_aa_dict = {'B':['D','N'],'J':['I','L'],'U':['C'],'Z':['E','Q'],'O':['K']}
if sequence not in ['',None]:
L = len(sequence)
aa_seq = []
for aa in sequence.upper():
if aa in alt_aa_dict.keys():
aa_seq.append(np.random.choice(alt_aa_dict[aa]))
else:
aa_seq.append(aa)
S.args['sequence'] = aa_seq
elif contigs not in ['',None]:
S.args['contigs'] = [contigs]
else:
S.args['contigs'] = [f'{seq_len}']
L = int(seq_len)
print('DEBUG: ',rewrite_pdb)
if rewrite_pdb not in ['',None]:
S.args['pdb'] = rewrite_pdb.name
if seq_mask not in ['',None]:
S.args['inpaint_seq'] = [seq_mask]
if str_mask not in ['',None]:
S.args['inpaint_str'] = [str_mask]
if secondary_structure in ['',None]:
secondary_structure = None
else:
secondary_structure = ''.join(['E' if x == 'S' else x for x in secondary_structure])
if L < len(secondary_structure):
secondary_structure = secondary_structure[:len(sequence)]
elif L == len(secondary_structure):
pass
else:
dseq = L - len(secondary_structure)
secondary_structure += secondary_structure[-1]*dseq
# potentials
potential_list = []
potential_bias_list = []
if aa_bias not in ['',None]:
potential_list.append('aa_bias')
S.args['aa_composition'] = aa_bias
if aa_bias_potential in ['',None]:
aa_bias_potential = 3
potential_bias_list.append(str(aa_bias_potential))
'''
if target_charge not in ['',None]:
potential_list.append('charge')
if charge_potential in ['',None]:
charge_potential = 1
potential_bias_list.append(str(charge_potential))
S.args['target_charge'] = float(target_charge)
if target_ph in ['',None]:
target_ph = 7.4
S.args['target_pH'] = float(target_ph)
'''
if hydrophobic_target_score not in ['',None]:
potential_list.append('hydrophobic')
S.args['hydrophobic_score'] = float(hydrophobic_target_score)
if hydrophobic_potential in ['',None]:
hydrophobic_potential = 3
potential_bias_list.append(str(hydrophobic_potential))
if pssm not in ['',None]:
potential_list.append('PSSM')
potential_bias_list.append('5')
S.args['PSSM'] = pssm.name
if len(potential_list) > 0:
S.args['potentials'] = ','.join(potential_list)
S.args['potential_scale'] = ','.join(potential_bias_list)
# normalise secondary_structure bias from range 0-0.3
S.args['secondary_structure'] = secondary_structure
S.args['helix_bias'] = helix_bias
S.args['strand_bias'] = strand_bias
S.args['loop_bias'] = loop_bias
# set T
if num_steps in ['',None]:
S.args['T'] = 20
else:
S.args['T'] = int(num_steps)
# noise
if 'normal' in noise:
S.args['sample_distribution'] = noise
S.args['sample_distribution_gmm_means'] = [0]
S.args['sample_distribution_gmm_variances'] = [1]
elif 'gmm2' in noise:
S.args['sample_distribution'] = noise
S.args['sample_distribution_gmm_means'] = [-1,1]
S.args['sample_distribution_gmm_variances'] = [1,1]
elif 'gmm3' in noise:
S.args['sample_distribution'] = noise
S.args['sample_distribution_gmm_means'] = [-1,0,1]
S.args['sample_distribution_gmm_variances'] = [1,1,1]
if secondary_structure not in ['',None] or helix_bias+strand_bias+loop_bias > 0:
S.args['checkpoint'] = dssp_checkpoint
S.args['d_t1d'] = 29
print('using dssp checkpoint')
else:
S.args['checkpoint'] = og_checkpoint
S.args['d_t1d'] = 24
print('using og checkpoint')
for k,v in S.args.items():
print(f"{k} --> {v}")
# init S
S.model_init()
S.diffuser_init()
S.setup()
# sampling loop
plddt_data = []
for j in range(S.max_t):
print(f'on step {j}')
output_seq, output_pdb, plddt = S.take_step_get_outputs(j)
plddt_data.append(plddt)
yield output_seq, output_pdb, display_pdb(output_pdb), get_plddt_plot(plddt_data, S.max_t)
output_seq, output_pdb, plddt = S.get_outputs()
return output_seq, output_pdb, display_pdb(output_pdb), get_plddt_plot(plddt_data, S.max_t)
def get_plddt_plot(plddt_data, max_t):
x = [i+1 for i in range(len(plddt_data))]
fig, ax = plt.subplots(figsize=(15,6))
ax.plot(x,plddt_data,color='#661dbf', linewidth=3,marker='o')
ax.set_xticks([i+1 for i in range(max_t)])
ax.set_yticks([(i+1)/10 for i in range(10)])
ax.set_ylim([0,1])
ax.set_ylabel('model confidence (plddt)')
ax.set_xlabel('diffusion steps (t)')
return fig
def display_pdb(path_to_pdb):
'''
#function to display pdb in py3dmol
'''
pdb = open(path_to_pdb, "r").read()
view = py3Dmol.view(width=500, height=500)
view.addModel(pdb, "pdb")
view.setStyle({'model': -1}, {"cartoon": {'colorscheme':{'prop':'b','gradient':'roygb','min':0,'max':1}}})#'linear', 'min': 0, 'max': 1, 'colors': ["#ff9ef0","#a903fc",]}}})
view.zoomTo()
output = view._make_html().replace("'", '"')
print(view._make_html())
x = f"""<!DOCTYPE html><html></center> {output} </center></html>""" # do not use ' in this input
return f"""<iframe height="500px" width="100%" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
'''
return f"""<iframe style="width: 100%; height:700px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
'''
def get_motif_preview(pdb_id, contigs):
try:
input_pdb = fetch_pdb(pdb_id=pdb_id.lower() if pdb_id else None)
if input_pdb is None:
return gr.HTML("PDB ID๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”"), None
parse = parse_pdb(input_pdb)
#output_name = './rewrite_'+input_pdb.split('/')[-1]
#writepdb(output_name, torch.tensor(parse_og['xyz']),torch.tensor(parse_og['seq']))
#parse = parse_pdb(output_name)
output_name = input_pdb
pdb = open(output_name, "r").read()
view = py3Dmol.view(width=500, height=500)
view.addModel(pdb, "pdb")
if contigs in ['',0]:
contigs = ['0']
else:
contigs = [contigs]
print('DEBUG: ',contigs)
pdb_map = get_mappings(ContigMap(parse,contigs))
print('DEBUG: ',pdb_map)
print('DEBUG: ',pdb_map['con_ref_idx0'])
roi = [x[1]-1 for x in pdb_map['con_ref_pdb_idx']]
colormap = {0:'#D3D3D3', 1:'#F74CFF'}
colors = {i+1: colormap[1] if i in roi else colormap[0] for i in range(parse['xyz'].shape[0])}
view.setStyle({"cartoon": {"colorscheme": {"prop": "resi", "map": colors}}})
view.zoomTo()
output = view._make_html().replace("'", '"')
print(view._make_html())
x = f"""<!DOCTYPE html><html></center> {output} </center></html>""" # do not use ' in this input
return f"""<iframe height="500px" width="100%" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""", output_name
except Exception as e:
return gr.HTML(f"์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"), None
def fetch_pdb(pdb_id=None):
if pdb_id is None or pdb_id == "":
return None
else:
os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_id}.pdb")
return f"{pdb_id}.pdb"
# MSA AND PSSM GUIDANCE
def save_pssm(file_upload):
filename = file_upload.name
orig_name = file_upload.orig_name
if filename.split('.')[-1] in ['fasta', 'a3m']:
return msa_to_pssm(file_upload)
return filename
def msa_to_pssm(msa_file):
# Define the lookup table for converting amino acids to indices
aa_to_index = {'A': 0, 'R': 1, 'N': 2, 'D': 3, 'C': 4, 'Q': 5, 'E': 6, 'G': 7, 'H': 8, 'I': 9, 'L': 10,
'K': 11, 'M': 12, 'F': 13, 'P': 14, 'S': 15, 'T': 16, 'W': 17, 'Y': 18, 'V': 19, 'X': 20, '-': 21}
# Open the FASTA file and read the sequences
records = list(SeqIO.parse(msa_file.name, "fasta"))
assert len(records) >= 1, "MSA must contain more than one protein sequecne."
first_seq = str(records[0].seq)
aligned_seqs = [first_seq]
# print(aligned_seqs)
# Perform sequence alignment using the Needleman-Wunsch algorithm
aligner = Align.PairwiseAligner()
aligner.open_gap_score = -0.7
aligner.extend_gap_score = -0.3
for record in records[1:]:
alignment = aligner.align(first_seq, str(record.seq))[0]
alignment = alignment.format().split("\n")
al1 = alignment[0]
al2 = alignment[2]
al1_fin = ""
al2_fin = ""
percent_gap = al2.count('-')/ len(al2)
if percent_gap > 0.4:
continue
for i in range(len(al1)):
if al1[i] != '-':
al1_fin += al1[i]
al2_fin += al2[i]
aligned_seqs.append(str(al2_fin))
# Get the length of the aligned sequences
aligned_seq_length = len(first_seq)
# Initialize the position scoring matrix
matrix = np.zeros((22, aligned_seq_length))
# Iterate through the aligned sequences and count the amino acids at each position
for seq in aligned_seqs:
#print(seq)
for i in range(aligned_seq_length):
if i == len(seq):
break
amino_acid = seq[i]
if amino_acid.upper() not in aa_to_index.keys():
continue
else:
aa_index = aa_to_index[amino_acid.upper()]
matrix[aa_index, i] += 1
# Normalize the counts to get the frequency of each amino acid at each position
matrix /= len(aligned_seqs)
print(len(aligned_seqs))
matrix[20:,]=0
outdir = ".".join(msa_file.name.split('.')[:-1]) + ".csv"
np.savetxt(outdir, matrix[:21,:].T, delimiter=",")
return outdir
def get_pssm(fasta_msa, input_pssm):
try:
if input_pssm is not None:
outdir = input_pssm.name
elif fasta_msa is not None:
outdir = save_pssm(fasta_msa)
else:
return gr.Plot(label="ํŒŒ์ผ์„ ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”"), None
pssm = np.loadtxt(outdir, delimiter=",", dtype=float)
fig, ax = plt.subplots(figsize=(15,6))
plt.imshow(torch.permute(torch.tensor(pssm),(1,0)))
return fig, outdir
except Exception as e:
return gr.Plot(label=f"์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"), None
# ํžˆ์–ด๋กœ ๋Šฅ๋ ฅ์น˜ ๊ณ„์‚ฐ ํ•จ์ˆ˜ ์ถ”๊ฐ€
def calculate_hero_stats(helix_bias, strand_bias, loop_bias, hydrophobic_score):
stats = {
'strength': strand_bias * 20, # ๋ฒ ํƒ€์‹œํŠธ ๊ตฌ์กฐ ๊ธฐ๋ฐ˜
'flexibility': helix_bias * 20, # ์•ŒํŒŒํ—ฌ๋ฆญ์Šค ๊ตฌ์กฐ ๊ธฐ๋ฐ˜
'speed': loop_bias * 5, # ๋ฃจํ”„ ๊ตฌ์กฐ ๊ธฐ๋ฐ˜
'defense': abs(hydrophobic_score) if hydrophobic_score else 0
}
return stats
def toggle_seq_input(choice):
if choice == "์ž๋™ ์„ค๊ณ„":
return gr.update(visible=True), gr.update(visible=False)
else: # "์ง์ ‘ ์ž…๋ ฅ"
return gr.update(visible=False), gr.update(visible=True)
def toggle_secondary_structure(choice):
if choice == "์Šฌ๋ผ์ด๋”๋กœ ์„ค์ •":
return (
gr.update(visible=True), # helix_bias
gr.update(visible=True), # strand_bias
gr.update(visible=True), # loop_bias
gr.update(visible=False) # secondary_structure
)
else: # "์ง์ ‘ ์ž…๋ ฅ"
return (
gr.update(visible=False), # helix_bias
gr.update(visible=False), # strand_bias
gr.update(visible=False), # loop_bias
gr.update(visible=True) # secondary_structure
)
def create_radar_chart(stats):
# ๋ ˆ์ด๋” ์ฐจํŠธ ์ƒ์„ฑ ๋กœ์ง
categories = list(stats.keys())
values = list(stats.values())
fig = go.Figure(data=go.Scatterpolar(
r=values,
theta=categories,
fill='toself'
))
fig.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, 1]
)),
showlegend=False
)
return fig
def generate_hero_description(name, stats, abilities):
# ํžˆ์–ด๋กœ ์„ค๋ช… ์ƒ์„ฑ ๋กœ์ง
description = f"""
ํžˆ์–ด๋กœ ์ด๋ฆ„: {name}
์ฃผ์š” ๋Šฅ๋ ฅ:
- ๊ทผ๋ ฅ: {'โ˜…' * int(stats['strength'] * 5)}
- ์œ ์—ฐ์„ฑ: {'โ˜…' * int(stats['flexibility'] * 5)}
- ์Šคํ”ผ๋“œ: {'โ˜…' * int(stats['speed'] * 5)}
- ๋ฐฉ์–ด๋ ฅ: {'โ˜…' * int(stats['defense'] * 5)}
ํŠน์ˆ˜ ๋Šฅ๋ ฅ: {', '.join(abilities)}
"""
return description
def combined_generation(name, strength, flexibility, speed, defense, size, abilities,
sequence, seq_len, helix_bias, strand_bias, loop_bias,
secondary_structure, aa_bias, aa_bias_potential,
num_steps, noise, hydrophobic_target_score, hydrophobic_potential,
contigs, pssm, seq_mask, str_mask, rewrite_pdb):
try:
# protein_diffusion_model ์‹คํ–‰
generator = protein_diffusion_model(
sequence=None,
seq_len=size, # ํžˆ์–ด๋กœ ํฌ๊ธฐ๋ฅผ seq_len์œผ๋กœ ์‚ฌ์šฉ
helix_bias=flexibility, # ํžˆ์–ด๋กœ ์œ ์—ฐ์„ฑ์„ helix_bias๋กœ ์‚ฌ์šฉ
strand_bias=strength, # ํžˆ์–ด๋กœ ๊ฐ•๋„๋ฅผ strand_bias๋กœ ์‚ฌ์šฉ
loop_bias=speed, # ํžˆ์–ด๋กœ ์Šคํ”ผ๋“œ๋ฅผ loop_bias๋กœ ์‚ฌ์šฉ
secondary_structure=None,
aa_bias=None,
aa_bias_potential=None,
num_steps="25",
noise="normal",
hydrophobic_target_score=str(-defense), # ํžˆ์–ด๋กœ ๋ฐฉ์–ด๋ ฅ์„ hydrophobic score๋กœ ์‚ฌ์šฉ
hydrophobic_potential="2",
contigs=None,
pssm=None,
seq_mask=None,
str_mask=None,
rewrite_pdb=None
)
# ๋งˆ์ง€๋ง‰ ๊ฒฐ๊ณผ ๊ฐ€์ ธ์˜ค๊ธฐ
final_result = None
for result in generator:
final_result = result
if final_result is None:
raise Exception("์ƒ์„ฑ ๊ฒฐ๊ณผ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค")
output_seq, output_pdb, structure_view, plddt_plot = final_result
# ํžˆ์–ด๋กœ ๋Šฅ๋ ฅ์น˜ ๊ณ„์‚ฐ
stats = calculate_hero_stats(flexibility, strength, speed, defense)
# ๋ชจ๋“  ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜
return (
create_radar_chart(stats), # ๋Šฅ๋ ฅ์น˜ ์ฐจํŠธ
generate_hero_description(name, stats, abilities), # ํžˆ์–ด๋กœ ์„ค๋ช…
output_seq, # ๋‹จ๋ฐฑ์งˆ ์„œ์—ด
output_pdb, # PDB ํŒŒ์ผ
structure_view, # 3D ๊ตฌ์กฐ
plddt_plot # ์‹ ๋ขฐ๋„ ์ฐจํŠธ
)
except Exception as e:
print(f"Error in combined_generation: {str(e)}")
return (
None,
f"์—๋Ÿฌ: {str(e)}",
None,
None,
gr.HTML("์—๋Ÿฌ๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค"),
None
)
with gr.Blocks(theme='ParityError/Interstellar') as demo:
with gr.Row():
with gr.Column():
gr.Markdown("# ๐Ÿฆธโ€โ™‚๏ธ ์Šˆํผํžˆ์–ด๋กœ ๋‹จ๋ฐฑ์งˆ ๋งŒ๋“ค๊ธฐ")
with gr.Tabs():
with gr.TabItem("๐Ÿฆธโ€โ™‚๏ธ ํžˆ์–ด๋กœ ๋””์ž์ธ"):
gr.Markdown("""
### โœจ ๋‹น์‹ ๋งŒ์˜ ํŠน๋ณ„ํ•œ ํžˆ์–ด๋กœ๋ฅผ ๋งŒ๋“ค์–ด๋ณด์„ธ์š”!
๊ฐ ๋Šฅ๋ ฅ์น˜๋ฅผ ์กฐ์ ˆํ•˜๋ฉด ํžˆ์–ด๋กœ์˜ DNA๊ฐ€ ์ž๋™์œผ๋กœ ์„ค๊ณ„๋ฉ๋‹ˆ๋‹ค.
""")
# ํžˆ์–ด๋กœ ๊ธฐ๋ณธ ์ •๋ณด
hero_name = gr.Textbox(
label="ํžˆ์–ด๋กœ ์ด๋ฆ„",
placeholder="๋‹น์‹ ์˜ ํžˆ์–ด๋กœ ์ด๋ฆ„์„ ์ง€์–ด์ฃผ์„ธ์š”!",
info="ํžˆ์–ด๋กœ์˜ ์ •์ฒด์„ฑ์„ ๋‚˜ํƒ€๋‚ด๋Š” ์ด๋ฆ„์„ ์ž…๋ ฅํ•˜์„ธ์š”"
)
# ๋Šฅ๋ ฅ์น˜ ์„ค์ •
gr.Markdown("### ๐Ÿ’ช ํžˆ์–ด๋กœ ๋Šฅ๋ ฅ์น˜ ์„ค์ •")
with gr.Row():
strength = gr.Slider(
minimum=0.0, maximum=0.05,
label="๐Ÿ’ช ์ดˆ๊ฐ•๋ ฅ(๊ทผ๋ ฅ)",
value=0.02,
info="๋‹จ๋‹จํ•œ ๋ฒ ํƒ€์‹œํŠธ ๊ตฌ์กฐ๋กœ ๊ฐ•๋ ฅํ•œ ํž˜์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค"
)
flexibility = gr.Slider(
minimum=0.0, maximum=0.05,
label="๐Ÿคธโ€โ™‚๏ธ ์œ ์—ฐ์„ฑ",
value=0.02,
info="๋‚˜์„ ํ˜• ์•ŒํŒŒํ—ฌ๋ฆญ์Šค ๊ตฌ์กฐ๋กœ ์œ ์—ฐํ•œ ์›€์ง์ž„์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค"
)
with gr.Row():
speed = gr.Slider(
minimum=0.0, maximum=0.20,
label="โšก ์Šคํ”ผ๋“œ",
value=0.1,
info="๋ฃจํ”„ ๊ตฌ์กฐ๋กœ ๋น ๋ฅธ ์›€์ง์ž„์„ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค"
)
defense = gr.Slider(
minimum=-10, maximum=10,
label="๐Ÿ›ก๏ธ ๋ฐฉ์–ด๋ ฅ",
value=0,
info="์Œ์ˆ˜: ์ˆ˜์ค‘ ํ™œ๋™์— ํŠนํ™”, ์–‘์ˆ˜: ์ง€์ƒ ํ™œ๋™์— ํŠนํ™”"
)
# ํžˆ์–ด๋กœ ํฌ๊ธฐ ์„ค์ •
hero_size = gr.Slider(
minimum=50, maximum=200,
label="๐Ÿ“ ํžˆ์–ด๋กœ ํฌ๊ธฐ",
value=100,
info="ํžˆ์–ด๋กœ์˜ ์ „์ฒด์ ์ธ ํฌ๊ธฐ๋ฅผ ๊ฒฐ์ •ํ•ฉ๋‹ˆ๋‹ค"
)
# ํŠน์ˆ˜ ๋Šฅ๋ ฅ ์„ค์ •
with gr.Accordion("๐ŸŒŸ ํŠน์ˆ˜ ๋Šฅ๋ ฅ", open=False):
gr.Markdown("""
ํŠน์ˆ˜ ๋Šฅ๋ ฅ์„ ์„ ํƒํ•˜๋ฉด ํžˆ์–ด๋กœ์˜ DNA์— ํŠน๋ณ„ํ•œ ๊ตฌ์กฐ๊ฐ€ ์ถ”๊ฐ€๋ฉ๋‹ˆ๋‹ค.
- ์ž๊ฐ€ ํšŒ๋ณต: ๋‹จ๋ฐฑ์งˆ ๊ตฌ์กฐ ๋ณต๊ตฌ ๋Šฅ๋ ฅ ๊ฐ•ํ™”
- ์›๊ฑฐ๋ฆฌ ๊ณต๊ฒฉ: ํŠน์ˆ˜ํ•œ ๊ตฌ์กฐ์  ๋Œ์ถœ๋ถ€ ํ˜•์„ฑ
- ๋ฐฉ์–ด๋ง‰ ์ƒ์„ฑ: ์•ˆ์ •์ ์ธ ๋ณดํ˜ธ์ธต ๊ตฌ์กฐ ์ƒ์„ฑ
""")
special_ability = gr.CheckboxGroup(
choices=["์ž๊ฐ€ ํšŒ๋ณต", "์›๊ฑฐ๋ฆฌ ๊ณต๊ฒฉ", "๋ฐฉ์–ด๋ง‰ ์ƒ์„ฑ"],
label="ํŠน์ˆ˜ ๋Šฅ๋ ฅ ์„ ํƒ"
)
# ์ƒ์„ฑ ๋ฒ„ํŠผ
create_btn = gr.Button("๐Ÿงฌ ํžˆ์–ด๋กœ ์ƒ์„ฑ!", variant="primary", scale=2)
with gr.TabItem("๐Ÿงฌ ํžˆ์–ด๋กœ DNA ์„ค๊ณ„"):
gr.Markdown("""
### ๐Ÿงช ํžˆ์–ด๋กœ DNA ๊ณ ๊ธ‰ ์„ค์ •
ํžˆ์–ด๋กœ์˜ ์œ ์ „์ž ๊ตฌ์กฐ๋ฅผ ๋” ์„ธ๋ฐ€ํ•˜๊ฒŒ ์กฐ์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
""")
seq_opt = gr.Radio(
["์ž๋™ ์„ค๊ณ„", "์ง์ ‘ ์ž…๋ ฅ"],
label="DNA ์„ค๊ณ„ ๋ฐฉ์‹",
value="์ž๋™ ์„ค๊ณ„"
)
sequence = gr.Textbox(
label="DNA ์‹œํ€€์Šค",
lines=1,
placeholder='์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ์•„๋ฏธ๋…ธ์‚ฐ: A,C,D,E,F,G,H,I,K,L,M,N,P,Q,R,S,T,V,W,Y (X๋Š” ๋ฌด์ž‘์œ„)',
visible=False
)
seq_len = gr.Slider(
minimum=5.0, maximum=250.0,
label="DNA ๊ธธ์ด",
value=100,
visible=True
)
with gr.Accordion(label='๐Ÿฆด ๊ณจ๊ฒฉ ๊ตฌ์กฐ ์„ค์ •', open=True):
gr.Markdown("""
ํžˆ์–ด๋กœ์˜ ๊ธฐ๋ณธ ๊ณจ๊ฒฉ ๊ตฌ์กฐ๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
- ๋‚˜์„ ํ˜• ๊ตฌ์กฐ: ์œ ์—ฐํ•˜๊ณ  ํƒ„๋ ฅ์žˆ๋Š” ์›€์ง์ž„
- ๋ณ‘ํ’ํ˜• ๊ตฌ์กฐ: ๋‹จ๋‹จํ•˜๊ณ  ๊ฐ•๋ ฅํ•œ ํž˜
- ๊ณ ๋ฆฌํ˜• ๊ตฌ์กฐ: ๋น ๋ฅด๊ณ  ๋ฏผ์ฒฉํ•œ ์›€์ง์ž„
""")
sec_str_opt = gr.Radio(
["์Šฌ๋ผ์ด๋”๋กœ ์„ค์ •", "์ง์ ‘ ์ž…๋ ฅ"],
label="๊ณจ๊ฒฉ ๊ตฌ์กฐ ์„ค์ • ๋ฐฉ์‹",
value="์Šฌ๋ผ์ด๋”๋กœ ์„ค์ •"
)
secondary_structure = gr.Textbox(
label="๊ณจ๊ฒฉ ๊ตฌ์กฐ",
lines=1,
placeholder='H:๋‚˜์„ ํ˜•, S:๋ณ‘ํ’ํ˜•, L:๊ณ ๋ฆฌํ˜•, X:์ž๋™์„ค์ •',
visible=False
)
with gr.Column():
helix_bias = gr.Slider(
minimum=0.0, maximum=0.05,
label="๋‚˜์„ ํ˜• ๊ตฌ์กฐ ๋น„์œจ",
visible=True
)
strand_bias = gr.Slider(
minimum=0.0, maximum=0.05,
label="๋ณ‘ํ’ํ˜• ๊ตฌ์กฐ ๋น„์œจ",
visible=True
)
loop_bias = gr.Slider(
minimum=0.0, maximum=0.20,
label="๊ณ ๋ฆฌํ˜• ๊ตฌ์กฐ ๋น„์œจ",
visible=True
)
# ์•„๋ฏธ๋…ธ์‚ฐ ๊ตฌ์„ฑ ์„ค์ • ์ถ”๊ฐ€
with gr.Accordion(label='๐Ÿงฌ DNA ๊ตฌ์„ฑ ์„ค์ •', open=False):
gr.Markdown("""
ํŠน์ • ์•„๋ฏธ๋…ธ์‚ฐ์˜ ๋น„์œจ์„ ์กฐ์ ˆํ•˜์—ฌ ํžˆ์–ด๋กœ์˜ ํŠน์„ฑ์„ ๊ฐ•ํ™”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
์˜ˆ์‹œ: W0.2,E0.1 (ํŠธ๋ฆฝํ† ํŒ 20%, ๊ธ€๋ฃจํƒ์‚ฐ 10%)
""")
with gr.Row():
aa_bias = gr.Textbox(
label="์•„๋ฏธ๋…ธ์‚ฐ ๋น„์œจ",
lines=1,
placeholder='์˜ˆ์‹œ: W0.2,E0.1'
)
aa_bias_potential = gr.Textbox(
label="๊ฐ•ํ™” ์ •๋„",
lines=1,
placeholder='1.0-5.0 ์‚ฌ์ด ๊ฐ’ ์ž…๋ ฅ'
)
# ํ™˜๊ฒฝ ์ ์‘๋ ฅ ์„ค์ • ์ถ”๊ฐ€
with gr.Accordion(label='๐ŸŒ ํ™˜๊ฒฝ ์ ์‘๋ ฅ ์„ค์ •', open=False):
gr.Markdown("""
ํžˆ์–ด๋กœ์˜ ํ™˜๊ฒฝ ์ ์‘๋ ฅ์„ ์กฐ์ ˆํ•ฉ๋‹ˆ๋‹ค.
์Œ์ˆ˜: ์ˆ˜์ค‘ ํ™œ๋™์— ํŠนํ™”, ์–‘์ˆ˜: ์ง€์ƒ ํ™œ๋™์— ํŠนํ™”
""")
with gr.Row():
hydrophobic_target_score = gr.Textbox(
label="ํ™˜๊ฒฝ ์ ์‘ ์ ์ˆ˜",
lines=1,
placeholder='์˜ˆ์‹œ: -5 (์ˆ˜์ค‘ ํ™œ๋™์— ํŠนํ™”)'
)
hydrophobic_potential = gr.Textbox(
label="์ ์‘๋ ฅ ๊ฐ•ํ™” ์ •๋„",
lines=1,
placeholder='1.0-2.0 ์‚ฌ์ด ๊ฐ’ ์ž…๋ ฅ'
)
# ํ™•์‚ฐ ๋งค๊ฐœ๋ณ€์ˆ˜ ์„ค์ •
with gr.Accordion(label='โš™๏ธ ๊ณ ๊ธ‰ ์„ค์ •', open=False):
gr.Markdown("""
DNA ์ƒ์„ฑ ๊ณผ์ •์˜ ์„ธ๋ถ€ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค.
""")
with gr.Row():
num_steps = gr.Textbox(
label="์ƒ์„ฑ ๋‹จ๊ณ„",
lines=1,
placeholder='25 ์ดํ•˜ ๊ถŒ์žฅ'
)
noise = gr.Dropdown(
['normal','gmm2 [-1,1]','gmm3 [-1,0,1]'],
label='๋…ธ์ด์ฆˆ ํƒ€์ž…',
value='normal'
)
design_btn = gr.Button("๐Ÿงฌ DNA ์„ค๊ณ„ ์ƒ์„ฑ!", variant="primary", scale=2)
with gr.TabItem("๐Ÿงช ํžˆ์–ด๋กœ ์œ ์ „์ž ๊ฐ•ํ™”"):
gr.Markdown("""
### โšก ๊ธฐ์กด ํžˆ์–ด๋กœ์˜ DNA ํ™œ์šฉ
๊ฐ•๋ ฅํ•œ ํžˆ์–ด๋กœ์˜ DNA ์ผ๋ถ€๋ฅผ ์ƒˆ๋กœ์šด ํžˆ์–ด๋กœ์—๊ฒŒ ์ด์‹ํ•ฉ๋‹ˆ๋‹ค.
""")
gr.Markdown("๊ณต๊ฐœ๋œ ํžˆ์–ด๋กœ DNA ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์—์„œ ์ฝ”๋“œ๋ฅผ ์ฐพ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค")
pdb_id_code = gr.Textbox(
label="ํžˆ์–ด๋กœ DNA ์ฝ”๋“œ",
lines=1,
placeholder='๊ธฐ์กด ํžˆ์–ด๋กœ์˜ DNA ์ฝ”๋“œ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š” (์˜ˆ: 1DPX)'
)
gr.Markdown("์ด์‹ํ•˜๊ณ  ์‹ถ์€ DNA ์˜์—ญ์„ ์„ ํƒํ•˜๊ณ  ์ƒˆ๋กœ์šด DNA๋ฅผ ์ถ”๊ฐ€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค")
contigs = gr.Textbox(
label="์ด์‹ํ•  DNA ์˜์—ญ",
lines=1,
placeholder='์˜ˆ์‹œ: 15,A3-10,20-30'
)
with gr.Row():
seq_mask = gr.Textbox(
label='๋Šฅ๋ ฅ ์žฌ์„ค๊ณ„',
lines=1,
placeholder='์„ ํƒํ•œ ์˜์—ญ์˜ ๋Šฅ๋ ฅ์„ ์ƒˆ๋กญ๊ฒŒ ๋””์ž์ธ'
)
str_mask = gr.Textbox(
label='๊ตฌ์กฐ ์žฌ์„ค๊ณ„',
lines=1,
placeholder='์„ ํƒํ•œ ์˜์—ญ์˜ ๊ตฌ์กฐ๋ฅผ ์ƒˆ๋กญ๊ฒŒ ๋””์ž์ธ'
)
preview_viewer = gr.HTML()
rewrite_pdb = gr.File(label='ํžˆ์–ด๋กœ DNA ํŒŒ์ผ')
preview_btn = gr.Button("๐Ÿ” ๋ฏธ๋ฆฌ๋ณด๊ธฐ", variant="secondary")
enhance_btn = gr.Button("โšก ๊ฐ•ํ™”๋œ ํžˆ์–ด๋กœ ์ƒ์„ฑ!", variant="primary", scale=2)
with gr.TabItem("๐Ÿ‘‘ ํžˆ์–ด๋กœ ๊ฐ€๋ฌธ"):
gr.Markdown("""
### ๐Ÿฐ ์œ„๋Œ€ํ•œ ํžˆ์–ด๋กœ ๊ฐ€๋ฌธ์˜ ์œ ์‚ฐ
๊ฐ•๋ ฅํ•œ ํžˆ์–ด๋กœ ๊ฐ€๋ฌธ์˜ ํŠน์„ฑ์„ ๊ณ„์Šนํ•˜์—ฌ ์ƒˆ๋กœ์šด ํžˆ์–ด๋กœ๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
""")
with gr.Row():
with gr.Column():
gr.Markdown("ํžˆ์–ด๋กœ ๊ฐ€๋ฌธ์˜ DNA ์ •๋ณด๊ฐ€ ๋‹ด๊ธด ํŒŒ์ผ์„ ์—…๋กœ๋“œํ•˜์„ธ์š”")
fasta_msa = gr.File(label='๊ฐ€๋ฌธ DNA ๋ฐ์ดํ„ฐ')
with gr.Column():
gr.Markdown("์ด๋ฏธ ๋ถ„์„๋œ ๊ฐ€๋ฌธ ํŠน์„ฑ ๋ฐ์ดํ„ฐ๊ฐ€ ์žˆ๋‹ค๋ฉด ์—…๋กœ๋“œํ•˜์„ธ์š”")
input_pssm = gr.File(label='๊ฐ€๋ฌธ ํŠน์„ฑ ๋ฐ์ดํ„ฐ')
pssm = gr.File(label='๋ถ„์„๋œ ๊ฐ€๋ฌธ ํŠน์„ฑ')
pssm_view = gr.Plot(label='๊ฐ€๋ฌธ ํŠน์„ฑ ๋ถ„์„ ๊ฒฐ๊ณผ')
pssm_gen_btn = gr.Button("โœจ ๊ฐ€๋ฌธ ํŠน์„ฑ ๋ถ„์„", variant="secondary")
inherit_btn = gr.Button("๐Ÿ‘‘ ๊ฐ€๋ฌธ์˜ ํž˜ ๊ณ„์Šน!", variant="primary", scale=2)
with gr.Column():
gr.Markdown("## ๐Ÿฆธโ€โ™‚๏ธ ํžˆ์–ด๋กœ ํ”„๋กœํ•„")
# ๋Šฅ๋ ฅ์น˜ ๋ ˆ์ด๋” ์ฐจํŠธ
hero_stats = gr.Plot(label="๋Šฅ๋ ฅ์น˜ ๋ถ„์„")
# ํžˆ์–ด๋กœ ์„ค๋ช…
hero_description = gr.Textbox(label="ํžˆ์–ด๋กœ ํŠน์„ฑ", lines=3)
gr.Markdown("## ๐Ÿงฌ ํžˆ์–ด๋กœ DNA ๋ถ„์„ ๊ฒฐ๊ณผ")
gr.Markdown("#### โšก DNA ์•ˆ์ •์„ฑ ์ ์ˆ˜")
plddt_plot = gr.Plot(label='์•ˆ์ •์„ฑ ๋ถ„์„')
gr.Markdown("#### ๐Ÿ“ DNA ์‹œํ€€์Šค")
output_seq = gr.Textbox(label="DNA ์„œ์—ด")
gr.Markdown("#### ๐Ÿ’พ DNA ๋ฐ์ดํ„ฐ")
output_pdb = gr.File(label="DNA ํŒŒ์ผ")
gr.Markdown("#### ๐Ÿ”ฌ DNA ๊ตฌ์กฐ")
output_viewer = gr.HTML()
# ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
seq_opt.change(
fn=toggle_seq_input,
inputs=[seq_opt],
outputs=[seq_len, sequence],
queue=False
)
sec_str_opt.change(
fn=toggle_secondary_structure,
inputs=[sec_str_opt],
outputs=[helix_bias, strand_bias, loop_bias, secondary_structure],
queue=False
)
preview_btn.click(get_motif_preview,[pdb_id_code, contigs],[preview_viewer, rewrite_pdb])
pssm_gen_btn.click(get_pssm,[fasta_msa,input_pssm],[pssm_view, pssm])
# ๊ฐ ํƒญ์˜ ์ƒ์„ฑ ๋ฒ„ํŠผ ์—ฐ๊ฒฐ
create_btn.click(
combined_generation,
inputs=[
hero_name, strength, flexibility, speed, defense, hero_size, special_ability,
sequence, seq_len, helix_bias, strand_bias, loop_bias,
secondary_structure, aa_bias, aa_bias_potential,
num_steps, noise, hydrophobic_target_score, hydrophobic_potential,
contigs, pssm, seq_mask, str_mask, rewrite_pdb
],
outputs=[
hero_stats,
hero_description,
output_seq,
output_pdb,
output_viewer,
plddt_plot
]
)
design_btn.click(
combined_generation,
inputs=[
hero_name, strength, flexibility, speed, defense, hero_size, special_ability,
sequence, seq_len, helix_bias, strand_bias, loop_bias,
secondary_structure, aa_bias, aa_bias_potential,
num_steps, noise, hydrophobic_target_score, hydrophobic_potential,
contigs, pssm, seq_mask, str_mask, rewrite_pdb
],
outputs=[
hero_stats,
hero_description,
output_seq,
output_pdb,
output_viewer,
plddt_plot
]
)
enhance_btn.click(
combined_generation,
inputs=[
hero_name, strength, flexibility, speed, defense, hero_size, special_ability,
sequence, seq_len, helix_bias, strand_bias, loop_bias,
secondary_structure, aa_bias, aa_bias_potential,
num_steps, noise, hydrophobic_target_score, hydrophobic_potential,
contigs, pssm, seq_mask, str_mask, rewrite_pdb
],
outputs=[
hero_stats,
hero_description,
output_seq,
output_pdb,
output_viewer,
plddt_plot
]
)
inherit_btn.click(
combined_generation,
inputs=[
hero_name, strength, flexibility, speed, defense, hero_size, special_ability,
sequence, seq_len, helix_bias, strand_bias, loop_bias,
secondary_structure, aa_bias, aa_bias_potential,
num_steps, noise, hydrophobic_target_score, hydrophobic_potential,
contigs, pssm, seq_mask, str_mask, rewrite_pdb
],
outputs=[
hero_stats,
hero_description,
output_seq,
output_pdb,
output_viewer,
plddt_plot
]
)
demo.queue()
demo.launch(debug=True)