← back to drewlinsley__tmp_bs_mcp

Function bodies 170 total

All specs Real LLM only Function bodies
get_reference_type_id function · python · L62-L65 (4 LOC)
data/eeg_pretrain_dataset_safe.py
def get_reference_type_id(dataset: str) -> int:
    """Get the reference type ID for a dataset."""
    ref_name = DATASET_REFERENCE_TYPES.get(dataset, 'unknown')
    return EEGConfig.REFERENCE_TYPES.get(ref_name, EEGConfig.REFERENCE_TYPES['unknown'])
get_dataset_channel_mask function · python · L68-L88 (21 LOC)
data/eeg_pretrain_dataset_safe.py
def get_dataset_channel_mask(dataset: str, max_channels: int = 19) -> np.ndarray:
    """Create channel mask for a dataset.

    Args:
        dataset: Dataset name
        max_channels: Maximum number of channels (default 19, standard 10-20 system)

    Returns:
        Binary mask array where 1 = channel present, 0 = channel absent
    """
    dataset_channels = DATASET_CHANNELS.get(dataset, EEGConfig.STANDARD_CHANNELS)
    all_channels = EEGConfig.STANDARD_CHANNELS

    mask = np.zeros(max_channels, dtype=np.float32)
    for ch in dataset_channels:
        if ch in all_channels:
            idx = all_channels.index(ch)
            if idx < max_channels:
                mask[idx] = 1.0

    return mask
get_channel_indices function · python · L91-L106 (16 LOC)
data/eeg_pretrain_dataset_safe.py
def get_channel_indices(dataset: str) -> List[int]:
    """Get indices of channels present in a dataset.

    Args:
        dataset: Dataset name

    Returns:
        List of channel indices in STANDARD_CHANNELS that are present
    """
    dataset_channels = DATASET_CHANNELS.get(dataset, EEGConfig.STANDARD_CHANNELS)
    all_channels = EEGConfig.STANDARD_CHANNELS
    indices = []
    for ch in dataset_channels:
        if ch in all_channels:
            indices.append(all_channels.index(ch))
    return indices
get_condition_description function · python · L193-L209 (17 LOC)
data/eeg_pretrain_dataset_safe.py
def get_condition_description(dataset: str, recording_state: str) -> str:
    """Get natural language description of the recording condition."""
    # Try exact match first
    key = (dataset, recording_state)
    if key in CONDITION_DESCRIPTIONS:
        return CONDITION_DESCRIPTIONS[key]

    # Try with 'unknown' state
    key_unknown = (dataset, 'unknown')
    if key_unknown in CONDITION_DESCRIPTIONS:
        return CONDITION_DESCRIPTIONS[key_unknown]

    # Fall back to dataset default
    if dataset in DEFAULT_CONDITION_DESCRIPTIONS:
        return DEFAULT_CONDITION_DESCRIPTIONS[dataset]

    return "EEG recording"
EEGPretrainDatasetSafe class · python · L212-L533 (322 LOC)
data/eeg_pretrain_dataset_safe.py
class EEGPretrainDatasetSafe(Dataset):
    """Dataset for EEG pretraining with safe H5 file handling for multi-GPU
    
    This version opens H5 files on demand in each worker process to avoid
    file handle corruption issues with multiprocessing.
    """
    
    def __init__(
        self,
        data_dir: Path,
        split: str = 'train',
        datasets: List[str] = ['CAUEEG', 'PEARL', 'TD_BRAIN'],
        max_length: int = 24000,  # 2 minutes @ 200Hz
        segment_length: int = 24000,  # Length of each preprocessed segment
        include_eyes_state: bool = True,
        splits_file: Optional[Path] = None,
        patch_size: int = 32,
        max_channels: int = 19,
    ):
        """
        Args:
            data_dir: Root directory containing processed datasets
            split: 'train', 'val', or 'test'
            datasets: List of datasets to include
            max_length: Maximum sequence length in samples (can exceed segment_length
                        to con
__init__ method · python · L219-L291 (73 LOC)
data/eeg_pretrain_dataset_safe.py
    def __init__(
        self,
        data_dir: Path,
        split: str = 'train',
        datasets: List[str] = ['CAUEEG', 'PEARL', 'TD_BRAIN'],
        max_length: int = 24000,  # 2 minutes @ 200Hz
        segment_length: int = 24000,  # Length of each preprocessed segment
        include_eyes_state: bool = True,
        splits_file: Optional[Path] = None,
        patch_size: int = 32,
        max_channels: int = 19,
    ):
        """
        Args:
            data_dir: Root directory containing processed datasets
            split: 'train', 'val', or 'test'
            datasets: List of datasets to include
            max_length: Maximum sequence length in samples (can exceed segment_length
                        to concatenate consecutive segments for longer context)
            segment_length: Length of each preprocessed segment in the H5 files
            include_eyes_state: Whether to include eyes open/closed info
            splits_file: Optional path to custom splits JSON
