▼ code ▼ output | Cell: utils | deps: torch, numpy | 30.61s |
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
"""Simple utilities for running the models."""
import torch

def to_dtype(dtype_str: str):
    """Convert string to torch dtype."""
    if dtype_str == "float16":
        return torch.float16
    if dtype_str == "bfloat16":
        return torch.bfloat16
    return torch.float32

def tensor_stats(t: torch.Tensor) -> str:
    """Generate stats string for a tensor."""
    return (f"shape={tuple(t.shape)}, "
            f"dtype={t.dtype}, "
            f"device={t.device}, "
            f"mean={t.mean().item():.6f}, "
            f"std={t.std().item():.6f}")

def set_seed(seed: int):
    """Set seeds for reproducibility."""
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
Downloading setuptools (1.1MiB) Downloading nvidia-cufile-cu12 (1.1MiB) Downloading nvidia-cuda-cupti-cu12 (9.8MiB) Downloading sympy (6.0MiB) Downloading torch (846.8MiB) Downloading nvidia-cublas-cu12 (566.8MiB) Downloading numpy (15.9MiB) Downloading nvidia-curand-cu12 (60.7MiB) Downloading networkx (1.9MiB) Downloading nvidia-cusolver-cu12 (255.1MiB) Downloading triton (148.4MiB) Downloading nvidia-nvjitlink-cu12 (37.4MiB) Downloading nvidia-cusparselt-cu12 (273.9MiB) Downloading nvidia-cusparse-cu12 (274.9MiB) Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) Downloading nvidia-cufft-cu12 (184.2MiB) Downloading nvidia-cudnn-cu12 (674.0MiB) Downloading nvidia-nccl-cu12 (307.4MiB) Downloading nvidia-cufile-cu12 Downloading setuptools Downloading networkx Downloading nvidia-cuda-cupti-cu12 Downloading numpy Downloading sympy Downloading nvidia-nvjitlink-cu12 Downloading nvidia-curand-cu12 Downloading nvidia-cuda-nvrtc-cu12 Downloading triton Downloading nvidia-cufft-cu12 Downloading nvidia-cusolver-cu12 Downloading nvidia-cusparselt-cu12 Downloading nvidia-cusparse-cu12 Downloading nvidia-nccl-cu12 Downloading nvidia-cublas-cu12 Downloading nvidia-cudnn-cu12 Downloading torch Installed 26 packages in 234ms
▼ code ▼ output | Cell: bench_utils | deps: torch, numpy | 31.57s |
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
"""Reusable benchmarking utilities for performance testing."""
import time
import numpy as np
from contextlib import contextmanager
from typing import Callable, Dict, Tuple, Any, Optional
import torch

def to_dtype(dtype_str: str):
    """Convert string to torch dtype."""
    if dtype_str == "float16":
        return torch.float16
    if dtype_str == "bfloat16":
        return torch.bfloat16
    return torch.float32

def _sync(device: str):
    """Synchronize device if CUDA."""
    if device == "cuda":
        torch.cuda.synchronize()

def _compute_stats(times_s, tokens: Optional[int] = None) -> Dict[str, float]:
    """Compute comprehensive latency and throughput statistics."""
    lat_ms = np.array([t * 1000.0 for t in times_s])
    lat_ms_sorted = np.sort(lat_ms)
    n = len(lat_ms)

    stats = {
        "avg_ms": np.mean(lat_ms),
        "min_ms": np.min(lat_ms),
        "max_ms": np.max(lat_ms),
        "std_ms": np.std(lat_ms),
        "p50_ms": np.percentile(lat_ms, 50),
        "p95_ms": np.percentile(lat_ms, 95),
        "p99_ms": np.percentile(lat_ms, 99),
        "num_iters": n
    }

    if tokens is not None and n > 0:
        avg_s = np.mean(times_s)
        stats["tokens_per_s"] = tokens / avg_s if avg_s > 0 else float("inf")
        stats["throughput_variance"] = np.std([tokens / t for t in times_s if t > 0])

    return stats

