Function bodies 170 total
main function · python · L966-L1237 (272 LOC)eval/evaluate_v3_linear_probe.py
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', type=str,
default='checkpoints/pretrain_v3_20260208_173529/best_model.pt')
parser.add_argument('--max_length', type=int, default=24000)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--tasks', type=str, nargs='+', default=None,
help='Specific tasks to evaluate (default: all)')
parser.add_argument('--classification_only', action='store_true')
parser.add_argument('--regression_only', action='store_true')
parser.add_argument('--include_random', action='store_true',
help='Also evaluate random (untrained) v3 baseline')
parser.add_argument('--condition_cache', type=str,
default='data/condition_embeddings.pt',
help='Path to condition embeddings cache')
parser.adlist_tools function · python · L33-L41 (9 LOC)mcp_stdio_proxy.py
async def list_tools():
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.get(f"{SERVER_URL}/api/tools")
resp.raise_for_status()
tools_data = resp.json()
return [
Tool(name=t["name"], description=t["description"], inputSchema=t["inputSchema"])
for t in tools_data
]call_tool function · python · L45-L53 (9 LOC)mcp_stdio_proxy.py
async def call_tool(name: str, arguments: dict):
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(
f"{SERVER_URL}/api/call_tool",
json={"name": name, "arguments": arguments},
)
resp.raise_for_status()
data = resp.json()
return [TextContent(type="text", text=data.get("text", str(data)))]main function · python · L56-L58 (3 LOC)mcp_stdio_proxy.py
async def main():
async with stdio_server() as (read_stream, write_stream):
await server.run(read_stream, write_stream, server.create_initialization_options())RMSNorm class · python · L7-L16 (10 LOC)models/common.py
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization."""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight__init__ method · python · L10-L13 (4 LOC)models/common.py
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))RotaryPositionalEmbedding class · python · L36-L59 (24 LOC)models/eeg_encoder_v3.py
class RotaryPositionalEmbedding(nn.Module):
def __init__(self, dim: int, max_seq_len: int = 4096, base: float = 500_000.0):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
self._precompute_cache(max_seq_len)
def _precompute_cache(self, seq_len: int):
t = torch.arange(seq_len, device=self.inv_freq.device)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer('cos_cached', emb.cos()[None, None, :, :])
self.register_buffer('sin_cached', emb.sin()[None, None, :, :])
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
seq_len = x.shape[2]
if seq_len > self.cos_cached.shape[2]:
self._precompute_cache(seq_len)
return (
self.cos_cached[:, :, :seq_len, :],
Want fix-PRs on findings? Install Repobility's GitHub App · github.com/apps/repobility-bot
__init__ method · python · L37-L43 (7 LOC)models/eeg_encoder_v3.py
def __init__(self, dim: int, max_seq_len: int = 4096, base: float = 500_000.0):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
self._precompute_cache(max_seq_len)_precompute_cache method · python · L45-L50 (6 LOC)models/eeg_encoder_v3.py
def _precompute_cache(self, seq_len: int):
t = torch.arange(seq_len, device=self.inv_freq.device)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer('cos_cached', emb.cos()[None, None, :, :])
self.register_buffer('sin_cached', emb.sin()[None, None, :, :])forward method · python · L52-L59 (8 LOC)models/eeg_encoder_v3.py
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
seq_len = x.shape[2]
if seq_len > self.cos_cached.shape[2]:
self._precompute_cache(seq_len)
return (
self.cos_cached[:, :, :seq_len, :],
self.sin_cached[:, :, :seq_len, :],
)_rotate_half function · python · L62-L64 (3 LOC)models/eeg_encoder_v3.py
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)_apply_rope function · python · L67-L69 (3 LOC)models/eeg_encoder_v3.py
def _apply_rope(q, k, cos, sin):
return (q * cos + _rotate_half(q) * sin,
k * cos + _rotate_half(k) * sin)GQAttention class · python · L76-L144 (69 LOC)models/eeg_encoder_v3.py
class GQAttention(nn.Module):
"""Pre-norm Grouped Query Attention with QK-Norm and RoPE."""
def __init__(
self,
dim: int,
n_heads: int = 24,
n_kv_heads: int = 6,
dropout: float = 0.1,
use_qk_norm: bool = True,
use_rope: bool = True,
rope_base: float = 500_000.0,
):
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = dim // n_heads
self.n_rep = n_heads // n_kv_heads # repetitions per KV head
self.q_proj = nn.Linear(dim, n_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
self.out_proj = nn.Linear(dim, dim, bias=False)
self.use_qk_norm = use_qk_norm
if use_qk_norm:
self.q_norm = RMSNorm(self.head_dim)
self.k_norm = RMSNor__init__ method · python · L79-L110 (32 LOC)models/eeg_encoder_v3.py
def __init__(
self,
dim: int,
n_heads: int = 24,
n_kv_heads: int = 6,
dropout: float = 0.1,
use_qk_norm: bool = True,
use_rope: bool = True,
rope_base: float = 500_000.0,
):
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = dim // n_heads
self.n_rep = n_heads // n_kv_heads # repetitions per KV head
self.q_proj = nn.Linear(dim, n_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
self.out_proj = nn.Linear(dim, dim, bias=False)
self.use_qk_norm = use_qk_norm
if use_qk_norm:
self.q_norm = RMSNorm(self.head_dim)
self.k_norm = RMSNorm(self.head_dim)
self.use_rope = use_rope
if use_rope:
self.rope = Rforward method · python · L112-L144 (33 LOC)models/eeg_encoder_v3.py
def forward(self, x: torch.Tensor, is_causal: bool = True) -> torch.Tensor:
B, T, _ = x.shape
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
if self.use_qk_norm:
q = self.q_norm(q)
k = self.k_norm(k)
if self.use_rope:
cos, sin = self.rope(q)
# RoPE only applied to Q heads and corresponding KV heads
# Expand cos/sin for Q heads
q = q * cos + _rotate_half(q) * sin
# For KV heads, take every n_rep-th slice of cos/sin
cos_kv = cos[:, :self.n_kv_heads, :, :]
sin_kv = sin[:, :self.n_kv_heads, :, :]
k = k * cos_kv + _rotate_half(k) * sin_kv
# Repeat KV heads to match Q heads
if self.n_rep > 1:
k = k.repeat_interIf a scraper extracted this row, it came from Repobility (https://repobility.com)
SwiGLUFFN class · python · L151-L163 (13 LOC)models/eeg_encoder_v3.py
class SwiGLUFFN(nn.Module):
def __init__(self, dim: int, multiplier: float = 8.0 / 3.0, dropout: float = 0.1):
super().__init__()
hidden = int(dim * multiplier)
# Round to nearest multiple of 128 for GPU efficiency
hidden = ((hidden + 127) // 128) * 128
self.w1 = nn.Linear(dim, hidden, bias=False)
self.w2 = nn.Linear(hidden, dim, bias=False)
self.w3 = nn.Linear(dim, hidden, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))__init__ method · python · L152-L160 (9 LOC)models/eeg_encoder_v3.py
def __init__(self, dim: int, multiplier: float = 8.0 / 3.0, dropout: float = 0.1):
super().__init__()
hidden = int(dim * multiplier)
# Round to nearest multiple of 128 for GPU efficiency
hidden = ((hidden + 127) // 128) * 128
self.w1 = nn.Linear(dim, hidden, bias=False)
self.w2 = nn.Linear(hidden, dim, bias=False)
self.w3 = nn.Linear(dim, hidden, bias=False)
self.dropout = nn.Dropout(dropout)TransformerBlockV3 class · python · L170-L193 (24 LOC)models/eeg_encoder_v3.py
class TransformerBlockV3(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: int,
dropout: float = 0.1,
ffn_multiplier: float = 8.0 / 3.0,
use_qk_norm: bool = True,
use_rope: bool = True,
rope_base: float = 500_000.0,
):
super().__init__()
self.norm1 = RMSNorm(dim)
self.attn = GQAttention(
dim, n_heads, n_kv_heads, dropout, use_qk_norm, use_rope, rope_base,
)
self.norm2 = RMSNorm(dim)
self.ffn = SwiGLUFFN(dim, multiplier=ffn_multiplier, dropout=dropout)
def forward(self, x: torch.Tensor, is_causal: bool = True) -> torch.Tensor:
x = x + self.attn(self.norm1(x), is_causal=is_causal)
x = x + self.ffn(self.norm2(x))
return x__init__ method · python · L171-L188 (18 LOC)models/eeg_encoder_v3.py
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: int,
dropout: float = 0.1,
ffn_multiplier: float = 8.0 / 3.0,
use_qk_norm: bool = True,
use_rope: bool = True,
rope_base: float = 500_000.0,
):
super().__init__()
self.norm1 = RMSNorm(dim)
self.attn = GQAttention(
dim, n_heads, n_kv_heads, dropout, use_qk_norm, use_rope, rope_base,
)
self.norm2 = RMSNorm(dim)
self.ffn = SwiGLUFFN(dim, multiplier=ffn_multiplier, dropout=dropout)forward method · python · L190-L193 (4 LOC)models/eeg_encoder_v3.py
def forward(self, x: torch.Tensor, is_causal: bool = True) -> torch.Tensor:
x = x + self.attn(self.norm1(x), is_causal=is_causal)
x = x + self.ffn(self.norm2(x))
return xSpatialPatchEmbedding class · python · L200-L293 (94 LOC)models/eeg_encoder_v3.py
class SpatialPatchEmbedding(nn.Module):
"""Spatial-aware patch embedding using 2D RoPE — pinned from v2."""
def __init__(
self,
n_channels: int = 19,
patch_size: int = 32,
hidden_dim: int = 768,
n_heads: int = 4,
n_reference_types: int = 7,
dropout: float = 0.1,
):
super().__init__()
self.n_channels = n_channels
self.patch_size = patch_size
self.hidden_dim = hidden_dim
self.head_dim = hidden_dim // n_heads
self.channel_embed = nn.Linear(patch_size, hidden_dim)
self.qkv = nn.Linear(hidden_dim, 3 * hidden_dim, bias=False)
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.n_heads = n_heads
self.spatial_rope = Spatial2DRoPE(
head_dim=self.head_dim,
n_reference_types=n_reference_types,
use_reference_conditioning=True,
)
self.q_norm = RMSNorm(self.head_dim)
self.k_n__init__ method · python · L203-L234 (32 LOC)models/eeg_encoder_v3.py
def __init__(
self,
n_channels: int = 19,
patch_size: int = 32,
hidden_dim: int = 768,
n_heads: int = 4,
n_reference_types: int = 7,
dropout: float = 0.1,
):
super().__init__()
self.n_channels = n_channels
self.patch_size = patch_size
self.hidden_dim = hidden_dim
self.head_dim = hidden_dim // n_heads
self.channel_embed = nn.Linear(patch_size, hidden_dim)
self.qkv = nn.Linear(hidden_dim, 3 * hidden_dim, bias=False)
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.n_heads = n_heads
self.spatial_rope = Spatial2DRoPE(
head_dim=self.head_dim,
n_reference_types=n_reference_types,
use_reference_conditioning=True,
)
self.q_norm = RMSNorm(self.head_dim)
self.k_norm = RMSNorm(self.head_dim)
self.norm = RMSNorm(hidden_dim)
self.dropout_p = dropoutforward method · python · L236-L293 (58 LOC)models/eeg_encoder_v3.py
def forward(
self,
patches: torch.Tensor,
channel_mask: Optional[torch.Tensor] = None,
reference_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, N, C, P = patches.shape
device = patches.device
x = self.channel_embed(patches)
x = x.view(B * N, C, self.hidden_dim)
qkv = self.qkv(x).reshape(B * N, C, 3, self.n_heads, self.head_dim)
q, k, v = qkv.unbind(2)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
q = self.q_norm(q)
k = self.k_norm(k)
if reference_type_ids is not None:
ref_ids_expanded = reference_type_ids.repeat_interleave(N)
else:
ref_ids_expanded = None
q, k = self.spatial_rope(q, k, ref_ids_expanded)
if channel_mask is not None:
mask_expanded = channel_mask.repeat_interleave(N, dim=0)
attn_mask_bool = (mask_expanded < 0.5).unsqueeze(1).unRepobility · code-quality intelligence · https://repobility.com
PredictionHead class · python · L300-L317 (18 LOC)models/eeg_encoder_v3.py
class PredictionHead(nn.Module):
"""SwiGLU residual block + projection — pinned from v2."""
def __init__(self, hidden_dim: int, patch_dim: int):
super().__init__()
self.norm = RMSNorm(hidden_dim)
self.up = nn.Linear(hidden_dim, hidden_dim * 2, bias=False)
self.gate = nn.Linear(hidden_dim, hidden_dim * 2, bias=False)
self.down = nn.Linear(hidden_dim * 2, hidden_dim, bias=False)
self.out_norm = RMSNorm(hidden_dim)
self.proj = nn.Linear(hidden_dim, patch_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.norm(x)
h = F.silu(self.up(h)) * self.gate(h)
h = self.down(h)
x = x + h
return self.proj(self.out_norm(x))__init__ method · python · L303-L310 (8 LOC)models/eeg_encoder_v3.py
def __init__(self, hidden_dim: int, patch_dim: int):
super().__init__()
self.norm = RMSNorm(hidden_dim)
self.up = nn.Linear(hidden_dim, hidden_dim * 2, bias=False)
self.gate = nn.Linear(hidden_dim, hidden_dim * 2, bias=False)
self.down = nn.Linear(hidden_dim * 2, hidden_dim, bias=False)
self.out_norm = RMSNorm(hidden_dim)
self.proj = nn.Linear(hidden_dim, patch_dim)forward method · python · L312-L317 (6 LOC)models/eeg_encoder_v3.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.norm(x)
h = F.silu(self.up(h)) * self.gate(h)
h = self.down(h)
x = x + h
return self.proj(self.out_norm(x))FeaturePredictionHead class · python · L324-L349 (26 LOC)models/eeg_encoder_v3.py
class FeaturePredictionHead(nn.Module):
"""SwiGLU residual block + projection for EEG feature prediction.
Predicts per-channel features: 5 band powers + 3 Hjorth + 1 variance = 9
Output shape: (B, N, C*9) = (B, N, 171)
Matches PredictionHead architecture for sufficient capacity.
"""
def __init__(self, hidden_dim: int, n_channels: int, n_features: int = 9):
super().__init__()
self.n_channels = n_channels
self.n_features = n_features
out_dim = n_channels * n_features
self.norm = RMSNorm(hidden_dim)
self.up = nn.Linear(hidden_dim, hidden_dim * 2, bias=False)
self.gate = nn.Linear(hidden_dim, hidden_dim * 2, bias=False)
self.down = nn.Linear(hidden_dim * 2, hidden_dim, bias=False)
self.out_norm = RMSNorm(hidden_dim)
self.proj = nn.Linear(hidden_dim, out_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.norm(x)
h = F.silu(self.up(h)) * self.gate(h)
__init__ method · python · L332-L342 (11 LOC)models/eeg_encoder_v3.py
def __init__(self, hidden_dim: int, n_channels: int, n_features: int = 9):
super().__init__()
self.n_channels = n_channels
self.n_features = n_features
out_dim = n_channels * n_features
self.norm = RMSNorm(hidden_dim)
self.up = nn.Linear(hidden_dim, hidden_dim * 2, bias=False)
self.gate = nn.Linear(hidden_dim, hidden_dim * 2, bias=False)
self.down = nn.Linear(hidden_dim * 2, hidden_dim, bias=False)
self.out_norm = RMSNorm(hidden_dim)
self.proj = nn.Linear(hidden_dim, out_dim)forward method · python · L344-L349 (6 LOC)models/eeg_encoder_v3.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.norm(x)
h = F.silu(self.up(h)) * self.gate(h)
h = self.down(h)
x = x + h
return self.proj(self.out_norm(x))MTPModule class · python · L356-L388 (33 LOC)models/eeg_encoder_v3.py
class MTPModule(nn.Module):
"""One depth of multi-step prediction.
Takes the previous depth's hidden states and the ground-truth next-patch
embedding, concatenates (after RMSNorm), projects back to hidden_dim,
runs through one transformer layer, then reuses the shared prediction heads.
"""
def __init__(self, hidden_dim: int, n_heads: int, n_kv_heads: int,
dropout: float, ffn_multiplier: float, rope_base: float):
super().__init__()
self.norm_hidden = RMSNorm(hidden_dim)
self.norm_embed = RMSNorm(hidden_dim)
self.proj = nn.Linear(hidden_dim * 2, hidden_dim, bias=False)
self.block = TransformerBlockV3(
dim=hidden_dim, n_heads=n_heads, n_kv_heads=n_kv_heads,
dropout=dropout, ffn_multiplier=ffn_multiplier,
use_qk_norm=True, use_rope=True, rope_base=rope_base,
)
def forward(self, hidden: torch.Tensor, patch_embeds: torch.Tensor,
is_causal: bool =__init__ method · python · L364-L374 (11 LOC)models/eeg_encoder_v3.py
def __init__(self, hidden_dim: int, n_heads: int, n_kv_heads: int,
dropout: float, ffn_multiplier: float, rope_base: float):
super().__init__()
self.norm_hidden = RMSNorm(hidden_dim)
self.norm_embed = RMSNorm(hidden_dim)
self.proj = nn.Linear(hidden_dim * 2, hidden_dim, bias=False)
self.block = TransformerBlockV3(
dim=hidden_dim, n_heads=n_heads, n_kv_heads=n_kv_heads,
dropout=dropout, ffn_multiplier=ffn_multiplier,
use_qk_norm=True, use_rope=True, rope_base=rope_base,
)Open data scored by Repobility · https://repobility.com
forward method · python · L376-L388 (13 LOC)models/eeg_encoder_v3.py
def forward(self, hidden: torch.Tensor, patch_embeds: torch.Tensor,
is_causal: bool = True) -> torch.Tensor:
"""
Args:
hidden: (B, N, D) hidden states from previous depth
patch_embeds: (B, N, D) embedded ground-truth patches at offset k
is_causal: whether to use causal masking
Returns:
(B, N, D) hidden states for this depth
"""
h = torch.cat([self.norm_hidden(hidden), self.norm_embed(patch_embeds)], dim=-1)
h = self.proj(h)
return self.block(h, is_causal=is_causal)EEGEncoderV3 class · python · L395-L725 (331 LOC)models/eeg_encoder_v3.py
class EEGEncoderV3(nn.Module):
"""
EEG Encoder v3 for next-patch prediction with feature losses.
Scaled transformer with GQA, SwiGLU 8/3, RoPE base=500k, scaled init.
Optional register tokens for working memory.
"""
def __init__(
self,
n_channels: int = 19,
hidden_dim: int = 1536,
n_heads: int = 24,
n_kv_heads: int = 6,
n_layers: int = 24,
dropout: float = 0.1,
patch_size: int = 32,
ffn_multiplier: float = 8.0 / 3.0,
rope_base: float = 500_000.0,
use_qk_norm: bool = True,
use_rope: bool = True,
use_condition_encoding: bool = True,
condition_cache_path: str = 'data/condition_embeddings.pt',
max_channels: int = None,
use_spatial_rope: bool = True,
n_registers: int = 0,
use_modality_tokens: bool = False,
mtp_depths: int = 0,
):
super().__init__()
self.n_channels = n_channels
self.max_ch__init__ method · python · L403-L533 (131 LOC)models/eeg_encoder_v3.py
def __init__(
self,
n_channels: int = 19,
hidden_dim: int = 1536,
n_heads: int = 24,
n_kv_heads: int = 6,
n_layers: int = 24,
dropout: float = 0.1,
patch_size: int = 32,
ffn_multiplier: float = 8.0 / 3.0,
rope_base: float = 500_000.0,
use_qk_norm: bool = True,
use_rope: bool = True,
use_condition_encoding: bool = True,
condition_cache_path: str = 'data/condition_embeddings.pt',
max_channels: int = None,
use_spatial_rope: bool = True,
n_registers: int = 0,
use_modality_tokens: bool = False,
mtp_depths: int = 0,
):
super().__init__()
self.n_channels = n_channels
self.max_channels = max_channels or n_channels
self.hidden_dim = hidden_dim
self.patch_size = patch_size
self.n_layers = n_layers
self.mtp_depths = mtp_depths
self.n_heads = n_heads
self.n_kv_heads = n_init_weights method · python · L535-L548 (14 LOC)models/eeg_encoder_v3.py
def _init_weights(self):
"""Apply scaled init: residual projections get 1/sqrt(2*N_layers)."""
residual_std = 0.02 / math.sqrt(2 * self.n_layers)
for module in self.modules():
if isinstance(module, nn.Linear):
torch.nn.init.trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
# Scale down residual projections
for block in self.blocks:
torch.nn.init.trunc_normal_(block.attn.out_proj.weight, std=residual_std)
torch.nn.init.trunc_normal_(block.ffn.w2.weight, std=residual_std)patchify method · python · L550-L557 (8 LOC)models/eeg_encoder_v3.py
def patchify(self, eeg: torch.Tensor, flatten: bool = True) -> torch.Tensor:
B, C, T = eeg.shape
n_patches = T // self.patch_size
eeg = eeg[:, :, :n_patches * self.patch_size]
patches = eeg.view(B, C, n_patches, self.patch_size).permute(0, 2, 1, 3)
if flatten:
patches = patches.reshape(B, n_patches, -1)
return patchesencode_condition method · python · L559-L581 (23 LOC)models/eeg_encoder_v3.py
def encode_condition(self, condition_keys, device):
if not self.use_condition_encoding or self.condition_cache is None:
return None
# Move cache to device once (avoids CPU→GPU copy every forward pass)
if self._condition_cache_device != device:
self.condition_cache = {k: v.to(device) for k, v in self.condition_cache.items()}
self._condition_cache_device = device
embeddings = []
for dataset, recording_state in condition_keys:
key_str = f"{dataset}|{recording_state}"
if key_str in self.condition_cache:
embed = self.condition_cache[key_str]
else:
key_str = f"{dataset}|default"
if key_str in self.condition_cache:
embed = self.condition_cache[key_str]
else:
import warnings
warnings.warn(f"No condition embedding for {dataset}|{recording_state}, using arbitrencode method · python · L583-L652 (70 LOC)models/eeg_encoder_v3.py
def encode(
self,
eeg: torch.Tensor,
condition_keys=None,
channel_mask=None,
reference_type_ids=None,
channel_coords=None,
) -> Dict[str, torch.Tensor]:
if self.use_spatial_rope:
patches = self.patchify(eeg, flatten=False)
if self.use_variable_channels:
x = self.patch_embed(
patches, channel_mask=channel_mask,
channel_coords=channel_coords,
reference_type_ids=reference_type_ids,
)
else:
x = self.patch_embed(patches, channel_mask=channel_mask, reference_type_ids=reference_type_ids)
else:
patches = self.patchify(eeg, flatten=True)
x = self.patch_embed(patches)
B = x.shape[0]
if self.use_modality_tokens:
# Build discrete condition token
if condition_keys is not None and self.use_condition_encoding:
forward method · python · L654-L725 (72 LOC)models/eeg_encoder_v3.py
def forward(
self,
eeg: torch.Tensor,
condition_keys=None,
channel_mask=None,
reference_type_ids=None,
channel_coords=None,
) -> Dict[str, torch.Tensor]:
if channel_mask is not None:
eeg = eeg * channel_mask.unsqueeze(-1)
target_patches = self.patchify(eeg, flatten=True)
enc_out = self.encode(
eeg, condition_keys=condition_keys,
channel_mask=channel_mask, reference_type_ids=reference_type_ids,
channel_coords=channel_coords,
)
hidden = enc_out['hidden']
next_patch_pred = self.next_patch_head(hidden[:, :-1, :])
feature_pred = self.feature_head(hidden[:, :-1, :])
result = {
'hidden': hidden,
'predictions': next_patch_pred,
'feature_predictions': feature_pred,
'targets': target_patches,
'channel_mask': channel_mask,
}
# Multi-token prediction (tWant fix-PRs on findings? Install Repobility's GitHub App · github.com/apps/repobility-bot
rotate_half function · python · L27-L30 (4 LOC)models/spatial_rope.py
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotate half the hidden dims."""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)apply_rotary_pos_emb function · python · L33-L42 (10 LOC)models/spatial_rope.py
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary positional embeddings to Q and K."""
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embedSpatial2DRoPE class · python · L45-L188 (144 LOC)models/spatial_rope.py
class Spatial2DRoPE(nn.Module):
"""
2D Rotary Position Embedding for EEG with Reference Conditioning.
Maps 19 EEG channels to their (x, y) scalp coordinates and applies
rotary embeddings. Reference type modulates the rotation frequencies,
allowing the model to learn different spatial relationships for different
reference schemes.
The core idea:
- Standard RoPE: rotation angle = position * frequency
- 2D RoPE: rotation angle = (x * freq_x) + (y * freq_y)
- Reference-conditioned: frequencies are scaled by reference type
This is applied to the channel dimension, not the temporal dimension.
For full EEG encoding, combine with temporal RoPE.
"""
def __init__(
self,
head_dim: int,
n_reference_types: int = 7,
base: float = 10000.0,
use_reference_conditioning: bool = True,
):
super().__init__()
self.head_dim = head_dim
self.n_reference_types = n_reference_types
__init__ method · python · L63-L103 (41 LOC)models/spatial_rope.py
def __init__(
self,
head_dim: int,
n_reference_types: int = 7,
base: float = 10000.0,
use_reference_conditioning: bool = True,
):
super().__init__()
self.head_dim = head_dim
self.n_reference_types = n_reference_types
self.base = base
self.use_reference_conditioning = use_reference_conditioning
# We split head_dim into 4 parts: x_sin, x_cos, y_sin, y_cos
# Each part gets head_dim // 4 dimensions
quarter_dim = head_dim // 4
if head_dim % 4 != 0:
raise ValueError(f"head_dim ({head_dim}) must be divisible by 4 for 2D RoPE")
# Base frequencies for x and y dimensions
inv_freq = 1.0 / (base ** (torch.arange(0, quarter_dim, dtype=torch.float32) / quarter_dim))
self.register_buffer('inv_freq', inv_freq)
# Channel positions from EEGConfig (normalized to ~[-1, 1])
# Standard 19-channel 10-20 montage
channel_coords = get_reference_type_id method · python · L105-L107 (3 LOC)models/spatial_rope.py
def get_reference_type_id(self, reference_name: str) -> int:
"""Convert reference name to type ID."""
return EEGConfig.REFERENCE_TYPES.get(reference_name, EEGConfig.REFERENCE_TYPES['unknown'])forward method · python · L109-L188 (80 LOC)models/spatial_rope.py
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
reference_type_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply 2D spatial RoPE with reference conditioning.
Args:
q: Query tensor of shape (batch, n_heads, n_channels, head_dim)
k: Key tensor of shape (batch, n_heads, n_channels, head_dim)
reference_type_ids: Reference type IDs of shape (batch,)
If None, uses reference type 2 (average) as default.
Returns:
q_rotated, k_rotated: Rotated query and key tensors
"""
batch_size, n_heads, n_channels, head_dim = q.shape
device = q.device
dtype = q.dtype
# Default to average reference if not specified
if reference_type_ids is None:
reference_type_ids = torch.full(
(batch_size,), 2, dtype=torch.long, device=device
)
# Get (xTemporalRoPE class · python · L191-L243 (53 LOC)models/spatial_rope.py
class TemporalRoPE(nn.Module):
"""
Standard 1D Rotary Position Embedding for temporal dimension.
Applied to patch/time positions in the EEG sequence.
"""
def __init__(self, dim: int, max_seq_len: int = 2048, base: float = 10000.0):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
# Compute inverse frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
# Precompute cos/sin embeddings
self._precompute_cache(max_seq_len)
def _precompute_cache(self, seq_len: int):
t = torch.arange(seq_len, device=self.inv_freq.device)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False)
self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False)
def forward(
__init__ method · python · L198-L208 (11 LOC)models/spatial_rope.py
def __init__(self, dim: int, max_seq_len: int = 2048, base: float = 10000.0):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
# Compute inverse frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
# Precompute cos/sin embeddings
self._precompute_cache(max_seq_len)If a scraper extracted this row, it came from Repobility (https://repobility.com)
_precompute_cache method · python · L210-L215 (6 LOC)models/spatial_rope.py
def _precompute_cache(self, seq_len: int):
t = torch.arange(seq_len, device=self.inv_freq.device)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False)
self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False)forward method · python · L217-L243 (27 LOC)models/spatial_rope.py
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
seq_len: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply temporal RoPE.
Args:
q, k: Shape (batch, n_heads, seq_len, head_dim)
seq_len: Optional sequence length override
Returns:
q_rotated, k_rotated: Rotated tensors
"""
if seq_len is None:
seq_len = q.shape[2]
# Extend cache if needed
if seq_len > self.cos_cached.shape[2]:
self._precompute_cache(seq_len)
cos = self.cos_cached[:, :, :seq_len, :].to(q.dtype)
sin = self.sin_cached[:, :, :seq_len, :].to(q.dtype)
return apply_rotary_pos_emb(q, k, cos, sin)SpatialTemporalRoPE class · python · L246-L349 (104 LOC)models/spatial_rope.py
class SpatialTemporalRoPE(nn.Module):
"""
Combined Spatial (2D) and Temporal (1D) RoPE for EEG.
For a sequence where each position represents a specific (channel, time) pair,
this applies both spatial and temporal rotary embeddings.
The key insight is that EEG has two positional dimensions:
1. Spatial: Which channel (electrode position on scalp)
2. Temporal: When in the recording
This module handles the case where the sequence is structured as:
[ch0_t0, ch1_t0, ..., ch18_t0, ch0_t1, ch1_t1, ..., ch18_t1, ...]
i.e., channels interleaved within time steps.
"""
def __init__(
self,
head_dim: int,
n_reference_types: int = 7,
max_seq_len: int = 2048,
base: float = 10000.0,
spatial_weight: float = 0.5, # Balance between spatial and temporal
):
super().__init__()
self.head_dim = head_dim
self.spatial_weight = spatial_weight
# Split head_dim between spatial a