from transformers import pipeline
from rcsbsearchapi import TextQuery, AttributeQuery, Query
from rcsbsearchapi.search import Sort, SequenceQuery
import os
from dotenv import load_dotenv
from shiny import App, render, ui, reactive
import pandas as pd
import warnings
import re
from  UniprotKB_P_Sequence_RCSB_API_test import ProteinQuery, ProteinSearchEngine
import plotly.graph_objects as go
from shinywidgets import output_widget, render_widget
warnings.filterwarnings('ignore')

# Load environment variables from .env file
load_dotenv()

class PDBSearchAssistant:
    def __init__(self, model_name="google/flan-t5-large"):
        # Set up HuggingFace pipeline with better model
        self.pipe = pipeline(
            "text2text-generation",
            model=model_name,
            max_new_tokens=512,
            temperature=0.3,
            torch_dtype="auto",
            device="cpu"
        )
        
        self.prompt_template = """
            Extract specific search parameters from the query, if present:
            1. Resolution cutoff (in Å)
            2. Sequence information
            3. Specific PDB ID
            4. Experimental method (X-RAY, EM, NMR)

            Format:
            Resolution: [maximum resolution in Å, if mentioned]
            Sequence: [any sequence mentioned]
            PDB_ID: [specific PDB ID if mentioned]
            Method: [experimental method if mentioned]

            Examples:
            Query: "Find X-ray structures better than 2.5Å resolution"
            Resolution: 2.5
            Sequence: none
            PDB_ID: none
            Method: X-RAY

            Query: "Show me NMR structures of kinases"
            Resolution: none
            Sequence: none
            PDB_ID: none
            Method: NMR

            Now analyze:
            Query: {query}
            """

    def search_pdb(self, query):
        try:
            # Get search parameters from LLM
            formatted_prompt = self.prompt_template.format(query=query)
            response = self.pipe(formatted_prompt)[0]['generated_text']
            print("Generated parameters:", response)
            
            # Parse LLM response
            resolution_limit = None
            pdb_id = None
            sequence = None
            method = None
            has_resolution_query = False
            resolution_direction = "less"
            
            # Check if query contains resolution-related terms
            resolution_terms = {
                'better': 'less',
                'best': 'less',
                'highest': 'less',
                'good': 'less',
                'fine': 'less',
                'worse': 'greater',
                'worst': 'greater',
                'lowest': 'greater',
                'poor': 'greater',
                'resolution': None,
                'å': None,
                'angstrom': None,
                'than': None,
                'under': 'less',
                'below': 'less',
                'above': 'greater',
                'over': 'greater'
            }
            
            # Check if the original query mentions resolution
            query_lower = query.lower()
            
            # Determine resolution direction from query
            for term, direction in resolution_terms.items():
                if term in query_lower:
                    has_resolution_query = True
                    if direction:  # if not None
                        resolution_direction = direction
            
            # Also check for numerical values with Å
            if re.search(r'\d+\.?\d*\s*å?', query_lower):
                has_resolution_query = True
            
            # Clean and parse LLM response
            for line in response.split('\n'):
                if 'Resolution:' in line:
                    value = line.split('Resolution:')[1].strip()
                    if value.lower() not in ['none', 'n/a'] and has_resolution_query:
                        try:
                            # Extract just the number
                            res_value = ''.join(c for c in value if c.isdigit() or c == '.')
                            resolution_limit = float(res_value)
                        except ValueError:
                            pass
                elif 'Method:' in line:
                    value = line.split('Method:')[1].strip()
                    if value.lower() not in ['none', 'n/a']:
                        method = value.upper()
                elif 'Sequence:' in line:
                    value = line.split('Sequence:')[1].strip()
                    if value.lower() not in ['none', 'n/a']:
                        sequence = value
                elif 'PDB_ID:' in line:
                    value = line.split('PDB_ID:')[1].strip()
                    if value.lower() not in ['none', 'n/a']:
                        pdb_id = value
            
            # Build search query
            queries = []
            
            # Check if the query contains a protein sequence pattern
            # Check for amino acid sequence (minimum 25 residues)
            query_words = query.split()
            for word in query_words:
                # Check if the word consists of valid amino acid letters
                if (len(word) >= 25 and  # minimum 25 residues requirement
                    all(c in 'ACDEFGHIKLMNPQRSTVWY' for c in word.upper()) and
                    sum(c.isupper() for c in word) / len(word) > 0.8):
                    sequence = word
                    break
            
            # If sequence is found, use SequenceQuery
            if sequence:
                if len(sequence) < 25:
                    print("Warning: Sequence must be at least 25 residues long. Skipping sequence search.")
                    sequence = None
                else:
                    print(f"Adding sequence search with identity 100% for sequence: {sequence}")
                    sequence_query = SequenceQuery(
                        sequence,
                        identity_cutoff=1.0,  # 100% identity
                        evalue_cutoff=1,
                        sequence_type="protein"
                    )
                    queries.append(sequence_query)
            # If no sequence, proceed with text search
            else:
                # Clean the original query and add text search
                clean_query = query.lower()
                
                # Remove resolution numbers and terms if they exist
                if has_resolution_query:
                    clean_query = re.sub(r'\d+\.?\d*\s*å?', '', clean_query)
                    for term in resolution_terms:
                        clean_query = clean_query.replace(term, '')
                
                # Clean up extra spaces and trim
                clean_query = ' '.join(clean_query.split())
                
                print("Cleaned query:", clean_query)
                
                # Add text search if query is not empty
                if clean_query.strip():
                    text_query = AttributeQuery(
                        attribute="struct.title",
                        operator="contains_phrase",
                        value=clean_query
                    )
                    queries.append(text_query)
            
            # Add resolution filter if specified
            if resolution_limit and has_resolution_query:
                operator = "less_or_equal" if resolution_direction == "less" else "greater_or_equal"
                print(f"Adding resolution filter: {operator} {resolution_limit}Å")
                resolution_query = AttributeQuery(
                    attribute="rcsb_entry_info.resolution_combined",
                    operator=operator,
                    value=resolution_limit
                )
                queries.append(resolution_query)
            
            # Add PDB ID search if specified
            if pdb_id:
                print(f"Searching for specific PDB ID: {pdb_id}")
                id_query = AttributeQuery(
                    attribute="rcsb_id",
                    operator="exact_match",
                    value=pdb_id.upper()
                )
                queries = [id_query]  # Override other queries for direct PDB ID search
            
            # Add experimental method filter if specified
            if method:
                print(f"Adding experimental method filter: {method}")
                method_query = AttributeQuery(
                    attribute="exptl.method",
                    operator="exact_match",
                    value=method
                )
                queries.append(method_query)
            
            # Combine queries with AND operator
            if queries:
                final_query = queries[0]
                for q in queries[1:]:
                    final_query = final_query & q
                
                print("Final query:", final_query)
                
                # Execute search
                session = final_query.exec()
                results = []
                
                # Process results safely with additional information
                try:
                    for entry in session:
                        # Handle both string and object types
                        if isinstance(entry, str):
                            result = {
                                'PDB ID': entry
                            }
                        else:
                            # Handle object type
                            result = {
                                'PDB ID': entry.identifier
                            }
                        
                        results.append(result)
                except Exception as e:
                    print(f"Error processing results: {str(e)}")
                    # If error occurs during processing, at least return PDB IDs
                    if isinstance(entry, str):
                        results.append({'PDB ID': entry})
                
                print(f"Found {len(results)} structures")
                return results
            
            return []
            
        except Exception as e:
            print(f"Error during search: {str(e)}")
            print(f"Error type: {type(e)}")
            return []