def _format_timing_stats(stats: Dict[str, float], tokens: Optional[int] = None) -> str:
    """Format timing statistics for display."""
    lines = [
        "\n━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━",
        f"Iterations: {stats.get('num_iters', 0)}",
        "\nLatency Statistics:",
        f"  Average: {stats['avg_ms']:.3f} ms",
        f"  Min:     {stats['min_ms']:.3f} ms",
        f"  Max:     {stats['max_ms']:.3f} ms", 
        f"  Std Dev: {stats['std_ms']:.3f} ms",
        "\nPercentiles:",
        f"  P50 (median): {stats['p50_ms']:.3f} ms",
        f"  P95:          {stats['p95_ms']:.3f} ms",
        f"  P99:          {stats['p99_ms']:.3f} ms",
    ]

    if tokens is not None and 'tokens_per_s' in stats:
        lines.extend([
            "\nThroughput:",
            f"  Tokens/sec: {stats['tokens_per_s']:.1f}",
            f"  Std Dev:    {stats.get('throughput_variance', 0):.1f}",
        ])

    lines.append("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
    return "\n".join(lines)

def _bench_engine(
    call: Callable[[], Any], *, warmup: int, iters: int, device: str, dtype
) -> Tuple[Any, list]:
    """Core benchmarking engine with warmup and timing."""
    use_autocast = device == "cuda" and dtype in (torch.float16, torch.bfloat16)

    # Warmup phase
    print(f"\nWarming up ({warmup} iterations)...")
    with torch.inference_mode():
        for _ in range(max(0, warmup)):
            if use_autocast:
                with torch.autocast(device_type="cuda", dtype=dtype):
                    _ = call()
            else:
                _ = call()
        _sync(device)

    # Measurement phase
    print(f"Benchmarking ({iters} iterations)...")
    times_s = []
    last = None
    with torch.inference_mode():
        for i in range(max(1, iters)):
            start = time.perf_counter()
            if use_autocast:
                with torch.autocast(device_type="cuda", dtype=dtype):
                    last = call()
            else:
                last = call()
            _sync(device)
            end = time.perf_counter()
            times_s.append(end - start)

            # Progress indicator every 20% of iterations
            if i > 0 and i % max(1, iters // 5) == 0:
                pct = (i / iters) * 100
                avg_so_far = np.mean(times_s[:i]) * 1000
                print(f"  Progress: {pct:.0f}% complete (avg: {avg_so_far:.3f} ms)")

    return last, times_s

def tensor_stats(t: torch.Tensor) -> str:
    """Generate comprehensive stats string for a tensor."""
    return (f"shape={tuple(t.shape)}, "
            f"dtype={t.dtype}, "
            f"device={t.device}, "
            f"range=[{t.min().item():.6f}, {t.max().item():.6f}], "
            f"mean={t.mean().item():.6f}, "
            f"std={t.std().item():.6f}, "
            f"norm={t.norm().item():.6f}")

@contextmanager
def bench_context(
    *, warmup: int = 25, iters: int = 100, device: str = "cuda", dtype=torch.float32, tokens: Optional[int] = None, verbose: bool = True, save_json: Optional[str] = None
):
    """Context that yields a runner: runner(fn, *args, **kwargs) -> (result, stats)."""

    def runner(fn: Callable[..., Any], *args, **kwargs) -> Tuple[Any, Dict[str, float]]:
        # Log configuration
        if verbose:
            print(f"\n┌─ Benchmark Configuration ─────────────────────────────┐")
            # print(f"│ Device: {device:<15} Dtype: {dtype}              │")
            print(f"│ Warmup: {warmup:<15} Iters: {iters}              │")
            if tokens:
                print(f"│ Tokens: {tokens}                                        │")
            print(f"└────────────────────────────────────────────────────────┘")

        # Log input if it's a tensor
        if verbose and args and isinstance(args[0], torch.Tensor):
            print(f"\nInput: {tensor_stats(args[0])}")

        call = lambda: fn(*args, **kwargs)
        result, times_s = _bench_engine(call, warmup=warmup, iters=iters, device=device, dtype=dtype)

        # Log output if it's a tensor or tuple with tensors
        if verbose:
            print("\nOutput tensors:")
            if isinstance(result, torch.Tensor):
                print(f"  Primary: {tensor_stats(result)}")
            elif isinstance(result, tuple) and len(result) > 0 and isinstance(result[0], torch.Tensor):
                print(f"  Primary: {tensor_stats(result[0])}")
                if len(result) > 1:
                    if isinstance(result[1], torch.Tensor):
                        print(f"  Auxiliary: {tensor_stats(result[1])}")
                    else:
                        print(f"  Auxiliary: {type(result[1]).__name__}")

        # Compute and display statistics
        stats = _compute_stats(times_s, tokens=tokens)
        if verbose:
            print(_format_timing_stats(stats, tokens))

        # Save to JSON if requested
        if save_json:
            import json
            json_data = {
                "implementation": save_json.replace(".json", ""),
                "config": {
                    "warmup": warmup,
                    "iters": iters,
                    "device": str(device),  # Convert device to string
                    "dtype": str(dtype),
                    "tokens": tokens
                },
                "stats": stats,
                "output_sum": float(result[0].sum().item()) if isinstance(result, tuple) and len(result) > 0 else float(result.sum().item()) if isinstance(result, torch.Tensor) else None
            }
            with open(save_json, 'w') as f:
                json.dump(json_data, f, indent=2)
            if verbose:
                print(f"\nSaved benchmark results to {save_json}")

        return result, stats

    yield runner

def set_seed(seed: int):
    """Set seeds for reproducibility."""
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
Downloading networkx (1.9MiB) Downloading numpy (15.9MiB) Downloading nvidia-cusparselt-cu12 (273.9MiB) Downloading nvidia-cuda-cupti-cu12 (9.8MiB) Downloading nvidia-cufile-cu12 (1.1MiB) Downloading setuptools (1.1MiB) Downloading nvidia-cudnn-cu12 (674.0MiB) Downloading sympy (6.0MiB) Downloading nvidia-cufft-cu12 (184.2MiB) Downloading nvidia-cublas-cu12 (566.8MiB) Downloading nvidia-cusolver-cu12 (255.1MiB) Downloading nvidia-nvjitlink-cu12 (37.4MiB) Downloading nvidia-curand-cu12 (60.7MiB) Downloading torch (846.8MiB) Downloading nvidia-nccl-cu12 (307.4MiB) Downloading nvidia-cusparse-cu12 (274.9MiB) Downloading triton (148.4MiB) Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) Downloading nvidia-cufile-cu12 Downloading setuptools Downloading networkx Downloading nvidia-cuda-cupti-cu12 Downloading numpy Downloading nvidia-nvjitlink-cu12 Downloading sympy Downloading nvidia-curand-cu12 Downloading nvidia-cuda-nvrtc-cu12 Downloading triton Downloading nvidia-cufft-cu12 Downloading nvidia-cusolver-cu12 Downloading nvidia-cusparselt-cu12 Downloading nvidia-cusparse-cu12 Downloading nvidia-nccl-cu12 Downloading nvidia-cublas-cu12 Downloading nvidia-cudnn-cu12 Downloading torch Installed 26 packages in 234ms

This notebook runs the Yamoe and Binned MoE implementations once each with identical inputs to verify they produce consistent outputs.

▼ code ▼ output | Cell: config | deps: torch, numpy | 37.88s |
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
"""Shared configuration for both implementations."""
import torch

# Model configuration
NUM_EXPERTS = 128
HIDDEN_SIZE = 1152
INTERMEDIATE_SIZE = 3072
TOP_K = 4

# Input configuration
BATCH_SIZE = 1
SEQ_LEN = 100
DTYPE = "float32"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Seeds for reproducibility
WEIGHT_SEED = 999
EXPERT_SEED = 777
INPUT_SEED = 123
GENERAL_SEED = 42
Downloading nvidia-cusparselt-cu12 (273.9MiB) Downloading triton (148.4MiB) Downloading nvidia-nccl-cu12 (307.4MiB) Downloading nvidia-cuda-cupti-cu12 (9.8MiB) Downloading nvidia-cublas-cu12 (566.8MiB) Downloading networkx (1.9MiB) Downloading nvidia-nvjitlink-cu12 (37.4MiB) Downloading nvidia-cufft-cu12 (184.2MiB) Downloading nvidia-curand-cu12 (60.7MiB) Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) Downloading nvidia-cusparse-cu12 (274.9MiB) Downloading nvidia-cufile-cu12 (1.1MiB) Downloading nvidia-cudnn-cu12 (674.0MiB) Downloading setuptools (1.1MiB) Downloading sympy (6.0MiB) Downloading torch (846.8MiB) Downloading nvidia-cusolver-cu12 (255.1MiB) Downloading numpy (15.9MiB) Downloading nvidia-cufile-cu12 Downloading setuptools Downloading networkx Downloading nvidia-cuda-cupti-cu12 Downloading numpy Downloading sympy Downloading nvidia-nvjitlink-cu12 Downloading nvidia-curand-cu12 Downloading nvidia-cuda-nvrtc-cu12 Downloading triton Downloading nvidia-cufft-cu12 Downloading nvidia-cusolver-cu12 Downloading nvidia-cusparse-cu12 Downloading nvidia-cusparselt-cu12 Downloading nvidia-nccl-cu12 Downloading nvidia-cublas-cu12 Downloading nvidia-cudnn-cu12 Downloading torch Installed 26 packages in 225ms
▼ code ▼ output | Cell: save_data | deps: torch, numpy | 38.59s |
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
"""
Generate deterministic shared weights once and save as artifacts so
both implementations load identical parameters.
"""
import torch
from config import NUM_EXPERTS, HIDDEN_SIZE, WEIGHT_SEED, EXPERT_SEED

def save_shared_weights():
    # Router: Kaiming uniform as used by both, bias zeros
    torch.manual_seed(WEIGHT_SEED)
    router_weight = torch.empty(NUM_EXPERTS, HIDDEN_SIZE)
    torch.nn.init.kaiming_uniform_(router_weight)
    router_bias = torch.zeros(NUM_EXPERTS)

    # Experts: normal(0, 0.02), biases zeros
    torch.manual_seed(EXPERT_SEED)
    gate_up_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, 2 * HIDDEN_SIZE).normal_(mean=0.0, std=0.02)
    gate_up_proj_bias = torch.zeros(NUM_EXPERTS, 2 * HIDDEN_SIZE)
    down_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, HIDDEN_SIZE).normal_(mean=0.0, std=0.02)
    down_proj_bias = torch.zeros(NUM_EXPERTS, HIDDEN_SIZE)

    # Save artifacts
    torch.save(router_weight, 'router_weight.pt')
    torch.save(router_bias, 'router_bias.pt')
    torch.save(gate_up_proj, 'gate_up_proj.pt')
    torch.save(gate_up_proj_bias, 'gate_up_proj_bias.pt')
    torch.save(down_proj, 'down_proj.pt')
    torch.save(down_proj_bias, 'down_proj_bias.pt')

    print("Saved shared weights to artifacts")
    print(f"Router weight sum: {router_weight.sum().item():.6f}")
    print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
    print(f"Down sum: {down_proj.sum().item():.6f}")

