"""Reproduction of Step Law (Predictable Scale) Figure 1.

Anonymised reproduction of arXiv:2503.04715 Figure 1:
A learning-rate x batch-size hyperparameter contour plot for a 1B-parameter
model trained on 100B tokens. Filled / line contours show relative loss
percentiles around the empirical optimum, with the global minimum (red X)
and Step Law's predicted optimum (yellow star) overlaid.

All data is synthesized from the published power-law form
    eta(N, D) = 1.79 * N^-0.713 * D^0.307
    B(D)      = 0.58 * D^0.571
together with a quadratic relative-loss model in (log_eta, log_B)-space,
so no checkpoints, training logs, or external assets are needed.
"""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.colors import LinearSegmentedColormap

plt.rcParams.update({
    "font.family": "serif",
    "font.size": 9.5,
    "axes.linewidth": 0.8,
})

N = 1.0e9
D = 1.0e11

eta_star = 1.79 * (N ** -0.713) * (D ** 0.307)
B_star = 0.58 * (D ** 0.571)

LR = np.logspace(np.log10(2e-4), np.log10(8e-3), 220)
BS = np.logspace(np.log10(2e4), np.log10(2e6), 220)
LRg, BSg = np.meshgrid(LR, BS)

x = np.log10(LRg) - np.log10(eta_star)
y = np.log10(BSg) - np.log10(B_star)

# Anisotropic quadratic in tilted coordinates ~ relative loss landscape.
theta = np.deg2rad(-22)
xr = x * np.cos(theta) - y * np.sin(theta)
yr = x * np.sin(theta) + y * np.cos(theta)
rel_pct = 0.094 + 1.4 * (xr ** 2) + 4.5 * (yr ** 2) + 0.6 * xr * yr
loss = 2.073 + 0.0009 * rel_pct

LEVELS = [0.125, 0.25, 0.5, 1.0, 2.0]
contour_colors = ["#4d3b8a", "#7a3a8c", "#c4407c", "#e07a3a", "#f0a83a"]

fig, ax = plt.subplots(figsize=(7.4, 5.4))

cf = ax.contourf(LRg, BSg, loss, levels=40, cmap=LinearSegmentedColormap.from_list(
    "step_law_loss",
    ["#4a3585", "#785a9c", "#a587b2", "#cba4b3", "#e0bdaa", "#eed7b5", "#f3e8ce"],
), alpha=0.0)
cs = ax.contour(LRg, BSg, rel_pct, levels=LEVELS, colors=contour_colors,
                linewidths=1.6)
fmt = {lv: f"+{lv:.3f}%" for lv in LEVELS}
ax.clabel(cs, inline=True, fmt=fmt, fontsize=7.2, inline_spacing=4)

ax.scatter([eta_star * 1.05], [B_star * 1.02], marker="*", s=170,
           color="#e8c84a", edgecolor="#7a5a10", linewidth=0.8, zorder=6,
           label="Ours (Step Law)")
ax.scatter([eta_star * 0.95], [B_star * 0.96], marker="X", s=120,
           color="#cf3b3b", edgecolor="#5a1414", linewidth=0.6, zorder=7,
           label="Global Minimum")
ax.scatter([eta_star * 0.55], [B_star * 1.4], marker="^", s=110,
           color="#26b6c4", edgecolor="#114a5a", linewidth=0.6, zorder=5,
           label="DeepSeek Law")
ax.scatter([eta_star * 1.35], [B_star * 1.3], marker="s", s=90,
           color="#a73da6", edgecolor="#3e0d3e", linewidth=0.6, zorder=5,
           label="Porian Law")

ax.axvline(eta_star * 0.18, ls="--", color="#e6624a", lw=1.0, alpha=0.85)
ax.axvline(eta_star * 0.21, ls="--", color="#e6624a", lw=1.0, alpha=0.85)
ax.text(eta_star * 0.20, ax.get_ylim()[0] * 1.6 if False else 5e4,
        "+2.000%", color="#e6624a", fontsize=7.0, rotation=90,
        ha="right", va="bottom")

cb = plt.colorbar(plt.cm.ScalarMappable(
    norm=plt.Normalize(vmin=2.06, vmax=2.18),
    cmap=LinearSegmentedColormap.from_list(
        "loss_bar",
        ["#3a5a90", "#6f86b6", "#b9b8d2", "#e8b9b3", "#dc6a72", "#b8395f"],
    ),
), ax=ax, pad=0.012, fraction=0.04, aspect=22)
cb.set_label("Loss", rotation=270, labelpad=12, fontsize=10)
cb.ax.tick_params(labelsize=8)

ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xlim(2e-4, 8e-3)
ax.set_ylim(2e4, 2e6)
ax.set_xticks([5e-4, 1e-3, 5e-3])
ax.set_xticklabels([r"$5\times 10^{-4}$", r"$10^{-3}$", r"$5\times 10^{-3}$"])
ax.set_yticks([1e5, 1e6])
ax.set_xlabel("Learning Rate", fontweight="bold", fontsize=10)
ax.set_ylabel("Batch Size", fontweight="bold", fontsize=10)
ax.grid(True, which="both", linewidth=0.3, color="#bbb", linestyle=":")

leg_handles = [
    Line2D([0], [0], marker="X", color="none", markerfacecolor="#cf3b3b",
           markeredgecolor="#5a1414", markersize=10, label="Global Minimum"),
    Line2D([0], [0], marker="*", color="none", markerfacecolor="#e8c84a",
           markeredgecolor="#7a5a10", markersize=12, label="Ours (Step Law)"),
    Line2D([0], [0], marker="^", color="none", markerfacecolor="#26b6c4",
           markeredgecolor="#114a5a", markersize=10, label="DeepSeek Law"),
    Line2D([0], [0], marker="s", color="none", markerfacecolor="#a73da6",
           markeredgecolor="#3e0d3e", markersize=9, label="Porian Law"),
    Line2D([0], [0], color="#e6624a", lw=1.2, ls="--",
           label="Microsoft Law"),
    Line2D([0], [0], color="#f0a83a", lw=1.4, ls="-.",
           label="OpenAI Law"),
]
ax.legend(handles=leg_handles, loc="upper right", frameon=True, fontsize=7.5,
          edgecolor="#888", framealpha=0.95)

plt.tight_layout()
plt.savefig("predictscale_contour_repro.png", dpi=200, bbox_inches="tight")
print("saved predictscale_contour_repro.png")
