Function bodies 170 total
get_reference_type_id function · python · L62-L65 (4 LOC)data/eeg_pretrain_dataset_safe.py
def get_reference_type_id(dataset: str) -> int:
"""Get the reference type ID for a dataset."""
ref_name = DATASET_REFERENCE_TYPES.get(dataset, 'unknown')
return EEGConfig.REFERENCE_TYPES.get(ref_name, EEGConfig.REFERENCE_TYPES['unknown'])get_dataset_channel_mask function · python · L68-L88 (21 LOC)data/eeg_pretrain_dataset_safe.py
def get_dataset_channel_mask(dataset: str, max_channels: int = 19) -> np.ndarray:
"""Create channel mask for a dataset.
Args:
dataset: Dataset name
max_channels: Maximum number of channels (default 19, standard 10-20 system)
Returns:
Binary mask array where 1 = channel present, 0 = channel absent
"""
dataset_channels = DATASET_CHANNELS.get(dataset, EEGConfig.STANDARD_CHANNELS)
all_channels = EEGConfig.STANDARD_CHANNELS
mask = np.zeros(max_channels, dtype=np.float32)
for ch in dataset_channels:
if ch in all_channels:
idx = all_channels.index(ch)
if idx < max_channels:
mask[idx] = 1.0
return maskget_channel_indices function · python · L91-L106 (16 LOC)data/eeg_pretrain_dataset_safe.py
def get_channel_indices(dataset: str) -> List[int]:
"""Get indices of channels present in a dataset.
Args:
dataset: Dataset name
Returns:
List of channel indices in STANDARD_CHANNELS that are present
"""
dataset_channels = DATASET_CHANNELS.get(dataset, EEGConfig.STANDARD_CHANNELS)
all_channels = EEGConfig.STANDARD_CHANNELS
indices = []
for ch in dataset_channels:
if ch in all_channels:
indices.append(all_channels.index(ch))
return indicesget_condition_description function · python · L193-L209 (17 LOC)data/eeg_pretrain_dataset_safe.py
def get_condition_description(dataset: str, recording_state: str) -> str:
"""Get natural language description of the recording condition."""
# Try exact match first
key = (dataset, recording_state)
if key in CONDITION_DESCRIPTIONS:
return CONDITION_DESCRIPTIONS[key]
# Try with 'unknown' state
key_unknown = (dataset, 'unknown')
if key_unknown in CONDITION_DESCRIPTIONS:
return CONDITION_DESCRIPTIONS[key_unknown]
# Fall back to dataset default
if dataset in DEFAULT_CONDITION_DESCRIPTIONS:
return DEFAULT_CONDITION_DESCRIPTIONS[dataset]
return "EEG recording"EEGPretrainDatasetSafe class · python · L212-L533 (322 LOC)data/eeg_pretrain_dataset_safe.py
class EEGPretrainDatasetSafe(Dataset):
"""Dataset for EEG pretraining with safe H5 file handling for multi-GPU
This version opens H5 files on demand in each worker process to avoid
file handle corruption issues with multiprocessing.
"""
def __init__(
self,
data_dir: Path,
split: str = 'train',
datasets: List[str] = ['CAUEEG', 'PEARL', 'TD_BRAIN'],
max_length: int = 24000, # 2 minutes @ 200Hz
segment_length: int = 24000, # Length of each preprocessed segment
include_eyes_state: bool = True,
splits_file: Optional[Path] = None,
patch_size: int = 32,
max_channels: int = 19,
):
"""
Args:
data_dir: Root directory containing processed datasets
split: 'train', 'val', or 'test'
datasets: List of datasets to include
max_length: Maximum sequence length in samples (can exceed segment_length
to con__init__ method · python · L219-L291 (73 LOC)data/eeg_pretrain_dataset_safe.py
def __init__(
self,
data_dir: Path,
split: str = 'train',
datasets: List[str] = ['CAUEEG', 'PEARL', 'TD_BRAIN'],
max_length: int = 24000, # 2 minutes @ 200Hz
segment_length: int = 24000, # Length of each preprocessed segment
include_eyes_state: bool = True,
splits_file: Optional[Path] = None,
patch_size: int = 32,
max_channels: int = 19,
):
"""
Args:
data_dir: Root directory containing processed datasets
split: 'train', 'val', or 'test'
datasets: List of datasets to include
max_length: Maximum sequence length in samples (can exceed segment_length
to concatenate consecutive segments for longer context)
segment_length: Length of each preprocessed segment in the H5 files
include_eyes_state: Whether to include eyes open/closed info
splits_file: Optional path to custom splits JSON__getitem__ method · python · L296-L439 (144 LOC)data/eeg_pretrain_dataset_safe.py
def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
file_path, start_seg_idx, n_avail, dataset = self.segments_info[idx]
# Open file fresh each time - safer for multiprocessing
with h5py.File(file_path, 'r') as f:
# Load and concatenate consecutive segments
segments = []
for i in range(n_avail):
seg = f['eeg/segments'][start_seg_idx + i][:]
segments.append(seg)
eeg_data = np.concatenate(segments, axis=-1) # (C, n_avail * segment_length)
# Load channel info from H5 (v6 has channel_coords; v5 does not)
has_v6_coords = 'eeg/channel_coords' in f
if has_v6_coords:
file_channel_coords = f['eeg/channel_coords'][:] # (n_ch, 2)
else:
file_channel_coords = None
file_channel_names = [
x.decode() if isinstance(x, bytes) else x
for x in f['eeg/channel_names'][:]
Repobility's GitHub App fixes findings like these · https://github.com/apps/repobility-bot
_extract_recording_state method · python · L441-L533 (93 LOC)data/eeg_pretrain_dataset_safe.py
def _extract_recording_state(self, f: h5py.File, dataset: str) -> str:
"""Extract detailed recording state from HDF5 file based on dataset
Returns detailed states like:
- 'resting eyes closed'
- 'resting eyes open'
- 'MSIT task'
- 'Sternberg task'
- 'unknown'
"""
# Default state
state = 'unknown'
try:
if dataset == 'AD_EEG':
# AD_EEG has eyes_closed attribute
if 'eyes_closed' in f.attrs:
state = 'resting eyes closed' if f.attrs['eyes_closed'] else 'resting eyes open'
elif dataset == 'CAUEEG':
# CAUEEG clinical EEG — check attrs, default to clinical recording
if 'recording_type' in f.attrs:
state = f.attrs['recording_type']
elif 'eyes_state' in f.attrs:
state = f'resting eyes {f.attrs["eyes_state"collate_batch_eeg function · python · L536-L584 (49 LOC)data/eeg_pretrain_dataset_safe.py
def collate_batch_eeg(batch: List[Dict]) -> Dict[str, torch.Tensor]:
"""Collate function that stacks EEG data properly"""
eeg_data = torch.stack([item['eeg'] for item in batch])
# Convert lists to tensors/lists as appropriate
eyes_states = [item['eyes_state'] for item in batch]
recording_states = [item['recording_state'] for item in batch]
condition_descriptions = [item['condition_description'] for item in batch]
datasets = [item['dataset'] for item in batch]
file_paths = [item['file_path'] for item in batch]
# Reference type IDs as tensor for 2D spatial RoPE
reference_type_ids = torch.tensor(
[item['reference_type_id'] for item in batch],
dtype=torch.long
)
reference_names = [item['reference_name'] for item in batch]
# Channel masks for variable channel support
channel_masks = torch.stack([item['channel_mask'] for item in batch])
channel_indices_list = [item['channel_indices'] for item in batch]
# Chanworker_init_fn function · python · L587-L601 (15 LOC)data/eeg_pretrain_dataset_safe.py
def worker_init_fn(worker_id):
"""Initialize worker with proper random seed and H5PY settings"""
import numpy as np
import random
import torch
# Set random seeds for reproducibility
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
# Configure H5PY for better performance
import h5py
# Use latest file format for better performance
h5py.get_config().track_order = Falsecreate_safe_pretrain_dataloader function · python · L604-L653 (50 LOC)data/eeg_pretrain_dataset_safe.py
def create_safe_pretrain_dataloader(
data_dir: Path,
split: str = 'train',
batch_size: int = 8,
num_workers: int = 4,
shuffle: bool = True,
datasets: List[str] = ['CAUEEG', 'PEARL', 'TD_BRAIN'],
pin_memory: bool = True,
prefetch_factor: int = 2,
persistent_workers: bool = True,
max_length: int = 24000,
segment_length: int = 24000,
include_eyes_state: bool = True,
splits_file: Optional[Path] = None,
patch_size: int = 32,
max_channels: int = 19,
) -> DataLoader:
"""Create a safe dataloader for pretraining that handles multiprocessing correctly"""
# Use the safe dataset
dataset = EEGPretrainDatasetSafe(
data_dir=data_dir,
split=split,
datasets=datasets,
max_length=max_length,
segment_length=segment_length,
include_eyes_state=include_eyes_state,
splits_file=splits_file,
patch_size=patch_size,
max_channels=max_channels,
)
# For distribload_model function · python · L48-L77 (30 LOC)eval/evaluate_full_recording.py
def load_model(checkpoint_path: str, device: str = 'cuda'):
ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
config = ckpt.get('config', {})
max_channels = config.get('max_channels', config.get('n_channels', 19))
model = EEGEncoderV3(
n_channels=config.get('n_channels', 19),
hidden_dim=config.get('hidden_dim', 1024),
n_heads=config.get('n_heads', 16),
n_kv_heads=config.get('n_kv_heads', 4),
n_layers=config.get('n_layers', 16),
patch_size=config.get('patch_size', 32),
dropout=0.0,
ffn_multiplier=config.get('ffn_multiplier', 8.0 / 3.0),
rope_base=config.get('rope_base', 500_000.0),
n_registers=config.get('n_registers', 8),
use_spatial_rope=config.get('use_spatial_rope', True),
use_modality_tokens=config.get('use_modality_tokens', True),
use_condition_encoding=True,
condition_cache_path='data/condition_embeddings.pt',
max_channeload_full_recording function · python · L84-L135 (52 LOC)eval/evaluate_full_recording.py
def load_full_recording(h5_path: str, max_channels: int = 19):
"""Load all segments from H5 and stitch into continuous recording."""
with h5py.File(h5_path, 'r') as f:
segments = f['eeg/segments'][:]
n_seg = segments.shape[0]
if n_seg == 1:
eeg = segments[0]
else:
parts = [segments[i, :, :STRIDE] for i in range(n_seg - 1)]
parts.append(segments[-1])
eeg = np.concatenate(parts, axis=1)
file_coords = f['eeg/channel_coords'][:] if 'eeg/channel_coords' in f else None
ch_names = [x.decode() if isinstance(x, bytes) else x for x in f['eeg/channel_names'][:]]
file_mask = f['channel_mask'][:] if 'channel_mask' in f else None
n_native = eeg.shape[0]
if n_native < max_channels:
eeg = np.pad(eeg, ((0, max_channels - n_native), (0, 0)))
elif n_native > max_channels:
eeg = eeg[:max_channels]
n_native = max_channels
# Channel mask
if max_channeextract_recording_state function · python · L138-L156 (19 LOC)eval/evaluate_full_recording.py
def extract_recording_state(filepath, dataset):
fname = Path(filepath).stem
states = {
'AD_EEG': 'resting eyes closed', 'DS004504': 'resting eyes closed',
'PARKINSON_UCSD': 'resting eyes closed', 'READTBI': 'clinical assessment',
'CAUEEG': 'resting eyes closed', 'LEMON': 'resting eyes closed',
'DEPRESSION': 'resting', 'SRM': 'resting eyes closed',
}
if dataset in states:
return states[dataset]
if dataset == 'DORTMUND':
return 'resting eyes closed' if 'EyesClosed' in fname else 'resting eyes open'
if dataset == 'PEARL':
if 'rest' in fname: return 'resting'
if 'msit' in fname: return 'MSIT task'
if 'sternberg' in fname: return 'Sternberg task'
if dataset == 'TD_BRAIN':
return 'resting eyes open' if 'restEO' in fname else 'resting eyes closed'
return 'unknown'embed_full_recording function · python · L160-L191 (32 LOC)eval/evaluate_full_recording.py
def embed_full_recording(model, h5_path, dataset, device, chunk_patches=1500):
"""Embed entire recording → one pooled vector."""
max_ch = model.max_channels
eeg, file_mask, coords = load_full_recording(h5_path, max_channels=max_ch)
if eeg.shape[1] < 32:
return None
if file_mask is not None:
channel_mask = torch.from_numpy(file_mask).unsqueeze(0).to(device)
else:
channel_mask = torch.from_numpy(
get_dataset_channel_mask(dataset, max_channels=max_ch)
).unsqueeze(0).to(device)
channel_coords = torch.from_numpy(coords).unsqueeze(0).to(device)
ref_ids = torch.tensor([get_reference_type_id(dataset)], dtype=torch.long, device=device)
rec_state = extract_recording_state(h5_path, dataset)
condition_keys = [(dataset, rec_state)]
chunk_samples = chunk_patches * 32
all_hidden = []
for start in range(0, eeg.shape[1], chunk_samples):
chunk = eeg[:, start:start + chunk_samples].unsqueeze(0).to(dRepobility — the code-quality scanner for AI-generated software · https://repobility.com
run_classification function · python · L196-L231 (36 LOC)eval/evaluate_full_recording.py
def run_classification(train_emb, train_lab, test_emb, test_lab):
results = {}
# XGBoost
clf = XGBClassifier(n_estimators=100, max_depth=4, use_label_encoder=False,
eval_metric='logloss', verbosity=0)
clf.fit(train_emb, train_lab)
probs = clf.predict_proba(test_emb)[:, 1]
auc = float(roc_auc_score(test_lab, probs))
results['xgboost'] = {
'auc': auc,
'y_true': [int(y) for y in test_lab],
'y_prob': [float(p) for p in probs],
}
# TabPFN
if HAS_TABPFN:
n_comp = min(128, train_emb.shape[1], train_emb.shape[0])
pca = PCA(n_components=n_comp)
tr_pca = pca.fit_transform(train_emb)
te_pca = pca.transform(test_emb)
try:
clf = TabPFNClassifier(device='cpu', N_ensemble_configurations=16,
ignore_pretraining_limits=True)
except TypeError:
try:
clf = TabPFNClassifier(device='cpu', ignore_pretrrun_regression function · python · L234-L245 (12 LOC)eval/evaluate_full_recording.py
def run_regression(train_emb, train_lab, test_emb, test_lab):
reg = XGBRegressor(n_estimators=100, max_depth=4, verbosity=0)
reg.fit(train_emb, train_lab)
preds = reg.predict(test_emb)
if np.std(test_lab) < 1e-8 or np.std(preds) < 1e-8:
return {'xgboost': {'pearson_r': 0.0, 'p_value': 1.0,
'y_true': [float(y) for y in test_lab],
'y_pred': [float(p) for p in preds]}}
r, p = pearsonr(test_lab, preds)
return {'xgboost': {'pearson_r': float(r), 'p_value': float(p),
'y_true': [float(y) for y in test_lab],
'y_pred': [float(p) for p in preds]}}resolve_h5_path function · python · L250-L273 (24 LOC)eval/evaluate_full_recording.py
def resolve_h5_path(path_str: str, dataset: str = None) -> Optional[str]:
"""Resolve a split file path to an actual H5 file on disk."""
# Direct path
p = Path(path_str)
if p.exists():
return str(p)
# v5 remap
v5 = remap_to_v5(path_str)
if Path(v5).exists():
return v5
# v6 directory
v6_base = Path('/home/dlinsley/eeg_pretrain_v2/data/processed_v6')
if dataset:
v6_path = v6_base / dataset / Path(path_str).name
if v6_path.exists():
return str(v6_path)
# Try with _eeg suffix
v6_path2 = v6_base / dataset / (Path(path_str).stem + '_eeg.h5')
if v6_path2.exists():
return str(v6_path2)
return Noneload_classification_v3_task function · python · L276-L314 (39 LOC)eval/evaluate_full_recording.py
def load_classification_v3_task(task: str):
"""Load classification_v3 or tdbrain task → list of (h5_path, label, dataset, patient_id)."""
# Try classification_v3 first, then tdbrain
for splits_dir in ['data/splits/classification_v3', 'data/splits/tdbrain_downstream']:
sp = Path(f'{splits_dir}/{task}/splits.json')
if sp.exists():
break
else:
return None
with open(sp) as f:
data = json.load(f)
datasets = data.get('datasets', ['unknown'])
entries = {}
for split in ['train', 'test']:
items = []
split_data = data.get(split, {})
if isinstance(split_data, dict):
for label_name, files in split_data.items():
label = 1 if label_name == 'positive' else 0
for fpath in files:
# Fix tdbrain paths
if 'processed_v4' in fpath:
fpath = fix_tdbrain_path(fpath)
ds = None
load_regression_task function · python · L317-L345 (29 LOC)eval/evaluate_full_recording.py
def load_regression_task(task: str):
"""Load regression_v2 task → list of (h5_path, label, dataset, patient_id)."""
sp = Path(f'data/splits/regression_v2/{task}/splits.json')
if not sp.exists():
return None
with open(sp) as f:
data = json.load(f)
target_keys = {'age': 'age', 'mmse': 'mmse', 'bdi': 'bdi', 'rpm': 'rpm',
'naart': 'naart', 'disease_duration': 'disease_duration'}
target_key = target_keys.get(task, task)
entries = {}
for split in ['train', 'test']:
items = []
for entry in data.get(split, []):
ds = entry.get('dataset', 'unknown')
if ds in EXCLUDED_DATASETS:
continue
val = entry.get(target_key)
if val is None:
continue
fpath = f"data/processed_v3_minimal/{ds}/{entry['file']}"
resolved = resolve_h5_path(fpath, ds)
if resolved:
pid = entry.get('patient_id', Path(entload_new_task function · python · L348-L370 (23 LOC)eval/evaluate_full_recording.py
def load_new_task(splits_dir: str, task: str):
"""Load new-format task (dortmund/srm/depression) → entries."""
sp = Path(f'data/splits/{splits_dir}/{task}/splits.json')
if not sp.exists():
return None
with open(sp) as f:
data = json.load(f)
v6_base = Path('/home/dlinsley/eeg_pretrain_v2/data/processed_v6')
entries = {}
for split in ['train', 'test']:
items = []
for entry in data.get(split, []):
h5_file = entry['file']
label = entry['label']
ds = h5_file.split('/')[0]
h5_path = v6_base / h5_file
if not h5_path.exists():
continue
pid = Path(h5_file).stem
items.append((str(h5_path), label, ds, pid))
entries[split] = items
return entriesembed_and_aggregate function · python · L375-L399 (25 LOC)eval/evaluate_full_recording.py
def embed_and_aggregate(model, items, device):
"""Embed all files, aggregate per patient."""
patient_emb = {}
patient_lab = {}
for h5_path, label, dataset, pid in tqdm(items, leave=False):
emb = embed_full_recording(model, h5_path, dataset, device)
if emb is None:
continue
if pid not in patient_emb:
patient_emb[pid] = []
patient_lab[pid] = label
patient_emb[pid].append(emb)
if not patient_emb:
return None, None
emb_arr = np.array([np.mean(v, axis=0) for v in patient_emb.values()])
lab_arr = np.array(list(patient_lab.values()))
# Drop any patients with NaN embeddings
valid = ~np.isnan(emb_arr).any(axis=1)
if not valid.all():
n_bad = (~valid).sum()
print(f" WARNING: dropping {n_bad} patients with NaN embeddings")
emb_arr = emb_arr[valid]
lab_arr = lab_arr[valid]
return emb_arr, lab_arrevaluate_task function · python · L402-L428 (27 LOC)eval/evaluate_full_recording.py
def evaluate_task(model, task_name, entries, device, is_regression=False):
"""Full pipeline: embed → aggregate → classify/regress."""
print(f" [{task_name}]")
results = {}
for split in ['train', 'test']:
if split not in entries or not entries[split]:
return {}
emb, lab = embed_and_aggregate(model, entries[split], device)
if emb is None:
return {}
results[split] = (emb, lab)
n_pos = int((lab == 1).sum()) if not is_regression else len(lab)
print(f" {split}: {len(lab)} patients" +
(f" ({n_pos}+/{len(lab)-n_pos}-)" if not is_regression else f" (range [{lab.min():.1f}, {lab.max():.1f}])"))
train_emb, train_lab = results['train']
test_emb, test_lab = results['test']
if is_regression:
res = run_regression(train_emb, train_lab, test_emb, test_lab)
print(f" → r={res['xgboost']['pearson_r']:.3f}")
else:
res = run_classification(train_emb, train_laSource: Repobility analyzer · https://repobility.com
main function · python · L463-L507 (45 LOC)eval/evaluate_full_recording.py
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', required=True)
parser.add_argument('--device', default='cuda')
parser.add_argument('--output', default='eval/full_recording_results.json')
args = parser.parse_args()
model, config = load_model(args.checkpoint, args.device)
all_results = {}
# 1. Classification_v3 + tdbrain tasks
print("\n=== Classification Tasks ===")
for task in CLASSIFICATION_TASKS:
entries = load_classification_v3_task(task)
if entries is None:
print(f" [{task}] SKIPPED (no splits)")
continue
res = evaluate_task(model, task, entries, args.device, is_regression=False)
all_results[task] = res
# 2. Regression tasks
print("\n=== Regression Tasks ===")
for task in REGRESSION_TASKS:
entries = load_regression_task(task)
if entries is None:
print(f" [{task}] SKIPPED (no splits)")
continue
extract_patient_id function · python · L55-L79 (25 LOC)eval/evaluate_v3_linear_probe.py
def extract_patient_id(filepath: str) -> str:
"""Extract patient-level ID from a file path for aggregation.
Examples:
.../ASZED/patient_108_s1_phase1.h5 → patient_108
.../CAUEEG/patient_00018.h5 → patient_00018
.../READTBI/110800044AR1_..._201504151548.h5 → 110800044AR1
.../AD_EEG/patient_001.h5 → patient_001
.../DS004504/patient_001.h5 → patient_001
"""
fname = Path(filepath).stem # drop .h5
# ASZED: patient_NNN_sN_phaseN
m = re.match(r'(patient_\d+)_s\d+', fname)
if m:
return m.group(1)
# READTBI: SITE_SUBJECT_DATETIME
m = re.match(r'(\d{9}AR\d)_', fname)
if m:
return m.group(1)
# TD_BRAIN: patient_NNNNNNN_restEC/restEO
m = re.match(r'(patient_\d+)_rest', fname)
if m:
return m.group(1)
# Default: full stem (works for CAUEEG patient_00018, AD_EEG patient_001, etc.)
return fnameextract_dataset_from_path function · python · L82-L95 (14 LOC)eval/evaluate_v3_linear_probe.py
def extract_dataset_from_path(filepath: str) -> str:
"""Extract dataset name from file path.
Expects paths like:
data/processed_v3_minimal/CAUEEG/patient_00018.h5 → CAUEEG
data/processed_v3_minimal/READTBI/110800044AR1_...h5 → READTBI
"""
parts = Path(filepath).parts
for i, part in enumerate(parts):
if part in ('processed_v3_minimal', 'processed_v4', 'processed_v5'):
if i + 1 < len(parts):
return parts[i + 1]
# Fallback: parent directory name
return Path(filepath).parent.nameextract_recording_state function · python · L98-L125 (28 LOC)eval/evaluate_v3_linear_probe.py
def extract_recording_state(filepath: str, dataset: str) -> str:
"""Extract recording state from filename, matching training dataset logic."""
fname = Path(filepath).stem
if dataset == 'AD_EEG':
return 'unknown' # Would need H5 attrs; use default
elif dataset == 'CAUEEG':
return 'unknown'
elif dataset == 'PEARL':
if 'rest' in fname:
return 'resting'
elif 'msit' in fname:
return 'MSIT task'
elif 'sternberg' in fname:
return 'Sternberg task'
return 'unknown'
elif dataset == 'TD_BRAIN':
if 'restEC' in fname:
return 'resting eyes closed'
elif 'restEO' in fname:
return 'resting eyes open'
return 'unknown'
elif dataset == 'READTBI':
return 'clinical assessment'
elif dataset == 'DS004504':
return 'resting eyes closed'
elif dataset == 'PARKINSON_UCSD':
return 'resting eyes closed'
return 'unknown'ClassificationDataset class · python · L130-L237 (108 LOC)eval/evaluate_v3_linear_probe.py
class ClassificationDataset(Dataset):
"""Loads EEG segments from v3 split file paths."""
def __init__(
self,
file_paths: List[str],
label: int,
max_length: int = 48000,
n_channels: int = 19,
):
self.max_length = max_length
self.n_channels = n_channels
# (h5_path, seg_idx, patient_id, label, dataset, recording_state)
self.segments_info = []
for rel_path in file_paths:
h5_path = Path(rel_path)
if not h5_path.exists():
print(f"Warning: File not found: {h5_path}")
continue
patient_id = extract_patient_id(rel_path)
dataset = extract_dataset_from_path(rel_path)
recording_state = extract_recording_state(rel_path, dataset)
try:
with h5py.File(h5_path, 'r') as f:
n_segments = f['eeg/segments'].shape[0]
for seg_idx in range(n_segments):
__init__ method · python · L133-L161 (29 LOC)eval/evaluate_v3_linear_probe.py
def __init__(
self,
file_paths: List[str],
label: int,
max_length: int = 48000,
n_channels: int = 19,
):
self.max_length = max_length
self.n_channels = n_channels
# (h5_path, seg_idx, patient_id, label, dataset, recording_state)
self.segments_info = []
for rel_path in file_paths:
h5_path = Path(rel_path)
if not h5_path.exists():
print(f"Warning: File not found: {h5_path}")
continue
patient_id = extract_patient_id(rel_path)
dataset = extract_dataset_from_path(rel_path)
recording_state = extract_recording_state(rel_path, dataset)
try:
with h5py.File(h5_path, 'r') as f:
n_segments = f['eeg/segments'].shape[0]
for seg_idx in range(n_segments):
self.segments_info.append(
(str(h5_path), seg_idx, __getitem__ method · python · L166-L237 (72 LOC)eval/evaluate_v3_linear_probe.py
def __getitem__(self, idx):
file_path, segment_idx, patient_id, label, dataset, recording_state = \
self.segments_info[idx]
with h5py.File(file_path, 'r') as f:
eeg_data = f['eeg/segments'][segment_idx][:] # (C, T)
# Load channel info for variable-channel support
if 'eeg/channel_coords' in f:
file_coords = f['eeg/channel_coords'][:]
else:
file_coords = None
file_ch_names = [
x.decode() if isinstance(x, bytes) else x
for x in f['eeg/channel_names'][:]
]
if 'channel_mask' in f:
file_mask = f['channel_mask'][:]
else:
file_mask = None
C, T = eeg_data.shape
if C > self.n_channels:
eeg_data = eeg_data[:self.n_channels, :]
elif C < self.n_channels:
eeg_data = np.pad(eeg_data, ((0, self.n_channels - C), (0, 0)))
iRegressionDataset class · python · L240-L348 (109 LOC)eval/evaluate_v3_linear_probe.py
class RegressionDataset(Dataset):
"""Loads EEG segments for regression tasks from regression_v2 split format."""
def __init__(
self,
entries: List[dict],
target_key: str,
max_length: int = 48000,
n_channels: int = 19,
use_v5: bool = True,
):
self.max_length = max_length
self.n_channels = n_channels
# (h5_path, seg_idx, patient_id, target_value, dataset, recording_state)
self.segments_info = []
for entry in entries:
dataset = entry['dataset']
if dataset in EXCLUDED_DATASETS:
continue
v3_path = f"data/processed_v3_minimal/{dataset}/{entry['file']}"
h5_path = Path(remap_to_v5(v3_path) if use_v5 else v3_path)
if not h5_path.exists():
continue
patient_id = entry.get('patient_id', Path(entry['file']).stem)
target_val = entry[target_key]
if target_val is None:
Provenance: Repobility (https://repobility.com) — every score reproducible from /scan/
__init__ method · python · L243-L277 (35 LOC)eval/evaluate_v3_linear_probe.py
def __init__(
self,
entries: List[dict],
target_key: str,
max_length: int = 48000,
n_channels: int = 19,
use_v5: bool = True,
):
self.max_length = max_length
self.n_channels = n_channels
# (h5_path, seg_idx, patient_id, target_value, dataset, recording_state)
self.segments_info = []
for entry in entries:
dataset = entry['dataset']
if dataset in EXCLUDED_DATASETS:
continue
v3_path = f"data/processed_v3_minimal/{dataset}/{entry['file']}"
h5_path = Path(remap_to_v5(v3_path) if use_v5 else v3_path)
if not h5_path.exists():
continue
patient_id = entry.get('patient_id', Path(entry['file']).stem)
target_val = entry[target_key]
if target_val is None:
continue
recording_state = extract_recording_state(str(h5_path), dataset)
try:
__getitem__ method · python · L282-L348 (67 LOC)eval/evaluate_v3_linear_probe.py
def __getitem__(self, idx):
file_path, segment_idx, patient_id, target_val, dataset, recording_state = \
self.segments_info[idx]
with h5py.File(file_path, 'r') as f:
eeg_data = f['eeg/segments'][segment_idx][:]
if 'eeg/channel_coords' in f:
file_coords = f['eeg/channel_coords'][:]
else:
file_coords = None
file_ch_names = [
x.decode() if isinstance(x, bytes) else x
for x in f['eeg/channel_names'][:]
]
if 'channel_mask' in f:
file_mask = f['channel_mask'][:]
else:
file_mask = None
C, T = eeg_data.shape
if C > self.n_channels:
eeg_data = eeg_data[:self.n_channels, :]
elif C < self.n_channels:
eeg_data = np.pad(eeg_data, ((0, self.n_channels - C), (0, 0)))
if T > self.max_length:
eeg_data = eeg_data[:, :self.max_load_v3_model function · python · L353-L385 (33 LOC)eval/evaluate_v3_linear_probe.py
def load_v3_model(checkpoint_path: str, device: str = 'cuda',
condition_cache_path: str = 'data/condition_embeddings.pt'):
ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
config = ckpt.get('config', {})
max_channels = config.get('max_channels', config.get('n_channels', 19))
model = EEGEncoderV3(
n_channels=config.get('n_channels', 19),
hidden_dim=config.get('hidden_dim', 1024),
n_heads=config.get('n_heads', 16),
n_kv_heads=config.get('n_kv_heads', 4),
n_layers=config.get('n_layers', 16),
patch_size=config.get('patch_size', 32),
dropout=0.0,
ffn_multiplier=config.get('ffn_multiplier', 8.0 / 3.0),
rope_base=config.get('rope_base', 500_000.0),
n_registers=config.get('n_registers', 8),
use_spatial_rope=config.get('use_spatial_rope', True),
use_modality_tokens=config.get('use_modality_tokens', True),
use_condition_encoding=Truecreate_random_v3_model function · python · L388-L408 (21 LOC)eval/evaluate_v3_linear_probe.py
def create_random_v3_model(config: dict, device: str = 'cuda',
condition_cache_path: str = 'data/condition_embeddings.pt'):
model = EEGEncoderV3(
n_channels=config.get('n_channels', 19),
hidden_dim=config.get('hidden_dim', 1024),
n_heads=config.get('n_heads', 16),
n_kv_heads=config.get('n_kv_heads', 4),
n_layers=config.get('n_layers', 16),
patch_size=config.get('patch_size', 32),
dropout=0.0,
ffn_multiplier=config.get('ffn_multiplier', 8.0 / 3.0),
rope_base=config.get('rope_base', 500_000.0),
n_registers=config.get('n_registers', 8),
use_spatial_rope=config.get('use_spatial_rope', True),
use_modality_tokens=config.get('use_modality_tokens', True),
use_condition_encoding=True,
condition_cache_path=condition_cache_path,
max_channels=config.get('max_channels', config.get('n_channels', 19)),
)
model.to(device).eval()
return modelextract_embeddings function · python · L414-L483 (70 LOC)eval/evaluate_v3_linear_probe.py
def extract_embeddings(
model: EEGEncoderV3,
dataloader: DataLoader,
device: str,
pooling: str = 'mean',
vqvae: Optional['VQVAE'] = None,
) -> Tuple[np.ndarray, np.ndarray, List[str]]:
"""Extract embeddings. Returns (embeddings, labels, patient_ids).
Matches training pipeline: applies channel_mask, passes reference_type_ids
and condition_keys to model.encode(). No per-channel normalization (training
does not apply any).
Args:
pooling: 'mean' for mean over all patch positions,
'last' for the last patch position (best for causal models).
vqvae: Optional VQ-VAE model. If provided, hidden states are quantized
and reconstructed before pooling (measures VQ-VAE information loss).
"""
all_emb, all_lab, all_pid = [], [], []
for batch in tqdm(dataloader, desc=" Embedding", leave=False):
eeg = batch['eeg'].to(device)
# Channel mask — zero out absent channels (same as training eval_collate_fn function · python · L486-L500 (15 LOC)eval/evaluate_v3_linear_probe.py
def eval_collate_fn(batch: List[Dict]) -> Dict:
"""Collate function that stacks EEG + metadata for eval."""
result = {
'eeg': torch.stack([b['eeg'] for b in batch]),
'label': torch.tensor([b['label'] for b in batch]),
'patient_id': [b['patient_id'] for b in batch],
'channel_mask': torch.stack([b['channel_mask'] for b in batch]),
'reference_type_ids': torch.tensor(
[b['reference_type_id'] for b in batch], dtype=torch.long),
'dataset': [b['dataset'] for b in batch],
'recording_state': [b['recording_state'] for b in batch],
}
if 'channel_coords' in batch[0]:
result['channel_coords'] = torch.stack([b['channel_coords'] for b in batch])
return resultaggregate_to_patient function · python · L505-L559 (55 LOC)eval/evaluate_v3_linear_probe.py
def aggregate_to_patient(
embeddings: np.ndarray,
labels: np.ndarray,
patient_ids: List[str],
segment_selection: str = 'all',
fixed_segments: int = 0,
) -> Tuple[np.ndarray, np.ndarray, List[str]]:
"""Average embeddings per patient.
Args:
segment_selection: 'all' uses all segments, 'first_half' uses first half
of segments per patient, 'second_half' uses second half. Segments are
ordered chronologically (as stored in the H5 file).
fixed_segments: If > 0, subsample each patient to exactly this many
segments before averaging. If 0 (default), use all segments.
Use -1 to auto-detect: subsample to min segment count across patients.
"""
patient_data = defaultdict(lambda: {'embs': [], 'label': None})
for emb, lab, pid in zip(embeddings, labels, patient_ids):
patient_data[pid]['embs'].append(emb)
patient_data[pid]['label'] = lab
# Determine fixed segment count if auto-d_compute_metrics function · python · L564-L574 (11 LOC)eval/evaluate_v3_linear_probe.py
def _compute_metrics(test_lab, pred, pred_proba):
try:
auc = roc_auc_score(test_lab, pred_proba)
except ValueError:
auc = float('nan')
return {
'auc': round(auc, 4),
'balanced_accuracy': round(balanced_accuracy_score(test_lab, pred), 4),
'f1': round(f1_score(test_lab, pred, average='binary'), 4),
'accuracy': round(accuracy_score(test_lab, pred), 4),
}Repobility's GitHub App fixes findings like these · https://github.com/apps/repobility-bot
classify_xgboost function · python · L577-L602 (26 LOC)eval/evaluate_v3_linear_probe.py
def classify_xgboost(
train_emb: np.ndarray, train_lab: np.ndarray,
test_emb: np.ndarray, test_lab: np.ndarray,
) -> Dict[str, float]:
scaler = StandardScaler()
train_s = scaler.fit_transform(train_emb)
test_s = scaler.transform(test_emb)
n_pos = int(train_lab.sum())
n_neg = len(train_lab) - n_pos
scale_pos = n_neg / max(n_pos, 1)
clf = xgb.XGBClassifier(
n_estimators=200,
max_depth=6,
learning_rate=0.1,
scale_pos_weight=scale_pos,
eval_metric='logloss',
use_label_encoder=False,
verbosity=0,
)
clf.fit(train_s, train_lab)
pred = clf.predict(test_s)
pred_proba = clf.predict_proba(test_s)[:, 1]
return _compute_metrics(test_lab, pred, pred_proba)classify_tabpfn function · python · L605-L630 (26 LOC)eval/evaluate_v3_linear_probe.py
def classify_tabpfn(
train_emb: np.ndarray, train_lab: np.ndarray,
test_emb: np.ndarray, test_lab: np.ndarray,
) -> Dict[str, float]:
scaler = StandardScaler()
train_s = scaler.fit_transform(train_emb)
test_s = scaler.transform(test_emb)
# StandardScaler produces NaN for zero-variance features (std=0 → div by 0).
# This happens with heavily masked data (e.g. TBI with 4/19 channels).
# Replace NaN with 0 (zero-variance features carry no information).
train_s = np.nan_to_num(train_s, nan=0.0)
test_s = np.nan_to_num(test_s, nan=0.0)
# TabPFN works best with fewer features — PCA to 128 dims
n_components = min(128, train_s.shape[0] - 1, train_s.shape[1])
pca = PCA(n_components=n_components)
train_s = pca.fit_transform(train_s)
test_s = pca.transform(test_s)
clf = TabPFNClassifier()
clf.fit(train_s, train_lab)
pred = clf.predict(test_s)
pred_proba = clf.predict_proba(test_s)[:, 1]
return _compute_metrics(test_compute_regression_metrics function · python · L635-L642 (8 LOC)eval/evaluate_v3_linear_probe.py
def _compute_regression_metrics(y_true, y_pred):
r, p = pearsonr(y_true, y_pred)
return {
'pearson_r': round(float(r), 4),
'r2': round(float(r2_score(y_true, y_pred)), 4),
'mae': round(float(mean_absolute_error(y_true, y_pred)), 4),
'n': len(y_true),
}regress_xgboost function · python · L645-L661 (17 LOC)eval/evaluate_v3_linear_probe.py
def regress_xgboost(
train_emb: np.ndarray, train_lab: np.ndarray,
test_emb: np.ndarray, test_lab: np.ndarray,
) -> Dict[str, float]:
scaler = StandardScaler()
train_s = scaler.fit_transform(train_emb)
test_s = scaler.transform(test_emb)
reg = xgb.XGBRegressor(
n_estimators=200,
max_depth=6,
learning_rate=0.1,
verbosity=0,
)
reg.fit(train_s, train_lab)
pred = reg.predict(test_s)
return _compute_regression_metrics(test_lab, pred)regress_ridge function · python · L664-L675 (12 LOC)eval/evaluate_v3_linear_probe.py
def regress_ridge(
train_emb: np.ndarray, train_lab: np.ndarray,
test_emb: np.ndarray, test_lab: np.ndarray,
) -> Dict[str, float]:
scaler = StandardScaler()
train_s = scaler.fit_transform(train_emb)
test_s = scaler.transform(test_emb)
reg = Ridge(alpha=1.0)
reg.fit(train_s, train_lab)
pred = reg.predict(test_s)
return _compute_regression_metrics(test_lab, pred)remap_to_v5 function · python · L683-L739 (57 LOC)eval/evaluate_v3_linear_probe.py
def remap_to_v5(v3_path: str) -> str:
"""Remap a v3_minimal (or v4) file path to the equivalent v5 path.
The model was trained on v5 data (robust z-score normalization), so
evaluation should also use v5 data to eliminate preprocessing mismatch.
Filename mapping per dataset:
CAUEEG: patient_00018.h5 → 00018.h5
AD_EEG: patient_001.h5 → sub-001_task-eyesclosed_eeg.h5
DS004504: patient_001.h5 → sub-001_task-eyesclosed_eeg.h5
PEARL: patient_01_msit.h5 → sub-01_task-msit_eeg.h5
READTBI: <same>.h5 → <same>.h5
TD_BRAIN: patient_ID_task.h5 → sub-ID_ses-1_task-TASK_eeg.h5
PARKINSON_UCSD: patient_hc10.h5 → sub-hc10_ses-hc_task-rest_eeg.h5
patient_pd11.h5 → sub-pd11_ses-on_task-rest_eeg.h5
"""
p = Path(v3_path)
dataset = extract_dataset_from_path(v3_path)
stem = p.stem # e.g. 'patient_00018'
if dfix_tdbrain_path function · python · L742-L744 (3 LOC)eval/evaluate_v3_linear_probe.py
def fix_tdbrain_path(path: str) -> str:
"""Fix tdbrain_downstream paths that reference non-existent processed_v4."""
return path.replace('processed_v4/', 'processed_v3_minimal/')evaluate_classification_task function · python · L789-L850 (62 LOC)eval/evaluate_v3_linear_probe.py
def evaluate_classification_task(
model: EEGEncoderV3,
task: str,
splits_dir: Path,
device: str,
max_length: int,
batch_size: int,
fix_paths: bool = False,
pooling: str = 'mean',
segment_selection: str = 'all',
fixed_segments: int = 0,
use_v5: bool = True,
vqvae: Optional['VQVAE'] = None,
) -> Dict[str, Dict[str, float]]:
splits_file = splits_dir / task / 'splits.json'
with open(splits_file) as f:
splits = json.load(f)
results = {}
for split in ['train', 'test']:
pos_files = splits[split]['positive']
neg_files = splits[split]['negative']
if fix_paths:
pos_files = [fix_tdbrain_path(f) for f in pos_files]
neg_files = [fix_tdbrain_path(f) for f in neg_files]
if use_v5:
pos_files = [remap_to_v5(f) for f in pos_files]
neg_files = [remap_to_v5(f) for f in neg_files]
n_ch = getattr(model, 'max_channels', 19)
pos_data = ClasRepobility — the code-quality scanner for AI-generated software · https://repobility.com
evaluate_regression_task function · python · L853-L909 (57 LOC)eval/evaluate_v3_linear_probe.py
def evaluate_regression_task(
model: EEGEncoderV3,
task: str,
device: str,
max_length: int,
batch_size: int,
pooling: str = 'mean',
segment_selection: str = 'all',
fixed_segments: int = 0,
use_v5: bool = True,
vqvae: Optional['VQVAE'] = None,
) -> Dict[str, Dict[str, float]]:
splits_file = REGRESSION_SPLITS_DIR / task / 'splits.json'
with open(splits_file) as f:
splits = json.load(f)
target_key = REGRESSION_TARGET_KEYS.get(task, task)
results = {}
for split in ['train', 'test']:
entries = splits[split]
n_ch = getattr(model, 'max_channels', 19)
dataset = RegressionDataset(entries, target_key, max_length=max_length,
n_channels=n_ch, use_v5=use_v5)
if len(dataset) == 0:
print(f" {split.capitalize():>5}: 0 segments — skipping")
return {}
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
_run_eval function · python · L912-L924 (13 LOC)eval/evaluate_v3_linear_probe.py
def _run_eval(model, task_list, eval_fn, device, max_length, batch_size, **kwargs):
"""Run evaluation across a list of tasks with a given eval function."""
task_results = {}
for task in task_list:
print(f"\n [{task}]")
try:
metrics = eval_fn(model, task, device=device,
max_length=max_length, batch_size=batch_size, **kwargs)
if metrics:
task_results[task] = metrics
except Exception as e:
print(f" ERROR: {e}")
return task_results_run_single_eval function · python · L927-L963 (37 LOC)eval/evaluate_v3_linear_probe.py
def _run_single_eval(model, clf_tasks, tdb_tasks, reg_tasks, device, args,
pooling='mean', segment_selection='all', fixed_segments=0,
vqvae=None):
"""Run a single eval pass with given pooling and segment selection."""
use_v5 = getattr(args, 'use_v5', True)
results = {}
if clf_tasks:
clf_results = _run_eval(
model, clf_tasks,
lambda m, t, **kw: evaluate_classification_task(
m, t, CLASSIFICATION_SPLITS_DIR,
pooling=pooling, segment_selection=segment_selection,
fixed_segments=fixed_segments, use_v5=use_v5, vqvae=vqvae, **kw),
device, args.max_length, args.batch_size)
results.update(clf_results)
if tdb_tasks:
tdb_results = _run_eval(
model, tdb_tasks,
lambda m, t, **kw: evaluate_classification_task(
m, t, TDBRAIN_SPLITS_DIR, fix_paths=True,
pooling=pooling, segment_page 1 / 4next ›