def pdbsummary(name):

    search_engine = ProteinSearchEngine()

    query = ProteinQuery(
        name,
        max_resolution= 5.0
    )

    results = search_engine.search(query)

    answer = ""
    for i, structure in enumerate(results, 1):
        answer += f"\n{i}. PDB ID : {structure.pdb_id}\n"
        answer += f"\nResolution : {structure.resolution:.2f} A \n"
        answer += f"Method : {structure.method}\n Title : {structure.title}\n"
        answer += f"Release Date : {structure.release_date}\n Sequence length: {len(structure.sequence)} aa\n"
        answer += f"    Sequence:\n {structure.sequence}\n"

    return answer

def create_interactive_table(df):
    if df.empty:
        return go.Figure()
    
    # Create interactive table
    table = go.Figure(data=[go.Table(
        header=dict(
            values=list(df.columns),
            fill_color='paleturquoise',
            align='left',
            font=dict(size=14),
        ),
        cells=dict(
            values=[df[col] for col in df.columns],
            align='left',
            font=dict(size=13),
            height=30
        ),
        columnwidth=[len(str(max(df[col], key=len))) for col in df.columns]
    )])
    
    # Update table layout
    table.update_layout(
        margin=dict(l=0, r=0, t=0, b=0),
        height=400,
        autosize=True
    )
    
    return table

