Spaces:
Sleeping
Sleeping
File size: 8,727 Bytes
49ed5db 64e8c41 49ed5db 9e875d0 49ed5db 9e875d0 49ed5db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import multiprocessing
from functools import partial
from pathlib import Path
import gradio as gr
import lightning
import torch
from datasets import Dataset
from torch.utils.data import DataLoader
import deepchopper
from deepchopper.deepchopper import default, encode_qual, remove_intervals_and_keep_left, smooth_label_region
from deepchopper.models.llm import (
tokenize_and_align_labels_and_quals,
)
from deepchopper.utils import (
summary_predict,
)
def parse_fq_record(text: str):
"""Parse a single FASTQ record into a dictionary."""
lines = text.strip().split("\n")
for i in range(0, len(lines), 4):
content = lines[i : i + 4]
record_id, seq, _, qual = content
assert len(seq) == len(qual) # noqa: S101
yield {
"id": record_id,
"seq": seq,
"qual": encode_qual(qual, default.KMER_SIZE),
"target": [0, 0],
}
def load_dataset(text: str, tokenizer):
"""Load dataset from text."""
dataset = Dataset.from_generator(parse_fq_record, gen_kwargs={"text": text}).with_format("torch")
tokenized_dataset = dataset.map(
partial(
tokenize_and_align_labels_and_quals,
tokenizer=tokenizer,
max_length=tokenizer.max_len_single_sentence,
),
num_proc=multiprocessing.cpu_count(), # type: ignore
).remove_columns(["id", "seq", "qual", "target"])
return dataset, tokenized_dataset
def predict(
text: str,
smooth_window_size: int = 21,
min_interval_size: int = 13,
approved_interval_number: int = 20,
max_process_intervals: int = 8, # default is 4
batch_size: int = 1,
num_workers: int = 1,
):
tokenizer = deepchopper.models.llm.load_tokenizer_from_hyena_model(model_name="hyenadna-small-32k-seqlen")
dataset, tokenized_dataset = load_dataset(text, tokenizer)
dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, num_workers=num_workers, persistent_workers=True)
model = deepchopper.DeepChopper.from_pretrained("yangliz5/deepchopper")
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
trainer = lightning.pytorch.trainer.Trainer(
accelerator=accelerator,
deterministic=False,
logger=False,
)
predicts = trainer.predict(model=model, dataloaders=dataloader, return_predictions=True)
assert len(predicts) == 1 # noqa: S101
smooth_interval_json: list[dict[str, int]] = []
highlighted_text: list[tuple[str, str | None]] = []
for idx, preds in enumerate(predicts):
true_prediction, _true_label = summary_predict(predictions=preds[0], labels=preds[1])
_id = dataset[idx]["id"]
seq = dataset[idx]["seq"]
smooth_predict_targets = smooth_label_region(
true_prediction[0], smooth_window_size, min_interval_size, approved_interval_number
)
if not smooth_predict_targets or len(smooth_predict_targets) > max_process_intervals:
continue
# zip two consecutive elements
_selected_seqs, selected_intervals = remove_intervals_and_keep_left(seq, smooth_predict_targets)
total_intervals = sorted(selected_intervals + smooth_predict_targets)
smooth_interval_json.extend({"start": i[0], "end": i[1]} for i in smooth_predict_targets)
highlighted_text.extend(
(seq[interval[0] : interval[1]], "ada" if interval in smooth_predict_targets else None)
for interval in total_intervals
)
return smooth_interval_json, highlighted_text
def process_input(text: str | None, file: str | None):
"""Process the input and return the prediction."""
if not text and not file:
gr.Warning("Both text and file are empty")
if file:
MAX_LINES = 4
file_content = []
with Path(file).open() as f:
for idx, line in enumerate(f):
if idx >= MAX_LINES:
break
file_content.append(line)
text = "".join(file_content)
return predict(text=text)
return predict(text=text)
def create_gradio_app():
"""Create a Gradio app for DeepChopper."""
example = (
"@1065:1135|393d635c-64f0-41ed-8531-12174d8efb28+f6a60069-1fcf-4049-8e7c-37523b4e273f\n"
"GCAGCTATGAATGCAAGGCCACAAGGTGGATGGAAGAGTTGTGGAACCAAAGAGCTGTCTTCCAGAGAAGATTTCGAGATAAGTCGCCCATCAGTGAACAAGATATTGTTGGTGGCATTTGATGAGAACGTTCCAAGATTATTGACAGATTAGTGAAAAGTAAGATTGAAATCATGACTGACCGTAAGTGGCAAGAAAGGGCTTTTGCCTTTGTAACCTTTGACGACCATGACTCCGTGGATAAGATTGTCATTCAGAATACCATACTGTGAATGGCCACATCTTTATTGTGAAGTTAGAAAAGCCCTGTCAAAGCAAGAGATGAATCAGTGCTTCTCCAGCCAAAGAGGTCGAAGTGGTTCTGGAAACTTTGGTGGTGGTCGTGGAGGTGGTTTCGGTGGGAATGACAACTCGGTCGTGGAGGAAACTTCAGTGGTCGTGGTGGCTTTGGTGGCAGCCGTGGTGGTGGTGGATATGGTGGCAGTGGGGATGGCTATAATGGATTTGGTAATGATGGAAGCAATTTGGAGGTGGTGGAAGCTACAATGATTTTGGGAATTACAACAATCAGTCTTCAAATTTTGGACCCCTAGGAGGAAATTTTGGTAGAAGCTCTGGCCCCATGGCGGTGGAGGCCAAATACTTTTGCAAACCACGAAACCAAGGTGGCTATGGCGGTCCAGCAGCAGCAGTAGCTATGGCAGTGGCAGAAGATTTTAATTAGGAAACAAAGCTTAGCAGGAGAGGAGAGCCAGAGAAGTGACAGGGAAGTACAGGTTACAACAGATTTGTGAACTCAGCCCAAGCACAGTGGTGGCAGGGCCTAGCTGCTACAAAGAAGACATGTTTTAGACAAATACTCATGTGTATGGGCAAAACTTGAGGACTGTATTTGTGACTAACTGTATAACAGGTTATTTTAGTTTCTGTTTGTGGAAAGTGTAAAGCATTCCAACAAAGGTTTTTAATGTAGATTTTTTTTTTTGCACCCCATGCTGTTGATTTGCTAAATGTAACAGTCTGATCGTGACGCTGAATAAATGTCTTTTTTAAAAAAAAAAAAAAGCTCCCTCCCATCCCCTGCTGCTAACTGATCCCATTATATCTAACCTGCCCCCCCATATCACCTGCTCCCGAGCTACCTAAGAACAGCTAAAAGAGCACACCCGCATGTAGCAAAATAGTGGGAAGATTATAGGTAGAGGCGACAAACCTACCGAGCCTGGTGATAGCTGGTTGTCCTAGATAGAATCTTAGTTCAACTTTAAATTTGCCCACAGAACCCTCTAAATCCCCTTGTAAATTTAACTGTTAGTCCAAAGAGGAACAGCTCTTTGGACACTAGGAAAAAACCTTGTAGAGAGTAAAAAATCAACACCCA\n"
"+\n"
".0==?SSSSSSSSSSSH2216<868;SSSSSSSSSQQSRSIIHEDDESSSSSSJIKMGEKISSJJICCBDQ?;;8:;,**(&$'+501)\"#$()+%&&0<5+*/('%'))))'''$##\"\"\"\"%&--$\"\"\"('%)1L3*'')'#\"#&+*$&\"\"#*(&'''+,,<;9<BHGF//.LKORQSK<###%*-89<FSSSSE=BAFHFDB???3313NN?>=ANOSJDCADHGMOQSSD=7>BRRSPIEEEOQSSQ4->LIC7EE045///03IIJQSSSNGE6('.5??@A@=,,EGRSPKJ<==<556GFLLQRANSSSSSSSSG...*%%%(***(%'3@LOOSSSSM...7BCMMSSSSSSSSSSSSSSSDFIPSSSGGGGPOQLIHIL4103HMSILLNOSSSSSSSSSS22CBCGSHHHHSSSSSSSSD??@<<<:DDDSSSSSSSSSSA@6688OSSSSSROJJKLSNNNMSSSSQPOOSOOQSSSSSRRHIHISSRSSSSSSSSSSSJFF=??@SSQRK:424<444FFG///1S@@@ASNNNNPN:4JMDDLPSSSSSSBA?B?@@+'&'BD**8EDEFQPIMLE$$&',79CSJJPSGA+***DN;3-('&(;>6(()/-,,)%')1FRNNJ-:=>GC;&;CHNFFDCEEKJLFA22/27A.....HSQLHL))8<=?JSSSFGSKIHDDCCEFDAA@CFJKLNL>:9/1>>?OSLK@+HPSA;>>>K;;;;SSSSOQLPPMORSSSSSQSSSSSSS=:9**?D889SSRFFEDKJJJEEDKSSSNNOSSS.---,&*++SSSSQRSSSSQPGED<<89<@GJ999:SSKBBBAJHK=SSSJJKNMGHKKHQA<<>OPKFEAACDHJKMORB/)'((6**)15DA99;JSQSSS2())+J))EGMQOMMKJF>?<<AA620..D..,/112SOIIJSQFNEEEOMF?066=>@4,3;B>87FSSSSSSSSSSSSSSS<<::5658@AHMMSSRECC448/=<<>SSCB:5546;<??KF==;;FFEDFHKKJG):C>=>BJHINJFDPPPPPPPPPPPPPP%'*%$%+-%'(-22&&%('''&&&#\"\"%&'+0,,0;:1&\"\"%'(+++8'**(\"$$#&$'**//.3497$\"3CFHLOSSSSR:887:;;FSSRPRSSS4433$#$%&$$-056>@:;>=@?AHEFEC;*EKMSSRSRRDB>=AFRSSSSBSOOPSMDAABHH976951-9DHPQO/---?@ELSSQSRJHKKBKKLSSLINSOSSQSRIMSSSSSS>?MKIINSSGSSSSSSSQQMK544MJKKNKHGGLFFGBDB?EHIKGD?@DHPPIIF555)&(+,ADSSSSRQSSSQSS=9/0JJMSQSOSSO/97=B@=:>"
)
custom_css = """
.header { text-align: center; margin-bottom: 30px; }
.footer { text-align: center; margin-top: 30px; font-size: 0.8em; color: #666; }
"""
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
gr.HTML(
"""
<div class="header">
<h1>🧬 DeepChopper: DNA Sequence Analysis</h1>
<p>Analyze DNA sequences and detect Adapter sequences</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
text_input = gr.Textbox(
label="Input DNA Sequence", placeholder="Paste your DNA sequence here...", lines=10
)
file_input = gr.File(label="Or upload a FASTQ file")
submit_btn = gr.Button("Analyze", variant="primary")
with gr.Column(scale=1):
json_output = gr.JSON(label="Detected Adapter Regions")
highlighted_text = gr.HighlightedText(label="Highlighted Sequence")
submit_btn.click(fn=process_input, inputs=[text_input, file_input], outputs=[json_output, highlighted_text])
gr.Examples(
examples=[[example]],
inputs=[text_input],
)
gr.HTML(
"""
<div class="footer">
<p>DeepChopper - Powered by AI for DNA sequence analysis</p>
</div>
"""
)
return demo
def main():
"""Launch the Gradio app."""
app = create_gradio_app()
app.launch()
if __name__ == "__main__":
main()
|