save_shared_weights()
Saved shared weights to artifacts Router weight sum: 12.588732 Gate/up sum: 1026.601807 Down sum: 206.729263
Downloading nvidia-cufile-cu12 (1.1MiB) Downloading nvidia-cuda-cupti-cu12 (9.8MiB) Downloading networkx (1.9MiB) Downloading nvidia-cudnn-cu12 (674.0MiB) Downloading nvidia-curand-cu12 (60.7MiB) Downloading nvidia-cusolver-cu12 (255.1MiB) Downloading nvidia-nccl-cu12 (307.4MiB) Downloading nvidia-cusparse-cu12 (274.9MiB) Downloading nvidia-nvjitlink-cu12 (37.4MiB) Downloading nvidia-cublas-cu12 (566.8MiB) Downloading triton (148.4MiB) Downloading nvidia-cufft-cu12 (184.2MiB) Downloading sympy (6.0MiB) Downloading numpy (15.9MiB) Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) Downloading setuptools (1.1MiB) Downloading torch (846.8MiB) Downloading nvidia-cusparselt-cu12 (273.9MiB) Downloading nvidia-cufile-cu12 Downloading setuptools Downloading networkx Downloading nvidia-cuda-cupti-cu12 Downloading numpy Downloading nvidia-nvjitlink-cu12 Downloading sympy Downloading nvidia-curand-cu12 Downloading nvidia-cuda-nvrtc-cu12 Downloading triton Downloading nvidia-cufft-cu12 Downloading nvidia-cusolver-cu12 Downloading nvidia-cusparse-cu12 Downloading nvidia-cusparselt-cu12 Downloading nvidia-nccl-cu12 Downloading nvidia-cublas-cu12 Downloading nvidia-cudnn-cu12 Downloading torch Installed 26 packages in 239ms

Yamoe Implementation

This section runs the Yamoe MoE implementation with optimized Triton kernels.

▼ code ▼ output | Cell: yamoe_run | deps: torch, kernels, numpy | 35.75s |
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
from torch import nn
from torch.nn import functional as F
from kernels import get_kernel, get_local_kernel
from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
from config import (
    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
)
from pathlib import Path
import os

# Discover the upstream artifact directory from env
data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
print(f"Loading weights from: {data_dir}")

router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')

print("Loaded shared weights from artifacts")
print(f"Router weight sum: {router_weight.sum().item():.6f}")
print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
print(f"Down sum: {down_proj.sum().item():.6f}")

class YamoeRouter(nn.Module):
    def __init__(self, router_weight, router_bias):
        super().__init__()
        self.top_k = TOP_K
        self.num_experts = NUM_EXPERTS
        self.hidden_dim = HIDDEN_SIZE
        self.weight = nn.Parameter(router_weight.clone())
        self.bias = nn.Parameter(router_bias.clone())

    def forward(self, hidden_states):
        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
        router_logits = F.linear(hidden_states, self.weight, self.bias)
        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
        return router_scores, router_indices


class YamoeMoEMLP(nn.Module):
    def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
        super().__init__()
        self.router = YamoeRouter(router_weight, router_bias)
        self.num_experts = NUM_EXPERTS
        self.hidden_size = HIDDEN_SIZE
        self.top_k = TOP_K

        # Load Yamoe kernel
        # self.yamoe = get_local_kernel(Path("/home/ubuntu/Projects/yamoe/result"), "yamoe")
        self.yamoe = get_kernel("drbh/yamoe", revision="v0.2.0")

        # Expert capacity - generous to avoid dropping tokens
        # self.expert_capacity = 256
        self.expert_capacity = 12

        # Expert weights - use the loaded weights
        self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
        self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
        self.down_proj = nn.Parameter(down_proj.clone())
        self.down_proj_bias = nn.Parameter(down_proj_bias.clone())

    def forward(self, hidden_states):
        batch_size, seq_len, hidden_dim = hidden_states.shape

        # Get routing decisions
        routing_weights, router_indices = self.router(hidden_states)

        # Reshape for Yamoe kernel
        hidden_states_flat = hidden_states.view(-1, hidden_dim)
        routing_weights_flat = routing_weights.view(-1, self.num_experts)

        # Call Yamoe optimized kernel
        output = self.yamoe.experts(
            hidden_states_flat,
            router_indices,
            routing_weights_flat,
            self.gate_up_proj,
            self.gate_up_proj_bias,
            self.down_proj,
            self.down_proj_bias,
            self.expert_capacity,
            self.num_experts,
            self.top_k,
        )

        # Reshape output back
        output = output.view(batch_size, seq_len, hidden_dim)

        return output, routing_weights

# Run the model
set_seed(GENERAL_SEED)

device = torch.device(DEVICE if DEVICE == "cuda" else "cuda")
dtype = to_dtype(DTYPE)

print("\n=== Yamoe Implementation ===")
# Initialize model with loaded weights
model = YamoeMoEMLP(
    router_weight.to(device),
    router_bias.to(device),
    gate_up_proj.to(device),
    gate_up_proj_bias.to(device),
    down_proj.to(device),
    down_proj_bias.to(device)
).to(device=device)

print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}")
print(f"Down proj sum: {model.down_proj.sum().item():.6f}")

# Generate input
set_seed(INPUT_SEED)
x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1

# Benchmark the model
tokens = BATCH_SIZE * SEQ_LEN
with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="yamoe_results.json") as bench:
    output, stats = bench(model, x)
    print(f"\nOutput sum: {output[0].sum().item():.6f}")
Loading weights from: /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/0dc3119d70b6b7e0618fb3e0070aede3d5fc82296ac58f1ab73305d459560b73 Loaded shared weights from artifacts Router weight sum: 12.588732 Gate/up sum: 1026.601807 Down sum: 206.729263 === Yamoe Implementation === Router weight sum: 12.588732 Gate/up proj sum: 1026.601807 Down proj sum: 206.729340 ┌─ Benchmark Configuration ─────────────────────────────┐ │ Warmup: 10 Iters: 50 │ │ Tokens: 100 │ └────────────────────────────────────────────────────────┘ Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=0.000016, std=0.099892, norm=33.904163 Warming up (10 iterations)... Benchmarking (50 iterations)... Progress: 20% complete (avg: 8.633 ms) Progress: 40% complete (avg: 8.627 ms) Progress: 60% complete (avg: 8.629 ms) Progress: 80% complete (avg: 8.630 ms) Output tensors: Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.050624, 0.050640], mean=-0.000005, std=0.011573, norm=3.928071 Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.294473], mean=0.007812, std=0.043541, norm=5.004632 ━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ Iterations: 50 Latency Statistics: Average: 8.631 ms Min: 8.526 ms Max: 8.661 ms Std Dev: 0.022 ms Percentiles: P50 (median): 8.636 ms P95: 8.653 ms P99: 8.658 ms Throughput: Tokens/sec: 11586.6 Std Dev: 29.1 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ Saved benchmark results to yamoe_results.json Output sum: -0.597250
Downloading nvidia-cuda-cupti-cu12 (9.8MiB) Downloading sympy (6.0MiB) Downloading networkx (1.9MiB) Downloading nvidia-nvjitlink-cu12 (37.4MiB) Downloading nvidia-nccl-cu12 (307.4MiB) Downloading nvidia-curand-cu12 (60.7MiB) Downloading nvidia-cufile-cu12 (1.1MiB) Downloading numpy (15.9MiB) Downloading nvidia-cusolver-cu12 (255.1MiB) Downloading torch (846.8MiB) Downloading nvidia-cudnn-cu12 (674.0MiB) Downloading hf-xet (3.0MiB) Downloading setuptools (1.1MiB) Downloading nvidia-cusparselt-cu12 (273.9MiB) Downloading nvidia-cublas-cu12 (566.8MiB) Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) Downloading nvidia-cufft-cu12 (184.2MiB) Downloading nvidia-cusparse-cu12 (274.9MiB) Downloading triton (148.4MiB) Downloading nvidia-cufile-cu12 Downloading hf-xet Downloading setuptools Downloading networkx Downloading nvidia-cuda-cupti-cu12 Downloading numpy Downloading sympy Downloading nvidia-nvjitlink-cu12 Downloading nvidia-curand-cu12 Downloading nvidia-cuda-nvrtc-cu12 Downloading triton Downloading nvidia-cufft-cu12 Downloading nvidia-cusolver-cu12 Downloading nvidia-cusparse-cu12 Downloading nvidia-cusparselt-cu12 Downloading nvidia-nccl-cu12 Downloading nvidia-cublas-cu12 Downloading nvidia-cudnn-cu12 Downloading torch Installed 37 packages in 287ms Fetching 6 files: 0%| | 0/6 [00:00<?, ?it/s] Fetching 6 files: 17%|█▋ | 1/6 [00:00<00:01, 3.90it/s] Fetching 6 files: 50%|█████ | 3/6 [00:00<00:00, 3.70it/s] Fetching 6 files: 100%|██████████| 6/6 [00:00<00:00, 7.44it/s]

