Function bodies 268 total
benchmark_dataloader function · python · L192-L241 (50 LOC)src/cl/datasets/jax_dataloader.py
def benchmark_dataloader(loader, num_batches: int = 100, warmup: int = 10):
"""Benchmark data loading throughput.
Measures batches/second and identifies bottlenecks.
Args:
loader: DataLoader to benchmark (PyTorch or JAX-wrapped)
num_batches: Number of batches to measure
warmup: Number of warmup batches (ignored in timing)
Returns:
dict with 'batches_per_sec', 'samples_per_sec', 'avg_batch_time_ms'
"""
import time
# Warmup
iter_loader = iter(loader)
for _ in range(warmup):
try:
_ = next(iter_loader)
except StopIteration:
iter_loader = iter(loader)
_ = next(iter_loader)
# Benchmark
start_time = time.perf_counter()
total_samples = 0
for i in range(num_batches):
try:
batch = next(iter_loader)
# Get batch size
if isinstance(batch, tuple):
x = batch[0]
else:
x = batget_optimal_prefetch_size function · python · L244-L261 (18 LOC)src/cl/datasets/jax_dataloader.py
def get_optimal_prefetch_size(batch_time_ms: float, data_load_time_ms: float) -> int:
"""Calculate optimal prefetch queue size.
Args:
batch_time_ms: Time to process one batch on GPU (ms)
data_load_time_ms: Time to load one batch from disk (ms)
Returns:
Recommended prefetch_size
Formula: prefetch_size = ceil(data_load_time / batch_time) + 1
This ensures the GPU never waits for data.
"""
import math
if batch_time_ms <= 0:
return 2 # Default
ratio = data_load_time_ms / batch_time_ms
return max(2, math.ceil(ratio) + 1)MNISTDataset.__init__ method · python · L49-L78 (30 LOC)src/cl/datasets/mnist.py
def __init__(self, config: Dict[str, Any]):
"""Initialize the MNIST dataset.
Args:
config: Configuration dictionary
"""
super().__init__(config)
self._n_tasks = config.get('n_task', 5)
self.rotation_range = config.get('rotation_range', DEFAULT_ROTATION_RANGE)
self.scaling_range = config.get('scaling_range', DEFAULT_SCALING_RANGE)
self.train_split = config.get('train_test_split', DEFAULT_TRAIN_TEST_SPLIT)
# Load MNIST dataset
print("Loading MNIST dataset")
my_transforms = transforms.Compose([transforms.ToTensor()])
self.dataset = torchvision.datasets.MNIST(
'./data', train=True, download=True, transform=my_transforms
)
# Extract images and labels
[self.images, self.labels] = [list(t) for t in zip(*self.dataset)]
self.images = torch.stack(self.images, dim=0)
self.labels = np.array(self.labels)
# Apply debug limit MNISTDataset._load_task_data method · python · L80-L118 (39 LOC)src/cl/datasets/mnist.py
def _load_task_data(self, task_id: int) -> None:
"""Load data for a specific MNIST task with transforms.
Applies rotation and scaling transforms based on task_id to create
distribution shift between tasks.
Args:
task_id: Task identifier (0-indexed)
"""
# Set deterministic seed for reproducible task generation
# Critical for accurate CL metrics - ensures task data is identical during training and evaluation
np.random.seed(task_id * DEFAULT_PERMUTATION_SEED_MULTIPLIER)
X = self.images.clone()
y = self.labels.copy()
# Apply task-specific transformations
rot_angle = np.random.random() * self.rotation_range
scaling_min, scaling_max = self.scaling_range
scaling = np.random.random() * (scaling_max - scaling_min) + scaling_min
X = torchvision.transforms.functional.affine(
X, rot_angle,
translate=(scaling, scaling),
scalPermutedMNISTDataset.__init__ method · python · L157-L186 (30 LOC)src/cl/datasets/mnist.py
def __init__(self, config: Dict[str, Any]):
"""Initialize the Permuted MNIST dataset.
Args:
config: Configuration dictionary
"""
super().__init__(config)
self._n_tasks = config.get('n_task', 10)
self.train_split = config.get('train_test_split', DEFAULT_TRAIN_TEST_SPLIT)
self.seed_multiplier = config.get('permutation_seed_multiplier', DEFAULT_PERMUTATION_SEED_MULTIPLIER)
self.image_size = config.get('image_size', DEFAULT_INPUT_SIZE_MNIST)
# Load MNIST dataset
print("Loading Permuted MNIST dataset")
my_transforms = transforms.Compose([transforms.ToTensor()])
self.dataset = torchvision.datasets.MNIST(
'./data', train=True, download=True, transform=my_transforms
)
# Extract images and labels
[self.images, self.labels] = [list(t) for t in zip(*self.dataset)]
self.images = torch.stack(self.images, dim=0)
self.labels = np.arrayPermutedMNISTDataset._load_task_data method · python · L188-L218 (31 LOC)src/cl/datasets/mnist.py
def _load_task_data(self, task_id: int) -> None:
"""Load data for a specific permutation task.
Applies a task-specific permutation to all pixels.
Args:
task_id: Task identifier (0-indexed)
"""
X = self.images.clone()
y = self.labels.copy()
# Generate task-specific permutation
rng = np.random.RandomState(seed=task_id * self.seed_multiplier)
perm = rng.permutation(self.image_size * self.image_size)
# Flatten, permute, reshape
# X shape: (N, 1, 28, 28) -> flatten to (N, 784) -> permute -> reshape back
X = X.view(X.shape[0], -1)[:, perm].view(X.shape[0], 1, self.image_size, self.image_size)
# Use deterministic train/test split for reproducibility
# Use the same RNG state for consistent splits across training and evaluation
n_samples = X.shape[0]
n_train = int(self.train_split * n_samples)
train_idx = rng.randint(0, n_samples, n_train)generate_sine_data function · python · L30-L75 (46 LOC)src/cl/datasets/sine.py
def generate_sine_data(delta: float, n_tasks: int = 40, output_path: str = 'data/Incremental_Sine1e^4.p',
seed: int = 1) -> str:
"""Generate sine data for continual learning tasks.
Creates a pickle file containing sine wave data for multiple tasks.
Each task has gradually increasing frequency and amplitude.
Args:
delta: Perturbation value for gradual task drift
n_tasks: Number of tasks to generate (default: 40)
output_path: Path to save the pickle file
seed: Random seed for reproducibility
Returns:
Path to the generated pickle file
Data format per task:
(y, time, phase, amplitude, frequency) where:
- y: Sine wave values, shape (n_samples, n_time_points)
- time: Time points array
- phase: Phase values, shape (n_samples, 1)
- amplitude: Amplitude values, shape (n_samples, 1)
- frequency: Frequency values, shape (n_samples, 1)
"""
# Added by ClRepobility · code-quality intelligence · https://repobility.com
SineDataset.__init__ method · python · L103-L149 (47 LOC)src/cl/datasets/sine.py
def __init__(self, config: Dict[str, Any]):
"""Initialize the sine dataset.
Args:
config: Configuration dictionary
"""
super().__init__(config)
self.delta = config.get('delta', 0.00001)
self.data_path = config.get('data_path', 'data/Incremental_Sine1e^4.p')
self._n_tasks = config.get('n_task', 40)
self.test_size = config.get('test_size', 0.2)
# Noise parameters for Experiment 3 (increasing noise per task)
self.noise_enabled = config.get('noise_enabled', False)
self.noise_scale = config.get('noise_scale', 0.1)
self.noise_increment = config.get('noise_increment', 0.05)
# Generate data if file doesn't exist
if not os.path.exists(self.data_path):
# Added by Claude: ensure parent directory exists before generating data
data_dir = os.path.dirname(self.data_path)
if data_dir and not os.path.exists(data_dir):
os.SineDataset._load_task_data method · python · L151-L186 (36 LOC)src/cl/datasets/sine.py
def _load_task_data(self, task_id: int) -> None:
"""Load data for a specific sine task.
Extracts sine data for the given task and creates train/test splits.
Features: [phase, amplitude, frequency]
Target: sine wave values (flattened)
Args:
task_id: Task identifier (0-indexed)
"""
if task_id >= self._available_tasks:
raise ValueError(f"Task {task_id} not available. Max task: {self._available_tasks - 1}")
# Extract task data
y, time, phase, amplitude, frequency = self.raw_data['task' + str(task_id)]
# Create feature matrix: [phase, amplitude, frequency]
X = np.concatenate([phase, amplitude.reshape([-1, 1]), frequency.reshape([-1, 1])], axis=1)
# Flatten y if needed (shape: n_samples x n_time_points -> n_samples)
# For regression, we typically predict all time points
# But original code treats y as target directly
y = y.astype(np.float32)BaseGraphDataset.__init__ method · python · L54-L85 (32 LOC)src/cl/datasets/synthetic_graph.py
def __init__(self, config: Dict[str, Any]):
"""Initialize the graph dataset.
Args:
config: Configuration dictionary containing:
- batch_size: Batch size for DataLoaders
- n_class: Number of classes for task sampling
- class_per_task: Number of classes per task
- debug_mode: (optional) Enable debug mode
- debug_limit: (optional) Number of samples in debug mode
"""
self.config = config
self.batch_size = config.get('batch_size', DEFAULT_BATCH_SIZE_GRAPH)
self.debug_mode = config.get('debug_mode', False)
self.debug_limit = config.get('debug_limit', 100)
# Graph datasets (to be set by subclass)
self.dataset = None
self.train_data = None
self.test_data = None
# Experience replay buffer (list of graph data objects)
self.memory_train: List = []
# Added by Claude: Separate test buffer foBaseGraphDataset._load_dataset method · python · L88-L96 (9 LOC)src/cl/datasets/synthetic_graph.py
def _load_dataset(self) -> None:
"""Load the graph dataset.
Subclasses must implement this to populate:
- self.dataset: Full dataset
- self.train_data: Training split
- self.test_data: Test split
"""
passBaseGraphDataset.generate_dataset method · python · L120-L184 (65 LOC)src/cl/datasets/synthetic_graph.py
def generate_dataset(self, task_id: int, batch_size: int = None,
phase: str = 'training') -> Tuple[DataLoader, DataLoader]:
"""Generate train/memory dataloaders for a task.
Implements continuum_Graph_classification logic with deterministic task selection.
Uses persistent task-to-class mapping to ensure reproducibility for CL metrics.
Args:
task_id: Current task ID (0-indexed)
batch_size: Batch size for DataLoader
phase: 'training' or 'testing'
Returns:
Tuple of (current_loader, memory_loader)
"""
if batch_size is None:
batch_size = self.batch_size
n_class = self.config.get('n_class', self.num_classes)
select = self.config.get('class_per_task', 2)
# Fixed by Claude: Restore proper task-based class selection
# Use persistent mapping to ensure reproducibility across CL evaluation
if task_id not in selfBaseGraphDataset.get_test_loader method · python · L186-L198 (13 LOC)src/cl/datasets/synthetic_graph.py
def get_test_loader(self, batch_size: int = None) -> DataLoader:
"""Get test data loader.
Args:
batch_size: Batch size for DataLoader
Returns:
DataLoader for test data
"""
if batch_size is None:
batch_size = self.batch_size
return DataLoader(self.test_data, batch_size=batch_size, shuffle=True)BaseGraphDataset.generate_test_loader method · python · L200-L218 (19 LOC)src/cl/datasets/synthetic_graph.py
def generate_test_loader(self, task_id: int, batch_size: int = None) -> DataLoader:
"""Generate test loader for a specific task (for CL metrics evaluation).
Added by Claude: This method enables per-task evaluation needed for computing
the performance matrix A[j][i] = accuracy on task i after training task j.
Args:
task_id: Task ID to generate test loader for
batch_size: Batch size for DataLoader
Returns:
DataLoader for task-specific test data
"""
if batch_size is None:
batch_size = self.batch_size
# Use generate_dataset in test phase to get task-specific data
test_loader, _ = self.generate_dataset(task_id, batch_size, phase='testing')
return test_loaderBaseGraphDataset.append_to_experience method · python · L220-L226 (7 LOC)src/cl/datasets/synthetic_graph.py
def append_to_experience(self, task_id: int) -> None:
"""No-op for graph datasets.
Experience replay is handled directly in generate_dataset by appending
to memory_train. This method exists for interface compatibility.
"""
passGenerated by Repobility's multi-pass static-analysis pipeline (https://repobility.com)
BaseGraphDataset.get_model_config method · python · L228-L239 (12 LOC)src/cl/datasets/synthetic_graph.py
def get_model_config(self) -> Dict[str, Any]:
"""Return configuration for model initialization.
Returns:
Dictionary with input_size, output_size, and graph-specific config
"""
return {
'input_size': self.num_features,
'output_size': self.num_classes,
'num_features': self.num_features,
'num_classes': self.num_classes,
}SyntheticGraphDataset._load_dataset method · python · L260-L296 (37 LOC)src/cl/datasets/synthetic_graph.py
def _load_dataset(self) -> None:
"""Load synthetic graph dataset."""
# Get dataset parameters from config
num_graphs = self.config.get('num_graphs', DEFAULT_SYNTHETIC_NUM_GRAPHS)
num_channels = self.config.get('num_channels', DEFAULT_SYNTHETIC_NUM_CHANNELS)
avg_num_nodes = self.config.get('avg_num_nodes', DEFAULT_SYNTHETIC_AVG_NUM_NODES)
num_classes = self.config.get('num_classes', DEFAULT_SYNTHETIC_NUM_CLASSES)
# Set seed for reproducibility
torch_geometric.seed.seed_everything(DEFAULT_GRAPH_SEED)
# Create synthetic dataset and shuffle (matching old code behavior)
self.dataset = FakeDataset(
num_graphs=num_graphs,
num_channels=num_channels,
avg_num_nodes=avg_num_nodes,
num_classes=num_classes,
transform=_transform_graph
).shuffle()
# Apply debug limit if enabled
if self.debug_mode:
print(f"DEBUG MODE: LimTaskShiftGraphDataset.__init__ method · python · L350-L369 (20 LOC)src/cl/datasets/synthetic_graph.py
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
# Task-shift parameters
self.task_shift_enabled = config.get('task_shift_enabled', DEFAULT_SYNTHETIC_TASK_SHIFT_ENABLED)
self.feature_noise_base = config.get('feature_noise_base', DEFAULT_SYNTHETIC_FEATURE_NOISE_BASE)
self.edge_dropout_base = config.get('edge_dropout_base', DEFAULT_SYNTHETIC_EDGE_DROPOUT_BASE)
self.feature_shift_base = config.get('feature_shift_base', DEFAULT_SYNTHETIC_FEATURE_SHIFT_BASE)
self.n_tasks = config.get('n_task', 5)
# Added by Claude: Perturbation control modes
self.perturbation_mode = config.get('perturbation_mode', 'linear')
self.perturbation_growth_rate = config.get('perturbation_growth_rate', 1.5)
self.perturbation_step_size = config.get('perturbation_step_size', 2)
self.custom_perturbations = config.get('custom_perturbations', {})
# Fixed by Claude: Store base train/test splitsTaskShiftGraphDataset._load_dataset method · python · L371-L426 (56 LOC)src/cl/datasets/synthetic_graph.py
def _load_dataset(self) -> None:
"""Load synthetic graph dataset with FIXED train/test split."""
# Get dataset parameters from config
num_graphs = self.config.get('num_graphs', DEFAULT_SYNTHETIC_NUM_GRAPHS)
num_channels = self.config.get('num_channels', DEFAULT_SYNTHETIC_NUM_CHANNELS)
avg_num_nodes = self.config.get('avg_num_nodes', DEFAULT_SYNTHETIC_AVG_NUM_NODES)
num_classes = self.config.get('num_classes', DEFAULT_SYNTHETIC_NUM_CLASSES_PER_TASK)
# Set seed for reproducibility
torch_geometric.seed.seed_everything(DEFAULT_GRAPH_SEED)
# Create synthetic dataset and shuffle ONCE
self.dataset = FakeDataset(
num_graphs=num_graphs,
num_channels=num_channels,
avg_num_nodes=avg_num_nodes,
num_classes=num_classes,
transform=_transform_graph
).shuffle()
# Apply debug limit if enabled
if self.debug_mode:
print(f"DETaskShiftGraphDataset._get_perturbation_params method · python · L428-L483 (56 LOC)src/cl/datasets/synthetic_graph.py
def _get_perturbation_params(self, task_id: int) -> Tuple[float, float, float]:
"""Get perturbation parameters for a specific task based on mode.
Added by Claude: Flexible perturbation control for different experimental designs.
Args:
task_id: Current task ID
Returns:
Tuple of (feature_noise_std, edge_dropout_prob, feature_shift)
"""
# Task 0 always has no perturbation (baseline)
if task_id == 0:
return 0.0, 0.0, 0.0
if self.perturbation_mode == 'linear':
# Linear scaling: perturbation = task_id * base
noise_std = task_id * self.feature_noise_base
dropout_prob = task_id * self.edge_dropout_base
shift = task_id * self.feature_shift_base
elif self.perturbation_mode == 'exponential':
# Exponential scaling: perturbation = base * (rate ^ task_id)
rate = self.perturbation_growth_rate
noise_std TaskShiftGraphDataset._apply_task_perturbation method · python · L485-L546 (62 LOC)src/cl/datasets/synthetic_graph.py
def _apply_task_perturbation(self, data_list: List, task_id: int, seed: int = None) -> List:
"""Apply task-specific perturbations to graph data.
Uses per-graph deterministic seeding so that the same graph always gets
the same perturbation pattern regardless of whether it's in train or test.
This fixes the train/test generalization gap caused by different random seeds.
Args:
data_list: List of graph data objects
task_id: Current task ID (0 = no perturbation)
seed: Base random seed (combined with graph index for per-graph seeding)
Returns:
List of perturbed graph data objects
"""
import copy
import torch
base_seed = seed if seed is not None else DEFAULT_GRAPH_SEED
# Get perturbation parameters based on mode
feature_noise_std, edge_dropout_prob, feature_shift = self._get_perturbation_params(task_id)
# Task 0 has no perturbation -TaskShiftGraphDataset._get_or_generate_task_data method · python · L548-L577 (30 LOC)src/cl/datasets/synthetic_graph.py
def _get_or_generate_task_data(self, task_id: int) -> Tuple[List, List]:
"""Generate or retrieve cached train/test data for a task.
Fixed by Claude: Now applies perturbations to FIXED train/test splits
instead of shuffling and re-splitting per task. This prevents data leakage.
Args:
task_id: Task ID
Returns:
Tuple of (train_data, test_data) for this task
"""
# Check if already generated
if task_id in self._task_train_data and task_id in self._task_test_data:
return self._task_train_data[task_id], self._task_test_data[task_id]
# Fixed by Claude: Apply perturbations to FIXED train and test sets SEPARATELY
# This ensures the same underlying graphs are always in train or test
train_data = self._apply_task_perturbation(
self._base_train_data, task_id, seed=DEFAULT_GRAPH_SEED
)
test_data = self._apply_task_perturbation(
seTaskShiftGraphDataset.generate_dataset method · python · L579-L625 (47 LOC)src/cl/datasets/synthetic_graph.py
def generate_dataset(self, task_id: int, batch_size: int = None,
phase: str = 'training') -> Tuple[DataLoader, DataLoader]:
"""Generate train/memory dataloaders for a task with domain shift.
Each task uses the same classes but with task-specific perturbations.
Memory buffer accumulates perturbed data from all seen tasks.
Args:
task_id: Current task ID (0-indexed)
batch_size: Batch size for DataLoader
phase: 'training' or 'testing'
Returns:
Tuple of (current_loader, memory_loader)
"""
if batch_size is None:
batch_size = self.batch_size
# Track if this is first time generating this task's data
is_new_task = task_id not in self._task_train_data
# Get or generate train/test data for this task
train_data, test_data = self._get_or_generate_task_data(task_id)
# Select appropriate data based on phase
Repobility · open methodology · https://repobility.com/research/
TaskShiftGraphDataset.get_perturbation_schedule method · python · L627-L643 (17 LOC)src/cl/datasets/synthetic_graph.py
def get_perturbation_schedule(self) -> Dict[int, Dict[str, float]]:
"""Return the perturbation schedule for all tasks.
Added by Claude: Useful for logging and visualization.
Returns:
Dict mapping task_id -> {noise, dropout, shift}
"""
schedule = {}
for task_id in range(self.n_tasks):
noise, dropout, shift = self._get_perturbation_params(task_id)
schedule[task_id] = {
'noise': noise,
'dropout': dropout,
'shift': shift
}
return scheduleTUGraphDataset._load_dataset method · python · L671-L701 (31 LOC)src/cl/datasets/synthetic_graph.py
def _load_dataset(self) -> None:
"""Load TU graph dataset."""
data_name = self.config.get('data', 'MUTAG')
# Set seed for reproducibility
torch_geometric.seed.seed_everything(DEFAULT_GRAPH_SEED)
# Load TU Dataset
self.dataset = TUDataset(
root='data/TUDataset',
name=data_name,
transform=_transform_graph
).shuffle()
# Apply debug limit if enabled
if self.debug_mode:
print(f"DEBUG MODE: Limiting {data_name} data from {len(self.dataset)} to {self.debug_limit} samples")
self.dataset = self.dataset[:self.debug_limit]
# Split into train/test
length = len(self.dataset)
train_split = self.config.get('train_test_split', DEFAULT_TRAIN_TEST_SPLIT)
self.train_data = self.dataset[:int(train_split * length)]
self.test_data = self.dataset[int(train_split * length):]
print(f'Dataset: {data_name}')
print('===load_graph_dataset function · python · L714-L734 (21 LOC)src/cl/datasets/synthetic_graph.py
def load_graph_dataset(config: Dict[str, Any]) -> BaseGraphDataset:
"""Factory function to load appropriate graph dataset.
Args:
config: Configuration dictionary with 'data' key
Returns:
Graph dataset instance (SyntheticGraphDataset, TaskShiftGraphDataset, or TUGraphDataset)
"""
data_name = config.get('data', 'synthetic')
if data_name == 'synthetic':
return SyntheticGraphDataset(config)
# Added by Claude: Task-shift synthetic dataset for domain-shift CL
elif data_name == 'synthetic_taskshift':
return TaskShiftGraphDataset(config)
elif data_name in ['MUTAG', 'ENZYMES', 'PROTEINS']:
return TUGraphDataset(config)
else:
raise ValueError(f"Unknown graph dataset: {data_name}. "
f"Supported: 'synthetic', 'synthetic_taskshift', 'MUTAG', 'ENZYMES', 'PROTEINS'")CNNorig.__init__ method · python · L34-L45 (12 LOC)src/cl/models/cnn.py
def __init__(self, key):
key1, key2, key3, key4 = jax.random.split(key, 4)
# Standard CNN setup: convolutional layer, followed by flattening,
# with a small MLP on top.
self.conv_layers = [
eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
]
self.feed_layers =[
eqx.nn.Linear(1728, 512, key=key2),
eqx.nn.Linear(512, 64, key=key3),
eqx.nn.Linear(64, 10, key=key4),
]CNNorig.__call__ method · python · L47-L52 (6 LOC)src/cl/models/cnn.py
def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
x = jnp.ravel(jax.nn.relu(eqx.nn.MaxPool2d(kernel_size=2, stride=2)(self.conv_layers[0](x))))
x = jax.nn.relu(self.feed_layers[0](x))
x = jax.nn.relu(self.feed_layers[1](x))
x = self.feed_layers[2](x)
return xCNN.__init__ method · python · L74-L143 (70 LOC)src/cl/models/cnn.py
def __init__(self, key, filter_size, feed_sizes,
input_size=None,
channel_in=1,
channel_out=None,
awb_arch=None,
awb_filter_size=None,
padding=None,
stride=None):
"""
Args:
key: PRNG key
filter_size: Convolutional filter size
feed_sizes: List of feed-forward layer sizes
input_size: Input image size (default: 28 for MNIST/Omniglot)
channel_in: Number of input channels (default: 1)
channel_out: Number of output channels (default: 3)
awb_arch: AWB architecture (default: [1875, 700, 100, 10])
awb_filter_size: AWB filter size (default: 5)
padding: Convolution padding (default: 0)
stride: Convolution stride (default: 1)
"""
key1, key2, key3, key4 = jax.random.split(key, 4)
# Set defaults from constants
seCNN.calc_output_size method · python · L145-L150 (6 LOC)src/cl/models/cnn.py
def calc_output_size(self, fil_size, input_size=None):
"""Calculate output size after convolution"""
if input_size is None:
input_size = self.input_size
output = ((input_size - fil_size + 2 * self.padding) / self.stride) + 1
return int(output)CNN.pool_output_size method · python · L152-L157 (6 LOC)src/cl/models/cnn.py
def pool_output_size(self, pool_size, conv_inputsize, pool_stride=None):
"""Calculate output size after pooling"""
if pool_stride is None:
pool_stride = DEFAULT_POOL_STRIDE
output = ((conv_inputsize - pool_size) / pool_stride) + 1
return int(output)Repobility (the analyzer behind this table) · https://repobility.com
CNN.__call__ method · python · L159-L168 (10 LOC)src/cl/models/cnn.py
def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
x = jnp.ravel(jax.nn.relu(eqx.nn.MaxPool2d(kernel_size=2, stride=2)(self.conv_layers[0](x))))
for lin in self.feed_layers[:-1]:
#print("x in model:", x.shape)
x = jax.nn.relu(lin(x))
x = self.feed_layers[-1](x)
#for lin in self.feed_layers:
#x = jax.nn.relu(lin(x))
#x = self.feed_layers[0](x)
return xCNN.get_AWBT method · python · L175-L205 (31 LOC)src/cl/models/cnn.py
def get_AWBT(self, x):
"""Forward pass using AWB transformation.
Uses AWBMixin.awb_transform_conv for efficient batched computation.
Single einsum call replaces nested Python loops, enabling JIT optimization.
"""
# AWB transformation on conv layer using AWBMixin
# For single-channel input (MNIST), conv weight is (out_ch, 1, H, W)
# We transform each output channel: A[i] @ W[i,0] @ B[i].T
# Using stacked arrays: A_conv (out_ch, new_f, old_f), W[:, 0] (out_ch, H, W)
conv_weights = self.conv_layers[0].weight[:, 0] # (channel_out, H, W)
weights_transformed = self.awb_transform_conv(self.A_conv, conv_weights, self.B_conv)
# Add back the channel_in dimension for convolution
weights_transformed = weights_transformed[:, jnp.newaxis, :, :] # (out_ch, 1, new_H, new_W)
x = jnp.expand_dims(x, axis=0)
x = jax.lax.conv_general_dilated(lhs=x, rhs=weights_transformed, window_strides=CNN.get_awb_layer_specs method · python · L208-L214 (7 LOC)src/cl/models/cnn.py
def get_awb_layer_specs(self) -> List[AWBLayerSpec]:
"""Get AWB specs for feed layers only (conv handled separately)."""
return [
AWBLayerSpec(layer=self.feed_layers[i], A=self.A_feed[i], B=self.B_feed[i],
layer_type='linear2', layer_index=i)
for i in range(len(self.feed_layers))
]CNN.partition_for_AB_training method · python · L216-L221 (6 LOC)src/cl/models/cnn.py
def partition_for_AB_training(self):
"""Partition for A/B training (freeze W, train A/B)."""
filter_spec = jtu.tree_map(lambda _: False, self)
filter_spec = eqx.tree_at(lambda x: (x.A_conv, x.B_conv, x.A_feed, x.B_feed),
filter_spec, replace=(True, True, True, True))
return eqx.partition(self, filter_spec)CNN.partition_for_standard_training method · python · L223-L252 (30 LOC)src/cl/models/cnn.py
def partition_for_standard_training(self):
"""Partition for standard training (freeze A/B, train W).
A_conv, B_conv are stacked arrays, A_feed, B_feed are lists of arrays.
Uses eqx.is_array to properly separate arrays from non-arrays (ints, etc.),
then moves A/B matrices to static for freezing.
"""
# Fixed by Claude: Use eqx.is_array to properly handle non-array leaves (ints)
# Previous approach with tree_map(lambda _: True, self) put ints in params
# which caused jax.grad to fail with "int64 not supported" error
params, static = eqx.partition(self, eqx.is_array)
# Move A/B matrices from params to static (freeze them)
# Use is_leaf=lambda x: x is None to handle None values in static
static = eqx.tree_at(
lambda x: (x.A_conv, x.B_conv, x.A_feed, x.B_feed),
static,
replace=(self.A_conv, self.B_conv, self.A_feed, self.B_feed),
is_leaf=lambdaCNN.generate_search_candidates method · python · L255-L318 (64 LOC)src/cl/models/cnn.py
def generate_search_candidates(self, iteration, current_best, config):
"""Generate candidate architectures for CNN search.
# Added by Claude: Model-specific search strategy for CNN
CNN searches over filter_size and feed hidden dimensions.
Returns list of (filter_size, feed_sizes) tuples.
Args:
iteration: Current search iteration (0-indexed)
current_best: Tuple of (filter_size, feed_sizes) for current best
config: Configuration dict with search hyperparameters
Returns:
List of (filter_size, feed_sizes) candidate tuples
"""
from ..config.constants import (
DEFAULT_ARCH_SEARCH_HIDDEN_RANGE,
DEFAULT_ARCH_SEARCH_FILTER_MIN,
DEFAULT_ARCH_SEARCH_FILTER_MAX,
DEFAULT_ARCH_SEARCH_STEP_SIZE_MLP,
)
# Get search hyperparameters
search_hidden_range = config.get('arch_search_hidden_range', DEFAULT_ARCH_SEARCH_HICNN.create_with_architecture method · python · L320-L346 (27 LOC)src/cl/models/cnn.py
def create_with_architecture(self, arch_spec, seed=0, awb_enabled=True):
"""Create CNN instance with specified architecture.
# Added by Claude: Instance method for search (uses self for channel_in, etc.)
Args:
arch_spec: Tuple of (filter_size, feed_sizes)
seed: Random seed for weight initialization
awb_enabled: Not used for CNN (always has AWB matrices)
Returns:
New CNN instance with specified architecture
"""
filter_size, feed_sizes = arch_spec
# Added by Claude: Use self attributes to preserve channel_in, channel_out, input_size
key = jax.random.PRNGKey(seed)
return CNN(
key=key,
filter_size=filter_size,
feed_sizes=feed_sizes,
input_size=self.input_size,
channel_in=self.channel_in,
channel_out=self.channel_out,
padding=self.padding,
stride=self.stride
)CNN.reinitialize_weights method · python · L348-L372 (25 LOC)src/cl/models/cnn.py
def reinitialize_weights(self, seed=0):
"""Reinitialize model weights for fair architecture comparison.
# Added by Claude: Weight reinitialization for search
For CNN, we create a fresh model instead of reinitializing in place
since the architecture is already set.
Args:
seed: Random seed for reproducible initialization
Returns:
New CNN with fresh weights (same architecture)
"""
# For CNN, architecture is fixed at creation, so we return a fresh instance
key = jax.random.PRNGKey(seed)
return CNN(
key=key,
filter_size=self.filter_size,
feed_sizes=self.feed_sizes,
input_size=self.input_size,
channel_in=self.channel_in,
channel_out=self.channel_out,
padding=self.padding,
stride=self.stride
)Repobility · code-quality intelligence · https://repobility.com
CNN3D.__init__ method · python · L396-L490 (95 LOC)src/cl/models/cnn.py
def __init__(self, key, filter_size, feed_sizes,
input_size=None,
channel_in=None,
channel_out=None,
num_classes=10,
awb_filter_increment=None,
awb_hidden_layers=None):
"""
Args:
key: PRNG key
filter_size: Convolutional filter size
feed_sizes: List of feed-forward layer sizes
input_size: Input image size (default: 32 for CIFAR)
channel_in: Number of input channels (default: 3)
channel_out: Number of output channels (default: 32)
num_classes: Number of output classes (default: 10)
awb_filter_increment: Increment for AWB filter size (default: 2)
awb_hidden_layers: AWB hidden layer sizes (default: [512, 256])
"""
key1, key2, key3, key4, key5 = jax.random.split(key, 5)
# Set defaults from constants
if input_size is None:
CNN3D.calc_output_size method · python · L492-L498 (7 LOC)src/cl/models/cnn.py
def calc_output_size(self, input_size, fil_size, pool_size=2):
"""Calculate output size after convolution and pooling."""
# After conv: (input_size - fil_size + 1)
# After pool: floor((conv_out) / pool_size)
conv_out = input_size - fil_size + 1
pool_out = conv_out // pool_size
return pool_outCNN3D.__call__ method · python · L500-L513 (14 LOC)src/cl/models/cnn.py
def __call__(self, x: Float[Array, "3 32 32"]) -> Float[Array, "num_classes"]:
# First conv + pool
x = jax.nn.relu(self.conv_layers[0](x))
x = eqx.nn.MaxPool2d(kernel_size=2, stride=2)(x)
# Second conv + pool
x = jax.nn.relu(self.conv_layers[1](x))
x = eqx.nn.MaxPool2d(kernel_size=2, stride=2)(x)
# Flatten
x = jnp.ravel(x)
# Feed forward layers
for lin in self.feed_layers[:-1]:
x = jax.nn.relu(lin(x))
x = self.feed_layers[-1](x)
return xCNN3D.get_AWBT method · python · L515-L555 (41 LOC)src/cl/models/cnn.py
def get_AWBT(self, x):
"""Forward pass using AWB transformation.
Uses AWBMixin.awb_transform_conv for efficient batched computation.
Single einsum call replaces nested Python loops, enabling JIT optimization.
"""
# AWB transformation on first conv layer using AWBMixin
# A_conv1, B_conv1: (channel_out, channel_in, new_f, old_f)
# conv weight: (channel_out, channel_in, old_f, old_f)
weights_transformed1 = self.awb_transform_conv(self.A_conv1, self.conv_layers[0].weight, self.B_conv1)
x = jnp.expand_dims(x, axis=0)
x = jax.lax.conv_general_dilated(lhs=x, rhs=weights_transformed1, window_strides=(1, 1), padding="VALID")
x = x.squeeze(0)
# Add conv1 bias - eqx.nn.Conv2d bias is (out_ch, 1, 1), broadcasts with (out_ch, H, W)
x = x + self.conv_layers[0].bias
x = jax.nn.relu(x)
x = eqx.nn.MaxPool2d(kernel_size=2, stride=2)(x)
# AWB transformation on second convCNN3D.get_awb_layer_specs method · python · L558-L564 (7 LOC)src/cl/models/cnn.py
def get_awb_layer_specs(self) -> List[AWBLayerSpec]:
"""Get AWB specs for feed layers only (conv handled separately)."""
return [
AWBLayerSpec(layer=self.feed_layers[i], A=self.A_feed[i], B=self.B_feed[i],
layer_type='linear2', layer_index=i)
for i in range(len(self.feed_layers))
]CNN3D.partition_for_AB_training method · python · L566-L571 (6 LOC)src/cl/models/cnn.py
def partition_for_AB_training(self):
"""Partition for A/B training (freeze W, train A/B)."""
filter_spec = jtu.tree_map(lambda _: False, self)
filter_spec = eqx.tree_at(lambda x: (x.A_conv1, x.B_conv1, x.A_conv2, x.B_conv2, x.A_feed, x.B_feed),
filter_spec, replace=(True, True, True, True, True, True))
return eqx.partition(self, filter_spec)CNN3D.partition_for_standard_training method · python · L573-L604 (32 LOC)src/cl/models/cnn.py
def partition_for_standard_training(self):
"""Partition for standard training (freeze A/B, train W).
A_conv1, B_conv1, A_conv2, B_conv2 are now stacked 4D arrays.
A_feed, B_feed are lists of arrays.
Uses eqx.is_array to properly separate arrays from non-arrays (ints, etc.),
then moves A/B matrices to static for freezing.
"""
# Fixed by Claude: Use eqx.is_array to properly handle non-array leaves (ints)
# Previous approach with tree_map(lambda _: True, self) put ints in params
# which caused jax.grad to fail with "int64 not supported" error
params, static = eqx.partition(self, eqx.is_array)
# Move A/B matrices from params to static (freeze them)
# Use is_leaf=lambda x: x is None to handle None values in static
static = eqx.tree_at(
lambda x: (x.A_conv1, x.B_conv1, x.A_conv2, x.B_conv2, x.A_conv, x.B_conv, x.A_feed, x.B_feed),
static,
replace=(CNN3D.generate_search_candidates method · python · L607-L672 (66 LOC)src/cl/models/cnn.py
def generate_search_candidates(self, iteration, current_best, config):
"""Generate candidate architectures for CNN3D search.
# Added by Claude: Model-specific search strategy for CNN3D
CNN3D searches over filter_size and feed hidden dimensions (same as CNN).
Returns list of (filter_size, feed_sizes) tuples.
Args:
iteration: Current search iteration (0-indexed)
current_best: Tuple of (filter_size, feed_sizes) for current best
config: Configuration dict with search hyperparameters
Returns:
List of (filter_size, feed_sizes) candidate tuples
"""
from ..config.constants import (
DEFAULT_ARCH_SEARCH_HIDDEN_RANGE,
DEFAULT_ARCH_SEARCH_FILTER_MIN,
DEFAULT_ARCH_SEARCH_FILTER_MAX,
DEFAULT_ARCH_SEARCH_STEP_SIZE_MLP,
)
# Get search hyperparameters
search_hidden_range = config.get('arch_search_hidden_range', DEGenerated by Repobility's multi-pass static-analysis pipeline (https://repobility.com)
CNN3D.create_with_architecture method · python · L674-L698 (25 LOC)src/cl/models/cnn.py
def create_with_architecture(self, arch_spec, seed=0, awb_enabled=True):
"""Create CNN3D instance with specified architecture.
# Added by Claude: Instance method for search (uses self for channel_in, etc.)
Args:
arch_spec: Tuple of (filter_size, feed_sizes)
seed: Random seed for weight initialization
awb_enabled: Not used for CNN3D (always has AWB matrices)
Returns:
New CNN3D instance with specified architecture
"""
filter_size, feed_sizes = arch_spec
# Added by Claude: Use self attributes to preserve channel_in, channel_out, input_size
key = jax.random.PRNGKey(seed)
return CNN3D(
key=key,
filter_size=filter_size,
feed_sizes=feed_sizes,
input_size=self.input_size,
channel_in=self.channel_in,
channel_out=self.channel_out
)CNN3D.reinitialize_weights method · python · L700-L721 (22 LOC)src/cl/models/cnn.py
def reinitialize_weights(self, seed=0):
"""Reinitialize model weights for fair architecture comparison.
# Added by Claude: Weight reinitialization for search
For CNN3D, we create a fresh model instead of reinitializing in place.
Args:
seed: Random seed for reproducible initialization
Returns:
New CNN3D with fresh weights (same architecture)
"""
# For CNN3D, architecture is fixed at creation, so we return a fresh instance
key = jax.random.PRNGKey(seed)
return CNN3D(
key=key,
filter_size=self.filter_size,
feed_sizes=self.feed_sizes,
input_size=self.input_size,
channel_in=self.channel_in,
channel_out=self.channel_out
)CNNAWBOps.search_architecture method · python · L739-L752 (14 LOC)src/cl/models/cnn.py
def search_architecture(self, model, task_id, baseline_loss, dataloader_curr,
dataloader_exp, test_loader_curr, test_loader_exp, config, trainer=None):
"""Search for optimal CNN architecture."""
from ..core.arch_search import search_architecture
# Extract current architecture
baseline_arch = (model.filter_size, list(model.feed_sizes))
model_type = 'cnn3d' if self.is_cnn3d else 'cnn'
return search_architecture(
model=model, baseline_arch=baseline_arch, task_id=task_id,
baseline_loss=baseline_loss, dataloader_curr=dataloader_curr,
dataloader_exp=dataloader_exp, test_loader_curr=test_loader_curr,
test_loader_exp=test_loader_exp, config=config, trainer=trainer,
model_type=model_type
)