← back to drewlinsley__tmp_bs_mcp

Function bodies 170 total

All specs Real LLM only Function bodies
main function · python · L966-L1237 (272 LOC)
eval/evaluate_v3_linear_probe.py
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--checkpoint', type=str,
                        default='checkpoints/pretrain_v3_20260208_173529/best_model.pt')
    parser.add_argument('--max_length', type=int, default=24000)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--tasks', type=str, nargs='+', default=None,
                        help='Specific tasks to evaluate (default: all)')
    parser.add_argument('--classification_only', action='store_true')
    parser.add_argument('--regression_only', action='store_true')
    parser.add_argument('--include_random', action='store_true',
                        help='Also evaluate random (untrained) v3 baseline')
    parser.add_argument('--condition_cache', type=str,
                        default='data/condition_embeddings.pt',
                        help='Path to condition embeddings cache')
    parser.ad
list_tools function · python · L33-L41 (9 LOC)
mcp_stdio_proxy.py
async def list_tools():
    async with httpx.AsyncClient(timeout=30.0) as client:
        resp = await client.get(f"{SERVER_URL}/api/tools")
        resp.raise_for_status()
        tools_data = resp.json()
    return [
        Tool(name=t["name"], description=t["description"], inputSchema=t["inputSchema"])
        for t in tools_data
    ]
call_tool function · python · L45-L53 (9 LOC)
mcp_stdio_proxy.py
async def call_tool(name: str, arguments: dict):
    async with httpx.AsyncClient(timeout=120.0) as client:
        resp = await client.post(
            f"{SERVER_URL}/api/call_tool",
            json={"name": name, "arguments": arguments},
        )
        resp.raise_for_status()
        data = resp.json()
    return [TextContent(type="text", text=data.get("text", str(data)))]
main function · python · L56-L58 (3 LOC)
mcp_stdio_proxy.py
async def main():
    async with stdio_server() as (read_stream, write_stream):
        await server.run(read_stream, write_stream, server.create_initialization_options())
