← back to krm9c__ContLearn

Function bodies 268 total

All specs Real LLM only Function bodies
CNNAWBOps.set_AB_matrices method · python · L754-L766 (13 LOC)
src/cl/models/cnn.py
    def set_AB_matrices(self, model, original_arch, new_arch):
        """Initialize A/B matrices for architecture transition."""
        from ..core.awb import set_new_AB_matrices_cnn, set_new_AB_matrices_cnn3d

        original_filter, original_feed_sizes = original_arch
        new_filter, new_feed_sizes = new_arch

        if self.is_cnn3d:
            return set_new_AB_matrices_cnn3d(model, original_feed_sizes, new_feed_sizes,
                                            original_filter, new_filter)
        else:
            return set_new_AB_matrices_cnn(model, original_feed_sizes, new_feed_sizes,
                                          original_filter, new_filter)
CNNAWBOps.partition_for_AB_training method · python · L768-L775 (8 LOC)
src/cl/models/cnn.py
    def partition_for_AB_training(self, model):
        """Partition model for AB training (freeze W, train A/B)."""
        from ..core.awb import partition_for_AB_training_cnn, partition_for_AB_training_cnn3d

        if self.is_cnn3d:
            return partition_for_AB_training_cnn3d(model)
        else:
            return partition_for_AB_training_cnn(model)
CNNAWBOps.compute_V method · python · L777-L784 (8 LOC)
src/cl/models/cnn.py
    def compute_V(self, model):
        """Compute V = A @ W @ B^T."""
        from ..core.awb import compute_V_from_AWB_cnn, compute_V_from_AWB_cnn3d

        if self.is_cnn3d:
            return compute_V_from_AWB_cnn3d(model)
        else:
            return compute_V_from_AWB_cnn(model)
CNNAWBOps.partition_for_standard_training method · python · L786-L793 (8 LOC)
src/cl/models/cnn.py
    def partition_for_standard_training(self, model):
        """Partition model for standard training (train V, freeze A/B)."""
        from ..core.awb import partition_for_standard_training_cnn, partition_for_standard_training_cnn3d

        if self.is_cnn3d:
            return partition_for_standard_training_cnn3d(model)
        else:
            return partition_for_standard_training_cnn(model)
sp_matmul function · python · L26-L43 (18 LOC)
src/cl/models/gcn.py
def sp_matmul(A, B, shape):
    """Sparse matrix multiplication for graph operations.

    Args:
        A: (N, M) sparse matrix represented as a tuple (indexes, values)
        B: (M, K) dense matrix
        shape: value of N

    Returns:
        (N, K) dense matrix
    """
    assert B.ndim == 2
    indexes, values = A
    rows, cols = indexes
    in_ = B.take(cols, axis=0)
    prod = in_ * values[:, None]
    res = jax.ops.segment_sum(prod, rows, shape)
    return res
GraphPooling.__call__ method · python · L79-L87 (9 LOC)
src/cl/models/gcn.py
    def __call__(self, x: jnp.ndarray, batch: jnp.ndarray, num_nodes: jnp.ndarray) -> jnp.ndarray:
        if self.pool_type == 'sum':
            return Pool.sum(x, batch, num_nodes)
        elif self.pool_type == 'mean':
            return Pool.mean(x, batch, num_nodes)
        elif self.pool_type == 'max':
            return Pool.max(x, batch, num_nodes)
        else:  # identity
            return Pool.identity(x, batch, num_nodes)
GCNLayer.__init__ method · python · L108-L119 (12 LOC)
src/cl/models/gcn.py
    def __init__(self, in_size: int, out_size: int, key: jax.Array,
                 bias: bool = True, sparse: bool = False):
        self.bias_flag = bias
        self.sparse = sparse
        # Use initializer locally, don't store as attribute (breaks JIT)
        initializer = jax.nn.initializers.glorot_uniform()
        wkey, bkey = jax.random.split(key)
        self.weight = initializer(wkey, (in_size, out_size))
        if self.bias_flag:
            self.bias = initializer(bkey, (1, out_size))
        else:
            self.bias = None
Repobility — same analyzer, your code, free for public repos · /scan/
GCNLayer.matmul method · python · L121-L126 (6 LOC)
src/cl/models/gcn.py
    def matmul(self, A, B, shape):
        """Matrix multiplication supporting sparse adjacency."""
        if self.sparse:
            return sp_matmul(A, B, shape)
        else:
            return jnp.matmul(A, B)
