query / app.py
lkjjj26's picture
dockerfile
b89b36a
raw
history blame
14.7 kB
from transformers import pipeline
from rcsbsearchapi import TextQuery, AttributeQuery, Query
from rcsbsearchapi.search import Sort, SequenceQuery
import os
from dotenv import load_dotenv
from shiny import App, render, ui, reactive
import pandas as pd
import warnings
import re
from UniprotKB_P_Sequence_RCSB_API_test import ProteinQuery, ProteinSearchEngine
import plotly.graph_objects as go
from shinywidgets import output_widget, render_widget
warnings.filterwarnings('ignore')
# Load environment variables from .env file
load_dotenv()
# os.environ["TRANSFORMERS_CACHE"] = "./transformers_cache"
# os.makedirs("./transformers_cache", exist_ok=True)
class PDBSearchAssistant:
def __init__(self, model_name="google/flan-t5-large"):
# Set up HuggingFace pipeline with better model
self.pipe = pipeline(
"text2text-generation",
model=model_name,
max_new_tokens=512,
temperature=0.3,
torch_dtype="auto",
device="cpu"
)
self.prompt_template = """
Extract specific search parameters from the query, if present:
1. Resolution cutoff (in ร…)
2. Sequence information
3. Specific PDB ID
4. Experimental method (X-RAY, EM, NMR)
Format:
Resolution: [maximum resolution in ร…, if mentioned]
Sequence: [any sequence mentioned]
PDB_ID: [specific PDB ID if mentioned]
Method: [experimental method if mentioned]
Examples:
Query: "Find X-ray structures better than 2.5ร… resolution"
Resolution: 2.5
Sequence: none
PDB_ID: none
Method: X-RAY
Query: "Show me NMR structures of kinases"
Resolution: none
Sequence: none
PDB_ID: none
Method: NMR
Now analyze:
Query: {query}
"""
def search_pdb(self, query):
try:
# Get search parameters from LLM
formatted_prompt = self.prompt_template.format(query=query)
response = self.pipe(formatted_prompt)[0]['generated_text']
print("Generated parameters:", response)
# Parse LLM response
resolution_limit = None
pdb_id = None
sequence = None
method = None
has_resolution_query = False
resolution_direction = "less"
# Check if query contains resolution-related terms
resolution_terms = {
'better': 'less',
'best': 'less',
'highest': 'less',
'good': 'less',
'fine': 'less',
'worse': 'greater',
'worst': 'greater',
'lowest': 'greater',
'poor': 'greater',
'resolution': None,
'รฅ': None,
'angstrom': None,
'than': None,
'under': 'less',
'below': 'less',
'above': 'greater',
'over': 'greater'
}
# Check if the original query mentions resolution
query_lower = query.lower()
# Determine resolution direction from query
for term, direction in resolution_terms.items():
if term in query_lower:
has_resolution_query = True
if direction: # if not None
resolution_direction = direction
# Also check for numerical values with ร…
if re.search(r'\d+\.?\d*\s*รฅ?', query_lower):
has_resolution_query = True
# Clean and parse LLM response
for line in response.split('\n'):
if 'Resolution:' in line:
value = line.split('Resolution:')[1].strip()
if value.lower() not in ['none', 'n/a'] and has_resolution_query:
try:
# Extract just the number
res_value = ''.join(c for c in value if c.isdigit() or c == '.')
resolution_limit = float(res_value)
except ValueError:
pass
elif 'Method:' in line:
value = line.split('Method:')[1].strip()
if value.lower() not in ['none', 'n/a']:
method = value.upper()
elif 'Sequence:' in line:
value = line.split('Sequence:')[1].strip()
if value.lower() not in ['none', 'n/a']:
sequence = value
elif 'PDB_ID:' in line:
value = line.split('PDB_ID:')[1].strip()
if value.lower() not in ['none', 'n/a']:
pdb_id = value
# Build search query
queries = []
# Check if the query contains a protein sequence pattern
# Check for amino acid sequence (minimum 25 residues)
query_words = query.split()
for word in query_words:
# Check if the word consists of valid amino acid letters
if (len(word) >= 25 and # minimum 25 residues requirement
all(c in 'ACDEFGHIKLMNPQRSTVWY' for c in word.upper()) and
sum(c.isupper() for c in word) / len(word) > 0.8):
sequence = word
break
# If sequence is found, use SequenceQuery
if sequence:
if len(sequence) < 25:
print("Warning: Sequence must be at least 25 residues long. Skipping sequence search.")
sequence = None
else:
print(f"Adding sequence search with identity 100% for sequence: {sequence}")
sequence_query = SequenceQuery(
sequence,
identity_cutoff=1.0, # 100% identity
evalue_cutoff=1,
sequence_type="protein"
)
queries.append(sequence_query)
# If no sequence, proceed with text search
else:
# Clean the original query and add text search
clean_query = query.lower()
# Remove resolution numbers and terms if they exist
if has_resolution_query:
clean_query = re.sub(r'\d+\.?\d*\s*รฅ?', '', clean_query)
for term in resolution_terms:
clean_query = clean_query.replace(term, '')
# Clean up extra spaces and trim
clean_query = ' '.join(clean_query.split())
print("Cleaned query:", clean_query)
# Add text search if query is not empty
if clean_query.strip():
text_query = AttributeQuery(
attribute="struct.title",
operator="contains_phrase",
value=clean_query
)
queries.append(text_query)
# Add resolution filter if specified
if resolution_limit and has_resolution_query:
operator = "less_or_equal" if resolution_direction == "less" else "greater_or_equal"
print(f"Adding resolution filter: {operator} {resolution_limit}ร…")
resolution_query = AttributeQuery(
attribute="rcsb_entry_info.resolution_combined",
operator=operator,
value=resolution_limit
)
queries.append(resolution_query)
# Add PDB ID search if specified
if pdb_id:
print(f"Searching for specific PDB ID: {pdb_id}")
id_query = AttributeQuery(
attribute="rcsb_id",
operator="exact_match",
value=pdb_id.upper()
)
queries = [id_query] # Override other queries for direct PDB ID search
# Add experimental method filter if specified
if method:
print(f"Adding experimental method filter: {method}")
method_query = AttributeQuery(
attribute="exptl.method",
operator="exact_match",
value=method
)
queries.append(method_query)
# Combine queries with AND operator
if queries:
final_query = queries[0]
for q in queries[1:]:
final_query = final_query & q
print("Final query:", final_query)
# Execute search
session = final_query.exec()
results = []
# Process results safely with additional information
try:
for entry in session:
# Handle both string and object types
if isinstance(entry, str):
result = {
'PDB ID': entry
}
else:
# Handle object type
result = {
'PDB ID': entry.identifier
}
results.append(result)
except Exception as e:
print(f"Error processing results: {str(e)}")
# If error occurs during processing, at least return PDB IDs
if isinstance(entry, str):
results.append({'PDB ID': entry})
print(f"Found {len(results)} structures")
return results
return []
except Exception as e:
print(f"Error during search: {str(e)}")
print(f"Error type: {type(e)}")
return []
def pdbsummary(name):
search_engine = ProteinSearchEngine()
query = ProteinQuery(
name,
max_resolution= 5.0
)
results = search_engine.search(query)
answer = ""
for i, structure in enumerate(results, 1):
answer += f"\n{i}. PDB ID : {structure.pdb_id}\n"
answer += f"\nResolution : {structure.resolution:.2f} A \n"
answer += f"Method : {structure.method}\n Title : {structure.title}\n"
answer += f"Release Date : {structure.release_date}\n Sequence length: {len(structure.sequence)} aa\n"
answer += f" Sequence:\n {structure.sequence}\n"
return answer
def create_interactive_table(df):
if df.empty:
return go.Figure()
# Create interactive table
table = go.Figure(data=[go.Table(
header=dict(
values=list(df.columns),
fill_color='paleturquoise',
align='left',
font=dict(size=14),
),
cells=dict(
values=[df[col] for col in df.columns],
align='left',
font=dict(size=13),
height=30
),
columnwidth=[len(str(max(df[col], key=len))) for col in df.columns]
)])
# Update table layout
table.update_layout(
margin=dict(l=0, r=0, t=0, b=0),
height=400,
autosize=True
)
return table
# Simplified Shiny app UI definition
app_ui = ui.page_fluid(
ui.tags.head(
ui.tags.style("""
.table a {
color: #0d6efd;
text-decoration: none;
}
.table a:hover {
color: #0a58ca;
text-decoration: underline;
}
""")
),
ui.h2("Advanced PDB Structure Search Tool"),
ui.row(
ui.column(12,
ui.input_text("query", "Search Query",
value="Human insulin"),
)
),
ui.row(
ui.column(12,
ui.p("Example queries:"),
ui.tags.ul(
ui.tags.li("Human hemoglobin C resolution better than 2.5ร…"),
ui.tags.li("Find structures containing sequence MNIFEMLRIDEGLRLKIYKDTEGYYTIGIGHLLTKSPSLNAAKSELDKAIGRNTNGVITKDEAEKLFNQDVDAAVRGILRNAKLKPVYDSLDAVRRAALINMVFQMGETGVAGFTNSLRMLQQKRWDEAAVNLAKSRWYNQTPNRAKRVITTFRTGTWDAYKNL"),
),
)
),
ui.row(
ui.column(12,
ui.input_action_button("search", "Search", class_="btn-primary"),
)
),
ui.row(
ui.column(12,
ui.h4("Search Parameters:"),
ui.output_text("search_conditions"),
)
),
ui.row(
ui.column(12,
ui.h4("Top 10 Results:"),
output_widget("results_table"),
ui.download_button("download", "Download Results")
)
)
)
def server(input, output, session):
assistant = PDBSearchAssistant()
results_store = reactive.Value([])
@reactive.Effect
@reactive.event(input.search)
def _():
results = assistant.search_pdb(query=input.query())
results_store.set(results)
# Convert results to DataFrame and add hyperlinks
df = pd.DataFrame(results)
if not df.empty:
df['PDB ID'] = df['PDB ID'].apply(
lambda x: f'<a href="https://www.rcsb.org/3d-view/{x}" target="_blank">{pdbsummary(x)}</a>'
)
@output
@render_widget
def results_table():
return create_interactive_table(df) # id ์ˆœ์œผ๋กœ ์ •๋ ฌ๋˜๋Š”๊ฑฐ์ธ๋“ฏ Top rank ์ˆœ์€ ์•„๋‹˜
@output
@render.text
def search_conditions():
results = results_store.get()
return f"""
Applied Search Conditions:
- Query: {input.query()}
- Total structures found: {len(results)}
"""
@output
@render.download(filename="pdb_search_results.csv")
def download():
df = pd.DataFrame(results_store.get())
return df.to_csv(index=False)
app = App(app_ui, server)
if __name__ == "__main__":
import nest_asyncio
nest_asyncio.apply()
app.run(host="0.0.0.0", port=7860)