Artifacts:

yamoe_results.json

Binned Implementation

This section runs the binned implementation that manually handles token gathering/scattering.

▼ code ▼ output | Cell: binned_run | deps: torch, numpy | 42.05s |
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import torch
from torch import nn
from torch.nn import functional as F
from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
from config import (
    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
)
from pathlib import Path
import os

# Discover the upstream artifact directory from env
data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')

router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')

print("Loaded shared weights from artifacts")
print(f"Router weight sum: {router_weight.sum().item():.6f}")
print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
print(f"Down sum: {down_proj.sum().item():.6f}")

def binned_gather(x, indices, bins, expert_capacity, top_k):
    E, H = bins.shape[0], x.shape[1]
    out = torch.zeros((E, expert_capacity, H), device=x.device, dtype=x.dtype)
    for e in range(E):
        start = 0 if e == 0 else bins[e - 1]
        end = bins[e]
        n = min(end - start, expert_capacity)
        for i in range(n):
            flat_pos = indices[start + i]
            tok = flat_pos // top_k
            out[e, i] = x[tok]
    return out

def binned_scatter(x, indices, weights, bins, expert_capacity, top_k):
    E, C, H = x.shape
    N = indices.shape[0] // top_k
    out = torch.zeros((N, top_k, H), dtype=x.dtype, device=x.device)
    for e in range(E):
        start = 0 if e == 0 else bins[e - 1]
        end = bins[e]
        n = end - start
        if n == 0:
            continue
        take = min(n, expert_capacity)
        for i in range(take):
            flat_pos = indices[start + i]
            tok = flat_pos // top_k
            slot = flat_pos % top_k
            scale = weights[flat_pos] if weights is not None else 1.0
            out[tok, slot] = x[e, i] * scale
    return out.sum(dim=1)

def sort_tokens_by_expert(router_indices, num_experts):
    flat_indices = router_indices.flatten()
    sorted_values, sorted_indices = torch.sort(flat_indices)
    tokens_per_expert = torch.bincount(sorted_values, minlength=num_experts)
    bins = torch.cumsum(tokens_per_expert, dim=0)
    return sorted_indices, sorted_values, bins, tokens_per_expert

def binned_experts_ref(
    hidden_states,
    router_indices,
    routing_weights,
    gate_up_proj,
    gate_up_proj_bias,
    down_proj,
    down_proj_bias,
    expert_capacity,
):
    B, S, H = hidden_states.shape
    E, K = routing_weights.shape[1], router_indices.shape[1]

    indices, _, bins, _ = sort_tokens_by_expert(router_indices, E)
    x = binned_gather(hidden_states.view(-1, H), indices, bins, expert_capacity, K)

    gate_up = torch.bmm(x, gate_up_proj) 
    gate_up += gate_up_proj_bias[..., None, :]

    gate, up = gate_up[..., ::2], gate_up[..., 1::2]

    # clamp to limit
    limit = 7.0
    gate = gate.clamp(min=None, max=limit)
    up = up.clamp(min=-limit, max=limit)

    glu = gate * torch.sigmoid(gate * 1.702)
    x = (up + 1) * glu
    x = torch.bmm(x, down_proj) + down_proj_bias[..., None, :]

    # build routing weights aligned to (token, slot)
    flat_dense = routing_weights.view(-1, E)
    flat_router = router_indices.view(-1, K)
    selected = torch.gather(flat_dense, 1, flat_router).reshape(-1)

    # scatter back
    y = binned_scatter(x, indices, selected, bins, expert_capacity, K)

    return y.view(B, S, H)

class BinnedRouter(nn.Module):
    def __init__(self, router_weight, router_bias):
        super().__init__()
        self.top_k = TOP_K
        self.num_experts = NUM_EXPERTS
        self.hidden_dim = HIDDEN_SIZE
        self.weight = nn.Parameter(router_weight.clone())
        self.bias = nn.Parameter(router_bias.clone())

    def forward(self, hidden_states):
        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
        router_logits = F.linear(hidden_states, self.weight, self.bias)
        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
        return router_scores, router_indices

class BinnedMoEMLP(nn.Module):
    def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
        super().__init__()
        self.router = BinnedRouter(router_weight, router_bias)
        self.num_experts = NUM_EXPERTS
        self.hidden_size = HIDDEN_SIZE
        self.expert_capacity = 256

        # Expert weights - use the loaded weights
        self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
        self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
        self.down_proj = nn.Parameter(down_proj.clone())
        self.down_proj_bias = nn.Parameter(down_proj_bias.clone())

    def forward(self, hidden_states):
        router_scores, router_indices = self.router(hidden_states)

        output = binned_experts_ref(
            hidden_states,
            router_indices,
            router_scores,
            self.gate_up_proj,
            self.gate_up_proj_bias,
            self.down_proj,
            self.down_proj_bias,
            self.expert_capacity,
        )

        return output, router_scores

# Run the model
set_seed(GENERAL_SEED)

device = torch.device(DEVICE)
dtype = to_dtype(DTYPE)

print("\n=== Binned Implementation ===")
# Initialize model with loaded weights
model = BinnedMoEMLP(
    router_weight.to(device),
    router_bias.to(device),
    gate_up_proj.to(device),
    gate_up_proj_bias.to(device),
    down_proj.to(device),
    down_proj_bias.to(device)
).to(device=device)

print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}")
print(f"Down proj sum: {model.down_proj.sum().item():.6f}")

# Generate the same input as Yamoe
set_seed(INPUT_SEED)
x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1

# Benchmark the model
tokens = BATCH_SIZE * SEQ_LEN
with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="binned_results.json") as bench:
    output, stats = bench(model, x)
    print(f"\nOutput sum: {output[0].sum().item():.6f}")