RMSNorm class · python · L7-L16 (10 LOC)
models/common.py
class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization."""

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
__init__ method · python · L10-L13 (4 LOC)
models/common.py
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
RotaryPositionalEmbedding class · python · L36-L59 (24 LOC)
models/eeg_encoder_v3.py
class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, dim: int, max_seq_len: int = 4096, base: float = 500_000.0):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        self._precompute_cache(max_seq_len)

    def _precompute_cache(self, seq_len: int):
        t = torch.arange(seq_len, device=self.inv_freq.device)
        freqs = torch.outer(t, self.inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        self.register_buffer('cos_cached', emb.cos()[None, None, :, :])
        self.register_buffer('sin_cached', emb.sin()[None, None, :, :])

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        seq_len = x.shape[2]
        if seq_len > self.cos_cached.shape[2]:
            self._precompute_cache(seq_len)
        return (
            self.cos_cached[:, :, :seq_len, :],
 
Want fix-PRs on findings? Install Repobility's GitHub App · github.com/apps/repobility-bot
__init__ method · python · L37-L43 (7 LOC)
models/eeg_encoder_v3.py
    def __init__(self, dim: int, max_seq_len: int = 4096, base: float = 500_000.0):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        self._precompute_cache(max_seq_len)
_precompute_cache method · python · L45-L50 (6 LOC)
models/eeg_encoder_v3.py
    def _precompute_cache(self, seq_len: int):
        t = torch.arange(seq_len, device=self.inv_freq.device)
        freqs = torch.outer(t, self.inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        self.register_buffer('cos_cached', emb.cos()[None, None, :, :])
        self.register_buffer('sin_cached', emb.sin()[None, None, :, :])
forward method · python · L52-L59 (8 LOC)
models/eeg_encoder_v3.py
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        seq_len = x.shape[2]
        if seq_len > self.cos_cached.shape[2]:
            self._precompute_cache(seq_len)
        return (
            self.cos_cached[:, :, :seq_len, :],
            self.sin_cached[:, :, :seq_len, :],
        )
_rotate_half function · python · L62-L64 (3 LOC)
models/eeg_encoder_v3.py
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat([-x2, x1], dim=-1)
_apply_rope function · python · L67-L69 (3 LOC)
models/eeg_encoder_v3.py
def _apply_rope(q, k, cos, sin):
    return (q * cos + _rotate_half(q) * sin,
            k * cos + _rotate_half(k) * sin)
GQAttention class · python · L76-L144 (69 LOC)
models/eeg_encoder_v3.py
class GQAttention(nn.Module):
    """Pre-norm Grouped Query Attention with QK-Norm and RoPE."""

    def __init__(
        self,
        dim: int,
        n_heads: int = 24,
        n_kv_heads: int = 6,
        dropout: float = 0.1,
        use_qk_norm: bool = True,
        use_rope: bool = True,
        rope_base: float = 500_000.0,
    ):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.head_dim = dim // n_heads
        self.n_rep = n_heads // n_kv_heads  # repetitions per KV head

        self.q_proj = nn.Linear(dim, n_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
        self.out_proj = nn.Linear(dim, dim, bias=False)

        self.use_qk_norm = use_qk_norm
        if use_qk_norm:
            self.q_norm = RMSNorm(self.head_dim)
            self.k_norm = RMSNor
__init__ method · python · L79-L110 (32 LOC)
models/eeg_encoder_v3.py
    def __init__(
        self,
        dim: int,
        n_heads: int = 24,
        n_kv_heads: int = 6,
        dropout: float = 0.1,
        use_qk_norm: bool = True,
        use_rope: bool = True,
        rope_base: float = 500_000.0,
    ):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.head_dim = dim // n_heads
        self.n_rep = n_heads // n_kv_heads  # repetitions per KV head

        self.q_proj = nn.Linear(dim, n_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
        self.out_proj = nn.Linear(dim, dim, bias=False)

        self.use_qk_norm = use_qk_norm
        if use_qk_norm:
            self.q_norm = RMSNorm(self.head_dim)
            self.k_norm = RMSNorm(self.head_dim)

        self.use_rope = use_rope
        if use_rope:
            self.rope = R
forward method · python · L112-L144 (33 LOC)
models/eeg_encoder_v3.py
    def forward(self, x: torch.Tensor, is_causal: bool = True) -> torch.Tensor:
        B, T, _ = x.shape

        q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)

        if self.use_qk_norm:
            q = self.q_norm(q)
            k = self.k_norm(k)

        if self.use_rope:
            cos, sin = self.rope(q)
            # RoPE only applied to Q heads and corresponding KV heads
            # Expand cos/sin for Q heads
            q = q * cos + _rotate_half(q) * sin
            # For KV heads, take every n_rep-th slice of cos/sin
            cos_kv = cos[:, :self.n_kv_heads, :, :]
            sin_kv = sin[:, :self.n_kv_heads, :, :]
            k = k * cos_kv + _rotate_half(k) * sin_kv

        # Repeat KV heads to match Q heads
        if self.n_rep > 1:
            k = k.repeat_inter
If a scraper extracted this row, it came from Repobility (https://repobility.com)
SwiGLUFFN class · python · L151-L163 (13 LOC)
models/eeg_encoder_v3.py
class SwiGLUFFN(nn.Module):
    def __init__(self, dim: int, multiplier: float = 8.0 / 3.0, dropout: float = 0.1):
        super().__init__()
        hidden = int(dim * multiplier)
        # Round to nearest multiple of 128 for GPU efficiency
        hidden = ((hidden + 127) // 128) * 128
        self.w1 = nn.Linear(dim, hidden, bias=False)
        self.w2 = nn.Linear(hidden, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
__init__ method · python · L152-L160 (9 LOC)
models/eeg_encoder_v3.py
    def __init__(self, dim: int, multiplier: float = 8.0 / 3.0, dropout: float = 0.1):
        super().__init__()
        hidden = int(dim * multiplier)
        # Round to nearest multiple of 128 for GPU efficiency
        hidden = ((hidden + 127) // 128) * 128
        self.w1 = nn.Linear(dim, hidden, bias=False)
        self.w2 = nn.Linear(hidden, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden, bias=False)
        self.dropout = nn.Dropout(dropout)
TransformerBlockV3 class · python · L170-L193 (24 LOC)
models/eeg_encoder_v3.py
class TransformerBlockV3(nn.Module):
    def __init__(
        self,
        dim: int,
        n_heads: int,
        n_kv_heads: int,
        dropout: float = 0.1,
        ffn_multiplier: float = 8.0 / 3.0,
        use_qk_norm: bool = True,
        use_rope: bool = True,
        rope_base: float = 500_000.0,
    ):
        super().__init__()
        self.norm1 = RMSNorm(dim)
        self.attn = GQAttention(
            dim, n_heads, n_kv_heads, dropout, use_qk_norm, use_rope, rope_base,
        )
        self.norm2 = RMSNorm(dim)
        self.ffn = SwiGLUFFN(dim, multiplier=ffn_multiplier, dropout=dropout)

    def forward(self, x: torch.Tensor, is_causal: bool = True) -> torch.Tensor:
        x = x + self.attn(self.norm1(x), is_causal=is_causal)
        x = x + self.ffn(self.norm2(x))
        return x
__init__ method · python · L171-L188 (18 LOC)
models/eeg_encoder_v3.py
    def __init__(
        self,
        dim: int,
        n_heads: int,
        n_kv_heads: int,
        dropout: float = 0.1,
        ffn_multiplier: float = 8.0 / 3.0,
        use_qk_norm: bool = True,
        use_rope: bool = True,
        rope_base: float = 500_000.0,
    ):
        super().__init__()
        self.norm1 = RMSNorm(dim)
        self.attn = GQAttention(
            dim, n_heads, n_kv_heads, dropout, use_qk_norm, use_rope, rope_base,
        )
        self.norm2 = RMSNorm(dim)
        self.ffn = SwiGLUFFN(dim, multiplier=ffn_multiplier, dropout=dropout)
forward method · python · L190-L193 (4 LOC)
models/eeg_encoder_v3.py
    def forward(self, x: torch.Tensor, is_causal: bool = True) -> torch.Tensor:
        x = x + self.attn(self.norm1(x), is_causal=is_causal)
        x = x + self.ffn(self.norm2(x))
        return x
SpatialPatchEmbedding class · python · L200-L293 (94 LOC)
models/eeg_encoder_v3.py
class SpatialPatchEmbedding(nn.Module):
    """Spatial-aware patch embedding using 2D RoPE — pinned from v2."""

    def __init__(
        self,
        n_channels: int = 19,
        patch_size: int = 32,
        hidden_dim: int = 768,
        n_heads: int = 4,
        n_reference_types: int = 7,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.n_channels = n_channels
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim
        self.head_dim = hidden_dim // n_heads

        self.channel_embed = nn.Linear(patch_size, hidden_dim)

        self.qkv = nn.Linear(hidden_dim, 3 * hidden_dim, bias=False)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.n_heads = n_heads

        self.spatial_rope = Spatial2DRoPE(
            head_dim=self.head_dim,
            n_reference_types=n_reference_types,
            use_reference_conditioning=True,
        )

        self.q_norm = RMSNorm(self.head_dim)
        self.k_n
__init__ method · python · L203-L234 (32 LOC)
models/eeg_encoder_v3.py
    def __init__(
        self,
        n_channels: int = 19,
        patch_size: int = 32,
        hidden_dim: int = 768,
        n_heads: int = 4,
        n_reference_types: int = 7,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.n_channels = n_channels
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim
        self.head_dim = hidden_dim // n_heads

        self.channel_embed = nn.Linear(patch_size, hidden_dim)

        self.qkv = nn.Linear(hidden_dim, 3 * hidden_dim, bias=False)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.n_heads = n_heads

        self.spatial_rope = Spatial2DRoPE(
            head_dim=self.head_dim,
            n_reference_types=n_reference_types,
            use_reference_conditioning=True,
        )

        self.q_norm = RMSNorm(self.head_dim)
        self.k_norm = RMSNorm(self.head_dim)

        self.norm = RMSNorm(hidden_dim)
        self.dropout_p = dropout
forward method · python · L236-L293 (58 LOC)
models/eeg_encoder_v3.py
    def forward(
        self,
        patches: torch.Tensor,
        channel_mask: Optional[torch.Tensor] = None,
        reference_type_ids: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        B, N, C, P = patches.shape
        device = patches.device

        x = self.channel_embed(patches)
        x = x.view(B * N, C, self.hidden_dim)

        qkv = self.qkv(x).reshape(B * N, C, 3, self.n_heads, self.head_dim)
        q, k, v = qkv.unbind(2)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        q = self.q_norm(q)
        k = self.k_norm(k)

        if reference_type_ids is not None:
            ref_ids_expanded = reference_type_ids.repeat_interleave(N)
        else:
            ref_ids_expanded = None
        q, k = self.spatial_rope(q, k, ref_ids_expanded)

        if channel_mask is not None:
            mask_expanded = channel_mask.repeat_interleave(N, dim=0)
            attn_mask_bool = (mask_expanded < 0.5).unsqueeze(1).un
Repobility · code-quality intelligence · https://repobility.com
PredictionHead class · python · L300-L317 (18 LOC)
models/eeg_encoder_v3.py
class PredictionHead(nn.Module):
    """SwiGLU residual block + projection — pinned from v2."""

    def __init__(self, hidden_dim: int, patch_dim: int):
        super().__init__()
        self.norm = RMSNorm(hidden_dim)
        self.up = nn.Linear(hidden_dim, hidden_dim * 2, bias=False)
        self.gate = nn.Linear(hidden_dim, hidden_dim * 2, bias=False)
        self.down = nn.Linear(hidden_dim * 2, hidden_dim, bias=False)
        self.out_norm = RMSNorm(hidden_dim)
        self.proj = nn.Linear(hidden_dim, patch_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.norm(x)
        h = F.silu(self.up(h)) * self.gate(h)
        h = self.down(h)
        x = x + h
        return self.proj(self.out_norm(x))
__init__ method · python · L303-L310 (8 LOC)
models/eeg_encoder_v3.py
    def __init__(self, hidden_dim: int, patch_dim: int):
        super().__init__()
        self.norm = RMSNorm(hidden_dim)
        self.up = nn.Linear(hidden_dim, hidden_dim * 2, bias=False)
        self.gate = nn.Linear(hidden_dim, hidden_dim * 2, bias=False)
        self.down = nn.Linear(hidden_dim * 2, hidden_dim, bias=False)
        self.out_norm = RMSNorm(hidden_dim)
        self.proj = nn.Linear(hidden_dim, patch_dim)
forward method · python · L312-L317 (6 LOC)
models/eeg_encoder_v3.py
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.norm(x)
        h = F.silu(self.up(h)) * self.gate(h)
        h = self.down(h)
        x = x + h
        return self.proj(self.out_norm(x))
FeaturePredictionHead class · python · L324-L349 (26 LOC)
models/eeg_encoder_v3.py
class FeaturePredictionHead(nn.Module):
    """SwiGLU residual block + projection for EEG feature prediction.

    Predicts per-channel features: 5 band powers + 3 Hjorth + 1 variance = 9
    Output shape: (B, N, C*9) = (B, N, 171)
    Matches PredictionHead architecture for sufficient capacity.
    """

    def __init__(self, hidden_dim: int, n_channels: int, n_features: int = 9):
        super().__init__()
        self.n_channels = n_channels
        self.n_features = n_features
        out_dim = n_channels * n_features
        self.norm = RMSNorm(hidden_dim)
        self.up = nn.Linear(hidden_dim, hidden_dim * 2, bias=False)
        self.gate = nn.Linear(hidden_dim, hidden_dim * 2, bias=False)
        self.down = nn.Linear(hidden_dim * 2, hidden_dim, bias=False)
        self.out_norm = RMSNorm(hidden_dim)
        self.proj = nn.Linear(hidden_dim, out_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.norm(x)
        h = F.silu(self.up(h)) * self.gate(h)
 
__init__ method · python · L332-L342 (11 LOC)
models/eeg_encoder_v3.py
    def __init__(self, hidden_dim: int, n_channels: int, n_features: int = 9):
        super().__init__()
        self.n_channels = n_channels
        self.n_features = n_features
        out_dim = n_channels * n_features
        self.norm = RMSNorm(hidden_dim)
        self.up = nn.Linear(hidden_dim, hidden_dim * 2, bias=False)
        self.gate = nn.Linear(hidden_dim, hidden_dim * 2, bias=False)
        self.down = nn.Linear(hidden_dim * 2, hidden_dim, bias=False)
        self.out_norm = RMSNorm(hidden_dim)
        self.proj = nn.Linear(hidden_dim, out_dim)
forward method · python · L344-L349 (6 LOC)
models/eeg_encoder_v3.py
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.norm(x)
        h = F.silu(self.up(h)) * self.gate(h)
        h = self.down(h)
        x = x + h
        return self.proj(self.out_norm(x))
MTPModule class · python · L356-L388 (33 LOC)
models/eeg_encoder_v3.py
class MTPModule(nn.Module):
    """One depth of multi-step prediction.

    Takes the previous depth's hidden states and the ground-truth next-patch
    embedding, concatenates (after RMSNorm), projects back to hidden_dim,
    runs through one transformer layer, then reuses the shared prediction heads.
    """

    def __init__(self, hidden_dim: int, n_heads: int, n_kv_heads: int,
                 dropout: float, ffn_multiplier: float, rope_base: float):
        super().__init__()
        self.norm_hidden = RMSNorm(hidden_dim)
        self.norm_embed = RMSNorm(hidden_dim)
        self.proj = nn.Linear(hidden_dim * 2, hidden_dim, bias=False)
        self.block = TransformerBlockV3(
            dim=hidden_dim, n_heads=n_heads, n_kv_heads=n_kv_heads,
            dropout=dropout, ffn_multiplier=ffn_multiplier,
            use_qk_norm=True, use_rope=True, rope_base=rope_base,
        )

    def forward(self, hidden: torch.Tensor, patch_embeds: torch.Tensor,
                is_causal: bool =
__init__ method · python · L364-L374 (11 LOC)
models/eeg_encoder_v3.py
    def __init__(self, hidden_dim: int, n_heads: int, n_kv_heads: int,
                 dropout: float, ffn_multiplier: float, rope_base: float):
        super().__init__()
        self.norm_hidden = RMSNorm(hidden_dim)
        self.norm_embed = RMSNorm(hidden_dim)
        self.proj = nn.Linear(hidden_dim * 2, hidden_dim, bias=False)
        self.block = TransformerBlockV3(
            dim=hidden_dim, n_heads=n_heads, n_kv_heads=n_kv_heads,
            dropout=dropout, ffn_multiplier=ffn_multiplier,
            use_qk_norm=True, use_rope=True, rope_base=rope_base,
        )
Open data scored by Repobility · https://repobility.com
forward method · python · L376-L388 (13 LOC)
models/eeg_encoder_v3.py
    def forward(self, hidden: torch.Tensor, patch_embeds: torch.Tensor,
                is_causal: bool = True) -> torch.Tensor:
        """
        Args:
            hidden: (B, N, D) hidden states from previous depth
            patch_embeds: (B, N, D) embedded ground-truth patches at offset k
            is_causal: whether to use causal masking
        Returns:
            (B, N, D) hidden states for this depth
        """
        h = torch.cat([self.norm_hidden(hidden), self.norm_embed(patch_embeds)], dim=-1)
        h = self.proj(h)
        return self.block(h, is_causal=is_causal)
EEGEncoderV3 class · python · L395-L725 (331 LOC)
models/eeg_encoder_v3.py
class EEGEncoderV3(nn.Module):
    """
    EEG Encoder v3 for next-patch prediction with feature losses.

    Scaled transformer with GQA, SwiGLU 8/3, RoPE base=500k, scaled init.
    Optional register tokens for working memory.
    """

    def __init__(
        self,
        n_channels: int = 19,
        hidden_dim: int = 1536,
        n_heads: int = 24,
        n_kv_heads: int = 6,
        n_layers: int = 24,
        dropout: float = 0.1,
        patch_size: int = 32,
        ffn_multiplier: float = 8.0 / 3.0,
        rope_base: float = 500_000.0,
        use_qk_norm: bool = True,
        use_rope: bool = True,
        use_condition_encoding: bool = True,
        condition_cache_path: str = 'data/condition_embeddings.pt',
        max_channels: int = None,
        use_spatial_rope: bool = True,
        n_registers: int = 0,
        use_modality_tokens: bool = False,
        mtp_depths: int = 0,
    ):
        super().__init__()
        self.n_channels = n_channels
        self.max_ch
__init__ method · python · L403-L533 (131 LOC)
models/eeg_encoder_v3.py
    def __init__(
        self,
        n_channels: int = 19,
        hidden_dim: int = 1536,
        n_heads: int = 24,
        n_kv_heads: int = 6,
        n_layers: int = 24,
        dropout: float = 0.1,
        patch_size: int = 32,
        ffn_multiplier: float = 8.0 / 3.0,
        rope_base: float = 500_000.0,
        use_qk_norm: bool = True,
        use_rope: bool = True,
        use_condition_encoding: bool = True,
        condition_cache_path: str = 'data/condition_embeddings.pt',
        max_channels: int = None,
        use_spatial_rope: bool = True,
        n_registers: int = 0,
        use_modality_tokens: bool = False,
        mtp_depths: int = 0,
    ):
        super().__init__()
        self.n_channels = n_channels
        self.max_channels = max_channels or n_channels
        self.hidden_dim = hidden_dim
        self.patch_size = patch_size
        self.n_layers = n_layers
        self.mtp_depths = mtp_depths
        self.n_heads = n_heads
        self.n_kv_heads = n
_init_weights method · python · L535-L548 (14 LOC)
models/eeg_encoder_v3.py
    def _init_weights(self):
        """Apply scaled init: residual projections get 1/sqrt(2*N_layers)."""
        residual_std = 0.02 / math.sqrt(2 * self.n_layers)

        for module in self.modules():
            if isinstance(module, nn.Linear):
                torch.nn.init.trunc_normal_(module.weight, std=0.02)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

        # Scale down residual projections
        for block in self.blocks:
            torch.nn.init.trunc_normal_(block.attn.out_proj.weight, std=residual_std)
            torch.nn.init.trunc_normal_(block.ffn.w2.weight, std=residual_std)
patchify method · python · L550-L557 (8 LOC)
models/eeg_encoder_v3.py
    def patchify(self, eeg: torch.Tensor, flatten: bool = True) -> torch.Tensor:
        B, C, T = eeg.shape
        n_patches = T // self.patch_size
        eeg = eeg[:, :, :n_patches * self.patch_size]
        patches = eeg.view(B, C, n_patches, self.patch_size).permute(0, 2, 1, 3)
        if flatten:
            patches = patches.reshape(B, n_patches, -1)
        return patches
encode_condition method · python · L559-L581 (23 LOC)
models/eeg_encoder_v3.py
    def encode_condition(self, condition_keys, device):
        if not self.use_condition_encoding or self.condition_cache is None:
            return None
        # Move cache to device once (avoids CPU→GPU copy every forward pass)
        if self._condition_cache_device != device:
            self.condition_cache = {k: v.to(device) for k, v in self.condition_cache.items()}
            self._condition_cache_device = device
        embeddings = []
        for dataset, recording_state in condition_keys:
            key_str = f"{dataset}|{recording_state}"
            if key_str in self.condition_cache:
                embed = self.condition_cache[key_str]
            else:
                key_str = f"{dataset}|default"
                if key_str in self.condition_cache:
                    embed = self.condition_cache[key_str]
                else:
                    import warnings
                    warnings.warn(f"No condition embedding for {dataset}|{recording_state}, using arbitr
encode method · python · L583-L652 (70 LOC)
models/eeg_encoder_v3.py
    def encode(
        self,
        eeg: torch.Tensor,
        condition_keys=None,
        channel_mask=None,
        reference_type_ids=None,
        channel_coords=None,
    ) -> Dict[str, torch.Tensor]:
        if self.use_spatial_rope:
            patches = self.patchify(eeg, flatten=False)
            if self.use_variable_channels:
                x = self.patch_embed(
                    patches, channel_mask=channel_mask,
                    channel_coords=channel_coords,
                    reference_type_ids=reference_type_ids,
                )
            else:
                x = self.patch_embed(patches, channel_mask=channel_mask, reference_type_ids=reference_type_ids)
        else:
            patches = self.patchify(eeg, flatten=True)
            x = self.patch_embed(patches)

        B = x.shape[0]

        if self.use_modality_tokens:
            # Build discrete condition token
            if condition_keys is not None and self.use_condition_encoding:
             
forward method · python · L654-L725 (72 LOC)
models/eeg_encoder_v3.py
    def forward(
        self,
        eeg: torch.Tensor,
        condition_keys=None,
        channel_mask=None,
        reference_type_ids=None,
        channel_coords=None,
    ) -> Dict[str, torch.Tensor]:
        if channel_mask is not None:
            eeg = eeg * channel_mask.unsqueeze(-1)

        target_patches = self.patchify(eeg, flatten=True)

        enc_out = self.encode(
            eeg, condition_keys=condition_keys,
            channel_mask=channel_mask, reference_type_ids=reference_type_ids,
            channel_coords=channel_coords,
        )
        hidden = enc_out['hidden']

        next_patch_pred = self.next_patch_head(hidden[:, :-1, :])
        feature_pred = self.feature_head(hidden[:, :-1, :])

        result = {
            'hidden': hidden,
            'predictions': next_patch_pred,
            'feature_predictions': feature_pred,
            'targets': target_patches,
            'channel_mask': channel_mask,
        }

        # Multi-token prediction (t
Want fix-PRs on findings? Install Repobility's GitHub App · github.com/apps/repobility-bot
rotate_half function · python · L27-L30 (4 LOC)
models/spatial_rope.py
def rotate_half(x: torch.Tensor) -> torch.Tensor:
    """Rotate half the hidden dims."""
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat([-x2, x1], dim=-1)
apply_rotary_pos_emb function · python · L33-L42 (10 LOC)
models/spatial_rope.py
def apply_rotary_pos_emb(
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Apply rotary positional embeddings to Q and K."""
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
Spatial2DRoPE class · python · L45-L188 (144 LOC)
models/spatial_rope.py
class Spatial2DRoPE(nn.Module):
    """
    2D Rotary Position Embedding for EEG with Reference Conditioning.

    Maps 19 EEG channels to their (x, y) scalp coordinates and applies
    rotary embeddings. Reference type modulates the rotation frequencies,
    allowing the model to learn different spatial relationships for different
    reference schemes.

    The core idea:
    - Standard RoPE: rotation angle = position * frequency
    - 2D RoPE: rotation angle = (x * freq_x) + (y * freq_y)
    - Reference-conditioned: frequencies are scaled by reference type

    This is applied to the channel dimension, not the temporal dimension.
    For full EEG encoding, combine with temporal RoPE.
    """

    def __init__(
        self,
        head_dim: int,
        n_reference_types: int = 7,
        base: float = 10000.0,
        use_reference_conditioning: bool = True,
    ):
        super().__init__()
        self.head_dim = head_dim
        self.n_reference_types = n_reference_types
      
