← back to krm9c__ContLearn

Function bodies 268 total

All specs Real LLM only Function bodies
benchmark_dataloader function · python · L192-L241 (50 LOC)
src/cl/datasets/jax_dataloader.py
def benchmark_dataloader(loader, num_batches: int = 100, warmup: int = 10):
    """Benchmark data loading throughput.

    Measures batches/second and identifies bottlenecks.

    Args:
        loader: DataLoader to benchmark (PyTorch or JAX-wrapped)
        num_batches: Number of batches to measure
        warmup: Number of warmup batches (ignored in timing)

    Returns:
        dict with 'batches_per_sec', 'samples_per_sec', 'avg_batch_time_ms'
    """
    import time

    # Warmup
    iter_loader = iter(loader)
    for _ in range(warmup):
        try:
            _ = next(iter_loader)
        except StopIteration:
            iter_loader = iter(loader)
            _ = next(iter_loader)

    # Benchmark
    start_time = time.perf_counter()
    total_samples = 0

    for i in range(num_batches):
        try:
            batch = next(iter_loader)
            # Get batch size
            if isinstance(batch, tuple):
                x = batch[0]
            else:
                x = bat
get_optimal_prefetch_size function · python · L244-L261 (18 LOC)
src/cl/datasets/jax_dataloader.py
def get_optimal_prefetch_size(batch_time_ms: float, data_load_time_ms: float) -> int:
    """Calculate optimal prefetch queue size.

    Args:
        batch_time_ms: Time to process one batch on GPU (ms)
        data_load_time_ms: Time to load one batch from disk (ms)

    Returns:
        Recommended prefetch_size

    Formula: prefetch_size = ceil(data_load_time / batch_time) + 1
    This ensures the GPU never waits for data.
    """
    import math
    if batch_time_ms <= 0:
        return 2  # Default
    ratio = data_load_time_ms / batch_time_ms
    return max(2, math.ceil(ratio) + 1)
MNISTDataset.__init__ method · python · L49-L78 (30 LOC)
src/cl/datasets/mnist.py
    def __init__(self, config: Dict[str, Any]):
        """Initialize the MNIST dataset.

        Args:
            config: Configuration dictionary
        """
        super().__init__(config)

        self._n_tasks = config.get('n_task', 5)
        self.rotation_range = config.get('rotation_range', DEFAULT_ROTATION_RANGE)
        self.scaling_range = config.get('scaling_range', DEFAULT_SCALING_RANGE)
        self.train_split = config.get('train_test_split', DEFAULT_TRAIN_TEST_SPLIT)

        # Load MNIST dataset
        print("Loading MNIST dataset")
        my_transforms = transforms.Compose([transforms.ToTensor()])
        self.dataset = torchvision.datasets.MNIST(
            './data', train=True, download=True, transform=my_transforms
        )

        # Extract images and labels
        [self.images, self.labels] = [list(t) for t in zip(*self.dataset)]
        self.images = torch.stack(self.images, dim=0)
        self.labels = np.array(self.labels)

        # Apply debug limit 
