C2MV commited on
Commit
199e65d
·
verified ·
1 Parent(s): 5b7b502

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -37
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, EsmForTokenClassification
4
  import time
5
  from functools import wraps
6
  import sys
@@ -23,25 +23,19 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
23
  if device == "cpu":
24
  print("Advertencia: CUDA no está disponible. Se usará la CPU, lo que puede ser lento.")
25
 
26
- # Definir el mapeo de clases
27
- class_mapping = {
28
- 0: 'Not Binding Site',
29
- 1: 'Binding Site',
30
- }
31
-
32
  # Cargar el modelo y el tokenizador
33
- model_name = "AmelieSchreiber/esm2_t6_8M_UR50D_rna_binding_site_predictor"
34
 
35
  try:
36
  print("Cargando el tokenizador...")
37
- tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
38
  except ValueError as e:
39
  print(f"Error al cargar el tokenizador: {e}")
40
  sys.exit(1)
41
 
42
  try:
43
- print("Cargando el modelo de predicción...")
44
- model = EsmForTokenClassification.from_pretrained(model_name)
45
  model.to(device)
46
  except Exception as e:
47
  print(f"Error al cargar el modelo: {e}")
@@ -49,21 +43,26 @@ except Exception as e:
49
 
50
  @spaces.GPU(duration=120) # Decorador para asignar GPU durante 120 segundos
51
  @medir_tiempo
52
- def predecir_sitios_arn(secuencias):
53
  """
54
- Función que predice sitios de unión de ARN para las secuencias de proteínas proporcionadas.
55
  """
56
  try:
57
  if not secuencias.strip():
58
- return "Por favor, ingresa una o más secuencias válidas."
59
 
60
  # Separar las secuencias por líneas y eliminar espacios vacíos
61
- secuencias_lista = [seq.strip() for seq in secuencias.strip().split('\n') if seq.strip()]
62
  resultados = []
63
 
64
  for seq in secuencias_lista:
 
 
 
 
 
65
  # Tokenizar la secuencia
66
- inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=1290, return_tensors="pt")
67
  input_ids = inputs["input_ids"].to(device)
68
  attention_mask = inputs["attention_mask"].to(device)
69
 
@@ -71,55 +70,72 @@ def predecir_sitios_arn(secuencias):
71
  with torch.no_grad():
72
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
73
 
 
 
 
 
 
 
74
  # Obtener las predicciones seleccionando la clase con el logit más alto
75
  predictions = torch.argmax(outputs.logits, dim=-1).squeeze().tolist()
76
 
77
- # Convertir las predicciones a etiquetas
78
- predicted_labels = [class_mapping.get(pred, "Unknown") for pred in predictions]
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- # Emparejar cada residuo con su etiqueta predicha
81
- residue_to_label = list(zip(list(seq), predicted_labels))
82
 
83
  # Formatear el resultado para mostrarlo en la interfaz
84
  secuencia_resultado = []
85
- for i, (residue, label) in enumerate(residue_to_label):
86
- # Omite los residuos 'PAD' que se agregan durante el padding
87
- if residue != 'PAD':
88
- secuencia_resultado.append(f"Posición {i+1} - {residue}: {label}")
89
-
90
  resultados.append("\n".join(secuencia_resultado))
91
 
 
92
  return "\n\n".join(resultados)
93
 
94
  except Exception as e:
95
  print(f"Error durante la predicción: {e}")
96
- return f"Error al predecir los sitios de ARN: {e}"
97
 
98
  # Definir la interfaz de Gradio
99
- titulo = "ESM-2 para Predicción de Sitios de Unión de ARN"
100
  descripcion = (
101
- "Ingresa una o más secuencias de proteínas (una por línea) y obtén predicciones de sitios de unión de ARN para cada residuo."
102
- " El modelo utilizado es ESM-2, entrenado en el dataset 'S1' de sitios de unión proteína-ARN."
103
  )
104
 
