Function bodies 268 total
load_results function · python · L16-L30 (15 LOC)examples/utils.py
def load_results(pkl_path: str) -> Dict:
"""Load experiment results from pickle file.
Args:
pkl_path: Path to the pickle file containing experiment results
Returns:
Dictionary containing:
- 'metadata': Experiment configuration (n_tasks, epochs_per_task, etc.)
- 'tasks': Per-task training metrics (H, V, grad_norm, metrics)
- 'task_performance_matrix': Performance matrix for CL metrics
"""
with open(pkl_path, 'rb') as f:
data = pickle.load(f)
return datacompute_cl_metrics function · python · L33-L87 (55 LOC)examples/utils.py
def compute_cl_metrics(data: Dict, is_regression: bool = False) -> Dict:
"""Compute continual learning metrics from experiment data.
Args:
data: Loaded experiment data dictionary
is_regression: If True, compute MSE-based metrics; otherwise accuracy-based
Returns:
Dictionary containing:
- 'avg_metric': Average final performance across all tasks
- 'bwt': Backward Transfer (how much performance changed on old tasks)
- 'forgetting': Maximum forgetting across tasks
"""
matrix = data.get('task_performance_matrix', {})
if not matrix:
return {'avg_metric': 0.0, 'bwt': 0.0, 'forgetting': 0.0}
# Convert to numpy array
n_tasks = len(matrix)
perf_matrix = np.zeros((n_tasks, n_tasks))
for j in range(n_tasks):
for i in range(n_tasks):
i_str = str(i)
if i_str in matrix.get(j, {}):
perf_matrix[j, i] = matrix[j][i_str]
# Average final performance (lasextract_training_curves function · python · L90-L118 (29 LOC)examples/utils.py
def extract_training_curves(data: Dict, metric_key: str = 'test_experience') -> Tuple[np.ndarray, np.ndarray]:
"""Extract training curves from experiment data.
Args:
data: Loaded experiment data dictionary
metric_key: Which metric to extract ('H', 'V', 'grad_norm', 'test_current', 'test_experience')
Returns:
Tuple of (iterations, values) arrays
"""
tasks = data.get('tasks', {})
all_iters = []
all_values = []
offset = 0
for task_id in sorted(tasks.keys()):
task_data = tasks[task_id]
main_training = task_data.get('main_training', {})
iters = main_training.get('iterations', [])
values = main_training.get(metric_key, [])
if len(iters) > 0 and len(values) > 0:
all_iters.extend([i + offset for i in iters])
all_values.extend(values)
offset = all_iters[-1] + 1 if all_iters else 0
return np.array(all_iters), np.array(all_values)plot_comparison function · python · L121-L232 (112 LOC)examples/utils.py
def plot_comparison(baseline_data: Dict, awb_data: Dict,
metric_type: str = 'accuracy',
title_prefix: str = '',
figsize: Tuple[int, int] = (14, 10)) -> plt.Figure:
"""Create a 4-panel comparison plot between Baseline and AWB methods.
Args:
baseline_data: Loaded data for baseline method
awb_data: Loaded data for AWB method
metric_type: 'accuracy' or 'mse' (affects scaling and labels)
title_prefix: Prefix for the figure title (e.g., 'Sine', 'MNIST')
figsize: Figure size as (width, height)
Returns:
matplotlib Figure object
"""
fig, axes = plt.subplots(2, 2, figsize=figsize)
# Colors
baseline_color = '#1f77b4' # Blue
awb_color = '#ff7f0e' # Orange
# Panel (a): Test metric over training
ax = axes[0, 0]
metric_key = 'test_experience'
iters_b, vals_b = extract_training_curves(baseline_data, metric_key)
iters_a, vals_a = extract_traadd_task_boundaries function · python · L235-L251 (17 LOC)examples/utils.py
def add_task_boundaries(ax: plt.Axes, n_tasks: int, epochs_per_task: int,
iterations_per_epoch: int = 1, linestyle: str = '--',
color: str = 'gray', alpha: float = 0.5):
"""Add vertical lines at task boundaries.
Args:
ax: Matplotlib axes object
n_tasks: Number of tasks
epochs_per_task: Epochs per task
iterations_per_epoch: Number of iterations per epoch
linestyle: Line style for boundaries
color: Line color
alpha: Line transparency
"""
for task in range(1, n_tasks):
x = task * epochs_per_task * iterations_per_epoch
ax.axvline(x=x, linestyle=linestyle, color=color, alpha=alpha)print_experiment_summary function · python · L254-L283 (30 LOC)examples/utils.py
def print_experiment_summary(data: Dict, name: str = 'Experiment'):
"""Print a summary of the experiment results.
Args:
data: Loaded experiment data
name: Name to display in the summary
"""
metadata = data.get('metadata', {})
print(f"\n{'='*50}")
print(f" {name} Summary")
print(f"{'='*50}")
print(f"Tasks: {metadata.get('n_tasks', 'N/A')}")
print(f"Epochs per task: {metadata.get('epochs_per_task', 'N/A')}")
print(f"Problem type: {metadata.get('problem', 'N/A')}")
print(f"Network: {metadata.get('network', 'N/A')}")
print(f"AWB enabled: {metadata.get('awb_enabled', False)}")
# Compute and print CL metrics
is_regression = metadata.get('problem', '') == 'regression'
metrics = compute_cl_metrics(data, is_regression)
print(f"\nContinual Learning Metrics:")
if is_regression:
print(f" Average MSE: {metrics['avg_metric']:.6f}")
else:
print(f" Average Accuracy: {metrics['avg_metric']:.4f}main function · python · L38-L113 (76 LOC)run.py
def main():
parser = argparse.ArgumentParser(
description='Continual Learning Framework',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python run.py kkt_run/config/sine.json
python run.py kkt_run/config/sine.json --runs 0
python run.py kkt_run/config/sine.json --runs 2 --no-plots
python run.py kkt_run/config/sine.json --runs 1 --figures-dir outputs/figures
"""
)
parser.add_argument('config', type=str, help='Path to JSON configuration file')
parser.add_argument('--runs', type=int, default=0, help='Run ID for this experiment (default: 0)')
parser.add_argument('--no-plots', action='store_true', help='Skip plot generation')
parser.add_argument('--figures-dir', type=str, default='figures', help='Output directory for figures')
parser.add_argument('--output-dir', type=str, default=None, help='Custom output directory for results')
parser.add_argument('--model-suffix', type=str, defaultRepobility · open methodology · https://repobility.com/research/
arch_search_CNN_fresh function · python · L41-L139 (99 LOC)src/cl/arch_search/cnn_search.py
def arch_search_CNN_fresh(filter_size, feed_sizes, task, trainW_loss, og_epochs, config,
dataloader_curr, dataloader_exp, test_loader_curr, test_loader_exp,
trainer=None):
"""Architecture search for CNN and CNN3D (BACKWARD COMPATIBLE).
This is a backward-compatible wrapper that uses the new generic search algorithm.
Automatically detects CNN vs CNN3D based on channel_in config.
Args:
filter_size: Current filter size
feed_sizes: Current feed layer sizes [feed_input, h1, h2, n_class]
task: Current task ID
trainW_loss: Training loss from preliminary training
og_epochs: Base epochs (overridden by config)
config: Configuration dictionary
dataloader_curr: Current task training data
dataloader_exp: Experience replay data
test_loader_curr: Current task test data
test_loader_exp: Experience replay test data
trainer: Trainer instance (optioprepABs function · python · L148-L204 (57 LOC)src/cl/arch_search/cnn_search.py
def prepABs(model, prev_feed_sizes, prev_filter_size, new_feed_sizes, new_filter_size):
"""
Prepare A and B transformation matrices for CNN architecture transition.
The AWB algorithm transforms OLD weights W_old using: A @ W_old @ B.T
- A transforms output dimensions: shape (new_out, old_out)
- B transforms input dimensions: shape (new_in, old_in)
W_old stays in the model - we do NOT recreate the model.
Args:
model: CNN model with OLD weights (W_old)
prev_feed_sizes: Previous/old feed layer sizes
prev_filter_size: Previous/old filter size
new_feed_sizes: New/target feed layer sizes
new_filter_size: New/target filter size
Returns:
Tuple of (A_feed, B_feed, A_conv, B_conv) transformation matrices
"""
initializer = jax.nn.initializers.glorot_uniform()
# Check what changed
hidden_changed = (list(prev_feed_sizes[1:3]) != list(new_feed_sizes[1:3]))
filter_changed = (new_filter_size != preprepABs_CNN3D function · python · L207-L283 (77 LOC)src/cl/arch_search/cnn_search.py
def prepABs_CNN3D(model, prev_feed_sizes, prev_filter_size, new_feed_sizes, new_filter_size):
"""
Prepare A and B transformation matrices for CNN3D architecture transition.
The AWB algorithm transforms OLD weights W_old using: A @ W_old @ B.T
- A transforms output dimensions: shape (new_out, old_out)
- B transforms input dimensions: shape (new_in, old_in)
W_old stays in the model - we do NOT recreate the model.
CNN3D has two conv layers, so we need A/B for both conv1 and conv2.
Args:
model: CNN3D model with OLD weights (W_old)
prev_feed_sizes: Previous/old feed layer sizes
prev_filter_size: Previous/old filter size
new_feed_sizes: New/target feed layer sizes
new_filter_size: New/target filter size
Returns:
Tuple of (A_feed, B_feed, A_conv1, B_conv1, A_conv2, B_conv2) transformation matrices
"""
initializer = jax.nn.initializers.glorot_uniform()
# Check what changed
hidden_changed = arch_search_GCN function · python · L27-L90 (64 LOC)src/cl/arch_search/gcn_search.py
def arch_search_GCN(original_gcn: List[int], original_mlp: List[int],
task: int, trainW_loss: float, og_epochs: int,
config: Dict[str, Any], train_loader, mem_train_loader, test_loader,
trainer=None, model=None) -> Tuple[List[int], List[int]]:
"""Architecture search for GCN (BACKWARD COMPATIBLE).
This is a backward-compatible wrapper that uses the new generic search algorithm.
Args:
original_gcn: Original GCN architecture sizes [in_size, hidden, ...]
original_mlp: Original MLP architecture sizes [gcn_out, hidden1, hidden2, n_class]
task: Current task ID
trainW_loss: Loss from preliminary training
og_epochs: Number of epochs per search iteration
config: Configuration dictionary
train_loader: Current task training DataLoader
mem_train_loader: Memory/experience replay DataLoader
test_loader: Test DataLoader
trainer: Trainer instance (optprepABs_GCN function · python · L93-L205 (113 LOC)src/cl/arch_search/gcn_search.py
def prepABs_GCN(model, prev_feed_sizes: List[int], prev_gcn_sizes: List[int]):
"""
Prepare A and B transformation matrices for GCN architecture transition.
Based on logic from run_AWB_ALL_functions.py train_model_graph function.
Args:
model: GCN model with new architecture
prev_feed_sizes: Previous feed layer sizes
prev_gcn_sizes: Previous GCN layer sizes
Returns:
Tuple of (A_feed, B_feed, A_gcn, B_gcn) transformation matrices
"""
opt_feed_sizes = model.feed_sizes
opt_gcn_sizes = model.gcn_sizes
initializer = jax.nn.initializers.glorot_uniform()
# Extract ACTUAL layer dimensions from model weights
# The model.feed_sizes may have been updated, but the actual layer weights still have old dimensions
actual_feed_sizes = [model.feed_layers[0].weight.shape[1]] # First layer input size
for layer in model.feed_layers:
actual_feed_sizes.append(layer.weight.shape[0]) # Output size
actual_gcn_si_create_search_model function · python · L35-L47 (13 LOC)src/cl/arch_search/mlp_search.py
def _create_search_model(arch: List[int], seed: int = 0, awb_enabled: bool = False) -> MLP:
"""Create an MLP model with specified architecture for search.
Args:
arch: Architecture sizes list [input, h1, h2, ..., output]
seed: Random seed for weight initialization
awb_enabled: Whether to enable AWB matrices
Returns:
Initialized MLP model
"""
key = jax.random.PRNGKey(seed)
return MLP(sizes=arch, key=key, awb_enabled=awb_enabled)_reinitialize_weights function · python · L50-L76 (27 LOC)src/cl/arch_search/mlp_search.py
def _reinitialize_weights(model: MLP, seed: int = 0) -> MLP:
"""Reinitialize model weights with glorot uniform initialization.
Used before each architecture search trial to ensure fair comparison
between different architectures.
Args:
model: MLP model to reinitialize
seed: Random seed for initialization
Returns:
Model with reinitialized weights
"""
initializer = jax.nn.initializers.glorot_uniform()
# Reinitialize each layer's weights and biases
for j in range(len(model.sizes) - 1):
in_size = model.sizes[j]
out_size = model.sizes[j + 1]
weight = initializer(jax.random.PRNGKey(seed + j), (out_size, in_size))
bias = initializer(jax.random.PRNGKey(seed + j + 100), (1, out_size))
model = eqx.tree_at(lambda x: x.layers[j].weight, model, weight)
model = eqx.tree_at(lambda x: x.layers[j].bias, model, bias)
return model_compute_search_loss function · python · L79-L109 (31 LOC)src/cl/arch_search/mlp_search.py
def _compute_search_loss(record_dict: Dict, task_id: int, epochs: int,
window: int = None) -> float:
"""Compute average loss over last `window` epochs for architecture comparison.
Args:
record_dict: Dictionary containing training records
task_id: Current task ID
epochs: Total epochs trained
window: Number of epochs to average (default: DEFAULT_ARCH_SEARCH_AVERAGING_WINDOW)
Returns:
Average loss value
"""
if window is None:
window = DEFAULT_ARCH_SEARCH_AVERAGING_WINDOW
losses = []
iterations = record_dict.get('iterations', record_dict)
for j in range(1, window + 1):
iteration = (task_id + 1) * epochs - j
if iteration in iterations:
record = iterations[iteration]
if isinstance(record, dict) and 'losses' in record:
losses.append(record['losses'].get('V', 0))
elif isinstance(record, tuple):
losses.apProvenance: Repobility (https://repobility.com) — every score reproducible from /scan/
_train_candidate_architecture function · python · L112-L190 (79 LOC)src/cl/arch_search/mlp_search.py
def _train_candidate_architecture(
model: MLP,
trainer: Trainer,
task_id: int,
epochs: int,
config: Dict[str, Any],
dataloader_curr,
dataloader_exp,
test_loader_curr,
test_loader_exp,
averaging_window: int = None,
) -> Tuple[MLP, float]:
"""Train a candidate architecture and return its final loss.
Args:
model: MLP model with candidate architecture
trainer: Trainer instance
task_id: Current task ID
epochs: Number of training epochs
config: Training configuration
dataloader_curr: Current task training data
dataloader_exp: Experience replay data
test_loader_curr: Current task test data
test_loader_exp: Experience replay test data
averaging_window: Number of epochs to average for loss computation
Returns:
Tuple of (trained_model, average_loss)
"""
# Partition model for standard training (A/B frozen if present)
params, static = eqx.partitiarch_search_MLP function · python · L193-L269 (77 LOC)src/cl/arch_search/mlp_search.py
def arch_search_MLP(
original_arch: List[int],
task_id: int,
trainW_loss: float,
og_epochs: int,
config: Dict[str, Any],
dataloader_curr,
dataloader_exp,
test_loader_curr,
test_loader_exp,
current_model: Optional[MLP] = None,
) -> List[int]:
"""Architecture search for MLP regression models.
# Added by Claude: Now delegates to core.arch_search.search_architecture()
This is a backward-compatible wrapper that uses the new generic search algorithm.
Searches for an optimal architecture by training candidate architectures
with different hidden layer sizes and selecting the one with lowest loss.
The search process:
1. Uses the current model's loss (trainW_loss) as baseline for comparison
2. Iteratively explores architectures with incrementally larger hidden layers
3. Uses a grid search over hidden layer sizes within a search range
4. Returns the architecture with lowest loss (or original if no improvement)
Params.__init__ method · python · L24-L32 (9 LOC)src/cl/config/params.py
def __init__(self, json_path: str):
"""Load parameters from JSON file.
Args:
json_path: Path to JSON configuration file
"""
with open(json_path) as f:
params = json.load(f)
self.__dict__.update(params)Params.save method · python · L34-L41 (8 LOC)src/cl/config/params.py
def save(self, json_path: str):
"""Save parameters to JSON file.
Args:
json_path: Path to save JSON file
"""
with open(json_path, 'w') as f:
json.dump(self.__dict__, f, indent=4)Params.update method · python · L43-L51 (9 LOC)src/cl/config/params.py
def update(self, json_path: str):
"""Update parameters from another JSON file.
Args:
json_path: Path to JSON file with updates
"""
with open(json_path) as f:
params = json.load(f)
self.__dict__.update(params)apply_defaults function · python · L63-L263 (201 LOC)src/cl/config/params.py
def apply_defaults(config: Dict[str, Any]) -> Dict[str, Any]:
"""Apply default values from constants.py to config dictionary.
Only applies defaults for parameters that are not already specified in config.
This allows config files to only specify non-default values.
Added by Claude: Automatic defaults for layer-level AWB refactor.
Args:
config: Configuration dictionary from JSON file
Returns:
Configuration dictionary with defaults applied
"""
# Create a copy to avoid modifying original
config = config.copy()
# Helper function to set default if not present
def set_default(key: str, value: Any):
if key not in config:
config[key] = value
# ===== DATASET-DRIVEN AUTO-SELECTION =====
# Get dataset (only required parameter from user)
data = config.get('data', constants.DEFAULT_DATA)
# Auto-select prob, problem, network, loss, metric based on dataset
if data in constants.DATASET_CONFIG_MAPload_config function · python · L266-L285 (20 LOC)src/cl/config/params.py
def load_config(json_path: str, apply_defaults_flag: bool = True) -> Dict[str, Any]:
"""Load configuration from JSON file with automatic defaults.
Convenience function that returns a dictionary directly.
Automatically applies defaults from constants.py unless disabled.
Args:
json_path: Path to JSON configuration file
apply_defaults_flag: If True, apply defaults from constants.py (default: True)
Returns:
Dictionary of configuration parameters with defaults applied
"""
with open(json_path) as f:
config = json.load(f)
if apply_defaults_flag:
config = apply_defaults(config)
return configload_search_config function · python · L54-L99 (46 LOC)src/cl/core/arch_search.py
def load_search_config(config: Dict[str, Any], model_type: Optional[str] = None) -> Dict[str, Any]:
"""Load architecture search hyperparameters with defaults from constants.py.
# Added by Claude: Consolidates config loading from all search functions
Replaces the 20+ lines of config.get() calls that were duplicated in
mlp_search.py (lines 242-248), cnn_search.py (lines 66-79), and
gcn_search.py (lines 171-178).
IMPORTANT: All defaults come from constants.py, NOT hardcoded.
Uses config.get(key, DEFAULT_FROM_CONSTANTS) pattern.
Args:
config: Full configuration dictionary
model_type: Optional model type ('mlp', 'cnn', 'cnn3d', 'gcn')
Used to select model-specific epoch defaults
Returns:
Dict with standardized search configuration keys:
- search_epochs: Training epochs per candidate architecture
- search_lr: Learning rate for candidate training
- search_batch_size: Batch siRepobility · severity-and-effort ranking · https://repobility.com
compute_search_loss function · python · L102-L148 (47 LOC)src/cl/core/arch_search.py
def compute_search_loss(
record_dict: Dict,
task_id: int,
epochs: int,
window: Optional[int] = None
) -> float:
"""Compute average loss from recent training iterations.
# Added by Claude: Unified loss computation for all model types
Replaces duplicated code in:
- mlp_search.py _compute_search_loss() (lines 79-109)
- gcn_search.py _compute_search_loss() (lines 27-64)
- cnn_search.py inline loss extraction (lines 142-147, 206-211)
Handles both dict and tuple record formats for compatibility.
Args:
record_dict: Training records dictionary with 'iterations' key
task_id: Current task ID (0-indexed)
epochs: Total epochs trained for this candidate
window: Number of recent epochs to average (default: from constants)
Returns:
Average loss value over the window, or float('inf') if no data
"""
if window is None:
window = DEFAULT_ARCH_SEARCH_AVERAGING_WINDOW
losses = []
iteratiopartition_for_search function · python · L151-L174 (24 LOC)src/cl/core/arch_search.py
def partition_for_search(model: eqx.Module) -> Tuple[eqx.Module, eqx.Module]:
"""Partition model for architecture search training (freeze A/B if AWB enabled).
# Added by Claude: Generic delegate pattern like apply_V_transformation() in awb.py
Delegates to model.partition_for_standard_training() if available (AWB models).
Falls back to standard partitioning for models without AWB.
This consolidates the partitioning logic that was duplicated in:
- mlp_search.py _train_candidate_architecture() (lines 142-147)
- cnn_search.py (lines 108-117, 181-190)
- gcn_search.py _partition_for_standard_training_gcn() (lines 124-138)
Args:
model: Model instance (MLP, CNN, CNN3D, or GCN)
Returns:
Tuple of (params, static) ready for training
"""
if hasattr(model, 'partition_for_standard_training'):
# AWB-enabled model - use model's partition method
return model.partition_for_standard_training()
else:
# Non-AWB reinitialize_weights function · python · L177-L201 (25 LOC)src/cl/core/arch_search.py
def reinitialize_weights(model: eqx.Module, seed: int = 0) -> eqx.Module:
"""Reinitialize model weights for fair architecture comparison.
# Added by Claude: Generic delegate pattern
Calls model.reinitialize_weights(seed) if available.
Ensures fair comparison between candidate architectures by starting
from fresh random initialization.
Args:
model: Model instance to reinitialize
seed: Random seed for reproducibility
Returns:
Model with freshly initialized weights
Raises:
NotImplementedError: If model doesn't implement reinitialize_weights()
"""
if hasattr(model, 'reinitialize_weights'):
return model.reinitialize_weights(seed)
else:
raise NotImplementedError(
f"Model {type(model).__name__} doesn't implement reinitialize_weights() method. "
f"Models must implement the search interface to use generic architecture search."
)build_train_config function · python · L204-L245 (42 LOC)src/cl/core/arch_search.py
def build_train_config(config: Dict[str, Any], search_cfg: Dict[str, Any]) -> Dict[str, Any]:
"""Build training configuration dict for trainer.train__CL() calls during search.
# Added by Claude: Standardizes train_config construction
Replaces inline config building in:
- mlp_search.py _train_candidate_architecture() (lines 160-167)
- cnn_search.py (lines 123-130, 193-196)
- gcn_search.py (lines 223-226)
Args:
config: Full configuration dictionary
search_cfg: Search-specific config from load_search_config()
Returns:
Training config dict compatible with trainer.train__CL()
"""
# Get problem type to determine defaults
problem = config.get('problem', 'vectors')
prob = config.get('prob', 'regression')
# Select appropriate defaults based on problem type
if problem == 'graph':
default_batch = config.get('batch_size', 32) # Graph default
default_replay = config.get('len_exp_replay', DEFAULT_REPcheck_early_stopping function · python · L248-L278 (31 LOC)src/cl/core/arch_search.py
def check_early_stopping(
current_loss: float,
best_loss: float,
patience_counter: int,
patience: int = 3,
min_improvement: float = 1e-4
) -> Tuple[bool, int]:
"""Check if architecture search should terminate early.
# Added by Claude: Speed optimization - NEW functionality
Enables early termination when search converges or stagnates.
Args:
current_loss: Loss of current candidate
best_loss: Best loss found so far
patience_counter: Current patience count (iterations without improvement)
patience: Max iterations without improvement before stopping
min_improvement: Minimum delta to count as improvement
Returns:
Tuple of (should_stop, updated_patience_counter)
"""
# Check if current candidate improved over best
if current_loss < best_loss - min_improvement:
# Improvement found - reset patience
return False, 0
else:
# No improvement - increment patience
uadapt_search_range function · python · L281-L304 (24 LOC)src/cl/core/arch_search.py
def adapt_search_range(
iteration: int,
improvement_rate: float,
base_range: int
) -> int:
"""Adapt search range based on convergence.
# Added by Claude: Speed optimization - NEW functionality
Reduces search space as optimal architecture is found.
Args:
iteration: Current search iteration
improvement_rate: (best_loss - baseline_loss) / baseline_loss
base_range: Initial search range
Returns:
Adapted search range (never less than 2)
"""
if improvement_rate < 0.01: # Converged (< 1% improvement)
return max(2, base_range // 2)
elif improvement_rate < 0.05: # Converging (< 5% improvement)
return max(3, base_range * 2 // 3)
else: # Still exploring
return base_rangeshould_evaluate_candidate function · python · L307-L340 (34 LOC)src/cl/core/arch_search.py
def should_evaluate_candidate(
candidate_size: int,
best_size: int,
best_loss: float,
baseline_loss: float,
expansion_threshold: float = 1.5
) -> bool:
"""Decide if a candidate architecture is worth evaluating.
# Added by Claude: Speed optimization - NEW functionality
Prunes candidates unlikely to improve via heuristics.
Skips very large expansions unless loss is still far from baseline.
Args:
candidate_size: Total parameters in candidate
best_size: Total parameters in current best
best_loss: Current best loss
baseline_loss: Baseline loss from preliminary training
expansion_threshold: Max size ratio to allow without loss check
Returns:
True if candidate should be evaluated, False to skip
"""
if best_size == 0:
return True # First candidate
size_ratio = candidate_size / best_size
loss_ratio = best_loss / baseline_loss if baseline_loss > 0 else 1.0
# Allow large e_train_and_evaluate_candidate function · python · L383-L472 (90 LOC)src/cl/core/arch_search.py
def _train_and_evaluate_candidate(
candidate_arch,
model,
task_id: int,
trial_number: int,
trainer,
train_data,
train_config: Dict[str, Any],
config: Dict[str, Any],
search_cfg: Dict[str, Any],
problem_type: str,
loss_type: str,
) -> float:
"""Train a candidate architecture and return its loss.
# Added by Claude: Shared evaluation logic for both grid and Bayesian search
Extracted to avoid code duplication between search methods.
Args:
candidate_arch: Architecture specification to evaluate
model: Reference model (for create_with_architecture interface)
task_id: Current task ID
trial_number: Trial/candidate number (for seeding)
trainer: Trainer instance
train_data: Training data tuple
train_config: Training configuration
config: Full configuration dict
search_cfg: Search-specific configuration
problem_type: 'vectors' or 'graph'
loss_type: 'Methodology: Repobility · https://repobility.com/research/state-of-ai-code-2026/
search_architecture_bayesian function · python · L475-L639 (165 LOC)src/cl/core/arch_search.py
def search_architecture_bayesian(
model: eqx.Module,
baseline_arch,
task_id: int,
baseline_loss: float,
dataloader_curr,
dataloader_exp,
test_loader_curr,
test_loader_exp,
config: Dict[str, Any],
trainer=None,
model_type: Optional[str] = None
):
"""Architecture search using Bayesian Optimization (Optuna).
# Added by Claude: Bayesian alternative to grid search
Uses Optuna's TPE (Tree-structured Parzen Estimator) sampler to
intelligently explore the architecture space with fewer evaluations.
Typically evaluates 4-5 candidates instead of 8, while finding
similar or better architectures. All training infrastructure
(trainer.train__CL, partitioning, etc.) is reused unchanged.
Args:
model: Current model instance (used to get awb_enabled state and interface)
baseline_arch: Baseline architecture (current best)
task_id: Current task ID
baseline_loss: Loss from preliminary training (bassearch_architecture_grid function · python · L642-L842 (201 LOC)src/cl/core/arch_search.py
def search_architecture_grid(
model: eqx.Module,
baseline_arch,
task_id: int,
baseline_loss: float,
dataloader_curr,
dataloader_exp,
test_loader_curr,
test_loader_exp,
config: Dict[str, Any],
trainer=None,
model_type: Optional[str] = None
):
"""Grid-based architecture search (original implementation).
# Added by Claude: Renamed from search_architecture for clarity
This is the original grid search implementation that evaluates
all candidates generated by model.generate_search_candidates().
Args:
Same as search_architecture()
Returns:
Optimal architecture found during search
"""
print(f" Starting grid architecture search for task {task_id}")
print(f" Baseline architecture: {baseline_arch}")
print(f" Baseline loss: {baseline_loss:.6f}")
# Load search configuration
search_cfg = load_search_config(config, model_type)
# Create trainer if not provided
if trainer is None:search_architecture function · python · L845-L904 (60 LOC)src/cl/core/arch_search.py
def search_architecture(
model: eqx.Module,
baseline_arch,
task_id: int,
baseline_loss: float,
dataloader_curr,
dataloader_exp,
test_loader_curr,
test_loader_exp,
config: Dict[str, Any],
trainer=None,
model_type: Optional[str] = None
):
"""Generic architecture search function for any model.
# Added by Claude: Dispatcher for grid vs Bayesian search
Selects search method based on config['arch_search_method']:
- 'grid': Traditional grid search (default)
- 'bayesian': Bayesian Optimization using Optuna
Works for ANY model implementing the search interface:
- model.generate_search_candidates(iteration, current_best, config)
- model.create_with_architecture(arch_spec, seed, awb_enabled)
- model.reinitialize_weights(seed)
Args:
model: Current model instance (used to get awb_enabled state)
baseline_arch: Baseline architecture (current best)
task_id: Current task IDestimate_model_memory_mb function · python · L43-L62 (20 LOC)src/cl/core/async_checkpoint.py
def estimate_model_memory_mb(model) -> float:
"""Estimate memory footprint of a JAX/Equinox model in MB.
Args:
model: Equinox model (MLP, CNN, GCN)
Returns:
Estimated memory in MB
"""
total_bytes = 0
# Count all array parameters
leaves = jax.tree_util.tree_leaves(model)
for leaf in leaves:
if isinstance(leaf, jnp.ndarray):
# bytes = size * itemsize
total_bytes += leaf.size * leaf.dtype.itemsize
# Convert bytes to MB
return total_bytes / (1024 * 1024)estimate_records_memory_mb function · python · L65-L83 (19 LOC)src/cl/core/async_checkpoint.py
def estimate_records_memory_mb(record_dict: Dict[str, Any]) -> float:
"""Estimate memory footprint of record_dict in MB.
Args:
record_dict: Training records dictionary
Returns:
Estimated memory in MB (approximate)
"""
# Rough estimate: pickle size is similar to in-memory size
# For detailed estimation, we'd need to traverse the entire dict
# For now, use a conservative estimate based on task count
n_tasks = len(record_dict.get('tasks', {}))
n_iterations = len(record_dict.get('iterations', {}))
# Rough estimate: ~10 KB per task, ~1 KB per iteration
estimated_bytes = (n_tasks * 10 * 1024) + (n_iterations * 1024)
return estimated_bytes / (1024 * 1024)get_available_memory_gb function · python · L86-L98 (13 LOC)src/cl/core/async_checkpoint.py
def get_available_memory_gb() -> float:
"""Get available system memory in GB.
Returns:
Available memory in GB, or 0.0 if unable to determine
"""
try:
import psutil
mem = psutil.virtual_memory()
return mem.available / (1024**3)
except ImportError:
# psutil not available, return conservative estimate
return 0.0AsyncCheckpointManager.__init__ method · python · L111-L142 (32 LOC)src/cl/core/async_checkpoint.py
def __init__(self,
save_dir: str = "outputs/checkpoints",
max_checkpoints: int = 3,
memory_limit_gb: float = 8.0,
enable_async: bool = True):
"""Initialize checkpoint manager.
Args:
save_dir: Directory to save checkpoints
max_checkpoints: Maximum number of checkpoints to keep (older ones deleted)
memory_limit_gb: Maximum memory to use for checkpointing (GB)
enable_async: If False, use synchronous saving (for debugging)
"""
self.save_dir = Path(save_dir)
self.max_checkpoints = max_checkpoints
self.memory_limit_gb = memory_limit_gb
self.enable_async = enable_async
# Create save directory
self.save_dir.mkdir(parents=True, exist_ok=True)
# Background saving thread
self._save_queue = queue.Queue()
self._shutdown = False
self._save_thread = None
# Track saved AsyncCheckpointManager._background_saver method · python · L144-L171 (28 LOC)src/cl/core/async_checkpoint.py
def _background_saver(self):
"""Background thread that saves checkpoints from queue."""
while not self._shutdown:
try:
# Wait for save request (timeout to check shutdown flag)
save_task = self._save_queue.get(timeout=1.0)
if save_task is None: # Shutdown signal
break
# Unpack save task
model_path, model, record_dict_path, record_dict = save_task
# Save model
if model is not None:
eqx.tree_serialise_leaves(model_path, model)
# Save records
if record_dict is not None:
with open(record_dict_path, 'wb') as f:
pickle.dump(record_dict, f)
self._save_queue.task_done()
except queue.Empty:
continue
except Exception as e:
print(f"[AsyncCheckpoint] Error saving cheRepobility · open methodology · https://repobility.com/research/
AsyncCheckpointManager._cleanup_old_checkpoints method · python · L173-L186 (14 LOC)src/cl/core/async_checkpoint.py
def _cleanup_old_checkpoints(self):
"""Remove old checkpoints if exceeding max_checkpoints."""
if len(self._checkpoint_history) > self.max_checkpoints:
# Remove oldest checkpoint
old_checkpoint = self._checkpoint_history.pop(0)
model_path, record_path = old_checkpoint
try:
if os.path.exists(model_path):
os.remove(model_path)
if os.path.exists(record_path):
os.remove(record_path)
except Exception as e:
print(f"[AsyncCheckpoint] Failed to delete old checkpoint: {e}", file=sys.stderr)AsyncCheckpointManager.save_checkpoint method · python · L188-L245 (58 LOC)src/cl/core/async_checkpoint.py
def save_checkpoint(self,
model,
record_dict: Dict[str, Any],
task_id: int,
epoch: int,
prefix: str = "checkpoint") -> bool:
"""Save a checkpoint asynchronously (or synchronously if disabled).
Args:
model: Equinox model to save
record_dict: Training records dictionary
task_id: Current task ID
epoch: Current epoch
prefix: Filename prefix
Returns:
True if checkpoint was queued/saved, False if skipped due to memory
"""
# Estimate memory requirements
model_mem_mb = estimate_model_memory_mb(model)
records_mem_mb = estimate_records_memory_mb(record_dict)
total_mem_mb = model_mem_mb + records_mem_mb
total_mem_gb = total_mem_mb / 1024
# Check memory limit
if total_mem_gb > self.memory_limit_gb:
print(f"[AAsyncCheckpointManager.wait_all method · python · L247-L254 (8 LOC)src/cl/core/async_checkpoint.py
def wait_all(self, timeout: Optional[float] = None):
"""Wait for all pending checkpoint saves to complete.
Args:
timeout: Maximum time to wait in seconds (None = wait forever)
"""
if self.enable_async and self._save_queue is not None:
self._save_queue.join()AsyncCheckpointManager.shutdown method · python · L256-L261 (6 LOC)src/cl/core/async_checkpoint.py
def shutdown(self):
"""Shutdown the checkpoint manager gracefully."""
if self.enable_async and self._save_thread is not None:
self._shutdown = True
self._save_queue.put(None) # Shutdown signal
self._save_thread.join(timeout=10.0)AWBOperations.search_architecture method · python · L41-L77 (37 LOC)src/cl/core/awb_operations.py
def search_architecture(
self,
model: eqx.Module,
task_id: int,
baseline_loss: float,
dataloader_curr,
dataloader_exp,
test_loader_curr,
test_loader_exp,
config: Dict[str, Any],
trainer=None
) -> Any:
"""Search for optimal architecture for current task.
This is STEP 3a of the AWB pipeline. The search creates fresh candidate
models with random initialization, trains each for a fixed number of epochs,
and returns the architecture specification that achieves lowest loss.
Args:
model: Current model (used to extract baseline architecture)
task_id: Current task ID
baseline_loss: Loss from preliminary training (baseline for comparison)
dataloader_curr: Current task training data
dataloader_exp: Experience replay data
test_loader_curr: Current task test data
test_loader_exp: ExperienAWBOperations.set_AB_matrices method · python · L80-L105 (26 LOC)src/cl/core/awb_operations.py
def set_AB_matrices(
self,
model: eqx.Module,
original_arch: Any,
new_arch: Any
) -> eqx.Module:
"""Initialize A/B matrices for architecture transition.
This is the setup for STEP 3b of the AWB pipeline. Creates transformation
matrices A and B that will map old architecture to new architecture:
- A transforms from old output dimensions to new output dimensions
- B transforms from old input dimensions to new input dimensions
The old weights W are preserved in the model. During AB training:
- get_AWBT computes: A @ W @ B^T for forward pass
- A and B are trainable, W is frozen
Args:
model: Model with old architecture and weights W
original_arch: Original architecture specification
new_arch: New architecture specification (from search_architecture)
Returns:
Model with A/B matrices initialized (W unchanged)
"""
AWBOperations.partition_for_AB_training method · python · L108-L129 (22 LOC)src/cl/core/awb_operations.py
def partition_for_AB_training(
self,
model: eqx.Module
) -> Tuple[eqx.Module, eqx.Module]:
"""Partition model for AB training phase.
This is used in STEP 3b of the AWB pipeline. Separates the model into:
- Trainable parameters: A and B matrices only
- Static (frozen) parameters: W (old weights) and everything else
During AB training (notABTrain=False), only A/B are updated via gradient
descent. The old weights W remain frozen.
Args:
model: Model with A, B, and W matrices
Returns:
Tuple of (trainable_params, static_params)
- trainable_params: Contains only A and B matrices
- static_params: Contains W and all other model parameters
"""
passAWBOperations.compute_V method · python · L132-L154 (23 LOC)src/cl/core/awb_operations.py
def compute_V(
self,
model: eqx.Module
) -> eqx.Module:
"""Compute transformed weights V = A @ W @ B^T.
This is STEP 4 of the AWB pipeline. After AB training completes, we
compute the effective weights V by applying the trained transformation
matrices to the old weights:
V = A @ W @ B^T
The model is then updated to use V as the new weights. For biases:
V_bias = A @ bias
Args:
model: Model with trained A, B matrices and old W weights
Returns:
Model with weights updated to V (A/B matrices still present but
will be frozen in subsequent training)
"""
passProvenance: Repobility (https://repobility.com) — every score reproducible from /scan/
AWBOperations.partition_for_standard_training method · python · L157-L179 (23 LOC)src/cl/core/awb_operations.py
def partition_for_standard_training(
self,
model: eqx.Module
) -> Tuple[eqx.Module, eqx.Module]:
"""Partition model for standard training (STEP 5 and beyond).
This is used in STEP 5 of the AWB pipeline and all subsequent training.
Separates the model into:
- Trainable parameters: V (transformed weights) and other parameters
- Static (frozen) parameters: A and B matrices
During standard training after AWB, only V is updated. The A/B matrices
remain frozen, preserving the architecture transformation.
Args:
model: Model with V, A, and B matrices
Returns:
Tuple of (trainable_params, static_params)
- trainable_params: Contains V and trainable parameters (excluding A/B)
- static_params: Contains A and B matrices (frozen)
"""
passAWBOperations.get_model_architecture method · python · L182-L197 (16 LOC)src/cl/core/awb_operations.py
def get_model_architecture(
self,
model: eqx.Module
) -> Any:
"""Extract architecture specification from model.
Returns the current architecture in the same format expected by
search_architecture() and set_AB_matrices().
Args:
model: Model instance
Returns:
Architecture specification (model-specific format)
"""
passAWBOperations.save_weights method · python · L200-L215 (16 LOC)src/cl/core/awb_operations.py
def save_weights(
self,
model: eqx.Module
) -> Any:
"""Save current model weights before architecture search.
Architecture search creates fresh models with random initialization.
We need to save the current weights to restore them after search completes.
Args:
model: Model instance
Returns:
Saved weights (model-specific format)
"""
passpage 1 / 6next ›