Function bodies 268 total
CNNAWBOps.set_AB_matrices method · python · L754-L766 (13 LOC)src/cl/models/cnn.py
def set_AB_matrices(self, model, original_arch, new_arch):
"""Initialize A/B matrices for architecture transition."""
from ..core.awb import set_new_AB_matrices_cnn, set_new_AB_matrices_cnn3d
original_filter, original_feed_sizes = original_arch
new_filter, new_feed_sizes = new_arch
if self.is_cnn3d:
return set_new_AB_matrices_cnn3d(model, original_feed_sizes, new_feed_sizes,
original_filter, new_filter)
else:
return set_new_AB_matrices_cnn(model, original_feed_sizes, new_feed_sizes,
original_filter, new_filter)CNNAWBOps.partition_for_AB_training method · python · L768-L775 (8 LOC)src/cl/models/cnn.py
def partition_for_AB_training(self, model):
"""Partition model for AB training (freeze W, train A/B)."""
from ..core.awb import partition_for_AB_training_cnn, partition_for_AB_training_cnn3d
if self.is_cnn3d:
return partition_for_AB_training_cnn3d(model)
else:
return partition_for_AB_training_cnn(model)CNNAWBOps.compute_V method · python · L777-L784 (8 LOC)src/cl/models/cnn.py
def compute_V(self, model):
"""Compute V = A @ W @ B^T."""
from ..core.awb import compute_V_from_AWB_cnn, compute_V_from_AWB_cnn3d
if self.is_cnn3d:
return compute_V_from_AWB_cnn3d(model)
else:
return compute_V_from_AWB_cnn(model)CNNAWBOps.partition_for_standard_training method · python · L786-L793 (8 LOC)src/cl/models/cnn.py
def partition_for_standard_training(self, model):
"""Partition model for standard training (train V, freeze A/B)."""
from ..core.awb import partition_for_standard_training_cnn, partition_for_standard_training_cnn3d
if self.is_cnn3d:
return partition_for_standard_training_cnn3d(model)
else:
return partition_for_standard_training_cnn(model)sp_matmul function · python · L26-L43 (18 LOC)src/cl/models/gcn.py
def sp_matmul(A, B, shape):
"""Sparse matrix multiplication for graph operations.
Args:
A: (N, M) sparse matrix represented as a tuple (indexes, values)
B: (M, K) dense matrix
shape: value of N
Returns:
(N, K) dense matrix
"""
assert B.ndim == 2
indexes, values = A
rows, cols = indexes
in_ = B.take(cols, axis=0)
prod = in_ * values[:, None]
res = jax.ops.segment_sum(prod, rows, shape)
return resGraphPooling.__call__ method · python · L79-L87 (9 LOC)src/cl/models/gcn.py
def __call__(self, x: jnp.ndarray, batch: jnp.ndarray, num_nodes: jnp.ndarray) -> jnp.ndarray:
if self.pool_type == 'sum':
return Pool.sum(x, batch, num_nodes)
elif self.pool_type == 'mean':
return Pool.mean(x, batch, num_nodes)
elif self.pool_type == 'max':
return Pool.max(x, batch, num_nodes)
else: # identity
return Pool.identity(x, batch, num_nodes)GCNLayer.__init__ method · python · L108-L119 (12 LOC)src/cl/models/gcn.py
def __init__(self, in_size: int, out_size: int, key: jax.Array,
bias: bool = True, sparse: bool = False):
self.bias_flag = bias
self.sparse = sparse
# Use initializer locally, don't store as attribute (breaks JIT)
initializer = jax.nn.initializers.glorot_uniform()
wkey, bkey = jax.random.split(key)
self.weight = initializer(wkey, (in_size, out_size))
if self.bias_flag:
self.bias = initializer(bkey, (1, out_size))
else:
self.bias = NoneRepobility — same analyzer, your code, free for public repos · /scan/
GCNLayer.matmul method · python · L121-L126 (6 LOC)src/cl/models/gcn.py
def matmul(self, A, B, shape):
"""Matrix multiplication supporting sparse adjacency."""
if self.sparse:
return sp_matmul(A, B, shape)
else:
return jnp.matmul(A, B)GCNLayer.normalize_adjacency method · python · L129-L161 (33 LOC)src/cl/models/gcn.py
def normalize_adjacency(self, adj: jax.Array) -> jax.Array:
"""Apply symmetric normalization with self-loops.
Implements the standard GCN normalization from Kipf & Welling (2017):
 = D̃^(-1/2) @ (A + I) @ D̃^(-1/2)
where:
- A is the adjacency matrix
- I is the identity matrix (self-loops)
- D̃ is the degree matrix of (A + I)
Args:
adj: Adjacency matrix (num_nodes, num_nodes)
Returns:
Normalized adjacency matrix (num_nodes, num_nodes)
"""
# Add self-loops: Ã = A + I
num_nodes = adj.shape[0]
adj_with_loops = adj + jnp.eye(num_nodes)
# Compute degree matrix: D̃_ii = sum_j Ã_ij
degree = jnp.sum(adj_with_loops, axis=1)
# Compute D̃^(-1/2) with numerical stability (handle isolated nodes)
deg_inv_sqrt = jnp.power(degree, -0.5)
deg_inv_sqrt = jnp.where(jnp.isinf(deg_inv_sqrt), 0., deg_inv_sqrt)
# Symmetric noGCNLayer.__call__ method · python · L163-L185 (23 LOC)src/cl/models/gcn.py
def __call__(self, x: jax.Array, adj: jax.Array) -> jax.Array:
"""Forward pass for GCN layer.
Implements: adj @ X @ W + b
Note: Adjacency normalization is applied via T.GCNNorm() transform
in the data pipeline (loops.py:get_graph_transforms), so we do NOT
normalize here to avoid double normalization.
Args:
x: Node features (num_nodes, in_size)
adj: Adjacency matrix (num_nodes, num_nodes) - already normalized by T.GCNNorm()
Returns:
Updated node features (num_nodes, out_size)
"""
# Graph convolution: adj @ (x @ W) + b
# Note: adj is already normalized by T.GCNNorm() transform
support = x @ self.weight
x = self.matmul(adj, support, support.shape)
if self.bias_flag:
x += self.bias
return xGCN.__init__ method · python · L254-L329 (76 LOC)src/cl/models/gcn.py
def __init__(self, in_size: int, feed_sizes: List[int] = None,
gcn_sizes: List[int] = None, node_num: int = 0,
SEED: int = 1234, out_size: int = 2, graph: bool = True,
awb_fnn_arch: List[int] = None, awb_gcn_arch: List[int] = None):
"""Initialize GCN model."""
self.SEED = SEED
self.graph = graph
self.node_num = node_num
# Default architectures
if gcn_sizes is None:
gcn_sizes = [in_size, 128]
if feed_sizes is None:
feed_sizes = [128, 128, 128, out_size]
# Ensure first GCN layer matches input size
gcn_sizes[0] = in_size
self.gcn_sizes = gcn_sizes
# Ensure last feed layer matches output size
feed_sizes[-1] = out_size
self.feed_sizes = feed_sizes
# AWB architectures - use provided or defaults
# AWB should preserve input and output dimensions but can expand hidden layers
if awb_fnn_aGCN.matmul method · python · L331-L336 (6 LOC)src/cl/models/gcn.py
def matmul(self, A, B, shape):
"""Matrix multiplication supporting sparse adjacency."""
if self.sparse:
return sp_matmul(A, B, shape)
else:
return jnp.matmul(A, B)GCN.__call__ method · python · L338-L365 (28 LOC)src/cl/models/gcn.py
def __call__(self, x: jax.Array, adj: jax.Array, batch: jax.Array,
n_nodes: jax.Array) -> jax.Array:
"""Standard forward pass.
Args:
x: Node features (total_nodes, in_size)
adj: Adjacency matrix (total_nodes, total_nodes)
batch: Batch assignment for each node (total_nodes,)
n_nodes: Number of nodes per graph (batch_size,)
Returns:
Class logits (batch_size, n_class)
"""
# GCN layers with LeakyReLU
for layer in self.gcn_layers:
x = jax.nn.leaky_relu(layer(x, adj))
# Graph pooling
x = self.pool_layer(x, batch, n_nodes)
# Feed-forward layers (all but last with LeakyReLU)
for i in range(len(self.feed_sizes) - 2):
x = jax.nn.leaky_relu(self.feed_layers[i](x))
# Final layer (no activation)
x = self.feed_layers[-1](x)
return xGCN.get_AWBT method · python · L367-L415 (49 LOC)src/cl/models/gcn.py
def get_AWBT(self, x: jax.Array, adj: jax.Array, batch: jax.Array,
n_nodes: jax.Array) -> jax.Array:
"""Forward pass using AWB transformation.
Uses AWBMixin.awb_transform_linear for efficient computation.
Uses V = A @ W @ B.T for both GCN and feed layers.
Note: Adjacency normalization is applied via T.GCNNorm() transform
in the data pipeline, so we do NOT normalize here.
Args:
x: Node features (total_nodes, in_size)
adj: Adjacency matrix (total_nodes, total_nodes) - already normalized by T.GCNNorm()
batch: Batch assignment for each node (total_nodes,)
n_nodes: Number of nodes per graph (batch_size,)
Returns:
Class logits (batch_size, n_class)
"""
# GCN layers with AWB transformation using AWBMixin
for i in range(len(self.gcn_layers)):
# Compute AWB transformed weight: V = A_gcn @ W @ B_gcn^T
V_weightGCN.get_awb_layer_specs method · python · L418-L424 (7 LOC)src/cl/models/gcn.py
def get_awb_layer_specs(self) -> List[AWBLayerSpec]:
"""Get AWB specs for feed layers only (GCN layers handled separately)."""
return [
AWBLayerSpec(layer=self.feed_layers[i], A=self.A_feed[i], B=self.B_feed[i],
layer_type='linear', layer_index=i)
for i in range(len(self.feed_layers))
]Open data scored by Repobility · https://repobility.com
GCN.partition_for_AB_training method · python · L426-L431 (6 LOC)src/cl/models/gcn.py
def partition_for_AB_training(self):
"""Partition for A/B training (freeze W, train A/B)."""
filter_spec = jtu.tree_map(lambda _: False, self)
filter_spec = eqx.tree_at(lambda x: (x.A_gcn, x.B_gcn, x.A_feed, x.B_feed),
filter_spec, replace=(True, True, True, True))
return eqx.partition(self, filter_spec)GCN.partition_for_standard_training method · python · L433-L440 (8 LOC)src/cl/models/gcn.py
def partition_for_standard_training(self):
"""Partition for standard training (freeze A/B, train W)."""
params, static = eqx.partition(self, eqx.is_array)
static = eqx.tree_at(lambda x: (x.A_gcn, x.B_gcn, x.A_feed, x.B_feed), static,
replace=(self.A_gcn, self.B_gcn, self.A_feed, self.B_feed))
params = eqx.tree_at(lambda x: (x.A_gcn, x.B_gcn, x.A_feed, x.B_feed), params,
replace=(None, None, None, None))
return params, staticGCN.generate_search_candidates method · python · L443-L516 (74 LOC)src/cl/models/gcn.py
def generate_search_candidates(self, iteration: int, current_best: Tuple[List[int], List[int]],
config: Dict[str, Any]) -> List[Tuple[List[int], List[int]]]:
"""Generate candidate architectures for GCN using neighborhood search.
GCN architecture has two components:
- gcn_sizes: [in_size, gcn_hidden, ...] for graph convolution layers
- feed_sizes: [gcn_out, mlp_h1, mlp_h2, n_class] for feedforward layers
Search strategy:
- Uses neighborhood search with expanding radius (n)
- For each iteration, explores 3x3x3 = 27 candidates
- GCN hidden: z2 + n*(j+1)*step_gcn for j in [0,1,2]
- MLP hidden1: x1 + n*(k+1)*step_mlp for k in [0,1,2]
- MLP hidden2: x2 + n*(r+1)*step_mlp for r in [0,1,2]
Args:
iteration: Current search iteration (controls neighborhood size n)
current_best: Tuple of (gcn_sizes, mlp_sizes) lists
config: ConfiguratGCN.create_with_architecture method · python · L519-L548 (30 LOC)src/cl/models/gcn.py
def create_with_architecture(cls, arch_spec: Tuple[List[int], List[int]],
seed: int = 0, awb_enabled: bool = True) -> 'GCN':
"""Create GCN model with specified architecture.
Args:
arch_spec: Tuple of (gcn_sizes, feed_sizes) lists
seed: Random seed for weight initialization
awb_enabled: Whether to enable AWB (always True for GCN)
Returns:
New GCN instance with specified architecture
"""
gcn_sizes, feed_sizes = arch_spec
# Infer parameters from architecture
in_size = gcn_sizes[0]
out_size = feed_sizes[-1]
# Create new GCN with specified architecture
return cls(
in_size=in_size,
feed_sizes=feed_sizes,
gcn_sizes=gcn_sizes,
node_num=0, # Will be set from actual data
SEED=seed,
out_size=out_size,
graph=True,
awb_fnn_arch=None, # GCN.reinitialize_weights method · python · L550-L567 (18 LOC)src/cl/models/gcn.py
def reinitialize_weights(self, seed: int = 0) -> 'GCN':
"""Reinitialize GCN weights for fair architecture comparison.
For GCN models, we create a fresh instance because the architecture
is fixed at initialization (GCN and feed layers).
Args:
seed: Random seed for initialization
Returns:
New GCN instance with reinitialized weights
"""
# Create fresh GCN with same architecture
return self.create_with_architecture(
arch_spec=(self.gcn_sizes, self.feed_sizes),
seed=seed,
awb_enabled=True
)GCNAWBOps.search_architecture method · python · L581-L606 (26 LOC)src/cl/models/gcn.py
def search_architecture(self, model, task_id, baseline_loss, dataloader_curr,
dataloader_exp, test_loader_curr, test_loader_exp, config, trainer=None):
"""Search for optimal GCN architecture.
Added by Claude: Supports awb_predetermined_arch config to bypass search for debugging.
Format: awb_predetermined_arch = {task_id: {"gcn_sizes": [...], "feed_sizes": [...]}}
"""
# Added by Claude: Check for predetermined architecture (bypasses search for debugging)
predetermined_archs = config.get('awb_predetermined_arch', {})
if str(task_id) in predetermined_archs or task_id in predetermined_archs:
arch_config = predetermined_archs.get(str(task_id)) or predetermined_archs.get(task_id)
new_arch = (arch_config['gcn_sizes'], arch_config['feed_sizes'])
print(f" [GCN] Using PREDETERMINED architecture (search bypassed)")
print(f" Predetermined arch: GCN={new_arch[0]}GCNAWBOps.set_AB_matrices method · python · L608-L616 (9 LOC)src/cl/models/gcn.py
def set_AB_matrices(self, model, original_arch, new_arch):
"""Initialize A/B matrices for architecture transition."""
from ..core.awb import set_new_AB_matrices_gcn
original_gcn_sizes, original_feed_sizes = original_arch
new_gcn_sizes, new_feed_sizes = new_arch
return set_new_AB_matrices_gcn(model, original_gcn_sizes, original_feed_sizes,
new_gcn_sizes, new_feed_sizes)awb_transform function · python · L19-L44 (26 LOC)src/cl/models/layers.py
def awb_transform(A: jax.Array, W: jax.Array, B: jax.Array) -> jax.Array:
"""Unified AWB weight transformation: V = A @ W @ B.T
Works for any number of batch dimensions using einsum with ellipsis.
This enables efficient batched computation for:
- MLP: 2D matrices (no batch dims)
- CNN: 3D matrices (batch over output channels)
- CNN3D: 4D matrices (batch over output and input channels)
- Transformers: arbitrary batch dims (layers, heads, etc.)
Args:
A: Transformation matrix with shape (..., new_out, old_out)
W: Weight matrix with shape (..., old_out, old_in)
B: Transformation matrix with shape (..., new_in, old_in)
Returns:
Transformed weight V with shape (..., new_out, new_in)
Example:
# Single layer (MLP-style)
V = awb_transform(A, W, B) # A: (m,k), W: (k,n), B: (p,n) -> V: (m,p)
# Batched over channels (CNN3D-style)
V = awb_transform(A, W, B) # A: (o,i,m,k), W: (o,i,k,n), B:Repobility · severity-and-effort ranking · https://repobility.com
AWBLayerSpec.validate method · python · L65-L91 (27 LOC)src/cl/models/layers.py
def validate(self) -> List[str]:
"""Validate AWB matrix shapes against layer dimensions.
Returns:
List of error messages (empty if valid)
"""
errors = []
if self.A is not None and hasattr(self.layer, 'weight'):
# For most layers: A transforms output dimension
expected_A_cols = self.layer.weight.shape[0] # old_out
if self.A.shape[1] != expected_A_cols:
errors.append(
f"Layer {self.layer_index} ({self.layer_type}): "
f"A.shape[1]={self.A.shape[1]} != weight.shape[0]={expected_A_cols}"
)
if self.B is not None and hasattr(self.layer, 'weight'):
# For most layers: B transforms input dimension
expected_B_cols = self.layer.weight.shape[1] # old_in
if self.B.shape[1] != expected_B_cols:
errors.append(
f"Layer {self.layer_index} ({self.layer_type}): Linear.compute_V_weight method · python · L127-L150 (24 LOC)src/cl/models/layers.py
def compute_V_weight(self, A: jax.Array, W: jax.Array, B: jax.Array) -> jax.Array:
"""Compute transformed weight V = A @ W @ B.T for Linear layer.
Args:
A: Output transformation matrix (new_out, old_out)
W: Original weight matrix (old_out, old_in)
B: Input transformation matrix (new_in, old_in)
Returns:
Transformed weight V with shape (new_out, new_in)
"""
# Validate shapes
if A.shape[1] != W.shape[0]:
raise AWBShapeError(
f"Linear layer: A.shape={A.shape} incompatible with W.shape={W.shape}. "
f"Expected A.shape[1]={A.shape[1]} == W.shape[0]={W.shape[0]}"
)
if B.shape[1] != W.shape[1]:
raise AWBShapeError(
f"Linear layer: B.shape={B.shape} incompatible with W.shape={W.shape}. "
f"Expected B.shape[1]={B.shape[1]} == W.shape[1]={W.shape[1]}"
)
return A @ W @ jLinear.compute_V_bias method · python · L152-L166 (15 LOC)src/cl/models/layers.py
def compute_V_bias(self, A: jax.Array, B: jax.Array, bias: jax.Array) -> jax.Array:
"""Compute transformed bias for Linear layer.
For Linear layers with bias shape (1, out_size), transformation is: bias @ A.T
Args:
A: Output transformation matrix (new_out, old_out)
B: Input transformation matrix (not used for bias in Linear)
bias: Original bias vector (1, old_out)
Returns:
Transformed bias with shape (1, new_out)
"""
# Linear layer: bias @ A.T (because bias shape is (1, out))
return bias @ A.TLinearGCN.__call__ method · python · L191-L196 (6 LOC)src/cl/models/layers.py
def __call__(self, x: jax.Array) -> jax.Array:
"""Forward pass: W @ x + bias"""
x = self.weight @ x
if self.bias is not None:
x = x + self.bias
return xLinearGCN.compute_V_weight method · python · L199-L222 (24 LOC)src/cl/models/layers.py
def compute_V_weight(self, A: jax.Array, W: jax.Array, B: jax.Array) -> jax.Array:
"""Compute transformed weight V = A @ W @ B.T for LinearGCN layer.
Args:
A: Output transformation matrix (new_out, old_out)
W: Original weight matrix (old_out, old_in)
B: Input transformation matrix (new_in, old_in)
Returns:
Transformed weight V with shape (new_out, new_in)
"""
# Validate shapes
if A.shape[1] != W.shape[0]:
raise AWBShapeError(
f"LinearGCN layer: A.shape={A.shape} incompatible with W.shape={W.shape}. "
f"Expected A.shape[1]={A.shape[1]} == W.shape[0]={W.shape[0]}"
)
if B.shape[1] != W.shape[1]:
raise AWBShapeError(
f"LinearGCN layer: B.shape={B.shape} incompatible with W.shape={W.shape}. "
f"Expected B.shape[1]={B.shape[1]} == W.shape[1]={W.shape[1]}"
)
return LinearGCN.compute_V_bias method · python · L224-L238 (15 LOC)src/cl/models/layers.py
def compute_V_bias(self, A: jax.Array, B: jax.Array, bias: jax.Array) -> jax.Array:
"""Compute transformed bias for LinearGCN layer.
For LinearGCN with bias shape (out_size, 1), transformation is: bias @ B.T
Args:
A: Output transformation matrix (not used for bias in LinearGCN)
B: Input transformation matrix (new_in, old_in)
bias: Original bias vector (old_out, 1)
Returns:
Transformed bias with shape (new_out, 1)
"""
# LinearGCN layer: bias @ B.T (because bias shape is (out, 1))
return bias @ B.TLinear2.compute_V_weight method · python · L270-L293 (24 LOC)src/cl/models/layers.py
def compute_V_weight(self, A: jax.Array, W: jax.Array, B: jax.Array) -> jax.Array:
"""Compute transformed weight V = A @ W @ B.T for Linear2 layer.
Args:
A: Output transformation matrix (new_out, old_out)
W: Original weight matrix (old_out, old_in)
B: Input transformation matrix (new_in, old_in)
Returns:
Transformed weight V with shape (new_out, new_in)
"""
# Validate shapes
if A.shape[1] != W.shape[0]:
raise AWBShapeError(
f"Linear2 layer: A.shape={A.shape} incompatible with W.shape={W.shape}. "
f"Expected A.shape[1]={A.shape[1]} == W.shape[0]={W.shape[0]}"
)
if B.shape[1] != W.shape[1]:
raise AWBShapeError(
f"Linear2 layer: B.shape={B.shape} incompatible with W.shape={W.shape}. "
f"Expected B.shape[1]={B.shape[1]} == W.shape[1]={W.shape[1]}"
)
return A @ W Linear2.compute_V_bias method · python · L295-L309 (15 LOC)src/cl/models/layers.py
def compute_V_bias(self, A: jax.Array, B: jax.Array, bias: jax.Array) -> jax.Array:
"""Compute transformed bias for Linear2 layer.
For Linear2 with bias shape (out_size, 1), transformation is: A @ bias
Args:
A: Output transformation matrix (new_out, old_out)
B: Input transformation matrix (not used for bias in Linear2)
bias: Original bias vector (old_out, 1)
Returns:
Transformed bias with shape (new_out, 1)
"""
# Linear2 layer: A @ bias (because bias shape is (out, 1) and forward is W @ x)
return A @ biasRepobility · MCP-ready · https://repobility.com
Dropout.__call__ method · python · L326-L347 (22 LOC)src/cl/models/layers.py
def __call__(self, inputs: jax.Array, rng: jax.Array, is_training: bool = True) -> jax.Array:
"""Apply dropout.
Args:
inputs: Input array
rng: JAX PRNG key (required)
is_training: Whether to apply dropout (default: True)
Returns:
Dropout-applied array if training, inputs otherwise
"""
if rng is None:
raise ValueError(
"Dropout layer requires a PRNG key argument. "
"Call with `apply_fun(params, inputs, rng)` where rng is a jax.random.PRNGKey."
)
keep = jax.random.bernoulli(rng, self.rate, shape=inputs.shape)
outs = jnp.where(keep, inputs / self.rate, 0)
# Return inputs unchanged if not training
out = lax.cond(is_training, outs, lambda x: x, inputs, lambda x: x)
return outcompute_V_conv2d_single_channel function · python · L351-L381 (31 LOC)src/cl/models/layers.py
def compute_V_conv2d_single_channel(
A_list: List[jax.Array],
W: jax.Array,
B_list: List[jax.Array],
channel_out: int
) -> List[List[jax.Array]]:
"""Compute V = A @ W @ B.T for Conv2d with single input channel (MNIST).
For single-channel Conv2d (e.g., MNIST), each output filter is transformed
independently. This is used in CNN models with 1-channel input.
Args:
A_list: List of A matrices, one per output filter [channel_out]
Each A has shape (new_filter_size, old_filter_size)
W: Conv weights with shape [channel_out, channel_in, H, W]
B_list: List of B matrices, one per output filter [channel_out]
Each B has shape (new_filter_size, old_filter_size)
channel_out: Number of output channels
Returns:
Transformed weights as list of lists: [channel_out][1]
Each element has shape (new_filter_size, new_filter_size)
"""
new_conv_weights = []
for i in range(channel_out)compute_V_conv2d_multi_channel function · python · L384-L420 (37 LOC)src/cl/models/layers.py
def compute_V_conv2d_multi_channel(
A_list: List[List[jax.Array]],
W: jax.Array,
B_list: List[List[jax.Array]],
channel_out: int,
channel_in: int
) -> List[List[jax.Array]]:
"""Compute V = A @ W @ B.T for Conv2d with multiple input channels (CIFAR).
For multi-channel Conv2d (e.g., CIFAR with 3 channels), each filter is
transformed per input-output channel pair. This enables fine-grained
control over channel-wise transformations.
Args:
A_list: Nested list of A matrices [channel_out][channel_in]
Each A has shape (new_filter_size, old_filter_size)
W: Conv weights with shape [channel_out, channel_in, H, W]
B_list: Nested list of B matrices [channel_out][channel_in]
Each B has shape (new_filter_size, old_filter_size)
channel_out: Number of output channels
channel_in: Number of input channels
Returns:
Transformed weights as nested list: [channel_out][channel_in]
AWBMixin.awb_transform_linear method · python · L459-L486 (28 LOC)src/cl/models/layers.py
def awb_transform_linear(
A: jax.Array,
W: jax.Array,
B: jax.Array
) -> jax.Array:
"""Transform linear layer weights: V = A @ W @ B.T
Uses einsum with ellipsis for arbitrary batch dimensions.
Works for single layers (2D) or batched layers (3D+).
Args:
A: Transformation matrix (..., new_out, old_out)
W: Weight matrix (..., old_out, old_in)
B: Transformation matrix (..., new_in, old_in)
Returns:
Transformed weight V (..., new_out, new_in)
Example:
# Single layer
V = AWBMixin.awb_transform_linear(A, W, B)
# A: (m,k), W: (k,n), B: (p,n) -> V: (m,p)
# Batched layers
V = AWBMixin.awb_transform_linear(A_stacked, W_stacked, B_stacked)
# A: (L,m,k), W: (L,k,n), B: (L,p,n) -> V: (L,m,p)
"""
return awb_transform(A, W, B)AWBMixin.awb_transform_conv method · python · L489-L516 (28 LOC)src/cl/models/layers.py
def awb_transform_conv(
A: jax.Array,
W: jax.Array,
B: jax.Array
) -> jax.Array:
"""Transform convolutional layer weights: V = A @ W @ B.T
Uses einsum with ellipsis for arbitrary channel dimensions.
Works for single-channel (3D) or multi-channel (4D) conv weights.
Args:
A: Transformation matrix (out_ch, [in_ch], new_f, old_f)
W: Conv weight (out_ch, [in_ch], old_f, old_f)
B: Transformation matrix (out_ch, [in_ch], new_f, old_f)
Returns:
Transformed weight V (out_ch, [in_ch], new_f, new_f)
Example:
# CNN single-channel (MNIST)
V = AWBMixin.awb_transform_conv(A, W, B)
# A: (o,m,k), W: (o,k,n), B: (o,p,n) -> V: (o,m,p)
# CNN3D multi-channel (CIFAR)
V = AWBMixin.awb_transform_conv(A, W, B)
# A: (o,i,m,k), W: (o,i,k,n), B: (o,i,p,n) -> V: (o,i,m,p)
"""
return awb_transform(AWBMixin.awb_transform_bias_linear method · python · L519-L539 (21 LOC)src/cl/models/layers.py
def awb_transform_bias_linear(
A: jax.Array,
bias: jax.Array,
bias_shape: str = 'row'
) -> jax.Array:
"""Transform linear layer bias.
Args:
A: Transformation matrix (new_out, old_out)
bias: Original bias
bias_shape: 'row' for (1, out) or 'col' for (out, 1)
Returns:
Transformed bias with same shape convention
"""
if bias_shape == 'row':
# Linear layer: bias @ A.T -> (1, new_out)
return bias @ A.T
else:
# Linear2/LinearGCN: A @ bias -> (new_out, 1)
return A @ biasAWBMixin.get_awb_matrices_spec method · python · L541-L553 (13 LOC)src/cl/models/layers.py
def get_awb_matrices_spec(self) -> dict:
"""Get specification of AWB matrices for this model.
Override in subclasses to specify which attributes hold A/B matrices.
Returns:
Dict with keys:
- 'linear_A': attribute name(s) for linear layer A matrices
- 'linear_B': attribute name(s) for linear layer B matrices
- 'conv_A': attribute name(s) for conv layer A matrices (optional)
- 'conv_B': attribute name(s) for conv layer B matrices (optional)
"""
raise NotImplementedError("Subclasses must implement get_awb_matrices_spec()")AWBMixin.get_awb_layer_specs method · python · L555-L563 (9 LOC)src/cl/models/layers.py
def get_awb_layer_specs(self) -> List[AWBLayerSpec]:
"""Get AWB layer specifications for all transformable layers.
Override in subclasses to provide layer-specific tracking.
Returns:
List of AWBLayerSpec for each layer
"""
raise NotImplementedError("Subclasses must implement get_awb_layer_specs()")Repobility — same analyzer, your code, free for public repos · /scan/
AWBMixin.get_AWBT method · python · L565-L570 (6 LOC)src/cl/models/layers.py
def get_AWBT(self, *args, **kwargs):
"""Forward pass using AWB transformation.
Override in subclasses with model-specific implementation.
"""
raise NotImplementedError("Subclasses must implement get_AWBT()")AWBMixin.partition_for_AB_training method · python · L572-L577 (6 LOC)src/cl/models/layers.py
def partition_for_AB_training(self):
"""Partition model for A/B training (freeze W, train A/B).
Override in subclasses with model-specific partition logic.
"""
raise NotImplementedError("Subclasses must implement partition_for_AB_training()")AWBMixin.partition_for_standard_training method · python · L579-L584 (6 LOC)src/cl/models/layers.py
def partition_for_standard_training(self):
"""Partition model for standard training (freeze A/B, train W).
Override in subclasses with model-specific partition logic.
"""
raise NotImplementedError("Subclasses must implement partition_for_standard_training()")MLP.__init__ method · python · L45-L79 (35 LOC)src/cl/models/mlp.py
def __init__(self, sizes: List[int], key: Optional[jax.Array] = None, awb_enabled: bool = False):
"""Initialize the MLP.
Args:
sizes: List of layer dimensions [input_dim, hidden1, hidden2, ..., output_dim]
key: Optional JAX PRNG key (uses default if not provided)
awb_enabled: Whether to initialize A/B matrices for AWB (default: False)
"""
if key is None:
key = jax.random.PRNGKey(0)
self.sizes = sizes
self.layers = []
self.awb_enabled = awb_enabled
# Only initialize A/B matrices if AWB is enabled
if awb_enabled:
# A transforms output dimension: shape (out_size, 1) initially
# B transforms input dimension: shape (out_size, in_size) initially
self.A = [
jax.random.normal(jax.random.PRNGKey(0), shape=(y, 1))
for y in sizes[1:]
]
self.B = [
jax.random.normal(jMLP.__call__ method · python · L81-L93 (13 LOC)src/cl/models/mlp.py
def __call__(self, x: jax.Array) -> jax.Array:
"""Standard forward pass.
Args:
x: Input tensor of shape (batch, input_dim) or (input_dim,)
Returns:
Output tensor of shape (batch, output_dim) or (output_dim,)
"""
for layer in self.layers[:-1]:
x = jax.nn.tanh(layer(x))
x = self.layers[-1](x)
return xMLP.getAWB method · python · L95-L127 (33 LOC)src/cl/models/mlp.py
def getAWB(self, x: jax.Array) -> jax.Array:
"""Forward pass using AWB transformation: A @ W @ B.T.
Uses AWBMixin.awb_transform_linear for efficient computation.
Used during AWB training (Step 3b) when A/B matrices are being optimized
while W is frozen. The effective weight becomes V = A @ W @ B.T.
Args:
x: Input tensor of shape (input_dim,) - note: expects unbatched input
Returns:
Output tensor after AWB transformation
Raises:
ValueError: If AWB is not enabled (A/B matrices are None)
"""
# Note: We don't check awb_enabled here to allow JIT tracing.
# The caller is responsible for ensuring A/B are properly initialized
# before calling getAWB(). This is guaranteed by the AWB pipeline:
# - partition_for_AB_training() is only called after with_new_AB_matrices()
# - hamiltonian.py only calls getAWB when notABTrain=False
for i in range(MLP.get_awb_layer_specs method · python · L130-L163 (34 LOC)src/cl/models/mlp.py
def get_awb_layer_specs(self) -> List[AWBLayerSpec]:
"""Get AWB layer specifications for each transformable layer.
Returns:
List of AWBLayerSpec, one per layer, containing layer, A, B matrices
Example:
>>> specs = model.get_awb_layer_specs()
>>> for spec in specs:
... errors = spec.validate() # Check shape compatibility
"""
if not self.awb_enabled or self.A is None or self.B is None:
# If AWB not enabled, return specs with None A/B
return [
AWBLayerSpec(
layer=self.layers[i],
A=None,
B=None,
layer_type='linear',
layer_index=i
)
for i in range(len(self.layers))
]
return [
AWBLayerSpec(
layer=self.layers[i],
A=self.A[i] if self.A else None,
BMLP.apply_V_transformation method · python · L165-L199 (35 LOC)src/cl/models/mlp.py
def apply_V_transformation(self) -> 'MLP':
"""Apply V = A @ W @ B.T transformation to all layers.
This is STEP 4 of the AWB algorithm. After training A/B matrices,
compute the effective weights V and update the model.
Returns:
New model with transformed weights V = A @ W @ B.T
Raises:
ValueError: If AWB is not enabled
Example:
>>> # After Step 3b (A/B training)
>>> model = model.apply_V_transformation()
>>> # Now model has V weights, ready for Step 5 training
"""
if not self.awb_enabled or self.A is None or self.B is None:
raise ValueError(
"apply_V_transformation requires awb_enabled=True and A/B matrices. "
"Train A/B matrices first (Step 3b)."
)
model = self
for i, spec in enumerate(self.get_awb_layer_specs()):
layer = spec.layer
# Use layer's AWB methods toOpen data scored by Repobility · https://repobility.com
MLP.partition_for_AB_training method · python · L201-L225 (25 LOC)src/cl/models/mlp.py
def partition_for_AB_training(self):
"""Partition model for A/B training (freeze W, train A/B).
This is used in STEP 3b of AWB algorithm where we train A/B matrices
while keeping layer weights W frozen.
Returns:
Tuple of (diff_model, static_model) where:
- diff_model: Contains only A and B (trainable)
- static_model: Contains layer weights (frozen)
Example:
>>> diff_model, static_model = model.partition_for_AB_training()
>>> # Train diff_model (only A/B parameters)
>>> model = eqx.combine(diff_model, static_model)
"""
# Create filter spec: True for A/B, False for everything else
filter_spec = jtu.tree_map(lambda _: False, self)
filter_spec = eqx.tree_at(
lambda x: (x.A, x.B),
filter_spec,
replace=(True, True)
)
diff_model, static_model = eqx.partition(self, filter_spec)
MLP.partition_for_standard_training method · python · L227-L259 (33 LOC)src/cl/models/mlp.py
def partition_for_standard_training(self):
"""Partition model for standard training (freeze A/B, train W).
This is used in standard CL training and STEP 5 of AWB (train V with A/B frozen).
Returns:
Tuple of (params, static) where:
- params: Contains trainable arrays (layer weights, biases)
- static: Contains A, B matrices (frozen)
Example:
>>> params, static = model.partition_for_standard_training()
>>> # Train params (only layer weights/biases)
>>> model = eqx.combine(params, static)
"""
params, static = eqx.partition(self, eqx.is_array)
if self.awb_enabled and self.A is not None and self.B is not None:
# Move A and B to static (frozen)
static = eqx.tree_at(
lambda x: (x.A, x.B),
static,
replace=(self.A, self.B)
)
# Remove A and B from params (set MLP.with_new_AB_matrices method · python · L261-L299 (39 LOC)src/cl/models/mlp.py
def with_new_AB_matrices(self, original_arch: List[int], new_arch: List[int], seed: int = 5) -> 'MLP':
"""Initialize A/B matrices for architecture transition.
This is used in STEP 3a of AWB algorithm when changing from original
architecture to new architecture.
Args:
original_arch: Original sizes [in, h1, h2, ..., out]
new_arch: New sizes [in, h1', h2', ..., out]
seed: Random seed for initialization
Returns:
Model with new A/B matrices initialized
Example:
>>> # After architecture search finds optimal size
>>> model = model.with_new_AB_matrices([3, 10, 10, 2], [3, 15, 20, 2])
>>> # Now A/B matrices are ready for Step 3b training
"""
initializer = jax.nn.initializers.glorot_uniform()
# A matrices: transform output dimensions [new_out, old_out]
A_list = [
initializer(jax.random.PRNGKey(seed + i), (y_new, y