__init__ method · python · L63-L103 (41 LOC)
models/spatial_rope.py
    def __init__(
        self,
        head_dim: int,
        n_reference_types: int = 7,
        base: float = 10000.0,
        use_reference_conditioning: bool = True,
    ):
        super().__init__()
        self.head_dim = head_dim
        self.n_reference_types = n_reference_types
        self.base = base
        self.use_reference_conditioning = use_reference_conditioning

        # We split head_dim into 4 parts: x_sin, x_cos, y_sin, y_cos
        # Each part gets head_dim // 4 dimensions
        quarter_dim = head_dim // 4
        if head_dim % 4 != 0:
            raise ValueError(f"head_dim ({head_dim}) must be divisible by 4 for 2D RoPE")

        # Base frequencies for x and y dimensions
        inv_freq = 1.0 / (base ** (torch.arange(0, quarter_dim, dtype=torch.float32) / quarter_dim))
        self.register_buffer('inv_freq', inv_freq)

        # Channel positions from EEGConfig (normalized to ~[-1, 1])
        # Standard 19-channel 10-20 montage
        channel_coords = 
get_reference_type_id method · python · L105-L107 (3 LOC)
models/spatial_rope.py
    def get_reference_type_id(self, reference_name: str) -> int:
        """Convert reference name to type ID."""
        return EEGConfig.REFERENCE_TYPES.get(reference_name, EEGConfig.REFERENCE_TYPES['unknown'])
