Rogerjs's picture
Update app.py
49be262 verified
raw
history blame
3.99 kB
import gradio as gr
import mne
import numpy as np
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
model_name = "tiiuae/falcon-7b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="auto"
)
def compute_band_power(psd, freqs, fmin, fmax):
freq_mask = (freqs >= fmin) & (freqs <= fmax)
band_psd = psd[:, freq_mask].mean()
return float(band_psd)
def load_eeg_data(file_path, default_sfreq=256.0, time_col='time'):
"""
Load EEG data from a file with flexible CSV handling.
- If FIF: Use read_raw_fif.
- If CSV:
* If `time_col` is present, use it as time.
* Otherwise, assume a default sfreq and treat all columns as channels.
"""
_, file_ext = os.path.splitext(file_path)
file_ext = file_ext.lower()
if file_ext == '.fif':
raw = mne.io.read_raw_fif(file_path, preload=True)
elif file_ext == '.csv':
df = pd.read_csv(file_path)
# Remove non-numeric columns except time_col
for col in df.columns:
if col != time_col:
# Drop non-numeric columns if any
if not pd.api.types.is_numeric_dtype(df[col]):
df = df.drop(columns=[col])
if time_col in df.columns:
# Use the provided time column
time = df[time_col].values
data_df = df.drop(columns=[time_col])
if len(time) < 2:
raise ValueError("Not enough time points to estimate sampling frequency.")
sfreq = 1.0 / np.mean(np.diff(time))
else:
# No explicit time column, assume uniform sampling at default_sfreq
sfreq = default_sfreq
data_df = df
# Channels are all remaining columns
ch_names = list(data_df.columns)
data = data_df.values.T # shape: (n_channels, n_samples)
# Create MNE info
ch_types = ['eeg'] * len(ch_names)
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
raw = mne.io.RawArray(data, info)
else:
raise ValueError("Unsupported file format. Please provide a FIF or CSV file.")
return raw
def process_eeg(file, default_sfreq, time_col):
raw = load_eeg_data(file.name, default_sfreq=float(default_sfreq), time_col=time_col)
psd, freqs = mne.time_frequency.psd_welch(raw, fmin=1, fmax=40)
alpha_power = compute_band_power(psd, freqs, 8, 12)
beta_power = compute_band_power(psd, freqs, 13, 30)
data_summary = (
f"Alpha power: {alpha_power:.3f}, Beta power: {beta_power:.3f}. "
f"The EEG shows stable alpha rhythms and slightly elevated beta activity."
)
prompt = f"""You are a neuroscientist analyzing EEG features.
Data Summary: {data_summary}
Provide a concise, user-friendly interpretation of these findings in simple terms.
"""
inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
inputs, max_length=200, do_sample=True, top_k=50, top_p=0.95
)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
return summary
iface = gr.Interface(
fn=process_eeg,
inputs=[
gr.File(label="Upload your EEG data (FIF or CSV)"),
gr.Textbox(label="Default Sampling Frequency if no time column (Hz)", value="256"),
gr.Textbox(label="Time column name (if exists)", value="time")
],
outputs="text",
title="NeuroNarrative-Lite: EEG Summary (Flexible CSV Handling)",
description=(
"Upload EEG data in FIF or CSV format. "
"If CSV, either include a 'time' column or specify a default sampling frequency. "
"Non-numeric columns will be removed (except the chosen time column)."
)
)
if __name__ == "__main__":
iface.launch()