Function bodies 268 total
AWBOperations.restore_weights method · python · L218-L236 (19 LOC)src/cl/core/awb_operations.py
def restore_weights(
self,
model: eqx.Module,
saved_weights: Any
) -> eqx.Module:
"""Restore model weights after architecture search.
After architecture search, we restore the original weights before
initializing A/B matrices. This ensures we're transforming the trained
weights, not random weights.
Args:
model: Model instance (possibly with different architecture)
saved_weights: Weights saved by save_weights()
Returns:
Model with restored weights
"""
passcreate_balanced_validation_set function · python · L54-L154 (101 LOC)src/cl/core/awb_pipeline.py
def create_balanced_validation_set(loader, validation_ratio=0.2, batch_size=64):
"""Create a balanced validation set from a data loader.
Samples validation_ratio% of data from each class to ensure balanced representation.
This is used for AWB architecture search to avoid using full training data.
Handles both vector datasets (MNIST, CIFAR) and graph datasets (synthetic graphs).
Args:
loader: PyTorch DataLoader or torch_geometric DataLoader to sample from
validation_ratio: Fraction of data to use for validation (default 0.2 = 20%)
batch_size: Batch size for validation loader
Returns:
DataLoader with balanced validation set
"""
# Detect if this is a graph dataset or vector dataset
# Check the first batch to determine loader type
first_batch = next(iter(loader))
is_graph = isinstance(first_batch, Batch)
if is_graph:
# Handle graph datasets (torch_geometric)
all_graphs = []
for batccompute_avg_loss function · python · L36-L80 (45 LOC)src/cl/core/awb.py
def compute_avg_loss(record_dict, task_id, epochs, window=None):
"""Compute average loss over last `window` epochs from training record.
Args:
record_dict: Dictionary containing training records (iterations dict)
task_id: Current task ID
epochs: Total epochs per task
window: Number of epochs to average (default: DEFAULT_AWB_AVERAGING_WINDOW)
Returns:
Average loss value over the specified window
"""
if window is None:
window = DEFAULT_AWB_AVERAGING_WINDOW
losses = []
# Handle both old format (train{epoch}) and new format (iterations dict)
if isinstance(record_dict, dict):
# New format: record_dict has 'iterations' key with iteration numbers as keys
for j in range(1, window + 1):
iteration = (task_id + 1) * epochs - j
if iteration in record_dict:
record = record_dict[iteration]
if isinstance(record, dict) and 'losses' in record:
should_change_arch function · python · L83-L107 (25 LOC)src/cl/core/awb.py
def should_change_arch(trainWLoss, end_last,
threshold_high=None, min_delta=None):
"""Decide if architecture change is needed based on loss ratio.
The decision logic:
- If ratio > threshold_high: change_arch = True
- Otherwise: change_arch = False
Args:
trainWLoss: Current training loss after preliminary training
end_last: Loss at end of previous task (for task 1, this is task 0's optimal loss)
threshold_high: High threshold for loss ratio (default: 0.45)
min_delta: Deprecated, kept for backward compatibility
Returns:
Boolean indicating whether architecture should be changed
"""
if threshold_high is None:
threshold_high = DEFAULT_AWB_CHANGE_THRESHOLD_HIGH
# Added by Claude: Compare current preliminary loss to previous task's final loss
# Simplified logic: only use ratio, not min_delta condition
ratio = trainWLoss / end_last
return ratio > threshold_highcompute_ab_threshold function · python · L110-L138 (29 LOC)src/cl/core/awb.py
def compute_ab_threshold(trainWLoss, end_last, base_threshold=None):
"""Compute dynamic threshold for AB training convergence.
The threshold adapts based on the loss ratio to allow more iterations
when the loss is significantly higher than baseline.
Args:
trainWLoss: Current training loss after preliminary training
end_last: Loss at end of previous task
base_threshold: Base threshold value (default: 0.6)
Returns:
Computed threshold for AB training convergence
"""
if base_threshold is None:
base_threshold = DEFAULT_AWB_AB_THRESHOLD_BASE
ratio = trainWLoss / end_last if end_last > 0 else 1.0
if ratio > 3.0:
threshold = max(1 / ratio, 0.45)
elif 2.0 <= ratio < 3.0:
threshold = min(1 / ratio, 0.6)
elif 1.0 <= ratio < 2.0:
threshold = min(1 / ratio, 0.75)
else:
threshold = 0.8
return thresholdapply_V_transformation function · python · L142-L162 (21 LOC)src/cl/core/awb.py
def apply_V_transformation(model):
"""Generic V = A @ W @ B.T transformation using model's interface.
Works with any model implementing apply_V_transformation() method.
Falls back to model-specific logic for backward compatibility.
Args:
model: AWB-enabled model (MLP, CNN, CNN3D, or GCN)
Returns:
Model with transformed weights V = A @ W @ B.T
"""
if hasattr(model, 'apply_V_transformation'):
return model.apply_V_transformation()
else:
# Backward compatibility: use old model-specific functions
from ..models.mlp import MLP
if isinstance(model, MLP):
return compute_V_from_AWB(model)
else:
return compute_V_from_AWB_gcn(model)partition_model_for_AB_training function · python · L165-L178 (14 LOC)src/cl/core/awb.py
def partition_model_for_AB_training(model):
"""Generic partition for A/B training using model's interface.
Args:
model: AWB-enabled model
Returns:
Tuple of (diff_model, static_model)
"""
if hasattr(model, 'partition_for_AB_training'):
return model.partition_for_AB_training()
else:
# Backward compatibility
return partition_for_AB_training(model)Generated by Repobility's multi-pass static-analysis pipeline (https://repobility.com)
partition_model_for_standard_training function · python · L181-L194 (14 LOC)src/cl/core/awb.py
def partition_model_for_standard_training(model):
"""Generic partition for standard training using model's interface.
Args:
model: AWB-enabled model
Returns:
Tuple of (params, static)
"""
if hasattr(model, 'partition_for_standard_training'):
return model.partition_for_standard_training()
else:
# Backward compatibility
return partition_for_standard_training(model)initialize_AB_matrices function · python · L197-L213 (17 LOC)src/cl/core/awb.py
def initialize_AB_matrices(model, original_arch, new_arch, seed=5):
"""Generic A/B matrix initialization using model's interface.
Args:
model: AWB-enabled model
original_arch: Original architecture specification
new_arch: New architecture specification
seed: Random seed
Returns:
Model with initialized A/B matrices
"""
if hasattr(model, 'with_new_AB_matrices'):
return model.with_new_AB_matrices(original_arch, new_arch, seed)
else:
# Backward compatibility
return set_new_AB_matrices(model, original_arch, new_arch, seed)_create_identity_like_matrix function · python · L221-L244 (24 LOC)src/cl/core/awb.py
def _create_identity_like_matrix(new_size, old_size, key=None):
"""Create identity-like transformation matrix for AWB.
Creates a matrix that preserves the original weights in the overlap region:
- If new_size >= old_size: Identity in upper-left, zeros elsewhere
- If new_size < old_size: Truncated identity
This ensures A @ W @ B^T ≈ W initially for smooth knowledge transfer.
Args:
new_size: Target dimension
old_size: Source dimension
key: Optional PRNG key for small noise (unused, kept for API compatibility)
Returns:
JAX array of shape (new_size, old_size)
"""
# Create identity-like matrix
min_size = min(new_size, old_size)
matrix = jnp.zeros((new_size, old_size))
# Set diagonal to 1 for the overlapping region
indices = jnp.arange(min_size)
matrix = matrix.at[indices, indices].set(1.0)
return matrixset_new_AB_matrices function · python · L247-L287 (41 LOC)src/cl/core/awb.py
def set_new_AB_matrices(model, original_arch, new_arch, seed=5):
"""Initialize A/B matrices for architecture transition (MLP only).
DEPRECATED: Use initialize_AB_matrices() or model.with_new_AB_matrices() instead.
When the architecture changes from original_arch to new_arch,
we create transformation matrices A and B such that
the forward pass becomes: A @ W @ B.T
Uses identity-like initialization so that A @ W @ B^T ≈ W initially,
preserving learned weights in the overlap region for smooth transfer.
Args:
model: Current equinox model (MLP)
original_arch: Original architecture sizes list [in, h1, h2, ..., out]
new_arch: New architecture sizes list [in, h1', h2', ..., out]
seed: Random seed for initializer (unused, kept for API compatibility)
Returns:
Updated model with new A, B matrices and sizes
"""
# A matrices: transform output dimensions [new_out, old_out]
# Identity-like: A @ old_output ≈ ocompute_V_from_AWB function · python · L290-L314 (25 LOC)src/cl/core/awb.py
def compute_V_from_AWB(model):
"""Compute new weights V = A @ W @ B.T for all layers (MLP only).
DEPRECATED: Use apply_V_transformation() or model.apply_V_transformation() instead.
This is STEP 4 of the AWB algorithm: after training A and B matrices,
we compute the effective weights V and update the model to use them.
Args:
model: Equinox MLP model with trained A, B matrices
Returns:
Updated model with weights set to V = A @ W @ B.T
"""
for j in range(len(model.sizes) - 1):
# Compute transformed weight: V = A @ W @ B.T
Vw = model.A[j] @ model.layers[j].weight @ jnp.transpose(model.B[j])
# Compute transformed bias: Vb = bias @ A.T
Vb = model.layers[j].bias @ model.A[j].T
# Update model with new weights
model = eqx.tree_at(lambda x: x.layers[j].weight, model, Vw)
model = eqx.tree_at(lambda x: x.layers[j].bias, model, Vb)
return modelpartition_for_AB_training function · python · L317-L340 (24 LOC)src/cl/core/awb.py
def partition_for_AB_training(model):
"""Partition model for A/B training (freeze W, train A/B) - MLP only.
DEPRECATED: Use partition_model_for_AB_training() or model.partition_for_AB_training() instead.
This creates a filter spec where only A and B are trainable,
used in STEP 3b of the AWB algorithm.
Args:
model: Equinox MLP model
Returns:
Tuple of (diff_model, static_model) where:
- diff_model: Contains only A and B (trainable)
- static_model: Contains everything else (frozen)
"""
filter_spec = jtu.tree_map(lambda _: False, model)
filter_spec = eqx.tree_at(
lambda x: (x.A, x.B),
filter_spec,
replace=(True, True)
)
diff_model, static_model = eqx.partition(model, filter_spec)
return diff_model, static_modelpartition_for_standard_training function · python · L343-L375 (33 LOC)src/cl/core/awb.py
def partition_for_standard_training(model):
"""Partition model for standard training (freeze A/B, train W) - MLP only.
DEPRECATED: Use partition_model_for_standard_training() or model.partition_for_standard_training() instead.
This creates the standard partition where A and B are frozen
and only the layer weights are trainable.
Args:
model: Equinox MLP model
Returns:
Tuple of (params, static) where:
- params: Contains trainable arrays (weights, biases)
- static: Contains A, B matrices (frozen)
"""
params, static = eqx.partition(model, eqx.is_array)
# Move A and B to static (frozen)
static = eqx.tree_at(
lambda x: (x.A, x.B),
static,
replace=(model.A, model.B)
)
# Remove A and B from params (set to None)
params = eqx.tree_at(
lambda x: (x.A, x.B),
params,
replace=(None, None)
)
return params, staticpartition_for_AB_training_cnn function · python · L378-L394 (17 LOC)src/cl/core/awb.py
def partition_for_AB_training_cnn(model):
"""Partition CNN model for A/B training (freeze W, train A/B).
Args:
model: Equinox CNN model with A_conv, B_conv, A_feed, B_feed
Returns:
Tuple of (diff_model, static_model)
"""
filter_spec = jtu.tree_map(lambda _: False, model)
filter_spec = eqx.tree_at(
lambda x: (x.A_conv, x.B_conv, x.A_feed, x.B_feed),
filter_spec,
replace=(True, True, True, True)
)
diff_model, static_model = eqx.partition(model, filter_spec)
return diff_model, static_modelSource: Repobility analyzer · https://repobility.com
partition_for_standard_training_cnn function · python · L397-L433 (37 LOC)src/cl/core/awb.py
def partition_for_standard_training_cnn(model):
"""Partition CNN model for standard training (freeze A/B, train W).
A_conv and B_conv are now stacked 3D arrays (not lists).
A_feed and B_feed remain lists.
Uses eqx.is_array to properly separate arrays from non-arrays (ints, etc.),
then moves A/B matrices to static for freezing.
Args:
model: Equinox CNN model
Returns:
Tuple of (params, static)
"""
# Fixed by Claude: Use eqx.is_array to properly handle non-array leaves (ints)
# Previous approach with tree_map(lambda _: True, model) put ints in params
# which caused jax.grad to fail with "int64 not supported" error
params, static = eqx.partition(model, eqx.is_array)
# Move A/B matrices from params to static (freeze them)
# Use is_leaf=lambda x: x is None to handle None values in static
static = eqx.tree_at(
lambda x: (x.A_conv, x.B_conv, x.A_feed, x.B_feed),
static,
replace=(model.A_convpartition_for_AB_training_cnn3d function · python · L437-L453 (17 LOC)src/cl/core/awb.py
def partition_for_AB_training_cnn3d(model):
"""Partition CNN3D model for A/B training (freeze W, train A/B).
Args:
model: Equinox CNN3D model with A_conv1, B_conv1, A_conv2, B_conv2, A_feed, B_feed
Returns:
Tuple of (diff_model, static_model)
"""
filter_spec = jtu.tree_map(lambda _: False, model)
filter_spec = eqx.tree_at(
lambda x: (x.A_conv1, x.B_conv1, x.A_conv2, x.B_conv2, x.A_feed, x.B_feed),
filter_spec,
replace=(True, True, True, True, True, True)
)
diff_model, static_model = eqx.partition(model, filter_spec)
return diff_model, static_modelpartition_for_standard_training_cnn3d function · python · L456-L493 (38 LOC)src/cl/core/awb.py
def partition_for_standard_training_cnn3d(model):
"""Partition CNN3D model for standard training (freeze A/B, train W).
A_conv1, B_conv1, A_conv2, B_conv2 are now stacked 4D arrays (not nested lists).
A_feed, B_feed remain lists of arrays.
Uses eqx.is_array to properly separate arrays from non-arrays (ints, etc.),
then moves A/B matrices to static for freezing.
Args:
model: Equinox CNN3D model
Returns:
Tuple of (params, static)
"""
# Fixed by Claude: Use eqx.is_array to properly handle non-array leaves (ints)
# Previous approach with tree_map(lambda _: True, model) put ints in params
# which caused jax.grad to fail with "int64 not supported" error
params, static = eqx.partition(model, eqx.is_array)
# Move A/B matrices from params to static (freeze them)
# Use is_leaf=lambda x: x is None to handle None values in static
static = eqx.tree_at(
lambda x: (x.A_conv1, x.B_conv1, x.A_conv2, x.B_conv2, x.A_cpartition_for_AB_training_gnn function · python · L496-L512 (17 LOC)src/cl/core/awb.py
def partition_for_AB_training_gnn(model):
"""Partition GNN model for A/B training (freeze W, train A/B).
Args:
model: Equinox GNN model with A_gcn, B_gcn, A_feed, B_feed
Returns:
Tuple of (diff_model, static_model)
"""
filter_spec = jtu.tree_map(lambda _: False, model)
filter_spec = eqx.tree_at(
lambda x: (x.A_gcn, x.B_gcn, x.A_feed, x.B_feed),
filter_spec,
replace=(True, True, True, True)
)
diff_model, static_model = eqx.partition(model, filter_spec)
return diff_model, static_modelpartition_for_standard_training_gnn function · python · L515-L538 (24 LOC)src/cl/core/awb.py
def partition_for_standard_training_gnn(model):
"""Partition GNN model for standard training (freeze A/B, train W).
Args:
model: Equinox GNN model
Returns:
Tuple of (params, static)
"""
params, static = eqx.partition(model, eqx.is_array)
static = eqx.tree_at(
lambda x: (x.A_gcn, x.B_gcn, x.A_feed, x.B_feed),
static,
replace=(model.A_gcn, model.B_gcn, model.A_feed, model.B_feed)
)
params = eqx.tree_at(
lambda x: (x.A_gcn, x.B_gcn, x.A_feed, x.B_feed),
params,
replace=(None, None, None, None)
)
return params, staticsave_layer_weights function · python · L541-L554 (14 LOC)src/cl/core/awb.py
def save_layer_weights(model):
"""Save current layer weights and biases before architecture search.
Used to restore weights if architecture search doesn't find improvement.
Args:
model: Equinox MLP model
Returns:
Tuple of (weight_list, bias_list)
"""
weight_list = [model.layers[j].weight for j in range(len(model.layers))]
bias_list = [model.layers[j].bias for j in range(len(model.layers))]
return weight_list, bias_listrestore_layer_weights function · python · L557-L571 (15 LOC)src/cl/core/awb.py
def restore_layer_weights(model, weight_list, bias_list):
"""Restore saved layer weights and biases to model.
Args:
model: Equinox MLP model
weight_list: List of weight matrices
bias_list: List of bias vectors
Returns:
Model with restored weights
"""
for j in range(len(weight_list)):
model = eqx.tree_at(lambda x: x.layers[j].weight, model, weight_list[j])
model = eqx.tree_at(lambda x: x.layers[j].bias, model, bias_list[j])
return modelcompute_V_from_AWB_gcn function · python · L574-L607 (34 LOC)src/cl/core/awb.py
def compute_V_from_AWB_gcn(model):
"""Compute new weights V = A @ W @ B.T for GCN model.
This is STEP 4 of the AWB algorithm for GCN: after training A and B matrices,
we compute the effective weights V and update the model to use them.
Based on train_model_graph from run_AWB_ALL_functions.py.
Args:
model: Equinox GCN model with trained A_gcn, B_gcn, A_feed, B_feed matrices
Returns:
Updated model with weights set to V = A @ W @ B.T
"""
# Transform GCN layer weights: V = A @ W @ B.T
for k in range(len(model.gcn_layers)):
Vw = model.A_gcn[k] @ model.gcn_layers[k].weight @ jnp.transpose(model.B_gcn[k])
Vb = model.gcn_layers[k].bias @ model.B_gcn[k].T
model = eqx.tree_at(lambda x, idx=k: x.gcn_layers[idx].weight, model, Vw)
model = eqx.tree_at(lambda x, idx=k: x.gcn_layers[idx].bias, model, Vb)
# Transform feed layer weights: V = (A @ W.T @ B.T).T = B @ W @ A.T
# Note: feed_layers use Linear3 wRepobility's GitHub App fixes findings like these · https://github.com/apps/repobility-bot
save_gcn_layer_weights function · python · L610-L623 (14 LOC)src/cl/core/awb.py
def save_gcn_layer_weights(model):
"""Save current GCN layer weights and biases before architecture search.
Args:
model: Equinox GCN model
Returns:
Tuple of (gcn_weights, gcn_biases, mlp_weights, mlp_biases)
"""
gcn_weights = [model.gcn_layers[k].weight for k in range(len(model.gcn_layers))]
gcn_biases = [model.gcn_layers[k].bias for k in range(len(model.gcn_layers))]
mlp_weights = [model.feed_layers[j].weight for j in range(len(model.feed_layers))]
mlp_biases = [model.feed_layers[j].bias for j in range(len(model.feed_layers))]
return gcn_weights, gcn_biases, mlp_weights, mlp_biasesrestore_gcn_layer_weights function · python · L626-L645 (20 LOC)src/cl/core/awb.py
def restore_gcn_layer_weights(model, gcn_weights, gcn_biases, mlp_weights, mlp_biases):
"""Restore saved GCN layer weights and biases to model.
Args:
model: Equinox GCN model
gcn_weights: List of GCN weight matrices
gcn_biases: List of GCN bias vectors
mlp_weights: List of MLP weight matrices
mlp_biases: List of MLP bias vectors
Returns:
Model with restored weights
"""
for k in range(len(gcn_weights)):
model = eqx.tree_at(lambda x, idx=k: x.gcn_layers[idx].weight, model, gcn_weights[k])
model = eqx.tree_at(lambda x, idx=k: x.gcn_layers[idx].bias, model, gcn_biases[k])
for j in range(len(mlp_weights)):
model = eqx.tree_at(lambda x, idx=j: x.feed_layers[idx].weight, model, mlp_weights[j])
model = eqx.tree_at(lambda x, idx=j: x.feed_layers[idx].bias, model, mlp_biases[j])
return modelcreate_optimizer_for_phase function · python · L648-L663 (16 LOC)src/cl/core/awb.py
def create_optimizer_for_phase(phase, learning_rate=1e-4):
"""Create appropriate optimizer for each training phase.
Args:
phase: One of 'standard', 'ab_training', 'v_training'
learning_rate: Learning rate for optimizer
Returns:
Optax optimizer
"""
if phase == 'ab_training':
return optax.adam(learning_rate)
elif phase == 'v_training':
return optax.adam(1e-3) # Higher LR for V training
else: # standard
return optax.adam(learning_rate)save_cnn_layer_weights function · python · L667-L679 (13 LOC)src/cl/core/awb.py
def save_cnn_layer_weights(model):
"""Save current CNN layer weights and biases before architecture search.
Args:
model: Equinox CNN model
Returns:
Tuple of (conv_weights, feed_weights, feed_biases)
"""
conv_weights = [model.conv_layers[j].weight for j in range(len(model.conv_layers))]
feed_weights = [model.feed_layers[j].weight for j in range(len(model.feed_layers))]
feed_biases = [model.feed_layers[j].bias for j in range(len(model.feed_layers))]
return conv_weights, feed_weights, feed_biasesrestore_cnn_layer_weights function · python · L682-L699 (18 LOC)src/cl/core/awb.py
def restore_cnn_layer_weights(model, conv_weights, feed_weights, feed_biases):
"""Restore saved CNN layer weights and biases to model.
Args:
model: Equinox CNN model
conv_weights: List of conv weight matrices
feed_weights: List of feed weight matrices
feed_biases: List of feed bias vectors
Returns:
Model with restored weights
"""
for j in range(len(conv_weights)):
model = eqx.tree_at(lambda x: x.conv_layers[j].weight, model, conv_weights[j])
for j in range(len(feed_weights)):
model = eqx.tree_at(lambda x: x.feed_layers[j].weight, model, feed_weights[j])
model = eqx.tree_at(lambda x: x.feed_layers[j].bias, model, feed_biases[j])
return modelcompute_V_from_AWB_cnn function · python · L702-L739 (38 LOC)src/cl/core/awb.py
def compute_V_from_AWB_cnn(model):
"""Compute new weights V = A @ W @ B.T for CNN model.
This is STEP 4 of the AWB algorithm for CNN: after training A and B matrices,
we compute the effective weights V and update the model to use them.
A_conv and B_conv are stacked 3D arrays with shape:
(channel_out, new_filter_size, filter_size)
This uses einsum-based awb_transform for efficient batched computation.
Args:
model: Equinox CNN model with trained A_conv, B_conv, A_feed, B_feed matrices
Returns:
Updated model with weights set to V = A @ W @ B.T
"""
import jax.numpy as jnp
from ..models.layers import awb_transform
# Transform conv layer weights using einsum
# A_conv: (out_ch, new_f, old_f), W[:, 0]: (out_ch, H, W), B_conv: (out_ch, new_f, old_f)
# For single input channel (MNIST): weight[:, 0] is (out_ch, H, W)
conv_weights = model.conv_layers[0].weight[:, 0] # (channel_out, H, W)
new_conv_weights = awb_tcompute_V_from_AWB_cnn3d function · python · L742-L779 (38 LOC)src/cl/core/awb.py
def compute_V_from_AWB_cnn3d(model):
"""Compute new weights V = A @ W @ B.T for CNN3D model.
This is STEP 4 of the AWB algorithm for CNN3D: after training A and B matrices,
we compute the effective weights V and update the model to use them.
A_conv1/2 and B_conv1/2 are stacked 4D arrays with shape:
(channel_out, channel_in, new_filter_size, filter_size)
This uses einsum-based awb_transform for efficient batched computation.
Args:
model: Equinox CNN3D model with trained A_conv1, B_conv1, A_conv2, B_conv2, A_feed, B_feed
Returns:
Updated model with weights set to V = A @ W @ B.T
"""
import jax.numpy as jnp
from ..models.layers import awb_transform
# Transform conv layer 1 weights using einsum
# A_conv1: (out, in, new_f, old_f), W: (out, in, H, W), B_conv1: (out, in, new_f, old_f)
# Result V: (out, in, new_f, new_f)
new_conv1_weights = awb_transform(model.A_conv1, model.conv_layers[0].weight, model.B_conv1)
set_new_AB_matrices_cnn function · python · L782-L807 (26 LOC)src/cl/core/awb.py
def set_new_AB_matrices_cnn(model, original_feed_sizes, new_feed_sizes, original_filter, new_filter):
"""Set new A/B matrices for CNN (single conv layer) architecture transition.
Args:
model: CNN model (with old weights to be transformed)
original_feed_sizes: Original feed layer sizes
new_feed_sizes: New feed layer sizes
original_filter: Original filter size
new_filter: New filter size
Returns:
Model with updated A/B matrices (W_old preserved)
"""
from ..arch_search.cnn_search import prepABs
A_feed, B_feed, A_conv, B_conv = prepABs(model, original_feed_sizes, original_filter,
new_feed_sizes, new_filter)
model = eqx.tree_at(lambda x: x.A_feed, model, A_feed)
model = eqx.tree_at(lambda x: x.B_feed, model, B_feed)
model = eqx.tree_at(lambda x: x.A_conv, model, A_conv)
model = eqx.tree_at(lambda x: x.B_conv, model, B_conv)
model = eqx.tree_at(lambda Repobility analyzer · published findings · https://repobility.com
set_new_AB_matrices_cnn3d function · python · L810-L837 (28 LOC)src/cl/core/awb.py
def set_new_AB_matrices_cnn3d(model, original_feed_sizes, new_feed_sizes, original_filter, new_filter):
"""Set new A/B matrices for CNN3D (two conv layers) architecture transition.
Args:
model: CNN3D model (with old weights to be transformed)
original_feed_sizes: Original feed layer sizes
new_feed_sizes: New feed layer sizes
original_filter: Original filter size
new_filter: New filter size
Returns:
Model with updated A/B matrices (W_old preserved)
"""
from ..arch_search.cnn_search import prepABs_CNN3D
A_feed, B_feed, A_conv1, B_conv1, A_conv2, B_conv2 = prepABs_CNN3D(
model, original_feed_sizes, original_filter, new_feed_sizes, new_filter)
model = eqx.tree_at(lambda x: x.A_feed, model, A_feed)
model = eqx.tree_at(lambda x: x.B_feed, model, B_feed)
model = eqx.tree_at(lambda x: x.A_conv1, model, A_conv1)
model = eqx.tree_at(lambda x: x.B_conv1, model, B_conv1)
model = eqx.tree_at(lambset_new_AB_matrices_gcn function · python · L840-L868 (29 LOC)src/cl/core/awb.py
def set_new_AB_matrices_gcn(model, prev_gcn_sizes, prev_feed_sizes, opt_gcn, opt_mlp):
"""Set new A/B matrices for GCN architecture transition.
Args:
model: GCN model
prev_gcn_sizes: Previous GCN layer sizes
prev_feed_sizes: Previous feed layer sizes
opt_gcn: Optimal GCN layer sizes
opt_mlp: Optimal MLP/feed layer sizes
Returns:
Model with updated A/B matrices and architecture
"""
from ..arch_search.gcn_search import prepABs_GCN
# Update architecture
model = eqx.tree_at(lambda x: x.gcn_sizes, model, opt_gcn)
model = eqx.tree_at(lambda x: x.feed_sizes, model, opt_mlp)
# Get transformation matrices
A_feed, B_feed, A_gcn, B_gcn = prepABs_GCN(model, prev_feed_sizes, prev_gcn_sizes)
model = eqx.tree_at(
lambda x: (x.A_feed, x.B_feed, x.A_gcn, x.B_gcn),
model,
replace=(A_feed, B_feed, A_gcn, B_gcn)
)
return model_tree_norm function · python · L46-L56 (11 LOC)src/cl/core/hamiltonian.py
def _tree_norm(tree):
"""Compute L2 norm of a pytree.
Args:
tree: PyTree of JAX arrays
Returns:
Scalar L2 norm across all leaves
"""
leaves = jax.tree_util.tree_leaves(tree)
return jnp.sqrt(sum(jnp.sum(leaf ** 2) for leaf in leaves))_normalize_tree function · python · L59-L71 (13 LOC)src/cl/core/hamiltonian.py
def _normalize_tree(tree, eps=1e-8):
"""Normalize a pytree to unit L2 norm.
Args:
tree: PyTree of JAX arrays
eps: Small constant for numerical stability
Returns:
Tuple of (normalized_tree, original_norm)
"""
norm = _tree_norm(tree)
normalized = jax.tree_util.tree_map(lambda x: x / (norm + eps), tree)
return normalized, norm_loss_mse_standard function · python · L81-L98 (18 LOC)src/cl/core/hamiltonian.py
def _loss_mse_standard(params, static, x, y):
"""JIT-compiled MSE loss for standard training.
Args:
params: Trainable model parameters (PyTree)
static: Frozen model components (PyTree)
x: Input features (JAX array, shape [batch, input_dim])
y: Target values (JAX array, shape [batch])
Returns:
Scalar loss value
"""
model = eqx.combine(params, static)
pred_y = jax.vmap(model)(x)
# Squeeze extra dimension if present: (batch, 1, output) -> (batch, output)
if pred_y.ndim == 3 and pred_y.shape[1] == 1:
pred_y = jnp.squeeze(pred_y, axis=1)
return jnp.mean(optax.l2_loss(y, pred_y))_loss_mse_awb function · python · L102-L109 (8 LOC)src/cl/core/hamiltonian.py
def _loss_mse_awb(params, static, x, y):
"""JIT-compiled MSE loss for AWB training (uses model.getAWB)."""
model = eqx.combine(params, static)
pred_y = jax.vmap(model.getAWB)(x)
# Squeeze extra dimension if present: (batch, 1, output) -> (batch, output)
if pred_y.ndim == 3 and pred_y.shape[1] == 1:
pred_y = jnp.squeeze(pred_y, axis=1)
return jnp.mean(optax.l2_loss(y, pred_y))_loss_class_standard function · python · L113-L127 (15 LOC)src/cl/core/hamiltonian.py
def _loss_class_standard(params, static, x, y):
"""JIT-compiled classification loss for standard training.
Args:
params: Trainable model parameters (PyTree)
static: Frozen model components (PyTree)
x: Input features (JAX array, shape [batch, ...])
y: Target labels (JAX array, shape [batch], int64)
Returns:
Scalar loss value
"""
model = eqx.combine(params, static)
pred_y = jax.nn.log_softmax(jax.vmap(model)(x))
return jnp.mean(optax.losses.softmax_cross_entropy_with_integer_labels(pred_y, y))_loss_class_awb function · python · L131-L136 (6 LOC)src/cl/core/hamiltonian.py
def _loss_class_awb(params, static, x, y):
"""JIT-compiled classification loss for AWB training (uses model.get_AWBT)."""
model = eqx.combine(params, static)
pred_y = jax.vmap(model.get_AWBT)(x)
pred_y = jax.nn.log_softmax(pred_y)
return jnp.mean(optax.losses.softmax_cross_entropy_with_integer_labels(pred_y, y))Generated by Repobility's multi-pass static-analysis pipeline (https://repobility.com)
_loss_graph_standard function · python · L140-L157 (18 LOC)src/cl/core/hamiltonian.py
def _loss_graph_standard(params, static, x, adj, b, n, y):
"""JIT-compiled graph classification loss for standard training.
Args:
params: Trainable model parameters (PyTree)
static: Frozen model components (PyTree)
x: Node features (JAX array)
adj: Adjacency matrix (JAX array)
b: Batch assignment vector (JAX array)
n: Node counts per graph (JAX array)
y: Target labels (JAX array, int64)
Returns:
Scalar loss value
"""
model = eqx.combine(params, static)
pred_y = model(x, adj, b, n)
return jnp.mean(optax.losses.softmax_cross_entropy_with_integer_labels(pred_y, y))_hamiltonian_core_mse_standard function · python · L175-L262 (88 LOC)src/cl/core/hamiltonian.py
def _hamiltonian_core_mse_standard(params, static, x, y, exp_x, exp_y, deltax,
alpha, beta, gamma, sqrt_param_count, dV_scale):
"""JIT-compiled Hamiltonian core for MSE regression (standard training).
Computes: grad = alpha * delta_theta + beta * grad_V + gamma * grad_dV
Args:
params: Trainable parameters
static: Frozen model components
x, y: Current task batch
exp_x, exp_y: Experience replay batch
deltax: Input perturbation
alpha, beta, gamma: Gradient combination weights
sqrt_param_count: Pre-computed sqrt(param_count) for normalization
dV_scale: Additional dV scaling factor
Returns:
Tuple of (grad, (H, V, dV, dV_dtheta, dV_dx))
"""
# Loss function for current task (closed over y)
def loss_fn_curr(p, xx):
model = eqx.combine(p, static)
pred = jax.vmap(model)(xx)
# Squeeze extra dimension if present: (batch, 1, output) -> (b_hamiltonian_core_mse_awb function · python · L266-L329 (64 LOC)src/cl/core/hamiltonian.py
def _hamiltonian_core_mse_awb(params, static, x, y, exp_x, exp_y, deltax,
alpha, beta, gamma, sqrt_param_count, dV_scale):
"""JIT-compiled Hamiltonian core for MSE regression (AWB training)."""
# Loss function for current task using AWB forward (closed over y)
def loss_fn_curr(p, xx):
model = eqx.combine(p, static)
pred = jax.vmap(model.getAWB)(xx)
# Squeeze extra dimension if present: (batch, 1, output) -> (batch, output)
if pred.ndim == 3 and pred.shape[1] == 1:
pred = jnp.squeeze(pred, axis=1)
# Also squeeze final dimension for scalar outputs (batch, 1) -> (batch,)
if pred.ndim == 2 and pred.shape[1] == 1:
pred = jnp.squeeze(pred, axis=-1)
return jnp.mean(optax.l2_loss(y, pred))
# Loss function for experience data using AWB forward (closed over exp_y)
def loss_fn_exp(p, xx):
model = eqx.combine(p, static)
pred = jax.vmap(model.getAWB)(xx_hamiltonian_core_class_standard function · python · L333-L383 (51 LOC)src/cl/core/hamiltonian.py
def _hamiltonian_core_class_standard(params, static, x, y, exp_x, exp_y, deltax,
alpha, beta, gamma, sqrt_param_count, dV_scale):
"""JIT-compiled Hamiltonian core for classification (standard training)."""
# Note: softmax_cross_entropy_with_integer_labels expects raw logits, not log_softmax
def loss_fn_curr(p, xx):
model = eqx.combine(p, static)
pred = jax.vmap(model)(xx)
return jnp.mean(optax.losses.softmax_cross_entropy_with_integer_labels(pred, y))
def loss_fn_exp(p, xx):
model = eqx.combine(p, static)
pred = jax.vmap(model)(xx)
return jnp.mean(optax.losses.softmax_cross_entropy_with_integer_labels(pred, exp_y))
# Compute delta_theta (gradient on current task)
delta_theta = jax.grad(loss_fn_curr)(params, x)
# Normalize wdot to unit magnitude (direction only)
wdot_unnorm = jax.tree_util.tree_map(lambda g: -g, delta_theta)
wdot, wdot_norm = _normalize_tree(wdot__hamiltonian_core_class_awb function · python · L387-L444 (58 LOC)src/cl/core/hamiltonian.py
def _hamiltonian_core_class_awb(params, static, x, y, exp_x, exp_y, deltax,
alpha, beta, gamma, sqrt_param_count, dV_scale):
"""JIT-compiled Hamiltonian core for classification (AWB training).
Notes:
- Partition functions properly exclude non-arrays (ints) from params
- AWB forward pass (get_AWBT) uses einsum-based transforms for efficiency
JIT compilation completes in <1s and provides 4.4x speedup.
"""
# Note: softmax_cross_entropy_with_integer_labels expects raw logits, not log_softmax
def loss_fn_curr(p, xx):
model = eqx.combine(p, static)
pred = jax.vmap(model.get_AWBT)(xx)
return jnp.mean(optax.losses.softmax_cross_entropy_with_integer_labels(pred, y))
def loss_fn_exp(p, xx):
model = eqx.combine(p, static)
pred = jax.vmap(model.get_AWBT)(xx)
return jnp.mean(optax.losses.softmax_cross_entropy_with_integer_labels(pred, exp_y))
# Compute delta_theta (gradient o_hamiltonian_core_graph_standard function · python · L448-L512 (65 LOC)src/cl/core/hamiltonian.py
def _hamiltonian_core_graph_standard(params, static, x, y, adj, b, n,
exp_x, exp_y, exp_adj, exp_b, exp_n,
deltax, delta_adj,
alpha, beta, gamma, sqrt_param_count, dV_scale):
"""JIT-compiled Hamiltonian core for graph classification (standard training)."""
def loss_fn_curr(p, xx, xxadj):
model = eqx.combine(p, static)
pred = model(xx, xxadj, b, n)
return jnp.mean(optax.losses.softmax_cross_entropy_with_integer_labels(pred, y))
def loss_fn_exp(p, xx, xxadj):
model = eqx.combine(p, static)
pred = model(xx, xxadj, exp_b, exp_n)
return jnp.mean(optax.losses.softmax_cross_entropy_with_integer_labels(pred, exp_y))
# Compute delta_theta (gradient on current task)
delta_theta = jax.grad(loss_fn_curr)(params, x, adj)
# Normalize wdot to unit magnitude (direction only)
# This makes dV_dθ measure alignment _hamiltonian_core_graph_awb function · python · L516-L575 (60 LOC)src/cl/core/hamiltonian.py
def _hamiltonian_core_graph_awb(params, static, x, y, adj, b, n,
exp_x, exp_y, exp_adj, exp_b, exp_n,
deltax, delta_adj,
alpha, beta, gamma, sqrt_param_count, dV_scale):
"""JIT-compiled Hamiltonian core for graph classification (AWB training)."""
def loss_fn_curr(p, xx, xxadj):
model = eqx.combine(p, static)
pred = model.get_AWBT(xx, xxadj, b, n)
return jnp.mean(optax.losses.softmax_cross_entropy_with_integer_labels(pred, y))
def loss_fn_exp(p, xx, xxadj):
model = eqx.combine(p, static)
pred = model.get_AWBT(xx, xxadj, exp_b, exp_n)
return jnp.mean(optax.losses.softmax_cross_entropy_with_integer_labels(pred, exp_y))
# Compute delta_theta (gradient on current task)
delta_theta = jax.grad(loss_fn_curr)(params, x, adj)
# Normalize wdot to unit magnitude (direction only)
wdot_unnorm = jax.tree_util.tree_map(lambda _get_param_count_jax function · python · L582-L592 (11 LOC)src/cl/core/hamiltonian.py
def _get_param_count_jax(params):
"""Compute parameter count as JAX scalar (for use in JIT).
Args:
params: PyTree of parameters
Returns:
JAX scalar with total parameter count
"""
leaves = jax.tree_util.tree_leaves(params)
return jnp.array(sum(leaf.size for leaf in leaves), dtype=jnp.float64)Source: Repobility analyzer · https://repobility.com
HamiltonianMixin._count_parameters method · python · L606-L616 (11 LOC)src/cl/core/hamiltonian.py
def _count_parameters(self, params):
"""Count total number of parameters in pytree.
Args:
params: PyTree of parameters
Returns:
Total parameter count as float
"""
leaves = jax.tree_util.tree_leaves(params)
return float(sum(leaf.size for leaf in leaves))HamiltonianMixin._get_sqrt_param_count method · python · L618-L629 (12 LOC)src/cl/core/hamiltonian.py
def _get_sqrt_param_count(self, params):
"""Get sqrt of parameter count as JAX scalar for JIT compatibility.
Args:
params: PyTree of parameters
Returns:
JAX scalar with sqrt(param_count)
"""
leaves = jax.tree_util.tree_leaves(params)
count = sum(leaf.size for leaf in leaves)
return jnp.sqrt(jnp.array(count, dtype=jnp.float64))HamiltonianMixin._normalize_dV method · python · L631-L650 (20 LOC)src/cl/core/hamiltonian.py
def _normalize_dV(self, dV, params, normalize=True, scale_factor=1.0):
"""Normalize dV by parameter count to prevent scaling with model size.
Args:
dV: Raw dV value
params: Model parameters (for counting)
normalize: Whether to apply normalization
scale_factor: Additional manual scaling factor
Returns:
Normalized dV
"""
if not normalize:
return dV * scale_factor
param_count = self._count_parameters(params)
dV_normalized = dV / jnp.sqrt(param_count)
dV_normalized = dV_normalized * scale_factor
return dV_normalized