Function bodies 933 total
analyze_checkpoint_structure function · python · L8-L84 (77 LOC)analyze_checkpoint_structure.py
def analyze_checkpoint_structure(path):
if not os.path.exists(path):
print(f"错误: 文件未找到 - {path}")
return
print(f"=== 分析模型文件: {os.path.basename(path)} ===\n")
try:
# 加载 checkpoint
checkpoint = torch.load(path, map_location='cpu')
# 1. 顶层结构
print(f"文件类型: {type(checkpoint)}")
if not isinstance(checkpoint, dict):
print("警告: 文件不是标准的字典格式,可能是直接保存的模型对象。")
return
print(f"包含的顶层键 (Keys): {list(checkpoint.keys())}\n")
# 2. 详细分析 Config (配置信息)
if 'config' in checkpoint:
print("--- [config] 模型配置信息 ---")
config = checkpoint['config']
for k, v in config.items():
if isinstance(v, dict):
print(f" {k}:")
for sub_k, sub_v in v.items():
print(f" {sub_k}: {sub_v}")
else:
print(f" {k}: {v}")
else:
__init__ method · python · L20-L45 (26 LOC)autoencoder/evaluation/ae_evaluator.py
def __init__(self,
autoencoder,
parameter_mapper=None,
wavelet_transform=None,
device: Optional[torch.device] = None):
"""
初始化评估器
Args:
autoencoder: AutoEncoder模型
parameter_mapper: 参数映射器(可选)
wavelet_transform: 小波变换器(可选)
device: 计算设备
"""
self.autoencoder = autoencoder
self.parameter_mapper = parameter_mapper
self.wavelet_transform = wavelet_transform
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 移动模型到设备
self.autoencoder.to(self.device)
if self.parameter_mapper and hasattr(self.parameter_mapper, 'to'):
self.parameter_mapper.to(self.device)
# 评估指标计算器
self.metrics_calculator = ReconstructionMetrics(device=self.device)evaluate_autoencoder_reconstruction method · python · L47-L133 (87 LOC)autoencoder/evaluation/ae_evaluator.py
def evaluate_autoencoder_reconstruction(self,
test_rcs: np.ndarray,
batch_size: int = 16) -> Dict[str, Any]:
"""
评估AutoEncoder重建质量
Args:
test_rcs: [N, 91, 91, 2] 测试RCS数据
batch_size: 批次大小
Returns:
evaluation_results: 评估结果
"""
print("评估AutoEncoder重建质量...")
self.autoencoder.eval()
# 数据预处理
if self.wavelet_transform:
rcs_tensor = torch.FloatTensor(test_rcs)
test_wavelet = self.wavelet_transform.forward_transform(rcs_tensor)
else:
test_wavelet = torch.FloatTensor(test_rcs)
# 批量重建
reconstructed_data = []
latent_representations = []
inference_times = []
n_batches = (len(test_wavelet) + batch_size - 1) // batch_size
with torch.no_grad():
for i in range(n_batches):
start_idx = ievaluate_parameter_mapping method · python · L135-L200 (66 LOC)autoencoder/evaluation/ae_evaluator.py
def evaluate_parameter_mapping(self,
test_params: np.ndarray,
test_rcs: np.ndarray) -> Dict[str, Any]:
"""
评估参数映射质量
Args:
test_params: [N, 9] 测试参数
test_rcs: [N, 91, 91, 2] 测试RCS
Returns:
evaluation_results: 评估结果
"""
if self.parameter_mapper is None:
raise ValueError("参数映射器未设置,无法进行评估")
print("评估参数映射质量...")
# 1. 获取目标隐空间表示
self.autoencoder.eval()
if self.wavelet_transform:
rcs_tensor = torch.FloatTensor(test_rcs)
test_wavelet = self.wavelet_transform.forward_transform(rcs_tensor)
else:
test_wavelet = torch.FloatTensor(test_rcs)
with torch.no_grad():
test_wavelet = test_wavelet.to(self.device)
_, target_latent = self.autoencoder(test_wavelet)
target_latent = target_latent.cpu().numpy()
# 2. 参evaluate_end_to_end method · python · L202-L283 (82 LOC)autoencoder/evaluation/ae_evaluator.py
def evaluate_end_to_end(self,
test_params: np.ndarray,
test_rcs: np.ndarray) -> Dict[str, Any]:
"""
端到端评估:参数 → RCS重建
Args:
test_params: [N, 9] 测试参数
test_rcs: [N, 91, 91, 2] 测试RCS
Returns:
evaluation_results: 评估结果
"""
if self.parameter_mapper is None:
raise ValueError("参数映射器未设置,无法进行端到端评估")
print("执行端到端评估...")
# 1. 参数 → 隐空间
if hasattr(self.parameter_mapper, 'predict'):
pred_latent = self.parameter_mapper.predict(test_params)
pred_latent = torch.FloatTensor(pred_latent)
else:
self.parameter_mapper.eval()
with torch.no_grad():
params_tensor = torch.FloatTensor(test_params).to(self.device)
pred_latent = self.parameter_mapper(params_tensor)
# 2. 隐空间 → 小波系数
self.autoencoder.eval()
with torch.no_grad():
_analyze_latent_space method · python · L285-L308 (24 LOC)autoencoder/evaluation/ae_evaluator.py
def _analyze_latent_space(self, latent_vectors: np.ndarray) -> Dict[str, Any]:
"""分析隐空间特性"""
analysis = {}
# 基础统计
analysis['shape'] = latent_vectors.shape
analysis['mean'] = np.mean(latent_vectors, axis=0)
analysis['std'] = np.std(latent_vectors, axis=0)
analysis['min'] = np.min(latent_vectors, axis=0)
analysis['max'] = np.max(latent_vectors, axis=0)
# 维度利用率
dimension_usage = np.std(latent_vectors, axis=0)
active_dims = np.sum(dimension_usage > 0.01) # 标准差大于0.01的维度
analysis['active_dimensions'] = active_dims
analysis['dimension_usage_ratio'] = active_dims / latent_vectors.shape[1]
# 相关性分析
correlation_matrix = np.corrcoef(latent_vectors.T)
mean_correlation = np.mean(np.abs(correlation_matrix[np.triu_indices_from(correlation_matrix, k=1)]))
analysis['mean_correlation'] = mean_correlation
return analysis_analyze_latent_consistency method · python · L310-L334 (25 LOC)autoencoder/evaluation/ae_evaluator.py
def _analyze_latent_consistency(self,
target_latent: np.ndarray,
pred_latent: np.ndarray) -> Dict[str, float]:
"""分析隐空间一致性"""
# 逐维度相关性
dim_correlations = []
for i in range(target_latent.shape[1]):
corr = np.corrcoef(target_latent[:, i], pred_latent[:, i])[0, 1]
if not np.isnan(corr):
dim_correlations.append(corr)
# 整体相关性
overall_correlation = np.corrcoef(target_latent.flatten(), pred_latent.flatten())[0, 1]
# 维度重要性分析
dim_importance = np.std(target_latent, axis=0)
weighted_correlation = np.average(dim_correlations, weights=dim_importance[:len(dim_correlations)])
return {
'overall_correlation': overall_correlation if not np.isnan(overall_correlation) else 0.0,
'mean_dim_correlation': np.mean(dim_correlations) if dim_correlations else 0.0,
'weighted_correlationSame scanner, your repo: https://repobility.com — Repobility
generate_comprehensive_report method · python · L336-L427 (92 LOC)autoencoder/evaluation/ae_evaluator.py
def generate_comprehensive_report(self,
ae_results: Dict[str, Any],
mapping_results: Optional[Dict[str, Any]] = None,
e2e_results: Optional[Dict[str, Any]] = None) -> str:
"""生成综合评估报告"""
report = "\\n" + "="*60 + "\\n"
report += " AutoEncoder系统综合评估报告\\n"
report += "="*60 + "\\n"
# AutoEncoder重建评估
report += "\\n🔧 AutoEncoder重建评估:\\n"
report += "-" * 40 + "\\n"
ae_metrics = ae_results['reconstruction_metrics']
performance = ae_results['performance']
report += f"重建质量指标:\\n"
report += f" MSE: {ae_metrics['mse']:.6f}\\n"
report += f" SSIM: {ae_metrics['ssim_mean']:.4f} ± {ae_metrics['ssim_std']:.4f}\\n"
report += f" 相关系数: {ae_metrics['correlation']:.4f}\\n"
report += f" R²决定系数: {ae_metrics['r2_score']:.4f}\\n"
reportest_ae_evaluator function · python · L430-L445 (16 LOC)autoencoder/evaluation/ae_evaluator.py
def test_ae_evaluator():
"""测试AE评估器"""
print("=== AE评估器测试 ===")
# 这里需要实际的模型进行测试
# 由于导入问题,暂时使用模拟测试
print("AE评估器模块创建完成")
print("包含以下功能:")
print("- AutoEncoder重建质量评估")
print("- 参数映射质量评估")
print("- 端到端性能评估")
print("- 隐空间分析")
print("- 综合评估报告生成")
return True__init__ method · python · L19-L26 (8 LOC)autoencoder/evaluation/reconstruction_metrics.py
def __init__(self, device: torch.device = None):
"""
初始化评估器
Args:
device: 计算设备
"""
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')compute_all_metrics method · python · L28-L58 (31 LOC)autoencoder/evaluation/reconstruction_metrics.py
def compute_all_metrics(self,
pred_rcs: torch.Tensor,
true_rcs: torch.Tensor) -> Dict[str, float]:
"""
计算所有重建质量指标
Args:
pred_rcs: [B, 91, 91, 2] 预测RCS
true_rcs: [B, 91, 91, 2] 真实RCS
Returns:
metrics: 所有评估指标
"""
metrics = {}
# 基础误差指标
metrics.update(self.compute_basic_errors(pred_rcs, true_rcs))
# 结构相似性指标
metrics.update(self.compute_ssim_metrics(pred_rcs, true_rcs))
# 频域一致性指标
metrics.update(self.compute_frequency_metrics(pred_rcs, true_rcs))
# 物理约束指标
metrics.update(self.compute_physics_metrics(pred_rcs, true_rcs))
# 统计指标
metrics.update(self.compute_statistical_metrics(pred_rcs, true_rcs))
return metricscompute_basic_errors method · python · L60-L92 (33 LOC)autoencoder/evaluation/reconstruction_metrics.py
def compute_basic_errors(self,
pred_rcs: torch.Tensor,
true_rcs: torch.Tensor) -> Dict[str, float]:
"""计算基础误差指标"""
# 确保在同一设备
pred_rcs = pred_rcs.to(self.device)
true_rcs = true_rcs.to(self.device)
# MSE (均方误差)
mse = F.mse_loss(pred_rcs, true_rcs).item()
# MAE (平均绝对误差)
mae = F.l1_loss(pred_rcs, true_rcs).item()
# RMSE (均方根误差)
rmse = torch.sqrt(F.mse_loss(pred_rcs, true_rcs)).item()
# 相对误差
true_rcs_safe = torch.where(torch.abs(true_rcs) < 1e-8,
torch.sign(true_rcs) * 1e-8, true_rcs)
relative_error = torch.mean(torch.abs((pred_rcs - true_rcs) / true_rcs_safe)).item()
# 最大误差
max_error = torch.max(torch.abs(pred_rcs - true_rcs)).item()
return {
'mse': mse,
'mae': mae,
'rmse': rmse,
'relative_error': relative_error,
compute_ssim_metrics method · python · L94-L130 (37 LOC)autoencoder/evaluation/reconstruction_metrics.py
def compute_ssim_metrics(self,
pred_rcs: torch.Tensor,
true_rcs: torch.Tensor) -> Dict[str, float]:
"""计算结构相似性指标"""
pred_np = pred_rcs.detach().cpu().numpy()
true_np = true_rcs.detach().cpu().numpy()
batch_size = pred_np.shape[0]
ssim_scores = []
# 对每个样本和频率计算SSIM
for b in range(batch_size):
for freq in range(2): # 两个频率
pred_freq = pred_np[b, :, :, freq]
true_freq = true_np[b, :, :, freq]
# 标准化到[0,1]范围
pred_norm = self._normalize_for_ssim(pred_freq)
true_norm = self._normalize_for_ssim(true_freq)
try:
ssim_score = ssim(true_norm, pred_norm, data_range=1.0)
ssim_scores.append(ssim_score)
except Exception:
# 如果SSIM计算失败,使用默认值
ssim_scores.append(0.0)
avg_ssim = np_normalize_for_ssim method · python · L132-L140 (9 LOC)autoencoder/evaluation/reconstruction_metrics.py
def _normalize_for_ssim(self, data: np.ndarray) -> np.ndarray:
"""为SSIM计算标准化数据"""
data_min = np.min(data)
data_max = np.max(data)
if data_max - data_min < 1e-8:
return np.zeros_like(data)
return (data - data_min) / (data_max - data_min)compute_frequency_metrics method · python · L142-L177 (36 LOC)autoencoder/evaluation/reconstruction_metrics.py
def compute_frequency_metrics(self,
pred_rcs: torch.Tensor,
true_rcs: torch.Tensor) -> Dict[str, float]:
"""计算频域一致性指标"""
pred_rcs = pred_rcs.to(self.device)
true_rcs = true_rcs.to(self.device)
# FFT分析
pred_fft = torch.fft.fft2(pred_rcs, dim=[1, 2])
true_fft = torch.fft.fft2(true_rcs, dim=[1, 2])
# 幅度谱误差
pred_magnitude = torch.abs(pred_fft)
true_magnitude = torch.abs(true_fft)
magnitude_error = F.mse_loss(pred_magnitude, true_magnitude).item()
# 相位谱误差
pred_phase = torch.angle(pred_fft)
true_phase = torch.angle(true_fft)
phase_error = F.mse_loss(pred_phase, true_phase).item()
# 功率谱密度误差
pred_power = pred_magnitude ** 2
true_power = true_magnitude ** 2
power_error = F.mse_loss(pred_power, true_power).item()
# 频率间一致性
freq_consistency = self._compute_frequency_cPowered by Repobility — scan your code at https://repobility.com
_compute_frequency_consistency method · python · L179-L196 (18 LOC)autoencoder/evaluation/reconstruction_metrics.py
def _compute_frequency_consistency(self,
pred_rcs: torch.Tensor,
true_rcs: torch.Tensor) -> float:
"""计算频率间一致性"""
# 1.5GHz和3GHz的一致性分析
pred_1_5g = pred_rcs[:, :, :, 0] # [B, 91, 91]
pred_3g = pred_rcs[:, :, :, 1]
true_1_5g = true_rcs[:, :, :, 0]
true_3g = true_rcs[:, :, :, 1]
# 频率差异的一致性
pred_diff = pred_3g - pred_1_5g
true_diff = true_3g - true_1_5g
diff_error = F.mse_loss(pred_diff, true_diff).item()
return diff_errorcompute_physics_metrics method · python · L198-L219 (22 LOC)autoencoder/evaluation/reconstruction_metrics.py
def compute_physics_metrics(self,
pred_rcs: torch.Tensor,
true_rcs: torch.Tensor) -> Dict[str, float]:
"""计算物理约束指标"""
pred_rcs = pred_rcs.to(self.device)
true_rcs = true_rcs.to(self.device)
# 对称性误差 (φ=0°平面对称性)
symmetry_error = self._compute_symmetry_error(pred_rcs, true_rcs)
# 连续性误差 (空间梯度连续性)
continuity_error = self._compute_continuity_error(pred_rcs, true_rcs)
# 非负性检查 (RCS通常非负,但这里可能有负值由于预处理)
negative_ratio = self._compute_negative_ratio(pred_rcs)
return {
'symmetry_error': symmetry_error,
'continuity_error': continuity_error,
'negative_ratio': negative_ratio
}_compute_symmetry_error method · python · L221-L248 (28 LOC)autoencoder/evaluation/reconstruction_metrics.py
def _compute_symmetry_error(self,
pred_rcs: torch.Tensor,
true_rcs: torch.Tensor) -> float:
"""计算对称性误差"""
center_phi = 45 # φ=0°对应第45列
symmetry_errors = []
for i in range(1, min(center_phi + 1, 46)): # 最多检查45度范围
left_idx = center_phi - i
right_idx = center_phi + i
if right_idx < 91:
# 预测的对称性
pred_left = pred_rcs[:, :, left_idx, :]
pred_right = pred_rcs[:, :, right_idx, :]
pred_sym_diff = pred_left - pred_right
# 真实的对称性
true_left = true_rcs[:, :, left_idx, :]
true_right = true_rcs[:, :, right_idx, :]
true_sym_diff = true_left - true_right
# 对称性误差
sym_error = F.mse_loss(pred_sym_diff, true_sym_diff)
symmetry_errors.append(sym_error.item())
return np.mean(symmetry_errors_compute_continuity_error method · python · L250-L267 (18 LOC)autoencoder/evaluation/reconstruction_metrics.py
def _compute_continuity_error(self,
pred_rcs: torch.Tensor,
true_rcs: torch.Tensor) -> float:
"""计算连续性误差"""
# θ方向梯度
pred_grad_theta = pred_rcs[:, 1:, :, :] - pred_rcs[:, :-1, :, :]
true_grad_theta = true_rcs[:, 1:, :, :] - true_rcs[:, :-1, :, :]
# φ方向梯度
pred_grad_phi = pred_rcs[:, :, 1:, :] - pred_rcs[:, :, :-1, :]
true_grad_phi = true_rcs[:, :, 1:, :] - true_rcs[:, :, :-1, :]
# 梯度误差
grad_error_theta = F.mse_loss(pred_grad_theta, true_grad_theta)
grad_error_phi = F.mse_loss(pred_grad_phi, true_grad_phi)
return (grad_error_theta + grad_error_phi).item() / 2_compute_negative_ratio method · python · L269-L273 (5 LOC)autoencoder/evaluation/reconstruction_metrics.py
def _compute_negative_ratio(self, pred_rcs: torch.Tensor) -> float:
"""计算负值比例"""
total_elements = pred_rcs.numel()
negative_elements = torch.sum(pred_rcs < 0).item()
return negative_elements / total_elementscompute_statistical_metrics method · python · L275-L302 (28 LOC)autoencoder/evaluation/reconstruction_metrics.py
def compute_statistical_metrics(self,
pred_rcs: torch.Tensor,
true_rcs: torch.Tensor) -> Dict[str, float]:
"""计算统计指标"""
pred_rcs = pred_rcs.to(self.device)
true_rcs = true_rcs.to(self.device)
# 皮尔逊相关系数
correlation = self._compute_correlation(pred_rcs, true_rcs)
# R²决定系数
r2_score = self._compute_r2_score(pred_rcs, true_rcs)
# 分布差异 (KL散度近似)
kl_divergence = self._compute_kl_divergence(pred_rcs, true_rcs)
# 数值范围比较
range_metrics = self._compute_range_metrics(pred_rcs, true_rcs)
metrics = {
'correlation': correlation,
'r2_score': r2_score,
'kl_divergence': kl_divergence
}
metrics.update(range_metrics)
return metrics_compute_correlation method · python · L304-L326 (23 LOC)autoencoder/evaluation/reconstruction_metrics.py
def _compute_correlation(self,
pred_rcs: torch.Tensor,
true_rcs: torch.Tensor) -> float:
"""计算皮尔逊相关系数"""
pred_flat = pred_rcs.flatten()
true_flat = true_rcs.flatten()
# 计算相关系数
pred_mean = torch.mean(pred_flat)
true_mean = torch.mean(true_flat)
numerator = torch.sum((pred_flat - pred_mean) * (true_flat - true_mean))
pred_std = torch.sqrt(torch.sum((pred_flat - pred_mean) ** 2))
true_std = torch.sqrt(torch.sum((true_flat - true_mean) ** 2))
denominator = pred_std * true_std
if denominator < 1e-8:
return 0.0
correlation = (numerator / denominator).item()
return correlation_compute_r2_score method · python · L328-L340 (13 LOC)autoencoder/evaluation/reconstruction_metrics.py
def _compute_r2_score(self,
pred_rcs: torch.Tensor,
true_rcs: torch.Tensor) -> float:
"""计算R²决定系数"""
ss_res = torch.sum((true_rcs - pred_rcs) ** 2)
ss_tot = torch.sum((true_rcs - torch.mean(true_rcs)) ** 2)
if ss_tot < 1e-8:
return 0.0
r2 = 1 - ss_res / ss_tot
return r2.item()Repobility · code-quality intelligence platform · https://repobility.com
_compute_kl_divergence method · python · L342-L362 (21 LOC)autoencoder/evaluation/reconstruction_metrics.py
def _compute_kl_divergence(self,
pred_rcs: torch.Tensor,
true_rcs: torch.Tensor) -> float:
"""计算KL散度近似"""
# 将数据转换为概率分布(简化处理)
pred_hist = torch.histc(pred_rcs, bins=50, min=pred_rcs.min(), max=pred_rcs.max())
true_hist = torch.histc(true_rcs, bins=50, min=true_rcs.min(), max=true_rcs.max())
# 标准化为概率分布
pred_prob = pred_hist / torch.sum(pred_hist)
true_prob = true_hist / torch.sum(true_hist)
# 避免log(0)
pred_prob = torch.clamp(pred_prob, min=1e-8)
true_prob = torch.clamp(true_prob, min=1e-8)
# KL散度
kl = torch.sum(true_prob * torch.log(true_prob / pred_prob))
return kl.item()_compute_range_metrics method · python · L364-L383 (20 LOC)autoencoder/evaluation/reconstruction_metrics.py
def _compute_range_metrics(self,
pred_rcs: torch.Tensor,
true_rcs: torch.Tensor) -> Dict[str, float]:
"""计算数值范围指标"""
pred_min, pred_max = torch.min(pred_rcs).item(), torch.max(pred_rcs).item()
true_min, true_max = torch.min(true_rcs).item(), torch.max(true_rcs).item()
pred_range = pred_max - pred_min
true_range = true_max - true_min
range_error = abs(pred_range - true_range) / max(true_range, 1e-8)
return {
'pred_min': pred_min,
'pred_max': pred_max,
'true_min': true_min,
'true_max': true_max,
'range_error': range_error
}generate_report method · python · L385-L421 (37 LOC)autoencoder/evaluation/reconstruction_metrics.py
def generate_report(self,
metrics: Dict[str, float],
detailed: bool = True) -> str:
"""生成评估报告"""
report = "\\n=== RCS重建质量评估报告 ===\\n"
# 基础误差指标
report += "\\n📊 基础误差指标:\\n"
report += f" MSE (均方误差): {metrics.get('mse', 0):.6f}\\n"
report += f" MAE (平均绝对误差): {metrics.get('mae', 0):.6f}\\n"
report += f" RMSE (均方根误差): {metrics.get('rmse', 0):.6f}\\n"
report += f" 相对误差: {metrics.get('relative_error', 0):.6f}\\n"
# 结构相似性
report += "\\n🔍 结构相似性指标:\\n"
report += f" SSIM (平均): {metrics.get('ssim_mean', 0):.4f}\\n"
report += f" SSIM (标准差): {metrics.get('ssim_std', 0):.4f}\\n"
# 物理约束
report += "\\n⚖️ 物理约束指标:\\n"
report += f" 对称性误差: {metrics.get('symmetry_error', 0):.6f}\\n"
report += f" 连续性误差: {metrics.get('continuity_error', 0):.6f}\\n"
# 统计指标
test_reconstruction_metrics function · python · L424-L464 (41 LOC)autoencoder/evaluation/reconstruction_metrics.py
def test_reconstruction_metrics():
"""测试重建质量评估"""
print("=== 重建质量评估测试 ===")
# 创建测试数据
batch_size = 5
true_rcs = torch.randn(batch_size, 91, 91, 2) * 10
# 创建不同质量的预测数据进行测试
test_cases = [
("完美重建", true_rcs),
("添加噪声", true_rcs + torch.randn_like(true_rcs) * 0.5),
("系统偏移", true_rcs + 2.0),
("比例缩放", true_rcs * 0.8),
("随机噪声", torch.randn_like(true_rcs) * 5)
]
# 创建评估器
evaluator = ReconstructionMetrics()
# 测试每种情况
for case_name, pred_rcs in test_cases:
print(f"\\n--- {case_name} ---")
metrics = evaluator.compute_all_metrics(pred_rcs, true_rcs)
# 显示关键指标
print(f"MSE: {metrics['mse']:.6f}")
print(f"SSIM: {metrics['ssim_mean']:.4f}")
print(f"相关系数: {metrics['correlation']:.4f}")
print(f"R²: {metrics['r2_score']:.4f}")
# 生成详细报告
print("\\n" + "="*50)
print("详细评估报告示例:")
best_metrics = evaluator.compute_all_metrics(test_cases[1][1], true_rcAdditiveDualBranchWaveletCNN class · python · L35-L406 (372 LOC)autoencoder/models/additive_dual_branch_autoencoder.py
class AdditiveDualBranchWaveletCNN(BaseAutoEncoder):
"""
叠加型双分支小波CNN (Wavelet模式)
"""
def __init__(self,
latent_dim: int = 256,
num_frequencies: int = 2,
wavelet_bands: int = 4,
dropout_rate: float = 0.2,
input_size: int = 49,
activation_encoder: str = 'relu',
activation_high: str = 'sin',
activation_smooth: str = 'tanh',
learnable_weights: bool = False,
alpha_high: float = 0.5,
alpha_smooth: float = 0.5,
output_activation: str = None):
"""
初始化叠加型双分支AutoEncoder
Args:
latent_dim: 隐空间维度
num_frequencies: 频率数量 (2 for 1.5GHz+3GHz)
wavelet_bands: 小波频带数 (4: LL, LH, HL, HH)
dropout_rate: Dropout比例
input_size: 输入小波系数尺寸 (49)
activation_encoder: Encoder激活函数 (默认'relu', 可选: 'gelu', 'swis__init__ method · python · L40-L163 (124 LOC)autoencoder/models/additive_dual_branch_autoencoder.py
def __init__(self,
latent_dim: int = 256,
num_frequencies: int = 2,
wavelet_bands: int = 4,
dropout_rate: float = 0.2,
input_size: int = 49,
activation_encoder: str = 'relu',
activation_high: str = 'sin',
activation_smooth: str = 'tanh',
learnable_weights: bool = False,
alpha_high: float = 0.5,
alpha_smooth: float = 0.5,
output_activation: str = None):
"""
初始化叠加型双分支AutoEncoder
Args:
latent_dim: 隐空间维度
num_frequencies: 频率数量 (2 for 1.5GHz+3GHz)
wavelet_bands: 小波频带数 (4: LL, LH, HL, HH)
dropout_rate: Dropout比例
input_size: 输入小波系数尺寸 (49)
activation_encoder: Encoder激活函数 (默认'relu', 可选: 'gelu', 'swish', 'sin'等)
activation_high: 高频分支Decoder激活函数 (默认'sin', 推荐: 'sin', 'gelu', 'swish')
_build_decoder_fc method · python · L165-L185 (21 LOC)autoencoder/models/additive_dual_branch_autoencoder.py
def _build_decoder_fc(self, activation: str) -> nn.Sequential:
"""构建Decoder全连接层"""
decoder_fc_layers = []
current_dim = self.latent_dim
# 反向构建中间层
for intermediate_dim in reversed(self.intermediate_dims):
decoder_fc_layers.extend([
nn.Linear(current_dim, intermediate_dim),
get_activation(activation),
nn.Dropout(self.dropout_rate)
])
current_dim = intermediate_dim
# 最后一层到flattened_size
decoder_fc_layers.extend([
nn.Linear(current_dim, self.flattened_size),
get_activation(activation)
])
return nn.Sequential(*decoder_fc_layers)_build_decoder_conv method · python · L187-L209 (23 LOC)autoencoder/models/additive_dual_branch_autoencoder.py
def _build_decoder_conv(self, activation: str) -> nn.Sequential:
"""构建Decoder卷积层"""
return nn.Sequential(
# 重塑为特征图: [4096] → [256, 4, 4]
nn.Unflatten(1, (256, 4, 4)),
# 上采样层1: [4, 4] → [8, 8]
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
get_activation(activation),
nn.Dropout2d(self.dropout_rate),
# 上采样层2: [8, 8] → [16, 16]
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
get_activation(activation),
nn.Dropout2d(self.dropout_rate),
# 上采样层3: [16, 16] → [32, 32]
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(32),
get_activation(activation)
)Repobility · severity-and-effort ranking · https://repobility.com
_initialize_weights method · python · L211-L223 (13 LOC)autoencoder/models/additive_dual_branch_autoencoder.py
def _initialize_weights(self):
"""Xavier权重初始化"""
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)encode method · python · L225-L244 (20 LOC)autoencoder/models/additive_dual_branch_autoencoder.py
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""
编码器: 小波系数 → 隐空间
Args:
x: [B, input_size, input_size, num_freq*4] 小波系数
Returns:
latent: [B, latent_dim] 隐空间表示
"""
# 转换为通道优先: [B, H, W, C] → [B, C, H, W]
x = x.permute(0, 3, 1, 2)
# CNN编码
features = self.encoder(x) # [B, 256, 4, 4]
# 全连接编码到隐空间
latent = self.encoder_fc(features) # [B, latent_dim]
return latentdecode method · python · L246-L289 (44 LOC)autoencoder/models/additive_dual_branch_autoencoder.py
def decode(self, latent: torch.Tensor) -> torch.Tensor:
"""
双分支解码器: 隐空间 → 小波系数
Args:
latent: [B, latent_dim] 隐空间表示
Returns:
recon: [B, input_size, input_size, num_freq*4] 重建小波系数
"""
# ===== 分支1: 高频Decoder =====
feat_high = self.decoder_high_fc(latent) # [B, 4096]
feat_high = self.decoder_high_conv(feat_high) # [B, 32, 32, 32]
recon_high = self.final_conv_high(feat_high) # [B, num_freq*4, 32, 32]
# ===== 分支2: 低频Decoder =====
feat_smooth = self.decoder_smooth_fc(latent) # [B, 4096]
feat_smooth = self.decoder_smooth_conv(feat_smooth) # [B, 32, 32, 32]
recon_smooth = self.final_conv_smooth(feat_smooth) # [B, num_freq*4, 32, 32]
# ===== 输出叠加 =====
# 确保权重归一化(可选,防止输出爆炸)
if self.learnable_weights:
# 可学习权重:直接使用,不强制归一化,允许模型自适应学习幅度
alpha_high_norm = self.alpha_high
alpha_smooth_norm = self.alpha_smforward method · python · L291-L304 (14 LOC)autoencoder/models/additive_dual_branch_autoencoder.py
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
前向传播
Args:
x: [B, input_size, input_size, num_freq*4] 小波系数
Returns:
recon: [B, input_size, input_size, num_freq*4] 重建小波系数
latent: [B, latent_dim] 隐空间表示
"""
latent = self.encode(x)
recon = self.decode(latent)
return recon, latentforward_with_branches method · python · L306-L357 (52 LOC)autoencoder/models/additive_dual_branch_autoencoder.py
def forward_with_branches(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
前向传播(返回分支输出,用于可视化和频域分析)
Args:
x: [B, input_size, input_size, num_freq*4] 小波系数
Returns:
recon: [B, input_size, input_size, num_freq*4] 叠加后的重建
latent: [B, latent_dim] 隐空间表示
recon_high: [B, input_size, input_size, num_freq*4] 高频分支重建
recon_smooth: [B, input_size, input_size, num_freq*4] 低频分支重建
"""
# Encode
latent = self.encode(x)
# ===== 分支1: 高频Decoder =====
feat_high = self.decoder_high_fc(latent)
feat_high = self.decoder_high_conv(feat_high)
recon_high = self.final_conv_high(feat_high)
# ===== 分支2: 低频Decoder =====
feat_smooth = self.decoder_smooth_fc(latent)
feat_smooth = self.decoder_smooth_conv(feat_smooth)
recon_smooth = self.final_conv_smooth(feat_smooth)
# ===== 上采样到目标尺寸 =get_parameter_count method · python · L359-L389 (31 LOC)autoencoder/models/additive_dual_branch_autoencoder.py
def get_parameter_count(self) -> Dict[str, int]:
"""获取参数统计"""
# Encoder参数
encoder_params = sum(p.numel() for p in self.encoder.parameters()) + \
sum(p.numel() for p in self.encoder_fc.parameters())
# High-frequency Decoder参数
decoder_high_params = sum(p.numel() for p in self.decoder_high_fc.parameters()) + \
sum(p.numel() for p in self.decoder_high_conv.parameters()) + \
sum(p.numel() for p in self.final_conv_high.parameters())
# Low-frequency Decoder参数
decoder_smooth_params = sum(p.numel() for p in self.decoder_smooth_fc.parameters()) + \
sum(p.numel() for p in self.decoder_smooth_conv.parameters()) + \
sum(p.numel() for p in self.final_conv_smooth.parameters())
# 总decoder参数(两个分支之和)
decoder_params = decoder_high_params + decoder_smooth_params
# 总参数和可训练参数
get_model_info method · python · L391-L406 (16 LOC)autoencoder/models/additive_dual_branch_autoencoder.py
def get_model_info(self) -> dict:
"""返回模型信息"""
return {
'type': 'AdditiveDualBranchWaveletCNN',
'latent_dim': self.latent_dim,
'input_channels': self.input_channels,
'input_size': self.input_size,
'activation_encoder': self.activation_encoder_type,
'activation_high': self.activation_high_type,
'activation_smooth': self.activation_smooth_type,
'learnable_weights': self.learnable_weights,
'alpha_high': self.alpha_high.item(),
'alpha_smooth': self.alpha_smooth.item(),
'intermediate_dims': self.intermediate_dims,
'parameters': sum(p.numel() for p in self.parameters() if p.requires_grad)
}AdditiveDualBranchDirectCNN class · python · L409-L746 (338 LOC)autoencoder/models/additive_dual_branch_autoencoder.py
class AdditiveDualBranchDirectCNN(BaseAutoEncoder):
"""
叠加型双分支Direct CNN (Direct模式 - 无小波变换)
"""
def __init__(self,
latent_dim: int = 256,
num_frequencies: int = 2,
dropout_rate: float = 0.2,
input_size: int = 91,
activation_encoder: str = 'relu',
activation_high: str = 'sin',
activation_smooth: str = 'tanh',
learnable_weights: bool = False,
alpha_high: float = 0.5,
alpha_smooth: float = 0.5,
output_activation: str = None):
"""
初始化Direct模式叠加型双分支AutoEncoder
Args:
latent_dim: 隐空间维度
num_frequencies: 频率数量 (2 for 1.5GHz+3GHz)
dropout_rate: Dropout比例
input_size: 输入RCS尺寸 (91)
activation_encoder: Encoder激活函数 (默认'relu')
activation_high: 高频分支Decoder激活函数 (默认'sin')
activation_smooth: 低频分支DecSame scanner, your repo: https://repobility.com — Repobility
__init__ method · python · L414-L538 (125 LOC)autoencoder/models/additive_dual_branch_autoencoder.py
def __init__(self,
latent_dim: int = 256,
num_frequencies: int = 2,
dropout_rate: float = 0.2,
input_size: int = 91,
activation_encoder: str = 'relu',
activation_high: str = 'sin',
activation_smooth: str = 'tanh',
learnable_weights: bool = False,
alpha_high: float = 0.5,
alpha_smooth: float = 0.5,
output_activation: str = None):
"""
初始化Direct模式叠加型双分支AutoEncoder
Args:
latent_dim: 隐空间维度
num_frequencies: 频率数量 (2 for 1.5GHz+3GHz)
dropout_rate: Dropout比例
input_size: 输入RCS尺寸 (91)
activation_encoder: Encoder激活函数 (默认'relu')
activation_high: 高频分支Decoder激活函数 (默认'sin')
activation_smooth: 低频分支Decoder激活函数 (默认'tanh')
learnable_weights: 是否学习叠加权重
alpha_high: 高频分支权重
alpha__build_decoder_fc method · python · L540-L558 (19 LOC)autoencoder/models/additive_dual_branch_autoencoder.py
def _build_decoder_fc(self, activation: str) -> nn.Sequential:
"""构建Decoder全连接层"""
decoder_fc_layers = []
current_dim = self.latent_dim
for intermediate_dim in reversed(self.intermediate_dims):
decoder_fc_layers.extend([
nn.Linear(current_dim, intermediate_dim),
get_activation(activation),
nn.Dropout(self.dropout_rate)
])
current_dim = intermediate_dim
decoder_fc_layers.extend([
nn.Linear(current_dim, self.flattened_size),
get_activation(activation)
])
return nn.Sequential(*decoder_fc_layers)_build_decoder_conv method · python · L560-L583 (24 LOC)autoencoder/models/additive_dual_branch_autoencoder.py
def _build_decoder_conv(self, activation: str) -> nn.Sequential:
"""构建Decoder卷积层"""
return nn.Sequential(
nn.Unflatten(1, (512, 4, 4)),
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
get_activation(activation),
nn.Dropout2d(self.dropout_rate),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
get_activation(activation),
nn.Dropout2d(self.dropout_rate),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
get_activation(activation),
nn.Dropout2d(self.dropout_rate),
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(32),
get_activation(activation)
)_initialize_weights method · python · L585-L597 (13 LOC)autoencoder/models/additive_dual_branch_autoencoder.py
def _initialize_weights(self):
"""权重初始化"""
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)encode method · python · L599-L604 (6 LOC)autoencoder/models/additive_dual_branch_autoencoder.py
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""编码器"""
x = x.permute(0, 3, 1, 2) # [B, H, W, C] → [B, C, H, W]
features = self.encoder(x)
latent = self.encoder_fc(features)
return latentdecode method · python · L606-L638 (33 LOC)autoencoder/models/additive_dual_branch_autoencoder.py
def decode(self, latent: torch.Tensor) -> torch.Tensor:
"""双分支解码器"""
# 高频分支
feat_high = self.decoder_high_fc(latent)
feat_high = self.decoder_high_conv(feat_high)
recon_high = self.final_conv_high(feat_high)
# 低频分支
feat_smooth = self.decoder_smooth_fc(latent)
feat_smooth = self.decoder_smooth_conv(feat_smooth)
recon_smooth = self.final_conv_smooth(feat_smooth)
# 叠加
if self.learnable_weights:
# 可学习权重:直接使用,不强制归一化
alpha_high_norm = self.alpha_high
alpha_smooth_norm = self.alpha_smooth
else:
alpha_high_norm = self.alpha_high
alpha_smooth_norm = self.alpha_smooth
recon = alpha_high_norm * recon_high + alpha_smooth_norm * recon_smooth
# 上采样: [64, 64] → [91, 91]
recon = F.interpolate(recon, size=(self.input_size, self.input_size),
mode='bilinear', align_corners=False)
reforward method · python · L640-L644 (5 LOC)autoencoder/models/additive_dual_branch_autoencoder.py
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""前向传播"""
latent = self.encode(x)
recon = self.decode(latent)
return recon, latentforward_with_branches method · python · L646-L697 (52 LOC)autoencoder/models/additive_dual_branch_autoencoder.py
def forward_with_branches(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
前向传播(返回分支输出,用于可视化和频域分析)
Args:
x: [B, input_size, input_size, num_freq] RCS数据
Returns:
recon: [B, input_size, input_size, num_freq] 叠加后的重建
latent: [B, latent_dim] 隐空间表示
recon_high: [B, input_size, input_size, num_freq] 高频分支重建
recon_smooth: [B, input_size, input_size, num_freq] 低频分支重建
"""
# Encode
latent = self.encode(x)
# ===== 分支1: 高频Decoder =====
feat_high = self.decoder_high_fc(latent)
feat_high = self.decoder_high_conv(feat_high)
recon_high = self.final_conv_high(feat_high)
# ===== 分支2: 低频Decoder =====
feat_smooth = self.decoder_smooth_fc(latent)
feat_smooth = self.decoder_smooth_conv(feat_smooth)
recon_smooth = self.final_conv_smooth(feat_smooth)
# ===== 上采样到目标尺寸: [64, 64Powered by Repobility — scan your code at https://repobility.com
get_parameter_count method · python · L699-L729 (31 LOC)autoencoder/models/additive_dual_branch_autoencoder.py
def get_parameter_count(self) -> Dict[str, int]:
"""获取参数统计"""
# Encoder参数
encoder_params = sum(p.numel() for p in self.encoder.parameters()) + \
sum(p.numel() for p in self.encoder_fc.parameters())
# High-frequency Decoder参数
decoder_high_params = sum(p.numel() for p in self.decoder_high_fc.parameters()) + \
sum(p.numel() for p in self.decoder_high_conv.parameters()) + \
sum(p.numel() for p in self.final_conv_high.parameters())
# Low-frequency Decoder参数
decoder_smooth_params = sum(p.numel() for p in self.decoder_smooth_fc.parameters()) + \
sum(p.numel() for p in self.decoder_smooth_conv.parameters()) + \
sum(p.numel() for p in self.final_conv_smooth.parameters())
# 总decoder参数(两个分支之和)
decoder_params = decoder_high_params + decoder_smooth_params
# 总参数和可训练参数
get_model_info method · python · L731-L746 (16 LOC)autoencoder/models/additive_dual_branch_autoencoder.py
def get_model_info(self) -> dict:
"""返回模型信息"""
return {
'type': 'AdditiveDualBranchDirectCNN',
'latent_dim': self.latent_dim,
'input_channels': self.input_channels,
'input_size': self.input_size,
'activation_encoder': self.activation_encoder_type,
'activation_high': self.activation_high_type,
'activation_smooth': self.activation_smooth_type,
'learnable_weights': self.learnable_weights,
'alpha_high': self.alpha_high.item(),
'alpha_smooth': self.alpha_smooth.item(),
'intermediate_dims': self.intermediate_dims,
'parameters': sum(p.numel() for p in self.parameters() if p.requires_grad)
}AdditiveDualBranchWaveletMLPAutoEncoder class · python · L34-L284 (251 LOC)autoencoder/models/additive_dual_branch_mlp.py
class AdditiveDualBranchWaveletMLPAutoEncoder(BaseAutoEncoder):
"""
叠加型双分支小波MLP AutoEncoder
"""
def _init_layer(self, layer: nn.Linear, activation_name: str, is_first: bool = False):
"""
根据激活函数初始化层权重
"""
if activation_name == 'sin':
# SIREN初始化
omega_0 = 30.0
if is_first:
limit = 1.0 / layer.in_features
nn.init.uniform_(layer.weight, -limit, limit)
else:
limit = np.sqrt(6 / layer.in_features) / omega_0
nn.init.uniform_(layer.weight, -limit, limit)
else:
# 标准Xavier初始化
nn.init.xavier_uniform_(layer.weight)
if layer.bias is not None:
nn.init.constant_(layer.bias, 0)
def __init__(self,
latent_dim: int = 256,
num_frequencies: int = 2,
wavelet_bands: int = 4,
dropout_rate: float = 0.2,
input_sipage 1 / 19next ›