Loaded shared weights from artifacts Router weight sum: 12.588732 Gate/up sum: 1026.601807 Down sum: 206.729263 === Binned Implementation === Router weight sum: 12.588732 Gate/up proj sum: 1026.601807 Down proj sum: 206.729340 ┌─ Benchmark Configuration ─────────────────────────────┐ │ Warmup: 10 Iters: 50 │ │ Tokens: 100 │ └────────────────────────────────────────────────────────┘ Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=0.000016, std=0.099892, norm=33.904163 Warming up (10 iterations)... Benchmarking (50 iterations)... Progress: 20% complete (avg: 104.222 ms) Progress: 40% complete (avg: 104.671 ms) Progress: 60% complete (avg: 105.372 ms) Progress: 80% complete (avg: 105.570 ms) Output tensors: Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.050624, 0.050640], mean=-0.000005, std=0.011573, norm=3.928071 Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.294473], mean=0.007812, std=0.043541, norm=5.004632 ━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ Iterations: 50 Latency Statistics: Average: 105.618 ms Min: 103.417 ms Max: 107.809 ms Std Dev: 1.458 ms Percentiles: P50 (median): 105.048 ms P95: 107.729 ms P99: 107.790 ms Throughput: Tokens/sec: 946.8 Std Dev: 13.0 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ Saved benchmark results to binned_results.json Output sum: -0.597248
Downloading nvidia-cusparselt-cu12 (273.9MiB) Downloading nvidia-cufft-cu12 (184.2MiB) Downloading triton (148.4MiB) Downloading nvidia-nvjitlink-cu12 (37.4MiB) Downloading torch (846.8MiB) Downloading networkx (1.9MiB) Downloading nvidia-curand-cu12 (60.7MiB) Downloading nvidia-cufile-cu12 (1.1MiB) Downloading nvidia-nccl-cu12 (307.4MiB) Downloading setuptools (1.1MiB) Downloading nvidia-cusolver-cu12 (255.1MiB) Downloading nvidia-cusparse-cu12 (274.9MiB) Downloading sympy (6.0MiB) Downloading nvidia-cudnn-cu12 (674.0MiB) Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) Downloading nvidia-cuda-cupti-cu12 (9.8MiB) Downloading nvidia-cublas-cu12 (566.8MiB) Downloading numpy (15.9MiB) Downloading nvidia-cufile-cu12 Downloading setuptools Downloading networkx Downloading nvidia-cuda-cupti-cu12 Downloading numpy Downloading nvidia-nvjitlink-cu12 Downloading sympy Downloading nvidia-curand-cu12 Downloading nvidia-cuda-nvrtc-cu12 Downloading triton Downloading nvidia-cufft-cu12 Downloading nvidia-cusolver-cu12 Downloading nvidia-cusparse-cu12 Downloading nvidia-cusparselt-cu12 Downloading nvidia-nccl-cu12 Downloading nvidia-cublas-cu12 Downloading nvidia-cudnn-cu12 Downloading torch Installed 26 packages in 233ms

Artifacts:

binned_results.json

GPT-OSS Implementation

This section runs the GPT-OSS MoE implementation with manual expert loop handling.

▼ code ▼ output | Cell: gptoss_run | deps: torch, numpy | 37.86s |
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
from torch import nn
from torch.nn import functional as F
from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
from config import (
    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
)
from pathlib import Path
import os

# Discover the upstream artifact directory from env
data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')

router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')

print("Loaded shared weights from artifacts")
print(f"Router weight sum: {router_weight.sum().item():.6f}")
print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
print(f"Down sum: {down_proj.sum().item():.6f}")

class GptOssRouter(nn.Module):
    def __init__(self, router_weight, router_bias):
        super().__init__()
        self.top_k = TOP_K
        self.num_experts = NUM_EXPERTS
        self.hidden_dim = HIDDEN_SIZE
        self.weight = nn.Parameter(router_weight.clone())
        self.bias = nn.Parameter(router_bias.clone())

    def forward(self, hidden_states):
        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
        router_logits = F.linear(hidden_states, self.weight, self.bias)
        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
        return router_scores, router_indices

class GptOssExperts(nn.Module):
    def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
        super().__init__()
        self.num_experts = NUM_EXPERTS
        self.hidden_size = HIDDEN_SIZE
        self.expert_dim = self.hidden_size
        self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
        self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
        self.down_proj = nn.Parameter(down_proj.clone())
        self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
        self.alpha = 1.702
        self.limit = 7.0

    def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
        batch_size = hidden_states.shape[0]
        hidden_states = hidden_states.reshape(-1, self.hidden_size)
        num_experts = routing_weights.shape[1]

        if hidden_states.device.type == "cpu" or self.training:
            next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
            with torch.no_grad():
                expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
                expert_mask = expert_mask.permute(2, 1, 0)
                expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()

            for expert_idx in expert_hit[:]:
                expert_idx = expert_idx[0]
                with torch.no_grad():
                    _, token_idx = torch.where(expert_mask[expert_idx])
                current_state = hidden_states[token_idx]
                gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
                gate, up = gate_up[..., ::2], gate_up[..., 1::2]
                gate = gate.clamp(min=None, max=self.limit)
                up = up.clamp(min=-self.limit, max=self.limit)
                glu = gate * torch.sigmoid(gate * self.alpha)
                gated_output = (up + 1) * glu
                out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
                weighted_output = out * routing_weights[token_idx, expert_idx, None]
                next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
            next_states = next_states.view(batch_size, -1, self.hidden_size)
        else:
            hidden_states = hidden_states.repeat(num_experts, 1)
            hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
            gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
            gate, up = gate_up[..., ::2], gate_up[..., 1::2]
            gate = gate.clamp(min=None, max=self.limit)
            up = up.clamp(min=-self.limit, max=self.limit)
            glu = gate * torch.sigmoid(gate * self.alpha)
            next_states = torch.bmm(((up + 1) * glu), self.down_proj)
            next_states = next_states + self.down_proj_bias[..., None, :]
            next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
            next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
            next_states = next_states.sum(dim=0)
        return next_states

class GptOssMoEMLP(nn.Module):
    def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
        super().__init__()
        self.router = GptOssRouter(router_weight, router_bias)
        self.experts = GptOssExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias)

    def forward(self, hidden_states):
        router_scores, router_indices = self.router(hidden_states)
        routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
        return routed_out, router_scores

# Run the model
set_seed(GENERAL_SEED)

device = torch.device(DEVICE)
dtype = to_dtype(DTYPE)

print("\n=== GPT-OSS Implementation ===")
# Initialize model with loaded weights
model = GptOssMoEMLP(
    router_weight.to(device),
    router_bias.to(device),
    gate_up_proj.to(device),
    gate_up_proj_bias.to(device),
    down_proj.to(device),
    down_proj_bias.to(device)
).to(device=device)

print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}")
print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}")

# Generate the same input as other implementations
set_seed(INPUT_SEED)
x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1

# Benchmark the model
tokens = BATCH_SIZE * SEQ_LEN
with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_results.json") as bench:
    output, stats = bench(model, x)
    print(f"\nOutput sum: {output[0].sum().item():.6f}")
