|
import os |
|
import torch |
|
import stat |
|
import re |
|
import time |
|
import argparse |
|
import numpy as np |
|
|
|
from functools import partial |
|
from typing import List, Tuple |
|
|
|
import torch.distributed as dist |
|
from sat.helpers import print_rank0 |
|
from sat import mpu, get_args, get_tokenizer |
|
from utils import AdvancedBaseStrategy, BeamSearchStrategy |
|
from model_utils import MSAGPT, FineTuneMSAGPT |
|
from utils import chat_api |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
py_parser = argparse.ArgumentParser(add_help=False) |
|
py_parser.add_argument("--sampling-strategy", type=str, default="BaseStrategy", help="Type of sampling strategy.") |
|
py_parser.add_argument("--min-gen-length", type=int, default=0, help="The minimum length each blank should generate.") |
|
py_parser.add_argument("--max-gen-length", type=int, default=512, help="The minimum length each blank should generate.") |
|
py_parser.add_argument("--is-valid", action="store_true", help="Print all output generated by beam search strategy.") |
|
py_parser.add_argument("--print-all-beams", action="store_true", help="Print all output generated by beam search strategy.") |
|
py_parser.add_argument("--multiline_stream", action="store_true", help="streaming multiline output.") |
|
py_parser.add_argument("--no-gap", action="store_true", help="do not generate gaps.") |
|
py_parser.add_argument("--from_pretrained", type=str, default="./checkpoints/MSAGPT", help='pretrained ckpt') |
|
py_parser.add_argument("--chinese", action='store_true', help='Chinese interface') |
|
py_parser.add_argument("--stream_chat", action='store_true', help='streaming output') |
|
|
|
|
|
py_parser = MSAGPT.add_model_specific_args(py_parser) |
|
known, args_list = py_parser.parse_known_args() |
|
args = get_args(args_list) |
|
args = argparse.Namespace(**vars(args), **vars(known)) |
|
model, args = MSAGPT.from_pretrained(args.from_pretrained, args, overwrite_args={'model_parallel_size': args.model_parallel_size} if args.model_parallel_size != 1 else {}) |
|
model.eval() |
|
rank = int(os.environ.get('RANK', 0)) |
|
world_size = int(os.environ.get('WORLD_SIZE', 1)) |
|
if torch.cuda.is_available(): |
|
model = model.to('cuda') |
|
from utils import proteinglm_tokenizer |
|
tokenizer = proteinglm_tokenizer() |
|
|
|
end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")] |
|
|
|
invalid_slices = [0,26,28,29,30,31,32] |
|
if args.no_gap: |
|
invalid_slices.append(tokenizer.TokenToId('-')) |
|
if args.sampling_strategy == "BaseStrategy": |
|
assert not args.print_all_beams, "BaseStrategy don't support print all beams." |
|
strategy = AdvancedBaseStrategy( |
|
batch_size=1, invalid_slices = invalid_slices, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, min_gen_length=args.min_gen_length, no_repeat_ngram_size=args.no_repeat_ngram_size, end_tokens=end_tokens |
|
) |
|
elif args.sampling_strategy == "BeamSearchStrategy": |
|
strategy = BeamSearchStrategy( |
|
1, |
|
args.num_beams, |
|
length_penalty=args.length_penalty, |
|
consider_end=True, |
|
end_tokens=end_tokens, |
|
invalid_slices=invalid_slices, |
|
no_repeat_ngram_size=args.no_repeat_ngram_size, |
|
min_gen_length=args.min_gen_length, |
|
deterministic=True |
|
) |
|
else: |
|
raise ValueError(f"unknown strategy {args.sampling_strategy}") |
|
|
|
|
|
|
|
if args.input_source == 'chat': |
|
if args.chinese: |
|
if rank == 0: |
|
print('欢迎使用 MSAGPT-CLI ,输入需要生成虚拟MSA的蛋白序列(或加上少量MSA作为prompt,以"<M>"相连),例如:"PEGKQGDPGIPGEPGPPGPPGPQGARGPPG<M>VTVEFVNSCLIGDMGVDGPPGQQGQPGPPG",其中"PEGKQGDPGIPGEPGPPGPPGPQGARGPPG"为主序列,"VTVEFVNSCLIGDMGVDGPPGQQGQPGPPG"为MSA prompt。 stop 终止程序'.center(20, "*")) |
|
else: |
|
if rank == 0: |
|
print('Welcome to MSAGPT-CLI. Enter the protein sequence you need to generate virtual MSAs (or add a few MSAs as a prompt, connected by "<M>"), for example: "PEGKQGDPGIPGEPGPPGPPGPQGARGPPG<M>VTVEFVNSCLIGDMGVDGPPGQQGQPGPPG", where "PEGKQGDPGIPGEPGPPGPPGPQGARGPPG" is the main sequence, and "VTVEFVNSCLIGDMGVDGPPGQQGQPGPPG" are MSA prompts. Type "stop" to end the program.'.center(20,"*")) |
|
with torch.no_grad(): |
|
while True: |
|
if args.chinese: |
|
if rank == 0: |
|
protein_input = input("请输入需要生成虚拟MSA的蛋白序列(或加上少量MSA作为prompt,以'<M>'相连):") |
|
else: |
|
protein_input = None |
|
else: |
|
if rank == 0: |
|
protein_input = input("Enter the protein sequence you need to generate virtual MSAs (or add a few MSAs as a prompt, connected by '<M>': ") |
|
else: |
|
protein_input = None |
|
if world_size > 1: |
|
torch.distributed.broadcast_object(protein_input, 0) |
|
protein_input = protein_input.strip() |
|
assert protein_input is not None |
|
|
|
if protein_input == 'stop': |
|
break |
|
|
|
try: |
|
response = chat_api( |
|
args=args, |
|
query=protein_input, |
|
model=model, |
|
tokenizer=tokenizer, |
|
strategy=strategy |
|
) |
|
except Exception as e: |
|
print(e) |
|
break |
|
if rank == 0 and not args.stream_chat: |
|
if args.chinese: |
|
print(f"{'生成的MSA'.center(20, '*')}") |
|
else: |
|
print(f"{'Virtual MSA'.center(20, '*')}") |
|
if args.print_all_beams: |
|
for idx, gen in enumerate(response): |
|
out_str = f"Beam: {idx}".center(11,'@') |
|
print(out_str) |
|
for _ in gen: |
|
print(_) |
|
print() |
|
else: |
|
response = response[0] |
|
for _ in response: |
|
print(_) |
|
print() |
|
else: |
|
chat_api( |
|
args=args, |
|
model=model, |
|
tokenizer=tokenizer, |
|
strategy=strategy |
|
) |