← back to krm9c__ContLearn

Function bodies 268 total

All specs Real LLM only Function bodies
HamiltonianMixin.return_Hamiltonian_graph method · python · L652-L717 (66 LOC)
src/cl/core/hamiltonian.py
    def return_Hamiltonian_graph(self, params, data, notABTrain, grad_weights=None,
                                 normalize_dV=True, dV_scale=1.0):
        """Compute Hamiltonian gradient for graph classification.

        Uses JIT-compiled core functions for high GPU utilization.

        Args:
            params: Trainable model parameters
            data: Tuple of (static, (batch, batch_ex, deltax, delta_adj))
                  OR (static, (x, y, adj, b, n, exp_x, exp_y, exp_adj, exp_b, exp_n, deltax, delta_adj))
            notABTrain: True for standard training, False for AWB A/B training
            grad_weights: Optional [alpha, beta, gamma] weights for gradient combination
            normalize_dV: Whether to normalize dV by parameter count (default True)
            dV_scale: Additional scaling factor for dV (default 1.0)

        Returns:
            Tuple of (grad, (H, V, dV, dV_dtheta, dV_dx, dV_dadj))
        """
        if grad_weights is None:
            grad_weight
HamiltonianMixin.return_Hamiltonian_mse method · python · L719-L757 (39 LOC)
src/cl/core/hamiltonian.py
    def return_Hamiltonian_mse(self, params, data, notABTrain=True, grad_weights=None,
                               normalize_dV=True, dV_scale=1.0):
        """Compute Hamiltonian gradient for MSE regression.

        Uses JIT-compiled core functions for high GPU utilization.

        Args:
            params: Trainable model parameters
            data: Tuple of (statics, (x, y, exp_x, exp_y, deltax, flag))
            notABTrain: True for standard training, False for AWB A/B training
            grad_weights: Optional [alpha, beta, gamma] weights for gradient combination
            normalize_dV: Whether to normalize dV by parameter count (default True)
            dV_scale: Additional scaling factor for dV (default 1.0)

        Returns:
            Tuple of (grad, (H, V, dV, dV_dtheta, dV_dx))
        """
        if grad_weights is None:
            grad_weights = DEFAULT_GRAD_WEIGHTS
        alpha, beta, gamma = grad_weights

        statics, (x, y, exp_x, exp_y, deltax, flag) 
