C2MV commited on
Commit
32f8cf2
verified
1 Parent(s): f4870a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -65
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, AutoModel
4
  import time
5
  from functools import wraps
6
  import sys
7
- import spaces # Aseg煤rate de que este m贸dulo est茅 disponible y correctamente instalado
8
 
9
  # Decorador para medir el tiempo de ejecuci贸n
10
  def medir_tiempo(func):
@@ -24,118 +24,86 @@ if device == "cpu":
24
  print("Advertencia: CUDA no est谩 disponible. Se usar谩 la CPU, lo que puede ser lento.")
25
 
26
  # Cargar el modelo y el tokenizador
27
- model_name = "yangheng/OmniGenome"
28
 
29
  try:
30
  print("Cargando el tokenizador...")
31
- tokenizer = AutoTokenizer.from_pretrained(model_name)
32
  except ValueError as e:
33
  print(f"Error al cargar el tokenizador: {e}")
34
  sys.exit(1)
35
 
36
  try:
37
  print("Cargando el modelo...")
38
- model = AutoModel.from_pretrained(model_name)
39
  model.to(device)
40
  except Exception as e:
41
  print(f"Error al cargar el modelo: {e}")
42
  sys.exit(1)
43
 
44
- @spaces.GPU(duration=120) # Decorador para asignar GPU durante 120 segundos
45
  @medir_tiempo
46
- def predecir_estructura_rna(secuencias):
47
  """
48
- Funci贸n que predice estructuras secundarias de ARN a partir de secuencias de ARN proporcionadas.
49
  """
50
  try:
51
  if not secuencias.strip():
52
- return "Por favor, ingresa una o m谩s secuencias de ARN v谩lidas."
53
 
54
  # Separar las secuencias por l铆neas y eliminar espacios vac铆os
55
  secuencias_lista = [seq.strip().upper() for seq in secuencias.strip().split('\n') if seq.strip()]
56
  resultados = []
57
 
 
 
 
58
  for seq in secuencias_lista:
59
- # Validar la secuencia de ARN
60
- if not all(residue in 'AUCG' for residue in seq):
61
- resultados.append(f"Secuencia inv谩lida: {seq}. Solo se permiten los nucle贸tidos A, U, C y G.")
62
  continue
63
 
64
- # Tokenizar la secuencia
65
- inputs = tokenizer(seq, return_tensors="pt")
66
- input_ids = inputs["input_ids"].to(device)
67
- attention_mask = inputs["attention_mask"].to(device)
68
-
69
- # Aplicar el modelo para obtener los logits
70
- with torch.no_grad():
71
- outputs = model(input_ids=input_ids, attention_mask=attention_mask)
72
-
73
- # Asumimos que el modelo devuelve logits para cada nucle贸tido que indican la estructura secundaria
74
- # Debes ajustar esto seg煤n la arquitectura espec铆fica de OmniGenome
75
-
76
- # Por ejemplo, supongamos que el modelo tiene una cabeza de clasificaci贸n con N etiquetas
77
- # donde cada etiqueta representa una clase de estructura secundaria (e.g., Helix, Loop, etc.)
78
-
79
- # Obtener las predicciones seleccionando la clase con el logit m谩s alto
80
- predictions = torch.argmax(outputs.logits, dim=-1).squeeze().tolist()
81
-
82
- # Definir el mapeo de clases seg煤n la documentaci贸n del modelo OmniGenome
83
- # Este mapeo debe ajustarse a las clases espec铆ficas que OmniGenome predice
84
- # Por ejemplo:
85
- structure_mapping = {
86
- 0: 'Helix',
87
- 1: 'Loop',
88
- 2: 'Bulge',
89
- 3: 'Internal Loop',
90
- # Agrega m谩s clases si es necesario
91
- }
92
-
93
- # Convertir las predicciones num茅ricas a etiquetas legibles
94
- predicted_structures = [structure_mapping.get(pred, "Unknown") for pred in predictions]
95
-
96
- # Emparejar cada nucle贸tido con su etiqueta de estructura predicha
97
- nucleotide_to_structure = list(zip(list(seq), predicted_structures))
98
 
99
- # Formatear el resultado para mostrarlo en la interfaz
100
- secuencia_resultado = []
101
- for i, (nucleotide, structure) in enumerate(nucleotide_to_structure):
102
- secuencia_resultado.append(f"Posici贸n {i+1} - {nucleotide}: {structure}")
103
 
104
- # Unir las predicciones en un solo string
105
- resultados.append("\n".join(secuencia_resultado))
106
 
107
- # Unir los resultados de todas las secuencias separadas por dos saltos de l铆nea
108
  return "\n\n".join(resultados)
109
 
110
  except Exception as e:
111
  print(f"Error durante la predicci贸n: {e}")
112
- return f"Error al predecir las estructuras de ARN: {e}"
113
 
114
  # Definir la interfaz de Gradio
115
- titulo = "OmniGenome: Predicci贸n de Estructuras Secundarias de ARN"
116
  descripcion = (
117
- "Ingresa una o m谩s secuencias de ARN (una por l铆nea) y obt茅n predicciones de estructuras secundarias para cada nucle贸tido."
118
- " El modelo utilizado es OmniGenome, un modelo de fundamentos basado en transformadores para alineaci贸n secuencia-estructura en tareas gen贸micas."
119
  )
