"""Reproduction of Step Law (Predictable Scale) Figure 2 — 3D loss surface.

Anonymised reproduction of arXiv:2503.04715 Figure 2 right column:
twin 3D surfaces of training loss as a function of (learning rate, batch
size). The inverse-cone shape highlights a single basin around the
empirically observed optimum. Top panel sliced along the LR axis,
bottom panel sliced along the BS axis.

All data is synthesized from a quadratic-in-log-space model anchored at
the published Step-Law optimum (eta* = 1.79 N^-0.713 D^0.307,
B* = 0.58 D^0.571). No checkpoints or training logs are needed.
"""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401  (registers 3D)

plt.rcParams.update({
    "font.family": "serif",
    "font.size": 9,
})

N = 1.0e9
D = 1.0e11
eta_star = 1.79 * (N ** -0.713) * (D ** 0.307)
B_star = 0.58 * (D ** 0.571)

LR = np.logspace(np.log10(eta_star * 0.20), np.log10(eta_star * 4.5), 70)
BS = np.logspace(np.log10(B_star * 0.05), np.log10(B_star * 18), 70)
LRg, BSg = np.meshgrid(LR, BS)

x = np.log10(LRg) - np.log10(eta_star)
y = np.log10(BSg) - np.log10(B_star)
loss = 0.74 + 0.040 * x ** 2 + 0.018 * y ** 2 + 0.012 * x * y

CMAP = LinearSegmentedColormap.from_list(
    "step_law_3d",
    ["#7c1f1c", "#a8443e", "#c66e60", "#dca09b", "#dcc1c9", "#bdc8df", "#94a8d3", "#5e7bc0"],
)

fig = plt.figure(figsize=(8.5, 8.5))

# Top panel — LR-axis sliced surface
ax1 = fig.add_subplot(2, 1, 1, projection="3d")
surf1 = ax1.plot_surface(
    np.log10(LRg), np.log10(BSg), loss,
    cmap=CMAP, edgecolor="white", linewidth=0.06, alpha=0.95, antialiased=True,
)
ax1.set_xlabel("Learning Rate (log10)", fontsize=8.5, labelpad=6)
ax1.set_ylabel("Batch Size (log10)", fontsize=8.5, labelpad=6)
ax1.set_zlabel("Smooth Loss (log)", fontsize=8.5, labelpad=4)
ax1.set_title("4 LR-axis slices", fontsize=10, style="italic", pad=6)
ax1.view_init(elev=24, azim=-58)
ax1.set_box_aspect((1.2, 1.0, 0.55))

cb1 = fig.colorbar(surf1, ax=ax1, shrink=0.55, pad=0.07, aspect=22)
cb1.ax.tick_params(labelsize=7)

# Bottom panel — BS-axis sliced surface (same surface, different camera)
ax2 = fig.add_subplot(2, 1, 2, projection="3d")
surf2 = ax2.plot_surface(
    np.log10(LRg), np.log10(BSg), loss,
    cmap=CMAP, edgecolor="white", linewidth=0.06, alpha=0.95, antialiased=True,
)
ax2.set_xlabel("Learning Rate (log10)", fontsize=8.5, labelpad=6)
ax2.set_ylabel("Batch Size (log10)", fontsize=8.5, labelpad=6)
ax2.set_zlabel("Smooth Loss (log)", fontsize=8.5, labelpad=4)
ax2.set_title("4 BS-axis slices", fontsize=10, style="italic", pad=6)
ax2.view_init(elev=24, azim=-128)
ax2.set_box_aspect((1.2, 1.0, 0.55))

cb2 = fig.colorbar(surf2, ax=ax2, shrink=0.55, pad=0.07, aspect=22)
cb2.ax.tick_params(labelsize=7)

plt.subplots_adjust(hspace=0.18, top=0.96, bottom=0.04, left=0.0, right=0.92)
plt.savefig("predictscale_3dloss_repro.png", dpi=200, bbox_inches="tight")
print("saved predictscale_3dloss_repro.png")