Loaded shared weights from artifacts Router weight sum: 12.588732 Gate/up sum: 1026.601807 Down sum: 206.729263 === GPT-OSS Implementation === Router weight sum: 12.588732 Gate/up proj sum: 1026.601807 Down proj sum: 206.729340 ┌─ Benchmark Configuration ─────────────────────────────┐ │ Warmup: 10 Iters: 50 │ │ Tokens: 100 │ └────────────────────────────────────────────────────────┘ Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=0.000016, std=0.099892, norm=33.904163 Warming up (10 iterations)... Benchmarking (50 iterations)... Progress: 20% complete (avg: 46.973 ms) Progress: 40% complete (avg: 47.262 ms) Progress: 60% complete (avg: 47.067 ms) Progress: 80% complete (avg: 46.985 ms) Output tensors: Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.050624, 0.050640], mean=-0.000005, std=0.011573, norm=3.928071 Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.294473], mean=0.007812, std=0.043541, norm=5.004632 ━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ Iterations: 50 Latency Statistics: Average: 47.135 ms Min: 46.582 ms Max: 47.895 ms Std Dev: 0.503 ms Percentiles: P50 (median): 46.789 ms P95: 47.801 ms P99: 47.856 ms Throughput: Tokens/sec: 2121.6 Std Dev: 22.5 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ Saved benchmark results to gptoss_results.json Output sum: -0.597250
Downloading nvidia-cufile-cu12 (1.1MiB) Downloading setuptools (1.1MiB) Downloading sympy (6.0MiB) Downloading nvidia-cusparse-cu12 (274.9MiB) Downloading nvidia-cusparselt-cu12 (273.9MiB) Downloading nvidia-cusolver-cu12 (255.1MiB) Downloading nvidia-cufft-cu12 (184.2MiB) Downloading nvidia-cuda-cupti-cu12 (9.8MiB) Downloading nvidia-nccl-cu12 (307.4MiB) Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) Downloading nvidia-curand-cu12 (60.7MiB) Downloading nvidia-cublas-cu12 (566.8MiB) Downloading nvidia-cudnn-cu12 (674.0MiB) Downloading numpy (15.9MiB) Downloading networkx (1.9MiB) Downloading torch (846.8MiB) Downloading nvidia-nvjitlink-cu12 (37.4MiB) Downloading triton (148.4MiB) Downloading nvidia-cufile-cu12 Downloading setuptools Downloading networkx Downloading nvidia-cuda-cupti-cu12 Downloading numpy Downloading sympy Downloading nvidia-nvjitlink-cu12 Downloading nvidia-curand-cu12 Downloading nvidia-cuda-nvrtc-cu12 Downloading triton Downloading nvidia-cufft-cu12 Downloading nvidia-cusolver-cu12 Downloading nvidia-cusparse-cu12 Downloading nvidia-cusparselt-cu12 Downloading nvidia-nccl-cu12 Downloading nvidia-cublas-cu12 Downloading nvidia-cudnn-cu12 Downloading torch Installed 26 packages in 241ms

Artifacts:

gptoss_results.json

GPT-OSS Implementation (Training Mode)

This section runs the GPT-OSS MoE implementation with training mode enabled to force the expert loop path.

▼ code ▼ output | Cell: gptoss_training_run | deps: torch, numpy | 36.75s |
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
from torch import nn
from torch.nn import functional as F
from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
from config import (
    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
)
from pathlib import Path
import os

# Discover the upstream artifact directory from env
data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')

router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')

print("Loaded shared weights from artifacts")
print(f"Router weight sum: {router_weight.sum().item():.6f}")
print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
print(f"Down sum: {down_proj.sum().item():.6f}")

class GptOssTrainingRouter(nn.Module):
    def __init__(self, router_weight, router_bias):
        super().__init__()
        self.top_k = TOP_K
        self.num_experts = NUM_EXPERTS
        self.hidden_dim = HIDDEN_SIZE
        self.weight = nn.Parameter(router_weight.clone())
        self.bias = nn.Parameter(router_bias.clone())

    def forward(self, hidden_states):
        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
        router_logits = F.linear(hidden_states, self.weight, self.bias)
        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
        return router_scores, router_indices

class GptOssTrainingExperts(nn.Module):
    def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
        super().__init__()
        self.num_experts = NUM_EXPERTS
        self.hidden_size = HIDDEN_SIZE
        self.expert_dim = self.hidden_size
        self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
        self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
        self.down_proj = nn.Parameter(down_proj.clone())
        self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
        self.alpha = 1.702
        self.limit = 7.0

    def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
        batch_size = hidden_states.shape[0]
        hidden_states = hidden_states.reshape(-1, self.hidden_size)
        num_experts = routing_weights.shape[1]

        # Force training mode path (expert loop instead of batched)
        next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
        with torch.no_grad():
            expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
            expert_mask = expert_mask.permute(2, 1, 0)
            expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()

        for expert_idx in expert_hit[:]:
            expert_idx = expert_idx[0]
            with torch.no_grad():
                _, token_idx = torch.where(expert_mask[expert_idx])
            current_state = hidden_states[token_idx]
            gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
            gate, up = gate_up[..., ::2], gate_up[..., 1::2]
            gate = gate.clamp(min=None, max=self.limit)
            up = up.clamp(min=-self.limit, max=self.limit)
            glu = gate * torch.sigmoid(gate * self.alpha)
            gated_output = (up + 1) * glu
            out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
            weighted_output = out * routing_weights[token_idx, expert_idx, None]
            next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
        next_states = next_states.view(batch_size, -1, self.hidden_size)
        return next_states

class GptOssTrainingMoEMLP(nn.Module):
    def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
        super().__init__()
        self.router = GptOssTrainingRouter(router_weight, router_bias)
        self.experts = GptOssTrainingExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias)

    def forward(self, hidden_states):
        router_scores, router_indices = self.router(hidden_states)
        routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
        return routed_out, router_scores

# Run the model
set_seed(GENERAL_SEED)

device = torch.device(DEVICE)
dtype = to_dtype(DTYPE)

print("\n=== GPT-OSS Implementation (Training Mode - Expert Loop) ===")
# Initialize model with loaded weights and force training mode
model = GptOssTrainingMoEMLP(
    router_weight.to(device),
    router_bias.to(device),
    gate_up_proj.to(device),
    gate_up_proj_bias.to(device),
    down_proj.to(device),
    down_proj_bias.to(device)
).to(device=device)

# Set to training mode to force expert loop path
model.train()

print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}")
print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}")
print(f"Model training mode: {model.training}")

# Generate the same input as other implementations
set_seed(INPUT_SEED)
x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1

# Benchmark the model
tokens = BATCH_SIZE * SEQ_LEN
with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_training_results.json") as bench:
    output, stats = bench(model, x)
    print(f"\nOutput sum: {output[0].sum().item():.6f}")
Loaded shared weights from artifacts Router weight sum: 12.588732 Gate/up sum: 1026.601807 Down sum: 206.729263 === GPT-OSS Implementation (Training Mode - Expert Loop) === Router weight sum: 12.588732 Gate/up proj sum: 1026.601807 Down proj sum: 206.729340 Model training mode: True ┌─ Benchmark Configuration ─────────────────────────────┐ │ Warmup: 10 Iters: 50 │ │ Tokens: 100 │ └────────────────────────────────────────────────────────┘ Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=0.000016, std=0.099892, norm=33.904163 Warming up (10 iterations)... Benchmarking (50 iterations)... Progress: 20% complete (avg: 48.328 ms) Progress: 40% complete (avg: 48.764 ms) Progress: 60% complete (avg: 48.825 ms) Progress: 80% complete (avg: 48.769 ms) Output tensors: Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.050624, 0.050640], mean=-0.000005, std=0.011573, norm=3.928071 Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.294473], mean=0.007812, std=0.043541, norm=5.004632 ━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ Iterations: 50 Latency Statistics: Average: 48.630 ms Min: 47.535 ms Max: 49.414 ms Std Dev: 0.559 ms Percentiles: P50 (median): 48.395 ms P95: 49.346 ms P99: 49.390 ms Throughput: Tokens/sec: 2056.3 Std Dev: 23.6 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ Saved benchmark results to gptoss_training_results.json Output sum: -0.597250
Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) Downloading numpy (15.9MiB) Downloading networkx (1.9MiB) Downloading nvidia-cufft-cu12 (184.2MiB) Downloading nvidia-cufile-cu12 (1.1MiB) Downloading sympy (6.0MiB) Downloading nvidia-nvjitlink-cu12 (37.4MiB) Downloading nvidia-cusolver-cu12 (255.1MiB) Downloading nvidia-nccl-cu12 (307.4MiB) Downloading nvidia-cusparse-cu12 (274.9MiB) Downloading nvidia-cusparselt-cu12 (273.9MiB) Downloading nvidia-curand-cu12 (60.7MiB) Downloading setuptools (1.1MiB) Downloading nvidia-cuda-cupti-cu12 (9.8MiB) Downloading nvidia-cublas-cu12 (566.8MiB) Downloading nvidia-cudnn-cu12 (674.0MiB) Downloading triton (148.4MiB) Downloading torch (846.8MiB) Downloading nvidia-cufile-cu12 Downloading setuptools Downloading networkx Downloading nvidia-cuda-cupti-cu12 Downloading numpy Downloading nvidia-nvjitlink-cu12 Downloading sympy Downloading nvidia-curand-cu12 Downloading nvidia-cuda-nvrtc-cu12 Downloading triton Downloading nvidia-cufft-cu12 Downloading nvidia-cusolver-cu12 Downloading nvidia-cusparse-cu12 Downloading nvidia-cusparselt-cu12 Downloading nvidia-nccl-cu12 Downloading nvidia-cublas-cu12 Downloading nvidia-cudnn-cu12 Downloading torch Installed 26 packages in 234ms

