← back to drewlinsley__tmp_bs_mcp

Function bodies 170 total

All specs Real LLM only Function bodies
__init__ method · python · L262-L297 (36 LOC)
models/spatial_rope.py
    def __init__(
        self,
        head_dim: int,
        n_reference_types: int = 7,
        max_seq_len: int = 2048,
        base: float = 10000.0,
        spatial_weight: float = 0.5,  # Balance between spatial and temporal
    ):
        super().__init__()
        self.head_dim = head_dim
        self.spatial_weight = spatial_weight

        # Split head_dim between spatial and temporal
        # Each gets half, then split further for sin/cos
        self.spatial_dim = head_dim // 2
        self.temporal_dim = head_dim - self.spatial_dim

        # Ensure dimensions are even for rotation
        if self.spatial_dim % 4 != 0:
            # Adjust to make spatial_dim divisible by 4
            self.spatial_dim = (self.spatial_dim // 4) * 4
            self.temporal_dim = head_dim - self.spatial_dim

        if self.spatial_dim > 0:
            self.spatial_rope = Spatial2DRoPE(
                head_dim=self.spatial_dim,
                n_reference_types=n_reference_types,
      
forward method · python · L299-L349 (51 LOC)
models/spatial_rope.py
    def forward(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        channel_indices: torch.Tensor,
        time_indices: torch.Tensor,
        reference_type_ids: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Apply combined spatial and temporal RoPE.

        Args:
            q, k: Shape (batch, n_heads, seq_len, head_dim)
            channel_indices: Which channel each position corresponds to (batch, seq_len)
            time_indices: Which time step each position corresponds to (batch, seq_len)
            reference_type_ids: Reference type IDs (batch,)

        Returns:
            q_rotated, k_rotated: Rotated tensors
        """
        batch_size, n_heads, seq_len, head_dim = q.shape

        # Split into spatial and temporal parts
        q_spatial = q[..., :self.spatial_dim]
        q_temporal = q[..., self.spatial_dim:]
        k_spatial = k[..., :self.spatial_dim]
        k_temporal = k[..., self.spatial_d
SpatialAttention class · python · L352-L447 (96 LOC)
models/spatial_rope.py
class SpatialAttention(nn.Module):
    """
    Multi-Head Attention with 2D Spatial RoPE for EEG.

    This attention operates on channels as positions (not time patches).
    Each channel has a 2D position on the scalp, and the attention learns
    spatial relationships modulated by the reference type.

    Use this for cross-channel attention within a time window.
    """

    def __init__(
        self,
        hidden_dim: int,
        n_heads: int,
        n_reference_types: int = 7,
        dropout: float = 0.1,
        use_qk_norm: bool = True,
    ):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_heads = n_heads
        self.head_dim = hidden_dim // n_heads
        self.scale = self.head_dim ** -0.5

        if hidden_dim % n_heads != 0:
            raise ValueError(f"hidden_dim ({hidden_dim}) must be divisible by n_heads ({n_heads})")

        # QKV projection
        self.qkv = nn.Linear(hidden_dim, 3 * hidden_dim, bias=False)
        self.out_p
__init__ method · python · L363-L396 (34 LOC)
models/spatial_rope.py
    def __init__(
        self,
        hidden_dim: int,
        n_heads: int,
        n_reference_types: int = 7,
        dropout: float = 0.1,
        use_qk_norm: bool = True,
    ):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_heads = n_heads
        self.head_dim = hidden_dim // n_heads
        self.scale = self.head_dim ** -0.5

        if hidden_dim % n_heads != 0:
            raise ValueError(f"hidden_dim ({hidden_dim}) must be divisible by n_heads ({n_heads})")

        # QKV projection
        self.qkv = nn.Linear(hidden_dim, 3 * hidden_dim, bias=False)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)

        # 2D Spatial RoPE
        self.spatial_rope = Spatial2DRoPE(
            head_dim=self.head_dim,
            n_reference_types=n_reference_types,
        )

        # QK normalization
        self.use_qk_norm = use_qk_norm
        if use_qk_norm:
            self.q_norm = RMSNorm(self.head_dim)
            self.k_no
forward method · python · L398-L447 (50 LOC)
models/spatial_rope.py
    def forward(
        self,
        x: torch.Tensor,
        reference_type_ids: Optional[torch.Tensor] = None,
        mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Apply spatial attention.

        Args:
            x: Input tensor of shape (batch, n_channels, hidden_dim)
               where n_channels = 19 for standard EEG
            reference_type_ids: Reference type IDs of shape (batch,)
            mask: Optional attention mask

        Returns:
            Output tensor of shape (batch, n_channels, hidden_dim)
        """
        B, N, D = x.shape

        # QKV projection
        qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, self.head_dim)
        q, k, v = qkv.unbind(2)  # Each: (B, N, n_heads, head_dim)
        q = q.transpose(1, 2)  # (B, n_heads, N, head_dim)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # QK normalization
        if self.use_qk_norm:
            q = self.q_norm(q)
            k = self.k_norm(
SpatialTransformerBlock class · python · L450-L506 (57 LOC)
models/spatial_rope.py
class SpatialTransformerBlock(nn.Module):
    """
    Transformer block with 2D Spatial RoPE for EEG.

    Pre-norm architecture with RMSNorm and SwiGLU FFN.
    """

    def __init__(
        self,
        hidden_dim: int,
        n_heads: int,
        n_reference_types: int = 7,
        dropout: float = 0.1,
        use_qk_norm: bool = True,
    ):
        super().__init__()

        self.norm1 = RMSNorm(hidden_dim)
        self.attn = SpatialAttention(
            hidden_dim=hidden_dim,
            n_heads=n_heads,
            n_reference_types=n_reference_types,
            dropout=dropout,
            use_qk_norm=use_qk_norm,
        )

        self.norm2 = RMSNorm(hidden_dim)

        # SwiGLU FFN
        hidden_ff = hidden_dim * 4
        self.w1 = nn.Linear(hidden_dim, hidden_ff, bias=False)
        self.w2 = nn.Linear(hidden_ff, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, hidden_ff, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(
  
__init__ method · python · L457-L483 (27 LOC)
models/spatial_rope.py
    def __init__(
        self,
        hidden_dim: int,
        n_heads: int,
        n_reference_types: int = 7,
        dropout: float = 0.1,
        use_qk_norm: bool = True,
    ):
        super().__init__()

        self.norm1 = RMSNorm(hidden_dim)
        self.attn = SpatialAttention(
            hidden_dim=hidden_dim,
            n_heads=n_heads,
            n_reference_types=n_reference_types,
            dropout=dropout,
            use_qk_norm=use_qk_norm,
        )

        self.norm2 = RMSNorm(hidden_dim)

        # SwiGLU FFN
        hidden_ff = hidden_dim * 4
        self.w1 = nn.Linear(hidden_dim, hidden_ff, bias=False)
        self.w2 = nn.Linear(hidden_ff, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, hidden_ff, bias=False)
        self.dropout = nn.Dropout(dropout)
Want this analysis on your repo? https://repobility.com/scan/
forward method · python · L485-L506 (22 LOC)
models/spatial_rope.py
    def forward(
        self,
        x: torch.Tensor,
        reference_type_ids: Optional[torch.Tensor] = None,
        mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Forward pass.

        Args:
            x: Input of shape (batch, n_channels, hidden_dim)
            reference_type_ids: Reference type IDs of shape (batch,)
            mask: Optional attention mask
        """
        # Pre-norm attention with spatial RoPE
        x = x + self.attn(self.norm1(x), reference_type_ids, mask)

        # Pre-norm SwiGLU FFN
        h = self.norm2(x)
        x = x + self.dropout(self.w2(F.silu(self.w1(h)) * self.w3(h)))

        return x
SpatialChannelEncoder class · python · L509-L598 (90 LOC)
models/spatial_rope.py
class SpatialChannelEncoder(nn.Module):
    """
    Encode EEG channels with spatial awareness.

    This encoder treats each channel as a position with 2D coordinates.
    It processes each time window across all channels, learning spatial
    relationships that are modulated by the reference type.

    Input: (batch, n_channels, time_samples)
    Output: (batch, n_windows, hidden_dim)
    """

    def __init__(
        self,
        n_channels: int = 19,
        hidden_dim: int = 768,
        n_heads: int = 12,
        n_layers: int = 2,  # Fewer layers for spatial processing
        n_reference_types: int = 7,
        window_size: int = 32,
        dropout: float = 0.1,
    ):
        super().__init__()

        self.n_channels = n_channels
        self.hidden_dim = hidden_dim
        self.window_size = window_size

        # Per-channel embedding (window_size -> hidden_dim)
        self.channel_embed = nn.Linear(window_size, hidden_dim)

        # Spatial transformer blocks
       
__init__ method · python · L521-L555 (35 LOC)
models/spatial_rope.py
    def __init__(
        self,
        n_channels: int = 19,
        hidden_dim: int = 768,
        n_heads: int = 12,
        n_layers: int = 2,  # Fewer layers for spatial processing
        n_reference_types: int = 7,
        window_size: int = 32,
        dropout: float = 0.1,
    ):
        super().__init__()

        self.n_channels = n_channels
        self.hidden_dim = hidden_dim
        self.window_size = window_size

        # Per-channel embedding (window_size -> hidden_dim)
        self.channel_embed = nn.Linear(window_size, hidden_dim)

        # Spatial transformer blocks
        self.spatial_blocks = nn.ModuleList([
            SpatialTransformerBlock(
                hidden_dim=hidden_dim,
                n_heads=n_heads,
                n_reference_types=n_reference_types,
                dropout=dropout,
            )
            for _ in range(n_layers)
        ])

        # Pool across channels to get window representation
        self.channel_pool = nn.Sequential(
forward method · python · L557-L598 (42 LOC)
models/spatial_rope.py
    def forward(
        self,
        x: torch.Tensor,
        reference_type_ids: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Encode EEG with spatial awareness.

        Args:
            x: Input EEG of shape (batch, n_channels, time_samples)
            reference_type_ids: Reference type IDs of shape (batch,)

        Returns:
            Encoded representation of shape (batch, n_windows, hidden_dim)
        """
        batch_size, n_channels, total_samples = x.shape
        n_windows = total_samples // self.window_size

        # Reshape to windows: (batch, n_channels, n_windows, window_size)
        x = x[:, :, :n_windows * self.window_size]
        x = x.view(batch_size, n_channels, n_windows, self.window_size)

        # Process each window with spatial attention
        outputs = []
        for w in range(n_windows):
            # Get window: (batch, n_channels, window_size)
            window = x[:, :, w, :]

            # Embed each channel: (ba
get_reference_type_tensor function · python · L601-L620 (20 LOC)
models/spatial_rope.py
def get_reference_type_tensor(
    reference_names: list,
    device: torch.device = None,
) -> torch.Tensor:
    """
    Convert a list of reference names to type IDs.

    Args:
        reference_names: List of reference name strings (e.g., ['A1-A2', 'FCz', 'average'])
        device: Target device

    Returns:
        Tensor of shape (batch,) with reference type IDs
    """
    ids = [EEGConfig.REFERENCE_TYPES.get(name, EEGConfig.REFERENCE_TYPES['unknown'])
           for name in reference_names]
    tensor = torch.tensor(ids, dtype=torch.long)
    if device is not None:
        tensor = tensor.to(device)
    return tensor
VariableSpatial2DRoPE class · python · L28-L105 (78 LOC)
models/variable_channel_embedding.py
class VariableSpatial2DRoPE(nn.Module):
    """2D RoPE that takes channel coordinates as input (not fixed buffer).

    Unlike Spatial2DRoPE which has a fixed 19-channel position buffer,
    this version accepts (x, y) coordinates per channel per sample,
    enabling variable channel counts across the batch.
    """

    def __init__(self, head_dim: int, n_reference_types: int = 7,
                 base: float = 10000.0, use_reference_conditioning: bool = True):
        super().__init__()
        self.head_dim = head_dim
        self.use_reference_conditioning = use_reference_conditioning

        quarter_dim = head_dim // 4
        if head_dim % 4 != 0:
            raise ValueError(f"head_dim ({head_dim}) must be divisible by 4 for 2D RoPE")

        inv_freq = 1.0 / (base ** (torch.arange(0, quarter_dim, dtype=torch.float32) / quarter_dim))
        self.register_buffer('inv_freq', inv_freq)

        if use_reference_conditioning:
            self.reference_freq_scale = nn.Parameter(t
__init__ method · python · L36-L51 (16 LOC)
models/variable_channel_embedding.py
    def __init__(self, head_dim: int, n_reference_types: int = 7,
                 base: float = 10000.0, use_reference_conditioning: bool = True):
        super().__init__()
        self.head_dim = head_dim
        self.use_reference_conditioning = use_reference_conditioning

        quarter_dim = head_dim // 4
        if head_dim % 4 != 0:
            raise ValueError(f"head_dim ({head_dim}) must be divisible by 4 for 2D RoPE")

        inv_freq = 1.0 / (base ** (torch.arange(0, quarter_dim, dtype=torch.float32) / quarter_dim))
        self.register_buffer('inv_freq', inv_freq)

        if use_reference_conditioning:
            self.reference_freq_scale = nn.Parameter(torch.ones(n_reference_types, 2))
            self.reference_rotation_offset = nn.Parameter(torch.zeros(n_reference_types, head_dim))
forward method · python · L53-L105 (53 LOC)
models/variable_channel_embedding.py
    def forward(self, q, k, channel_coords, reference_type_ids=None):
        """
        Args:
            q, k: (batch, n_heads, n_channels, head_dim)
            channel_coords: (batch, n_channels, 2) — (x, y) per channel
            reference_type_ids: (batch,) — reference type per sample
        """
        batch_size, n_heads, n_channels, head_dim = q.shape
        device = q.device
        dtype = q.dtype

        if reference_type_ids is None:
            reference_type_ids = torch.full((batch_size,), 2, dtype=torch.long, device=device)

        x_pos = channel_coords[:, :, 0]  # (batch, n_channels)
        y_pos = channel_coords[:, :, 1]  # (batch, n_channels)

        if self.use_reference_conditioning:
            freq_scale = self.reference_freq_scale[reference_type_ids]  # (batch, 2)
            freq_scale_x = freq_scale[:, 0:1]  # (batch, 1)
            freq_scale_y = freq_scale[:, 1:2]  # (batch, 1)
        else:
            freq_scale_x = torch.ones(batch_size, 1, devic
All rows above produced by Repobility · https://repobility.com
VariableSpatialPatchEmbedding class · python · L108-L218 (111 LOC)
models/variable_channel_embedding.py
class VariableSpatialPatchEmbedding(nn.Module):
    """Spatial-aware patch embedding supporting variable channel counts.

    Input: (B, N_patches, max_channels, patch_size) with channel_mask (B, max_channels)
    Output: (B, N_patches, hidden_dim)

    Channels are identified by their (x, y) scalp coordinates via channel_coords.
    Attention is computed only over present channels (masked by channel_mask).
    """

    def __init__(
        self,
        max_channels: int = 64,
        patch_size: int = 32,
        hidden_dim: int = 768,
        n_heads: int = 4,
        n_reference_types: int = 7,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.max_channels = max_channels
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim
        self.head_dim = hidden_dim // n_heads
        self.n_heads = n_heads

        self.channel_embed = nn.Linear(patch_size, hidden_dim)
        self.qkv = nn.Linear(hidden_dim, 3 * hidden_dim, bias=False)
    
__init__ method · python · L118-L147 (30 LOC)
models/variable_channel_embedding.py
    def __init__(
        self,
        max_channels: int = 64,
        patch_size: int = 32,
        hidden_dim: int = 768,
        n_heads: int = 4,
        n_reference_types: int = 7,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.max_channels = max_channels
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim
        self.head_dim = hidden_dim // n_heads
        self.n_heads = n_heads

        self.channel_embed = nn.Linear(patch_size, hidden_dim)
        self.qkv = nn.Linear(hidden_dim, 3 * hidden_dim, bias=False)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)

        self.spatial_rope = VariableSpatial2DRoPE(
            head_dim=self.head_dim,
            n_reference_types=n_reference_types,
            use_reference_conditioning=True,
        )

        self.q_norm = RMSNorm(self.head_dim)
        self.k_norm = RMSNorm(self.head_dim)
        self.norm = RMSNorm(hidden_dim)
        self.dropout_p = dropou
forward method · python · L149-L218 (70 LOC)
models/variable_channel_embedding.py
    def forward(
        self,
        patches: torch.Tensor,
        channel_mask: Optional[torch.Tensor] = None,
        channel_coords: Optional[torch.Tensor] = None,
        reference_type_ids: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Args:
            patches: (B, N, C, P) — C can be any number up to max_channels
            channel_mask: (B, C) — 1 for present, 0 for absent
            channel_coords: (B, C, 2) — (x, y) scalp coordinates per channel
            reference_type_ids: (B,) — reference type per sample
        """
        B, N, C, P = patches.shape
        device = patches.device

        # Embed each channel's patch
        x = self.channel_embed(patches)  # (B, N, C, D)
        x = x.view(B * N, C, self.hidden_dim)

        # QKV projection
        qkv = self.qkv(x).reshape(B * N, C, 3, self.n_heads, self.head_dim)
        q, k, v = qkv.unbind(2)
        q = q.transpose(1, 2)  # (B*N, heads, C, head_dim)
        k = k.transpose(1, 2)
get_channel_coords_tensor function · python · L221-L239 (19 LOC)
models/variable_channel_embedding.py
def get_channel_coords_tensor(channel_names: List[str], device='cpu') -> torch.Tensor:
    """Convert channel names to (N, 2) coordinate tensor.

    Looks up each channel in EEGConfig.CHANNEL_COORDS.
    Unknown channels get (0, 0) coordinates.
    """
    coords = []
    for ch in channel_names:
        if ch in EEGConfig.CHANNEL_COORDS:
            coords.append(EEGConfig.CHANNEL_COORDS[ch])
        else:
            # Try common aliases
            aliases = {'T7': 'T3', 'T8': 'T4', 'P7': 'T5', 'P8': 'T6'}
            alias = aliases.get(ch, ch)
            if alias in EEGConfig.CHANNEL_COORDS:
                coords.append(EEGConfig.CHANNEL_COORDS[alias])
            else:
                coords.append((0.0, 0.0))  # Unknown position
    return torch.tensor(coords, dtype=torch.float32, device=device)
compute_band_powers function · python · L23-L33 (11 LOC)
server/eeg_features.py
def compute_band_powers(data_1ch: np.ndarray, srate: int = SRATE) -> dict:
    """Compute relative band powers for one channel."""
    freqs, psd = signal.welch(data_1ch, fs=srate, nperseg=min(512, len(data_1ch)))
    total_power = np.sum(psd)
    if total_power < 1e-12:
        return {name: 0.0 for name in BANDS}
    powers = {}
    for name, (lo, hi) in BANDS.items():
        mask = (freqs >= lo) & (freqs < hi)
        powers[name] = float(np.sum(psd[mask]) / total_power)
    return powers
compute_hjorth function · python · L36-L45 (10 LOC)
server/eeg_features.py
def compute_hjorth(data_1ch: np.ndarray) -> dict:
    """Compute Hjorth activity, mobility, complexity."""
    activity = float(np.var(data_1ch))
    d1 = np.diff(data_1ch)
    d2 = np.diff(d1)
    var_d1 = np.var(d1) if len(d1) > 0 else 1e-12
    var_d2 = np.var(d2) if len(d2) > 0 else 1e-12
    mobility = float(np.sqrt(var_d1 / max(activity, 1e-12)))
    complexity = float(np.sqrt(var_d2 / max(var_d1, 1e-12)) / max(mobility, 1e-12))
    return {'activity': activity, 'mobility': mobility, 'complexity': complexity}
compute_spectral_ratios function · python · L48-L60 (13 LOC)
server/eeg_features.py
def compute_spectral_ratios(band_powers: dict) -> dict:
    """Compute clinically relevant spectral ratios."""
    theta = band_powers.get('theta', 0)
    alpha = band_powers.get('alpha', 0)
    beta = band_powers.get('beta', 0)
    delta = band_powers.get('delta', 0)

    return {
        'theta_alpha_ratio': theta / max(alpha, 1e-8),
        'theta_beta_ratio': theta / max(beta, 1e-8),
        'delta_alpha_ratio': delta / max(alpha, 1e-8),
        'alpha_beta_ratio': alpha / max(beta, 1e-8),
    }
compute_statistical_features function · python · L63-L72 (10 LOC)
server/eeg_features.py
def compute_statistical_features(data_1ch: np.ndarray) -> dict:
    """Basic statistical features."""
    return {
        'mean': float(np.mean(data_1ch)),
        'std': float(np.std(data_1ch)),
        'skewness': float(_skewness(data_1ch)),
        'kurtosis': float(_kurtosis(data_1ch)),
        'peak_to_peak': float(np.ptp(data_1ch)),
        'zero_crossing_rate': float(_zero_crossing_rate(data_1ch)),
    }
Repobility · MCP-ready · https://repobility.com
_skewness function · python · L75-L79 (5 LOC)
server/eeg_features.py
def _skewness(x):
    m = np.mean(x)
    s = np.std(x)
    if s < 1e-12: return 0.0
    return np.mean(((x - m) / s) ** 3)
_kurtosis function · python · L82-L86 (5 LOC)
server/eeg_features.py
def _kurtosis(x):
    m = np.mean(x)
    s = np.std(x)
    if s < 1e-12: return 0.0
    return np.mean(((x - m) / s) ** 4) - 3
extract_features function · python · L93-L147 (55 LOC)
server/eeg_features.py
def extract_features(eeg_data: np.ndarray, channel_mask: Optional[np.ndarray] = None,
                     srate: int = SRATE) -> np.ndarray:
    """
    Extract traditional EEG features from multi-channel data.

    Args:
        eeg_data: (n_channels, n_samples)
        channel_mask: (n_channels,) — 1 for active, 0 for padding
        srate: sample rate

    Returns:
        features: (n_features,) — concatenated per-channel features averaged across active channels
    """
    n_ch = eeg_data.shape[0]

    if channel_mask is None:
        channel_mask = np.ones(n_ch, dtype=np.float32)

    active_channels = np.where(channel_mask > 0.5)[0]
    if len(active_channels) == 0:
        active_channels = np.arange(min(n_ch, 19))

    all_features = []

    for ch in active_channels:
        x = eeg_data[ch]
        if np.std(x) < 1e-8:
            continue

        bands = compute_band_powers(x, srate)
        hjorth = compute_hjorth(x)
        ratios = compute_spectral_ratios(bands)
      
EEGInferenceEngine class · python · L109-L286 (178 LOC)
server/inference.py
class EEGInferenceEngine:
    """Wraps the frozen encoder for inference."""

    def __init__(self, checkpoint_path: str, device: str = 'cuda'):
        self.device = device
        ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
        config = ckpt.get('config', {})
        self.max_channels = config.get('max_channels', config.get('n_channels', 19))
        self.hidden_dim = config.get('hidden_dim', 512)

        self.model = EEGEncoderV3(
            n_channels=config.get('n_channels', 19),
            hidden_dim=self.hidden_dim,
            n_heads=config.get('n_heads', 8),
            n_kv_heads=config.get('n_kv_heads', 2),
            n_layers=config.get('n_layers', 8),
            patch_size=config.get('patch_size', 32),
            dropout=0.0,
            ffn_multiplier=config.get('ffn_multiplier', 8.0 / 3.0),
            rope_base=config.get('rope_base', 500_000.0),
            n_registers=config.get('n_registers', 8),
            use_spatial_rope=
__init__ method · python · L112-L145 (34 LOC)
server/inference.py
    def __init__(self, checkpoint_path: str, device: str = 'cuda'):
        self.device = device
        ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
        config = ckpt.get('config', {})
        self.max_channels = config.get('max_channels', config.get('n_channels', 19))
        self.hidden_dim = config.get('hidden_dim', 512)

        self.model = EEGEncoderV3(
            n_channels=config.get('n_channels', 19),
            hidden_dim=self.hidden_dim,
            n_heads=config.get('n_heads', 8),
            n_kv_heads=config.get('n_kv_heads', 2),
            n_layers=config.get('n_layers', 8),
            patch_size=config.get('patch_size', 32),
            dropout=0.0,
            ffn_multiplier=config.get('ffn_multiplier', 8.0 / 3.0),
            rope_base=config.get('rope_base', 500_000.0),
            n_registers=config.get('n_registers', 8),
            use_spatial_rope=config.get('use_spatial_rope', True),
            use_modality_tokens=config.
encode method · python · L148-L204 (57 LOC)
server/inference.py
    def encode(self, eeg_data: np.ndarray, channel_mask: np.ndarray,
               channel_coords: np.ndarray, metadata: Dict) -> np.ndarray:
        """
        Encode full EEG recording → single embedding vector.

        Args:
            eeg_data: (max_channels, n_samples) float32
            channel_mask: (max_channels,) float32
            channel_coords: (max_channels, 2) float32
            metadata: dict with recording info

        Returns:
            embedding: (hidden_dim,) float32
        """
        device = self.device

        eeg = torch.from_numpy(eeg_data).unsqueeze(0).to(device)
        mask = torch.from_numpy(channel_mask).unsqueeze(0).to(device)
        coords = torch.from_numpy(channel_coords).unsqueeze(0).to(device)

        # Zero out inactive channels
        eeg = eeg * mask.unsqueeze(-1)

        # Trim to patch boundary
        ps = 32
        T = eeg.shape[2]
        eeg = eeg[:, :, :T // ps * ps]
        if eeg.shape[2] < ps:
            raise ValueErro
map_query_to_tasks method · python · L206-L219 (14 LOC)
server/inference.py
    def map_query_to_tasks(self, query: str) -> List[str]:
        """Map a natural language query to relevant task names."""
        query_lower = query.lower()
        tasks = set()

        for keyword, task_list in QUERY_TASK_MAP.items():
            if keyword in query_lower:
                tasks.update(task_list)

        if not tasks:
            # Default to general screening
            tasks = set(QUERY_TASK_MAP['general'])

        return sorted(tasks)
predict method · python · L221-L245 (25 LOC)
server/inference.py
    def predict(self, embedding: np.ndarray, tasks: List[str]) -> Dict:
        """
        Run predictions using the embedding.

        For now, returns the embedding with task metadata.
        Full SFT prediction will be added in Phase 2.

        Returns dict of {task_name: {probability, classification, confidence, model_auc}}.
        """
        # Phase 1: Return encoder-based analysis (embedding similarity to training cohorts)
        # Phase 2: Will add SFT model for proper Yes/No probability per task
        results = {}
        for task in tasks:
            info = TASK_INFO.get(task, {})
            results[task] = {
                'probability': None,  # Phase 2: SFT prediction
                'classification': 'unavailable',
                'confidence': 0.0,
                'model_auc': info.get('auc', 0.0),
                'n_test': info.get('n_test', 0),
                'description': info.get('description', task),
                'name': info.get('name', task),
     
All rows scored by the Repobility analyzer (https://repobility.com)
get_supported_conditions method · python · L247-L252 (6 LOC)
server/inference.py
    def get_supported_conditions(self) -> List[Dict]:
        """Return list of supported conditions with performance metrics."""
        return [
            {'task': task, **info}
            for task, info in sorted(TASK_INFO.items(), key=lambda x: -x[1]['auc'])
        ]
get_model_card method · python · L254-L286 (33 LOC)
server/inference.py
    def get_model_card(self) -> str:
        """Return model card text."""
        return f"""# NeoCog EEG Encoder — Model Card

## Architecture
- **Type**: EEGEncoderV3 (Transformer with SwiGLU, GQA, Spatial RoPE)
- **Parameters**: {sum(p.numel() for p in self.model.parameters()):,}
- **Hidden dim**: {self.hidden_dim}
- **Max channels**: {self.max_channels}
- **Patch size**: 32 samples (160ms at 200Hz)
- **Context**: Up to 4 minutes per chunk (1500 patches)

## Training
- **Objective**: Self-supervised next-patch prediction + spectral feature prediction
- **Data**: ~4000 EEG recordings from 7 clinical datasets
- **Datasets**: AD_EEG, CAUEEG, DS004504, PEARL, TD_BRAIN, READTBI, DORTMUND
- **Channel range**: 4 (portable) to 64 (HD-EEG) channels
- **Preprocessing**: 200Hz resample, robust z-score normalization

## Validation
- Evaluated on held-out patients (no train/test patient overlap)
- XGBoost + TabPFN classifiers on frozen encoder embeddings
- Best AUC: 0.95 (Vascular Dementia), 0.
lifespan function · python · L93-L95 (3 LOC)
server/mcp_server.py
async def lifespan(app: FastAPI):
    async with session_manager.run():
        yield
favicon function · python · L114-L116 (3 LOC)
server/mcp_server.py
async def favicon():
    from fastapi.responses import FileResponse
    return FileResponse(_ICON_PATH, media_type="image/png")
upload_eeg function · python · L191-L201 (11 LOC)
server/mcp_server.py
async def upload_eeg(file: UploadFile = File(...)):
    ext = Path(file.filename).suffix.lower()
    if ext not in SUPPORTED_EXTENSIONS:
        return JSONResponse(status_code=400,
            content={"detail": f"Unsupported format '{ext}'. Use: {', '.join(SUPPORTED_EXTENSIONS)}"})
    upload_id = str(uuid.uuid4())[:8]
    dest = UPLOAD_DIR / f"{upload_id}_{file.filename}"
    with open(dest, "wb") as f_out:
        shutil.copyfileobj(file.file, f_out)
    return {"upload_id": upload_id, "filename": file.filename,
            "size_mb": round(dest.stat().st_size / 1e6, 1)}
find_upload function · python · L204-L208 (5 LOC)
server/mcp_server.py
def find_upload(upload_id: str) -> Optional[Path]:
    for p in UPLOAD_DIR.iterdir():
        if p.name.startswith(upload_id + "_"):
            return p
    return None
get_or_compute_embedding function · python · L211-L230 (20 LOC)
server/mcp_server.py
def get_or_compute_embedding(upload_id: str):
    """Get cached embedding or compute from uploaded file."""
    if upload_id in embedding_cache:
        return embedding_cache[upload_id], None

    filepath = find_upload(upload_id)
    if filepath is None:
        return None, f"Upload ID '{upload_id}' not found. It may have expired. Please re-upload at /upload."

    try:
        eeg_data, channel_mask, channel_coords, metadata = preprocess_eeg(filepath)
        embedding = engine.encode(eeg_data, channel_mask, channel_coords, metadata)
        embedding_cache[upload_id] = embedding
        metadata_cache[upload_id] = metadata
        # Also cache features while we have the raw data
        feats = extract_features(eeg_data, channel_mask)
        feature_cache[upload_id] = feats
        return embedding, None
    except Exception as e:
        return None, f"Error processing EEG: {str(e)}"
get_or_compute_features function · python · L233-L248 (16 LOC)
server/mcp_server.py
def get_or_compute_features(upload_id: str):
    """Get cached features or compute from uploaded file."""
    if upload_id in feature_cache:
        return feature_cache[upload_id], None

    filepath = find_upload(upload_id)
    if filepath is None:
        return None, f"Upload ID '{upload_id}' not found."

    try:
        eeg_data, channel_mask, channel_coords, metadata = preprocess_eeg(filepath)
        feats = extract_features(eeg_data, channel_mask)
        feature_cache[upload_id] = feats
        return feats, None
    except Exception as e:
        return None, f"Error extracting features: {str(e)}"
Want this analysis on your repo? https://repobility.com/scan/
run_single_screen function · python · L251-L288 (38 LOC)
server/mcp_server.py
def run_single_screen(task: str, clf_data: dict, embedding: np.ndarray) -> dict:
    """Run a single classifier on an embedding. Returns structured result."""
    model = clf_data['model']
    display_name = clf_data['display_name']

    if clf_data['type'] == 'classification':
        inp = embedding.reshape(1, -1)
        # TabPFN needs PCA transform first
        if clf_data.get('method') == 'tabpfn' and 'pca' in clf_data:
            inp = clf_data['pca'].transform(inp)
        prob = float(model.predict_proba(inp)[0, 1])
        pred = "POSITIVE" if prob > 0.5 else "NEGATIVE"
        auc = clf_data['auc']
        n_test = clf_data['n_test']

        if auc < 0.70:
            confidence = "LOW"
        elif abs(prob - 0.5) < 0.1:
            confidence = "EQUIVOCAL"
        elif auc >= 0.85:
            confidence = "HIGH"
        else:
            confidence = "MODERATE"

        return {
            'task': task, 'display_name': display_name, 'type': 'classification',
          
handle_sse function · python · L301-L303 (3 LOC)
server/mcp_server.py
async def handle_sse(request: Request):
    async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
        await mcp_server.run(streams[0], streams[1], mcp_server.create_initialization_options())
list_tools function · python · L318-L458 (141 LOC)
server/mcp_server.py
async def list_tools():
    tools = []

    # ── Upload via web page ──
    tools.append(Tool(
        name="get_upload_link",
        description=(
            "ALWAYS call this tool FIRST when BrainScope is mentioned. "
            "Returns the upload link to show the user so they can submit their EEG file. "
            "You must call this before doing anything else."
        ),
        inputSchema={"type": "object", "properties": {}},
    ))

    # ── Recording info ──
    tools.append(Tool(
        name="get_recording_info",
        description="Get basic information about an uploaded EEG recording: duration, channels, sample rate, format.",
        inputSchema={
            "type": "object",
            "properties": {
                "upload_id": {"type": "string", "description": "Upload ID from the upload step"},
            },
            "required": ["upload_id"],
        },
    ))

    # ── Run all screens at once ──
    tools.append(Tool(
        name="run_all_screens",
   
call_tool function · python · L462-L493 (32 LOC)
server/mcp_server.py
async def call_tool(name: str, arguments: dict):
    if name == "get_upload_link":
        url = "https://b4b9ezp8vloj5o-8080.proxy.runpod.net/upload"
        return [TextContent(type="text", text=(
            f"Show this link to the user: {url}\n"
            f"Format as: [Upload your EEG recording here]({url})\n"
            f"Tell them to upload their EEG file (.edf/.bdf/.h5/.set) and paste "
            f"the upload_id back. If they don't have an EEG, suggest asking their "
            f"doctor to order one."
        ))]

    if name == "get_recording_info":
        return await handle_recording_info(arguments)

    if name == "run_all_screens":
        return await handle_run_all_screens(arguments)

    if name == "run_baseline_screens":
        return await handle_run_baseline_screens(arguments)

    if name == "search_literature":
        return await handle_search_literature(arguments)

    if name == "get_model_card":
        return [TextContent(type="text", text=engine.get_m
handle_recording_info function · python · L498-L531 (34 LOC)
server/mcp_server.py
async def handle_recording_info(args: dict):
    upload_id = args.get("upload_id")
    if not upload_id:
        return [TextContent(type="text", text="Please provide an upload_id.")]

    filepath = find_upload(upload_id)
    if filepath is None:
        return [TextContent(type="text", text=f"Upload ID '{upload_id}' not found.")]

    try:
        eeg_data, mask, coords, metadata = preprocess_eeg(filepath)
        n_active = int(mask.sum())
        text = (
            f"**Recording: {metadata['filename']}**\n"
            f"- Duration: {metadata['duration_s']:.0f}s ({metadata['duration_s']/60:.1f} min)\n"
            f"- Active EEG channels: {n_active}\n"
            f"- Sample rate: {metadata['srate']} Hz\n"
            f"- Format: {metadata['source_format']}\n"
            f"- Total samples: {metadata['n_samples']:,}\n"
            f"\nRecording quality notes:\n"
        )
        if metadata['duration_s'] < 60:
            text += "- WARNING: Recording is very short (<1 min). Res
handle_single_screen function · python · L534-L545 (12 LOC)
server/mcp_server.py
async def handle_single_screen(task: str, clf_data: dict, args: dict):
    upload_id = args.get("upload_id")
    if not upload_id:
        return [TextContent(type="text", text="Please provide an upload_id.")]

    embedding, error = get_or_compute_embedding(upload_id)
    if error:
        return [TextContent(type="text", text=error)]

    result = run_single_screen(task, clf_data, embedding)
    text = _format_single_result(result)
    return [TextContent(type="text", text=text)]
handle_run_all_screens function · python · L548-L580 (33 LOC)
server/mcp_server.py
async def handle_run_all_screens(args: dict):
    upload_id = args.get("upload_id")
    if not upload_id:
        return [TextContent(type="text", text="Please provide an upload_id.")]

    if upload_id in screen_cache:
        return screen_cache[upload_id]

    embedding, error = get_or_compute_embedding(upload_id)
    if error:
        return [TextContent(type="text", text=error)]

    # Run all classifiers in parallel
    loop = asyncio.get_event_loop()
    with concurrent.futures.ThreadPoolExecutor() as pool:
        futures = [
            loop.run_in_executor(pool, run_single_screen, task, clf_data, embedding)
            for task, clf_data in classifiers.items()
        ]
        results = await asyncio.gather(*futures)

    # Generate OpenEvidence-style report with inline citations + visualization JSON
    from server.response_generator import generate_enriched_report
    narrative, viz_json = generate_enriched_report(
        list(results), metadata_cache.get(upload_id)
    )
handle_run_baseline_screens function · python · L583-L650 (68 LOC)
server/mcp_server.py
async def handle_run_baseline_screens(args: dict):
    """Run traditional EEG feature classifiers as baseline comparison."""
    upload_id = args.get("upload_id")
    if not upload_id:
        return [TextContent(type="text", text="Please provide an upload_id.")]

    features, error = get_or_compute_features(upload_id)
    if error:
        return [TextContent(type="text", text=error)]

    if not feature_classifiers:
        return [TextContent(type="text", text="No feature baseline classifiers loaded.")]

    lines = ["# Traditional EEG Feature Analysis (Baseline)\n"]
    lines.append("*Using handcrafted features: band powers (delta/theta/alpha/beta/gamma), "
                 "Hjorth parameters (activity/mobility/complexity), spectral ratios, "
                 "and statistical measures.*\n")

    def _run_feature_clf(task, clf_data):
        model = clf_data['model']
        inp = features.reshape(1, -1)
        if clf_data['type'] == 'classification':
            prob = float(mode
All rows above produced by Repobility · https://repobility.com
handle_search_literature function · python · L653-L693 (41 LOC)
server/mcp_server.py
async def handle_search_literature(args: dict):
    from server.response_generator import CITATIONS

    conditions = args.get("conditions", [])
    if not conditions:
        return [TextContent(type="text", text="Please provide a list of conditions to search for.")]

    lines = ["# Relevant Literature\n"]
    ref_num = 1
    seen = set()

    for condition in conditions:
        cond_lower = condition.lower().strip()
        matched_keys = []
        for key in CITATIONS:
            if key in cond_lower or cond_lower in key:
                matched_keys.append(key)
        if not matched_keys:
            for key in CITATIONS:
                if any(word in cond_lower for word in key.split('_')):
                    matched_keys.append(key)

        for key in matched_keys:
            for ref in CITATIONS[key]:
                if ref['citation'] not in seen:
                    seen.add(ref['citation'])
                    tags_str = " | ".join(f"[{t}]" for t in ref['tags'])
     
_format_single_result function · python · L696-L714 (19 LOC)
server/mcp_server.py
def _format_single_result(result: dict) -> str:
    if result['type'] == 'classification':
        return (
            f"**{result['display_name']}**\n\n"
            f"- Prediction: **{result['prediction']}**\n"
            f"- Probability: {result['probability']:.3f}\n"
            f"- Confidence: {result['confidence']}\n"
            f"- Model AUC: {result['auc']:.2f} (validated on {result['n_test']} held-out patients)\n\n"
            f"*Screening result — not a clinical diagnosis.*"
        )
    else:
        lo, hi = result['valid_range']
        return (
            f"**{result['display_name']}**\n\n"
            f"- Predicted value: **{result['predicted_value']:.1f}**\n"
            f"- Valid range: [{lo:.0f}, {hi:.0f}]\n"
            f"- Model correlation: r={result['pearson_r']:.2f} (validated on {result['n_test']} patients)\n\n"
            f"*Screening result — not a clinical diagnosis.*"
        )
api_list_tools function · python · L720-L723 (4 LOC)
server/mcp_server.py
async def api_list_tools():
    """REST endpoint: list all tools (for stdio proxy)."""
    tools = await list_tools()
    return [{"name": t.name, "description": t.description, "inputSchema": t.inputSchema} for t in tools]
‹ prevpage 3 / 4next ›