"""
2D Image with Histogram
=======================

Display a 2-D image with physical axes, a colourmap, and an interactive
histogram below — all wired together with draggable threshold widgets.

Layout
------
A :class:`~anyplotlib.GridSpec` with two rows puts the image
on top and a bar-chart histogram below.  Two
:class:`~anyplotlib.widgets.VLineWidget` handles on the histogram mark the
``display_min`` / ``display_max`` thresholds; dragging them updates the
image colour scale in real time.

Key bindings on the image panel: **R** reset view · **C** toggle colorbar ·
**L** / **S** cycle colour-scale modes.

New ``imshow`` parameters
-------------------------
``cmap``
    Colormap name passed directly to :meth:`~anyplotlib.Axes.imshow`
    (e.g. ``"viridis"``, ``"inferno"``).  Defaults to ``"gray"``.
``vmin`` / ``vmax``
    Colormap clipping limits in data units.  Values outside the range are
    clamped to the colormap endpoints.  Defaults to the data min/max.
``origin``
    ``"upper"`` (default) places row 0 at the top (image convention).
    ``"lower"`` places row 0 at the bottom (scientific / matrix convention)
    and automatically reverses the y-axis so tick values increase upward.
"""
import numpy as np
import anyplotlib as apl


rng = np.random.default_rng(1)

# ── Synthetic diffraction pattern ─────────────────────────────────────────────
N = 256
x = np.linspace(-5, 5, N)   # physical axis in nm
y = np.linspace(-5, 5, N)
XX, YY = np.meshgrid(x, y)
R = np.sqrt(XX ** 2 + YY ** 2)


def _ring(r, r0, width, amp):
    return amp * np.exp(-0.5 * ((r - r0) / width) ** 2)


image = (
    _ring(R, 0.0, 0.30, 1.00)    # central spot
    + _ring(R, 2.1, 0.15, 0.55)  # first-order ring
    + _ring(R, 4.2, 0.15, 0.25)  # second-order ring
    + rng.normal(scale=0.04, size=(N, N))
)

# ── Layout: image (top, 3×) + histogram bar chart (bottom, 1×) ────────────────
gs = apl.GridSpec(2, 1, height_ratios=[3, 1])
fig = apl.Figure(figsize=(500, 640))
ax_img  = fig.add_subplot(gs[0, 0])
ax_hist = fig.add_subplot(gs[1, 0])

# ── Image panel — cmap, vmin, vmax supplied directly to imshow ────────────────
vmin_init = float(image.min())
vmax_init = float(image.max())

# Pass cmap, vmin, and vmax directly — no separate set_colormap / set_clim call
# needed for the initial display.
v = ax_img.imshow(image, axes=[x, y], units="nm",
                  cmap="inferno", vmin=vmin_init, vmax=vmax_init)

# First-order spot markers in the same physical coordinates used by imshow
spot_nm = np.array([[ 2.1,  0.0], [-2.1,  0.0],
                    [ 0.0,  2.1], [ 0.0, -2.1]])
v.add_circles(spot_nm, name="spots", radius=7,
              edgecolors="#00e5ff", facecolors="#00e5ff22",
              labels=["g1", "g1_bar", "g2", "g2_bar"])

# ── Histogram bar chart ────────────────────────────────────────────────────────
counts, edges = np.histogram(image.ravel(), bins=64)
bin_centers   = 0.5 * (edges[:-1] + edges[1:])

h = ax_hist.bar(counts, x_centers=bin_centers, orient="v",
                color="#4fc3f7", y_units="count")

# ── Draggable threshold handles on the histogram ──────────────────────────────
wlo = h.add_vline_widget(vmin_init, color="#ff6e40")   # low-threshold handle
whi = h.add_vline_widget(vmax_init, color="#ffffff")   # high-threshold handle


@wlo.add_event_handler("pointer_up")
def _apply_low(event):
    """Update image display_min when the low handle is released."""
    v.set_clim(vmin=event.source.x)


@whi.add_event_handler("pointer_up")
def _apply_high(event):
    """Update image display_max when the high handle is released."""
    v.set_clim(vmax=event.source.x)


fig # Interactive

# %%
# Adjust colour map and display range
# ------------------------------------
# :meth:`~anyplotlib.Plot2D.set_colormap` switches the palette;
# :meth:`~anyplotlib.Plot2D.set_clim` adjusts the display range.
# Both are equivalent to passing ``cmap`` / ``vmin`` / ``vmax`` at construction.

v.set_colormap("viridis")
v.set_clim(vmin=0.0, vmax=0.8)

fig

# %%
# origin='lower' — scientific / matrix convention
# ------------------------------------------------
# Passing ``origin='lower'`` places row 0 of the data at the *bottom* of the
# image, matching the matplotlib / scientific convention.  The y-axis is
# automatically reversed so tick values still increase upward.

mat = np.arange(64, dtype=float).reshape(8, 8)   # row 0 = small values

fig2, ax2 = apl.subplots()
v2 = ax2.imshow(mat, cmap="plasma", origin="lower")

fig2 # Interactive

