"""1x7 small-multiples line plot: Pass@K for Qwen3-4B.

Reproduces the `passk_qwen4b` figure from the PreRL paper
(arXiv:2602.02488). Compares Ours vs GRPO across six math benchmarks
plus the average column (last panel, soft-purple background).

Standalone: just `python prerl_passk_qwen4b.py`. Pure inline data.
"""

import matplotlib.pyplot as plt
import matplotlib.ticker as mticker

# ──────────────────────────────────────────────────────────────────────
# K values (log2-spaced)
# ──────────────────────────────────────────────────────────────────────
K_VALUES = [1, 2, 4, 8, 16, 32, 64, 128, 256]

# ──────────────────────────────────────────────────────────────────────
# Data (Qwen3-4B). Each list aligns with K_VALUES above.
# ──────────────────────────────────────────────────────────────────────
DATA = {
    "Ours": {
        "minerva":  [0.304865, 0.334460, 0.359647, 0.383087, 0.405598, 0.426486, 0.444187, 0.460351, 0.475689],
        "aime24":   [0.509778, 0.603747, 0.677154, 0.727426, 0.760887, 0.788026, 0.817054, 0.842754, 0.861764],
        "aime25":   [0.432444, 0.517194, 0.585988, 0.634240, 0.667442, 0.696793, 0.731475, 0.765654, 0.794394],
        "amc23":    [0.899083, 0.936672, 0.950135, 0.955258, 0.959838, 0.966073, 0.972210, 0.974847, 0.975000],
        "math500":  [0.897987, 0.917559, 0.927944, 0.934850, 0.940064, 0.944716, 0.949121, 0.953031, 0.957036],
        "olympiad": [0.413951, 0.437050, 0.454850, 0.469750, 0.483052, 0.494420, 0.503774, 0.511872, 0.518588],
    },
    "GRPO": {
        "minerva":  [0.303199, 0.336934, 0.364681, 0.388603, 0.409653, 0.428501, 0.444822, 0.460563, 0.479799],
        "aime24":   [0.463222, 0.543437, 0.617671, 0.675826, 0.712266, 0.740692, 0.768174, 0.797021, 0.827642],
        "aime25":   [0.403889, 0.490693, 0.560218, 0.608361, 0.642509, 0.672837, 0.703557, 0.733968, 0.761073],
        "amc23":    [0.880500, 0.928585, 0.951694, 0.960918, 0.967412, 0.972857, 0.974867, 0.975000, 0.975000],
        "math500":  [0.892120, 0.914349, 0.925825, 0.932961, 0.938154, 0.942254, 0.946080, 0.950202, 0.953604],
        "olympiad": [0.403412, 0.427072, 0.445834, 0.462269, 0.476845, 0.490005, 0.501309, 0.510334, 0.516221],
    },
}

# ──────────────────────────────────────────────────────────────────────
# Compute AVG per algo
# ──────────────────────────────────────────────────────────────────────
DATASETS = ["aime25", "aime24", "amc23", "math500", "olympiad", "minerva"]
for algo in DATA:
    avg = []
    for ki in range(len(K_VALUES)):
        avg.append(sum(DATA[algo][ds][ki] for ds in DATASETS) / len(DATASETS))
    DATA[algo]["avg"] = avg

# ──────────────────────────────────────────────────────────────────────
# Style
# ──────────────────────────────────────────────────────────────────────
ALGO_STYLES = {
    "Ours": {"color": "#9370DB", "marker": "*", "markersize": 9, "lw": 2.2, "zorder": 3},
    "GRPO": {"color": "#444444", "marker": "^", "markersize": 6, "lw": 2.0, "zorder": 2},
}

COL_KEYS   = ["aime25", "aime24", "amc23", "math500", "olympiad", "minerva", "avg"]
COL_TITLES = ["AIME25", "AIME24", "AMC23", "MATH500", "OLYMPIAD", "MINERVA", "AVG"]

# ──────────────────────────────────────────────────────────────────────
# Plot
# ──────────────────────────────────────────────────────────────────────
fig, axes = plt.subplots(1, 7, figsize=(30, 4.5))
plt.subplots_adjust(wspace=0.28, bottom=0.22, top=0.88)

for j, (ds, title) in enumerate(zip(COL_KEYS, COL_TITLES)):
    ax = axes[j]
    for algo, style in ALGO_STYLES.items():
        ax.plot(
            K_VALUES, DATA[algo][ds],
            color=style["color"], marker=style["marker"],
            markersize=style["markersize"], linewidth=style["lw"],
            zorder=style["zorder"], label=algo,
        )

    ax.set_xscale("log", base=2)
    ax.set_xticks([1, 4, 16, 64, 256])
    ax.set_xticklabels([1, 4, 16, 64, 256], fontsize=9)
    ax.minorticks_off()
    ax.yaxis.set_major_formatter(
        mticker.FuncFormatter(lambda y, _: f"{y * 100:.0f}")
    )
    ax.grid(True, linestyle="-", alpha=0.5, color="lightgrey")
    ax.set_title(title, fontsize=14, fontweight="bold", pad=8)
    ax.set_xlabel("Number of Samples $K$", fontsize=10)

    if ds == "avg":
        ax.set_facecolor("#f7f4fb")

    if j == 0:
        ax.set_ylabel("Pass@K (%)", fontsize=12)
        ax.legend(loc="lower right", frameon=True, fontsize=10,
                  fancybox=True, framealpha=0.9)

fig.suptitle("Qwen3-4B", fontsize=14, fontweight="bold", y=0.98)

plt.savefig("passk_qwen4b.png", dpi=200, bbox_inches="tight")
print("Saved passk_qwen4b.png")
