import os,sys
from openai import OpenAI
import gradio as gr
# install required packages
os.system('pip install -q plotly')
os.system('pip install -q matplotlib')
os.system('pip install dgl==1.0.2+cu116 -f https://data.dgl.ai/wheels/cu116/repo.html')
os.environ["DGLBACKEND"] = "pytorch"
print('Modules installed')
# 필수 라이브러리 임포트
from datasets import load_dataset
import plotly.graph_objects as go
import numpy as np
import py3Dmol
from io import StringIO
import json
import secrets
import copy
import matplotlib.pyplot as plt
from utils.sampler import HuggingFace_sampler
from utils.parsers_inference import parse_pdb
from model.util import writepdb
from utils.inpainting_util import *
# Hugging Face 토큰 설정
ACCESS_TOKEN = os.getenv("HF_TOKEN")
if not ACCESS_TOKEN:
raise ValueError("HF_TOKEN not found in environment variables")
# OpenAI 클라이언트 설정 (Hugging Face 엔드포인트 사용)
client = OpenAI(
base_url="https://api-inference.huggingface.co/v1/",
api_key=ACCESS_TOKEN,
)
# 데이터셋 로드
ds = load_dataset("lamm-mit/protein_secondary_structure_from_PDB",
token=ACCESS_TOKEN)
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for message in client.chat.completions.create(
model="CohereForAI/c4ai-command-r-plus-08-2024",
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
messages=messages,
):
token = message.choices[0].delta.content
response += token
yield response
# 챗봇 및 단백질 생성 관련 함수들
def process_chat(message, history):
messages = [{"role": "user", "content": message}]
response = pipe(messages)[0]['generated_text']
if any(keyword in message.lower() for keyword in ['protein', 'generate', '단백질', '생성']):
relevant_data = search_protein_data(message)
params = extract_parameters(response, relevant_data)
protein_result = generate_protein(params)
explanation = generate_explanation(protein_result, params)
return response + "\n\n" + explanation
return response
def search_protein_data(query):
relevant_entries = []
for entry in ds['train']:
if any(keyword in entry['sequence'].lower() for keyword in query.lower().split()):
relevant_entries.append(entry)
return relevant_entries
def extract_parameters(llm_response, dataset_info):
params = {
'sequence_length': 100,
'helix_bias': 0.02,
'strand_bias': 0.02,
'loop_bias': 0.1,
'hydrophobic_target_score': 0
}
return params
def generate_explanation(result, params):
explanation = f"""
생성된 단백질 분석:
- 길이: {params['sequence_length']} 아미노산
- 구조적 특징:
* 알파 나선 비율: {params['helix_bias']*100}%
* 베타 시트 비율: {params['strand_bias']*100}%
* 루프 구조 비율: {params['loop_bias']*100}%
- 특수 기능: {result.get('special_features', '없음')}
"""
return explanation
def protein_diffusion_model(sequence, seq_len, helix_bias, strand_bias, loop_bias,
secondary_structure, aa_bias, aa_bias_potential,
num_steps, noise, hydrophobic_target_score, hydrophobic_potential,
contigs, pssm, seq_mask, str_mask, rewrite_pdb):
dssp_checkpoint = './SEQDIFF_230205_dssp_hotspots_25mask_EQtasks_mod30.pt'
og_checkpoint = './SEQDIFF_221219_equalTASKS_nostrSELFCOND_mod30.pt'
model_args = copy.deepcopy(args)
# make sampler
S = HuggingFace_sampler(args=model_args)
# get random prefix
S.out_prefix = './tmp/'+secrets.token_hex(nbytes=10).upper()
# set args
S.args['checkpoint'] = None
S.args['dump_trb'] = False
S.args['dump_args'] = True
S.args['save_best_plddt'] = True
S.args['T'] = 20
S.args['strand_bias'] = 0.0
S.args['loop_bias'] = 0.0
S.args['helix_bias'] = 0.0
S.args['potentials'] = None
S.args['potential_scale'] = None
S.args['aa_composition'] = None
# get sequence if entered and make sure all chars are valid
alt_aa_dict = {'B':['D','N'],'J':['I','L'],'U':['C'],'Z':['E','Q'],'O':['K']}
if sequence not in ['',None]:
L = len(sequence)
aa_seq = []
for aa in sequence.upper():
if aa in alt_aa_dict.keys():
aa_seq.append(np.random.choice(alt_aa_dict[aa]))
else:
aa_seq.append(aa)
S.args['sequence'] = aa_seq
elif contigs not in ['',None]:
S.args['contigs'] = [contigs]
else:
S.args['contigs'] = [f'{seq_len}']
L = int(seq_len)
print('DEBUG: ',rewrite_pdb)
if rewrite_pdb not in ['',None]:
S.args['pdb'] = rewrite_pdb.name
if seq_mask not in ['',None]:
S.args['inpaint_seq'] = [seq_mask]
if str_mask not in ['',None]:
S.args['inpaint_str'] = [str_mask]
if secondary_structure in ['',None]:
secondary_structure = None
else:
secondary_structure = ''.join(['E' if x == 'S' else x for x in secondary_structure])
if L < len(secondary_structure):
secondary_structure = secondary_structure[:len(sequence)]
elif L == len(secondary_structure):
pass
else:
dseq = L - len(secondary_structure)
secondary_structure += secondary_structure[-1]*dseq
# potentials
potential_list = []
potential_bias_list = []
if aa_bias not in ['',None]:
potential_list.append('aa_bias')
S.args['aa_composition'] = aa_bias
if aa_bias_potential in ['',None]:
aa_bias_potential = 3
potential_bias_list.append(str(aa_bias_potential))
'''
if target_charge not in ['',None]:
potential_list.append('charge')
if charge_potential in ['',None]:
charge_potential = 1
potential_bias_list.append(str(charge_potential))
S.args['target_charge'] = float(target_charge)
if target_ph in ['',None]:
target_ph = 7.4
S.args['target_pH'] = float(target_ph)
'''
if hydrophobic_target_score not in ['',None]:
potential_list.append('hydrophobic')
S.args['hydrophobic_score'] = float(hydrophobic_target_score)
if hydrophobic_potential in ['',None]:
hydrophobic_potential = 3
potential_bias_list.append(str(hydrophobic_potential))
if pssm not in ['',None]:
potential_list.append('PSSM')
potential_bias_list.append('5')
S.args['PSSM'] = pssm.name
if len(potential_list) > 0:
S.args['potentials'] = ','.join(potential_list)
S.args['potential_scale'] = ','.join(potential_bias_list)
# normalise secondary_structure bias from range 0-0.3
S.args['secondary_structure'] = secondary_structure
S.args['helix_bias'] = helix_bias
S.args['strand_bias'] = strand_bias
S.args['loop_bias'] = loop_bias
# set T
if num_steps in ['',None]:
S.args['T'] = 20
else:
S.args['T'] = int(num_steps)
# noise
if 'normal' in noise:
S.args['sample_distribution'] = noise
S.args['sample_distribution_gmm_means'] = [0]
S.args['sample_distribution_gmm_variances'] = [1]
elif 'gmm2' in noise:
S.args['sample_distribution'] = noise
S.args['sample_distribution_gmm_means'] = [-1,1]
S.args['sample_distribution_gmm_variances'] = [1,1]
elif 'gmm3' in noise:
S.args['sample_distribution'] = noise
S.args['sample_distribution_gmm_means'] = [-1,0,1]
S.args['sample_distribution_gmm_variances'] = [1,1,1]
if secondary_structure not in ['',None] or helix_bias+strand_bias+loop_bias > 0:
S.args['checkpoint'] = dssp_checkpoint
S.args['d_t1d'] = 29
print('using dssp checkpoint')
else:
S.args['checkpoint'] = og_checkpoint
S.args['d_t1d'] = 24
print('using og checkpoint')
for k,v in S.args.items():
print(f"{k} --> {v}")
# init S
S.model_init()
S.diffuser_init()
S.setup()
# sampling loop
plddt_data = []
for j in range(S.max_t):
print(f'on step {j}')
output_seq, output_pdb, plddt = S.take_step_get_outputs(j)
plddt_data.append(plddt)
yield output_seq, output_pdb, display_pdb(output_pdb), get_plddt_plot(plddt_data, S.max_t)
output_seq, output_pdb, plddt = S.get_outputs()
return output_seq, output_pdb, display_pdb(output_pdb), get_plddt_plot(plddt_data, S.max_t)
def get_plddt_plot(plddt_data, max_t):
x = [i+1 for i in range(len(plddt_data))]
fig, ax = plt.subplots(figsize=(15,6))
ax.plot(x,plddt_data,color='#661dbf', linewidth=3,marker='o')
ax.set_xticks([i+1 for i in range(max_t)])
ax.set_yticks([(i+1)/10 for i in range(10)])
ax.set_ylim([0,1])
ax.set_ylabel('model confidence (plddt)')
ax.set_xlabel('diffusion steps (t)')
return fig
def display_pdb(path_to_pdb):
'''
#function to display pdb in py3dmol
'''
pdb = open(path_to_pdb, "r").read()
view = py3Dmol.view(width=500, height=500)
view.addModel(pdb, "pdb")
view.setStyle({'model': -1}, {"cartoon": {'colorscheme':{'prop':'b','gradient':'roygb','min':0,'max':1}}})#'linear', 'min': 0, 'max': 1, 'colors': ["#ff9ef0","#a903fc",]}}})
view.zoomTo()
output = view._make_html().replace("'", '"')
print(view._make_html())
x = f""" {output} """ # do not use ' in this input
return f""""""
'''
return f""""""
'''
def get_motif_preview(pdb_id, contigs):
try:
input_pdb = fetch_pdb(pdb_id=pdb_id.lower() if pdb_id else None)
if input_pdb is None:
return gr.HTML("PDB ID를 입력해주세요"), None
parse = parse_pdb(input_pdb)
output_name = input_pdb
pdb = open(output_name, "r").read()
view = py3Dmol.view(width=500, height=500)
view.addModel(pdb, "pdb")
if contigs in ['',0]:
contigs = ['0']
else:
contigs = [contigs]
print('DEBUG: ',contigs)
pdb_map = get_mappings(ContigMap(parse,contigs))
print('DEBUG: ',pdb_map)
print('DEBUG: ',pdb_map['con_ref_idx0'])
roi = [x[1]-1 for x in pdb_map['con_ref_pdb_idx']]
colormap = {0:'#D3D3D3', 1:'#F74CFF'}
colors = {i+1: colormap[1] if i in roi else colormap[0] for i in range(parse['xyz'].shape[0])}
view.setStyle({"cartoon": {"colorscheme": {"prop": "resi", "map": colors}}})
view.zoomTo()
output = view._make_html().replace("'", '"')
print(view._make_html())
x = f""" {output} """ # do not use ' in this input
return f"""""", output_name
except Exception as e:
return gr.HTML(f"오류가 발생했습니다: {str(e)}"), None
def fetch_pdb(pdb_id=None):
if pdb_id is None or pdb_id == "":
return None
else:
os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_id}.pdb")
return f"{pdb_id}.pdb"
# MSA AND PSSM GUIDANCE
def save_pssm(file_upload):
filename = file_upload.name
orig_name = file_upload.orig_name
if filename.split('.')[-1] in ['fasta', 'a3m']:
return msa_to_pssm(file_upload)
return filename
def msa_to_pssm(msa_file):
# Define the lookup table for converting amino acids to indices
aa_to_index = {'A': 0, 'R': 1, 'N': 2, 'D': 3, 'C': 4, 'Q': 5, 'E': 6, 'G': 7, 'H': 8, 'I': 9, 'L': 10,
'K': 11, 'M': 12, 'F': 13, 'P': 14, 'S': 15, 'T': 16, 'W': 17, 'Y': 18, 'V': 19, 'X': 20, '-': 21}
# Open the FASTA file and read the sequences
records = list(SeqIO.parse(msa_file.name, "fasta"))
assert len(records) >= 1, "MSA must contain more than one protein sequecne."
first_seq = str(records[0].seq)
aligned_seqs = [first_seq]
# print(aligned_seqs)
# Perform sequence alignment using the Needleman-Wunsch algorithm
aligner = Align.PairwiseAligner()
aligner.open_gap_score = -0.7
aligner.extend_gap_score = -0.3
for record in records[1:]:
alignment = aligner.align(first_seq, str(record.seq))[0]
alignment = alignment.format().split("\n")
al1 = alignment[0]
al2 = alignment[2]
al1_fin = ""
al2_fin = ""
percent_gap = al2.count('-')/ len(al2)
if percent_gap > 0.4:
continue
for i in range(len(al1)):
if al1[i] != '-':
al1_fin += al1[i]
al2_fin += al2[i]
aligned_seqs.append(str(al2_fin))
# Get the length of the aligned sequences
aligned_seq_length = len(first_seq)
# Initialize the position scoring matrix
matrix = np.zeros((22, aligned_seq_length))
# Iterate through the aligned sequences and count the amino acids at each position
for seq in aligned_seqs:
#print(seq)
for i in range(aligned_seq_length):
if i == len(seq):
break
amino_acid = seq[i]
if amino_acid.upper() not in aa_to_index.keys():
continue
else:
aa_index = aa_to_index[amino_acid.upper()]
matrix[aa_index, i] += 1
# Normalize the counts to get the frequency of each amino acid at each position
matrix /= len(aligned_seqs)
print(len(aligned_seqs))
matrix[20:,]=0
outdir = ".".join(msa_file.name.split('.')[:-1]) + ".csv"
np.savetxt(outdir, matrix[:21,:].T, delimiter=",")
return outdir
def get_pssm(fasta_msa, input_pssm):
try:
if input_pssm is not None:
outdir = input_pssm.name
elif fasta_msa is not None:
outdir = save_pssm(fasta_msa)
else:
return gr.Plot(label="파일을 업로드해주세요"), None
pssm = np.loadtxt(outdir, delimiter=",", dtype=float)
fig, ax = plt.subplots(figsize=(15,6))
plt.imshow(torch.permute(torch.tensor(pssm),(1,0)))
return fig, outdir
except Exception as e:
return gr.Plot(label=f"오류가 발생했습니다: {str(e)}"), None
# 히어로 능력치 계산 함수 추가
def calculate_hero_stats(helix_bias, strand_bias, loop_bias, hydrophobic_score):
stats = {
'strength': strand_bias * 20, # 베타시트 구조 기반
'flexibility': helix_bias * 20, # 알파헬릭스 구조 기반
'speed': loop_bias * 5, # 루프 구조 기반
'defense': abs(hydrophobic_score) if hydrophobic_score else 0
}
return stats
def toggle_seq_input(choice):
if choice == "자동 설계":
return gr.update(visible=True), gr.update(visible=False)
else: # "직접 입력"
return gr.update(visible=False), gr.update(visible=True)
def toggle_secondary_structure(choice):
if choice == "슬라이더로 설정":
return (
gr.update(visible=True), # helix_bias
gr.update(visible=True), # strand_bias
gr.update(visible=True), # loop_bias
gr.update(visible=False) # secondary_structure
)
else: # "직접 입력"
return (
gr.update(visible=False), # helix_bias
gr.update(visible=False), # strand_bias
gr.update(visible=False), # loop_bias
gr.update(visible=True) # secondary_structure
)
def create_radar_chart(stats):
# 레이더 차트 생성 로직
categories = list(stats.keys())
values = list(stats.values())
fig = go.Figure(data=go.Scatterpolar(
r=values,
theta=categories,
fill='toself'
))
fig.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, 1]
)),
showlegend=False
)
return fig
def generate_hero_description(name, stats, abilities):
# 히어로 설명 생성 로직
description = f"""
히어로 이름: {name}
주요 능력:
- 근력: {'★' * int(stats['strength'] * 5)}
- 유연성: {'★' * int(stats['flexibility'] * 5)}
- 스피드: {'★' * int(stats['speed'] * 5)}
- 방어력: {'★' * int(stats['defense'] * 5)}
특수 능력: {', '.join(abilities)}
"""
return description
def combined_generation(name, strength, flexibility, speed, defense, size, abilities,
sequence, seq_len, helix_bias, strand_bias, loop_bias,
secondary_structure, aa_bias, aa_bias_potential,
num_steps, noise, hydrophobic_target_score, hydrophobic_potential,
contigs, pssm, seq_mask, str_mask, rewrite_pdb):
try:
# protein_diffusion_model 실행
generator = protein_diffusion_model(
sequence=None,
seq_len=size, # 히어로 크기를 seq_len으로 사용
helix_bias=flexibility, # 히어로 유연성을 helix_bias로 사용
strand_bias=strength, # 히어로 강도를 strand_bias로 사용
loop_bias=speed, # 히어로 스피드를 loop_bias로 사용
secondary_structure=None,
aa_bias=None,
aa_bias_potential=None,
num_steps="25",
noise="normal",
hydrophobic_target_score=str(-defense), # 히어로 방어력을 hydrophobic score로 사용
hydrophobic_potential="2",
contigs=None,
pssm=None,
seq_mask=None,
str_mask=None,
rewrite_pdb=None
)
# 마지막 결과 가져오기
final_result = None
for result in generator:
final_result = result
if final_result is None:
raise Exception("생성 결과가 없습니다")
output_seq, output_pdb, structure_view, plddt_plot = final_result
# 히어로 능력치 계산
stats = calculate_hero_stats(flexibility, strength, speed, defense)
# 모든 결과 반환
return (
create_radar_chart(stats), # 능력치 차트
generate_hero_description(name, stats, abilities), # 히어로 설명
output_seq, # 단백질 서열
output_pdb, # PDB 파일
structure_view, # 3D 구조
plddt_plot # 신뢰도 차트
)
except Exception as e:
print(f"Error in combined_generation: {str(e)}")
return (
None,
f"에러: {str(e)}",
None,
None,
gr.HTML("에러가 발생했습니다"),
None
)
with gr.Blocks(theme='ParityError/Interstellar') as demo:
with gr.Row():
# 왼쪽 열: 챗봇 및 컨트롤 패널
with gr.Column(scale=1):
# 챗봇 인터페이스
gr.Markdown("# 🤖 AI 단백질 설계 도우미")
chatbot = gr.Chatbot(height=600)
with gr.Accordion("채팅 설정", open=False):
system_message = gr.Textbox(
value="당신은 단백질 설계를 도와주는 전문가입니다.",
label="시스템 메시지"
)
max_tokens = gr.Slider(
minimum=1,
maximum=2048,
value=512,
step=1,
label="최대 토큰 수"
)
temperature = gr.Slider(
minimum=0.1,
maximum=4.0,
value=0.7,
step=0.1,
label="Temperature"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-P"
)
# 탭 인터페이스
with gr.Tabs():
with gr.TabItem("🦸♂️ 히어로 디자인"):
gr.Markdown("""
### ✨ 당신만의 특별한 히어로를 만들어보세요!
각 능력치를 조절하면 히어로의 DNA가 자동으로 설계됩니다.
""")
# 히어로 기본 정보
hero_name = gr.Textbox(
label="히어로 이름",
placeholder="당신의 히어로 이름을 지어주세요!",
info="히어로의 정체성을 나타내는 이름을 입력하세요"
)
# 능력치 설정
gr.Markdown("### 💪 히어로 능력치 설정")
with gr.Row():
strength = gr.Slider(
minimum=0.0, maximum=0.05,
label="💪 초강력(근력)",
value=0.02,
info="단단한 베타시트 구조로 강력한 힘을 생성합니다"
)
flexibility = gr.Slider(
minimum=0.0, maximum=0.05,
label="🤸♂️ 유연성",
value=0.02,
info="나선형 알파헬릭스 구조로 유연한 움직임을 가능하게 합니다"
)
with gr.Row():
speed = gr.Slider(
minimum=0.0, maximum=0.20,
label="⚡ 스피드",
value=0.1,
info="루프 구조로 빠른 움직임을 구현합니다"
)
defense = gr.Slider(
minimum=-10, maximum=10,
label="🛡️ 방어력",
value=0,
info="음수: 수중 활동에 특화, 양수: 지상 활동에 특화"
)
# 히어로 크기 설정
hero_size = gr.Slider(
minimum=50, maximum=200,
label="📏 히어로 크기",
value=100,
info="히어로의 전체적인 크기를 결정합니다"
)
# 특수 능력 설정
with gr.Accordion("🌟 특수 능력", open=False):
gr.Markdown("""
특수 능력을 선택하면 히어로의 DNA에 특별한 구조가 추가됩니다.
- 자가 회복: 단백질 구조 복구 능력 강화
- 원거리 공격: 특수한 구조적 돌출부 형성
- 방어막 생성: 안정적인 보호층 구조 생성
""")
special_ability = gr.CheckboxGroup(
choices=["자가 회복", "원거리 공격", "방어막 생성"],
label="특수 능력 선택"
)
# 생성 버튼
create_btn = gr.Button("🧬 히어로 생성!", variant="primary", scale=2)
with gr.TabItem("🧬 히어로 DNA 설계"):
gr.Markdown("""
### 🧪 히어로 DNA 고급 설정
히어로의 유전자 구조를 더 세밀하게 조정할 수 있습니다.
""")
seq_opt = gr.Radio(
["자동 설계", "직접 입력"],
label="DNA 설계 방식",
value="자동 설계"
)
sequence = gr.Textbox(
label="DNA 시퀀스",
lines=1,
placeholder='사용 가능한 아미노산: A,C,D,E,F,G,H,I,K,L,M,N,P,Q,R,S,T,V,W,Y (X는 무작위)',
visible=False
)
seq_len = gr.Slider(
minimum=5.0, maximum=250.0,
label="DNA 길이",
value=100,
visible=True
)
with gr.Accordion(label='🦴 골격 구조 설정', open=True):
gr.Markdown("""
히어로의 기본 골격 구조를 설정합니다.
- 나선형 구조: 유연하고 탄력있는 움직임
- 병풍형 구조: 단단하고 강력한 힘
- 고리형 구조: 빠르고 민첩한 움직임
""")
sec_str_opt = gr.Radio(
["슬라이더로 설정", "직접 입력"],
label="골격 구조 설정 방식",
value="슬라이더로 설정"
)
secondary_structure = gr.Textbox(
label="골격 구조",
lines=1,
placeholder='H:나선형, S:병풍형, L:고리형, X:자동설정',
visible=False
)
with gr.Column():
helix_bias = gr.Slider(
minimum=0.0, maximum=0.05,
label="나선형 구조 비율",
visible=True
)
strand_bias = gr.Slider(
minimum=0.0, maximum=0.05,
label="병풍형 구조 비율",
visible=True
)
loop_bias = gr.Slider(
minimum=0.0, maximum=0.20,
label="고리형 구조 비율",
visible=True
)
with gr.Accordion(label='🧬 DNA 구성 설정', open=False):
gr.Markdown("""
특정 아미노산의 비율을 조절하여 히어로의 특성을 강화할 수 있습니다.
예시: W0.2,E0.1 (트립토판 20%, 글루탐산 10%)
""")
with gr.Row():
aa_bias = gr.Textbox(
label="아미노산 비율",
lines=1,
placeholder='예시: W0.2,E0.1'
)
aa_bias_potential = gr.Textbox(
label="강화 정도",
lines=1,
placeholder='1.0-5.0 사이 값 입력'
)
with gr.Accordion(label='🌍 환경 적응력 설정', open=False):
gr.Markdown("""
히어로의 환경 적응력을 조절합니다.
음수: 수중 활동에 특화, 양수: 지상 활동에 특화
""")
with gr.Row():
hydrophobic_target_score = gr.Textbox(
label="환경 적응 점수",
lines=1,
placeholder='예시: -5 (수중 활동에 특화)'
)
hydrophobic_potential = gr.Textbox(
label="적응력 강화 정도",
lines=1,
placeholder='1.0-2.0 사이 값 입력'
)
with gr.Accordion(label='⚙️ 고급 설정', open=False):
gr.Markdown("""
DNA 생성 과정의 세부 매개변수를 조정합니다.
""")
with gr.Row():
num_steps = gr.Textbox(
label="생성 단계",
lines=1,
placeholder='25 이하 권장'
)
noise = gr.Dropdown(
['normal','gmm2 [-1,1]','gmm3 [-1,0,1]'],
label='노이즈 타입',
value='normal'
)
design_btn = gr.Button("🧬 DNA 설계 생성!", variant="primary", scale=2)
with gr.TabItem("🧪 히어로 유전자 강화"):
gr.Markdown("""
### ⚡ 기존 히어로의 DNA 활용
강력한 히어로의 DNA 일부를 새로운 히어로에게 이식합니다.
""")
gr.Markdown("공개된 히어로 DNA 데이터베이스에서 코드를 찾을 수 있습니다")
pdb_id_code = gr.Textbox(
label="히어로 DNA 코드",
lines=1,
placeholder='기존 히어로의 DNA 코드를 입력하세요 (예: 1DPX)'
)
gr.Markdown("이식하고 싶은 DNA 영역을 선택하고 새로운 DNA를 추가할 수 있습니다")
contigs = gr.Textbox(
label="이식할 DNA 영역",
lines=1,
placeholder='예시: 15,A3-10,20-30'
)
with gr.Row():
seq_mask = gr.Textbox(
label='능력 재설계',
lines=1,
placeholder='선택한 영역의 능력을 새롭게 디자인'
)
str_mask = gr.Textbox(
label='구조 재설계',
lines=1,
placeholder='선택한 영역의 구조를 새롭게 디자인'
)
preview_viewer = gr.HTML()
rewrite_pdb = gr.File(label='히어로 DNA 파일')
preview_btn = gr.Button("🔍 미리보기", variant="secondary")
enhance_btn = gr.Button("⚡ 강화된 히어로 생성!", variant="primary", scale=2)
with gr.TabItem("👑 히어로 가문"):
gr.Markdown("""
### 🏰 위대한 히어로 가문의 유산
강력한 히어로 가문의 특성을 계승하여 새로운 히어로를 만듭니다.
""")
with gr.Row():
with gr.Column():
gr.Markdown("히어로 가문의 DNA 정보가 담긴 파일을 업로드하세요")
fasta_msa = gr.File(label='가문 DNA 데이터')
with gr.Column():
gr.Markdown("이미 분석된 가문 특성 데이터가 있다면 업로드하세요")
input_pssm = gr.File(label='가문 특성 데이터')
pssm = gr.File(label='분석된 가문 특성')
pssm_view = gr.Plot(label='가문 특성 분석 결과')
pssm_gen_btn = gr.Button("✨ 가문 특성 분석", variant="secondary")
inherit_btn = gr.Button("👑 가문의 힘 계승!", variant="primary", scale=2)
# 오른쪽 열: 결과 표시
with gr.Column(scale=1):
gr.Markdown("## 🦸♂️ 히어로 프로필")
hero_stats = gr.Plot(label="능력치 분석")
hero_description = gr.Textbox(label="히어로 특성", lines=3)
gr.Markdown("## 🧬 히어로 DNA 분석 결과")
gr.Markdown("#### ⚡ DNA 안정성 점수")
plddt_plot = gr.Plot(label='안정성 분석')
gr.Markdown("#### 📝 DNA 시퀀스")
output_seq = gr.Textbox(label="DNA 서열")
gr.Markdown("#### 💾 DNA 데이터")
output_pdb = gr.File(label="DNA 파일")
gr.Markdown("#### 🔬 DNA 구조")
output_viewer = gr.HTML()
# 이벤트 연결
# 챗봇 이벤트
msg.submit(process_chat, [msg, chatbot], [chatbot])
clear.click(lambda: None, None, chatbot, queue=False)
# UI 컨트롤 이벤트
seq_opt.change(
fn=toggle_seq_input,
inputs=[seq_opt],
outputs=[seq_len, sequence],
queue=False
)
sec_str_opt.change(
fn=toggle_secondary_structure,
inputs=[sec_str_opt],
outputs=[helix_bias, strand_bias, loop_bias, secondary_structure],
queue=False
)
preview_btn.click(
get_motif_preview,
inputs=[pdb_id_code, contigs],
outputs=[preview_viewer, rewrite_pdb]
)
pssm_gen_btn.click(
get_pssm,
inputs=[fasta_msa, input_pssm],
outputs=[pssm_view, pssm]
)
# 챗봇 기반 단백질 생성 결과 업데이트
def update_protein_display(chat_response):
if "생성된 단백질 분석" in chat_response:
params = extract_parameters_from_chat(chat_response)
result = generate_protein(params)
return {
hero_stats: create_radar_chart(calculate_hero_stats(params)),
hero_description: chat_response,
output_seq: result[0],
output_pdb: result[1],
output_viewer: display_pdb(result[1]),
plddt_plot: result[3]
}
return None
# 각 생성 버튼 이벤트 연결
for btn in [create_btn, design_btn, enhance_btn, inherit_btn]:
btn.click(
combined_generation,
inputs=[
hero_name, strength, flexibility, speed, defense, hero_size, special_ability,
sequence, seq_len, helix_bias, strand_bias, loop_bias,
secondary_structure, aa_bias, aa_bias_potential,
num_steps, noise, hydrophobic_target_score, hydrophobic_potential,
contigs, pssm, seq_mask, str_mask, rewrite_pdb
],
outputs=[
hero_stats,
hero_description,
output_seq,
output_pdb,
output_viewer,
plddt_plot
]
)
# 챗봇 응답에 따른 결과 업데이트
msg.submit(
update_protein_display,
inputs=[chatbot],
outputs=[hero_stats, hero_description, output_seq, output_pdb, output_viewer, plddt_plot]
)
chat_interface = gr.ChatInterface(
respond,
additional_inputs=[
system_message,
max_tokens,
temperature,
top_p,
],
chatbot=chatbot,
)
# 실행
demo.queue()
demo.launch(debug=True)