File size: 3,993 Bytes
c48497c
 
4fb4636
56dc0d1
c48497c
 
56dc0d1
c48497c
 
 
4fb4636
 
 
 
49be262
4fb4636
 
 
 
56dc0d1
4fb4636
c48497c
49be262
56dc0d1
49be262
 
 
 
 
56dc0d1
 
 
 
 
 
 
 
49be262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56dc0d1
 
 
 
 
 
 
 
 
49be262
 
4fb4636
c48497c
 
 
4fb4636
 
 
 
 
 
 
 
 
 
 
 
c48497c
4fb4636
 
 
c48497c
4fb4636
c48497c
 
 
 
49be262
 
 
 
 
c48497c
49be262
 
 
 
 
 
c48497c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import mne
import numpy as np
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os

model_name = "tiiuae/falcon-7b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    torch_dtype=torch.float16,
    device_map="auto"
)

def compute_band_power(psd, freqs, fmin, fmax):
    freq_mask = (freqs >= fmin) & (freqs <= fmax)
    band_psd = psd[:, freq_mask].mean()
    return float(band_psd)

def load_eeg_data(file_path, default_sfreq=256.0, time_col='time'):
    """
    Load EEG data from a file with flexible CSV handling.
    - If FIF: Use read_raw_fif.
    - If CSV: 
       * If `time_col` is present, use it as time.
       * Otherwise, assume a default sfreq and treat all columns as channels.
    """
    _, file_ext = os.path.splitext(file_path)
    file_ext = file_ext.lower()

    if file_ext == '.fif':
        raw = mne.io.read_raw_fif(file_path, preload=True)
    elif file_ext == '.csv':
        df = pd.read_csv(file_path)

        # Remove non-numeric columns except time_col
        for col in df.columns:
            if col != time_col:
                # Drop non-numeric columns if any
                if not pd.api.types.is_numeric_dtype(df[col]):
                    df = df.drop(columns=[col])

        if time_col in df.columns:
            # Use the provided time column
            time = df[time_col].values
            data_df = df.drop(columns=[time_col])
            
            if len(time) < 2:
                raise ValueError("Not enough time points to estimate sampling frequency.")
            sfreq = 1.0 / np.mean(np.diff(time))
        else:
            # No explicit time column, assume uniform sampling at default_sfreq
            sfreq = default_sfreq
            data_df = df

        # Channels are all remaining columns
        ch_names = list(data_df.columns)
        data = data_df.values.T  # shape: (n_channels, n_samples)

        # Create MNE info
        ch_types = ['eeg'] * len(ch_names)
        info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
        
        raw = mne.io.RawArray(data, info)
    else:
        raise ValueError("Unsupported file format. Please provide a FIF or CSV file.")
    
    return raw

def process_eeg(file, default_sfreq, time_col):
    raw = load_eeg_data(file.name, default_sfreq=float(default_sfreq), time_col=time_col)

    psd, freqs = mne.time_frequency.psd_welch(raw, fmin=1, fmax=40)
    alpha_power = compute_band_power(psd, freqs, 8, 12)
    beta_power = compute_band_power(psd, freqs, 13, 30)

    data_summary = (
        f"Alpha power: {alpha_power:.3f}, Beta power: {beta_power:.3f}. "
        f"The EEG shows stable alpha rhythms and slightly elevated beta activity."
    )

    prompt = f"""You are a neuroscientist analyzing EEG features.
Data Summary: {data_summary}

Provide a concise, user-friendly interpretation of these findings in simple terms.
"""

    inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(
        inputs, max_length=200, do_sample=True, top_k=50, top_p=0.95
    )
    summary = tokenizer.decode(outputs[0], skip_special_tokens=True)

    return summary

iface = gr.Interface(
    fn=process_eeg,
    inputs=[
        gr.File(label="Upload your EEG data (FIF or CSV)"),
        gr.Textbox(label="Default Sampling Frequency if no time column (Hz)", value="256"),
        gr.Textbox(label="Time column name (if exists)", value="time")
    ],
    outputs="text",
    title="NeuroNarrative-Lite: EEG Summary (Flexible CSV Handling)",
    description=(
        "Upload EEG data in FIF or CSV format. "
        "If CSV, either include a 'time' column or specify a default sampling frequency. "
        "Non-numeric columns will be removed (except the chosen time column)."
    )
)

if __name__ == "__main__":
    iface.launch()