File size: 6,323 Bytes
1ccf8ee
c48497c
4fb4636
56dc0d1
1ccf8ee
c48497c
 
 
1ccf8ee
c48497c
 
4fb4636
 
 
 
49be262
4fb4636
 
 
 
56dc0d1
4fb4636
c48497c
1ccf8ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49be262
56dc0d1
1ccf8ee
 
 
 
 
56dc0d1
 
 
 
 
 
1ccf8ee
56dc0d1
 
49be262
1ccf8ee
 
49be262
 
1ccf8ee
 
 
 
 
 
49be262
1ccf8ee
 
 
 
 
49be262
1ccf8ee
 
 
 
 
49be262
1ccf8ee
49be262
 
 
 
56dc0d1
 
 
1ccf8ee
56dc0d1
1ccf8ee
 
56dc0d1
 
1ccf8ee
 
 
49be262
4fb4636
c48497c
 
 
4fb4636
 
 
 
 
 
 
 
 
 
 
c48497c
4fb4636
 
 
c48497c
 
 
1ccf8ee
 
 
 
 
 
 
 
cdd0fc4
 
1ccf8ee
 
 
 
 
 
49be262
1ccf8ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c48497c
 
1ccf8ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os
import mne
import numpy as np
import pandas as pd
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Load LLM
model_name = "tiiuae/falcon-7b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    torch_dtype=torch.float16,
    device_map="auto"
)

def compute_band_power(psd, freqs, fmin, fmax):
    freq_mask = (freqs >= fmin) & (freqs <= fmax)
    band_psd = psd[:, freq_mask].mean()
    return float(band_psd)

def inspect_file(file):
    """
    Inspect the uploaded file to determine available columns.
    If FIF: Just inform that it's an MNE file and no time column is needed.
    If CSV: Return a list of columns (both numeric and non-numeric).
    """
    if file is None:
        return "No file uploaded.", [], "No preview available."
    file_path = file.name
    _, file_ext = os.path.splitext(file_path)

    file_ext = file_ext.lower()
    if file_ext == ".fif":
        # FIF files: We know they're MNE compatible
        # No columns to choose from, just proceed with default analysis
        return (
            "FIF file detected. No need for time column selection. Default sampling frequency will be read from file.",
            [],
            "FIF file doesn't require further inspection."
        )
    elif file_ext == ".csv":
        # Read a small portion of the CSV to determine columns
        try:
            df = pd.read_csv(file_path, nrows=5)
        except Exception as e:
            return f"Error reading CSV: {e}", [], "Could not read CSV preview."

        cols = list(df.columns)
        preview = df.head().to_markdown()
        return (
            "CSV file detected. Select a time column if available, or leave it blank and specify a default frequency.",
            cols,
            preview
        )
    else:
        return "Unsupported file format.", [], "No preview available."


def load_eeg_data(file_path, default_sfreq=256.0, time_col='time'):
    """
    Load EEG data with flexibility.
    If FIF: Use MNE's read_raw_fif directly.
    If CSV: 
      - If time_col is given and present in the file, use it.
      - Otherwise, assume default_sfreq.
    """
    _, file_ext = os.path.splitext(file_path)
    file_ext = file_ext.lower()

    if file_ext == '.fif':
        raw = mne.io.read_raw_fif(file_path, preload=True)

    elif file_ext == '.csv':
        df = pd.read_csv(file_path)

        # If time_col is specified and in df, use it to compute sfreq
        if time_col and time_col in df.columns:
            time = df[time_col].values
            data_df = df.drop(columns=[time_col])

            # Drop non-numeric columns
            for col in data_df.columns:
                if not pd.api.types.is_numeric_dtype(data_df[col]):
                    data_df = data_df.drop(columns=[col])

            if len(time) < 2:
                # Not enough time points, fallback to default_sfreq
                sfreq = default_sfreq
            else:
                # Compute sfreq from time
                sfreq = 1.0 / np.mean(np.diff(time))
        else:
            # No time column used, assume default_sfreq
            # Drop non-numeric columns
            for col in df.columns:
                if not pd.api.types.is_numeric_dtype(df[col]):
                    df = df.drop(columns=[col])
            data_df = df
            sfreq = default_sfreq

        ch_names = list(data_df.columns)
        data = data_df.values.T  # shape: (n_channels, n_samples)

        ch_types = ['eeg'] * len(ch_names)
        info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
        raw = mne.io.RawArray(data, info)

    else:
        raise ValueError("Unsupported file format. Provide a FIF or CSV file.")

    return raw

def analyze_eeg(file, default_sfreq, time_col):
    if file is None:
        return "No file uploaded."
    raw = load_eeg_data(file.name, default_sfreq=float(default_sfreq), time_col=time_col)

    psd, freqs = mne.time_frequency.psd_welch(raw, fmin=1, fmax=40)
    alpha_power = compute_band_power(psd, freqs, 8, 12)
    beta_power = compute_band_power(psd, freqs, 13, 30)

    data_summary = (
        f"Alpha power: {alpha_power:.3f}, Beta power: {beta_power:.3f}. "
        f"The EEG shows stable alpha rhythms and slightly elevated beta activity."
    )

    prompt = f"""You are a neuroscientist analyzing EEG features.
Data Summary: {data_summary}

Provide a concise, user-friendly interpretation of these findings in simple terms.
"""
    inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(
        inputs, max_length=200, do_sample=True, top_k=50, top_p=0.95
    )
    summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return summary


#########################
# BUILD THE GRADIO INTERFACE
#########################

# Step 1: Inspect file
def preview_file(file):
    msg, cols, preview = inspect_file(file)
    # Instead of gr.Dropdown.update(...)
    return msg, {"choices": cols, "value": None}, preview

with gr.Blocks() as demo:
    gr.Markdown("# NeuroNarrative-Lite: EEG Summary with Flexible Preprocessing")
    gr.Markdown(
        "Upload an EEG file (FIF or CSV). If it's CSV, we will inspect the file and let you choose a time column. "
        "If no suitable time column is found, leave it blank and provide a default sampling frequency."
    )

    file_input = gr.File(label="Upload your EEG data (FIF or CSV)")
    preview_button = gr.Button("Inspect File")
    msg_output = gr.Markdown()
    cols_dropdown = gr.Dropdown(label="Select Time Column (optional)", interactive=True)
    preview_output = gr.Markdown()

    preview_button.click(preview_file, inputs=[file_input], outputs=[msg_output, cols_dropdown, preview_output])

    default_sfreq_input = gr.Textbox(label="Default Sampling Frequency (Hz) if no time column", value="256")
    analyze_button = gr.Button("Run Analysis")
    result_output = gr.Textbox(label="Analysis Summary")

    analyze_button.click(analyze_eeg, 
                          inputs=[file_input, default_sfreq_input, cols_dropdown],
                          outputs=[result_output])

if __name__ == "__main__":
    demo.launch()