MegaBlocks Implementation

This section runs the MegaBlocks MoE implementation with optimized kernels from the Hugging Face hub.

▼ code ▼ output | Cell: megablocks_run | deps: torch, numpy, kernels | 43.51s |
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
import torch
from torch import nn
from torch.nn import functional as F
from kernels import get_kernel, get_local_kernel
from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
from config import (
    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
)
from pathlib import Path
from collections import namedtuple
import os

# Discover the upstream artifact directory from env
data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')

print(f"Loading weights from: {data_dir}")

router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')

print("Loaded shared weights from artifacts")
print(f"Router weight sum: {router_weight.sum().item():.6f}")
print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
print(f"Down sum: {down_proj.sum().item():.6f}")

def build_megablocks_model(device: torch.device):
    # Download optimized kernels from the Hugging Face hub
    megablocks = get_kernel("kernels-community/megablocks")

    # megablocks = get_local_kernel(
    #     Path("/home/ubuntu/Projects/megablocks-moe/build"), "megablocks")

    model = megablocks.layers.MegaBlocksMoeMLP()

    # Create attribute container for expert weights
    model.experts = namedtuple(
        "Experts", ["gate_up_proj", "gate_up_proj_bias", "down_proj", "down_proj_bias", "hidden_size"]
    )

    # Use loaded router weights for consistency
    model.router = torch.nn.Linear(HIDDEN_SIZE, NUM_EXPERTS, device=device)
    with torch.no_grad():
        model.router.weight.copy_(router_weight)
        model.router.bias.copy_(router_bias)

    # Attach loaded expert weights to the experts container
    e = model.experts
    e.alpha = 1.702
    e.capacity_factor = 4
    e.gate_up_proj = torch.nn.Parameter(gate_up_proj.clone().to(device))
    e.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias.clone().to(device))
    e.down_proj = torch.nn.Parameter(down_proj.clone().to(device))
    e.down_proj_bias = torch.nn.Parameter(down_proj_bias.clone().to(device))
    e.hidden_size = HIDDEN_SIZE

    # Log weight statistics for comparison
    print(f"[MegaBlocks] Router weight sum: {model.router.weight.sum().item():.6f}")
    print(f"[MegaBlocks] Gate/up projection shape: {tuple(e.gate_up_proj.shape)}, sum: {e.gate_up_proj.sum().item():.6f}")
    print(f"[MegaBlocks] Down projection shape: {tuple(e.down_proj.shape)}, sum: {e.down_proj.sum().item():.6f}")

    return model

# Create a wrapper to match the interface of other implementations
class MegaBlocksMoEWrapper(nn.Module):
    def __init__(self, megablocks_model):
        super().__init__()
        self.model = megablocks_model

    def forward(self, hidden_states):
        # MegaBlocks expects input in the format (batch, seq_len, hidden_dim)
        output, dummy_routing_weights = self.model(hidden_states)
        # Return output and dummy routing weights for consistency with other implementations
        # dummy_routing_weights = torch.zeros(
        #     hidden_states.shape[0] * hidden_states.shape[1], 
        #     NUM_EXPERTS, 
        #     device=hidden_states.device,
        #     dtype=hidden_states.dtype
        # )
        return output, dummy_routing_weights

# Run the model
set_seed(GENERAL_SEED)

device = torch.device(DEVICE)
dtype = to_dtype(DTYPE)

print("\n=== MegaBlocks Implementation ===")
# Build MegaBlocks model with loaded weights
megablocks_model = build_megablocks_model(device)
model = MegaBlocksMoEWrapper(megablocks_model).to(device=device)

# Generate the same input as other implementations
set_seed(INPUT_SEED)
x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1

# Benchmark the model
tokens = BATCH_SIZE * SEQ_LEN
with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="megablocks_results.json") as bench:
    output, stats = bench(model, x)
    print(f"\nOutput sum: {output[0].sum().item():.6f}")
Loading weights from: /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/0dc3119d70b6b7e0618fb3e0070aede3d5fc82296ac58f1ab73305d459560b73 Loaded shared weights from artifacts Router weight sum: 12.588732 Gate/up sum: 1026.601807 Down sum: 206.729263 === MegaBlocks Implementation === [MegaBlocks] Router weight sum: 12.588732 [MegaBlocks] Gate/up projection shape: (128, 1152, 2304), sum: 1026.601807 [MegaBlocks] Down projection shape: (128, 1152, 1152), sum: 206.729340 ┌─ Benchmark Configuration ─────────────────────────────┐ │ Warmup: 10 Iters: 50 │ │ Tokens: 100 │ └────────────────────────────────────────────────────────┘ Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=0.000016, std=0.099892, norm=33.904163 Warming up (10 iterations)... Benchmarking (50 iterations)... Progress: 20% complete (avg: 0.867 ms) Progress: 40% complete (avg: 0.853 ms) Progress: 60% complete (avg: 1.181 ms) Progress: 80% complete (avg: 3.026 ms) Output tensors: Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.050624, 0.050640], mean=-0.000005, std=0.011573, norm=3.928071 Auxiliary: shape=(100, 4), dtype=torch.float32, device=cuda:0, range=[0.220910, 0.294473], mean=0.250000, std=0.010777, norm=5.004632 ━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ Iterations: 50 Latency Statistics: Average: 4.133 ms Min: 0.823 ms Max: 8.589 ms Std Dev: 3.781 ms Percentiles: P50 (median): 0.864 ms P95: 8.579 ms P99: 8.589 ms Throughput: Tokens/sec: 24194.9 Std Dev: 52511.7 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ Saved benchmark results to megablocks_results.json Output sum: -0.597249
Downloading setuptools (1.1MiB) Downloading nvidia-cudnn-cu12 (674.0MiB) Downloading numpy (15.9MiB) Downloading nvidia-cusparse-cu12 (274.9MiB) Downloading nvidia-nvjitlink-cu12 (37.4MiB) Downloading hf-xet (3.0MiB) Downloading nvidia-cusolver-cu12 (255.1MiB) Downloading networkx (1.9MiB) Downloading nvidia-cufft-cu12 (184.2MiB) Downloading nvidia-cufile-cu12 (1.1MiB) Downloading triton (148.4MiB) Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB) Downloading nvidia-curand-cu12 (60.7MiB) Downloading sympy (6.0MiB) Downloading nvidia-cuda-cupti-cu12 (9.8MiB) Downloading nvidia-nccl-cu12 (307.4MiB) Downloading nvidia-cusparselt-cu12 (273.9MiB) Downloading nvidia-cublas-cu12 (566.8MiB) Downloading torch (846.8MiB) Downloading nvidia-cufile-cu12 Downloading hf-xet Downloading setuptools Downloading networkx Downloading nvidia-cuda-cupti-cu12 Downloading numpy Downloading sympy Downloading nvidia-nvjitlink-cu12 Downloading nvidia-curand-cu12 Downloading nvidia-cuda-nvrtc-cu12 Downloading triton Downloading nvidia-cufft-cu12 Downloading nvidia-cusolver-cu12 Downloading nvidia-cusparse-cu12 Downloading nvidia-cusparselt-cu12 Downloading nvidia-nccl-cu12 Downloading nvidia-cublas-cu12 Downloading nvidia-cudnn-cu12 Downloading torch Installed 37 packages in 216ms Fetching 66 files: 0%| | 0/66 [00:00<?, ?it/s] Fetching 66 files: 2%|▏ | 1/66 [00:00<00:22, 2.87it/s] Fetching 66 files: 26%|██▌ | 17/66 [00:01<00:04, 11.84it/s] Fetching 66 files: 100%|██████████| 66/66 [00:01<00:00, 43.56it/s]