120
 
121
  iface = gr.Interface(
122
- fn=predecir_estructura_rna,
123
  inputs=gr.Textbox(
124
  lines=10,
125
- placeholder="Escribe tus secuencias de ARN aqu铆, una por l铆nea (solo A, U, C, G)...",
126
- label="Secuencias de ARN"
127
  ),
128
- outputs=gr.Textbox(label="Predicciones de Estructuras Secundarias de ARN"),
129
  title=titulo,
130
  description=descripcion,
131
  examples=[
132
  [
133
- "AUGGCUACUUUCG",
134
- "GCGCGAUCGACGUAGCUAGC"
135
  ],
136
  [
137
- "AUAUGCGGUAUCGUACGUA",
138
- "GGAUACGUGAUCGUAGCAGU"
139
  ]
140
  ],
141
  cache_examples=False,
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import pipeline
4
  import time
5
  from functools import wraps
6
  import sys
7
+ from multimolecule import RnaTokenizer, RnaFmModel # Importar clases espec铆ficas de multimolecule
8
 
9
  # Decorador para medir el tiempo de ejecuci贸n
10
  def medir_tiempo(func):
 
24
  print("Advertencia: CUDA no est谩 disponible. Se usar谩 la CPU, lo que puede ser lento.")
25
 
26
  # Cargar el modelo y el tokenizador
27
+ model_name = "multimolecule/mrnafm"
28
 
29
  try:
30
  print("Cargando el tokenizador...")
31
+ tokenizer = RnaTokenizer.from_pretrained(model_name)
32
  except ValueError as e:
33
  print(f"Error al cargar el tokenizador: {e}")
34
  sys.exit(1)
35
 
36
  try:
37
  print("Cargando el modelo...")
38
+ model = RnaFmModel.from_pretrained(model_name)
39
  model.to(device)
40
  except Exception as e:
41
  print(f"Error al cargar el modelo: {e}")
42
  sys.exit(1)
43
 
 
44
  @medir_tiempo
45
+ def predecir_fill_mask(secuencias):
46
  """
47
+ Funci贸n que realiza una predicci贸n de Fill-Mask para las secuencias de ARN proporcionadas.
48
  """
49
  try:
50
  if not secuencias.strip():
51
+ return "Por favor, ingresa una o m谩s secuencias de ARN v谩lidas con <mask> para predecir."
52
 
53
  # Separar las secuencias por l铆neas y eliminar espacios vac铆os
54
  secuencias_lista = [seq.strip().upper() for seq in secuencias.strip().split('\n') if seq.strip()]
55
  resultados = []
56
 
57
+ # Crear el pipeline de fill-mask utilizando el tokenizador y modelo cargados
58
+ fill_mask = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=0 if device == "cuda" else -1)
59
+
60
  for seq in secuencias_lista:
61
+ # Asegurarse de que la secuencia contenga al menos un <mask>
62
+ if "<MASK>" not in seq and "<mask>" not in seq:
63
+ resultados.append(f"Secuencia sin token <mask>: {seq}. Agrega <mask> donde desees predecir.")
64
  continue
65
 
66
+ # Realizar la predicci贸n de fill-mask
67
+ predictions = fill_mask(seq)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ # Formatear las predicciones
70
+ pred_str = ""
71
+ for pred in predictions:
72
+ pred_str += f"Predicci贸n: {pred['sequence']}, Score: {pred['score']:.4f}\n"
73
 
74
+ resultados.append(f"Secuencia: {seq}\n{pred_str}")
 
75
 
 
76
  return "\n\n".join(resultados)
77
 
78
  except Exception as e:
79
  print(f"Error durante la predicci贸n: {e}")
80
+ return f"Error al realizar la predicci贸n: {e}"
81
 
82
  # Definir la interfaz de Gradio
83
+ titulo = "OmniGenome: Predicci贸n de Fill-Mask para Secuencias de ARN"
84
  descripcion = (
85
+ "Ingresa una o m谩s secuencias de ARN (una por l铆nea) con un token <mask> donde deseas realizar la predicci贸n."
86
+ " El modelo utilizado es mRNA-FM de MultiMolecule, un modelo pre-entrenado de lenguaje para secuencias de ARN."
87
  )
88
 
89
  iface = gr.Interface(
90
+ fn=predecir_fill_mask,
91
  inputs=gr.Textbox(
92
  lines=10,
93
+ placeholder="Escribe tus secuencias de ARN aqu铆, una por l铆nea, incluyendo <mask> donde desees predecir...",
94
+ label="Secuencias de ARN con <mask>"
95
  ),
96
+ outputs=gr.Textbox(label="Predicciones de Fill-Mask"),
97
  title=titulo,
98
  description=descripcion,
99
  examples=[
100
  [
101
+ "AUGGCUACUUU<mask>G",
102
+ "GCGCGAU<mask>CGACGUAGCUAGC"
103
  ],
104
  [
105
+ "AUAUGCGGUAUCGU<mask>GUA",
106
+ "GGAUACGUGAU<mask>GCUAGCAGU"
107
  ]
108
  ],
109
  cache_examples=False,