GCNLayer.normalize_adjacency method · python · L129-L161 (33 LOC)
src/cl/models/gcn.py
    def normalize_adjacency(self, adj: jax.Array) -> jax.Array:
        """Apply symmetric normalization with self-loops.

        Implements the standard GCN normalization from Kipf & Welling (2017):
        Â = D̃^(-1/2) @ (A + I) @ D̃^(-1/2)

        where:
        - A is the adjacency matrix
        - I is the identity matrix (self-loops)
        - D̃ is the degree matrix of (A + I)

        Args:
            adj: Adjacency matrix (num_nodes, num_nodes)

        Returns:
            Normalized adjacency matrix (num_nodes, num_nodes)
        """
        # Add self-loops: Ã = A + I
        num_nodes = adj.shape[0]
        adj_with_loops = adj + jnp.eye(num_nodes)

        # Compute degree matrix: D̃_ii = sum_j Ã_ij
        degree = jnp.sum(adj_with_loops, axis=1)

        # Compute D̃^(-1/2) with numerical stability (handle isolated nodes)
        deg_inv_sqrt = jnp.power(degree, -0.5)
        deg_inv_sqrt = jnp.where(jnp.isinf(deg_inv_sqrt), 0., deg_inv_sqrt)

        # Symmetric no
GCNLayer.__call__ method · python · L163-L185 (23 LOC)
src/cl/models/gcn.py
    def __call__(self, x: jax.Array, adj: jax.Array) -> jax.Array:
        """Forward pass for GCN layer.

        Implements: adj @ X @ W + b

        Note: Adjacency normalization is applied via T.GCNNorm() transform
        in the data pipeline (loops.py:get_graph_transforms), so we do NOT
        normalize here to avoid double normalization.

        Args:
            x: Node features (num_nodes, in_size)
            adj: Adjacency matrix (num_nodes, num_nodes) - already normalized by T.GCNNorm()

        Returns:
            Updated node features (num_nodes, out_size)
        """
        # Graph convolution: adj @ (x @ W) + b
        # Note: adj is already normalized by T.GCNNorm() transform
        support = x @ self.weight
        x = self.matmul(adj, support, support.shape)
        if self.bias_flag:
            x += self.bias
        return x
GCN.__init__ method · python · L254-L329 (76 LOC)
src/cl/models/gcn.py
    def __init__(self, in_size: int, feed_sizes: List[int] = None,
                 gcn_sizes: List[int] = None, node_num: int = 0,
                 SEED: int = 1234, out_size: int = 2, graph: bool = True,
                 awb_fnn_arch: List[int] = None, awb_gcn_arch: List[int] = None):
        """Initialize GCN model."""
        self.SEED = SEED
        self.graph = graph
        self.node_num = node_num

        # Default architectures
        if gcn_sizes is None:
            gcn_sizes = [in_size, 128]
        if feed_sizes is None:
            feed_sizes = [128, 128, 128, out_size]

        # Ensure first GCN layer matches input size
        gcn_sizes[0] = in_size
        self.gcn_sizes = gcn_sizes

        # Ensure last feed layer matches output size
        feed_sizes[-1] = out_size
        self.feed_sizes = feed_sizes

        # AWB architectures - use provided or defaults
        # AWB should preserve input and output dimensions but can expand hidden layers
        if awb_fnn_a
GCN.matmul method · python · L331-L336 (6 LOC)
src/cl/models/gcn.py
    def matmul(self, A, B, shape):
        """Matrix multiplication supporting sparse adjacency."""
        if self.sparse:
            return sp_matmul(A, B, shape)
        else:
            return jnp.matmul(A, B)
GCN.__call__ method · python · L338-L365 (28 LOC)
src/cl/models/gcn.py
    def __call__(self, x: jax.Array, adj: jax.Array, batch: jax.Array,
                 n_nodes: jax.Array) -> jax.Array:
        """Standard forward pass.

        Args:
            x: Node features (total_nodes, in_size)
            adj: Adjacency matrix (total_nodes, total_nodes)
            batch: Batch assignment for each node (total_nodes,)
            n_nodes: Number of nodes per graph (batch_size,)

        Returns:
            Class logits (batch_size, n_class)
        """
        # GCN layers with LeakyReLU
        for layer in self.gcn_layers:
            x = jax.nn.leaky_relu(layer(x, adj))

        # Graph pooling
        x = self.pool_layer(x, batch, n_nodes)

        # Feed-forward layers (all but last with LeakyReLU)
        for i in range(len(self.feed_sizes) - 2):
            x = jax.nn.leaky_relu(self.feed_layers[i](x))

        # Final layer (no activation)
        x = self.feed_layers[-1](x)

        return x
GCN.get_AWBT method · python · L367-L415 (49 LOC)
src/cl/models/gcn.py
    def get_AWBT(self, x: jax.Array, adj: jax.Array, batch: jax.Array,
                 n_nodes: jax.Array) -> jax.Array:
        """Forward pass using AWB transformation.

        Uses AWBMixin.awb_transform_linear for efficient computation.
        Uses V = A @ W @ B.T for both GCN and feed layers.

        Note: Adjacency normalization is applied via T.GCNNorm() transform
        in the data pipeline, so we do NOT normalize here.

        Args:
            x: Node features (total_nodes, in_size)
            adj: Adjacency matrix (total_nodes, total_nodes) - already normalized by T.GCNNorm()
            batch: Batch assignment for each node (total_nodes,)
            n_nodes: Number of nodes per graph (batch_size,)

        Returns:
            Class logits (batch_size, n_class)
        """
        # GCN layers with AWB transformation using AWBMixin
        for i in range(len(self.gcn_layers)):
            # Compute AWB transformed weight: V = A_gcn @ W @ B_gcn^T
            V_weight
