Spaces:
Sleeping
Sleeping
import gradio as gr | |
import mne | |
import numpy as np | |
import pandas as pd | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import os | |
model_name = "tiiuae/falcon-7b-instruct" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
def compute_band_power(psd, freqs, fmin, fmax): | |
freq_mask = (freqs >= fmin) & (freqs <= fmax) | |
band_psd = psd[:, freq_mask].mean() | |
return float(band_psd) | |
def load_eeg_data(file_path, default_sfreq=256.0, time_col='time'): | |
""" | |
Load EEG data from a file with flexible CSV handling. | |
- If FIF: Use read_raw_fif. | |
- If CSV: | |
* If `time_col` is present, use it as time. | |
* Otherwise, assume a default sfreq and treat all columns as channels. | |
""" | |
_, file_ext = os.path.splitext(file_path) | |
file_ext = file_ext.lower() | |
if file_ext == '.fif': | |
raw = mne.io.read_raw_fif(file_path, preload=True) | |
elif file_ext == '.csv': | |
df = pd.read_csv(file_path) | |
# Remove non-numeric columns except time_col | |
for col in df.columns: | |
if col != time_col: | |
# Drop non-numeric columns if any | |
if not pd.api.types.is_numeric_dtype(df[col]): | |
df = df.drop(columns=[col]) | |
if time_col in df.columns: | |
# Use the provided time column | |
time = df[time_col].values | |
data_df = df.drop(columns=[time_col]) | |
if len(time) < 2: | |
raise ValueError("Not enough time points to estimate sampling frequency.") | |
sfreq = 1.0 / np.mean(np.diff(time)) | |
else: | |
# No explicit time column, assume uniform sampling at default_sfreq | |
sfreq = default_sfreq | |
data_df = df | |
# Channels are all remaining columns | |
ch_names = list(data_df.columns) | |
data = data_df.values.T # shape: (n_channels, n_samples) | |
# Create MNE info | |
ch_types = ['eeg'] * len(ch_names) | |
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) | |
raw = mne.io.RawArray(data, info) | |
else: | |
raise ValueError("Unsupported file format. Please provide a FIF or CSV file.") | |
return raw | |
def process_eeg(file, default_sfreq, time_col): | |
raw = load_eeg_data(file.name, default_sfreq=float(default_sfreq), time_col=time_col) | |
psd, freqs = mne.time_frequency.psd_welch(raw, fmin=1, fmax=40) | |
alpha_power = compute_band_power(psd, freqs, 8, 12) | |
beta_power = compute_band_power(psd, freqs, 13, 30) | |
data_summary = ( | |
f"Alpha power: {alpha_power:.3f}, Beta power: {beta_power:.3f}. " | |
f"The EEG shows stable alpha rhythms and slightly elevated beta activity." | |
) | |
prompt = f"""You are a neuroscientist analyzing EEG features. | |
Data Summary: {data_summary} | |
Provide a concise, user-friendly interpretation of these findings in simple terms. | |
""" | |
inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device) | |
outputs = model.generate( | |
inputs, max_length=200, do_sample=True, top_k=50, top_p=0.95 | |
) | |
summary = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return summary | |
iface = gr.Interface( | |
fn=process_eeg, | |
inputs=[ | |
gr.File(label="Upload your EEG data (FIF or CSV)"), | |
gr.Textbox(label="Default Sampling Frequency if no time column (Hz)", value="256"), | |
gr.Textbox(label="Time column name (if exists)", value="time") | |
], | |
outputs="text", | |
title="NeuroNarrative-Lite: EEG Summary (Flexible CSV Handling)", | |
description=( | |
"Upload EEG data in FIF or CSV format. " | |
"If CSV, either include a 'time' column or specify a default sampling frequency. " | |
"Non-numeric columns will be removed (except the chosen time column)." | |
) | |
) | |
if __name__ == "__main__": | |
iface.launch() | |