"""
Pareto pipeline — B737-class cross-subsystem case
==================================================

Loads the published MOSOF Pareto front from /data/pareto-b737-ecs.json,
verifies non-dominance under the published axis directions, identifies
the knee solution, and emits clean CSV + a quick-look plot.

Source: Suslu, Ali, Jennions (2026). MOSOF with NDCI: A Cross-Subsystem
Evaluation. Sensors 26(1), 160. doi:10.3390/s26010160

Run:
    python pareto_pipeline.py
"""
from __future__ import annotations

import csv
import json
import os

import numpy as np
import matplotlib
matplotlib.use("Agg")  # headless backend — no GUI needed.
import matplotlib.pyplot as plt


HERE = os.path.dirname(os.path.abspath(__file__))
# Prefer the bundled copy (standalone clone); fall back to the
# parent-site path so the script also works in-place inside the
# buraksuslu.com working tree at engineering/pareto-pipeline/.
_BUNDLED = os.path.join(HERE, "data", "pareto-b737-ecs.json")
_PARENT = os.path.normpath(os.path.join(HERE, "..", "..", "data", "pareto-b737-ecs.json"))
DATA = _BUNDLED if os.path.exists(_BUNDLED) else _PARENT


# ── Non-dominance check ─────────────────────────────────────────────────

def dominates(a: np.ndarray, b: np.ndarray, directions: list[int]) -> bool:
    """Does point a dominate point b? directions: +1 maximise, -1 minimise."""
    strictly_better = False
    for ai, bi, d in zip(a, b, directions):
        if d == 1:  # maximise
            if ai < bi:
                return False
            if ai > bi:
                strictly_better = True
        else:       # minimise
            if ai > bi:
                return False
            if ai < bi:
                strictly_better = True
    return strictly_better


def non_dominated_mask(points: np.ndarray, directions: list[int]) -> np.ndarray:
    """Boolean mask: True where the point is non-dominated."""
    n = len(points)
    mask = np.ones(n, dtype=bool)
    for i in range(n):
        if not mask[i]:
            continue
        for j in range(n):
            if i == j:
                continue
            if dominates(points[j], points[i], directions):
                mask[i] = False
                break
    return mask


# ── Knee finder ─────────────────────────────────────────────────────────

def find_knee(perf: np.ndarray, cost: np.ndarray) -> int:
    """Knee = point with max perpendicular distance from the line
    joining the (max-perf, max-cost) and (min-perf, min-cost) extremes
    on the normalised perf-cost projection."""
    perf_n = (perf - perf.min()) / max(perf.max() - perf.min(), 1e-9)
    cost_n = (cost - cost.min()) / max(cost.max() - cost.min(), 1e-9)

    # Endpoints of the diagonal we measure against.
    lo_perf_idx = int(np.argmin(perf_n))
    hi_perf_idx = int(np.argmax(perf_n))
    p1 = np.array([perf_n[lo_perf_idx], cost_n[lo_perf_idx]])
    p2 = np.array([perf_n[hi_perf_idx], cost_n[hi_perf_idx]])
    line = p2 - p1
    line_len = np.linalg.norm(line) or 1e-9

    dists = []
    for i in range(len(perf_n)):
        pt = np.array([perf_n[i], cost_n[i]])
        a, b = line, pt - p1
        d = (a[0] * b[1] - a[1] * b[0]) / line_len
        dists.append(abs(d))
    return int(np.argmax(dists))


# ── Driver ───────────────────────────────────────────────────────────────

def main():
    with open(DATA, encoding="utf-8") as fh:
        d = json.load(fh)

    points = d["points"]
    axes = d["axes"]
    n = len(points)
    print(f"{n} points loaded")
    print(f"Source: {d['_meta']['title']}\n")

    # Build the (n, 3) array in the published axis order.
    axis_ids = [a["id"] for a in axes]
    directions = [(+1 if a["direction"] == "max" else -1) for a in axes]
    arr = np.array([[p[k] for k in axis_ids] for p in points], dtype=float)

    # Non-dominance check.
    nd = non_dominated_mask(arr, directions)
    n_nd = int(nd.sum())
    print(f"{n_nd} non-dominated  "
          f"({'all clean' if n_nd == n else f'{n - n_nd} dominated points'})\n")

    # Knee on the perf-cost projection (this script's heuristic).
    perf = arr[:, axis_ids.index("performance")]
    cost = arr[:, axis_ids.index("cost")]
    mtbf = arr[:, axis_ids.index("reliability")]
    heur_knee = find_knee(perf, cost)
    print(f"Heuristic knee (2D perf-cost): idx {heur_knee} "
          f"perf {perf[heur_knee]:.3f}, ${cost[heur_knee]:.0f}k, "
          f"{mtbf[heur_knee]:.0f} kh")

    # Published knee from the canonical metadata (3D measure).
    pub_idx = d.get("knee_index")
    if pub_idx is not None:
        print(f"Published knee (3D, from thesis Table 4-5): idx {pub_idx} "
              f"perf {perf[pub_idx]:.3f}, ${cost[pub_idx]:.0f}k, "
              f"{mtbf[pub_idx]:.0f} kh")
    kc = d["_meta"].get("knee_composition")
    if kc:
        composition = " * ".join(f"{k} {v}" for k, v in kc.items())
        print(f"Published knee composition: {composition}")

    # Use the published knee for downstream CSV and plot.
    knee = pub_idx if pub_idx is not None else heur_knee

    # Write derived CSV.
    csv_path = os.path.join(HERE, "pareto_clean.csv")
    with open(csv_path, "w", newline="", encoding="utf-8") as fh:
        w = csv.writer(fh)
        w.writerow(["idx", "performance", "cost_kUSD", "mtbf_kh",
                    "non_dominated", "is_knee"])
        for i in range(n):
            w.writerow([i, f"{perf[i]:.4f}", f"{cost[i]:.2f}",
                        f"{mtbf[i]:.1f}",
                        "yes" if nd[i] else "no",
                        "yes" if i == knee else "no"])

    # Quick-look plot — perf vs cost with MTBF encoded by colour.
    fig, ax = plt.subplots(figsize=(7, 5), dpi=110)
    sc = ax.scatter(cost, perf, c=mtbf, cmap="viridis",
                    s=28, edgecolors="#141413", linewidths=0.4, alpha=0.9)
    ax.scatter([cost[knee]], [perf[knee]], s=220, facecolor="none",
               edgecolor="#b8400b", linewidths=2.0, label="Knee")
    ax.set_xlabel("Cost (kUSD) -- lower is better")
    ax.set_ylabel("Diagnostic performance (NDCI normalised)")
    ax.set_title("Pareto front -- B737 cross-subsystem MOSOF run")
    cbar = fig.colorbar(sc, ax=ax)
    cbar.set_label("MTBF (kilohours)")
    ax.legend(loc="lower right")
    ax.grid(True, alpha=0.25)
    fig.tight_layout()
    png_path = os.path.join(HERE, "pareto_plot.png")
    fig.savefig(png_path)

    print(f"\nWrote {csv_path}")
    print(f"Wrote {png_path}")


if __name__ == "__main__":
    main()