GCN.get_awb_layer_specs method · python · L418-L424 (7 LOC)
src/cl/models/gcn.py
    def get_awb_layer_specs(self) -> List[AWBLayerSpec]:
        """Get AWB specs for feed layers only (GCN layers handled separately)."""
        return [
            AWBLayerSpec(layer=self.feed_layers[i], A=self.A_feed[i], B=self.B_feed[i],
                        layer_type='linear', layer_index=i)
            for i in range(len(self.feed_layers))
        ]
Open data scored by Repobility · https://repobility.com
GCN.partition_for_AB_training method · python · L426-L431 (6 LOC)
src/cl/models/gcn.py
    def partition_for_AB_training(self):
        """Partition for A/B training (freeze W, train A/B)."""
        filter_spec = jtu.tree_map(lambda _: False, self)
        filter_spec = eqx.tree_at(lambda x: (x.A_gcn, x.B_gcn, x.A_feed, x.B_feed),
                                  filter_spec, replace=(True, True, True, True))
        return eqx.partition(self, filter_spec)
GCN.partition_for_standard_training method · python · L433-L440 (8 LOC)
src/cl/models/gcn.py
    def partition_for_standard_training(self):
        """Partition for standard training (freeze A/B, train W)."""
        params, static = eqx.partition(self, eqx.is_array)
        static = eqx.tree_at(lambda x: (x.A_gcn, x.B_gcn, x.A_feed, x.B_feed), static,
                            replace=(self.A_gcn, self.B_gcn, self.A_feed, self.B_feed))
        params = eqx.tree_at(lambda x: (x.A_gcn, x.B_gcn, x.A_feed, x.B_feed), params,
                            replace=(None, None, None, None))
        return params, static
GCN.generate_search_candidates method · python · L443-L516 (74 LOC)
src/cl/models/gcn.py
    def generate_search_candidates(self, iteration: int, current_best: Tuple[List[int], List[int]],
                                   config: Dict[str, Any]) -> List[Tuple[List[int], List[int]]]:
        """Generate candidate architectures for GCN using neighborhood search.

        GCN architecture has two components:
        - gcn_sizes: [in_size, gcn_hidden, ...] for graph convolution layers
        - feed_sizes: [gcn_out, mlp_h1, mlp_h2, n_class] for feedforward layers

        Search strategy:
        - Uses neighborhood search with expanding radius (n)
        - For each iteration, explores 3x3x3 = 27 candidates
        - GCN hidden: z2 + n*(j+1)*step_gcn for j in [0,1,2]
        - MLP hidden1: x1 + n*(k+1)*step_mlp for k in [0,1,2]
        - MLP hidden2: x2 + n*(r+1)*step_mlp for r in [0,1,2]

        Args:
            iteration: Current search iteration (controls neighborhood size n)
            current_best: Tuple of (gcn_sizes, mlp_sizes) lists
            config: Configurat
GCN.create_with_architecture method · python · L519-L548 (30 LOC)
src/cl/models/gcn.py
    def create_with_architecture(cls, arch_spec: Tuple[List[int], List[int]],
                                 seed: int = 0, awb_enabled: bool = True) -> 'GCN':
        """Create GCN model with specified architecture.

        Args:
            arch_spec: Tuple of (gcn_sizes, feed_sizes) lists
            seed: Random seed for weight initialization
            awb_enabled: Whether to enable AWB (always True for GCN)

        Returns:
            New GCN instance with specified architecture
        """
        gcn_sizes, feed_sizes = arch_spec

        # Infer parameters from architecture
        in_size = gcn_sizes[0]
        out_size = feed_sizes[-1]

        # Create new GCN with specified architecture
        return cls(
            in_size=in_size,
            feed_sizes=feed_sizes,
            gcn_sizes=gcn_sizes,
            node_num=0,  # Will be set from actual data
            SEED=seed,
            out_size=out_size,
            graph=True,
            awb_fnn_arch=None,  # 
GCN.reinitialize_weights method · python · L550-L567 (18 LOC)
src/cl/models/gcn.py
    def reinitialize_weights(self, seed: int = 0) -> 'GCN':
        """Reinitialize GCN weights for fair architecture comparison.

        For GCN models, we create a fresh instance because the architecture
        is fixed at initialization (GCN and feed layers).

        Args:
            seed: Random seed for initialization

        Returns:
            New GCN instance with reinitialized weights
        """
        # Create fresh GCN with same architecture
        return self.create_with_architecture(
            arch_spec=(self.gcn_sizes, self.feed_sizes),
            seed=seed,
            awb_enabled=True
        )
GCNAWBOps.search_architecture method · python · L581-L606 (26 LOC)
src/cl/models/gcn.py
    def search_architecture(self, model, task_id, baseline_loss, dataloader_curr,
                           dataloader_exp, test_loader_curr, test_loader_exp, config, trainer=None):
        """Search for optimal GCN architecture.

        Added by Claude: Supports awb_predetermined_arch config to bypass search for debugging.
        Format: awb_predetermined_arch = {task_id: {"gcn_sizes": [...], "feed_sizes": [...]}}
        """
        # Added by Claude: Check for predetermined architecture (bypasses search for debugging)
        predetermined_archs = config.get('awb_predetermined_arch', {})
        if str(task_id) in predetermined_archs or task_id in predetermined_archs:
            arch_config = predetermined_archs.get(str(task_id)) or predetermined_archs.get(task_id)
            new_arch = (arch_config['gcn_sizes'], arch_config['feed_sizes'])
            print(f"  [GCN] Using PREDETERMINED architecture (search bypassed)")
            print(f"  Predetermined arch: GCN={new_arch[0]}
GCNAWBOps.set_AB_matrices method · python · L608-L616 (9 LOC)
src/cl/models/gcn.py
    def set_AB_matrices(self, model, original_arch, new_arch):
        """Initialize A/B matrices for architecture transition."""
        from ..core.awb import set_new_AB_matrices_gcn

        original_gcn_sizes, original_feed_sizes = original_arch
        new_gcn_sizes, new_feed_sizes = new_arch

        return set_new_AB_matrices_gcn(model, original_gcn_sizes, original_feed_sizes,
                                       new_gcn_sizes, new_feed_sizes)
awb_transform function · python · L19-L44 (26 LOC)
src/cl/models/layers.py
def awb_transform(A: jax.Array, W: jax.Array, B: jax.Array) -> jax.Array:
    """Unified AWB weight transformation: V = A @ W @ B.T

    Works for any number of batch dimensions using einsum with ellipsis.
    This enables efficient batched computation for:
    - MLP: 2D matrices (no batch dims)
    - CNN: 3D matrices (batch over output channels)
    - CNN3D: 4D matrices (batch over output and input channels)
    - Transformers: arbitrary batch dims (layers, heads, etc.)

    Args:
        A: Transformation matrix with shape (..., new_out, old_out)
        W: Weight matrix with shape (..., old_out, old_in)
        B: Transformation matrix with shape (..., new_in, old_in)

    Returns:
        Transformed weight V with shape (..., new_out, new_in)

    Example:
        # Single layer (MLP-style)
        V = awb_transform(A, W, B)  # A: (m,k), W: (k,n), B: (p,n) -> V: (m,p)

        # Batched over channels (CNN3D-style)
        V = awb_transform(A, W, B)  # A: (o,i,m,k), W: (o,i,k,n), B:
