Spaces:
Sleeping
Sleeping
File size: 3,993 Bytes
c48497c 4fb4636 56dc0d1 c48497c 56dc0d1 c48497c 4fb4636 49be262 4fb4636 56dc0d1 4fb4636 c48497c 49be262 56dc0d1 49be262 56dc0d1 49be262 56dc0d1 49be262 4fb4636 c48497c 4fb4636 c48497c 4fb4636 c48497c 4fb4636 c48497c 49be262 c48497c 49be262 c48497c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import gradio as gr
import mne
import numpy as np
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
model_name = "tiiuae/falcon-7b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="auto"
)
def compute_band_power(psd, freqs, fmin, fmax):
freq_mask = (freqs >= fmin) & (freqs <= fmax)
band_psd = psd[:, freq_mask].mean()
return float(band_psd)
def load_eeg_data(file_path, default_sfreq=256.0, time_col='time'):
"""
Load EEG data from a file with flexible CSV handling.
- If FIF: Use read_raw_fif.
- If CSV:
* If `time_col` is present, use it as time.
* Otherwise, assume a default sfreq and treat all columns as channels.
"""
_, file_ext = os.path.splitext(file_path)
file_ext = file_ext.lower()
if file_ext == '.fif':
raw = mne.io.read_raw_fif(file_path, preload=True)
elif file_ext == '.csv':
df = pd.read_csv(file_path)
# Remove non-numeric columns except time_col
for col in df.columns:
if col != time_col:
# Drop non-numeric columns if any
if not pd.api.types.is_numeric_dtype(df[col]):
df = df.drop(columns=[col])
if time_col in df.columns:
# Use the provided time column
time = df[time_col].values
data_df = df.drop(columns=[time_col])
if len(time) < 2:
raise ValueError("Not enough time points to estimate sampling frequency.")
sfreq = 1.0 / np.mean(np.diff(time))
else:
# No explicit time column, assume uniform sampling at default_sfreq
sfreq = default_sfreq
data_df = df
# Channels are all remaining columns
ch_names = list(data_df.columns)
data = data_df.values.T # shape: (n_channels, n_samples)
# Create MNE info
ch_types = ['eeg'] * len(ch_names)
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
raw = mne.io.RawArray(data, info)
else:
raise ValueError("Unsupported file format. Please provide a FIF or CSV file.")
return raw
def process_eeg(file, default_sfreq, time_col):
raw = load_eeg_data(file.name, default_sfreq=float(default_sfreq), time_col=time_col)
psd, freqs = mne.time_frequency.psd_welch(raw, fmin=1, fmax=40)
alpha_power = compute_band_power(psd, freqs, 8, 12)
beta_power = compute_band_power(psd, freqs, 13, 30)
data_summary = (
f"Alpha power: {alpha_power:.3f}, Beta power: {beta_power:.3f}. "
f"The EEG shows stable alpha rhythms and slightly elevated beta activity."
)
prompt = f"""You are a neuroscientist analyzing EEG features.
Data Summary: {data_summary}
Provide a concise, user-friendly interpretation of these findings in simple terms.
"""
inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
inputs, max_length=200, do_sample=True, top_k=50, top_p=0.95
)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
return summary
iface = gr.Interface(
fn=process_eeg,
inputs=[
gr.File(label="Upload your EEG data (FIF or CSV)"),
gr.Textbox(label="Default Sampling Frequency if no time column (Hz)", value="256"),
gr.Textbox(label="Time column name (if exists)", value="time")
],
outputs="text",
title="NeuroNarrative-Lite: EEG Summary (Flexible CSV Handling)",
description=(
"Upload EEG data in FIF or CSV format. "
"If CSV, either include a 'time' column or specify a default sampling frequency. "
"Non-numeric columns will be removed (except the chosen time column)."
)
)
if __name__ == "__main__":
iface.launch()
|