MNISTDataset._load_task_data method · python · L80-L118 (39 LOC)
src/cl/datasets/mnist.py
    def _load_task_data(self, task_id: int) -> None:
        """Load data for a specific MNIST task with transforms.

        Applies rotation and scaling transforms based on task_id to create
        distribution shift between tasks.

        Args:
            task_id: Task identifier (0-indexed)
        """
        # Set deterministic seed for reproducible task generation
        # Critical for accurate CL metrics - ensures task data is identical during training and evaluation
        np.random.seed(task_id * DEFAULT_PERMUTATION_SEED_MULTIPLIER)

        X = self.images.clone()
        y = self.labels.copy()

        # Apply task-specific transformations
        rot_angle = np.random.random() * self.rotation_range
        scaling_min, scaling_max = self.scaling_range
        scaling = np.random.random() * (scaling_max - scaling_min) + scaling_min

        X = torchvision.transforms.functional.affine(
            X, rot_angle,
            translate=(scaling, scaling),
            scal
PermutedMNISTDataset.__init__ method · python · L157-L186 (30 LOC)
src/cl/datasets/mnist.py
    def __init__(self, config: Dict[str, Any]):
        """Initialize the Permuted MNIST dataset.

        Args:
            config: Configuration dictionary
        """
        super().__init__(config)

        self._n_tasks = config.get('n_task', 10)
        self.train_split = config.get('train_test_split', DEFAULT_TRAIN_TEST_SPLIT)
        self.seed_multiplier = config.get('permutation_seed_multiplier', DEFAULT_PERMUTATION_SEED_MULTIPLIER)
        self.image_size = config.get('image_size', DEFAULT_INPUT_SIZE_MNIST)

        # Load MNIST dataset
        print("Loading Permuted MNIST dataset")
        my_transforms = transforms.Compose([transforms.ToTensor()])
        self.dataset = torchvision.datasets.MNIST(
            './data', train=True, download=True, transform=my_transforms
        )

        # Extract images and labels
        [self.images, self.labels] = [list(t) for t in zip(*self.dataset)]
        self.images = torch.stack(self.images, dim=0)
        self.labels = np.array
PermutedMNISTDataset._load_task_data method · python · L188-L218 (31 LOC)
src/cl/datasets/mnist.py
    def _load_task_data(self, task_id: int) -> None:
        """Load data for a specific permutation task.

        Applies a task-specific permutation to all pixels.

        Args:
            task_id: Task identifier (0-indexed)
        """
        X = self.images.clone()
        y = self.labels.copy()

        # Generate task-specific permutation
        rng = np.random.RandomState(seed=task_id * self.seed_multiplier)
        perm = rng.permutation(self.image_size * self.image_size)

        # Flatten, permute, reshape
        # X shape: (N, 1, 28, 28) -> flatten to (N, 784) -> permute -> reshape back
        X = X.view(X.shape[0], -1)[:, perm].view(X.shape[0], 1, self.image_size, self.image_size)

        # Use deterministic train/test split for reproducibility
        # Use the same RNG state for consistent splits across training and evaluation
        n_samples = X.shape[0]
        n_train = int(self.train_split * n_samples)

        train_idx = rng.randint(0, n_samples, n_train)
generate_sine_data function · python · L30-L75 (46 LOC)
src/cl/datasets/sine.py
def generate_sine_data(delta: float, n_tasks: int = 40, output_path: str = 'data/Incremental_Sine1e^4.p',
                       seed: int = 1) -> str:
    """Generate sine data for continual learning tasks.

    Creates a pickle file containing sine wave data for multiple tasks.
    Each task has gradually increasing frequency and amplitude.

    Args:
        delta: Perturbation value for gradual task drift
        n_tasks: Number of tasks to generate (default: 40)
        output_path: Path to save the pickle file
        seed: Random seed for reproducibility

    Returns:
        Path to the generated pickle file

    Data format per task:
        (y, time, phase, amplitude, frequency) where:
        - y: Sine wave values, shape (n_samples, n_time_points)
        - time: Time points array
        - phase: Phase values, shape (n_samples, 1)
        - amplitude: Amplitude values, shape (n_samples, 1)
        - frequency: Frequency values, shape (n_samples, 1)
    """
    # Added by Cl
Repobility · code-quality intelligence · https://repobility.com
SineDataset.__init__ method · python · L103-L149 (47 LOC)
src/cl/datasets/sine.py
    def __init__(self, config: Dict[str, Any]):
        """Initialize the sine dataset.

        Args:
            config: Configuration dictionary
        """
        super().__init__(config)

        self.delta = config.get('delta', 0.00001)
        self.data_path = config.get('data_path', 'data/Incremental_Sine1e^4.p')
        self._n_tasks = config.get('n_task', 40)
        self.test_size = config.get('test_size', 0.2)

        # Noise parameters for Experiment 3 (increasing noise per task)
        self.noise_enabled = config.get('noise_enabled', False)
        self.noise_scale = config.get('noise_scale', 0.1)
        self.noise_increment = config.get('noise_increment', 0.05)

        # Generate data if file doesn't exist
        if not os.path.exists(self.data_path):
            # Added by Claude: ensure parent directory exists before generating data
            data_dir = os.path.dirname(self.data_path)
            if data_dir and not os.path.exists(data_dir):
                os.
SineDataset._load_task_data method · python · L151-L186 (36 LOC)
src/cl/datasets/sine.py
    def _load_task_data(self, task_id: int) -> None:
        """Load data for a specific sine task.

        Extracts sine data for the given task and creates train/test splits.
        Features: [phase, amplitude, frequency]
        Target: sine wave values (flattened)

        Args:
            task_id: Task identifier (0-indexed)
        """
        if task_id >= self._available_tasks:
            raise ValueError(f"Task {task_id} not available. Max task: {self._available_tasks - 1}")

        # Extract task data
        y, time, phase, amplitude, frequency = self.raw_data['task' + str(task_id)]

        # Create feature matrix: [phase, amplitude, frequency]
        X = np.concatenate([phase, amplitude.reshape([-1, 1]), frequency.reshape([-1, 1])], axis=1)

        # Flatten y if needed (shape: n_samples x n_time_points -> n_samples)
        # For regression, we typically predict all time points
        # But original code treats y as target directly
        y = y.astype(np.float32)
BaseGraphDataset.__init__ method · python · L54-L85 (32 LOC)
src/cl/datasets/synthetic_graph.py
    def __init__(self, config: Dict[str, Any]):
        """Initialize the graph dataset.

        Args:
            config: Configuration dictionary containing:
                - batch_size: Batch size for DataLoaders
                - n_class: Number of classes for task sampling
                - class_per_task: Number of classes per task
                - debug_mode: (optional) Enable debug mode
                - debug_limit: (optional) Number of samples in debug mode
        """
        self.config = config
        self.batch_size = config.get('batch_size', DEFAULT_BATCH_SIZE_GRAPH)
        self.debug_mode = config.get('debug_mode', False)
        self.debug_limit = config.get('debug_limit', 100)

        # Graph datasets (to be set by subclass)
        self.dataset = None
        self.train_data = None
        self.test_data = None

        # Experience replay buffer (list of graph data objects)
        self.memory_train: List = []
        # Added by Claude: Separate test buffer fo
BaseGraphDataset._load_dataset method · python · L88-L96 (9 LOC)
src/cl/datasets/synthetic_graph.py
    def _load_dataset(self) -> None:
        """Load the graph dataset.

        Subclasses must implement this to populate:
        - self.dataset: Full dataset
        - self.train_data: Training split
        - self.test_data: Test split
        """
        pass
BaseGraphDataset.generate_dataset method · python · L120-L184 (65 LOC)
src/cl/datasets/synthetic_graph.py
    def generate_dataset(self, task_id: int, batch_size: int = None,
                         phase: str = 'training') -> Tuple[DataLoader, DataLoader]:
        """Generate train/memory dataloaders for a task.

        Implements continuum_Graph_classification logic with deterministic task selection.
        Uses persistent task-to-class mapping to ensure reproducibility for CL metrics.

        Args:
            task_id: Current task ID (0-indexed)
            batch_size: Batch size for DataLoader
            phase: 'training' or 'testing'

        Returns:
            Tuple of (current_loader, memory_loader)
        """
        if batch_size is None:
            batch_size = self.batch_size

        n_class = self.config.get('n_class', self.num_classes)
        select = self.config.get('class_per_task', 2)

        # Fixed by Claude: Restore proper task-based class selection
        # Use persistent mapping to ensure reproducibility across CL evaluation
        if task_id not in self
BaseGraphDataset.get_test_loader method · python · L186-L198 (13 LOC)
src/cl/datasets/synthetic_graph.py
    def get_test_loader(self, batch_size: int = None) -> DataLoader:
        """Get test data loader.

        Args:
            batch_size: Batch size for DataLoader

        Returns:
            DataLoader for test data
        """
        if batch_size is None:
            batch_size = self.batch_size

        return DataLoader(self.test_data, batch_size=batch_size, shuffle=True)
BaseGraphDataset.generate_test_loader method · python · L200-L218 (19 LOC)
src/cl/datasets/synthetic_graph.py
    def generate_test_loader(self, task_id: int, batch_size: int = None) -> DataLoader:
        """Generate test loader for a specific task (for CL metrics evaluation).

        Added by Claude: This method enables per-task evaluation needed for computing
        the performance matrix A[j][i] = accuracy on task i after training task j.

        Args:
            task_id: Task ID to generate test loader for
            batch_size: Batch size for DataLoader

        Returns:
            DataLoader for task-specific test data
        """
        if batch_size is None:
            batch_size = self.batch_size

        # Use generate_dataset in test phase to get task-specific data
        test_loader, _ = self.generate_dataset(task_id, batch_size, phase='testing')
        return test_loader
BaseGraphDataset.append_to_experience method · python · L220-L226 (7 LOC)
src/cl/datasets/synthetic_graph.py
    def append_to_experience(self, task_id: int) -> None:
        """No-op for graph datasets.

        Experience replay is handled directly in generate_dataset by appending
        to memory_train. This method exists for interface compatibility.
        """
        pass
Generated by Repobility's multi-pass static-analysis pipeline (https://repobility.com)
BaseGraphDataset.get_model_config method · python · L228-L239 (12 LOC)
src/cl/datasets/synthetic_graph.py
    def get_model_config(self) -> Dict[str, Any]:
        """Return configuration for model initialization.

        Returns:
            Dictionary with input_size, output_size, and graph-specific config
        """
        return {
            'input_size': self.num_features,
            'output_size': self.num_classes,
            'num_features': self.num_features,
            'num_classes': self.num_classes,
        }
SyntheticGraphDataset._load_dataset method · python · L260-L296 (37 LOC)
src/cl/datasets/synthetic_graph.py
    def _load_dataset(self) -> None:
        """Load synthetic graph dataset."""
        # Get dataset parameters from config
        num_graphs = self.config.get('num_graphs', DEFAULT_SYNTHETIC_NUM_GRAPHS)
        num_channels = self.config.get('num_channels', DEFAULT_SYNTHETIC_NUM_CHANNELS)
        avg_num_nodes = self.config.get('avg_num_nodes', DEFAULT_SYNTHETIC_AVG_NUM_NODES)
        num_classes = self.config.get('num_classes', DEFAULT_SYNTHETIC_NUM_CLASSES)

        # Set seed for reproducibility
        torch_geometric.seed.seed_everything(DEFAULT_GRAPH_SEED)

        # Create synthetic dataset and shuffle (matching old code behavior)
        self.dataset = FakeDataset(
            num_graphs=num_graphs,
            num_channels=num_channels,
            avg_num_nodes=avg_num_nodes,
            num_classes=num_classes,
            transform=_transform_graph
        ).shuffle()

        # Apply debug limit if enabled
        if self.debug_mode:
            print(f"DEBUG MODE: Lim
TaskShiftGraphDataset.__init__ method · python · L350-L369 (20 LOC)
src/cl/datasets/synthetic_graph.py
    def __init__(self, config: Dict[str, Any]):
        super().__init__(config)
        # Task-shift parameters
        self.task_shift_enabled = config.get('task_shift_enabled', DEFAULT_SYNTHETIC_TASK_SHIFT_ENABLED)
        self.feature_noise_base = config.get('feature_noise_base', DEFAULT_SYNTHETIC_FEATURE_NOISE_BASE)
        self.edge_dropout_base = config.get('edge_dropout_base', DEFAULT_SYNTHETIC_EDGE_DROPOUT_BASE)
        self.feature_shift_base = config.get('feature_shift_base', DEFAULT_SYNTHETIC_FEATURE_SHIFT_BASE)
        self.n_tasks = config.get('n_task', 5)

        # Added by Claude: Perturbation control modes
        self.perturbation_mode = config.get('perturbation_mode', 'linear')
        self.perturbation_growth_rate = config.get('perturbation_growth_rate', 1.5)
        self.perturbation_step_size = config.get('perturbation_step_size', 2)
        self.custom_perturbations = config.get('custom_perturbations', {})

        # Fixed by Claude: Store base train/test splits
TaskShiftGraphDataset._load_dataset method · python · L371-L426 (56 LOC)
src/cl/datasets/synthetic_graph.py
    def _load_dataset(self) -> None:
        """Load synthetic graph dataset with FIXED train/test split."""
        # Get dataset parameters from config
        num_graphs = self.config.get('num_graphs', DEFAULT_SYNTHETIC_NUM_GRAPHS)
        num_channels = self.config.get('num_channels', DEFAULT_SYNTHETIC_NUM_CHANNELS)
        avg_num_nodes = self.config.get('avg_num_nodes', DEFAULT_SYNTHETIC_AVG_NUM_NODES)
        num_classes = self.config.get('num_classes', DEFAULT_SYNTHETIC_NUM_CLASSES_PER_TASK)

        # Set seed for reproducibility
        torch_geometric.seed.seed_everything(DEFAULT_GRAPH_SEED)

        # Create synthetic dataset and shuffle ONCE
        self.dataset = FakeDataset(
            num_graphs=num_graphs,
            num_channels=num_channels,
            avg_num_nodes=avg_num_nodes,
            num_classes=num_classes,
            transform=_transform_graph
        ).shuffle()

        # Apply debug limit if enabled
        if self.debug_mode:
            print(f"DE
TaskShiftGraphDataset._get_perturbation_params method · python · L428-L483 (56 LOC)
src/cl/datasets/synthetic_graph.py
    def _get_perturbation_params(self, task_id: int) -> Tuple[float, float, float]:
        """Get perturbation parameters for a specific task based on mode.

        Added by Claude: Flexible perturbation control for different experimental designs.

        Args:
            task_id: Current task ID

        Returns:
            Tuple of (feature_noise_std, edge_dropout_prob, feature_shift)
        """
        # Task 0 always has no perturbation (baseline)
        if task_id == 0:
            return 0.0, 0.0, 0.0

        if self.perturbation_mode == 'linear':
            # Linear scaling: perturbation = task_id * base
            noise_std = task_id * self.feature_noise_base
            dropout_prob = task_id * self.edge_dropout_base
            shift = task_id * self.feature_shift_base

        elif self.perturbation_mode == 'exponential':
            # Exponential scaling: perturbation = base * (rate ^ task_id)
            rate = self.perturbation_growth_rate
            noise_std 
TaskShiftGraphDataset._apply_task_perturbation method · python · L485-L546 (62 LOC)
src/cl/datasets/synthetic_graph.py
    def _apply_task_perturbation(self, data_list: List, task_id: int, seed: int = None) -> List:
        """Apply task-specific perturbations to graph data.

        Uses per-graph deterministic seeding so that the same graph always gets
        the same perturbation pattern regardless of whether it's in train or test.
        This fixes the train/test generalization gap caused by different random seeds.

        Args:
            data_list: List of graph data objects
            task_id: Current task ID (0 = no perturbation)
            seed: Base random seed (combined with graph index for per-graph seeding)

        Returns:
            List of perturbed graph data objects
        """
        import copy
        import torch

        base_seed = seed if seed is not None else DEFAULT_GRAPH_SEED

        # Get perturbation parameters based on mode
        feature_noise_std, edge_dropout_prob, feature_shift = self._get_perturbation_params(task_id)

        # Task 0 has no perturbation -
TaskShiftGraphDataset._get_or_generate_task_data method · python · L548-L577 (30 LOC)
src/cl/datasets/synthetic_graph.py
    def _get_or_generate_task_data(self, task_id: int) -> Tuple[List, List]:
        """Generate or retrieve cached train/test data for a task.

        Fixed by Claude: Now applies perturbations to FIXED train/test splits
        instead of shuffling and re-splitting per task. This prevents data leakage.

        Args:
            task_id: Task ID

        Returns:
            Tuple of (train_data, test_data) for this task
        """
        # Check if already generated
        if task_id in self._task_train_data and task_id in self._task_test_data:
            return self._task_train_data[task_id], self._task_test_data[task_id]

        # Fixed by Claude: Apply perturbations to FIXED train and test sets SEPARATELY
        # This ensures the same underlying graphs are always in train or test
        train_data = self._apply_task_perturbation(
            self._base_train_data, task_id, seed=DEFAULT_GRAPH_SEED
        )
        test_data = self._apply_task_perturbation(
            se
TaskShiftGraphDataset.generate_dataset method · python · L579-L625 (47 LOC)
src/cl/datasets/synthetic_graph.py
    def generate_dataset(self, task_id: int, batch_size: int = None,
                         phase: str = 'training') -> Tuple[DataLoader, DataLoader]:
        """Generate train/memory dataloaders for a task with domain shift.

        Each task uses the same classes but with task-specific perturbations.
        Memory buffer accumulates perturbed data from all seen tasks.

        Args:
            task_id: Current task ID (0-indexed)
            batch_size: Batch size for DataLoader
            phase: 'training' or 'testing'

        Returns:
            Tuple of (current_loader, memory_loader)
        """
        if batch_size is None:
            batch_size = self.batch_size

        # Track if this is first time generating this task's data
        is_new_task = task_id not in self._task_train_data

        # Get or generate train/test data for this task
        train_data, test_data = self._get_or_generate_task_data(task_id)

        # Select appropriate data based on phase
     
Repobility · open methodology · https://repobility.com/research/
TaskShiftGraphDataset.get_perturbation_schedule method · python · L627-L643 (17 LOC)
src/cl/datasets/synthetic_graph.py
    def get_perturbation_schedule(self) -> Dict[int, Dict[str, float]]:
        """Return the perturbation schedule for all tasks.

        Added by Claude: Useful for logging and visualization.

        Returns:
            Dict mapping task_id -> {noise, dropout, shift}
        """
        schedule = {}
        for task_id in range(self.n_tasks):
            noise, dropout, shift = self._get_perturbation_params(task_id)
            schedule[task_id] = {
                'noise': noise,
                'dropout': dropout,
                'shift': shift
            }
        return schedule
TUGraphDataset._load_dataset method · python · L671-L701 (31 LOC)
src/cl/datasets/synthetic_graph.py
    def _load_dataset(self) -> None:
        """Load TU graph dataset."""
        data_name = self.config.get('data', 'MUTAG')

        # Set seed for reproducibility
        torch_geometric.seed.seed_everything(DEFAULT_GRAPH_SEED)

        # Load TU Dataset
        self.dataset = TUDataset(
            root='data/TUDataset',
            name=data_name,
            transform=_transform_graph
        ).shuffle()

        # Apply debug limit if enabled
        if self.debug_mode:
            print(f"DEBUG MODE: Limiting {data_name} data from {len(self.dataset)} to {self.debug_limit} samples")
            self.dataset = self.dataset[:self.debug_limit]

        # Split into train/test
        length = len(self.dataset)
        train_split = self.config.get('train_test_split', DEFAULT_TRAIN_TEST_SPLIT)
        self.train_data = self.dataset[:int(train_split * length)]
        self.test_data = self.dataset[int(train_split * length):]

        print(f'Dataset: {data_name}')
        print('===
load_graph_dataset function · python · L714-L734 (21 LOC)
src/cl/datasets/synthetic_graph.py
def load_graph_dataset(config: Dict[str, Any]) -> BaseGraphDataset:
    """Factory function to load appropriate graph dataset.

    Args:
        config: Configuration dictionary with 'data' key

    Returns:
        Graph dataset instance (SyntheticGraphDataset, TaskShiftGraphDataset, or TUGraphDataset)
    """
    data_name = config.get('data', 'synthetic')

    if data_name == 'synthetic':
        return SyntheticGraphDataset(config)
    # Added by Claude: Task-shift synthetic dataset for domain-shift CL
    elif data_name == 'synthetic_taskshift':
        return TaskShiftGraphDataset(config)
    elif data_name in ['MUTAG', 'ENZYMES', 'PROTEINS']:
        return TUGraphDataset(config)
    else:
        raise ValueError(f"Unknown graph dataset: {data_name}. "
                         f"Supported: 'synthetic', 'synthetic_taskshift', 'MUTAG', 'ENZYMES', 'PROTEINS'")
CNNorig.__init__ method · python · L34-L45 (12 LOC)
src/cl/models/cnn.py
    def __init__(self, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        # Standard CNN setup: convolutional layer, followed by flattening,
        # with a small MLP on top.
        self.conv_layers = [
            eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
            ]
        self.feed_layers =[
            eqx.nn.Linear(1728, 512, key=key2),
            eqx.nn.Linear(512, 64, key=key3),
            eqx.nn.Linear(64, 10, key=key4),
        ]
CNNorig.__call__ method · python · L47-L52 (6 LOC)
src/cl/models/cnn.py
    def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
        x = jnp.ravel(jax.nn.relu(eqx.nn.MaxPool2d(kernel_size=2, stride=2)(self.conv_layers[0](x))))
        x = jax.nn.relu(self.feed_layers[0](x))
        x = jax.nn.relu(self.feed_layers[1](x))
        x = self.feed_layers[2](x)
        return x
CNN.__init__ method · python · L74-L143 (70 LOC)
src/cl/models/cnn.py
    def __init__(self, key, filter_size, feed_sizes,
                 input_size=None,
                 channel_in=1,
                 channel_out=None,
                 awb_arch=None,
                 awb_filter_size=None,
                 padding=None,
                 stride=None):
        """
        Args:
            key: PRNG key
            filter_size: Convolutional filter size
            feed_sizes: List of feed-forward layer sizes
            input_size: Input image size (default: 28 for MNIST/Omniglot)
            channel_in: Number of input channels (default: 1)
            channel_out: Number of output channels (default: 3)
            awb_arch: AWB architecture (default: [1875, 700, 100, 10])
            awb_filter_size: AWB filter size (default: 5)
            padding: Convolution padding (default: 0)
            stride: Convolution stride (default: 1)
        """
        key1, key2, key3, key4 = jax.random.split(key, 4)

        # Set defaults from constants
        se
CNN.calc_output_size method · python · L145-L150 (6 LOC)
src/cl/models/cnn.py
    def calc_output_size(self, fil_size, input_size=None):
        """Calculate output size after convolution"""
        if input_size is None:
            input_size = self.input_size
        output = ((input_size - fil_size + 2 * self.padding) / self.stride) + 1
        return int(output)
CNN.pool_output_size method · python · L152-L157 (6 LOC)
src/cl/models/cnn.py
    def pool_output_size(self, pool_size, conv_inputsize, pool_stride=None):
        """Calculate output size after pooling"""
        if pool_stride is None:
            pool_stride = DEFAULT_POOL_STRIDE
        output = ((conv_inputsize - pool_size) / pool_stride) + 1
        return int(output)
Repobility (the analyzer behind this table) · https://repobility.com
CNN.__call__ method · python · L159-L168 (10 LOC)
src/cl/models/cnn.py
    def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
        x = jnp.ravel(jax.nn.relu(eqx.nn.MaxPool2d(kernel_size=2, stride=2)(self.conv_layers[0](x))))
        for lin in self.feed_layers[:-1]:
            #print("x in model:", x.shape)
            x = jax.nn.relu(lin(x))
        x = self.feed_layers[-1](x)
        #for lin in self.feed_layers:
            #x = jax.nn.relu(lin(x))
        #x = self.feed_layers[0](x)
        return x
CNN.get_AWBT method · python · L175-L205 (31 LOC)
src/cl/models/cnn.py
    def get_AWBT(self, x):
        """Forward pass using AWB transformation.

        Uses AWBMixin.awb_transform_conv for efficient batched computation.
        Single einsum call replaces nested Python loops, enabling JIT optimization.
        """
        # AWB transformation on conv layer using AWBMixin
        # For single-channel input (MNIST), conv weight is (out_ch, 1, H, W)
        # We transform each output channel: A[i] @ W[i,0] @ B[i].T
        # Using stacked arrays: A_conv (out_ch, new_f, old_f), W[:, 0] (out_ch, H, W)
        conv_weights = self.conv_layers[0].weight[:, 0]  # (channel_out, H, W)
        weights_transformed = self.awb_transform_conv(self.A_conv, conv_weights, self.B_conv)
        # Add back the channel_in dimension for convolution
        weights_transformed = weights_transformed[:, jnp.newaxis, :, :]  # (out_ch, 1, new_H, new_W)

        x = jnp.expand_dims(x, axis=0)
        x = jax.lax.conv_general_dilated(lhs=x, rhs=weights_transformed, window_strides=
CNN.get_awb_layer_specs method · python · L208-L214 (7 LOC)
src/cl/models/cnn.py
    def get_awb_layer_specs(self) -> List[AWBLayerSpec]:
        """Get AWB specs for feed layers only (conv handled separately)."""
        return [
            AWBLayerSpec(layer=self.feed_layers[i], A=self.A_feed[i], B=self.B_feed[i],
                        layer_type='linear2', layer_index=i)
            for i in range(len(self.feed_layers))
        ]
CNN.partition_for_AB_training method · python · L216-L221 (6 LOC)
src/cl/models/cnn.py
    def partition_for_AB_training(self):
        """Partition for A/B training (freeze W, train A/B)."""
        filter_spec = jtu.tree_map(lambda _: False, self)
        filter_spec = eqx.tree_at(lambda x: (x.A_conv, x.B_conv, x.A_feed, x.B_feed),
                                  filter_spec, replace=(True, True, True, True))
        return eqx.partition(self, filter_spec)
CNN.partition_for_standard_training method · python · L223-L252 (30 LOC)
src/cl/models/cnn.py
    def partition_for_standard_training(self):
        """Partition for standard training (freeze A/B, train W).

        A_conv, B_conv are stacked arrays, A_feed, B_feed are lists of arrays.
        Uses eqx.is_array to properly separate arrays from non-arrays (ints, etc.),
        then moves A/B matrices to static for freezing.
        """
        # Fixed by Claude: Use eqx.is_array to properly handle non-array leaves (ints)
        # Previous approach with tree_map(lambda _: True, self) put ints in params
        # which caused jax.grad to fail with "int64 not supported" error
        params, static = eqx.partition(self, eqx.is_array)

        # Move A/B matrices from params to static (freeze them)
        # Use is_leaf=lambda x: x is None to handle None values in static
        static = eqx.tree_at(
            lambda x: (x.A_conv, x.B_conv, x.A_feed, x.B_feed),
            static,
            replace=(self.A_conv, self.B_conv, self.A_feed, self.B_feed),
            is_leaf=lambda
CNN.generate_search_candidates method · python · L255-L318 (64 LOC)
src/cl/models/cnn.py
    def generate_search_candidates(self, iteration, current_best, config):
        """Generate candidate architectures for CNN search.

        # Added by Claude: Model-specific search strategy for CNN
        CNN searches over filter_size and feed hidden dimensions.
        Returns list of (filter_size, feed_sizes) tuples.

        Args:
            iteration: Current search iteration (0-indexed)
            current_best: Tuple of (filter_size, feed_sizes) for current best
            config: Configuration dict with search hyperparameters

        Returns:
            List of (filter_size, feed_sizes) candidate tuples
        """
        from ..config.constants import (
            DEFAULT_ARCH_SEARCH_HIDDEN_RANGE,
            DEFAULT_ARCH_SEARCH_FILTER_MIN,
            DEFAULT_ARCH_SEARCH_FILTER_MAX,
            DEFAULT_ARCH_SEARCH_STEP_SIZE_MLP,
        )

        # Get search hyperparameters
        search_hidden_range = config.get('arch_search_hidden_range', DEFAULT_ARCH_SEARCH_HI
CNN.create_with_architecture method · python · L320-L346 (27 LOC)
src/cl/models/cnn.py
    def create_with_architecture(self, arch_spec, seed=0, awb_enabled=True):
        """Create CNN instance with specified architecture.

        # Added by Claude: Instance method for search (uses self for channel_in, etc.)

        Args:
            arch_spec: Tuple of (filter_size, feed_sizes)
            seed: Random seed for weight initialization
            awb_enabled: Not used for CNN (always has AWB matrices)

        Returns:
            New CNN instance with specified architecture
        """
        filter_size, feed_sizes = arch_spec

        # Added by Claude: Use self attributes to preserve channel_in, channel_out, input_size
        key = jax.random.PRNGKey(seed)
        return CNN(
            key=key,
            filter_size=filter_size,
            feed_sizes=feed_sizes,
            input_size=self.input_size,
            channel_in=self.channel_in,
            channel_out=self.channel_out,
            padding=self.padding,
            stride=self.stride
        )
CNN.reinitialize_weights method · python · L348-L372 (25 LOC)
src/cl/models/cnn.py
    def reinitialize_weights(self, seed=0):
        """Reinitialize model weights for fair architecture comparison.

        # Added by Claude: Weight reinitialization for search
        For CNN, we create a fresh model instead of reinitializing in place
        since the architecture is already set.

        Args:
            seed: Random seed for reproducible initialization

        Returns:
            New CNN with fresh weights (same architecture)
        """
        # For CNN, architecture is fixed at creation, so we return a fresh instance
        key = jax.random.PRNGKey(seed)
        return CNN(
            key=key,
            filter_size=self.filter_size,
            feed_sizes=self.feed_sizes,
            input_size=self.input_size,
            channel_in=self.channel_in,
            channel_out=self.channel_out,
            padding=self.padding,
            stride=self.stride
        )
Repobility · code-quality intelligence · https://repobility.com
CNN3D.__init__ method · python · L396-L490 (95 LOC)
src/cl/models/cnn.py
    def __init__(self, key, filter_size, feed_sizes,
                 input_size=None,
                 channel_in=None,
                 channel_out=None,
                 num_classes=10,
                 awb_filter_increment=None,
                 awb_hidden_layers=None):
        """
        Args:
            key: PRNG key
            filter_size: Convolutional filter size
            feed_sizes: List of feed-forward layer sizes
            input_size: Input image size (default: 32 for CIFAR)
            channel_in: Number of input channels (default: 3)
            channel_out: Number of output channels (default: 32)
            num_classes: Number of output classes (default: 10)
            awb_filter_increment: Increment for AWB filter size (default: 2)
            awb_hidden_layers: AWB hidden layer sizes (default: [512, 256])
        """
        key1, key2, key3, key4, key5 = jax.random.split(key, 5)

        # Set defaults from constants
        if input_size is None:
          
CNN3D.calc_output_size method · python · L492-L498 (7 LOC)
src/cl/models/cnn.py
    def calc_output_size(self, input_size, fil_size, pool_size=2):
        """Calculate output size after convolution and pooling."""
        # After conv: (input_size - fil_size + 1)
        # After pool: floor((conv_out) / pool_size)
        conv_out = input_size - fil_size + 1
        pool_out = conv_out // pool_size
        return pool_out
CNN3D.__call__ method · python · L500-L513 (14 LOC)
src/cl/models/cnn.py
    def __call__(self, x: Float[Array, "3 32 32"]) -> Float[Array, "num_classes"]:
        # First conv + pool
        x = jax.nn.relu(self.conv_layers[0](x))
        x = eqx.nn.MaxPool2d(kernel_size=2, stride=2)(x)
        # Second conv + pool
        x = jax.nn.relu(self.conv_layers[1](x))
        x = eqx.nn.MaxPool2d(kernel_size=2, stride=2)(x)
        # Flatten
        x = jnp.ravel(x)
        # Feed forward layers
        for lin in self.feed_layers[:-1]:
            x = jax.nn.relu(lin(x))
        x = self.feed_layers[-1](x)
        return x
CNN3D.get_AWBT method · python · L515-L555 (41 LOC)
src/cl/models/cnn.py
    def get_AWBT(self, x):
        """Forward pass using AWB transformation.

        Uses AWBMixin.awb_transform_conv for efficient batched computation.
        Single einsum call replaces nested Python loops, enabling JIT optimization.
        """
        # AWB transformation on first conv layer using AWBMixin
        # A_conv1, B_conv1: (channel_out, channel_in, new_f, old_f)
        # conv weight: (channel_out, channel_in, old_f, old_f)
        weights_transformed1 = self.awb_transform_conv(self.A_conv1, self.conv_layers[0].weight, self.B_conv1)

        x = jnp.expand_dims(x, axis=0)
        x = jax.lax.conv_general_dilated(lhs=x, rhs=weights_transformed1, window_strides=(1, 1), padding="VALID")
        x = x.squeeze(0)
        # Add conv1 bias - eqx.nn.Conv2d bias is (out_ch, 1, 1), broadcasts with (out_ch, H, W)
        x = x + self.conv_layers[0].bias
        x = jax.nn.relu(x)
        x = eqx.nn.MaxPool2d(kernel_size=2, stride=2)(x)

        # AWB transformation on second conv
CNN3D.get_awb_layer_specs method · python · L558-L564 (7 LOC)
src/cl/models/cnn.py
    def get_awb_layer_specs(self) -> List[AWBLayerSpec]:
        """Get AWB specs for feed layers only (conv handled separately)."""
        return [
            AWBLayerSpec(layer=self.feed_layers[i], A=self.A_feed[i], B=self.B_feed[i],
                        layer_type='linear2', layer_index=i)
            for i in range(len(self.feed_layers))
        ]
CNN3D.partition_for_AB_training method · python · L566-L571 (6 LOC)
src/cl/models/cnn.py
    def partition_for_AB_training(self):
        """Partition for A/B training (freeze W, train A/B)."""
        filter_spec = jtu.tree_map(lambda _: False, self)
        filter_spec = eqx.tree_at(lambda x: (x.A_conv1, x.B_conv1, x.A_conv2, x.B_conv2, x.A_feed, x.B_feed),
                                  filter_spec, replace=(True, True, True, True, True, True))
        return eqx.partition(self, filter_spec)
CNN3D.partition_for_standard_training method · python · L573-L604 (32 LOC)
src/cl/models/cnn.py
    def partition_for_standard_training(self):
        """Partition for standard training (freeze A/B, train W).

        A_conv1, B_conv1, A_conv2, B_conv2 are now stacked 4D arrays.
        A_feed, B_feed are lists of arrays.
        Uses eqx.is_array to properly separate arrays from non-arrays (ints, etc.),
        then moves A/B matrices to static for freezing.
        """
        # Fixed by Claude: Use eqx.is_array to properly handle non-array leaves (ints)
        # Previous approach with tree_map(lambda _: True, self) put ints in params
        # which caused jax.grad to fail with "int64 not supported" error
        params, static = eqx.partition(self, eqx.is_array)

        # Move A/B matrices from params to static (freeze them)
        # Use is_leaf=lambda x: x is None to handle None values in static
        static = eqx.tree_at(
            lambda x: (x.A_conv1, x.B_conv1, x.A_conv2, x.B_conv2, x.A_conv, x.B_conv, x.A_feed, x.B_feed),
            static,
            replace=(
CNN3D.generate_search_candidates method · python · L607-L672 (66 LOC)
src/cl/models/cnn.py
    def generate_search_candidates(self, iteration, current_best, config):
        """Generate candidate architectures for CNN3D search.

        # Added by Claude: Model-specific search strategy for CNN3D
        CNN3D searches over filter_size and feed hidden dimensions (same as CNN).
        Returns list of (filter_size, feed_sizes) tuples.

        Args:
            iteration: Current search iteration (0-indexed)
            current_best: Tuple of (filter_size, feed_sizes) for current best
            config: Configuration dict with search hyperparameters

        Returns:
            List of (filter_size, feed_sizes) candidate tuples
        """
        from ..config.constants import (
            DEFAULT_ARCH_SEARCH_HIDDEN_RANGE,
            DEFAULT_ARCH_SEARCH_FILTER_MIN,
            DEFAULT_ARCH_SEARCH_FILTER_MAX,
            DEFAULT_ARCH_SEARCH_STEP_SIZE_MLP,
        )

        # Get search hyperparameters
        search_hidden_range = config.get('arch_search_hidden_range', DE
Generated by Repobility's multi-pass static-analysis pipeline (https://repobility.com)
CNN3D.create_with_architecture method · python · L674-L698 (25 LOC)
src/cl/models/cnn.py
    def create_with_architecture(self, arch_spec, seed=0, awb_enabled=True):
        """Create CNN3D instance with specified architecture.

        # Added by Claude: Instance method for search (uses self for channel_in, etc.)

        Args:
            arch_spec: Tuple of (filter_size, feed_sizes)
            seed: Random seed for weight initialization
            awb_enabled: Not used for CNN3D (always has AWB matrices)

        Returns:
            New CNN3D instance with specified architecture
        """
        filter_size, feed_sizes = arch_spec

        # Added by Claude: Use self attributes to preserve channel_in, channel_out, input_size
        key = jax.random.PRNGKey(seed)
        return CNN3D(
            key=key,
            filter_size=filter_size,
            feed_sizes=feed_sizes,
            input_size=self.input_size,
            channel_in=self.channel_in,
            channel_out=self.channel_out
        )
CNN3D.reinitialize_weights method · python · L700-L721 (22 LOC)
src/cl/models/cnn.py
    def reinitialize_weights(self, seed=0):
        """Reinitialize model weights for fair architecture comparison.

        # Added by Claude: Weight reinitialization for search
        For CNN3D, we create a fresh model instead of reinitializing in place.

        Args:
            seed: Random seed for reproducible initialization

        Returns:
            New CNN3D with fresh weights (same architecture)
        """
        # For CNN3D, architecture is fixed at creation, so we return a fresh instance
        key = jax.random.PRNGKey(seed)
        return CNN3D(
            key=key,
            filter_size=self.filter_size,
            feed_sizes=self.feed_sizes,
            input_size=self.input_size,
            channel_in=self.channel_in,
            channel_out=self.channel_out
        )
CNNAWBOps.search_architecture method · python · L739-L752 (14 LOC)
src/cl/models/cnn.py
    def search_architecture(self, model, task_id, baseline_loss, dataloader_curr,
                           dataloader_exp, test_loader_curr, test_loader_exp, config, trainer=None):
        """Search for optimal CNN architecture."""
        from ..core.arch_search import search_architecture
        # Extract current architecture
        baseline_arch = (model.filter_size, list(model.feed_sizes))
        model_type = 'cnn3d' if self.is_cnn3d else 'cnn'
        return search_architecture(
            model=model, baseline_arch=baseline_arch, task_id=task_id,
            baseline_loss=baseline_loss, dataloader_curr=dataloader_curr,
            dataloader_exp=dataloader_exp, test_loader_curr=test_loader_curr,
            test_loader_exp=test_loader_exp, config=config, trainer=trainer,
            model_type=model_type
        )
‹ prevpage 4 / 6next ›