Rogerjs's picture
Update app.py
56dc0d1 verified
raw
history blame
3.85 kB
import gradio as gr
import mne
import numpy as np
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
# Load an open-source LLM model with no additional training
model_name = "tiiuae/falcon-7b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="auto" # Automatically selects CPU/GPU if available
)
def compute_band_power(psd, freqs, fmin, fmax):
"""Compute mean band power in the given frequency range."""
freq_mask = (freqs >= fmin) & (freqs <= fmax)
# Take the mean across channels and frequencies
band_psd = psd[:, freq_mask].mean()
return float(band_psd)
def load_eeg_data(file_path):
"""
Load EEG data from a file.
If FIF file is detected, use MNE's read_raw_fif.
If CSV file is detected, load via pandas and create a RawArray.
"""
_, file_ext = os.path.splitext(file_path)
file_ext = file_ext.lower()
if file_ext == '.fif':
raw = mne.io.read_raw_fif(file_path, preload=True)
elif file_ext == '.csv':
# Assume first column is 'time', and subsequent columns are channels
df = pd.read_csv(file_path)
if 'time' not in df.columns:
raise ValueError("CSV must contain a 'time' column for timestamps.")
time = df['time'].values
data = df.drop(columns=['time']).values.T # shape: (n_channels, n_samples)
# Estimate sampling frequency from time vector (assuming uniform)
# This is a simplistic approach: we take 1 / average time step.
# Make sure time is in seconds
if len(time) < 2:
raise ValueError("Not enough time points in CSV.")
sfreq = 1.0 / np.mean(np.diff(time))
# Create MNE Info
ch_names = list(df.columns)
ch_names.remove('time')
ch_types = ['eeg'] * len(ch_names)
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
raw = mne.io.RawArray(data, info)
else:
raise ValueError("Unsupported file format. Please provide a FIF or CSV file.")
return raw
def process_eeg(file):
# Load EEG data
raw = load_eeg_data(file.name)
# Compute PSD (Power Spectral Density) between 1 and 40 Hz
psd, freqs = mne.time_frequency.psd_welch(raw, fmin=1, fmax=40)
# Compute simple band powers
alpha_power = compute_band_power(psd, freqs, 8, 12)
beta_power = compute_band_power(psd, freqs, 13, 30)
# Create a short summary of the extracted features
data_summary = (
f"Alpha power: {alpha_power:.3f}, Beta power: {beta_power:.3f}. "
f"The EEG shows stable alpha rhythms and slightly elevated beta activity."
)
# Prepare the prompt for the language model
prompt = f"""You are a neuroscientist analyzing EEG features.
Data Summary: {data_summary}
Provide a concise, user-friendly interpretation of these findings in simple terms.
"""
# Generate the summary using the LLM
inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
inputs, max_length=200, do_sample=True, top_k=50, top_p=0.95
)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
return summary
iface = gr.Interface(
fn=process_eeg,
inputs=gr.File(label="Upload your EEG data (FIF or CSV)"),
outputs="text",
title="NeuroNarrative-Lite: EEG Summary",
description=("Upload EEG data in FIF (MNE native) or CSV format. "
"The system extracts basic EEG features and generates "
"a human-readable summary using an open-source language model.")
)
if __name__ == "__main__":
iface.launch()