Repobility · severity-and-effort ranking · https://repobility.com
AWBLayerSpec.validate method · python · L65-L91 (27 LOC)
src/cl/models/layers.py
    def validate(self) -> List[str]:
        """Validate AWB matrix shapes against layer dimensions.

        Returns:
            List of error messages (empty if valid)
        """
        errors = []

        if self.A is not None and hasattr(self.layer, 'weight'):
            # For most layers: A transforms output dimension
            expected_A_cols = self.layer.weight.shape[0]  # old_out
            if self.A.shape[1] != expected_A_cols:
                errors.append(
                    f"Layer {self.layer_index} ({self.layer_type}): "
                    f"A.shape[1]={self.A.shape[1]} != weight.shape[0]={expected_A_cols}"
                )

        if self.B is not None and hasattr(self.layer, 'weight'):
            # For most layers: B transforms input dimension
            expected_B_cols = self.layer.weight.shape[1]  # old_in
            if self.B.shape[1] != expected_B_cols:
                errors.append(
                    f"Layer {self.layer_index} ({self.layer_type}): 
Linear.compute_V_weight method · python · L127-L150 (24 LOC)
src/cl/models/layers.py
    def compute_V_weight(self, A: jax.Array, W: jax.Array, B: jax.Array) -> jax.Array:
        """Compute transformed weight V = A @ W @ B.T for Linear layer.

        Args:
            A: Output transformation matrix (new_out, old_out)
            W: Original weight matrix (old_out, old_in)
            B: Input transformation matrix (new_in, old_in)

        Returns:
            Transformed weight V with shape (new_out, new_in)
        """
        # Validate shapes
        if A.shape[1] != W.shape[0]:
            raise AWBShapeError(
                f"Linear layer: A.shape={A.shape} incompatible with W.shape={W.shape}. "
                f"Expected A.shape[1]={A.shape[1]} == W.shape[0]={W.shape[0]}"
            )
        if B.shape[1] != W.shape[1]:
            raise AWBShapeError(
                f"Linear layer: B.shape={B.shape} incompatible with W.shape={W.shape}. "
                f"Expected B.shape[1]={B.shape[1]} == W.shape[1]={W.shape[1]}"
            )

        return A @ W @ j
Linear.compute_V_bias method · python · L152-L166 (15 LOC)
src/cl/models/layers.py
    def compute_V_bias(self, A: jax.Array, B: jax.Array, bias: jax.Array) -> jax.Array:
        """Compute transformed bias for Linear layer.

        For Linear layers with bias shape (1, out_size), transformation is: bias @ A.T

        Args:
            A: Output transformation matrix (new_out, old_out)
            B: Input transformation matrix (not used for bias in Linear)
            bias: Original bias vector (1, old_out)

        Returns:
            Transformed bias with shape (1, new_out)
        """
        # Linear layer: bias @ A.T (because bias shape is (1, out))
        return bias @ A.T
LinearGCN.__call__ method · python · L191-L196 (6 LOC)
src/cl/models/layers.py
    def __call__(self, x: jax.Array) -> jax.Array:
        """Forward pass: W @ x + bias"""
        x = self.weight @ x
        if self.bias is not None:
            x = x + self.bias
        return x
LinearGCN.compute_V_weight method · python · L199-L222 (24 LOC)
src/cl/models/layers.py
    def compute_V_weight(self, A: jax.Array, W: jax.Array, B: jax.Array) -> jax.Array:
        """Compute transformed weight V = A @ W @ B.T for LinearGCN layer.

        Args:
            A: Output transformation matrix (new_out, old_out)
            W: Original weight matrix (old_out, old_in)
            B: Input transformation matrix (new_in, old_in)

        Returns:
            Transformed weight V with shape (new_out, new_in)
        """
        # Validate shapes
        if A.shape[1] != W.shape[0]:
            raise AWBShapeError(
                f"LinearGCN layer: A.shape={A.shape} incompatible with W.shape={W.shape}. "
                f"Expected A.shape[1]={A.shape[1]} == W.shape[0]={W.shape[0]}"
            )
        if B.shape[1] != W.shape[1]:
            raise AWBShapeError(
                f"LinearGCN layer: B.shape={B.shape} incompatible with W.shape={W.shape}. "
                f"Expected B.shape[1]={B.shape[1]} == W.shape[1]={W.shape[1]}"
            )

        return 
LinearGCN.compute_V_bias method · python · L224-L238 (15 LOC)
src/cl/models/layers.py
    def compute_V_bias(self, A: jax.Array, B: jax.Array, bias: jax.Array) -> jax.Array:
        """Compute transformed bias for LinearGCN layer.

        For LinearGCN with bias shape (out_size, 1), transformation is: bias @ B.T

        Args:
            A: Output transformation matrix (not used for bias in LinearGCN)
            B: Input transformation matrix (new_in, old_in)
            bias: Original bias vector (old_out, 1)

        Returns:
            Transformed bias with shape (new_out, 1)
        """
        # LinearGCN layer: bias @ B.T (because bias shape is (out, 1))
        return bias @ B.T
Linear2.compute_V_weight method · python · L270-L293 (24 LOC)
src/cl/models/layers.py
    def compute_V_weight(self, A: jax.Array, W: jax.Array, B: jax.Array) -> jax.Array:
        """Compute transformed weight V = A @ W @ B.T for Linear2 layer.

        Args:
            A: Output transformation matrix (new_out, old_out)
            W: Original weight matrix (old_out, old_in)
            B: Input transformation matrix (new_in, old_in)

        Returns:
            Transformed weight V with shape (new_out, new_in)
        """
        # Validate shapes
        if A.shape[1] != W.shape[0]:
            raise AWBShapeError(
                f"Linear2 layer: A.shape={A.shape} incompatible with W.shape={W.shape}. "
                f"Expected A.shape[1]={A.shape[1]} == W.shape[0]={W.shape[0]}"
            )
        if B.shape[1] != W.shape[1]:
            raise AWBShapeError(
                f"Linear2 layer: B.shape={B.shape} incompatible with W.shape={W.shape}. "
                f"Expected B.shape[1]={B.shape[1]} == W.shape[1]={W.shape[1]}"
            )

        return A @ W 
Linear2.compute_V_bias method · python · L295-L309 (15 LOC)
src/cl/models/layers.py
    def compute_V_bias(self, A: jax.Array, B: jax.Array, bias: jax.Array) -> jax.Array:
        """Compute transformed bias for Linear2 layer.

        For Linear2 with bias shape (out_size, 1), transformation is: A @ bias

        Args:
            A: Output transformation matrix (new_out, old_out)
            B: Input transformation matrix (not used for bias in Linear2)
            bias: Original bias vector (old_out, 1)

        Returns:
            Transformed bias with shape (new_out, 1)
        """
        # Linear2 layer: A @ bias (because bias shape is (out, 1) and forward is W @ x)
        return A @ bias
Repobility · MCP-ready · https://repobility.com
Dropout.__call__ method · python · L326-L347 (22 LOC)
src/cl/models/layers.py
    def __call__(self, inputs: jax.Array, rng: jax.Array, is_training: bool = True) -> jax.Array:
        """Apply dropout.

        Args:
            inputs: Input array
            rng: JAX PRNG key (required)
            is_training: Whether to apply dropout (default: True)

        Returns:
            Dropout-applied array if training, inputs otherwise
        """
        if rng is None:
            raise ValueError(
                "Dropout layer requires a PRNG key argument. "
                "Call with `apply_fun(params, inputs, rng)` where rng is a jax.random.PRNGKey."
            )

        keep = jax.random.bernoulli(rng, self.rate, shape=inputs.shape)
        outs = jnp.where(keep, inputs / self.rate, 0)
        # Return inputs unchanged if not training
        out = lax.cond(is_training, outs, lambda x: x, inputs, lambda x: x)
        return out
compute_V_conv2d_single_channel function · python · L351-L381 (31 LOC)
src/cl/models/layers.py
def compute_V_conv2d_single_channel(
    A_list: List[jax.Array],
    W: jax.Array,
    B_list: List[jax.Array],
    channel_out: int
) -> List[List[jax.Array]]:
    """Compute V = A @ W @ B.T for Conv2d with single input channel (MNIST).

    For single-channel Conv2d (e.g., MNIST), each output filter is transformed
    independently. This is used in CNN models with 1-channel input.

    Args:
        A_list: List of A matrices, one per output filter [channel_out]
                Each A has shape (new_filter_size, old_filter_size)
        W: Conv weights with shape [channel_out, channel_in, H, W]
        B_list: List of B matrices, one per output filter [channel_out]
                Each B has shape (new_filter_size, old_filter_size)
        channel_out: Number of output channels

    Returns:
        Transformed weights as list of lists: [channel_out][1]
        Each element has shape (new_filter_size, new_filter_size)
    """
    new_conv_weights = []
    for i in range(channel_out)
compute_V_conv2d_multi_channel function · python · L384-L420 (37 LOC)
src/cl/models/layers.py
def compute_V_conv2d_multi_channel(
    A_list: List[List[jax.Array]],
    W: jax.Array,
    B_list: List[List[jax.Array]],
    channel_out: int,
    channel_in: int
) -> List[List[jax.Array]]:
    """Compute V = A @ W @ B.T for Conv2d with multiple input channels (CIFAR).

    For multi-channel Conv2d (e.g., CIFAR with 3 channels), each filter is
    transformed per input-output channel pair. This enables fine-grained
    control over channel-wise transformations.

    Args:
        A_list: Nested list of A matrices [channel_out][channel_in]
                Each A has shape (new_filter_size, old_filter_size)
        W: Conv weights with shape [channel_out, channel_in, H, W]
        B_list: Nested list of B matrices [channel_out][channel_in]
                Each B has shape (new_filter_size, old_filter_size)
        channel_out: Number of output channels
        channel_in: Number of input channels

    Returns:
        Transformed weights as nested list: [channel_out][channel_in]
    
AWBMixin.awb_transform_linear method · python · L459-L486 (28 LOC)
src/cl/models/layers.py
    def awb_transform_linear(
        A: jax.Array,
        W: jax.Array,
        B: jax.Array
    ) -> jax.Array:
        """Transform linear layer weights: V = A @ W @ B.T

        Uses einsum with ellipsis for arbitrary batch dimensions.
        Works for single layers (2D) or batched layers (3D+).

        Args:
            A: Transformation matrix (..., new_out, old_out)
            W: Weight matrix (..., old_out, old_in)
            B: Transformation matrix (..., new_in, old_in)

        Returns:
            Transformed weight V (..., new_out, new_in)

        Example:
            # Single layer
            V = AWBMixin.awb_transform_linear(A, W, B)
            # A: (m,k), W: (k,n), B: (p,n) -> V: (m,p)

            # Batched layers
            V = AWBMixin.awb_transform_linear(A_stacked, W_stacked, B_stacked)
            # A: (L,m,k), W: (L,k,n), B: (L,p,n) -> V: (L,m,p)
        """
        return awb_transform(A, W, B)
AWBMixin.awb_transform_conv method · python · L489-L516 (28 LOC)
src/cl/models/layers.py
    def awb_transform_conv(
        A: jax.Array,
        W: jax.Array,
        B: jax.Array
    ) -> jax.Array:
        """Transform convolutional layer weights: V = A @ W @ B.T

        Uses einsum with ellipsis for arbitrary channel dimensions.
        Works for single-channel (3D) or multi-channel (4D) conv weights.

        Args:
            A: Transformation matrix (out_ch, [in_ch], new_f, old_f)
            W: Conv weight (out_ch, [in_ch], old_f, old_f)
            B: Transformation matrix (out_ch, [in_ch], new_f, old_f)

        Returns:
            Transformed weight V (out_ch, [in_ch], new_f, new_f)

        Example:
            # CNN single-channel (MNIST)
            V = AWBMixin.awb_transform_conv(A, W, B)
            # A: (o,m,k), W: (o,k,n), B: (o,p,n) -> V: (o,m,p)

            # CNN3D multi-channel (CIFAR)
            V = AWBMixin.awb_transform_conv(A, W, B)
            # A: (o,i,m,k), W: (o,i,k,n), B: (o,i,p,n) -> V: (o,i,m,p)
        """
        return awb_transform(
AWBMixin.awb_transform_bias_linear method · python · L519-L539 (21 LOC)
src/cl/models/layers.py
    def awb_transform_bias_linear(
        A: jax.Array,
        bias: jax.Array,
        bias_shape: str = 'row'
    ) -> jax.Array:
        """Transform linear layer bias.

        Args:
            A: Transformation matrix (new_out, old_out)
            bias: Original bias
            bias_shape: 'row' for (1, out) or 'col' for (out, 1)

        Returns:
            Transformed bias with same shape convention
        """
        if bias_shape == 'row':
            # Linear layer: bias @ A.T -> (1, new_out)
            return bias @ A.T
        else:
            # Linear2/LinearGCN: A @ bias -> (new_out, 1)
            return A @ bias
AWBMixin.get_awb_matrices_spec method · python · L541-L553 (13 LOC)
src/cl/models/layers.py
    def get_awb_matrices_spec(self) -> dict:
        """Get specification of AWB matrices for this model.

        Override in subclasses to specify which attributes hold A/B matrices.

        Returns:
            Dict with keys:
                - 'linear_A': attribute name(s) for linear layer A matrices
                - 'linear_B': attribute name(s) for linear layer B matrices
                - 'conv_A': attribute name(s) for conv layer A matrices (optional)
                - 'conv_B': attribute name(s) for conv layer B matrices (optional)
        """
        raise NotImplementedError("Subclasses must implement get_awb_matrices_spec()")
AWBMixin.get_awb_layer_specs method · python · L555-L563 (9 LOC)
src/cl/models/layers.py
    def get_awb_layer_specs(self) -> List[AWBLayerSpec]:
        """Get AWB layer specifications for all transformable layers.

        Override in subclasses to provide layer-specific tracking.

        Returns:
            List of AWBLayerSpec for each layer
        """
        raise NotImplementedError("Subclasses must implement get_awb_layer_specs()")
Repobility — same analyzer, your code, free for public repos · /scan/
AWBMixin.get_AWBT method · python · L565-L570 (6 LOC)
src/cl/models/layers.py
    def get_AWBT(self, *args, **kwargs):
        """Forward pass using AWB transformation.

        Override in subclasses with model-specific implementation.
        """
        raise NotImplementedError("Subclasses must implement get_AWBT()")
AWBMixin.partition_for_AB_training method · python · L572-L577 (6 LOC)
src/cl/models/layers.py
    def partition_for_AB_training(self):
        """Partition model for A/B training (freeze W, train A/B).

        Override in subclasses with model-specific partition logic.
        """
        raise NotImplementedError("Subclasses must implement partition_for_AB_training()")
AWBMixin.partition_for_standard_training method · python · L579-L584 (6 LOC)
src/cl/models/layers.py
    def partition_for_standard_training(self):
        """Partition model for standard training (freeze A/B, train W).

        Override in subclasses with model-specific partition logic.
        """
        raise NotImplementedError("Subclasses must implement partition_for_standard_training()")
MLP.__init__ method · python · L45-L79 (35 LOC)
src/cl/models/mlp.py
    def __init__(self, sizes: List[int], key: Optional[jax.Array] = None, awb_enabled: bool = False):
        """Initialize the MLP.

        Args:
            sizes: List of layer dimensions [input_dim, hidden1, hidden2, ..., output_dim]
            key: Optional JAX PRNG key (uses default if not provided)
            awb_enabled: Whether to initialize A/B matrices for AWB (default: False)
        """
        if key is None:
            key = jax.random.PRNGKey(0)

        self.sizes = sizes
        self.layers = []
        self.awb_enabled = awb_enabled

        # Only initialize A/B matrices if AWB is enabled
        if awb_enabled:
            # A transforms output dimension: shape (out_size, 1) initially
            # B transforms input dimension: shape (out_size, in_size) initially
            self.A = [
                jax.random.normal(jax.random.PRNGKey(0), shape=(y, 1))
                for y in sizes[1:]
            ]
            self.B = [
                jax.random.normal(j
MLP.__call__ method · python · L81-L93 (13 LOC)
src/cl/models/mlp.py
    def __call__(self, x: jax.Array) -> jax.Array:
        """Standard forward pass.

        Args:
            x: Input tensor of shape (batch, input_dim) or (input_dim,)

        Returns:
            Output tensor of shape (batch, output_dim) or (output_dim,)
        """
        for layer in self.layers[:-1]:
            x = jax.nn.tanh(layer(x))
        x = self.layers[-1](x)
        return x
MLP.getAWB method · python · L95-L127 (33 LOC)
src/cl/models/mlp.py
    def getAWB(self, x: jax.Array) -> jax.Array:
        """Forward pass using AWB transformation: A @ W @ B.T.

        Uses AWBMixin.awb_transform_linear for efficient computation.
        Used during AWB training (Step 3b) when A/B matrices are being optimized
        while W is frozen. The effective weight becomes V = A @ W @ B.T.

        Args:
            x: Input tensor of shape (input_dim,) - note: expects unbatched input

        Returns:
            Output tensor after AWB transformation

        Raises:
            ValueError: If AWB is not enabled (A/B matrices are None)
        """
        # Note: We don't check awb_enabled here to allow JIT tracing.
        # The caller is responsible for ensuring A/B are properly initialized
        # before calling getAWB(). This is guaranteed by the AWB pipeline:
        # - partition_for_AB_training() is only called after with_new_AB_matrices()
        # - hamiltonian.py only calls getAWB when notABTrain=False

        for i in range(
MLP.get_awb_layer_specs method · python · L130-L163 (34 LOC)
src/cl/models/mlp.py
    def get_awb_layer_specs(self) -> List[AWBLayerSpec]:
        """Get AWB layer specifications for each transformable layer.

        Returns:
            List of AWBLayerSpec, one per layer, containing layer, A, B matrices

        Example:
            >>> specs = model.get_awb_layer_specs()
            >>> for spec in specs:
            ...     errors = spec.validate()  # Check shape compatibility
        """
        if not self.awb_enabled or self.A is None or self.B is None:
            # If AWB not enabled, return specs with None A/B
            return [
                AWBLayerSpec(
                    layer=self.layers[i],
                    A=None,
                    B=None,
                    layer_type='linear',
                    layer_index=i
                )
                for i in range(len(self.layers))
            ]

        return [
            AWBLayerSpec(
                layer=self.layers[i],
                A=self.A[i] if self.A else None,
                B
MLP.apply_V_transformation method · python · L165-L199 (35 LOC)
src/cl/models/mlp.py
    def apply_V_transformation(self) -> 'MLP':
        """Apply V = A @ W @ B.T transformation to all layers.

        This is STEP 4 of the AWB algorithm. After training A/B matrices,
        compute the effective weights V and update the model.

        Returns:
            New model with transformed weights V = A @ W @ B.T

        Raises:
            ValueError: If AWB is not enabled

        Example:
            >>> # After Step 3b (A/B training)
            >>> model = model.apply_V_transformation()
            >>> # Now model has V weights, ready for Step 5 training
        """
        if not self.awb_enabled or self.A is None or self.B is None:
            raise ValueError(
                "apply_V_transformation requires awb_enabled=True and A/B matrices. "
                "Train A/B matrices first (Step 3b)."
            )

        model = self
        for i, spec in enumerate(self.get_awb_layer_specs()):
            layer = spec.layer
            # Use layer's AWB methods to
Open data scored by Repobility · https://repobility.com
MLP.partition_for_AB_training method · python · L201-L225 (25 LOC)
src/cl/models/mlp.py
    def partition_for_AB_training(self):
        """Partition model for A/B training (freeze W, train A/B).

        This is used in STEP 3b of AWB algorithm where we train A/B matrices
        while keeping layer weights W frozen.

        Returns:
            Tuple of (diff_model, static_model) where:
                - diff_model: Contains only A and B (trainable)
                - static_model: Contains layer weights (frozen)

        Example:
            >>> diff_model, static_model = model.partition_for_AB_training()
            >>> # Train diff_model (only A/B parameters)
            >>> model = eqx.combine(diff_model, static_model)
        """
        # Create filter spec: True for A/B, False for everything else
        filter_spec = jtu.tree_map(lambda _: False, self)
        filter_spec = eqx.tree_at(
            lambda x: (x.A, x.B),
            filter_spec,
            replace=(True, True)
        )
        diff_model, static_model = eqx.partition(self, filter_spec)
        
MLP.partition_for_standard_training method · python · L227-L259 (33 LOC)
src/cl/models/mlp.py
    def partition_for_standard_training(self):
        """Partition model for standard training (freeze A/B, train W).

        This is used in standard CL training and STEP 5 of AWB (train V with A/B frozen).

        Returns:
            Tuple of (params, static) where:
                - params: Contains trainable arrays (layer weights, biases)
                - static: Contains A, B matrices (frozen)

        Example:
            >>> params, static = model.partition_for_standard_training()
            >>> # Train params (only layer weights/biases)
            >>> model = eqx.combine(params, static)
        """
        params, static = eqx.partition(self, eqx.is_array)

        if self.awb_enabled and self.A is not None and self.B is not None:
            # Move A and B to static (frozen)
            static = eqx.tree_at(
                lambda x: (x.A, x.B),
                static,
                replace=(self.A, self.B)
            )

            # Remove A and B from params (set 
MLP.with_new_AB_matrices method · python · L261-L299 (39 LOC)
src/cl/models/mlp.py
    def with_new_AB_matrices(self, original_arch: List[int], new_arch: List[int], seed: int = 5) -> 'MLP':
        """Initialize A/B matrices for architecture transition.

        This is used in STEP 3a of AWB algorithm when changing from original
        architecture to new architecture.

        Args:
            original_arch: Original sizes [in, h1, h2, ..., out]
            new_arch: New sizes [in, h1', h2', ..., out]
            seed: Random seed for initialization

        Returns:
            Model with new A/B matrices initialized

        Example:
            >>> # After architecture search finds optimal size
            >>> model = model.with_new_AB_matrices([3, 10, 10, 2], [3, 15, 20, 2])
            >>> # Now A/B matrices are ready for Step 3b training
        """
        initializer = jax.nn.initializers.glorot_uniform()

        # A matrices: transform output dimensions [new_out, old_out]
        A_list = [
            initializer(jax.random.PRNGKey(seed + i), (y_new, y
‹ prevpage 5 / 6next ›