105
  iface = gr.Interface(
106
- fn=predecir_sitios_arn,
107
  inputs=gr.Textbox(
108
  lines=10,
109
- placeholder="Escribe tus secuencias de proteínas aquí, una por línea...",
110
- label="Secuencias de Proteínas"
111
  ),
112
- outputs=gr.Textbox(label="Predicciones de Sitios de Unión de ARN"),
113
  title=titulo,
114
  description=descripcion,
115
  examples=[
116
  [
117
- "VLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTK",
118
- "SQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWF"
119
  ],
120
  [
121
- "MKAILVVLLYTFATANADAVAHVAA",
122
- "GATVQAAEEVTQGVVVVEEVAGGAA"
123
  ]
124
  ],
125
  cache_examples=False,
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModel
4
  import time
5
  from functools import wraps
6
  import sys
 
23
  if device == "cpu":
24
  print("Advertencia: CUDA no está disponible. Se usará la CPU, lo que puede ser lento.")
25
 
 
 
 
 
 
 
26
  # Cargar el modelo y el tokenizador
27
+ model_name = "yangheng/OmniGenome"
28
 
29
  try:
30
  print("Cargando el tokenizador...")
31
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
32
  except ValueError as e:
33
  print(f"Error al cargar el tokenizador: {e}")
34
  sys.exit(1)
35
 
36
  try:
37
+ print("Cargando el modelo...")
38
+ model = AutoModel.from_pretrained(model_name)
39
  model.to(device)
40
  except Exception as e:
41
  print(f"Error al cargar el modelo: {e}")
 
43
 
44
  @spaces.GPU(duration=120) # Decorador para asignar GPU durante 120 segundos
45
  @medir_tiempo
46
+ def predecir_estructura_rna(secuencias):
47
  """
48
+ Función que predice estructuras secundarias de ARN a partir de secuencias de ARN proporcionadas.
49
  """
50
  try:
51
  if not secuencias.strip():
52
+ return "Por favor, ingresa una o más secuencias de ARN válidas."
53
 
54
  # Separar las secuencias por líneas y eliminar espacios vacíos
55
+ secuencias_lista = [seq.strip().upper() for seq in secuencias.strip().split('\n') if seq.strip()]
56
  resultados = []
57
 
58
  for seq in secuencias_lista:
59
+ # Validar la secuencia de ARN
60
+ if not all(residue in 'AUCG' for residue in seq):
61
+ resultados.append(f"Secuencia inválida: {seq}. Solo se permiten los nucleótidos A, U, C y G.")
62
+ continue
63
+
64
  # Tokenizar la secuencia
65
+ inputs = tokenizer(seq, return_tensors="pt")
66
  input_ids = inputs["input_ids"].to(device)
67
  attention_mask = inputs["attention_mask"].to(device)
68
 
 
70
  with torch.no_grad():
71
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
72
 
73
+ # Asumimos que el modelo devuelve logits para cada nucleótido que indican la estructura secundaria
74
+ # Debes ajustar esto según la arquitectura específica de OmniGenome
75
+
76
+ # Por ejemplo, supongamos que el modelo tiene una cabeza de clasificación con N etiquetas
77
+ # donde cada etiqueta representa una clase de estructura secundaria (e.g., Helix, Loop, etc.)
78
+
79
  # Obtener las predicciones seleccionando la clase con el logit más alto
80
  predictions = torch.argmax(outputs.logits, dim=-1).squeeze().tolist()
81
 
82
+ # Definir el mapeo de clases según la documentación del modelo OmniGenome
83
+ # Este mapeo debe ajustarse a las clases específicas que OmniGenome predice
84
+ # Por ejemplo:
85
+ structure_mapping = {
86
+ 0: 'Helix',
87
+ 1: 'Loop',
88
+ 2: 'Bulge',
89
+ 3: 'Internal Loop',
90
+ # Agrega más clases si es necesario
91
+ }
92
+
93
+ # Convertir las predicciones numéricas a etiquetas legibles
94
+ predicted_structures = [structure_mapping.get(pred, "Unknown") for pred in predictions]
95
 
96
+ # Emparejar cada nucleótido con su etiqueta de estructura predicha
97
+ nucleotide_to_structure = list(zip(list(seq), predicted_structures))
98
 
99
  # Formatear el resultado para mostrarlo en la interfaz
100
  secuencia_resultado = []
101
+ for i, (nucleotide, structure) in enumerate(nucleotide_to_structure):
102
+ secuencia_resultado.append(f"Posición {i+1} - {nucleotide}: {structure}")
103
+
104
+ # Unir las predicciones en un solo string
 
105
  resultados.append("\n".join(secuencia_resultado))
106
 
107
+ # Unir los resultados de todas las secuencias separadas por dos saltos de línea
108
  return "\n\n".join(resultados)
109
 
110
  except Exception as e:
111
  print(f"Error durante la predicción: {e}")
112
+ return f"Error al predecir las estructuras de ARN: {e}"
113
 
114
  # Definir la interfaz de Gradio
115
+ titulo = "OmniGenome: Predicción de Estructuras Secundarias de ARN"
116
  descripcion = (
117
+ "Ingresa una o más secuencias de ARN (una por línea) y obtén predicciones de estructuras secundarias para cada nucleótido."
118
+ " El modelo utilizado es OmniGenome, un modelo de fundamentos basado en transformadores para alineación secuencia-estructura en tareas genómicas."
119
  )
120
 
121
  iface = gr.Interface(
122
+ fn=predecir_estructura_rna,
123
  inputs=gr.Textbox(
124
  lines=10,
125
+ placeholder="Escribe tus secuencias de ARN aquí, una por línea (solo A, U, C, G)...",
126
+ label="Secuencias de ARN"
127
  ),
128
+ outputs=gr.Textbox(label="Predicciones de Estructuras Secundarias de ARN"),
129
  title=titulo,
130
  description=descripcion,
131
  examples=[
132
  [
133
+ "AUGGCUACUUUCG",
134
+ "GCGCGAUCGACGUAGCUAGC"
135
  ],
136
  [
137
+ "AUAUGCGGUAUCGUACGUA",
138
+ "GGAUACGUGAUCGUAGCAGU"
139
  ]
140
  ],
141
  cache_examples=False,