File size: 3,845 Bytes
c48497c
 
4fb4636
56dc0d1
c48497c
 
56dc0d1
c48497c
4fb4636
c48497c
 
4fb4636
 
 
 
 
 
 
 
 
 
 
56dc0d1
4fb4636
c48497c
56dc0d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c48497c
56dc0d1
 
4fb4636
 
c48497c
4fb4636
 
c48497c
 
4fb4636
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c48497c
4fb4636
 
 
c48497c
4fb4636
c48497c
 
 
 
56dc0d1
c48497c
 
56dc0d1
 
 
c48497c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import mne
import numpy as np
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os

# Load an open-source LLM model with no additional training
model_name = "tiiuae/falcon-7b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    torch_dtype=torch.float16,
    device_map="auto"  # Automatically selects CPU/GPU if available
)

def compute_band_power(psd, freqs, fmin, fmax):
    """Compute mean band power in the given frequency range."""
    freq_mask = (freqs >= fmin) & (freqs <= fmax)
    # Take the mean across channels and frequencies
    band_psd = psd[:, freq_mask].mean()
    return float(band_psd)

def load_eeg_data(file_path):
    """
    Load EEG data from a file.
    If FIF file is detected, use MNE's read_raw_fif.
    If CSV file is detected, load via pandas and create a RawArray.
    """
    _, file_ext = os.path.splitext(file_path)
    file_ext = file_ext.lower()

    if file_ext == '.fif':
        raw = mne.io.read_raw_fif(file_path, preload=True)
    elif file_ext == '.csv':
        # Assume first column is 'time', and subsequent columns are channels
        df = pd.read_csv(file_path)
        if 'time' not in df.columns:
            raise ValueError("CSV must contain a 'time' column for timestamps.")
        
        time = df['time'].values
        data = df.drop(columns=['time']).values.T  # shape: (n_channels, n_samples)
        
        # Estimate sampling frequency from time vector (assuming uniform)
        # This is a simplistic approach: we take 1 / average time step.
        # Make sure time is in seconds
        if len(time) < 2:
            raise ValueError("Not enough time points in CSV.")
        sfreq = 1.0 / np.mean(np.diff(time))
        
        # Create MNE Info
        ch_names = list(df.columns)
        ch_names.remove('time')
        ch_types = ['eeg'] * len(ch_names)
        info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
        
        raw = mne.io.RawArray(data, info)
    else:
        raise ValueError("Unsupported file format. Please provide a FIF or CSV file.")
    
    return raw

def process_eeg(file):
    # Load EEG data
    raw = load_eeg_data(file.name)

    # Compute PSD (Power Spectral Density) between 1 and 40 Hz
    psd, freqs = mne.time_frequency.psd_welch(raw, fmin=1, fmax=40)

    # Compute simple band powers
    alpha_power = compute_band_power(psd, freqs, 8, 12)
    beta_power = compute_band_power(psd, freqs, 13, 30)

    # Create a short summary of the extracted features
    data_summary = (
        f"Alpha power: {alpha_power:.3f}, Beta power: {beta_power:.3f}. "
        f"The EEG shows stable alpha rhythms and slightly elevated beta activity."
    )

    # Prepare the prompt for the language model
    prompt = f"""You are a neuroscientist analyzing EEG features.
Data Summary: {data_summary}

Provide a concise, user-friendly interpretation of these findings in simple terms.
"""

    # Generate the summary using the LLM
    inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(
        inputs, max_length=200, do_sample=True, top_k=50, top_p=0.95
    )
    summary = tokenizer.decode(outputs[0], skip_special_tokens=True)

    return summary

iface = gr.Interface(
    fn=process_eeg,
    inputs=gr.File(label="Upload your EEG data (FIF or CSV)"),
    outputs="text",
    title="NeuroNarrative-Lite: EEG Summary",
    description=("Upload EEG data in FIF (MNE native) or CSV format. "
                 "The system extracts basic EEG features and generates "
                 "a human-readable summary using an open-source language model.")
)

if __name__ == "__main__":
    iface.launch()