forward method · python · L109-L188 (80 LOC)
models/spatial_rope.py
    def forward(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        reference_type_ids: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Apply 2D spatial RoPE with reference conditioning.

        Args:
            q: Query tensor of shape (batch, n_heads, n_channels, head_dim)
            k: Key tensor of shape (batch, n_heads, n_channels, head_dim)
            reference_type_ids: Reference type IDs of shape (batch,)
                If None, uses reference type 2 (average) as default.

        Returns:
            q_rotated, k_rotated: Rotated query and key tensors
        """
        batch_size, n_heads, n_channels, head_dim = q.shape
        device = q.device
        dtype = q.dtype

        # Default to average reference if not specified
        if reference_type_ids is None:
            reference_type_ids = torch.full(
                (batch_size,), 2, dtype=torch.long, device=device
            )

        # Get (x
TemporalRoPE class · python · L191-L243 (53 LOC)
models/spatial_rope.py
class TemporalRoPE(nn.Module):
    """
    Standard 1D Rotary Position Embedding for temporal dimension.

    Applied to patch/time positions in the EEG sequence.
    """

    def __init__(self, dim: int, max_seq_len: int = 2048, base: float = 10000.0):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len

        # Compute inverse frequencies
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

        # Precompute cos/sin embeddings
        self._precompute_cache(max_seq_len)

    def _precompute_cache(self, seq_len: int):
        t = torch.arange(seq_len, device=self.inv_freq.device)
        freqs = torch.outer(t, self.inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False)
        self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False)

    def forward(
    
__init__ method · python · L198-L208 (11 LOC)
models/spatial_rope.py
    def __init__(self, dim: int, max_seq_len: int = 2048, base: float = 10000.0):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len

        # Compute inverse frequencies
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

        # Precompute cos/sin embeddings
        self._precompute_cache(max_seq_len)
If a scraper extracted this row, it came from Repobility (https://repobility.com)
_precompute_cache method · python · L210-L215 (6 LOC)
models/spatial_rope.py
    def _precompute_cache(self, seq_len: int):
        t = torch.arange(seq_len, device=self.inv_freq.device)
        freqs = torch.outer(t, self.inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False)
        self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False)
forward method · python · L217-L243 (27 LOC)
models/spatial_rope.py
    def forward(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        seq_len: Optional[int] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Apply temporal RoPE.

        Args:
            q, k: Shape (batch, n_heads, seq_len, head_dim)
            seq_len: Optional sequence length override

        Returns:
            q_rotated, k_rotated: Rotated tensors
        """
        if seq_len is None:
            seq_len = q.shape[2]

        # Extend cache if needed
        if seq_len > self.cos_cached.shape[2]:
            self._precompute_cache(seq_len)

        cos = self.cos_cached[:, :, :seq_len, :].to(q.dtype)
        sin = self.sin_cached[:, :, :seq_len, :].to(q.dtype)

        return apply_rotary_pos_emb(q, k, cos, sin)
SpatialTemporalRoPE class · python · L246-L349 (104 LOC)
models/spatial_rope.py
class SpatialTemporalRoPE(nn.Module):
    """
    Combined Spatial (2D) and Temporal (1D) RoPE for EEG.

    For a sequence where each position represents a specific (channel, time) pair,
    this applies both spatial and temporal rotary embeddings.

    The key insight is that EEG has two positional dimensions:
    1. Spatial: Which channel (electrode position on scalp)
    2. Temporal: When in the recording

    This module handles the case where the sequence is structured as:
    [ch0_t0, ch1_t0, ..., ch18_t0, ch0_t1, ch1_t1, ..., ch18_t1, ...]
    i.e., channels interleaved within time steps.
    """

    def __init__(
        self,
        head_dim: int,
        n_reference_types: int = 7,
        max_seq_len: int = 2048,
        base: float = 10000.0,
        spatial_weight: float = 0.5,  # Balance between spatial and temporal
    ):
        super().__init__()
        self.head_dim = head_dim
        self.spatial_weight = spatial_weight

        # Split head_dim between spatial a
‹ prevpage 2 / 4next ›