Performance Visualization

This section reads all benchmark results and creates a comprehensive performance comparison chart.

▼ code ▼ output | Cell: visualization | deps: matplotlib | 3.96s |
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
import json
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import os

# List of expected result files
yamoe_dir = os.environ.get('UVNOTE_INPUT_YAMOE_RUN', '.')
binned_dir = os.environ.get('UVNOTE_INPUT_BINNED_RUN', '.')
gptoss_dir = os.environ.get('UVNOTE_INPUT_GPTOSS_RUN', '.')
gptoss_training_dir = os.environ.get('UVNOTE_INPUT_GPTOSS_TRAINING_RUN', '.')
megablocks_dir = os.environ.get('UVNOTE_INPUT_MEGABLOCKS_RUN', '.')

result_files = [
    Path(yamoe_dir) / "yamoe_results.json",
    Path(binned_dir) / "binned_results.json", 
    Path(gptoss_dir) / "gptoss_results.json",
    Path(gptoss_training_dir) / "gptoss_training_results.json",
    Path(megablocks_dir) / "megablocks_results.json"
]

# Load all benchmark results
results = {}
for file in result_files:
    if Path(file).exists():
        with open(file, 'r') as f:
            data = json.load(f)
            results[data['implementation']] = data
        print(f"Loaded {file}")
    else:
        print(f"Missing {file}")

if not results:
    print("No benchmark results found. Run the benchmark cells first.")
else:
    # Extract data for plotting
    implementations = list(results.keys())
    avg_latencies = [results[impl]['stats']['avg_ms'] for impl in implementations]
    p95_latencies = [results[impl]['stats']['p95_ms'] for impl in implementations]
    throughputs = [results[impl]['stats'].get('tokens_per_s', 0) for impl in implementations]

    # Create figure with subplots
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
    fig.suptitle('MoE Implementation Performance Comparison', fontsize=16, fontweight='bold')

    # Colors for each implementation
    colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FECA57'][:len(implementations)]

    # 1. Average Latency Chart
    bars1 = ax1.bar(implementations, avg_latencies, color=colors, alpha=0.8, edgecolor='black', linewidth=1)
    ax1.set_title('Average Latency', fontweight='bold', fontsize=14)
    ax1.set_ylabel('Latency (ms)', fontweight='bold')
    ax1.tick_params(axis='x', rotation=45)
    ax1.grid(axis='y', alpha=0.3)

    # Add value labels on bars
    for bar, val in zip(bars1, avg_latencies):
        ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(avg_latencies)*0.01,
                f'{val:.2f}ms', ha='center', va='bottom', fontweight='bold')

    # 2. P95 Latency Chart
    bars2 = ax2.bar(implementations, p95_latencies, color=colors, alpha=0.8, edgecolor='black', linewidth=1)
    ax2.set_title('95th Percentile Latency', fontweight='bold', fontsize=14)
    ax2.set_ylabel('Latency (ms)', fontweight='bold')
    ax2.tick_params(axis='x', rotation=45)
    ax2.grid(axis='y', alpha=0.3)

    # Add value labels on bars
    for bar, val in zip(bars2, p95_latencies):
        ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(p95_latencies)*0.01,
                f'{val:.2f}ms', ha='center', va='bottom', fontweight='bold')

    # 3. Throughput Chart
    bars3 = ax3.bar(implementations, throughputs, color=colors, alpha=0.8, edgecolor='black', linewidth=1)
    ax3.set_title('Throughput', fontweight='bold', fontsize=14)
    ax3.set_ylabel('Tokens/sec', fontweight='bold')
    ax3.tick_params(axis='x', rotation=45)
    ax3.grid(axis='y', alpha=0.3)

    # Add value labels on bars
    for bar, val in zip(bars3, throughputs):
        if val > 0:  # Only show label if throughput was calculated
            ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(throughputs)*0.01,
                    f'{val:.0f}', ha='center', va='bottom', fontweight='bold')

    plt.tight_layout()
    plt.savefig("moe_performance_comparison.png", dpi=300)

    # Print summary table
    print("\nPerformance Summary:")
    print(f"{'Implementation':<30} {'Avg (ms)':<12} {'P95 (ms)':<12} {'Tokens/sec':<12} {'Relative Speed':<15}")
    print("-"*80)

    # Sort by average latency for relative speed calculation
    sorted_results = sorted(results.items(), key=lambda x: x[1]['stats']['avg_ms'])
    fastest_latency = sorted_results[0][1]['stats']['avg_ms']

    for impl, data in sorted_results:
        avg_ms = data['stats']['avg_ms']
        p95_ms = data['stats']['p95_ms']
        tokens_s = data['stats'].get('tokens_per_s', 0)
        relative_speed = fastest_latency / avg_ms

        print(f"{impl:<30} {avg_ms:>8.2f}    {p95_ms:>8.2f}    {tokens_s:>8.0f}      {relative_speed:>6.2f}x")

    print(f"\nFastest: {sorted_results[0][0]} ({sorted_results[0][1]['stats']['avg_ms']:.2f}ms avg)")
    if len(sorted_results) > 1:
        print(f"Slowest: {sorted_results[-1][0]} ({sorted_results[-1][1]['stats']['avg_ms']:.2f}ms avg)")
        speedup = sorted_results[-1][1]['stats']['avg_ms'] / sorted_results[0][1]['stats']['avg_ms']
        print(f"Max Speedup: {speedup:.1f}x")
Loaded /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/c5c8a351e1080ea89737c25df783e5c81cd76df0f2b017cedfd813e3bdf2f9f9/yamoe_results.json Loaded /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/af01d090b967f1cb05cacea7795553418933b27fc2f188da52f7c4642e456c24/binned_results.json Loaded /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/cf359ebbdbfd10241ce11898ee298eefd5da768c42d502b034caf3ba5b16aed6/gptoss_results.json Loaded /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/28eb2a85c2dc94e627a0c6373b55120bd67c549ef80cd5b5e94ae756ecd11aff/gptoss_training_results.json Loaded /home/ubuntu/Projects/uvnote-megablocks-bench/.uvnote/cache/a712c225c474c8776a91d23a96a2d4dd5dde0716ed16f6eb0dce9d92b65e06b8/megablocks_results.json Performance Summary: Implementation Avg (ms) P95 (ms) Tokens/sec Relative Speed -------------------------------------------------------------------------------- megablocks_results 4.13 8.58 24195 1.00x yamoe_results 8.63 8.65 11587 0.48x gptoss_results 47.14 47.80 2122 0.09x gptoss_training_results 48.63 49.35 2056 0.08x binned_results 105.62 107.73 947 0.04x Fastest: megablocks_results (4.13ms avg) Slowest: binned_results (105.62ms avg) Max Speedup: 25.6x
Downloading numpy (15.9MiB) Downloading fonttools (4.7MiB) Downloading pillow (6.3MiB) Downloading matplotlib (8.3MiB) Downloading kiwisolver (1.4MiB) Downloading kiwisolver Downloading pillow Downloading fonttools Downloading matplotlib Downloading numpy Installed 11 packages in 24ms