__getitem__ method · python · L296-L439 (144 LOC)
data/eeg_pretrain_dataset_safe.py
    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        file_path, start_seg_idx, n_avail, dataset = self.segments_info[idx]

        # Open file fresh each time - safer for multiprocessing
        with h5py.File(file_path, 'r') as f:
            # Load and concatenate consecutive segments
            segments = []
            for i in range(n_avail):
                seg = f['eeg/segments'][start_seg_idx + i][:]
                segments.append(seg)
            eeg_data = np.concatenate(segments, axis=-1)  # (C, n_avail * segment_length)

            # Load channel info from H5 (v6 has channel_coords; v5 does not)
            has_v6_coords = 'eeg/channel_coords' in f
            if has_v6_coords:
                file_channel_coords = f['eeg/channel_coords'][:]  # (n_ch, 2)
            else:
                file_channel_coords = None
            file_channel_names = [
                x.decode() if isinstance(x, bytes) else x
                for x in f['eeg/channel_names'][:]
 
Repobility's GitHub App fixes findings like these · https://github.com/apps/repobility-bot
_extract_recording_state method · python · L441-L533 (93 LOC)
data/eeg_pretrain_dataset_safe.py
    def _extract_recording_state(self, f: h5py.File, dataset: str) -> str:
        """Extract detailed recording state from HDF5 file based on dataset
        
        Returns detailed states like:
        - 'resting eyes closed'
        - 'resting eyes open'
        - 'MSIT task'
        - 'Sternberg task'
        - 'unknown'
        """
        # Default state
        state = 'unknown'
        
        try:
            if dataset == 'AD_EEG':
                # AD_EEG has eyes_closed attribute
                if 'eyes_closed' in f.attrs:
                    state = 'resting eyes closed' if f.attrs['eyes_closed'] else 'resting eyes open'
                    
            elif dataset == 'CAUEEG':
                # CAUEEG clinical EEG — check attrs, default to clinical recording
                if 'recording_type' in f.attrs:
                    state = f.attrs['recording_type']
                elif 'eyes_state' in f.attrs:
                    state = f'resting eyes {f.attrs["eyes_state"
collate_batch_eeg function · python · L536-L584 (49 LOC)
data/eeg_pretrain_dataset_safe.py
def collate_batch_eeg(batch: List[Dict]) -> Dict[str, torch.Tensor]:
    """Collate function that stacks EEG data properly"""
    eeg_data = torch.stack([item['eeg'] for item in batch])

    # Convert lists to tensors/lists as appropriate
    eyes_states = [item['eyes_state'] for item in batch]
    recording_states = [item['recording_state'] for item in batch]
    condition_descriptions = [item['condition_description'] for item in batch]
    datasets = [item['dataset'] for item in batch]
    file_paths = [item['file_path'] for item in batch]

    # Reference type IDs as tensor for 2D spatial RoPE
    reference_type_ids = torch.tensor(
        [item['reference_type_id'] for item in batch],
        dtype=torch.long
    )
    reference_names = [item['reference_name'] for item in batch]

    # Channel masks for variable channel support
    channel_masks = torch.stack([item['channel_mask'] for item in batch])
    channel_indices_list = [item['channel_indices'] for item in batch]

    # Chan
worker_init_fn function · python · L587-L601 (15 LOC)
data/eeg_pretrain_dataset_safe.py
def worker_init_fn(worker_id):
    """Initialize worker with proper random seed and H5PY settings"""
    import numpy as np
    import random
    import torch
    
    # Set random seeds for reproducibility
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)
    
    # Configure H5PY for better performance
    import h5py
    # Use latest file format for better performance
    h5py.get_config().track_order = False
create_safe_pretrain_dataloader function · python · L604-L653 (50 LOC)
data/eeg_pretrain_dataset_safe.py
def create_safe_pretrain_dataloader(
    data_dir: Path,
    split: str = 'train',
    batch_size: int = 8,
    num_workers: int = 4,
    shuffle: bool = True,
    datasets: List[str] = ['CAUEEG', 'PEARL', 'TD_BRAIN'],
    pin_memory: bool = True,
    prefetch_factor: int = 2,
    persistent_workers: bool = True,
    max_length: int = 24000,
    segment_length: int = 24000,
    include_eyes_state: bool = True,
    splits_file: Optional[Path] = None,
    patch_size: int = 32,
    max_channels: int = 19,
) -> DataLoader:
    """Create a safe dataloader for pretraining that handles multiprocessing correctly"""

    # Use the safe dataset
    dataset = EEGPretrainDatasetSafe(
        data_dir=data_dir,
        split=split,
        datasets=datasets,
        max_length=max_length,
        segment_length=segment_length,
        include_eyes_state=include_eyes_state,
        splits_file=splits_file,
        patch_size=patch_size,
        max_channels=max_channels,
    )
    
    # For distrib
