Function bodies 268 total
HamiltonianMixin.return_Hamiltonian_graph method · python · L652-L717 (66 LOC)src/cl/core/hamiltonian.py
def return_Hamiltonian_graph(self, params, data, notABTrain, grad_weights=None,
normalize_dV=True, dV_scale=1.0):
"""Compute Hamiltonian gradient for graph classification.
Uses JIT-compiled core functions for high GPU utilization.
Args:
params: Trainable model parameters
data: Tuple of (static, (batch, batch_ex, deltax, delta_adj))
OR (static, (x, y, adj, b, n, exp_x, exp_y, exp_adj, exp_b, exp_n, deltax, delta_adj))
notABTrain: True for standard training, False for AWB A/B training
grad_weights: Optional [alpha, beta, gamma] weights for gradient combination
normalize_dV: Whether to normalize dV by parameter count (default True)
dV_scale: Additional scaling factor for dV (default 1.0)
Returns:
Tuple of (grad, (H, V, dV, dV_dtheta, dV_dx, dV_dadj))
"""
if grad_weights is None:
grad_weightHamiltonianMixin.return_Hamiltonian_mse method · python · L719-L757 (39 LOC)src/cl/core/hamiltonian.py
def return_Hamiltonian_mse(self, params, data, notABTrain=True, grad_weights=None,
normalize_dV=True, dV_scale=1.0):
"""Compute Hamiltonian gradient for MSE regression.
Uses JIT-compiled core functions for high GPU utilization.
Args:
params: Trainable model parameters
data: Tuple of (statics, (x, y, exp_x, exp_y, deltax, flag))
notABTrain: True for standard training, False for AWB A/B training
grad_weights: Optional [alpha, beta, gamma] weights for gradient combination
normalize_dV: Whether to normalize dV by parameter count (default True)
dV_scale: Additional scaling factor for dV (default 1.0)
Returns:
Tuple of (grad, (H, V, dV, dV_dtheta, dV_dx))
"""
if grad_weights is None:
grad_weights = DEFAULT_GRAD_WEIGHTS
alpha, beta, gamma = grad_weights
statics, (x, y, exp_x, exp_y, deltax, flag) HamiltonianMixin.return_Hamiltonian_class method · python · L759-L801 (43 LOC)src/cl/core/hamiltonian.py
def return_Hamiltonian_class(self, params, data, notABTrain=True, grad_weights=None,
normalize_dV=True, dV_scale=1.0):
"""Compute Hamiltonian gradient for classification.
Uses JIT-compiled core functions for high GPU utilization.
Args:
params: Trainable model parameters
data: Tuple of (statics, (x, y, exp_x, exp_y, deltax, flag))
notABTrain: True for standard training, False for AWB A/B training
grad_weights: Optional [alpha, beta, gamma] weights for gradient combination
normalize_dV: Whether to normalize dV by parameter count (default True)
dV_scale: Additional scaling factor for dV (default 1.0)
Returns:
Tuple of (grad, (H, V, dV, dV_dtheta, dV_dx))
"""
if grad_weights is None:
grad_weights = DEFAULT_GRAD_WEIGHTS
alpha, beta, gamma = grad_weights
statics, (x, y, exp_x, exp_y, deltax, flget_graph_transforms function · python · L38-L54 (17 LOC)src/cl/core/loops.py
def get_graph_transforms():
"""Lazily load graph transforms to avoid torch_geometric import when not needed.
Transform pipeline matches old working code:
T.Compose([T.GCNNorm(), T.ToDense(), T.NormalizeFeatures()])
"""
global _GRAPH_TRANSFORMS
if _GRAPH_TRANSFORMS is None:
import torch_geometric.transforms as T
# Fixed by Claude: Removed T.NormalizeFeatures() which was destroying the
# class-feature correlation in FakeDataset (correlation dropped from 0.998 to -0.135)
_GRAPH_TRANSFORMS = T.Compose([
T.GCNNorm(),
T.ToDense()
])
return _GRAPH_TRANSFORMSTrainingLoopsMixin._clip_gradients method · python · L60-L97 (38 LOC)src/cl/core/loops.py
def _clip_gradients(self, grads, max_norm=None):
"""Clip gradients by global norm.
Added by Claude: Following Pascanu et al. (2013) recommendations for
gradient clipping in recurrent networks, applicable to continual learning
with potentially unstable gradients.
Reference:
- Pascanu et al., "On the difficulty of training recurrent neural networks",
ICML 2013
Args:
grads: PyTree of gradients
max_norm: Maximum gradient norm (None = no clipping)
Returns:
Tuple of (clipped_grads, global_norm, was_clipped)
"""
if max_norm is None or max_norm <= 0:
# No clipping
grad_leaves = jax.tree_util.tree_leaves(grads)
global_norm = jnp.sqrt(sum([jnp.sum(g**2) for g in grad_leaves]))
return grads, global_norm, False
# Compute global norm
grad_leaves = jax.tree_util.tree_leaves(grads)
global_norm = jTrainingLoopsMixin._compute_metrics_on_sampled_batches method · python · L99-L218 (120 LOC)src/cl/core/loops.py
def _compute_metrics_on_sampled_batches(self, params, static, loader,
num_batches=10, problem_type='vectors',
notABTrain=True, transforms=None):
"""Efficiently compute metrics on N sampled batches from a loader.
Args:
params: Model parameters
static: Static model components
loader: Data loader (tuple of current_loader, exp_loader)
num_batches: Number of batches to sample (default 10)
problem_type: 'vectors' or 'graph'
notABTrain: Whether using normal training (True) or AWB training (False)
transforms: Transform pipeline for graph data
Returns:
Tuple of (current_task_metric, experience_metric)
"""
current_metrics = []
exp_metrics = []
if problem_type == 'graph':
if isinstance(loader, tuple):
current_loader, exp_loadeTrainingLoopsMixin._compute_perturbation_variance method · python · L220-L287 (68 LOC)src/cl/core/loops.py
def _compute_perturbation_variance(self, trainloader, exploader, problem_type, max_batches=5):
"""Pre-compute variance for perturbation sampling.
Uses the mean difference approach for feature variance.
For graphs, also computes adjacency variance.
Args:
trainloader: Current task data loader
exploader: Experience replay data loader
problem_type: 'vectors' or 'graph'
max_batches: Maximum batches to sample for variance estimation (default 5)
5 batches × 2048 batch_size = 10,240 samples (statistically sufficient)
Returns:
Tuple of (var_x, var_adj) where var_adj is 0 for vectors
"""
var_x_list, var_adj_list = [], []
transforms = get_graph_transforms() if problem_type == 'graph' else None
# Fixed by Claude: Handle empty experience loader (Task 0 for graph datasets)
# Try to check if exploader has any data
exp_iter_Repobility · code-quality intelligence · https://repobility.com
LossMixin.loss_fn_mse method · python · L31-L38 (8 LOC)src/cl/core/losses.py
def loss_fn_mse(self, params, statics, x, y):
"""MSE loss for regression (vectors)."""
model = eqx.combine(params, statics)
preds = jax.vmap(model)(x)
# Added by Claude: squeeze extra dimension if present (batch, 1, output) -> (batch, output)
if preds.ndim == 3 and preds.shape[1] == 1:
preds = jnp.squeeze(preds, axis=1)
return jnp.mean((y - preds)**2)LossMixin.accuracy_vectors method · python · L41-L67 (27 LOC)src/cl/core/losses.py
def accuracy_vectors(self, params, statics, x, y):
"""Accuracy metric for classification (vectors).
Args:
params: Model parameters
statics: Static model components
x: Input features, shape (batch, ...)
y: Class labels - either scalar indices shape (batch,) or one-hot shape (batch, num_classes)
Returns:
Accuracy as float between 0 and 1
"""
model = eqx.combine(params, statics)
preds = jax.vmap(model)(x)
# Squeeze extra dimension if present (batch, 1, classes) -> (batch, classes)
if preds.ndim == 3 and preds.shape[1] == 1:
preds = jnp.squeeze(preds, axis=1)
pred = jnp.argmax(jax.nn.softmax(preds), axis=1)
# Handle both one-hot encoded (batch, num_classes) and scalar labels (batch,)
if y.ndim == 2 and y.shape[1] > 1:
# One-hot encoded: convert to class indices
y = jnp.argmax(y, axis=1)
elif y.nLossMixin.mse_vectors method · python · L70-L77 (8 LOC)src/cl/core/losses.py
def mse_vectors(self, params, statics, x, y):
"""MSE metric for regression (vectors)."""
model = eqx.combine(params, statics)
preds = jax.vmap(model)(x)
# Added by Claude: squeeze extra dimension if present (batch, 1, output) -> (batch, output)
if preds.ndim == 3 and preds.shape[1] == 1:
preds = jnp.squeeze(preds, axis=1)
return jnp.mean(optax.l2_loss(y, preds))LossMixin.loss_fn_class_graph method · python · L82-L89 (8 LOC)src/cl/core/losses.py
def loss_fn_class_graph(self, params, statics, x, y, adj=None):
"""Cross-entropy loss for graph classification."""
model = eqx.combine(params, statics)
logits = jnp.stack([model(x[i], adj[i]).T for i in range(len(x))])
pred_y = jnp.stack(logits)
y = y.astype(jnp.int64)
pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
return -jnp.mean(y * pred_y)LossMixin.accuracy_graphs method · python · L98-L103 (6 LOC)src/cl/core/losses.py
def accuracy_graphs(self, params, statics, x, adj, b, n):
"""Accuracy computation for graph classification (standard forward)."""
model = eqx.combine(params, statics)
array_log = [model(x[i], adj[i], b[i], n[i]) for i in range(len(x))]
logits = jnp.concatenate(array_log)
return jax.nn.log_softmax(logits, axis=1)LossMixin.accuracy_graphs_AWBT method · python · L106-L111 (6 LOC)src/cl/core/losses.py
def accuracy_graphs_AWBT(self, params, statics, x, adj, b, n):
"""Accuracy computation for graph classification (AWB forward)."""
model = eqx.combine(params, statics)
array_log = [model.get_AWBT(x[i], adj[i], b[i], n[i]) for i in range(len(x))]
logits = jnp.concatenate(array_log)
return jax.nn.log_softmax(logits, axis=1)LossMixin.return_loss_grad method · python · L135-L162 (28 LOC)src/cl/core/losses.py
def return_loss_grad(self, params, batch, static):
"""Compute loss and gradient for a batch.
Args:
params: Trainable model parameters
batch: Input batch (varies by problem type)
static: Static model components
Returns:
Tuple of (loss, gradients)
"""
if self.problem == 'vectors':
(x, y) = batch
if self.loss == 'class':
grads = jax.grad(self.loss_fn_class)(params, static, x, y)
loss = self.loss_fn_class(params, static, x, y)
elif self.loss == 'mse':
grads = jax.grad(self.loss_fn_mse)(params, static, x, y)
loss = self.loss_fn_mse(params, static, x, y)
elif self.problem == 'graph':
(x, y, adj) = batch
if self.loss == 'class':
grads = jax.grad(self.loss_fn_class_graph)(params, static, x, y, adj=adj)
loss = self.loss_fn_class_graph(paramsLossMixin.return_metric method · python · L164-L229 (66 LOC)src/cl/core/losses.py
def return_metric(self, params, statics, data, notABTrain=True):
"""Compute evaluation metric based on problem type.
Args:
params: Trainable model parameters
statics: Static model components
data: Input data (format depends on problem type)
For vectors: (x, y) where y is scalar class indices shape (batch,)
notABTrain: True for standard forward, False for AWB forward
Returns:
Metric value (accuracy for classification, MSE for regression)
"""
model = eqx.combine(params, statics)
if self.problem == 'vectors':
x, y = data
if self.metric == 'class':
# Ensure y is 1D int64 array of class indices
y = y.astype(jnp.int64)
if y.ndim == 2:
y = jnp.squeeze(y, axis=-1)
if notABTrain:
preds = jax.vmap(model)(x)
else:
Repobility (the analyzer behind this table) · https://repobility.com
get_gpu_memory_usage function · python · L16-L56 (41 LOC)src/cl/core/profiling.py
def get_gpu_memory_usage():
"""
Get current GPU memory usage from JAX.
Returns:
Tuple of (used_gb, total_gb, percent) or None if unavailable
"""
try:
import jax
from jax.lib import xla_bridge
# Get GPU backend
backend = xla_bridge.get_backend()
# Get memory stats from first GPU device
devices = backend.devices()
if not devices:
return None
device = devices[0]
# Get memory info
mem_stats = device.memory_stats()
if mem_stats is None:
return None
bytes_in_use = mem_stats.get('bytes_in_use', 0)
# Try to get peak memory
peak_bytes = mem_stats.get('peak_bytes_in_use', bytes_in_use)
# Convert to GB
used_gb = bytes_in_use / (1024**3)
peak_gb = peak_bytes / (1024**3)
# Try to estimate total memory (H200 has 144GB typically)
# We'll use peak as a proxy since total isn't always available
format_memory_stats function · python · L58-L70 (13 LOC)src/cl/core/profiling.py
def format_memory_stats():
"""
Format GPU memory usage as a string.
Returns:
Formatted string like "GPU: 2.3GB / 144GB (1.6%)" or "GPU: N/A"
"""
mem = get_gpu_memory_usage()
if mem is None:
return "GPU: N/A"
used_gb, peak_gb = mem
return f"GPU: {used_gb:.2f}GB used, {peak_gb:.2f}GB peak"enable_profiling function · python · L72-L80 (9 LOC)src/cl/core/profiling.py
def enable_profiling(enabled: bool):
"""
Enable/disable profiling globally.
Args:
enabled: True to enable profiling output, False to disable
"""
global _PROFILING_ENABLED
_PROFILING_ENABLED = enabledprofile function · python · L82-L110 (29 LOC)src/cl/core/profiling.py
def profile(phase_name: str):
"""
Decorator to time a function if profiling is enabled.
When profiling is disabled, this has zero overhead (simple boolean check).
When enabled, prints timing and GPU memory information for the decorated function.
Args:
phase_name: Human-readable name for this profiling phase
Example:
@profile("Dataset Loading")
def load_data():
# ... code ...
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not _PROFILING_ENABLED:
return func(*args, **kwargs)
print(f"\n[PROFILE] {phase_name} starting... | {format_memory_stats()}")
start_time = time.time()
result = func(*args, **kwargs)
elapsed = time.time() - start_time
print(f"[PROFILE] {phase_name} complete: {elapsed:.2f}s | {format_memory_stats()}")
return result
return wrapper
return decoprofile_section function · python · L112-L141 (30 LOC)src/cl/core/profiling.py
def profile_section(phase_name: str, enabled: bool = None):
"""
Context manager for profiling a code section.
Args:
phase_name: Human-readable name for this profiling phase
enabled: Optional override for profiling enabled status
Example:
with profile_section("JAX Pre-conversion"):
# ... code to profile ...
"""
class ProfileContext:
def __enter__(self):
if enabled is None:
self.enabled = _PROFILING_ENABLED
else:
self.enabled = enabled
if self.enabled:
print(f"\n[PROFILE] {phase_name} starting... | {format_memory_stats()}")
self.start_time = time.time()
return self
def __exit__(self, *args):
if self.enabled:
elapsed = time.time() - self.start_time
print(f"[PROFILE] {phase_name} complete: {elapsed:.2f}s | {format_memory_stats()}")
return ProfileContextRecordingMixin._compute_eigenvalues method · python · L26-L148 (123 LOC)src/cl/core/recording.py
def _compute_eigenvalues(self, model, combined=False):
"""Compute eigenvalues of A/B matrices (AWB mode) or weight matrices (standard mode).
For AWB-enabled models: computes eigenvalues of A and B matrices.
For standard models: computes eigenvalues of weight matrices (W).
Args:
model: The model (combined params + static), or None to skip computation
combined: If True, return model; if False, extract from self
Returns:
dict: Eigenvalues organized as:
AWB mode:
{
'A': {'layer_0': array, 'layer_1': array, ...},
'B': {'layer_0': array, 'layer_1': array, ...}
}
Standard mode (weights stored in 'A' key for plot compatibility):
{
'A': {'layer_0': array, 'layer_1': array, ...},
'B': {}
}
If model is None, returns empty dictRecordingMixin.record_metrics method · python · L150-L188 (39 LOC)src/cl/core/recording.py
def record_metrics(self,
iteration: int,
step: int,
task_id: int,
losses: Dict[str, float],
gradients: Dict[str, float],
metrics: Dict[str, float],
model,
extra_metrics: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""Record all metrics for a single iteration in unified format.
Args:
iteration: Iteration number (key for the record)
step: Current step within task
task_id: Current task ID
losses: Dict of loss values (H, V, dV, dV_dx, dV_dtheta, dV_dadj, etc.)
gradients: Dict of gradient norms
metrics: Dict with train, test_current, test_experience metrics
model: The model (for eigenvalue extraction)
extra_metrics: Optional dict for dataset-specific metrics
Returns:
dict: CRecordingMixin.initialize_record_dict method · python · L190-L229 (40 LOC)src/cl/core/recording.py
def initialize_record_dict(self, config: Dict[str, Any], run_id: int = 0) -> Dict[str, Any]:
"""Initialize the recording dictionary with metadata.
Args:
config: Configuration dictionary containing problem/dataset info
run_id: The run/repetition number for this experiment
Returns:
dict: Initialized recording structure with task-based format
"""
# Added by Claude: New task-based recording structure
# - Global iterations: track progress across all tasks
# - Within-task epochs: track progress within each task
# - AB iterations: separate counter for AB training phases
# Also maintain 'iterations' dict for backward compatibility with compute_avg_loss
record_dict = {
'metadata': {
'problem': config.get('problem', 'unknown'),
'prob': config.get('prob', 'unknown'),
'dataset': config.get('data', 'unknown'),
Generated by Repobility's multi-pass static-analysis pipeline (https://repobility.com)
RecordingMixin.initialize_task method · python · L231-L261 (31 LOC)src/cl/core/recording.py
def initialize_task(self, record_dict: Dict[str, Any], task_id: int, arch_info: Dict[str, Any]):
"""Initialize a new task's recording structure.
Args:
record_dict: The recording dictionary
task_id: Task ID to initialize
arch_info: Architecture information (sizes, filter_size, etc.)
"""
# Added by Claude: Initialize task with main_training structure
record_dict['tasks'][task_id] = {
'main_training': {
'iterations': [], # Global iteration numbers
'epochs': [], # Within-task epoch numbers
'H': [],
'V': [],
'dV': [],
'dV_dx': [],
'dV_dtheta': [],
'grad_norm': [],
'train_metric': [],
'test_current': [],
'test_experience': [],
'eigenvalues': {'A': {}, 'B': {}}
},
'phase_infoRecordingMixin.record_preliminary_summary method · python · L263-L282 (20 LOC)src/cl/core/recording.py
def record_preliminary_summary(self, record_dict: Dict[str, Any], task_id: int,
n_epochs: int, warmup_epochs: int,
final_loss: float, decision: str):
"""Record summary of preliminary phase (not detailed metrics).
Args:
record_dict: The recording dictionary
task_id: Task ID
n_epochs: Total preliminary epochs
warmup_epochs: Warmup epochs within preliminary
final_loss: Final loss after preliminary training
decision: 'arch_change' or 'no_change'
"""
# Added by Claude: Record preliminary phase summary
record_dict['tasks'][task_id]['preliminary'] = {
'n_epochs': n_epochs,
'warmup_epochs': warmup_epochs,
'final_loss': float(final_loss),
'decision': decision
}RecordingMixin.initialize_ab_training method · python · L284-L297 (14 LOC)src/cl/core/recording.py
def initialize_ab_training(self, record_dict: Dict[str, Any], task_id: int):
"""Initialize AB training recording for a task.
Args:
record_dict: The recording dictionary
task_id: Task ID
"""
# Added by Claude: Initialize AB training structure with separate iteration counter
record_dict['tasks'][task_id]['ab_training'] = {
'iterations': [], # AB-specific iteration numbers (local to AB phase)
'H': [],
'V': [],
'ab_eigenvalues': {'A': {}, 'B': {}}
}RecordingMixin.record_ab_training_epoch method · python · L299-L329 (31 LOC)src/cl/core/recording.py
def record_ab_training_epoch(self, record_dict: Dict[str, Any], task_id: int,
iteration: int, losses: Dict[str, float], model):
"""Record a single epoch of AB training with AB eigenvalues.
Args:
record_dict: The recording dictionary
task_id: Task ID
iteration: AB training iteration number (local to AB phase)
losses: Loss dictionary
model: Model for eigenvalue extraction
"""
# Added by Claude: Auto-initialize task and AB training if not exists
if task_id not in record_dict.get('tasks', {}):
arch_info = {'sizes': 'unknown'}
self.initialize_task(record_dict, task_id, arch_info)
if 'ab_training' not in record_dict['tasks'][task_id]:
self.initialize_ab_training(record_dict, task_id)
# Added by Claude: Record AB training epoch
ab = record_dict['tasks'][task_id]['ab_training']
ab['iteratioRecordingMixin.record_main_training_epoch method · python · L331-L373 (43 LOC)src/cl/core/recording.py
def record_main_training_epoch(self, record_dict: Dict[str, Any], task_id: int,
global_iteration: int, epoch: int,
losses: Dict[str, float], gradients: Dict[str, float],
metrics: Dict[str, float], model):
"""Record a single epoch of main training (Step 5 or standard).
Args:
record_dict: The recording dictionary
task_id: Task ID
global_iteration: Global iteration number (across all tasks)
epoch: Within-task epoch number
losses: Loss dictionary
gradients: Gradient dictionary
metrics: Metrics dictionary
model: Model for eigenvalue extraction
"""
# Added by Claude: Auto-initialize task if not exists (for backward compatibility)
if task_id not in record_dict.get('tasks', {}):
# Get minimal architecture info from model
arch_inRecordingMixin.record_task_performance method · python · L375-L413 (39 LOC)src/cl/core/recording.py
def record_task_performance(self, record_dict: Dict[str, Any], current_task_id: int,
task_performances: Dict[int, float]):
"""Record performance on all tasks after training current_task_id.
This builds the performance matrix A_{i,j} needed for CL metrics:
- ACC (Average Accuracy)
- BWT (Backward Transfer)
- F (Average Forgetting)
- FWT (Forward Transfer)
Args:
record_dict: The recording dictionary
current_task_id: The task that was just trained (j in A_{i,j})
task_performances: Dict mapping task_id -> performance metric
e.g., {0: 0.95, 1: 0.92, 2: 0.89}
Performance on tasks 0,1,2 after training task 2
Example:
After training task 2:
task_performances = {
0: 0.93, # accuracy on task 0 after training task 2
1: 0.91, # accuracy on tasRecordingMixin.save_record_dict method · python · L415-L445 (31 LOC)src/cl/core/recording.py
def save_record_dict(self, record_dict: Dict[str, Any], base_path: str):
"""Save the recording dictionary to file using problem/dataset name.
Args:
record_dict: The complete recording dictionary
base_path: Base path for saving (typically config['model_path'])
"""
metadata = record_dict['metadata']
# Create filename: problem_dataset_network[_awb]_run{run_id}_records.pkl
run_id = metadata.get('run_id', 0)
awb_suffix = "_awb" if metadata.get('awb_enabled', False) else ""
filename = f"{metadata['prob']}_{metadata['dataset']}_{metadata['network']}{awb_suffix}_run{run_id}_records.pkl"
# Fixed by Claude: Use base_path directly as save directory (not parent)
# This keeps each config's results in its own folder
# e.g., model_path="kkt_run/kkt/results/sine_condition1_optimized" → save there
if base_path:
save_dir = base_path
if not os.path.exists(RecordingMixin.save_all_runs method · python · L448-L488 (41 LOC)src/cl/core/recording.py
def save_all_runs(all_runs_records: Dict[str, Any], base_path: str, config: Dict[str, Any]):
"""Save all experiment runs to a single file.
Args:
all_runs_records: Dictionary with run IDs as keys and record_dicts as values
base_path: Base path for saving
config: Configuration dictionary for metadata
"""
# Create a consolidated structure with runs at the top level
consolidated = {
'runs': all_runs_records,
'metadata': {
'total_runs': len(all_runs_records),
'problem': config.get('problem', 'unknown'),
'prob': config.get('prob', 'unknown'),
'dataset': config.get('data', 'unknown'),
'network': config.get('network', 'unknown'),
'awb_enabled': config.get('awb_enabled', False),
}
}
# Create filename with AWB suffix if enabled
awb_suffix = "_awb" if config.get(Repobility analyzer · published findings · https://repobility.com
Trainer.__init__ method · python · L45-L55 (11 LOC)src/cl/core/trainer.py
def __init__(self, loss='mse', metric='mse', problem='vectors'):
"""Initialize the Trainer.
Args:
loss: Loss function type ('mse' for regression, 'class' for classification)
metric: Metric function type ('mse' for MSE, 'class' for accuracy)
problem: Problem type ('vectors' for MLP/CNN, 'graph' for GNN)
"""
self.loss = loss
self.problem = problem
self.metric = metricContinualDataset.__init__ method · python · L31-L39 (9 LOC)src/cl/datasets/base.py
def __init__(self, config: Dict[str, Any], data_x, data_y):
self.config = config
self.x = data_x if torch.is_tensor(data_x) else torch.from_numpy(data_x.astype(np.float32))
self.y = data_y if torch.is_tensor(data_y) else torch.from_numpy(data_y.astype(np.float32))
# Reshape for fully connected networks (flatten images)
if self.config.get('problem') == 'classification':
if self.config.get('network') == 'fcnn':
self.x = self.x.reshape([-1, 784])BaseDataset.__init__ method · python · L71-L110 (40 LOC)src/cl/datasets/base.py
def __init__(self, config: Dict[str, Any]):
"""Initialize the dataset.
Args:
config: Configuration dictionary containing:
- batch_size: Batch size for DataLoaders
- len_exp_replay: Maximum experience replay buffer size
- debug_mode: (optional) Enable debug mode with limited data
- debug_limit: (optional) Number of samples in debug mode
"""
self.config = config
self.batch_size = config.get('batch_size', 64)
self.len_exp_replay = config.get('len_exp_replay', 20000)
self.debug_mode = config.get('debug_mode', False)
self.debug_limit = config.get('debug_limit', 100)
# Current task data
self.X_train: Optional[torch.Tensor] = None
self.y_train: Optional[np.ndarray] = None
self.X_test: Optional[torch.Tensor] = None
self.y_test: Optional[np.ndarray] = None
# Experience replay buffers
self.exp_x_BaseDataset._load_task_data method · python · L113-L123 (11 LOC)src/cl/datasets/base.py
def _load_task_data(self, task_id: int) -> None:
"""Internal method to load data for a specific task.
Subclasses must implement this to populate:
- self.X_train, self.y_train (training data)
- self.X_test, self.y_test (test data)
Args:
task_id: Task identifier (0-indexed)
"""
passBaseDataset.load_task method · python · L125-L138 (14 LOC)src/cl/datasets/base.py
def load_task(self, task_id: int) -> None:
"""Load data for a specific task with debug limit applied.
This wrapper calls _load_task_data and then applies debug limits.
Args:
task_id: Task identifier (0-indexed)
"""
# Call the subclass implementation
self._load_task_data(task_id)
# Apply debug limit if enabled
if self.debug_mode:
self._apply_debug_limit()BaseDataset._apply_debug_limit method · python · L140-L155 (16 LOC)src/cl/datasets/base.py
def _apply_debug_limit(self) -> None:
"""Apply debug limit to current task data.
Limits X_train, y_train, X_test, y_test to debug_limit samples.
"""
limit = self.debug_limit
if self.X_train is not None and len(self.X_train) > limit:
print(f"DEBUG MODE: Limiting training data from {len(self.X_train)} to {limit} samples")
self.X_train = self.X_train[:limit]
self.y_train = self.y_train[:limit]
if self.X_test is not None and len(self.X_test) > limit:
print(f"DEBUG MODE: Limiting test data from {len(self.X_test)} to {limit} samples")
self.X_test = self.X_test[:limit]
self.y_test = self.y_test[:limit]BaseDataset._rebalance_buffer method · python · L175-L248 (74 LOC)src/cl/datasets/base.py
def _rebalance_buffer(self, task_id: int, is_train: bool = True) -> None:
"""Rebalance experience buffer with task-weighted sampling.
Added by Claude: Implements balanced replay with recency weighting.
Allocation: 10% recent task, 80% older tasks (equal split), 10% random.
Args:
task_id: Current task identifier (most recent task)
is_train: True for training buffer, False for test buffer
"""
if is_train:
exp_x = self.exp_x_train
exp_y = self.exp_y_train
exp_task_ids = self.exp_task_ids_train
else:
exp_x = self.exp_x_test
exp_y = self.exp_y_test
exp_task_ids = self.exp_task_ids_test
if len(exp_x) <= self.len_exp_replay:
return # No rebalancing needed
n_tasks = task_id + 1
if not self.balanced_replay_enabled or n_tasks == 1:
# Disabled or first task: simple random sampling
BaseDataset.append_to_experience method · python · L250-L290 (41 LOC)src/cl/datasets/base.py
def append_to_experience(self, task_id: int) -> None:
"""Add current task data to the experience replay buffer.
Manages buffer size using task-balanced sampling when enabled.
Args:
task_id: Current task identifier
"""
# Convert to tensor if needed
X_train = self.X_train if torch.is_tensor(self.X_train) else torch.from_numpy(
np.array(self.X_train, dtype=np.float32))
X_test = self.X_test if torch.is_tensor(self.X_test) else torch.from_numpy(
np.array(self.X_test, dtype=np.float32))
y_train = self.y_train if isinstance(self.y_train, np.ndarray) else np.array(self.y_train)
y_test = self.y_test if isinstance(self.y_test, np.ndarray) else np.array(self.y_test)
# Create task ID arrays for new samples
new_task_ids_train = np.full(len(X_train), task_id, dtype=np.int8)
new_task_ids_test = np.full(len(X_test), task_id, dtype=np.int8)
if not self._expRepobility · code-quality intelligence · https://repobility.com
BaseDataset.get_task_data method · python · L292-L315 (24 LOC)src/cl/datasets/base.py
def get_task_data(self, task_id: int, phase: str) -> Tuple[Tuple, Tuple]:
"""Retrieve current and experience data for a task.
Args:
task_id: Task identifier
phase: 'training' or 'testing'
Returns:
Tuple of ((current_x, current_y), (experience_x, experience_y))
"""
if phase == 'training':
current = (self.X_train, self.y_train)
if task_id > 0 and self._exp_initialized:
experience = (self.exp_x_train, self.exp_y_train)
else:
experience = (self.X_train, self.y_train) # Use current as experience for task 0
else: # testing
current = (self.X_test, self.y_test)
if task_id > 0 and self._exp_initialized:
experience = (self.exp_x_test, self.exp_y_test)
else:
experience = (self.X_test, self.y_test)
return current, experienceBaseDataset.generate_dataset method · python · L317-L364 (48 LOC)src/cl/datasets/base.py
def generate_dataset(self, task_id: int, batch_size: int, phase: str) -> Tuple[DataLoader, DataLoader]:
"""Generate DataLoaders for a specific task.
This is the main interface method called by runners.
Args:
task_id: Task identifier (0-indexed)
batch_size: Batch size for DataLoaders
phase: 'training' or 'testing'
Returns:
Tuple of (current_task_loader, experience_replay_loader)
"""
# Load task data if in training phase (test uses same task data)
if phase == 'training':
self.load_task(task_id)
# Get current and experience data
(x_curr, y_curr), (x_exp, y_exp) = self.get_task_data(task_id, phase)
# Create datasets
dataset_curr = ContinualDataset(self.config, x_curr, y_curr)
dataset_exp = ContinualDataset(self.config, x_exp, y_exp)
# Create DataLoaders
# Note: num_workers=0 required when using JAX because os.fBaseDataset.generate_test_loader method · python · L366-L396 (31 LOC)src/cl/datasets/base.py
def generate_test_loader(self, task_id: int, batch_size: int = None) -> DataLoader:
"""Generate test loader for a specific task (for CL metrics evaluation).
Added by Claude: This method enables per-task evaluation needed for computing
the performance matrix A[j][i] = accuracy on task i after training task j.
Args:
task_id: Task ID to generate test loader for
batch_size: Batch size for DataLoader (uses self.batch_size if None)
Returns:
DataLoader for task-specific test data
"""
if batch_size is None:
batch_size = self.batch_size
# Load task data (populates self.X_test, self.y_test)
self.load_task(task_id)
# Create dataset and loader for test data
dataset_test = ContinualDataset(self.config, self.X_test, self.y_test)
loader_kwargs = {
'batch_size': batch_size,
'shuffle': False, # No shuffle for evaluation
BaseDataset.get_model_config method · python · L398-L408 (11 LOC)src/cl/datasets/base.py
def get_model_config(self) -> Dict[str, Any]:
"""Return configuration for model initialization.
Returns:
Dictionary with input_size, output_size, and any dataset-specific config
"""
return {
'input_size': self.input_size,
'output_size': self.output_size,
'n_tasks': self.n_tasks,
}CIFAR10Dataset.__init__ method · python · L49-L79 (31 LOC)src/cl/datasets/cifar.py
def __init__(self, config: Dict[str, Any]):
"""Initialize the CIFAR-10 dataset.
Args:
config: Configuration dictionary
"""
super().__init__(config)
self._n_tasks = config.get('n_task', 5)
self.rotation_range = config.get('rotation_range', DEFAULT_ROTATION_RANGE)
self.scaling_range = config.get('scaling_range', DEFAULT_SCALING_RANGE)
self.train_split = config.get('train_test_split', DEFAULT_TRAIN_TEST_SPLIT)
# Load CIFAR-10 dataset
print("Loading CIFAR-10 dataset")
my_transforms = transforms.Compose([transforms.ToTensor()])
self.dataset = torchvision.datasets.CIFAR10(
'./data', train=True, download=True, transform=my_transforms
)
# Extract images and labels
[self.images, self.labels] = [list(t) for t in zip(*self.dataset)]
self.images = torch.stack(self.images, dim=0)
self.labels = np.array(self.labels)
print(f"CICIFAR10Dataset._load_task_data method · python · L81-L119 (39 LOC)src/cl/datasets/cifar.py
def _load_task_data(self, task_id: int) -> None:
"""Load data for a specific CIFAR-10 task with transforms.
Applies rotation and scaling transforms based on task_id to create
distribution shift between tasks.
Args:
task_id: Task identifier (0-indexed)
"""
# Fixed by Claude: Set deterministic seed for reproducible task generation
# Critical for accurate CL metrics - ensures task data is identical during training and evaluation
np.random.seed(task_id * DEFAULT_PERMUTATION_SEED_MULTIPLIER)
X = self.images.clone()
y = self.labels.copy()
# Apply task-specific transformations
rot_angle = np.random.random() * self.rotation_range
scaling_min, scaling_max = self.scaling_range
scaling = np.random.random() * (scaling_max - scaling_min) + scaling_min
X = torchvision.transforms.functional.affine(
X, rot_angle,
translate=(scaling, scalinCIFAR100Dataset.__init__ method · python · L163-L194 (32 LOC)src/cl/datasets/cifar.py
def __init__(self, config: Dict[str, Any]):
"""Initialize the CIFAR-100 dataset.
Args:
config: Configuration dictionary
"""
super().__init__(config)
self._n_tasks = config.get('n_task', 5)
self._n_classes = config.get('n_class', 100)
self.rotation_range = config.get('rotation_range', DEFAULT_ROTATION_RANGE)
self.scaling_range = config.get('scaling_range', DEFAULT_SCALING_RANGE)
self.train_split = config.get('train_test_split', DEFAULT_TRAIN_TEST_SPLIT)
# Load CIFAR-100 dataset
print("Loading CIFAR-100 dataset")
my_transforms = transforms.Compose([transforms.ToTensor()])
self.dataset = torchvision.datasets.CIFAR100(
'./data', train=True, download=True, transform=my_transforms
)
# Extract images and labels
[self.images, self.labels] = [list(t) for t in zip(*self.dataset)]
self.images = torch.stack(self.images, dim=0)
CIFAR100Dataset._load_task_data method · python · L196-L234 (39 LOC)src/cl/datasets/cifar.py
def _load_task_data(self, task_id: int) -> None:
"""Load data for a specific CIFAR-100 task with transforms.
Applies rotation and scaling transforms based on task_id to create
distribution shift between tasks.
Args:
task_id: Task identifier (0-indexed)
"""
# Fixed by Claude: Set deterministic seed for reproducible task generation
# Critical for accurate CL metrics - ensures task data is identical during training and evaluation
np.random.seed(task_id * DEFAULT_PERMUTATION_SEED_MULTIPLIER)
X = self.images.clone()
y = self.labels.copy()
# Apply task-specific transformations
rot_angle = np.random.random() * self.rotation_range
scaling_min, scaling_max = self.scaling_range
scaling = np.random.random() * (scaling_max - scaling_min) + scaling_min
X = torchvision.transforms.functional.affine(
X, rot_angle,
translate=(scaling, scaliRepobility (the analyzer behind this table) · https://repobility.com
PrefetchDataLoader.__init__ method · python · L43-L56 (14 LOC)src/cl/datasets/jax_dataloader.py
def __init__(self, dataloader, prefetch_size: int = 2, device=None, loss_type='classification'):
self.dataloader = dataloader
self.prefetch_size = prefetch_size
# Added by Claude: Explicitly use GPU (respects CUDA_VISIBLE_DEVICES from parallel scripts)
if device is None:
gpu_devices = jax.devices('gpu')
if not gpu_devices:
raise RuntimeError("No GPU found. JAX continual learning requires GPU.")
device = gpu_devices[0] # Will be the GPU assigned by CUDA_VISIBLE_DEVICES
self.device = device
self.loss_type = loss_type # Added by Claude: Store loss_type for dtype conversion
# Added by Claude: Diagnostic logging for device detection
print(f"[DEBUG] PrefetchDataLoader initialized with device: {self.device}")
print(f"[DEBUG] All available JAX devices: {jax.devices()}")PrefetchDataLoader.__iter__ method · python · L58-L123 (66 LOC)src/cl/datasets/jax_dataloader.py
def __iter__(self) -> Iterator[Tuple[jnp.ndarray, jnp.ndarray]]:
"""Iterate over batches with GPU prefetching."""
# Queue to hold prefetched GPU batches
batch_queue = queue.Queue(maxsize=self.prefetch_size)
# Exception storage for background thread
exception_storage = [None]
def prefetch_worker():
"""Background thread that loads and transfers batches to GPU."""
try:
for batch in self.dataloader:
# Unpack batch (handle both 2-tuple and 3-tuple formats)
if len(batch) == 2:
x, y = batch
elif len(batch) == 3:
x, y, _ = batch # Ignore task_id or other metadata
else:
x, y = batch[0], batch[1]
# Convert PyTorch tensors to numpy
if hasattr(x, 'numpy'):
x = x.numpy()
wrap_dataloader function · python · L162-L187 (26 LOC)src/cl/datasets/jax_dataloader.py
def wrap_dataloader(loader, prefetch_size: int = 2, device=None):
"""Wrap a PyTorch DataLoader with JAX prefetching.
Convenience function to convert any PyTorch DataLoader to JAX-optimized version.
Args:
loader: PyTorch DataLoader (or tuple of DataLoaders for CL)
prefetch_size: Number of batches to prefetch (default: 2)
device: JAX device (default: first GPU)
Returns:
PrefetchDataLoader or DualPrefetchDataLoader
Example:
>>> # Single loader
>>> fast_loader = wrap_dataloader(pytorch_loader, prefetch_size=3)
>>> # Continual learning (current + experience)
>>> fast_loader = wrap_dataloader((train_loader, exp_loader))
"""
if isinstance(loader, tuple) and len(loader) == 2:
# Continual learning: dual loader
return DualPrefetchDataLoader(loader[0], loader[1], prefetch_size, device)
else:
# Single loader
return PrefetchDataLoader(loader, prefetch_size, device)