HamiltonianMixin.return_Hamiltonian_class method · python · L759-L801 (43 LOC)
src/cl/core/hamiltonian.py
    def return_Hamiltonian_class(self, params, data, notABTrain=True, grad_weights=None,
                                 normalize_dV=True, dV_scale=1.0):
        """Compute Hamiltonian gradient for classification.

        Uses JIT-compiled core functions for high GPU utilization.

        Args:
            params: Trainable model parameters
            data: Tuple of (statics, (x, y, exp_x, exp_y, deltax, flag))
            notABTrain: True for standard training, False for AWB A/B training
            grad_weights: Optional [alpha, beta, gamma] weights for gradient combination
            normalize_dV: Whether to normalize dV by parameter count (default True)
            dV_scale: Additional scaling factor for dV (default 1.0)

        Returns:
            Tuple of (grad, (H, V, dV, dV_dtheta, dV_dx))
        """
        if grad_weights is None:
            grad_weights = DEFAULT_GRAD_WEIGHTS
        alpha, beta, gamma = grad_weights

        statics, (x, y, exp_x, exp_y, deltax, fl
get_graph_transforms function · python · L38-L54 (17 LOC)
src/cl/core/loops.py
def get_graph_transforms():
    """Lazily load graph transforms to avoid torch_geometric import when not needed.

    Transform pipeline matches old working code:
    T.Compose([T.GCNNorm(), T.ToDense(), T.NormalizeFeatures()])
    """
    global _GRAPH_TRANSFORMS
    if _GRAPH_TRANSFORMS is None:
        import torch_geometric.transforms as T

        # Fixed by Claude: Removed T.NormalizeFeatures() which was destroying the
        # class-feature correlation in FakeDataset (correlation dropped from 0.998 to -0.135)
        _GRAPH_TRANSFORMS = T.Compose([
            T.GCNNorm(),
            T.ToDense()
        ])
    return _GRAPH_TRANSFORMS
TrainingLoopsMixin._clip_gradients method · python · L60-L97 (38 LOC)
src/cl/core/loops.py
    def _clip_gradients(self, grads, max_norm=None):
        """Clip gradients by global norm.

        Added by Claude: Following Pascanu et al. (2013) recommendations for
        gradient clipping in recurrent networks, applicable to continual learning
        with potentially unstable gradients.

        Reference:
        - Pascanu et al., "On the difficulty of training recurrent neural networks",
          ICML 2013

        Args:
            grads: PyTree of gradients
            max_norm: Maximum gradient norm (None = no clipping)

        Returns:
            Tuple of (clipped_grads, global_norm, was_clipped)
        """
        if max_norm is None or max_norm <= 0:
            # No clipping
            grad_leaves = jax.tree_util.tree_leaves(grads)
            global_norm = jnp.sqrt(sum([jnp.sum(g**2) for g in grad_leaves]))
            return grads, global_norm, False

        # Compute global norm
        grad_leaves = jax.tree_util.tree_leaves(grads)
        global_norm = j
TrainingLoopsMixin._compute_metrics_on_sampled_batches method · python · L99-L218 (120 LOC)
src/cl/core/loops.py
    def _compute_metrics_on_sampled_batches(self, params, static, loader,
                                            num_batches=10, problem_type='vectors',
                                            notABTrain=True, transforms=None):
        """Efficiently compute metrics on N sampled batches from a loader.

        Args:
            params: Model parameters
            static: Static model components
            loader: Data loader (tuple of current_loader, exp_loader)
            num_batches: Number of batches to sample (default 10)
            problem_type: 'vectors' or 'graph'
            notABTrain: Whether using normal training (True) or AWB training (False)
            transforms: Transform pipeline for graph data

        Returns:
            Tuple of (current_task_metric, experience_metric)
        """
        current_metrics = []
        exp_metrics = []

        if problem_type == 'graph':
            if isinstance(loader, tuple):
                current_loader, exp_loade
TrainingLoopsMixin._compute_perturbation_variance method · python · L220-L287 (68 LOC)
src/cl/core/loops.py
    def _compute_perturbation_variance(self, trainloader, exploader, problem_type, max_batches=5):
        """Pre-compute variance for perturbation sampling.

        Uses the mean difference approach for feature variance.
        For graphs, also computes adjacency variance.

        Args:
            trainloader: Current task data loader
            exploader: Experience replay data loader
            problem_type: 'vectors' or 'graph'
            max_batches: Maximum batches to sample for variance estimation (default 5)
                        5 batches × 2048 batch_size = 10,240 samples (statistically sufficient)

        Returns:
            Tuple of (var_x, var_adj) where var_adj is 0 for vectors
        """
        var_x_list, var_adj_list = [], []
        transforms = get_graph_transforms() if problem_type == 'graph' else None

        # Fixed by Claude: Handle empty experience loader (Task 0 for graph datasets)
        # Try to check if exploader has any data
        exp_iter_
Repobility · code-quality intelligence · https://repobility.com
LossMixin.loss_fn_mse method · python · L31-L38 (8 LOC)
src/cl/core/losses.py
    def loss_fn_mse(self, params, statics, x, y):
        """MSE loss for regression (vectors)."""
        model = eqx.combine(params, statics)
        preds = jax.vmap(model)(x)
        # Added by Claude: squeeze extra dimension if present (batch, 1, output) -> (batch, output)
        if preds.ndim == 3 and preds.shape[1] == 1:
            preds = jnp.squeeze(preds, axis=1)
        return jnp.mean((y - preds)**2)
LossMixin.accuracy_vectors method · python · L41-L67 (27 LOC)
src/cl/core/losses.py
    def accuracy_vectors(self, params, statics, x, y):
        """Accuracy metric for classification (vectors).

        Args:
            params: Model parameters
            statics: Static model components
            x: Input features, shape (batch, ...)
            y: Class labels - either scalar indices shape (batch,) or one-hot shape (batch, num_classes)

        Returns:
            Accuracy as float between 0 and 1
        """
        model = eqx.combine(params, statics)
        preds = jax.vmap(model)(x)
        # Squeeze extra dimension if present (batch, 1, classes) -> (batch, classes)
        if preds.ndim == 3 and preds.shape[1] == 1:
            preds = jnp.squeeze(preds, axis=1)
        pred = jnp.argmax(jax.nn.softmax(preds), axis=1)
        # Handle both one-hot encoded (batch, num_classes) and scalar labels (batch,)
        if y.ndim == 2 and y.shape[1] > 1:
            # One-hot encoded: convert to class indices
            y = jnp.argmax(y, axis=1)
        elif y.n
LossMixin.mse_vectors method · python · L70-L77 (8 LOC)
src/cl/core/losses.py
    def mse_vectors(self, params, statics, x, y):
        """MSE metric for regression (vectors)."""
        model = eqx.combine(params, statics)
        preds = jax.vmap(model)(x)
        # Added by Claude: squeeze extra dimension if present (batch, 1, output) -> (batch, output)
        if preds.ndim == 3 and preds.shape[1] == 1:
            preds = jnp.squeeze(preds, axis=1)
        return jnp.mean(optax.l2_loss(y, preds))
LossMixin.loss_fn_class_graph method · python · L82-L89 (8 LOC)
src/cl/core/losses.py
    def loss_fn_class_graph(self, params, statics, x, y, adj=None):
        """Cross-entropy loss for graph classification."""
        model = eqx.combine(params, statics)
        logits = jnp.stack([model(x[i], adj[i]).T for i in range(len(x))])
        pred_y = jnp.stack(logits)
        y = y.astype(jnp.int64)
        pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
        return -jnp.mean(y * pred_y)
LossMixin.accuracy_graphs method · python · L98-L103 (6 LOC)
src/cl/core/losses.py
    def accuracy_graphs(self, params, statics, x, adj, b, n):
        """Accuracy computation for graph classification (standard forward)."""
        model = eqx.combine(params, statics)
        array_log = [model(x[i], adj[i], b[i], n[i]) for i in range(len(x))]
        logits = jnp.concatenate(array_log)
        return jax.nn.log_softmax(logits, axis=1)
LossMixin.accuracy_graphs_AWBT method · python · L106-L111 (6 LOC)
src/cl/core/losses.py
    def accuracy_graphs_AWBT(self, params, statics, x, adj, b, n):
        """Accuracy computation for graph classification (AWB forward)."""
        model = eqx.combine(params, statics)
        array_log = [model.get_AWBT(x[i], adj[i], b[i], n[i]) for i in range(len(x))]
        logits = jnp.concatenate(array_log)
        return jax.nn.log_softmax(logits, axis=1)
LossMixin.return_loss_grad method · python · L135-L162 (28 LOC)
src/cl/core/losses.py
    def return_loss_grad(self, params, batch, static):
        """Compute loss and gradient for a batch.

        Args:
            params: Trainable model parameters
            batch: Input batch (varies by problem type)
            static: Static model components

        Returns:
            Tuple of (loss, gradients)
        """
        if self.problem == 'vectors':
            (x, y) = batch
            if self.loss == 'class':
                grads = jax.grad(self.loss_fn_class)(params, static, x, y)
                loss = self.loss_fn_class(params, static, x, y)
            elif self.loss == 'mse':
                grads = jax.grad(self.loss_fn_mse)(params, static, x, y)
                loss = self.loss_fn_mse(params, static, x, y)
        elif self.problem == 'graph':
            (x, y, adj) = batch
            if self.loss == 'class':
                grads = jax.grad(self.loss_fn_class_graph)(params, static, x, y, adj=adj)
                loss = self.loss_fn_class_graph(params
LossMixin.return_metric method · python · L164-L229 (66 LOC)
src/cl/core/losses.py
    def return_metric(self, params, statics, data, notABTrain=True):
        """Compute evaluation metric based on problem type.

        Args:
            params: Trainable model parameters
            statics: Static model components
            data: Input data (format depends on problem type)
                  For vectors: (x, y) where y is scalar class indices shape (batch,)
            notABTrain: True for standard forward, False for AWB forward

        Returns:
            Metric value (accuracy for classification, MSE for regression)
        """
        model = eqx.combine(params, statics)

        if self.problem == 'vectors':
            x, y = data
            if self.metric == 'class':
                # Ensure y is 1D int64 array of class indices
                y = y.astype(jnp.int64)
                if y.ndim == 2:
                    y = jnp.squeeze(y, axis=-1)
                if notABTrain:
                    preds = jax.vmap(model)(x)
                else:
          
Repobility (the analyzer behind this table) · https://repobility.com
get_gpu_memory_usage function · python · L16-L56 (41 LOC)
src/cl/core/profiling.py
def get_gpu_memory_usage():
    """
    Get current GPU memory usage from JAX.

    Returns:
        Tuple of (used_gb, total_gb, percent) or None if unavailable
    """
    try:
        import jax
        from jax.lib import xla_bridge

        # Get GPU backend
        backend = xla_bridge.get_backend()

        # Get memory stats from first GPU device
        devices = backend.devices()
        if not devices:
            return None

        device = devices[0]

        # Get memory info
        mem_stats = device.memory_stats()
        if mem_stats is None:
            return None

        bytes_in_use = mem_stats.get('bytes_in_use', 0)
        # Try to get peak memory
        peak_bytes = mem_stats.get('peak_bytes_in_use', bytes_in_use)

        # Convert to GB
        used_gb = bytes_in_use / (1024**3)
        peak_gb = peak_bytes / (1024**3)

        # Try to estimate total memory (H200 has 144GB typically)
        # We'll use peak as a proxy since total isn't always available
format_memory_stats function · python · L58-L70 (13 LOC)
src/cl/core/profiling.py
def format_memory_stats():
    """
    Format GPU memory usage as a string.

    Returns:
        Formatted string like "GPU: 2.3GB / 144GB (1.6%)" or "GPU: N/A"
    """
    mem = get_gpu_memory_usage()
    if mem is None:
        return "GPU: N/A"

    used_gb, peak_gb = mem
    return f"GPU: {used_gb:.2f}GB used, {peak_gb:.2f}GB peak"
enable_profiling function · python · L72-L80 (9 LOC)
src/cl/core/profiling.py
def enable_profiling(enabled: bool):
    """
    Enable/disable profiling globally.

    Args:
        enabled: True to enable profiling output, False to disable
    """
    global _PROFILING_ENABLED
    _PROFILING_ENABLED = enabled
profile function · python · L82-L110 (29 LOC)
src/cl/core/profiling.py
def profile(phase_name: str):
    """
    Decorator to time a function if profiling is enabled.

    When profiling is disabled, this has zero overhead (simple boolean check).
    When enabled, prints timing and GPU memory information for the decorated function.

    Args:
        phase_name: Human-readable name for this profiling phase

    Example:
        @profile("Dataset Loading")
        def load_data():
            # ... code ...
    """
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            if not _PROFILING_ENABLED:
                return func(*args, **kwargs)

            print(f"\n[PROFILE] {phase_name} starting... | {format_memory_stats()}")
            start_time = time.time()
            result = func(*args, **kwargs)
            elapsed = time.time() - start_time
            print(f"[PROFILE] {phase_name} complete: {elapsed:.2f}s | {format_memory_stats()}")
            return result
        return wrapper
    return deco
profile_section function · python · L112-L141 (30 LOC)
src/cl/core/profiling.py
def profile_section(phase_name: str, enabled: bool = None):
    """
    Context manager for profiling a code section.

    Args:
        phase_name: Human-readable name for this profiling phase
        enabled: Optional override for profiling enabled status

    Example:
        with profile_section("JAX Pre-conversion"):
            # ... code to profile ...
    """
    class ProfileContext:
        def __enter__(self):
            if enabled is None:
                self.enabled = _PROFILING_ENABLED
            else:
                self.enabled = enabled

            if self.enabled:
                print(f"\n[PROFILE] {phase_name} starting... | {format_memory_stats()}")
                self.start_time = time.time()
            return self

        def __exit__(self, *args):
            if self.enabled:
                elapsed = time.time() - self.start_time
                print(f"[PROFILE] {phase_name} complete: {elapsed:.2f}s | {format_memory_stats()}")

    return ProfileContext
RecordingMixin._compute_eigenvalues method · python · L26-L148 (123 LOC)
src/cl/core/recording.py
    def _compute_eigenvalues(self, model, combined=False):
        """Compute eigenvalues of A/B matrices (AWB mode) or weight matrices (standard mode).

        For AWB-enabled models: computes eigenvalues of A and B matrices.
        For standard models: computes eigenvalues of weight matrices (W).

        Args:
            model: The model (combined params + static), or None to skip computation
            combined: If True, return model; if False, extract from self

        Returns:
            dict: Eigenvalues organized as:
                AWB mode:
                {
                    'A': {'layer_0': array, 'layer_1': array, ...},
                    'B': {'layer_0': array, 'layer_1': array, ...}
                }
                Standard mode (weights stored in 'A' key for plot compatibility):
                {
                    'A': {'layer_0': array, 'layer_1': array, ...},
                    'B': {}
                }
                If model is None, returns empty dict
RecordingMixin.record_metrics method · python · L150-L188 (39 LOC)
src/cl/core/recording.py
    def record_metrics(self,
                      iteration: int,
                      step: int,
                      task_id: int,
                      losses: Dict[str, float],
                      gradients: Dict[str, float],
                      metrics: Dict[str, float],
                      model,
                      extra_metrics: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
        """Record all metrics for a single iteration in unified format.

        Args:
            iteration: Iteration number (key for the record)
            step: Current step within task
            task_id: Current task ID
            losses: Dict of loss values (H, V, dV, dV_dx, dV_dtheta, dV_dadj, etc.)
            gradients: Dict of gradient norms
            metrics: Dict with train, test_current, test_experience metrics
            model: The model (for eigenvalue extraction)
            extra_metrics: Optional dict for dataset-specific metrics

        Returns:
            dict: C
RecordingMixin.initialize_record_dict method · python · L190-L229 (40 LOC)
src/cl/core/recording.py
    def initialize_record_dict(self, config: Dict[str, Any], run_id: int = 0) -> Dict[str, Any]:
        """Initialize the recording dictionary with metadata.

        Args:
            config: Configuration dictionary containing problem/dataset info
            run_id: The run/repetition number for this experiment

        Returns:
            dict: Initialized recording structure with task-based format
        """
        # Added by Claude: New task-based recording structure
        # - Global iterations: track progress across all tasks
        # - Within-task epochs: track progress within each task
        # - AB iterations: separate counter for AB training phases
        # Also maintain 'iterations' dict for backward compatibility with compute_avg_loss
        record_dict = {
            'metadata': {
                'problem': config.get('problem', 'unknown'),
                'prob': config.get('prob', 'unknown'),
                'dataset': config.get('data', 'unknown'),
         
Generated by Repobility's multi-pass static-analysis pipeline (https://repobility.com)
RecordingMixin.initialize_task method · python · L231-L261 (31 LOC)
src/cl/core/recording.py
    def initialize_task(self, record_dict: Dict[str, Any], task_id: int, arch_info: Dict[str, Any]):
        """Initialize a new task's recording structure.

        Args:
            record_dict: The recording dictionary
            task_id: Task ID to initialize
            arch_info: Architecture information (sizes, filter_size, etc.)
        """
        # Added by Claude: Initialize task with main_training structure
        record_dict['tasks'][task_id] = {
            'main_training': {
                'iterations': [],      # Global iteration numbers
                'epochs': [],          # Within-task epoch numbers
                'H': [],
                'V': [],
                'dV': [],
                'dV_dx': [],
                'dV_dtheta': [],
                'grad_norm': [],
                'train_metric': [],
                'test_current': [],
                'test_experience': [],
                'eigenvalues': {'A': {}, 'B': {}}
            },
            'phase_info
RecordingMixin.record_preliminary_summary method · python · L263-L282 (20 LOC)
src/cl/core/recording.py
    def record_preliminary_summary(self, record_dict: Dict[str, Any], task_id: int,
                                   n_epochs: int, warmup_epochs: int,
                                   final_loss: float, decision: str):
        """Record summary of preliminary phase (not detailed metrics).

        Args:
            record_dict: The recording dictionary
            task_id: Task ID
            n_epochs: Total preliminary epochs
            warmup_epochs: Warmup epochs within preliminary
            final_loss: Final loss after preliminary training
            decision: 'arch_change' or 'no_change'
        """
        # Added by Claude: Record preliminary phase summary
        record_dict['tasks'][task_id]['preliminary'] = {
            'n_epochs': n_epochs,
            'warmup_epochs': warmup_epochs,
            'final_loss': float(final_loss),
            'decision': decision
        }
RecordingMixin.initialize_ab_training method · python · L284-L297 (14 LOC)
src/cl/core/recording.py
    def initialize_ab_training(self, record_dict: Dict[str, Any], task_id: int):
        """Initialize AB training recording for a task.

        Args:
            record_dict: The recording dictionary
            task_id: Task ID
        """
        # Added by Claude: Initialize AB training structure with separate iteration counter
        record_dict['tasks'][task_id]['ab_training'] = {
            'iterations': [],      # AB-specific iteration numbers (local to AB phase)
            'H': [],
            'V': [],
            'ab_eigenvalues': {'A': {}, 'B': {}}
        }
RecordingMixin.record_ab_training_epoch method · python · L299-L329 (31 LOC)
src/cl/core/recording.py
    def record_ab_training_epoch(self, record_dict: Dict[str, Any], task_id: int,
                                 iteration: int, losses: Dict[str, float], model):
        """Record a single epoch of AB training with AB eigenvalues.

        Args:
            record_dict: The recording dictionary
            task_id: Task ID
            iteration: AB training iteration number (local to AB phase)
            losses: Loss dictionary
            model: Model for eigenvalue extraction
        """
        # Added by Claude: Auto-initialize task and AB training if not exists
        if task_id not in record_dict.get('tasks', {}):
            arch_info = {'sizes': 'unknown'}
            self.initialize_task(record_dict, task_id, arch_info)
        if 'ab_training' not in record_dict['tasks'][task_id]:
            self.initialize_ab_training(record_dict, task_id)

        # Added by Claude: Record AB training epoch
        ab = record_dict['tasks'][task_id]['ab_training']
        ab['iteratio
RecordingMixin.record_main_training_epoch method · python · L331-L373 (43 LOC)
src/cl/core/recording.py
    def record_main_training_epoch(self, record_dict: Dict[str, Any], task_id: int,
                                   global_iteration: int, epoch: int,
                                   losses: Dict[str, float], gradients: Dict[str, float],
                                   metrics: Dict[str, float], model):
        """Record a single epoch of main training (Step 5 or standard).

        Args:
            record_dict: The recording dictionary
            task_id: Task ID
            global_iteration: Global iteration number (across all tasks)
            epoch: Within-task epoch number
            losses: Loss dictionary
            gradients: Gradient dictionary
            metrics: Metrics dictionary
            model: Model for eigenvalue extraction
        """
        # Added by Claude: Auto-initialize task if not exists (for backward compatibility)
        if task_id not in record_dict.get('tasks', {}):
            # Get minimal architecture info from model
            arch_in
RecordingMixin.record_task_performance method · python · L375-L413 (39 LOC)
src/cl/core/recording.py
    def record_task_performance(self, record_dict: Dict[str, Any], current_task_id: int,
                                task_performances: Dict[int, float]):
        """Record performance on all tasks after training current_task_id.

        This builds the performance matrix A_{i,j} needed for CL metrics:
        - ACC (Average Accuracy)
        - BWT (Backward Transfer)
        - F (Average Forgetting)
        - FWT (Forward Transfer)

        Args:
            record_dict: The recording dictionary
            current_task_id: The task that was just trained (j in A_{i,j})
            task_performances: Dict mapping task_id -> performance metric
                               e.g., {0: 0.95, 1: 0.92, 2: 0.89}
                               Performance on tasks 0,1,2 after training task 2

        Example:
            After training task 2:
            task_performances = {
                0: 0.93,  # accuracy on task 0 after training task 2
                1: 0.91,  # accuracy on tas
RecordingMixin.save_record_dict method · python · L415-L445 (31 LOC)
src/cl/core/recording.py
    def save_record_dict(self, record_dict: Dict[str, Any], base_path: str):
        """Save the recording dictionary to file using problem/dataset name.

        Args:
            record_dict: The complete recording dictionary
            base_path: Base path for saving (typically config['model_path'])
        """
        metadata = record_dict['metadata']

        # Create filename: problem_dataset_network[_awb]_run{run_id}_records.pkl
        run_id = metadata.get('run_id', 0)
        awb_suffix = "_awb" if metadata.get('awb_enabled', False) else ""
        filename = f"{metadata['prob']}_{metadata['dataset']}_{metadata['network']}{awb_suffix}_run{run_id}_records.pkl"

        # Fixed by Claude: Use base_path directly as save directory (not parent)
        # This keeps each config's results in its own folder
        # e.g., model_path="kkt_run/kkt/results/sine_condition1_optimized" → save there
        if base_path:
            save_dir = base_path
            if not os.path.exists(
RecordingMixin.save_all_runs method · python · L448-L488 (41 LOC)
src/cl/core/recording.py
    def save_all_runs(all_runs_records: Dict[str, Any], base_path: str, config: Dict[str, Any]):
        """Save all experiment runs to a single file.

        Args:
            all_runs_records: Dictionary with run IDs as keys and record_dicts as values
            base_path: Base path for saving
            config: Configuration dictionary for metadata
        """
        # Create a consolidated structure with runs at the top level
        consolidated = {
            'runs': all_runs_records,
            'metadata': {
                'total_runs': len(all_runs_records),
                'problem': config.get('problem', 'unknown'),
                'prob': config.get('prob', 'unknown'),
                'dataset': config.get('data', 'unknown'),
                'network': config.get('network', 'unknown'),
                'awb_enabled': config.get('awb_enabled', False),
            }
        }

        # Create filename with AWB suffix if enabled
        awb_suffix = "_awb" if config.get(
Repobility analyzer · published findings · https://repobility.com
Trainer.__init__ method · python · L45-L55 (11 LOC)
src/cl/core/trainer.py
    def __init__(self, loss='mse', metric='mse', problem='vectors'):
        """Initialize the Trainer.

        Args:
            loss: Loss function type ('mse' for regression, 'class' for classification)
            metric: Metric function type ('mse' for MSE, 'class' for accuracy)
            problem: Problem type ('vectors' for MLP/CNN, 'graph' for GNN)
        """
        self.loss = loss
        self.problem = problem
        self.metric = metric
ContinualDataset.__init__ method · python · L31-L39 (9 LOC)
src/cl/datasets/base.py
    def __init__(self, config: Dict[str, Any], data_x, data_y):
        self.config = config
        self.x = data_x if torch.is_tensor(data_x) else torch.from_numpy(data_x.astype(np.float32))
        self.y = data_y if torch.is_tensor(data_y) else torch.from_numpy(data_y.astype(np.float32))

        # Reshape for fully connected networks (flatten images)
        if self.config.get('problem') == 'classification':
            if self.config.get('network') == 'fcnn':
                self.x = self.x.reshape([-1, 784])
BaseDataset.__init__ method · python · L71-L110 (40 LOC)
src/cl/datasets/base.py
    def __init__(self, config: Dict[str, Any]):
        """Initialize the dataset.

        Args:
            config: Configuration dictionary containing:
                - batch_size: Batch size for DataLoaders
                - len_exp_replay: Maximum experience replay buffer size
                - debug_mode: (optional) Enable debug mode with limited data
                - debug_limit: (optional) Number of samples in debug mode
        """
        self.config = config
        self.batch_size = config.get('batch_size', 64)
        self.len_exp_replay = config.get('len_exp_replay', 20000)
        self.debug_mode = config.get('debug_mode', False)
        self.debug_limit = config.get('debug_limit', 100)

        # Current task data
        self.X_train: Optional[torch.Tensor] = None
        self.y_train: Optional[np.ndarray] = None
        self.X_test: Optional[torch.Tensor] = None
        self.y_test: Optional[np.ndarray] = None

        # Experience replay buffers
        self.exp_x_
BaseDataset._load_task_data method · python · L113-L123 (11 LOC)
src/cl/datasets/base.py
    def _load_task_data(self, task_id: int) -> None:
        """Internal method to load data for a specific task.

        Subclasses must implement this to populate:
        - self.X_train, self.y_train (training data)
        - self.X_test, self.y_test (test data)

        Args:
            task_id: Task identifier (0-indexed)
        """
        pass
BaseDataset.load_task method · python · L125-L138 (14 LOC)
src/cl/datasets/base.py
    def load_task(self, task_id: int) -> None:
        """Load data for a specific task with debug limit applied.

        This wrapper calls _load_task_data and then applies debug limits.

        Args:
            task_id: Task identifier (0-indexed)
        """
        # Call the subclass implementation
        self._load_task_data(task_id)

        # Apply debug limit if enabled
        if self.debug_mode:
            self._apply_debug_limit()
BaseDataset._apply_debug_limit method · python · L140-L155 (16 LOC)
src/cl/datasets/base.py
    def _apply_debug_limit(self) -> None:
        """Apply debug limit to current task data.

        Limits X_train, y_train, X_test, y_test to debug_limit samples.
        """
        limit = self.debug_limit

        if self.X_train is not None and len(self.X_train) > limit:
            print(f"DEBUG MODE: Limiting training data from {len(self.X_train)} to {limit} samples")
            self.X_train = self.X_train[:limit]
            self.y_train = self.y_train[:limit]

        if self.X_test is not None and len(self.X_test) > limit:
            print(f"DEBUG MODE: Limiting test data from {len(self.X_test)} to {limit} samples")
            self.X_test = self.X_test[:limit]
            self.y_test = self.y_test[:limit]
BaseDataset._rebalance_buffer method · python · L175-L248 (74 LOC)
src/cl/datasets/base.py
    def _rebalance_buffer(self, task_id: int, is_train: bool = True) -> None:
        """Rebalance experience buffer with task-weighted sampling.

        Added by Claude: Implements balanced replay with recency weighting.
        Allocation: 10% recent task, 80% older tasks (equal split), 10% random.

        Args:
            task_id: Current task identifier (most recent task)
            is_train: True for training buffer, False for test buffer
        """
        if is_train:
            exp_x = self.exp_x_train
            exp_y = self.exp_y_train
            exp_task_ids = self.exp_task_ids_train
        else:
            exp_x = self.exp_x_test
            exp_y = self.exp_y_test
            exp_task_ids = self.exp_task_ids_test

        if len(exp_x) <= self.len_exp_replay:
            return  # No rebalancing needed

        n_tasks = task_id + 1

        if not self.balanced_replay_enabled or n_tasks == 1:
            # Disabled or first task: simple random sampling
         
BaseDataset.append_to_experience method · python · L250-L290 (41 LOC)
src/cl/datasets/base.py
    def append_to_experience(self, task_id: int) -> None:
        """Add current task data to the experience replay buffer.

        Manages buffer size using task-balanced sampling when enabled.

        Args:
            task_id: Current task identifier
        """
        # Convert to tensor if needed
        X_train = self.X_train if torch.is_tensor(self.X_train) else torch.from_numpy(
            np.array(self.X_train, dtype=np.float32))
        X_test = self.X_test if torch.is_tensor(self.X_test) else torch.from_numpy(
            np.array(self.X_test, dtype=np.float32))
        y_train = self.y_train if isinstance(self.y_train, np.ndarray) else np.array(self.y_train)
        y_test = self.y_test if isinstance(self.y_test, np.ndarray) else np.array(self.y_test)

        # Create task ID arrays for new samples
        new_task_ids_train = np.full(len(X_train), task_id, dtype=np.int8)
        new_task_ids_test = np.full(len(X_test), task_id, dtype=np.int8)

        if not self._exp
Repobility · code-quality intelligence · https://repobility.com
BaseDataset.get_task_data method · python · L292-L315 (24 LOC)
src/cl/datasets/base.py
    def get_task_data(self, task_id: int, phase: str) -> Tuple[Tuple, Tuple]:
        """Retrieve current and experience data for a task.

        Args:
            task_id: Task identifier
            phase: 'training' or 'testing'

        Returns:
            Tuple of ((current_x, current_y), (experience_x, experience_y))
        """
        if phase == 'training':
            current = (self.X_train, self.y_train)
            if task_id > 0 and self._exp_initialized:
                experience = (self.exp_x_train, self.exp_y_train)
            else:
                experience = (self.X_train, self.y_train)  # Use current as experience for task 0
        else:  # testing
            current = (self.X_test, self.y_test)
            if task_id > 0 and self._exp_initialized:
                experience = (self.exp_x_test, self.exp_y_test)
            else:
                experience = (self.X_test, self.y_test)

        return current, experience
BaseDataset.generate_dataset method · python · L317-L364 (48 LOC)
src/cl/datasets/base.py
    def generate_dataset(self, task_id: int, batch_size: int, phase: str) -> Tuple[DataLoader, DataLoader]:
        """Generate DataLoaders for a specific task.

        This is the main interface method called by runners.

        Args:
            task_id: Task identifier (0-indexed)
            batch_size: Batch size for DataLoaders
            phase: 'training' or 'testing'

        Returns:
            Tuple of (current_task_loader, experience_replay_loader)
        """
        # Load task data if in training phase (test uses same task data)
        if phase == 'training':
            self.load_task(task_id)

        # Get current and experience data
        (x_curr, y_curr), (x_exp, y_exp) = self.get_task_data(task_id, phase)

        # Create datasets
        dataset_curr = ContinualDataset(self.config, x_curr, y_curr)
        dataset_exp = ContinualDataset(self.config, x_exp, y_exp)

        # Create DataLoaders
        # Note: num_workers=0 required when using JAX because os.f
BaseDataset.generate_test_loader method · python · L366-L396 (31 LOC)
src/cl/datasets/base.py
    def generate_test_loader(self, task_id: int, batch_size: int = None) -> DataLoader:
        """Generate test loader for a specific task (for CL metrics evaluation).

        Added by Claude: This method enables per-task evaluation needed for computing
        the performance matrix A[j][i] = accuracy on task i after training task j.

        Args:
            task_id: Task ID to generate test loader for
            batch_size: Batch size for DataLoader (uses self.batch_size if None)

        Returns:
            DataLoader for task-specific test data
        """
        if batch_size is None:
            batch_size = self.batch_size

        # Load task data (populates self.X_test, self.y_test)
        self.load_task(task_id)

        # Create dataset and loader for test data
        dataset_test = ContinualDataset(self.config, self.X_test, self.y_test)

        loader_kwargs = {
            'batch_size': batch_size,
            'shuffle': False,  # No shuffle for evaluation
      
BaseDataset.get_model_config method · python · L398-L408 (11 LOC)
src/cl/datasets/base.py
    def get_model_config(self) -> Dict[str, Any]:
        """Return configuration for model initialization.

        Returns:
            Dictionary with input_size, output_size, and any dataset-specific config
        """
        return {
            'input_size': self.input_size,
            'output_size': self.output_size,
            'n_tasks': self.n_tasks,
        }
CIFAR10Dataset.__init__ method · python · L49-L79 (31 LOC)
src/cl/datasets/cifar.py
    def __init__(self, config: Dict[str, Any]):
        """Initialize the CIFAR-10 dataset.

        Args:
            config: Configuration dictionary
        """
        super().__init__(config)

        self._n_tasks = config.get('n_task', 5)
        self.rotation_range = config.get('rotation_range', DEFAULT_ROTATION_RANGE)
        self.scaling_range = config.get('scaling_range', DEFAULT_SCALING_RANGE)
        self.train_split = config.get('train_test_split', DEFAULT_TRAIN_TEST_SPLIT)

        # Load CIFAR-10 dataset
        print("Loading CIFAR-10 dataset")
        my_transforms = transforms.Compose([transforms.ToTensor()])
        self.dataset = torchvision.datasets.CIFAR10(
            './data', train=True, download=True, transform=my_transforms
        )

        # Extract images and labels
        [self.images, self.labels] = [list(t) for t in zip(*self.dataset)]
        self.images = torch.stack(self.images, dim=0)
        self.labels = np.array(self.labels)
        print(f"CI
CIFAR10Dataset._load_task_data method · python · L81-L119 (39 LOC)
src/cl/datasets/cifar.py
    def _load_task_data(self, task_id: int) -> None:
        """Load data for a specific CIFAR-10 task with transforms.

        Applies rotation and scaling transforms based on task_id to create
        distribution shift between tasks.

        Args:
            task_id: Task identifier (0-indexed)
        """
        # Fixed by Claude: Set deterministic seed for reproducible task generation
        # Critical for accurate CL metrics - ensures task data is identical during training and evaluation
        np.random.seed(task_id * DEFAULT_PERMUTATION_SEED_MULTIPLIER)

        X = self.images.clone()
        y = self.labels.copy()

        # Apply task-specific transformations
        rot_angle = np.random.random() * self.rotation_range
        scaling_min, scaling_max = self.scaling_range
        scaling = np.random.random() * (scaling_max - scaling_min) + scaling_min

        X = torchvision.transforms.functional.affine(
            X, rot_angle,
            translate=(scaling, scalin
CIFAR100Dataset.__init__ method · python · L163-L194 (32 LOC)
src/cl/datasets/cifar.py
    def __init__(self, config: Dict[str, Any]):
        """Initialize the CIFAR-100 dataset.

        Args:
            config: Configuration dictionary
        """
        super().__init__(config)

        self._n_tasks = config.get('n_task', 5)
        self._n_classes = config.get('n_class', 100)
        self.rotation_range = config.get('rotation_range', DEFAULT_ROTATION_RANGE)
        self.scaling_range = config.get('scaling_range', DEFAULT_SCALING_RANGE)
        self.train_split = config.get('train_test_split', DEFAULT_TRAIN_TEST_SPLIT)

        # Load CIFAR-100 dataset
        print("Loading CIFAR-100 dataset")
        my_transforms = transforms.Compose([transforms.ToTensor()])
        self.dataset = torchvision.datasets.CIFAR100(
            './data', train=True, download=True, transform=my_transforms
        )

        # Extract images and labels
        [self.images, self.labels] = [list(t) for t in zip(*self.dataset)]
        self.images = torch.stack(self.images, dim=0)
     
CIFAR100Dataset._load_task_data method · python · L196-L234 (39 LOC)
src/cl/datasets/cifar.py
    def _load_task_data(self, task_id: int) -> None:
        """Load data for a specific CIFAR-100 task with transforms.

        Applies rotation and scaling transforms based on task_id to create
        distribution shift between tasks.

        Args:
            task_id: Task identifier (0-indexed)
        """
        # Fixed by Claude: Set deterministic seed for reproducible task generation
        # Critical for accurate CL metrics - ensures task data is identical during training and evaluation
        np.random.seed(task_id * DEFAULT_PERMUTATION_SEED_MULTIPLIER)

        X = self.images.clone()
        y = self.labels.copy()

        # Apply task-specific transformations
        rot_angle = np.random.random() * self.rotation_range
        scaling_min, scaling_max = self.scaling_range
        scaling = np.random.random() * (scaling_max - scaling_min) + scaling_min

        X = torchvision.transforms.functional.affine(
            X, rot_angle,
            translate=(scaling, scali
Repobility (the analyzer behind this table) · https://repobility.com
PrefetchDataLoader.__init__ method · python · L43-L56 (14 LOC)
src/cl/datasets/jax_dataloader.py
    def __init__(self, dataloader, prefetch_size: int = 2, device=None, loss_type='classification'):
        self.dataloader = dataloader
        self.prefetch_size = prefetch_size
        # Added by Claude: Explicitly use GPU (respects CUDA_VISIBLE_DEVICES from parallel scripts)
        if device is None:
            gpu_devices = jax.devices('gpu')
            if not gpu_devices:
                raise RuntimeError("No GPU found. JAX continual learning requires GPU.")
            device = gpu_devices[0]  # Will be the GPU assigned by CUDA_VISIBLE_DEVICES
        self.device = device
        self.loss_type = loss_type  # Added by Claude: Store loss_type for dtype conversion
        # Added by Claude: Diagnostic logging for device detection
        print(f"[DEBUG] PrefetchDataLoader initialized with device: {self.device}")
        print(f"[DEBUG] All available JAX devices: {jax.devices()}")
PrefetchDataLoader.__iter__ method · python · L58-L123 (66 LOC)
src/cl/datasets/jax_dataloader.py
    def __iter__(self) -> Iterator[Tuple[jnp.ndarray, jnp.ndarray]]:
        """Iterate over batches with GPU prefetching."""

        # Queue to hold prefetched GPU batches
        batch_queue = queue.Queue(maxsize=self.prefetch_size)

        # Exception storage for background thread
        exception_storage = [None]

        def prefetch_worker():
            """Background thread that loads and transfers batches to GPU."""
            try:
                for batch in self.dataloader:
                    # Unpack batch (handle both 2-tuple and 3-tuple formats)
                    if len(batch) == 2:
                        x, y = batch
                    elif len(batch) == 3:
                        x, y, _ = batch  # Ignore task_id or other metadata
                    else:
                        x, y = batch[0], batch[1]

                    # Convert PyTorch tensors to numpy
                    if hasattr(x, 'numpy'):
                        x = x.numpy()
                    
wrap_dataloader function · python · L162-L187 (26 LOC)
src/cl/datasets/jax_dataloader.py
def wrap_dataloader(loader, prefetch_size: int = 2, device=None):
    """Wrap a PyTorch DataLoader with JAX prefetching.

    Convenience function to convert any PyTorch DataLoader to JAX-optimized version.

    Args:
        loader: PyTorch DataLoader (or tuple of DataLoaders for CL)
        prefetch_size: Number of batches to prefetch (default: 2)
        device: JAX device (default: first GPU)

    Returns:
        PrefetchDataLoader or DualPrefetchDataLoader

    Example:
        >>> # Single loader
        >>> fast_loader = wrap_dataloader(pytorch_loader, prefetch_size=3)

        >>> # Continual learning (current + experience)
        >>> fast_loader = wrap_dataloader((train_loader, exp_loader))
    """
    if isinstance(loader, tuple) and len(loader) == 2:
        # Continual learning: dual loader
        return DualPrefetchDataLoader(loader[0], loader[1], prefetch_size, device)
    else:
        # Single loader
        return PrefetchDataLoader(loader, prefetch_size, device)
‹ prevpage 3 / 6next ›