load_model function · python · L48-L77 (30 LOC)
eval/evaluate_full_recording.py
def load_model(checkpoint_path: str, device: str = 'cuda'):
    ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
    config = ckpt.get('config', {})
    max_channels = config.get('max_channels', config.get('n_channels', 19))

    model = EEGEncoderV3(
        n_channels=config.get('n_channels', 19),
        hidden_dim=config.get('hidden_dim', 1024),
        n_heads=config.get('n_heads', 16),
        n_kv_heads=config.get('n_kv_heads', 4),
        n_layers=config.get('n_layers', 16),
        patch_size=config.get('patch_size', 32),
        dropout=0.0,
        ffn_multiplier=config.get('ffn_multiplier', 8.0 / 3.0),
        rope_base=config.get('rope_base', 500_000.0),
        n_registers=config.get('n_registers', 8),
        use_spatial_rope=config.get('use_spatial_rope', True),
        use_modality_tokens=config.get('use_modality_tokens', True),
        use_condition_encoding=True,
        condition_cache_path='data/condition_embeddings.pt',
        max_channe
load_full_recording function · python · L84-L135 (52 LOC)
eval/evaluate_full_recording.py
def load_full_recording(h5_path: str, max_channels: int = 19):
    """Load all segments from H5 and stitch into continuous recording."""
    with h5py.File(h5_path, 'r') as f:
        segments = f['eeg/segments'][:]
        n_seg = segments.shape[0]
        if n_seg == 1:
            eeg = segments[0]
        else:
            parts = [segments[i, :, :STRIDE] for i in range(n_seg - 1)]
            parts.append(segments[-1])
            eeg = np.concatenate(parts, axis=1)

        file_coords = f['eeg/channel_coords'][:] if 'eeg/channel_coords' in f else None
        ch_names = [x.decode() if isinstance(x, bytes) else x for x in f['eeg/channel_names'][:]]
        file_mask = f['channel_mask'][:] if 'channel_mask' in f else None

    n_native = eeg.shape[0]
    if n_native < max_channels:
        eeg = np.pad(eeg, ((0, max_channels - n_native), (0, 0)))
    elif n_native > max_channels:
        eeg = eeg[:max_channels]
        n_native = max_channels

    # Channel mask
    if max_channe
extract_recording_state function · python · L138-L156 (19 LOC)
eval/evaluate_full_recording.py
def extract_recording_state(filepath, dataset):
    fname = Path(filepath).stem
    states = {
        'AD_EEG': 'resting eyes closed', 'DS004504': 'resting eyes closed',
        'PARKINSON_UCSD': 'resting eyes closed', 'READTBI': 'clinical assessment',
        'CAUEEG': 'resting eyes closed', 'LEMON': 'resting eyes closed',
        'DEPRESSION': 'resting', 'SRM': 'resting eyes closed',
    }
    if dataset in states:
        return states[dataset]
    if dataset == 'DORTMUND':
        return 'resting eyes closed' if 'EyesClosed' in fname else 'resting eyes open'
    if dataset == 'PEARL':
        if 'rest' in fname: return 'resting'
        if 'msit' in fname: return 'MSIT task'
        if 'sternberg' in fname: return 'Sternberg task'
    if dataset == 'TD_BRAIN':
        return 'resting eyes open' if 'restEO' in fname else 'resting eyes closed'
    return 'unknown'
embed_full_recording function · python · L160-L191 (32 LOC)
eval/evaluate_full_recording.py
def embed_full_recording(model, h5_path, dataset, device, chunk_patches=1500):
    """Embed entire recording → one pooled vector."""
    max_ch = model.max_channels
    eeg, file_mask, coords = load_full_recording(h5_path, max_channels=max_ch)
    if eeg.shape[1] < 32:
        return None

    if file_mask is not None:
        channel_mask = torch.from_numpy(file_mask).unsqueeze(0).to(device)
    else:
        channel_mask = torch.from_numpy(
            get_dataset_channel_mask(dataset, max_channels=max_ch)
        ).unsqueeze(0).to(device)

    channel_coords = torch.from_numpy(coords).unsqueeze(0).to(device)
    ref_ids = torch.tensor([get_reference_type_id(dataset)], dtype=torch.long, device=device)
    rec_state = extract_recording_state(h5_path, dataset)
    condition_keys = [(dataset, rec_state)]

    chunk_samples = chunk_patches * 32
    all_hidden = []
    for start in range(0, eeg.shape[1], chunk_samples):
        chunk = eeg[:, start:start + chunk_samples].unsqueeze(0).to(d
Repobility — the code-quality scanner for AI-generated software · https://repobility.com
run_classification function · python · L196-L231 (36 LOC)
eval/evaluate_full_recording.py
def run_classification(train_emb, train_lab, test_emb, test_lab):
    results = {}
    # XGBoost
    clf = XGBClassifier(n_estimators=100, max_depth=4, use_label_encoder=False,
                        eval_metric='logloss', verbosity=0)
    clf.fit(train_emb, train_lab)
    probs = clf.predict_proba(test_emb)[:, 1]
    auc = float(roc_auc_score(test_lab, probs))
    results['xgboost'] = {
        'auc': auc,
        'y_true': [int(y) for y in test_lab],
        'y_prob': [float(p) for p in probs],
    }
    # TabPFN
    if HAS_TABPFN:
        n_comp = min(128, train_emb.shape[1], train_emb.shape[0])
        pca = PCA(n_components=n_comp)
        tr_pca = pca.fit_transform(train_emb)
        te_pca = pca.transform(test_emb)
        try:
            clf = TabPFNClassifier(device='cpu', N_ensemble_configurations=16,
                                   ignore_pretraining_limits=True)
        except TypeError:
            try:
                clf = TabPFNClassifier(device='cpu', ignore_pretr
run_regression function · python · L234-L245 (12 LOC)
eval/evaluate_full_recording.py
def run_regression(train_emb, train_lab, test_emb, test_lab):
    reg = XGBRegressor(n_estimators=100, max_depth=4, verbosity=0)
    reg.fit(train_emb, train_lab)
    preds = reg.predict(test_emb)
    if np.std(test_lab) < 1e-8 or np.std(preds) < 1e-8:
        return {'xgboost': {'pearson_r': 0.0, 'p_value': 1.0,
                            'y_true': [float(y) for y in test_lab],
                            'y_pred': [float(p) for p in preds]}}
    r, p = pearsonr(test_lab, preds)
    return {'xgboost': {'pearson_r': float(r), 'p_value': float(p),
                        'y_true': [float(y) for y in test_lab],
                        'y_pred': [float(p) for p in preds]}}
resolve_h5_path function · python · L250-L273 (24 LOC)
eval/evaluate_full_recording.py
def resolve_h5_path(path_str: str, dataset: str = None) -> Optional[str]:
    """Resolve a split file path to an actual H5 file on disk."""
    # Direct path
    p = Path(path_str)
    if p.exists():
        return str(p)

    # v5 remap
    v5 = remap_to_v5(path_str)
    if Path(v5).exists():
        return v5

    # v6 directory
    v6_base = Path('/home/dlinsley/eeg_pretrain_v2/data/processed_v6')
    if dataset:
        v6_path = v6_base / dataset / Path(path_str).name
        if v6_path.exists():
            return str(v6_path)
        # Try with _eeg suffix
        v6_path2 = v6_base / dataset / (Path(path_str).stem + '_eeg.h5')
        if v6_path2.exists():
            return str(v6_path2)

    return None
load_classification_v3_task function · python · L276-L314 (39 LOC)
eval/evaluate_full_recording.py
def load_classification_v3_task(task: str):
    """Load classification_v3 or tdbrain task → list of (h5_path, label, dataset, patient_id)."""
    # Try classification_v3 first, then tdbrain
    for splits_dir in ['data/splits/classification_v3', 'data/splits/tdbrain_downstream']:
        sp = Path(f'{splits_dir}/{task}/splits.json')
        if sp.exists():
            break
    else:
        return None

    with open(sp) as f:
        data = json.load(f)

    datasets = data.get('datasets', ['unknown'])

    entries = {}
    for split in ['train', 'test']:
        items = []
        split_data = data.get(split, {})
        if isinstance(split_data, dict):
            for label_name, files in split_data.items():
                label = 1 if label_name == 'positive' else 0
                for fpath in files:
                    # Fix tdbrain paths
                    if 'processed_v4' in fpath:
                        fpath = fix_tdbrain_path(fpath)
                    ds = None
       
load_regression_task function · python · L317-L345 (29 LOC)
eval/evaluate_full_recording.py
def load_regression_task(task: str):
    """Load regression_v2 task → list of (h5_path, label, dataset, patient_id)."""
    sp = Path(f'data/splits/regression_v2/{task}/splits.json')
    if not sp.exists():
        return None
    with open(sp) as f:
        data = json.load(f)

    target_keys = {'age': 'age', 'mmse': 'mmse', 'bdi': 'bdi', 'rpm': 'rpm',
                   'naart': 'naart', 'disease_duration': 'disease_duration'}
    target_key = target_keys.get(task, task)

    entries = {}
    for split in ['train', 'test']:
        items = []
        for entry in data.get(split, []):
            ds = entry.get('dataset', 'unknown')
            if ds in EXCLUDED_DATASETS:
                continue
            val = entry.get(target_key)
            if val is None:
                continue
            fpath = f"data/processed_v3_minimal/{ds}/{entry['file']}"
            resolved = resolve_h5_path(fpath, ds)
            if resolved:
                pid = entry.get('patient_id', Path(ent
load_new_task function · python · L348-L370 (23 LOC)
eval/evaluate_full_recording.py
def load_new_task(splits_dir: str, task: str):
    """Load new-format task (dortmund/srm/depression) → entries."""
    sp = Path(f'data/splits/{splits_dir}/{task}/splits.json')
    if not sp.exists():
        return None
    with open(sp) as f:
        data = json.load(f)

    v6_base = Path('/home/dlinsley/eeg_pretrain_v2/data/processed_v6')
    entries = {}
    for split in ['train', 'test']:
        items = []
        for entry in data.get(split, []):
            h5_file = entry['file']
            label = entry['label']
            ds = h5_file.split('/')[0]
            h5_path = v6_base / h5_file
            if not h5_path.exists():
                continue
            pid = Path(h5_file).stem
            items.append((str(h5_path), label, ds, pid))
        entries[split] = items
    return entries
embed_and_aggregate function · python · L375-L399 (25 LOC)
eval/evaluate_full_recording.py
def embed_and_aggregate(model, items, device):
    """Embed all files, aggregate per patient."""
    patient_emb = {}
    patient_lab = {}
    for h5_path, label, dataset, pid in tqdm(items, leave=False):
        emb = embed_full_recording(model, h5_path, dataset, device)
        if emb is None:
            continue
        if pid not in patient_emb:
            patient_emb[pid] = []
            patient_lab[pid] = label
        patient_emb[pid].append(emb)

    if not patient_emb:
        return None, None
    emb_arr = np.array([np.mean(v, axis=0) for v in patient_emb.values()])
    lab_arr = np.array(list(patient_lab.values()))
    # Drop any patients with NaN embeddings
    valid = ~np.isnan(emb_arr).any(axis=1)
    if not valid.all():
        n_bad = (~valid).sum()
        print(f"    WARNING: dropping {n_bad} patients with NaN embeddings")
        emb_arr = emb_arr[valid]
        lab_arr = lab_arr[valid]
    return emb_arr, lab_arr
evaluate_task function · python · L402-L428 (27 LOC)
eval/evaluate_full_recording.py
def evaluate_task(model, task_name, entries, device, is_regression=False):
    """Full pipeline: embed → aggregate → classify/regress."""
    print(f"  [{task_name}]")
    results = {}
    for split in ['train', 'test']:
        if split not in entries or not entries[split]:
            return {}
        emb, lab = embed_and_aggregate(model, entries[split], device)
        if emb is None:
            return {}
        results[split] = (emb, lab)
        n_pos = int((lab == 1).sum()) if not is_regression else len(lab)
        print(f"    {split}: {len(lab)} patients" +
              (f" ({n_pos}+/{len(lab)-n_pos}-)" if not is_regression else f" (range [{lab.min():.1f}, {lab.max():.1f}])"))

    train_emb, train_lab = results['train']
    test_emb, test_lab = results['test']

    if is_regression:
        res = run_regression(train_emb, train_lab, test_emb, test_lab)
        print(f"    → r={res['xgboost']['pearson_r']:.3f}")
    else:
        res = run_classification(train_emb, train_la
Source: Repobility analyzer · https://repobility.com
main function · python · L463-L507 (45 LOC)
eval/evaluate_full_recording.py
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--checkpoint', required=True)
    parser.add_argument('--device', default='cuda')
    parser.add_argument('--output', default='eval/full_recording_results.json')
    args = parser.parse_args()

    model, config = load_model(args.checkpoint, args.device)
    all_results = {}

    # 1. Classification_v3 + tdbrain tasks
    print("\n=== Classification Tasks ===")
    for task in CLASSIFICATION_TASKS:
        entries = load_classification_v3_task(task)
        if entries is None:
            print(f"  [{task}] SKIPPED (no splits)")
            continue
        res = evaluate_task(model, task, entries, args.device, is_regression=False)
        all_results[task] = res

    # 2. Regression tasks
    print("\n=== Regression Tasks ===")
    for task in REGRESSION_TASKS:
        entries = load_regression_task(task)
        if entries is None:
            print(f"  [{task}] SKIPPED (no splits)")
            continue
    
extract_patient_id function · python · L55-L79 (25 LOC)
eval/evaluate_v3_linear_probe.py
def extract_patient_id(filepath: str) -> str:
    """Extract patient-level ID from a file path for aggregation.

    Examples:
      .../ASZED/patient_108_s1_phase1.h5  → patient_108
      .../CAUEEG/patient_00018.h5         → patient_00018
      .../READTBI/110800044AR1_..._201504151548.h5 → 110800044AR1
      .../AD_EEG/patient_001.h5           → patient_001
      .../DS004504/patient_001.h5         → patient_001
    """
    fname = Path(filepath).stem  # drop .h5
    # ASZED: patient_NNN_sN_phaseN
    m = re.match(r'(patient_\d+)_s\d+', fname)
    if m:
        return m.group(1)
    # READTBI: SITE_SUBJECT_DATETIME
    m = re.match(r'(\d{9}AR\d)_', fname)
    if m:
        return m.group(1)
    # TD_BRAIN: patient_NNNNNNN_restEC/restEO
    m = re.match(r'(patient_\d+)_rest', fname)
    if m:
        return m.group(1)
    # Default: full stem (works for CAUEEG patient_00018, AD_EEG patient_001, etc.)
    return fname
extract_dataset_from_path function · python · L82-L95 (14 LOC)
eval/evaluate_v3_linear_probe.py
def extract_dataset_from_path(filepath: str) -> str:
    """Extract dataset name from file path.

    Expects paths like:
      data/processed_v3_minimal/CAUEEG/patient_00018.h5 → CAUEEG
      data/processed_v3_minimal/READTBI/110800044AR1_...h5 → READTBI
    """
    parts = Path(filepath).parts
    for i, part in enumerate(parts):
        if part in ('processed_v3_minimal', 'processed_v4', 'processed_v5'):
            if i + 1 < len(parts):
                return parts[i + 1]
    # Fallback: parent directory name
    return Path(filepath).parent.name
extract_recording_state function · python · L98-L125 (28 LOC)
eval/evaluate_v3_linear_probe.py
def extract_recording_state(filepath: str, dataset: str) -> str:
    """Extract recording state from filename, matching training dataset logic."""
    fname = Path(filepath).stem
    if dataset == 'AD_EEG':
        return 'unknown'  # Would need H5 attrs; use default
    elif dataset == 'CAUEEG':
        return 'unknown'
    elif dataset == 'PEARL':
        if 'rest' in fname:
            return 'resting'
        elif 'msit' in fname:
            return 'MSIT task'
        elif 'sternberg' in fname:
            return 'Sternberg task'
        return 'unknown'
    elif dataset == 'TD_BRAIN':
        if 'restEC' in fname:
            return 'resting eyes closed'
        elif 'restEO' in fname:
            return 'resting eyes open'
        return 'unknown'
    elif dataset == 'READTBI':
        return 'clinical assessment'
    elif dataset == 'DS004504':
        return 'resting eyes closed'
    elif dataset == 'PARKINSON_UCSD':
        return 'resting eyes closed'
    return 'unknown'
ClassificationDataset class · python · L130-L237 (108 LOC)
eval/evaluate_v3_linear_probe.py
class ClassificationDataset(Dataset):
    """Loads EEG segments from v3 split file paths."""

    def __init__(
        self,
        file_paths: List[str],
        label: int,
        max_length: int = 48000,
        n_channels: int = 19,
    ):
        self.max_length = max_length
        self.n_channels = n_channels
        # (h5_path, seg_idx, patient_id, label, dataset, recording_state)
        self.segments_info = []

        for rel_path in file_paths:
            h5_path = Path(rel_path)
            if not h5_path.exists():
                print(f"Warning: File not found: {h5_path}")
                continue
            patient_id = extract_patient_id(rel_path)
            dataset = extract_dataset_from_path(rel_path)
            recording_state = extract_recording_state(rel_path, dataset)
            try:
                with h5py.File(h5_path, 'r') as f:
                    n_segments = f['eeg/segments'].shape[0]
                    for seg_idx in range(n_segments):
         
__init__ method · python · L133-L161 (29 LOC)
eval/evaluate_v3_linear_probe.py
    def __init__(
        self,
        file_paths: List[str],
        label: int,
        max_length: int = 48000,
        n_channels: int = 19,
    ):
        self.max_length = max_length
        self.n_channels = n_channels
        # (h5_path, seg_idx, patient_id, label, dataset, recording_state)
        self.segments_info = []

        for rel_path in file_paths:
            h5_path = Path(rel_path)
            if not h5_path.exists():
                print(f"Warning: File not found: {h5_path}")
                continue
            patient_id = extract_patient_id(rel_path)
            dataset = extract_dataset_from_path(rel_path)
            recording_state = extract_recording_state(rel_path, dataset)
            try:
                with h5py.File(h5_path, 'r') as f:
                    n_segments = f['eeg/segments'].shape[0]
                    for seg_idx in range(n_segments):
                        self.segments_info.append(
                            (str(h5_path), seg_idx, 
__getitem__ method · python · L166-L237 (72 LOC)
eval/evaluate_v3_linear_probe.py
    def __getitem__(self, idx):
        file_path, segment_idx, patient_id, label, dataset, recording_state = \
            self.segments_info[idx]

        with h5py.File(file_path, 'r') as f:
            eeg_data = f['eeg/segments'][segment_idx][:]  # (C, T)
            # Load channel info for variable-channel support
            if 'eeg/channel_coords' in f:
                file_coords = f['eeg/channel_coords'][:]
            else:
                file_coords = None
            file_ch_names = [
                x.decode() if isinstance(x, bytes) else x
                for x in f['eeg/channel_names'][:]
            ]
            if 'channel_mask' in f:
                file_mask = f['channel_mask'][:]
            else:
                file_mask = None

        C, T = eeg_data.shape

        if C > self.n_channels:
            eeg_data = eeg_data[:self.n_channels, :]
        elif C < self.n_channels:
            eeg_data = np.pad(eeg_data, ((0, self.n_channels - C), (0, 0)))

        i
RegressionDataset class · python · L240-L348 (109 LOC)
eval/evaluate_v3_linear_probe.py
class RegressionDataset(Dataset):
    """Loads EEG segments for regression tasks from regression_v2 split format."""

    def __init__(
        self,
        entries: List[dict],
        target_key: str,
        max_length: int = 48000,
        n_channels: int = 19,
        use_v5: bool = True,
    ):
        self.max_length = max_length
        self.n_channels = n_channels
        # (h5_path, seg_idx, patient_id, target_value, dataset, recording_state)
        self.segments_info = []

        for entry in entries:
            dataset = entry['dataset']
            if dataset in EXCLUDED_DATASETS:
                continue
            v3_path = f"data/processed_v3_minimal/{dataset}/{entry['file']}"
            h5_path = Path(remap_to_v5(v3_path) if use_v5 else v3_path)
            if not h5_path.exists():
                continue
            patient_id = entry.get('patient_id', Path(entry['file']).stem)
            target_val = entry[target_key]
            if target_val is None:
      
Provenance: Repobility (https://repobility.com) — every score reproducible from /scan/
__init__ method · python · L243-L277 (35 LOC)
eval/evaluate_v3_linear_probe.py
    def __init__(
        self,
        entries: List[dict],
        target_key: str,
        max_length: int = 48000,
        n_channels: int = 19,
        use_v5: bool = True,
    ):
        self.max_length = max_length
        self.n_channels = n_channels
        # (h5_path, seg_idx, patient_id, target_value, dataset, recording_state)
        self.segments_info = []

        for entry in entries:
            dataset = entry['dataset']
            if dataset in EXCLUDED_DATASETS:
                continue
            v3_path = f"data/processed_v3_minimal/{dataset}/{entry['file']}"
            h5_path = Path(remap_to_v5(v3_path) if use_v5 else v3_path)
            if not h5_path.exists():
                continue
            patient_id = entry.get('patient_id', Path(entry['file']).stem)
            target_val = entry[target_key]
            if target_val is None:
                continue
            recording_state = extract_recording_state(str(h5_path), dataset)
            try:
     
__getitem__ method · python · L282-L348 (67 LOC)
eval/evaluate_v3_linear_probe.py
    def __getitem__(self, idx):
        file_path, segment_idx, patient_id, target_val, dataset, recording_state = \
            self.segments_info[idx]

        with h5py.File(file_path, 'r') as f:
            eeg_data = f['eeg/segments'][segment_idx][:]
            if 'eeg/channel_coords' in f:
                file_coords = f['eeg/channel_coords'][:]
            else:
                file_coords = None
            file_ch_names = [
                x.decode() if isinstance(x, bytes) else x
                for x in f['eeg/channel_names'][:]
            ]
            if 'channel_mask' in f:
                file_mask = f['channel_mask'][:]
            else:
                file_mask = None

        C, T = eeg_data.shape
        if C > self.n_channels:
            eeg_data = eeg_data[:self.n_channels, :]
        elif C < self.n_channels:
            eeg_data = np.pad(eeg_data, ((0, self.n_channels - C), (0, 0)))
        if T > self.max_length:
            eeg_data = eeg_data[:, :self.max_
load_v3_model function · python · L353-L385 (33 LOC)
eval/evaluate_v3_linear_probe.py
def load_v3_model(checkpoint_path: str, device: str = 'cuda',
                  condition_cache_path: str = 'data/condition_embeddings.pt'):
    ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
    config = ckpt.get('config', {})

    max_channels = config.get('max_channels', config.get('n_channels', 19))
    model = EEGEncoderV3(
        n_channels=config.get('n_channels', 19),
        hidden_dim=config.get('hidden_dim', 1024),
        n_heads=config.get('n_heads', 16),
        n_kv_heads=config.get('n_kv_heads', 4),
        n_layers=config.get('n_layers', 16),
        patch_size=config.get('patch_size', 32),
        dropout=0.0,
        ffn_multiplier=config.get('ffn_multiplier', 8.0 / 3.0),
        rope_base=config.get('rope_base', 500_000.0),
        n_registers=config.get('n_registers', 8),
        use_spatial_rope=config.get('use_spatial_rope', True),
        use_modality_tokens=config.get('use_modality_tokens', True),
        use_condition_encoding=True
create_random_v3_model function · python · L388-L408 (21 LOC)
eval/evaluate_v3_linear_probe.py
def create_random_v3_model(config: dict, device: str = 'cuda',
                           condition_cache_path: str = 'data/condition_embeddings.pt'):
    model = EEGEncoderV3(
        n_channels=config.get('n_channels', 19),
        hidden_dim=config.get('hidden_dim', 1024),
        n_heads=config.get('n_heads', 16),
        n_kv_heads=config.get('n_kv_heads', 4),
        n_layers=config.get('n_layers', 16),
        patch_size=config.get('patch_size', 32),
        dropout=0.0,
        ffn_multiplier=config.get('ffn_multiplier', 8.0 / 3.0),
        rope_base=config.get('rope_base', 500_000.0),
        n_registers=config.get('n_registers', 8),
        use_spatial_rope=config.get('use_spatial_rope', True),
        use_modality_tokens=config.get('use_modality_tokens', True),
        use_condition_encoding=True,
        condition_cache_path=condition_cache_path,
        max_channels=config.get('max_channels', config.get('n_channels', 19)),
    )
    model.to(device).eval()
    return model
extract_embeddings function · python · L414-L483 (70 LOC)
eval/evaluate_v3_linear_probe.py
def extract_embeddings(
    model: EEGEncoderV3,
    dataloader: DataLoader,
    device: str,
    pooling: str = 'mean',
    vqvae: Optional['VQVAE'] = None,
) -> Tuple[np.ndarray, np.ndarray, List[str]]:
    """Extract embeddings. Returns (embeddings, labels, patient_ids).

    Matches training pipeline: applies channel_mask, passes reference_type_ids
    and condition_keys to model.encode().  No per-channel normalization (training
    does not apply any).

    Args:
        pooling: 'mean' for mean over all patch positions,
                 'last' for the last patch position (best for causal models).
        vqvae: Optional VQ-VAE model. If provided, hidden states are quantized
               and reconstructed before pooling (measures VQ-VAE information loss).
    """
    all_emb, all_lab, all_pid = [], [], []

    for batch in tqdm(dataloader, desc="  Embedding", leave=False):
        eeg = batch['eeg'].to(device)

        # Channel mask — zero out absent channels (same as training 
eval_collate_fn function · python · L486-L500 (15 LOC)
eval/evaluate_v3_linear_probe.py
def eval_collate_fn(batch: List[Dict]) -> Dict:
    """Collate function that stacks EEG + metadata for eval."""
    result = {
        'eeg': torch.stack([b['eeg'] for b in batch]),
        'label': torch.tensor([b['label'] for b in batch]),
        'patient_id': [b['patient_id'] for b in batch],
        'channel_mask': torch.stack([b['channel_mask'] for b in batch]),
        'reference_type_ids': torch.tensor(
            [b['reference_type_id'] for b in batch], dtype=torch.long),
        'dataset': [b['dataset'] for b in batch],
        'recording_state': [b['recording_state'] for b in batch],
    }
    if 'channel_coords' in batch[0]:
        result['channel_coords'] = torch.stack([b['channel_coords'] for b in batch])
    return result
aggregate_to_patient function · python · L505-L559 (55 LOC)
eval/evaluate_v3_linear_probe.py
def aggregate_to_patient(
    embeddings: np.ndarray,
    labels: np.ndarray,
    patient_ids: List[str],
    segment_selection: str = 'all',
    fixed_segments: int = 0,
) -> Tuple[np.ndarray, np.ndarray, List[str]]:
    """Average embeddings per patient.

    Args:
        segment_selection: 'all' uses all segments, 'first_half' uses first half
            of segments per patient, 'second_half' uses second half. Segments are
            ordered chronologically (as stored in the H5 file).
        fixed_segments: If > 0, subsample each patient to exactly this many
            segments before averaging. If 0 (default), use all segments.
            Use -1 to auto-detect: subsample to min segment count across patients.
    """
    patient_data = defaultdict(lambda: {'embs': [], 'label': None})

    for emb, lab, pid in zip(embeddings, labels, patient_ids):
        patient_data[pid]['embs'].append(emb)
        patient_data[pid]['label'] = lab

    # Determine fixed segment count if auto-d
_compute_metrics function · python · L564-L574 (11 LOC)
eval/evaluate_v3_linear_probe.py
def _compute_metrics(test_lab, pred, pred_proba):
    try:
        auc = roc_auc_score(test_lab, pred_proba)
    except ValueError:
        auc = float('nan')
    return {
        'auc': round(auc, 4),
        'balanced_accuracy': round(balanced_accuracy_score(test_lab, pred), 4),
        'f1': round(f1_score(test_lab, pred, average='binary'), 4),
        'accuracy': round(accuracy_score(test_lab, pred), 4),
    }
Repobility's GitHub App fixes findings like these · https://github.com/apps/repobility-bot
classify_xgboost function · python · L577-L602 (26 LOC)
eval/evaluate_v3_linear_probe.py
def classify_xgboost(
    train_emb: np.ndarray, train_lab: np.ndarray,
    test_emb: np.ndarray, test_lab: np.ndarray,
) -> Dict[str, float]:
    scaler = StandardScaler()
    train_s = scaler.fit_transform(train_emb)
    test_s = scaler.transform(test_emb)

    n_pos = int(train_lab.sum())
    n_neg = len(train_lab) - n_pos
    scale_pos = n_neg / max(n_pos, 1)

    clf = xgb.XGBClassifier(
        n_estimators=200,
        max_depth=6,
        learning_rate=0.1,
        scale_pos_weight=scale_pos,
        eval_metric='logloss',
        use_label_encoder=False,
        verbosity=0,
    )
    clf.fit(train_s, train_lab)

    pred = clf.predict(test_s)
    pred_proba = clf.predict_proba(test_s)[:, 1]
    return _compute_metrics(test_lab, pred, pred_proba)
classify_tabpfn function · python · L605-L630 (26 LOC)
eval/evaluate_v3_linear_probe.py
def classify_tabpfn(
    train_emb: np.ndarray, train_lab: np.ndarray,
    test_emb: np.ndarray, test_lab: np.ndarray,
) -> Dict[str, float]:
    scaler = StandardScaler()
    train_s = scaler.fit_transform(train_emb)
    test_s = scaler.transform(test_emb)

    # StandardScaler produces NaN for zero-variance features (std=0 → div by 0).
    # This happens with heavily masked data (e.g. TBI with 4/19 channels).
    # Replace NaN with 0 (zero-variance features carry no information).
    train_s = np.nan_to_num(train_s, nan=0.0)
    test_s = np.nan_to_num(test_s, nan=0.0)

    # TabPFN works best with fewer features — PCA to 128 dims
    n_components = min(128, train_s.shape[0] - 1, train_s.shape[1])
    pca = PCA(n_components=n_components)
    train_s = pca.fit_transform(train_s)
    test_s = pca.transform(test_s)

    clf = TabPFNClassifier()
    clf.fit(train_s, train_lab)

    pred = clf.predict(test_s)
    pred_proba = clf.predict_proba(test_s)[:, 1]
    return _compute_metrics(test
_compute_regression_metrics function · python · L635-L642 (8 LOC)
eval/evaluate_v3_linear_probe.py
def _compute_regression_metrics(y_true, y_pred):
    r, p = pearsonr(y_true, y_pred)
    return {
        'pearson_r': round(float(r), 4),
        'r2': round(float(r2_score(y_true, y_pred)), 4),
        'mae': round(float(mean_absolute_error(y_true, y_pred)), 4),
        'n': len(y_true),
    }
regress_xgboost function · python · L645-L661 (17 LOC)
eval/evaluate_v3_linear_probe.py
def regress_xgboost(
    train_emb: np.ndarray, train_lab: np.ndarray,
    test_emb: np.ndarray, test_lab: np.ndarray,
) -> Dict[str, float]:
    scaler = StandardScaler()
    train_s = scaler.fit_transform(train_emb)
    test_s = scaler.transform(test_emb)

    reg = xgb.XGBRegressor(
        n_estimators=200,
        max_depth=6,
        learning_rate=0.1,
        verbosity=0,
    )
    reg.fit(train_s, train_lab)
    pred = reg.predict(test_s)
    return _compute_regression_metrics(test_lab, pred)
regress_ridge function · python · L664-L675 (12 LOC)
eval/evaluate_v3_linear_probe.py
def regress_ridge(
    train_emb: np.ndarray, train_lab: np.ndarray,
    test_emb: np.ndarray, test_lab: np.ndarray,
) -> Dict[str, float]:
    scaler = StandardScaler()
    train_s = scaler.fit_transform(train_emb)
    test_s = scaler.transform(test_emb)

    reg = Ridge(alpha=1.0)
    reg.fit(train_s, train_lab)
    pred = reg.predict(test_s)
    return _compute_regression_metrics(test_lab, pred)
remap_to_v5 function · python · L683-L739 (57 LOC)
eval/evaluate_v3_linear_probe.py
def remap_to_v5(v3_path: str) -> str:
    """Remap a v3_minimal (or v4) file path to the equivalent v5 path.

    The model was trained on v5 data (robust z-score normalization), so
    evaluation should also use v5 data to eliminate preprocessing mismatch.

    Filename mapping per dataset:
        CAUEEG:         patient_00018.h5     → 00018.h5
        AD_EEG:         patient_001.h5       → sub-001_task-eyesclosed_eeg.h5
        DS004504:       patient_001.h5       → sub-001_task-eyesclosed_eeg.h5
        PEARL:          patient_01_msit.h5   → sub-01_task-msit_eeg.h5
        READTBI:        <same>.h5            → <same>.h5
        TD_BRAIN:       patient_ID_task.h5   → sub-ID_ses-1_task-TASK_eeg.h5
        PARKINSON_UCSD: patient_hc10.h5      → sub-hc10_ses-hc_task-rest_eeg.h5
                        patient_pd11.h5      → sub-pd11_ses-on_task-rest_eeg.h5
    """
    p = Path(v3_path)
    dataset = extract_dataset_from_path(v3_path)
    stem = p.stem  # e.g. 'patient_00018'

    if d
fix_tdbrain_path function · python · L742-L744 (3 LOC)
eval/evaluate_v3_linear_probe.py
def fix_tdbrain_path(path: str) -> str:
    """Fix tdbrain_downstream paths that reference non-existent processed_v4."""
    return path.replace('processed_v4/', 'processed_v3_minimal/')
evaluate_classification_task function · python · L789-L850 (62 LOC)
eval/evaluate_v3_linear_probe.py
def evaluate_classification_task(
    model: EEGEncoderV3,
    task: str,
    splits_dir: Path,
    device: str,
    max_length: int,
    batch_size: int,
    fix_paths: bool = False,
    pooling: str = 'mean',
    segment_selection: str = 'all',
    fixed_segments: int = 0,
    use_v5: bool = True,
    vqvae: Optional['VQVAE'] = None,
) -> Dict[str, Dict[str, float]]:
    splits_file = splits_dir / task / 'splits.json'
    with open(splits_file) as f:
        splits = json.load(f)

    results = {}
    for split in ['train', 'test']:
        pos_files = splits[split]['positive']
        neg_files = splits[split]['negative']

        if fix_paths:
            pos_files = [fix_tdbrain_path(f) for f in pos_files]
            neg_files = [fix_tdbrain_path(f) for f in neg_files]

        if use_v5:
            pos_files = [remap_to_v5(f) for f in pos_files]
            neg_files = [remap_to_v5(f) for f in neg_files]

        n_ch = getattr(model, 'max_channels', 19)
        pos_data = Clas
Repobility — the code-quality scanner for AI-generated software · https://repobility.com
evaluate_regression_task function · python · L853-L909 (57 LOC)
eval/evaluate_v3_linear_probe.py
def evaluate_regression_task(
    model: EEGEncoderV3,
    task: str,
    device: str,
    max_length: int,
    batch_size: int,
    pooling: str = 'mean',
    segment_selection: str = 'all',
    fixed_segments: int = 0,
    use_v5: bool = True,
    vqvae: Optional['VQVAE'] = None,
) -> Dict[str, Dict[str, float]]:
    splits_file = REGRESSION_SPLITS_DIR / task / 'splits.json'
    with open(splits_file) as f:
        splits = json.load(f)

    target_key = REGRESSION_TARGET_KEYS.get(task, task)

    results = {}
    for split in ['train', 'test']:
        entries = splits[split]
        n_ch = getattr(model, 'max_channels', 19)
        dataset = RegressionDataset(entries, target_key, max_length=max_length,
                                    n_channels=n_ch, use_v5=use_v5)

        if len(dataset) == 0:
            print(f"    {split.capitalize():>5}: 0 segments — skipping")
            return {}

        loader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
              
_run_eval function · python · L912-L924 (13 LOC)
eval/evaluate_v3_linear_probe.py
def _run_eval(model, task_list, eval_fn, device, max_length, batch_size, **kwargs):
    """Run evaluation across a list of tasks with a given eval function."""
    task_results = {}
    for task in task_list:
        print(f"\n  [{task}]")
        try:
            metrics = eval_fn(model, task, device=device,
                              max_length=max_length, batch_size=batch_size, **kwargs)
            if metrics:
                task_results[task] = metrics
        except Exception as e:
            print(f"    ERROR: {e}")
    return task_results
_run_single_eval function · python · L927-L963 (37 LOC)
eval/evaluate_v3_linear_probe.py
def _run_single_eval(model, clf_tasks, tdb_tasks, reg_tasks, device, args,
                     pooling='mean', segment_selection='all', fixed_segments=0,
                     vqvae=None):
    """Run a single eval pass with given pooling and segment selection."""
    use_v5 = getattr(args, 'use_v5', True)
    results = {}

    if clf_tasks:
        clf_results = _run_eval(
            model, clf_tasks,
            lambda m, t, **kw: evaluate_classification_task(
                m, t, CLASSIFICATION_SPLITS_DIR,
                pooling=pooling, segment_selection=segment_selection,
                fixed_segments=fixed_segments, use_v5=use_v5, vqvae=vqvae, **kw),
            device, args.max_length, args.batch_size)
        results.update(clf_results)

    if tdb_tasks:
        tdb_results = _run_eval(
            model, tdb_tasks,
            lambda m, t, **kw: evaluate_classification_task(
                m, t, TDBRAIN_SPLITS_DIR, fix_paths=True,
                pooling=pooling, segment_
page 1 / 4next ›