# Simplified Shiny app UI definition
app_ui = ui.page_fluid(
    ui.tags.head(
        ui.tags.style("""
            .table a {
                color: #0d6efd;
                text-decoration: none;
            }
            .table a:hover {
                color: #0a58ca;
                text-decoration: underline;
            }
        """)
    ),
    ui.h2("Advanced PDB Structure Search Tool"),
    ui.row(
        ui.column(12,
            ui.input_text("query", "Search Query", 
                         value="Human insulin"),
        )
    ),
    ui.row(
        ui.column(12,
            ui.p("Example queries:"),
            ui.tags.ul(
                ui.tags.li("Human hemoglobin C resolution better than 2.5Å"),
                ui.tags.li("Find structures containing sequence MNIFEMLRIDEGLRLKIYKDTEGYYTIGIGHLLTKSPSLNAAKSELDKAIGRNTNGVITKDEAEKLFNQDVDAAVRGILRNAKLKPVYDSLDAVRRAALINMVFQMGETGVAGFTNSLRMLQQKRWDEAAVNLAKSRWYNQTPNRAKRVITTFRTGTWDAYKNL"),

            ),
        )
    ),
    ui.row(
        ui.column(12,
            ui.input_action_button("search", "Search", class_="btn-primary"),
        )
    ),
    ui.row(
        ui.column(12,
            ui.h4("Search Parameters:"),
            ui.output_text("search_conditions"),
        )
    ),
    ui.row(
        ui.column(12,
            ui.h4("Top 10 Results:"),
            output_widget("results_table"),
            ui.download_button("download", "Download Results")
        )
    )
)

def server(input, output, session):
    assistant = PDBSearchAssistant()
    results_store = reactive.Value([])
    
    @reactive.Effect
    @reactive.event(input.search)
    def _():
        results = assistant.search_pdb(query=input.query())
        results_store.set(results)
        
        # Convert results to DataFrame and add hyperlinks
        df = pd.DataFrame(results)
        if not df.empty:
            df['PDB ID'] = df['PDB ID'].apply(
                lambda x: f'<a href="https://www.rcsb.org/3d-view/{x}" target="_blank">{x}</a>'
            )
        
        @output
        @render_widget
        def results_table():
            return create_interactive_table(df) # id 순으로 정렬되는거인듯 Top rank 순은 아님
    
    @output
    @render.text
    def search_conditions():
        results = results_store.get()
        return f"""
        Applied Search Conditions:
        - Query: {input.query()}
        - Total structures found: {len(results)}
        """
    
    @output
    @render.download(filename="pdb_search_results.csv")
    def download():
        df = pd.DataFrame(results_store.get())
        return df.to_csv(index=False)

app = App(app_ui, server)

if __name__ == "__main__":
    import nest_asyncio
    nest_asyncio.apply()
    app.run(host="0.0.0.0", port=7860)