|
from transformers import pipeline |
|
from rcsbsearchapi import TextQuery, AttributeQuery, Query |
|
from rcsbsearchapi.search import Sort, SequenceQuery |
|
import os |
|
from dotenv import load_dotenv |
|
from shiny import App, render, ui, reactive |
|
import pandas as pd |
|
import warnings |
|
import re |
|
from UniprotKB_P_Sequence_RCSB_API_test import ProteinQuery, ProteinSearchEngine |
|
import plotly.graph_objects as go |
|
from shinywidgets import output_widget, render_widget |
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
|
class PDBSearchAssistant: |
|
def __init__(self, model_name="google/flan-t5-large"): |
|
|
|
self.pipe = pipeline( |
|
"text2text-generation", |
|
model=model_name, |
|
max_new_tokens=512, |
|
temperature=0.3, |
|
torch_dtype="auto", |
|
device="cpu" |
|
) |
|
|
|
self.prompt_template = """ |
|
Extract specific search parameters from the query, if present: |
|
1. Resolution cutoff (in ร
) |
|
2. Sequence information |
|
3. Specific PDB ID |
|
4. Experimental method (X-RAY, EM, NMR) |
|
|
|
Format: |
|
Resolution: [maximum resolution in ร
, if mentioned] |
|
Sequence: [any sequence mentioned] |
|
PDB_ID: [specific PDB ID if mentioned] |
|
Method: [experimental method if mentioned] |
|
|
|
Examples: |
|
Query: "Find X-ray structures better than 2.5ร
resolution" |
|
Resolution: 2.5 |
|
Sequence: none |
|
PDB_ID: none |
|
Method: X-RAY |
|
|
|
Query: "Show me NMR structures of kinases" |
|
Resolution: none |
|
Sequence: none |
|
PDB_ID: none |
|
Method: NMR |
|
|
|
Now analyze: |
|
Query: {query} |
|
""" |
|
|
|
def search_pdb(self, query): |
|
try: |
|
|
|
formatted_prompt = self.prompt_template.format(query=query) |
|
response = self.pipe(formatted_prompt)[0]['generated_text'] |
|
print("Generated parameters:", response) |
|
|
|
|
|
resolution_limit = None |
|
pdb_id = None |
|
sequence = None |
|
method = None |
|
has_resolution_query = False |
|
resolution_direction = "less" |
|
|
|
|
|
resolution_terms = { |
|
'better': 'less', |
|
'best': 'less', |
|
'highest': 'less', |
|
'good': 'less', |
|
'fine': 'less', |
|
'worse': 'greater', |
|
'worst': 'greater', |
|
'lowest': 'greater', |
|
'poor': 'greater', |
|
'resolution': None, |
|
'รฅ': None, |
|
'angstrom': None, |
|
'than': None, |
|
'under': 'less', |
|
'below': 'less', |
|
'above': 'greater', |
|
'over': 'greater' |
|
} |
|
|
|
|
|
query_lower = query.lower() |
|
|
|
|
|
for term, direction in resolution_terms.items(): |
|
if term in query_lower: |
|
has_resolution_query = True |
|
if direction: |
|
resolution_direction = direction |
|
|
|
|
|
if re.search(r'\d+\.?\d*\s*รฅ?', query_lower): |
|
has_resolution_query = True |
|
|
|
|
|
for line in response.split('\n'): |
|
if 'Resolution:' in line: |
|
value = line.split('Resolution:')[1].strip() |
|
if value.lower() not in ['none', 'n/a'] and has_resolution_query: |
|
try: |
|
|
|
res_value = ''.join(c for c in value if c.isdigit() or c == '.') |
|
resolution_limit = float(res_value) |
|
except ValueError: |
|
pass |
|
elif 'Method:' in line: |
|
value = line.split('Method:')[1].strip() |
|
if value.lower() not in ['none', 'n/a']: |
|
method = value.upper() |
|
elif 'Sequence:' in line: |
|
value = line.split('Sequence:')[1].strip() |
|
if value.lower() not in ['none', 'n/a']: |
|
sequence = value |
|
elif 'PDB_ID:' in line: |
|
value = line.split('PDB_ID:')[1].strip() |
|
if value.lower() not in ['none', 'n/a']: |
|
pdb_id = value |
|
|
|
|
|
queries = [] |
|
|
|
|
|
|
|
query_words = query.split() |
|
for word in query_words: |
|
|
|
if (len(word) >= 25 and |
|
all(c in 'ACDEFGHIKLMNPQRSTVWY' for c in word.upper()) and |
|
sum(c.isupper() for c in word) / len(word) > 0.8): |
|
sequence = word |
|
break |
|
|
|
|
|
if sequence: |
|
if len(sequence) < 25: |
|
print("Warning: Sequence must be at least 25 residues long. Skipping sequence search.") |
|
sequence = None |
|
else: |
|
print(f"Adding sequence search with identity 100% for sequence: {sequence}") |
|
sequence_query = SequenceQuery( |
|
sequence, |
|
identity_cutoff=1.0, |
|
evalue_cutoff=1, |
|
sequence_type="protein" |
|
) |
|
queries.append(sequence_query) |
|
|
|
else: |
|
|
|
clean_query = query.lower() |
|
|
|
|
|
if has_resolution_query: |
|
clean_query = re.sub(r'\d+\.?\d*\s*รฅ?', '', clean_query) |
|
for term in resolution_terms: |
|
clean_query = clean_query.replace(term, '') |
|
|
|
|
|
clean_query = ' '.join(clean_query.split()) |
|
|
|
print("Cleaned query:", clean_query) |
|
|
|
|
|
if clean_query.strip(): |
|
text_query = AttributeQuery( |
|
attribute="struct.title", |
|
operator="contains_phrase", |
|
value=clean_query |
|
) |
|
queries.append(text_query) |
|
|
|
|
|
if resolution_limit and has_resolution_query: |
|
operator = "less_or_equal" if resolution_direction == "less" else "greater_or_equal" |
|
print(f"Adding resolution filter: {operator} {resolution_limit}ร
") |
|
resolution_query = AttributeQuery( |
|
attribute="rcsb_entry_info.resolution_combined", |
|
operator=operator, |
|
value=resolution_limit |
|
) |
|
queries.append(resolution_query) |
|
|
|
|
|
if pdb_id: |
|
print(f"Searching for specific PDB ID: {pdb_id}") |
|
id_query = AttributeQuery( |
|
attribute="rcsb_id", |
|
operator="exact_match", |
|
value=pdb_id.upper() |
|
) |
|
queries = [id_query] |
|
|
|
|
|
if method: |
|
print(f"Adding experimental method filter: {method}") |
|
method_query = AttributeQuery( |
|
attribute="exptl.method", |
|
operator="exact_match", |
|
value=method |
|
) |
|
queries.append(method_query) |
|
|
|
|
|
if queries: |
|
final_query = queries[0] |
|
for q in queries[1:]: |
|
final_query = final_query & q |
|
|
|
print("Final query:", final_query) |
|
|
|
|
|
session = final_query.exec() |
|
results = [] |
|
|
|
|
|
try: |
|
for entry in session: |
|
|
|
if isinstance(entry, str): |
|
result = { |
|
'PDB ID': entry |
|
} |
|
else: |
|
|
|
result = { |
|
'PDB ID': entry.identifier |
|
} |
|
|
|
results.append(result) |
|
except Exception as e: |
|
print(f"Error processing results: {str(e)}") |
|
|
|
if isinstance(entry, str): |
|
results.append({'PDB ID': entry}) |
|
|
|
print(f"Found {len(results)} structures") |
|
return results |
|
|
|
return [] |
|
|
|
except Exception as e: |
|
print(f"Error during search: {str(e)}") |
|
print(f"Error type: {type(e)}") |
|
return [] |
|
|
|
def pdbsummary(name): |
|
|
|
search_engine = ProteinSearchEngine() |
|
|
|
query = ProteinQuery( |
|
name, |
|
max_resolution= 5.0 |
|
) |
|
|
|
results = search_engine.search(query) |
|
|
|
answer = "" |
|
for i, structure in enumerate(results, 1): |
|
answer += f"\n{i}. PDB ID : {structure.pdb_id}\n" |
|
answer += f"\nResolution : {structure.resolution:.2f} A \n" |
|
answer += f"Method : {structure.method}\n Title : {structure.title}\n" |
|
answer += f"Release Date : {structure.release_date}\n Sequence length: {len(structure.sequence)} aa\n" |
|
answer += f" Sequence:\n {structure.sequence}\n" |
|
|
|
return answer |
|
|
|
def create_interactive_table(df): |
|
if df.empty: |
|
return go.Figure() |
|
|
|
|
|
table = go.Figure(data=[go.Table( |
|
header=dict( |
|
values=list(df.columns), |
|
fill_color='paleturquoise', |
|
align='left', |
|
font=dict(size=14), |
|
), |
|
cells=dict( |
|
values=[df[col] for col in df.columns], |
|
align='left', |
|
font=dict(size=13), |
|
height=30 |
|
), |
|
columnwidth=[len(str(max(df[col], key=len))) for col in df.columns] |
|
)]) |
|
|
|
|
|
table.update_layout( |
|
margin=dict(l=0, r=0, t=0, b=0), |
|
height=400, |
|
autosize=True |
|
) |
|
|
|
return table |
|
|
|
|
|
app_ui = ui.page_fluid( |
|
ui.tags.head( |
|
ui.tags.style(""" |
|
.table a { |
|
color: #0d6efd; |
|
text-decoration: none; |
|
} |
|
.table a:hover { |
|
color: #0a58ca; |
|
text-decoration: underline; |
|
} |
|
""") |
|
), |
|
ui.h2("Advanced PDB Structure Search Tool"), |
|
ui.row( |
|
ui.column(12, |
|
ui.input_text("query", "Search Query", |
|
value="Human insulin"), |
|
) |
|
), |
|
ui.row( |
|
ui.column(12, |
|
ui.p("Example queries:"), |
|
ui.tags.ul( |
|
ui.tags.li("Human hemoglobin C resolution better than 2.5ร
"), |
|
ui.tags.li("Find structures containing sequence MNIFEMLRIDEGLRLKIYKDTEGYYTIGIGHLLTKSPSLNAAKSELDKAIGRNTNGVITKDEAEKLFNQDVDAAVRGILRNAKLKPVYDSLDAVRRAALINMVFQMGETGVAGFTNSLRMLQQKRWDEAAVNLAKSRWYNQTPNRAKRVITTFRTGTWDAYKNL"), |
|
|
|
), |
|
) |
|
), |
|
ui.row( |
|
ui.column(12, |
|
ui.input_action_button("search", "Search", class_="btn-primary"), |
|
) |
|
), |
|
ui.row( |
|
ui.column(12, |
|
ui.h4("Search Parameters:"), |
|
ui.output_text("search_conditions"), |
|
) |
|
), |
|
ui.row( |
|
ui.column(12, |
|
ui.h4("Top 10 Results:"), |
|
output_widget("results_table"), |
|
ui.download_button("download", "Download Results") |
|
) |
|
) |
|
) |
|
|
|
def server(input, output, session): |
|
assistant = PDBSearchAssistant() |
|
results_store = reactive.Value([]) |
|
|
|
@reactive.Effect |
|
@reactive.event(input.search) |
|
def _(): |
|
results = assistant.search_pdb(query=input.query()) |
|
results_store.set(results) |
|
|
|
|
|
df = pd.DataFrame(results) |
|
if not df.empty: |
|
df['PDB ID'] = df['PDB ID'].apply( |
|
lambda x: f'<a href="https://www.rcsb.org/3d-view/{x}" target="_blank">{pdbsummary(x)}</a>' |
|
) |
|
|
|
@output |
|
@render_widget |
|
def results_table(): |
|
return create_interactive_table(df) |
|
|
|
@output |
|
@render.text |
|
def search_conditions(): |
|
results = results_store.get() |
|
return f""" |
|
Applied Search Conditions: |
|
- Query: {input.query()} |
|
- Total structures found: {len(results)} |
|
""" |
|
|
|
@output |
|
@render.download(filename="pdb_search_results.csv") |
|
def download(): |
|
df = pd.DataFrame(results_store.get()) |
|
return df.to_csv(index=False) |
|
|
|
app = App(app_ui, server) |
|
|
|
if __name__ == "__main__": |
|
import nest_asyncio |
|
nest_asyncio.apply() |
|
app.run(host="0.0.0.0", port=7860) |