#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Raman Substrate Analysis GUI — control + paired modes (v3)
=====================================================

Mode 1: Control-only (legacy v1.1 behavior)
Paste:
    x    baseline    substrate1    substrate2 ...

Mode 2: Paired analysis
Paste a header row with triplets:
    x    PairA__baseline    PairA__control    PairA__sample    PairB__baseline    PairB__control    PairB__sample ...

The script auto-detects the mode from the headers.

Outputs:
- TXT report
- TSV export
  * Control mode: residual columns
  * Paired mode: d_control, d_sample, delta, smooth_bg, peaks for each pair
"""

import io
import re
import tkinter as tk
from tkinter import ttk, messagebox, filedialog
import numpy as np

APP_TITLE = "Raman Substrate Analysis GUI v3"

PI_K = 3.0
MAX_SUBPPRATES = 20
THRESHOLDS = {
    "clean_rms_ratio_max": 1.20,
    "clean_pi_max": 0.005,
    "intrinsic_rms_ratio_min": 1.90,
    "intrinsic_pi_min": 0.025,
}

def robust_std_mad(y: np.ndarray) -> float:
    y = np.asarray(y, dtype=float)
    med = np.median(y)
    mad = np.median(np.abs(y - med))
    return 1.4826 * mad

def detect_delimiter(line: str) -> str:
    candidates = ['\t', ',', ';', '|']
    counts = {d: line.count(d) for d in candidates}
    if counts['\t'] > 0:
        return '\t'
    best = max(counts, key=counts.get)
    return best if counts[best] > 0 else ' '

def split_line(line: str, delim: str):
    if delim == ' ':
        return [p for p in re.split(r"\s+", line.strip()) if p]
    parts = [p.strip() for p in line.split(delim)]
    if len(parts) < 2:
        parts = [p for p in re.split(r"[,\t;| ]+", line.strip()) if p]
    return parts

def is_floatish(s: str) -> bool:
    try:
        float(str(s).replace(",", "."))
        return True
    except Exception:
        return False

def moving_average(y, window=51):
    y = np.asarray(y, dtype=float)
    w = max(5, int(window))
    if w % 2 == 0:
        w += 1
    if len(y) < w:
        w = max(3, (len(y) // 2) * 2 + 1)
    pad = w // 2
    yp = np.pad(y, pad_width=pad, mode="edge")
    kernel = np.ones(w) / w
    return np.convolve(yp, kernel, mode="valid")

def fit_one(x: np.ndarray, b: np.ndarray, s: np.ndarray):
    x = np.asarray(x, dtype=float)
    b = np.asarray(b, dtype=float)
    s = np.asarray(s, dtype=float)
    X = np.column_stack([b, np.ones(len(x)), x])
    theta, *_ = np.linalg.lstsq(X, s, rcond=None)
    alpha, beta, gamma = theta.tolist()
    d = s - X @ theta
    sigma_b = robust_std_mad(b)
    sigma_d0 = robust_std_mad(d)
    sigma0 = max(sigma_b, sigma_d0, 1e-12)
    mask = np.abs(d) <= 3.0 * sigma0
    sigma_d_clip = robust_std_mad(d[mask]) if np.any(mask) else sigma_d0
    sigma_ref = max(sigma_b, sigma_d_clip, 1e-12)
    rms_d = float(np.sqrt(np.mean(d ** 2)))
    l1_d = float(np.mean(np.abs(d)))
    pi = float(np.mean(np.abs(d) > PI_K * sigma_ref))
    rms_ratio = rms_d / sigma_ref if sigma_ref > 0 else float("inf")
    if (rms_ratio <= THRESHOLDS["clean_rms_ratio_max"]) and (pi <= THRESHOLDS["clean_pi_max"]):
        classification = "CLEAN"
    elif (rms_ratio >= THRESHOLDS["intrinsic_rms_ratio_min"]) or (pi >= THRESHOLDS["intrinsic_pi_min"]):
        classification = "INTRINSIC SIGNAL PRESENT"
    else:
        classification = "MINOR PARASITICS"
    return {
        "alpha": float(alpha), "beta": float(beta), "gamma": float(gamma),
        "residual": d, "sigma_b": float(sigma_b), "sigma_ref": float(sigma_ref),
        "rms_d": float(rms_d), "rms_ratio": float(rms_ratio),
        "l1_d": float(l1_d), "pi": float(pi), "classification": classification,
    }

def paired_one(x, baseline, control, sample, window=51):
    cfit = fit_one(x, baseline, control)
    sfit = fit_one(x, baseline, sample)
    delta = sfit["residual"] - cfit["residual"]
    smooth_bg = moving_average(delta, window=window)
    peaks = delta - smooth_bg
    useful_rms = float(np.sqrt(np.mean(peaks ** 2)))
    useful_l1 = float(np.mean(np.abs(peaks)))
    background_rms = float(np.sqrt(np.mean(smooth_bg ** 2)))
    substrate_rms = float(cfit["rms_d"])
    parasite_rss = float(np.sqrt(background_rms ** 2 + substrate_rms ** 2))
    spr = useful_rms / parasite_rss if parasite_rss > 0 else float("inf")
    return {
        "control_fit": cfit,
        "sample_fit": sfit,
        "delta": delta,
        "smooth_bg": smooth_bg,
        "peaks": peaks,
        "useful_rms": useful_rms,
        "useful_l1": useful_l1,
        "background_rms": background_rms,
        "substrate_rms": substrate_rms,
        "parasite_rss": parasite_rss,
        "spr": spr,
    }

def parse_text_table(text: str):
    lines = [ln for ln in text.strip().splitlines() if ln.strip()]
    if len(lines) < 2:
        raise ValueError("Please paste at least two rows.")
    delim = detect_delimiter(lines[0])
    rows = [split_line(ln, delim) for ln in lines]
    return rows

def parse_control_mode(rows):
    header_present = any(not is_floatish(tok) for tok in rows[0][2:])
    headers = rows[0] if header_present else None
    data_rows = rows[1:] if header_present else rows

    numeric = []
    max_cols = max(len(r) for r in data_rows)
    for r in data_rows:
        if len(r) < 3:
            continue
        r2 = r + [""] * (max_cols - len(r))
        try:
            numeric.append([float(v.replace(",", ".")) for v in r2[:max_cols]])
        except Exception:
            continue
    if len(numeric) < 3:
        raise ValueError("Not enough numeric rows found.")
    arr = np.asarray(numeric, dtype=float)
    x = arr[:, 0]
    b = arr[:, 1]
    subs = []
    for j in range(2, min(arr.shape[1], 2 + MAX_SUBPPRATES)):
        if headers and j < len(headers) and not is_floatish(headers[j]):
            name = str(headers[j]).strip()
        else:
            name = f"Substrate_{j-1}"
        name = re.sub(r"\s+", " ", name).strip()[:30]
        subs.append({"name": name, "values": arr[:, j]})
    return {"mode": "control", "x": x, "b": b, "subs": subs}

def parse_paired_mode(rows):
    headers = rows[0]
    data_rows = rows[1:]
    max_cols = max(len(r) for r in data_rows)
    numeric = []
    for r in data_rows:
        r2 = r + [""] * (max_cols - len(r))
        try:
            numeric.append([float(v.replace(",", ".")) for v in r2[:max_cols]])
        except Exception:
            continue
    if len(numeric) < 3:
        raise ValueError("Not enough numeric rows found for paired mode.")
    arr = np.asarray(numeric, dtype=float)
    # find triplets based on header suffixes
    groups = {}
    for j, h in enumerate(headers):
        hh = str(h).strip()
        if "__" not in hh:
            continue
        prefix, suffix = hh.rsplit("__", 1)
        suffix = suffix.lower()
        groups.setdefault(prefix, {})
        groups[prefix][suffix] = j
    pairs = []
    for prefix, mapping in groups.items():
        if all(k in mapping for k in ("baseline", "control", "sample")):
            pairs.append({
                "name": prefix,
                "baseline": arr[:, mapping["baseline"]],
                "control": arr[:, mapping["control"]],
                "sample": arr[:, mapping["sample"]],
            })
    if not pairs:
        raise ValueError("No valid paired triplets found. Expected headers like Pair__baseline / Pair__control / Pair__sample.")
    x = arr[:, 0]
    return {"mode": "paired", "x": x, "pairs": pairs}

def detect_mode(rows):
    first = rows[0]
    if any("__baseline" in str(h).lower() or "__control" in str(h).lower() or "__sample" in str(h).lower() for h in first):
        return "paired"
    return "control"

def make_control_report(results):
    cols = ["Substrate","α","β","γ","σ_b","σ_ref","RMS_d","RMS_d/σ_ref","L1_d",f"PI(>|{PI_K}·σ_ref|) %","Class"]
    widths = [max(11, min(30, max(len(k) for k in results.keys()) + 2)), 10,10,10,10,10,10,14,10,16,24]
    lines = []
    lines.append("Raman baseline-projection report")
    lines.append("=" * 78)
    header = "".join(c.ljust(w) for c, w in zip(cols, widths))
    lines.append(header)
    lines.append("-" * len(header))
    for name, res in results.items():
        row = [name, f"{res['alpha']:.6g}", f"{res['beta']:.6g}", f"{res['gamma']:.6g}",
               f"{res['sigma_b']:.6g}", f"{res['sigma_ref']:.6g}", f"{res['rms_d']:.6g}",
               f"{res['rms_ratio']:.4g}", f"{res['l1_d']:.6g}", f"{100*res['pi']:.3f}", res["classification"]]
        lines.append("".join(v.ljust(w) for v, w in zip(row, widths)))
    lines.append("")
    lines.append("Methodology Summary:")
    lines.append("  Fit s ≈ α·b + β·1 + γ·x (least squares) per substrate; residual d = s − (α·b + β + γ·x).")
    lines.append("  σ_b from MAD(b). Stabilized σ_ref = max(MAD(b), MAD(d after 3σ clip)).")
    lines.append("  Metrics: RMS_d, RMS_d/σ_ref, L1_d, PI (fraction of |d| > 3·σ_ref).")
    lines.append("  Classification: CLEAN if (RMS_d/σ_ref ≤ 1.20) AND (PI ≤ 0.5%);")
    lines.append("                 INTRINSIC if (RMS_d/σ_ref ≥ 1.90) OR (PI ≥ 2.5%);")
    lines.append("                 otherwise MINOR PARASITICS.")
    return "\n".join(lines)

def make_paired_report(results):
    cols = ["Pair","α_control","α_sample","Control class","Useful RMS","Background RMS","Substrate RMS","Parasite RSS","SPR","Useful L1"]
    widths = [max(12, min(34, max(len(k) for k in results.keys()) + 2)),10,10,18,12,15,14,14,10,12]
    lines = []
    lines.append("Raman paired-analysis report")
    lines.append("=" * 78)
    header = "".join(c.ljust(w) for c, w in zip(cols, widths))
    lines.append(header)
    lines.append("-" * len(header))
    for name, res in results.items():
        row = [name, f"{res['control_fit']['alpha']:.6g}", f"{res['sample_fit']['alpha']:.6g}",
               res['control_fit']['classification'], f"{res['useful_rms']:.6g}", f"{res['background_rms']:.6g}",
               f"{res['substrate_rms']:.6g}", f"{res['parasite_rss']:.6g}", f"{res['spr']:.4g}", f"{res['useful_l1']:.6g}"]
        lines.append("".join(v.ljust(w) for v, w in zip(row, widths)))
    lines.append("")
    lines.append("Methodology Summary:")
    lines.append("  For each pair, the control and sample spectra are baseline-projected separately against the supplied baseline.")
    lines.append("  The paired residual is Δ = d_sample − d_control. Δ is then decomposed into a smooth background term")
    lines.append("  (moving average) and a peak-rich component. Useful Raman signal is quantified as RMS(peaks).")
    lines.append("  Parasitic contribution is quantified as RSS(RMS(background), RMS(control residual)).")
    lines.append("  SPR = Useful RMS / Parasite RSS. Larger SPR indicates that the structured sample contribution dominates")
    lines.append("  over smooth background and substrate-induced parasitics.")
    return "\n".join(lines)

class App(tk.Tk):
    def __init__(self):
        super().__init__()
        self.title(APP_TITLE)
        self.geometry("1100x760")
        self.minsize(920, 620)
        self.last_mode = None
        self.last_results = None
        self.last_data = None
        self._build_ui()

    def _build_ui(self):
        top = ttk.Frame(self, padding=8)
        top.pack(side=tk.TOP, fill=tk.X)
        ttk.Button(top, text="Analyze", command=self.on_analyze).pack(side=tk.LEFT)
        self.btn_export = ttk.Button(top, text="Export TSV", command=self.on_export, state=tk.DISABLED)
        self.btn_export.pack(side=tk.LEFT, padx=(8,0))
        self.btn_report = ttk.Button(top, text="Save Report (TXT)", command=self.on_report, state=tk.DISABLED)
        self.btn_report.pack(side=tk.LEFT, padx=(8,0))
        ttk.Button(top, text="Clear", command=self.on_clear).pack(side=tk.LEFT, padx=(8,0))

        mid = ttk.LabelFrame(self, text="Paste your table here:", padding=8)
        mid.pack(side=tk.TOP, fill=tk.BOTH, expand=True, padx=8, pady=(0,8))
        self.txt_input = tk.Text(mid, wrap=tk.NONE, height=20, undo=True)
        self.txt_input.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
        xscroll = ttk.Scrollbar(mid, orient="horizontal", command=self.txt_input.xview)
        xscroll.pack(side=tk.BOTTOM, fill=tk.X)
        yscroll = ttk.Scrollbar(mid, orient="vertical", command=self.txt_input.yview)
        yscroll.pack(side=tk.RIGHT, fill=tk.Y)
        self.txt_input.configure(xscrollcommand=xscroll.set, yscrollcommand=yscroll.set)

        bot = ttk.LabelFrame(self, text="Results / report preview:", padding=8)
        bot.pack(side=tk.BOTTOM, fill=tk.BOTH, expand=True, padx=8, pady=(0,8))
        self.txt_results = tk.Text(bot, wrap=tk.WORD, height=14)
        self.txt_results.pack(side=tk.TOP, fill=tk.BOTH, expand=True)

        self.status = tk.StringVar(value="Ready.")
        ttk.Label(self, textvariable=self.status, anchor="w", padding=(8,4)).pack(side=tk.BOTTOM, fill=tk.X)

    def on_clear(self):
        self.txt_input.delete("1.0", tk.END)
        self.txt_results.delete("1.0", tk.END)
        self.last_mode = None
        self.last_results = None
        self.last_data = None
        self.btn_export.configure(state=tk.DISABLED)
        self.btn_report.configure(state=tk.DISABLED)
        self.status.set("Cleared.")

    def on_analyze(self):
        txt = self.txt_input.get("1.0", tk.END)
        if not txt.strip():
            messagebox.showinfo(APP_TITLE, "Please paste your table first.")
            return
        try:
            rows = parse_text_table(txt)
            mode = detect_mode(rows)
            data = parse_paired_mode(rows) if mode == "paired" else parse_control_mode(rows)
        except Exception as e:
            messagebox.showerror(APP_TITLE, f"Parse error:\n{e}")
            return

        try:
            if data["mode"] == "control":
                results = {}
                for sub in data["subs"]:
                    results[sub["name"]] = fit_one(data["x"], data["b"], sub["values"])
                report = make_control_report(results)
            else:
                results = {}
                for pair in data["pairs"]:
                    results[pair["name"]] = paired_one(data["x"], pair["baseline"], pair["control"], pair["sample"])
                report = make_paired_report(results)
        except Exception as e:
            messagebox.showerror(APP_TITLE, f"Analysis error:\n{e}")
            return

        self.last_mode = data["mode"]
        self.last_results = results
        self.last_data = data
        self.txt_results.delete("1.0", tk.END)
        self.txt_results.insert(tk.END, report)
        self.btn_export.configure(state=tk.NORMAL)
        self.btn_report.configure(state=tk.NORMAL)
        self.status.set(f"Analysis complete ({self.last_mode} mode).")

    def on_export(self):
        if not self.last_results or not self.last_data:
            return
        out = io.StringIO()
        if self.last_mode == "control":
            x = self.last_data["x"]
            out.write("x")
            names = list(self.last_results.keys())
            for name in names:
                out.write(f"\tresidual_{name}")
            out.write("\n")
            for i in range(len(x)):
                row = [f"{x[i]}"] + [f"{self.last_results[name]['residual'][i]}" for name in names]
                out.write("\t".join(row) + "\n")
            initial = "control_residuals.tsv"
        else:
            x = self.last_data["x"]
            names = list(self.last_results.keys())
            header = ["x"]
            for name in names:
                header += [f"{name}__d_control", f"{name}__d_sample", f"{name}__delta", f"{name}__smooth_bg", f"{name}__peaks"]
            out.write("\t".join(header) + "\n")
            for i in range(len(x)):
                row = [f"{x[i]}"]
                for name in names:
                    rr = self.last_results[name]
                    row += [f"{rr['control_fit']['residual'][i]}", f"{rr['sample_fit']['residual'][i]}",
                            f"{rr['delta'][i]}", f"{rr['smooth_bg'][i]}", f"{rr['peaks'][i]}"]
                out.write("\t".join(row) + "\n")
            initial = "paired_analysis.tsv"
        path = filedialog.asksaveasfilename(title="Save TSV", defaultextension=".tsv", initialfile=initial,
                                            filetypes=[("TSV files","*.tsv"),("All files","*.*")])
        if path:
            with open(path, "w", encoding="utf-8") as f:
                f.write(out.getvalue())
            self.status.set(f"TSV saved to: {path}")

    def on_report(self):
        if not self.last_results:
            return
        text = make_paired_report(self.last_results) if self.last_mode == "paired" else make_control_report(self.last_results)
        initial = "paired_report.txt" if self.last_mode == "paired" else "control_report.txt"
        path = filedialog.asksaveasfilename(title="Save TXT report", defaultextension=".txt", initialfile=initial,
                                            filetypes=[("Text files","*.txt"),("All files","*.*")])
        if path:
            with open(path, "w", encoding="utf-8") as f:
                f.write(text)
            self.status.set(f"Report saved to: {path}")

def main():
    app = App()
    app.mainloop()

if __name__ == "__main__":
    main()
