Function bodies 170 total
__init__ method · python · L262-L297 (36 LOC)models/spatial_rope.py
def __init__(
self,
head_dim: int,
n_reference_types: int = 7,
max_seq_len: int = 2048,
base: float = 10000.0,
spatial_weight: float = 0.5, # Balance between spatial and temporal
):
super().__init__()
self.head_dim = head_dim
self.spatial_weight = spatial_weight
# Split head_dim between spatial and temporal
# Each gets half, then split further for sin/cos
self.spatial_dim = head_dim // 2
self.temporal_dim = head_dim - self.spatial_dim
# Ensure dimensions are even for rotation
if self.spatial_dim % 4 != 0:
# Adjust to make spatial_dim divisible by 4
self.spatial_dim = (self.spatial_dim // 4) * 4
self.temporal_dim = head_dim - self.spatial_dim
if self.spatial_dim > 0:
self.spatial_rope = Spatial2DRoPE(
head_dim=self.spatial_dim,
n_reference_types=n_reference_types,
forward method · python · L299-L349 (51 LOC)models/spatial_rope.py
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
channel_indices: torch.Tensor,
time_indices: torch.Tensor,
reference_type_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply combined spatial and temporal RoPE.
Args:
q, k: Shape (batch, n_heads, seq_len, head_dim)
channel_indices: Which channel each position corresponds to (batch, seq_len)
time_indices: Which time step each position corresponds to (batch, seq_len)
reference_type_ids: Reference type IDs (batch,)
Returns:
q_rotated, k_rotated: Rotated tensors
"""
batch_size, n_heads, seq_len, head_dim = q.shape
# Split into spatial and temporal parts
q_spatial = q[..., :self.spatial_dim]
q_temporal = q[..., self.spatial_dim:]
k_spatial = k[..., :self.spatial_dim]
k_temporal = k[..., self.spatial_dSpatialAttention class · python · L352-L447 (96 LOC)models/spatial_rope.py
class SpatialAttention(nn.Module):
"""
Multi-Head Attention with 2D Spatial RoPE for EEG.
This attention operates on channels as positions (not time patches).
Each channel has a 2D position on the scalp, and the attention learns
spatial relationships modulated by the reference type.
Use this for cross-channel attention within a time window.
"""
def __init__(
self,
hidden_dim: int,
n_heads: int,
n_reference_types: int = 7,
dropout: float = 0.1,
use_qk_norm: bool = True,
):
super().__init__()
self.hidden_dim = hidden_dim
self.n_heads = n_heads
self.head_dim = hidden_dim // n_heads
self.scale = self.head_dim ** -0.5
if hidden_dim % n_heads != 0:
raise ValueError(f"hidden_dim ({hidden_dim}) must be divisible by n_heads ({n_heads})")
# QKV projection
self.qkv = nn.Linear(hidden_dim, 3 * hidden_dim, bias=False)
self.out_p__init__ method · python · L363-L396 (34 LOC)models/spatial_rope.py
def __init__(
self,
hidden_dim: int,
n_heads: int,
n_reference_types: int = 7,
dropout: float = 0.1,
use_qk_norm: bool = True,
):
super().__init__()
self.hidden_dim = hidden_dim
self.n_heads = n_heads
self.head_dim = hidden_dim // n_heads
self.scale = self.head_dim ** -0.5
if hidden_dim % n_heads != 0:
raise ValueError(f"hidden_dim ({hidden_dim}) must be divisible by n_heads ({n_heads})")
# QKV projection
self.qkv = nn.Linear(hidden_dim, 3 * hidden_dim, bias=False)
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
# 2D Spatial RoPE
self.spatial_rope = Spatial2DRoPE(
head_dim=self.head_dim,
n_reference_types=n_reference_types,
)
# QK normalization
self.use_qk_norm = use_qk_norm
if use_qk_norm:
self.q_norm = RMSNorm(self.head_dim)
self.k_noforward method · python · L398-L447 (50 LOC)models/spatial_rope.py
def forward(
self,
x: torch.Tensor,
reference_type_ids: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Apply spatial attention.
Args:
x: Input tensor of shape (batch, n_channels, hidden_dim)
where n_channels = 19 for standard EEG
reference_type_ids: Reference type IDs of shape (batch,)
mask: Optional attention mask
Returns:
Output tensor of shape (batch, n_channels, hidden_dim)
"""
B, N, D = x.shape
# QKV projection
qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, self.head_dim)
q, k, v = qkv.unbind(2) # Each: (B, N, n_heads, head_dim)
q = q.transpose(1, 2) # (B, n_heads, N, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# QK normalization
if self.use_qk_norm:
q = self.q_norm(q)
k = self.k_norm(SpatialTransformerBlock class · python · L450-L506 (57 LOC)models/spatial_rope.py
class SpatialTransformerBlock(nn.Module):
"""
Transformer block with 2D Spatial RoPE for EEG.
Pre-norm architecture with RMSNorm and SwiGLU FFN.
"""
def __init__(
self,
hidden_dim: int,
n_heads: int,
n_reference_types: int = 7,
dropout: float = 0.1,
use_qk_norm: bool = True,
):
super().__init__()
self.norm1 = RMSNorm(hidden_dim)
self.attn = SpatialAttention(
hidden_dim=hidden_dim,
n_heads=n_heads,
n_reference_types=n_reference_types,
dropout=dropout,
use_qk_norm=use_qk_norm,
)
self.norm2 = RMSNorm(hidden_dim)
# SwiGLU FFN
hidden_ff = hidden_dim * 4
self.w1 = nn.Linear(hidden_dim, hidden_ff, bias=False)
self.w2 = nn.Linear(hidden_ff, hidden_dim, bias=False)
self.w3 = nn.Linear(hidden_dim, hidden_ff, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(
__init__ method · python · L457-L483 (27 LOC)models/spatial_rope.py
def __init__(
self,
hidden_dim: int,
n_heads: int,
n_reference_types: int = 7,
dropout: float = 0.1,
use_qk_norm: bool = True,
):
super().__init__()
self.norm1 = RMSNorm(hidden_dim)
self.attn = SpatialAttention(
hidden_dim=hidden_dim,
n_heads=n_heads,
n_reference_types=n_reference_types,
dropout=dropout,
use_qk_norm=use_qk_norm,
)
self.norm2 = RMSNorm(hidden_dim)
# SwiGLU FFN
hidden_ff = hidden_dim * 4
self.w1 = nn.Linear(hidden_dim, hidden_ff, bias=False)
self.w2 = nn.Linear(hidden_ff, hidden_dim, bias=False)
self.w3 = nn.Linear(hidden_dim, hidden_ff, bias=False)
self.dropout = nn.Dropout(dropout)Want this analysis on your repo? https://repobility.com/scan/
forward method · python · L485-L506 (22 LOC)models/spatial_rope.py
def forward(
self,
x: torch.Tensor,
reference_type_ids: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass.
Args:
x: Input of shape (batch, n_channels, hidden_dim)
reference_type_ids: Reference type IDs of shape (batch,)
mask: Optional attention mask
"""
# Pre-norm attention with spatial RoPE
x = x + self.attn(self.norm1(x), reference_type_ids, mask)
# Pre-norm SwiGLU FFN
h = self.norm2(x)
x = x + self.dropout(self.w2(F.silu(self.w1(h)) * self.w3(h)))
return xSpatialChannelEncoder class · python · L509-L598 (90 LOC)models/spatial_rope.py
class SpatialChannelEncoder(nn.Module):
"""
Encode EEG channels with spatial awareness.
This encoder treats each channel as a position with 2D coordinates.
It processes each time window across all channels, learning spatial
relationships that are modulated by the reference type.
Input: (batch, n_channels, time_samples)
Output: (batch, n_windows, hidden_dim)
"""
def __init__(
self,
n_channels: int = 19,
hidden_dim: int = 768,
n_heads: int = 12,
n_layers: int = 2, # Fewer layers for spatial processing
n_reference_types: int = 7,
window_size: int = 32,
dropout: float = 0.1,
):
super().__init__()
self.n_channels = n_channels
self.hidden_dim = hidden_dim
self.window_size = window_size
# Per-channel embedding (window_size -> hidden_dim)
self.channel_embed = nn.Linear(window_size, hidden_dim)
# Spatial transformer blocks
__init__ method · python · L521-L555 (35 LOC)models/spatial_rope.py
def __init__(
self,
n_channels: int = 19,
hidden_dim: int = 768,
n_heads: int = 12,
n_layers: int = 2, # Fewer layers for spatial processing
n_reference_types: int = 7,
window_size: int = 32,
dropout: float = 0.1,
):
super().__init__()
self.n_channels = n_channels
self.hidden_dim = hidden_dim
self.window_size = window_size
# Per-channel embedding (window_size -> hidden_dim)
self.channel_embed = nn.Linear(window_size, hidden_dim)
# Spatial transformer blocks
self.spatial_blocks = nn.ModuleList([
SpatialTransformerBlock(
hidden_dim=hidden_dim,
n_heads=n_heads,
n_reference_types=n_reference_types,
dropout=dropout,
)
for _ in range(n_layers)
])
# Pool across channels to get window representation
self.channel_pool = nn.Sequential(forward method · python · L557-L598 (42 LOC)models/spatial_rope.py
def forward(
self,
x: torch.Tensor,
reference_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Encode EEG with spatial awareness.
Args:
x: Input EEG of shape (batch, n_channels, time_samples)
reference_type_ids: Reference type IDs of shape (batch,)
Returns:
Encoded representation of shape (batch, n_windows, hidden_dim)
"""
batch_size, n_channels, total_samples = x.shape
n_windows = total_samples // self.window_size
# Reshape to windows: (batch, n_channels, n_windows, window_size)
x = x[:, :, :n_windows * self.window_size]
x = x.view(batch_size, n_channels, n_windows, self.window_size)
# Process each window with spatial attention
outputs = []
for w in range(n_windows):
# Get window: (batch, n_channels, window_size)
window = x[:, :, w, :]
# Embed each channel: (baget_reference_type_tensor function · python · L601-L620 (20 LOC)models/spatial_rope.py
def get_reference_type_tensor(
reference_names: list,
device: torch.device = None,
) -> torch.Tensor:
"""
Convert a list of reference names to type IDs.
Args:
reference_names: List of reference name strings (e.g., ['A1-A2', 'FCz', 'average'])
device: Target device
Returns:
Tensor of shape (batch,) with reference type IDs
"""
ids = [EEGConfig.REFERENCE_TYPES.get(name, EEGConfig.REFERENCE_TYPES['unknown'])
for name in reference_names]
tensor = torch.tensor(ids, dtype=torch.long)
if device is not None:
tensor = tensor.to(device)
return tensorVariableSpatial2DRoPE class · python · L28-L105 (78 LOC)models/variable_channel_embedding.py
class VariableSpatial2DRoPE(nn.Module):
"""2D RoPE that takes channel coordinates as input (not fixed buffer).
Unlike Spatial2DRoPE which has a fixed 19-channel position buffer,
this version accepts (x, y) coordinates per channel per sample,
enabling variable channel counts across the batch.
"""
def __init__(self, head_dim: int, n_reference_types: int = 7,
base: float = 10000.0, use_reference_conditioning: bool = True):
super().__init__()
self.head_dim = head_dim
self.use_reference_conditioning = use_reference_conditioning
quarter_dim = head_dim // 4
if head_dim % 4 != 0:
raise ValueError(f"head_dim ({head_dim}) must be divisible by 4 for 2D RoPE")
inv_freq = 1.0 / (base ** (torch.arange(0, quarter_dim, dtype=torch.float32) / quarter_dim))
self.register_buffer('inv_freq', inv_freq)
if use_reference_conditioning:
self.reference_freq_scale = nn.Parameter(t__init__ method · python · L36-L51 (16 LOC)models/variable_channel_embedding.py
def __init__(self, head_dim: int, n_reference_types: int = 7,
base: float = 10000.0, use_reference_conditioning: bool = True):
super().__init__()
self.head_dim = head_dim
self.use_reference_conditioning = use_reference_conditioning
quarter_dim = head_dim // 4
if head_dim % 4 != 0:
raise ValueError(f"head_dim ({head_dim}) must be divisible by 4 for 2D RoPE")
inv_freq = 1.0 / (base ** (torch.arange(0, quarter_dim, dtype=torch.float32) / quarter_dim))
self.register_buffer('inv_freq', inv_freq)
if use_reference_conditioning:
self.reference_freq_scale = nn.Parameter(torch.ones(n_reference_types, 2))
self.reference_rotation_offset = nn.Parameter(torch.zeros(n_reference_types, head_dim))forward method · python · L53-L105 (53 LOC)models/variable_channel_embedding.py
def forward(self, q, k, channel_coords, reference_type_ids=None):
"""
Args:
q, k: (batch, n_heads, n_channels, head_dim)
channel_coords: (batch, n_channels, 2) — (x, y) per channel
reference_type_ids: (batch,) — reference type per sample
"""
batch_size, n_heads, n_channels, head_dim = q.shape
device = q.device
dtype = q.dtype
if reference_type_ids is None:
reference_type_ids = torch.full((batch_size,), 2, dtype=torch.long, device=device)
x_pos = channel_coords[:, :, 0] # (batch, n_channels)
y_pos = channel_coords[:, :, 1] # (batch, n_channels)
if self.use_reference_conditioning:
freq_scale = self.reference_freq_scale[reference_type_ids] # (batch, 2)
freq_scale_x = freq_scale[:, 0:1] # (batch, 1)
freq_scale_y = freq_scale[:, 1:2] # (batch, 1)
else:
freq_scale_x = torch.ones(batch_size, 1, devicAll rows above produced by Repobility · https://repobility.com
VariableSpatialPatchEmbedding class · python · L108-L218 (111 LOC)models/variable_channel_embedding.py
class VariableSpatialPatchEmbedding(nn.Module):
"""Spatial-aware patch embedding supporting variable channel counts.
Input: (B, N_patches, max_channels, patch_size) with channel_mask (B, max_channels)
Output: (B, N_patches, hidden_dim)
Channels are identified by their (x, y) scalp coordinates via channel_coords.
Attention is computed only over present channels (masked by channel_mask).
"""
def __init__(
self,
max_channels: int = 64,
patch_size: int = 32,
hidden_dim: int = 768,
n_heads: int = 4,
n_reference_types: int = 7,
dropout: float = 0.1,
):
super().__init__()
self.max_channels = max_channels
self.patch_size = patch_size
self.hidden_dim = hidden_dim
self.head_dim = hidden_dim // n_heads
self.n_heads = n_heads
self.channel_embed = nn.Linear(patch_size, hidden_dim)
self.qkv = nn.Linear(hidden_dim, 3 * hidden_dim, bias=False)
__init__ method · python · L118-L147 (30 LOC)models/variable_channel_embedding.py
def __init__(
self,
max_channels: int = 64,
patch_size: int = 32,
hidden_dim: int = 768,
n_heads: int = 4,
n_reference_types: int = 7,
dropout: float = 0.1,
):
super().__init__()
self.max_channels = max_channels
self.patch_size = patch_size
self.hidden_dim = hidden_dim
self.head_dim = hidden_dim // n_heads
self.n_heads = n_heads
self.channel_embed = nn.Linear(patch_size, hidden_dim)
self.qkv = nn.Linear(hidden_dim, 3 * hidden_dim, bias=False)
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.spatial_rope = VariableSpatial2DRoPE(
head_dim=self.head_dim,
n_reference_types=n_reference_types,
use_reference_conditioning=True,
)
self.q_norm = RMSNorm(self.head_dim)
self.k_norm = RMSNorm(self.head_dim)
self.norm = RMSNorm(hidden_dim)
self.dropout_p = dropouforward method · python · L149-L218 (70 LOC)models/variable_channel_embedding.py
def forward(
self,
patches: torch.Tensor,
channel_mask: Optional[torch.Tensor] = None,
channel_coords: Optional[torch.Tensor] = None,
reference_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
patches: (B, N, C, P) — C can be any number up to max_channels
channel_mask: (B, C) — 1 for present, 0 for absent
channel_coords: (B, C, 2) — (x, y) scalp coordinates per channel
reference_type_ids: (B,) — reference type per sample
"""
B, N, C, P = patches.shape
device = patches.device
# Embed each channel's patch
x = self.channel_embed(patches) # (B, N, C, D)
x = x.view(B * N, C, self.hidden_dim)
# QKV projection
qkv = self.qkv(x).reshape(B * N, C, 3, self.n_heads, self.head_dim)
q, k, v = qkv.unbind(2)
q = q.transpose(1, 2) # (B*N, heads, C, head_dim)
k = k.transpose(1, 2)
get_channel_coords_tensor function · python · L221-L239 (19 LOC)models/variable_channel_embedding.py
def get_channel_coords_tensor(channel_names: List[str], device='cpu') -> torch.Tensor:
"""Convert channel names to (N, 2) coordinate tensor.
Looks up each channel in EEGConfig.CHANNEL_COORDS.
Unknown channels get (0, 0) coordinates.
"""
coords = []
for ch in channel_names:
if ch in EEGConfig.CHANNEL_COORDS:
coords.append(EEGConfig.CHANNEL_COORDS[ch])
else:
# Try common aliases
aliases = {'T7': 'T3', 'T8': 'T4', 'P7': 'T5', 'P8': 'T6'}
alias = aliases.get(ch, ch)
if alias in EEGConfig.CHANNEL_COORDS:
coords.append(EEGConfig.CHANNEL_COORDS[alias])
else:
coords.append((0.0, 0.0)) # Unknown position
return torch.tensor(coords, dtype=torch.float32, device=device)compute_band_powers function · python · L23-L33 (11 LOC)server/eeg_features.py
def compute_band_powers(data_1ch: np.ndarray, srate: int = SRATE) -> dict:
"""Compute relative band powers for one channel."""
freqs, psd = signal.welch(data_1ch, fs=srate, nperseg=min(512, len(data_1ch)))
total_power = np.sum(psd)
if total_power < 1e-12:
return {name: 0.0 for name in BANDS}
powers = {}
for name, (lo, hi) in BANDS.items():
mask = (freqs >= lo) & (freqs < hi)
powers[name] = float(np.sum(psd[mask]) / total_power)
return powerscompute_hjorth function · python · L36-L45 (10 LOC)server/eeg_features.py
def compute_hjorth(data_1ch: np.ndarray) -> dict:
"""Compute Hjorth activity, mobility, complexity."""
activity = float(np.var(data_1ch))
d1 = np.diff(data_1ch)
d2 = np.diff(d1)
var_d1 = np.var(d1) if len(d1) > 0 else 1e-12
var_d2 = np.var(d2) if len(d2) > 0 else 1e-12
mobility = float(np.sqrt(var_d1 / max(activity, 1e-12)))
complexity = float(np.sqrt(var_d2 / max(var_d1, 1e-12)) / max(mobility, 1e-12))
return {'activity': activity, 'mobility': mobility, 'complexity': complexity}compute_spectral_ratios function · python · L48-L60 (13 LOC)server/eeg_features.py
def compute_spectral_ratios(band_powers: dict) -> dict:
"""Compute clinically relevant spectral ratios."""
theta = band_powers.get('theta', 0)
alpha = band_powers.get('alpha', 0)
beta = band_powers.get('beta', 0)
delta = band_powers.get('delta', 0)
return {
'theta_alpha_ratio': theta / max(alpha, 1e-8),
'theta_beta_ratio': theta / max(beta, 1e-8),
'delta_alpha_ratio': delta / max(alpha, 1e-8),
'alpha_beta_ratio': alpha / max(beta, 1e-8),
}compute_statistical_features function · python · L63-L72 (10 LOC)server/eeg_features.py
def compute_statistical_features(data_1ch: np.ndarray) -> dict:
"""Basic statistical features."""
return {
'mean': float(np.mean(data_1ch)),
'std': float(np.std(data_1ch)),
'skewness': float(_skewness(data_1ch)),
'kurtosis': float(_kurtosis(data_1ch)),
'peak_to_peak': float(np.ptp(data_1ch)),
'zero_crossing_rate': float(_zero_crossing_rate(data_1ch)),
}Repobility · MCP-ready · https://repobility.com
_skewness function · python · L75-L79 (5 LOC)server/eeg_features.py
def _skewness(x):
m = np.mean(x)
s = np.std(x)
if s < 1e-12: return 0.0
return np.mean(((x - m) / s) ** 3)_kurtosis function · python · L82-L86 (5 LOC)server/eeg_features.py
def _kurtosis(x):
m = np.mean(x)
s = np.std(x)
if s < 1e-12: return 0.0
return np.mean(((x - m) / s) ** 4) - 3extract_features function · python · L93-L147 (55 LOC)server/eeg_features.py
def extract_features(eeg_data: np.ndarray, channel_mask: Optional[np.ndarray] = None,
srate: int = SRATE) -> np.ndarray:
"""
Extract traditional EEG features from multi-channel data.
Args:
eeg_data: (n_channels, n_samples)
channel_mask: (n_channels,) — 1 for active, 0 for padding
srate: sample rate
Returns:
features: (n_features,) — concatenated per-channel features averaged across active channels
"""
n_ch = eeg_data.shape[0]
if channel_mask is None:
channel_mask = np.ones(n_ch, dtype=np.float32)
active_channels = np.where(channel_mask > 0.5)[0]
if len(active_channels) == 0:
active_channels = np.arange(min(n_ch, 19))
all_features = []
for ch in active_channels:
x = eeg_data[ch]
if np.std(x) < 1e-8:
continue
bands = compute_band_powers(x, srate)
hjorth = compute_hjorth(x)
ratios = compute_spectral_ratios(bands)
EEGInferenceEngine class · python · L109-L286 (178 LOC)server/inference.py
class EEGInferenceEngine:
"""Wraps the frozen encoder for inference."""
def __init__(self, checkpoint_path: str, device: str = 'cuda'):
self.device = device
ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
config = ckpt.get('config', {})
self.max_channels = config.get('max_channels', config.get('n_channels', 19))
self.hidden_dim = config.get('hidden_dim', 512)
self.model = EEGEncoderV3(
n_channels=config.get('n_channels', 19),
hidden_dim=self.hidden_dim,
n_heads=config.get('n_heads', 8),
n_kv_heads=config.get('n_kv_heads', 2),
n_layers=config.get('n_layers', 8),
patch_size=config.get('patch_size', 32),
dropout=0.0,
ffn_multiplier=config.get('ffn_multiplier', 8.0 / 3.0),
rope_base=config.get('rope_base', 500_000.0),
n_registers=config.get('n_registers', 8),
use_spatial_rope=__init__ method · python · L112-L145 (34 LOC)server/inference.py
def __init__(self, checkpoint_path: str, device: str = 'cuda'):
self.device = device
ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
config = ckpt.get('config', {})
self.max_channels = config.get('max_channels', config.get('n_channels', 19))
self.hidden_dim = config.get('hidden_dim', 512)
self.model = EEGEncoderV3(
n_channels=config.get('n_channels', 19),
hidden_dim=self.hidden_dim,
n_heads=config.get('n_heads', 8),
n_kv_heads=config.get('n_kv_heads', 2),
n_layers=config.get('n_layers', 8),
patch_size=config.get('patch_size', 32),
dropout=0.0,
ffn_multiplier=config.get('ffn_multiplier', 8.0 / 3.0),
rope_base=config.get('rope_base', 500_000.0),
n_registers=config.get('n_registers', 8),
use_spatial_rope=config.get('use_spatial_rope', True),
use_modality_tokens=config.encode method · python · L148-L204 (57 LOC)server/inference.py
def encode(self, eeg_data: np.ndarray, channel_mask: np.ndarray,
channel_coords: np.ndarray, metadata: Dict) -> np.ndarray:
"""
Encode full EEG recording → single embedding vector.
Args:
eeg_data: (max_channels, n_samples) float32
channel_mask: (max_channels,) float32
channel_coords: (max_channels, 2) float32
metadata: dict with recording info
Returns:
embedding: (hidden_dim,) float32
"""
device = self.device
eeg = torch.from_numpy(eeg_data).unsqueeze(0).to(device)
mask = torch.from_numpy(channel_mask).unsqueeze(0).to(device)
coords = torch.from_numpy(channel_coords).unsqueeze(0).to(device)
# Zero out inactive channels
eeg = eeg * mask.unsqueeze(-1)
# Trim to patch boundary
ps = 32
T = eeg.shape[2]
eeg = eeg[:, :, :T // ps * ps]
if eeg.shape[2] < ps:
raise ValueErromap_query_to_tasks method · python · L206-L219 (14 LOC)server/inference.py
def map_query_to_tasks(self, query: str) -> List[str]:
"""Map a natural language query to relevant task names."""
query_lower = query.lower()
tasks = set()
for keyword, task_list in QUERY_TASK_MAP.items():
if keyword in query_lower:
tasks.update(task_list)
if not tasks:
# Default to general screening
tasks = set(QUERY_TASK_MAP['general'])
return sorted(tasks)predict method · python · L221-L245 (25 LOC)server/inference.py
def predict(self, embedding: np.ndarray, tasks: List[str]) -> Dict:
"""
Run predictions using the embedding.
For now, returns the embedding with task metadata.
Full SFT prediction will be added in Phase 2.
Returns dict of {task_name: {probability, classification, confidence, model_auc}}.
"""
# Phase 1: Return encoder-based analysis (embedding similarity to training cohorts)
# Phase 2: Will add SFT model for proper Yes/No probability per task
results = {}
for task in tasks:
info = TASK_INFO.get(task, {})
results[task] = {
'probability': None, # Phase 2: SFT prediction
'classification': 'unavailable',
'confidence': 0.0,
'model_auc': info.get('auc', 0.0),
'n_test': info.get('n_test', 0),
'description': info.get('description', task),
'name': info.get('name', task),
All rows scored by the Repobility analyzer (https://repobility.com)
get_supported_conditions method · python · L247-L252 (6 LOC)server/inference.py
def get_supported_conditions(self) -> List[Dict]:
"""Return list of supported conditions with performance metrics."""
return [
{'task': task, **info}
for task, info in sorted(TASK_INFO.items(), key=lambda x: -x[1]['auc'])
]get_model_card method · python · L254-L286 (33 LOC)server/inference.py
def get_model_card(self) -> str:
"""Return model card text."""
return f"""# NeoCog EEG Encoder — Model Card
## Architecture
- **Type**: EEGEncoderV3 (Transformer with SwiGLU, GQA, Spatial RoPE)
- **Parameters**: {sum(p.numel() for p in self.model.parameters()):,}
- **Hidden dim**: {self.hidden_dim}
- **Max channels**: {self.max_channels}
- **Patch size**: 32 samples (160ms at 200Hz)
- **Context**: Up to 4 minutes per chunk (1500 patches)
## Training
- **Objective**: Self-supervised next-patch prediction + spectral feature prediction
- **Data**: ~4000 EEG recordings from 7 clinical datasets
- **Datasets**: AD_EEG, CAUEEG, DS004504, PEARL, TD_BRAIN, READTBI, DORTMUND
- **Channel range**: 4 (portable) to 64 (HD-EEG) channels
- **Preprocessing**: 200Hz resample, robust z-score normalization
## Validation
- Evaluated on held-out patients (no train/test patient overlap)
- XGBoost + TabPFN classifiers on frozen encoder embeddings
- Best AUC: 0.95 (Vascular Dementia), 0.lifespan function · python · L93-L95 (3 LOC)server/mcp_server.py
async def lifespan(app: FastAPI):
async with session_manager.run():
yieldfavicon function · python · L114-L116 (3 LOC)server/mcp_server.py
async def favicon():
from fastapi.responses import FileResponse
return FileResponse(_ICON_PATH, media_type="image/png")upload_eeg function · python · L191-L201 (11 LOC)server/mcp_server.py
async def upload_eeg(file: UploadFile = File(...)):
ext = Path(file.filename).suffix.lower()
if ext not in SUPPORTED_EXTENSIONS:
return JSONResponse(status_code=400,
content={"detail": f"Unsupported format '{ext}'. Use: {', '.join(SUPPORTED_EXTENSIONS)}"})
upload_id = str(uuid.uuid4())[:8]
dest = UPLOAD_DIR / f"{upload_id}_{file.filename}"
with open(dest, "wb") as f_out:
shutil.copyfileobj(file.file, f_out)
return {"upload_id": upload_id, "filename": file.filename,
"size_mb": round(dest.stat().st_size / 1e6, 1)}find_upload function · python · L204-L208 (5 LOC)server/mcp_server.py
def find_upload(upload_id: str) -> Optional[Path]:
for p in UPLOAD_DIR.iterdir():
if p.name.startswith(upload_id + "_"):
return p
return Noneget_or_compute_embedding function · python · L211-L230 (20 LOC)server/mcp_server.py
def get_or_compute_embedding(upload_id: str):
"""Get cached embedding or compute from uploaded file."""
if upload_id in embedding_cache:
return embedding_cache[upload_id], None
filepath = find_upload(upload_id)
if filepath is None:
return None, f"Upload ID '{upload_id}' not found. It may have expired. Please re-upload at /upload."
try:
eeg_data, channel_mask, channel_coords, metadata = preprocess_eeg(filepath)
embedding = engine.encode(eeg_data, channel_mask, channel_coords, metadata)
embedding_cache[upload_id] = embedding
metadata_cache[upload_id] = metadata
# Also cache features while we have the raw data
feats = extract_features(eeg_data, channel_mask)
feature_cache[upload_id] = feats
return embedding, None
except Exception as e:
return None, f"Error processing EEG: {str(e)}"get_or_compute_features function · python · L233-L248 (16 LOC)server/mcp_server.py
def get_or_compute_features(upload_id: str):
"""Get cached features or compute from uploaded file."""
if upload_id in feature_cache:
return feature_cache[upload_id], None
filepath = find_upload(upload_id)
if filepath is None:
return None, f"Upload ID '{upload_id}' not found."
try:
eeg_data, channel_mask, channel_coords, metadata = preprocess_eeg(filepath)
feats = extract_features(eeg_data, channel_mask)
feature_cache[upload_id] = feats
return feats, None
except Exception as e:
return None, f"Error extracting features: {str(e)}"Want this analysis on your repo? https://repobility.com/scan/
run_single_screen function · python · L251-L288 (38 LOC)server/mcp_server.py
def run_single_screen(task: str, clf_data: dict, embedding: np.ndarray) -> dict:
"""Run a single classifier on an embedding. Returns structured result."""
model = clf_data['model']
display_name = clf_data['display_name']
if clf_data['type'] == 'classification':
inp = embedding.reshape(1, -1)
# TabPFN needs PCA transform first
if clf_data.get('method') == 'tabpfn' and 'pca' in clf_data:
inp = clf_data['pca'].transform(inp)
prob = float(model.predict_proba(inp)[0, 1])
pred = "POSITIVE" if prob > 0.5 else "NEGATIVE"
auc = clf_data['auc']
n_test = clf_data['n_test']
if auc < 0.70:
confidence = "LOW"
elif abs(prob - 0.5) < 0.1:
confidence = "EQUIVOCAL"
elif auc >= 0.85:
confidence = "HIGH"
else:
confidence = "MODERATE"
return {
'task': task, 'display_name': display_name, 'type': 'classification',
handle_sse function · python · L301-L303 (3 LOC)server/mcp_server.py
async def handle_sse(request: Request):
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
await mcp_server.run(streams[0], streams[1], mcp_server.create_initialization_options())list_tools function · python · L318-L458 (141 LOC)server/mcp_server.py
async def list_tools():
tools = []
# ── Upload via web page ──
tools.append(Tool(
name="get_upload_link",
description=(
"ALWAYS call this tool FIRST when BrainScope is mentioned. "
"Returns the upload link to show the user so they can submit their EEG file. "
"You must call this before doing anything else."
),
inputSchema={"type": "object", "properties": {}},
))
# ── Recording info ──
tools.append(Tool(
name="get_recording_info",
description="Get basic information about an uploaded EEG recording: duration, channels, sample rate, format.",
inputSchema={
"type": "object",
"properties": {
"upload_id": {"type": "string", "description": "Upload ID from the upload step"},
},
"required": ["upload_id"],
},
))
# ── Run all screens at once ──
tools.append(Tool(
name="run_all_screens",
call_tool function · python · L462-L493 (32 LOC)server/mcp_server.py
async def call_tool(name: str, arguments: dict):
if name == "get_upload_link":
url = "https://b4b9ezp8vloj5o-8080.proxy.runpod.net/upload"
return [TextContent(type="text", text=(
f"Show this link to the user: {url}\n"
f"Format as: [Upload your EEG recording here]({url})\n"
f"Tell them to upload their EEG file (.edf/.bdf/.h5/.set) and paste "
f"the upload_id back. If they don't have an EEG, suggest asking their "
f"doctor to order one."
))]
if name == "get_recording_info":
return await handle_recording_info(arguments)
if name == "run_all_screens":
return await handle_run_all_screens(arguments)
if name == "run_baseline_screens":
return await handle_run_baseline_screens(arguments)
if name == "search_literature":
return await handle_search_literature(arguments)
if name == "get_model_card":
return [TextContent(type="text", text=engine.get_mhandle_recording_info function · python · L498-L531 (34 LOC)server/mcp_server.py
async def handle_recording_info(args: dict):
upload_id = args.get("upload_id")
if not upload_id:
return [TextContent(type="text", text="Please provide an upload_id.")]
filepath = find_upload(upload_id)
if filepath is None:
return [TextContent(type="text", text=f"Upload ID '{upload_id}' not found.")]
try:
eeg_data, mask, coords, metadata = preprocess_eeg(filepath)
n_active = int(mask.sum())
text = (
f"**Recording: {metadata['filename']}**\n"
f"- Duration: {metadata['duration_s']:.0f}s ({metadata['duration_s']/60:.1f} min)\n"
f"- Active EEG channels: {n_active}\n"
f"- Sample rate: {metadata['srate']} Hz\n"
f"- Format: {metadata['source_format']}\n"
f"- Total samples: {metadata['n_samples']:,}\n"
f"\nRecording quality notes:\n"
)
if metadata['duration_s'] < 60:
text += "- WARNING: Recording is very short (<1 min). Reshandle_single_screen function · python · L534-L545 (12 LOC)server/mcp_server.py
async def handle_single_screen(task: str, clf_data: dict, args: dict):
upload_id = args.get("upload_id")
if not upload_id:
return [TextContent(type="text", text="Please provide an upload_id.")]
embedding, error = get_or_compute_embedding(upload_id)
if error:
return [TextContent(type="text", text=error)]
result = run_single_screen(task, clf_data, embedding)
text = _format_single_result(result)
return [TextContent(type="text", text=text)]handle_run_all_screens function · python · L548-L580 (33 LOC)server/mcp_server.py
async def handle_run_all_screens(args: dict):
upload_id = args.get("upload_id")
if not upload_id:
return [TextContent(type="text", text="Please provide an upload_id.")]
if upload_id in screen_cache:
return screen_cache[upload_id]
embedding, error = get_or_compute_embedding(upload_id)
if error:
return [TextContent(type="text", text=error)]
# Run all classifiers in parallel
loop = asyncio.get_event_loop()
with concurrent.futures.ThreadPoolExecutor() as pool:
futures = [
loop.run_in_executor(pool, run_single_screen, task, clf_data, embedding)
for task, clf_data in classifiers.items()
]
results = await asyncio.gather(*futures)
# Generate OpenEvidence-style report with inline citations + visualization JSON
from server.response_generator import generate_enriched_report
narrative, viz_json = generate_enriched_report(
list(results), metadata_cache.get(upload_id)
)handle_run_baseline_screens function · python · L583-L650 (68 LOC)server/mcp_server.py
async def handle_run_baseline_screens(args: dict):
"""Run traditional EEG feature classifiers as baseline comparison."""
upload_id = args.get("upload_id")
if not upload_id:
return [TextContent(type="text", text="Please provide an upload_id.")]
features, error = get_or_compute_features(upload_id)
if error:
return [TextContent(type="text", text=error)]
if not feature_classifiers:
return [TextContent(type="text", text="No feature baseline classifiers loaded.")]
lines = ["# Traditional EEG Feature Analysis (Baseline)\n"]
lines.append("*Using handcrafted features: band powers (delta/theta/alpha/beta/gamma), "
"Hjorth parameters (activity/mobility/complexity), spectral ratios, "
"and statistical measures.*\n")
def _run_feature_clf(task, clf_data):
model = clf_data['model']
inp = features.reshape(1, -1)
if clf_data['type'] == 'classification':
prob = float(modeAll rows above produced by Repobility · https://repobility.com
handle_search_literature function · python · L653-L693 (41 LOC)server/mcp_server.py
async def handle_search_literature(args: dict):
from server.response_generator import CITATIONS
conditions = args.get("conditions", [])
if not conditions:
return [TextContent(type="text", text="Please provide a list of conditions to search for.")]
lines = ["# Relevant Literature\n"]
ref_num = 1
seen = set()
for condition in conditions:
cond_lower = condition.lower().strip()
matched_keys = []
for key in CITATIONS:
if key in cond_lower or cond_lower in key:
matched_keys.append(key)
if not matched_keys:
for key in CITATIONS:
if any(word in cond_lower for word in key.split('_')):
matched_keys.append(key)
for key in matched_keys:
for ref in CITATIONS[key]:
if ref['citation'] not in seen:
seen.add(ref['citation'])
tags_str = " | ".join(f"[{t}]" for t in ref['tags'])
_format_single_result function · python · L696-L714 (19 LOC)server/mcp_server.py
def _format_single_result(result: dict) -> str:
if result['type'] == 'classification':
return (
f"**{result['display_name']}**\n\n"
f"- Prediction: **{result['prediction']}**\n"
f"- Probability: {result['probability']:.3f}\n"
f"- Confidence: {result['confidence']}\n"
f"- Model AUC: {result['auc']:.2f} (validated on {result['n_test']} held-out patients)\n\n"
f"*Screening result — not a clinical diagnosis.*"
)
else:
lo, hi = result['valid_range']
return (
f"**{result['display_name']}**\n\n"
f"- Predicted value: **{result['predicted_value']:.1f}**\n"
f"- Valid range: [{lo:.0f}, {hi:.0f}]\n"
f"- Model correlation: r={result['pearson_r']:.2f} (validated on {result['n_test']} patients)\n\n"
f"*Screening result — not a clinical diagnosis.*"
)api_list_tools function · python · L720-L723 (4 LOC)server/mcp_server.py
async def api_list_tools():
"""REST endpoint: list all tools (for stdio proxy)."""
tools = await list_tools()
return [{"name": t.name, "description": t.description, "inputSchema": t.inputSchema} for t in tools]