Advanced Interactive Contrast Segmentation (3 × 3 Grid)#

A 3 × 3 grid of synthetic images, each independently segmented by flood-fill. Pass 8 or 9 images as images_flat; the grid always has 3 columns and enough rows to fit them all (last cell left blank when 8 images are supplied).

Interaction

Action

Effect

Left-click

Add a positive seed (green dot) on the clicked panel.

Shift + Left-click

Add a negative seed (red dot) — subtracts that connected region from the mask.

Ctrl + Left-click

Add a polygon vertex to the clip polygon of the active panel. The mask is restricted to pixels inside the polygon once at least 3 vertices exist.

Drag polygon vertex

Reposition any clip-polygon vertex; mask updates on mouse-up.

Hover + Delete / Backspace

Remove the clip vertex or seed nearest to the cursor (≤ 15 px).

+ / =

Increase tolerance (grow regions).

-

Decrease tolerance (shrink regions).

c

Clear all seeds (keeps clip polygon).

p

Clear the clip polygon.

After interaction, the resulting boolean mask arrays are in masks_flat (same order as images_flat).

Note

Click on a panel first to give it keyboard focus, then use the key bindings.

import math
import numpy as np
import anyplotlib as apl

# ── Helpers ───────────────────────────────────────────────────────────────────

N     = 192   # image size (pixels per side) for the synthetic demo images
NCOLS = 3     # fixed column count

rng  = np.random.default_rng(42)
xx, yy = np.meshgrid(np.arange(N), np.arange(N))


def _gauss(cx, cy, sigma, amplitude):
    return amplitude * np.exp(-((xx - cx) ** 2 + (yy - cy) ** 2) / (2 * sigma ** 2))


def _make_image(seed):
    """Synthesise a unique multi-blob test image."""
    r = np.random.default_rng(seed)
    blobs = [
        (r.integers(30, N - 30), r.integers(30, N - 30),
         r.integers(15, 35), r.uniform(0.5, 1.0))
        for _ in range(5)
    ]
    img = sum(_gauss(cx, cy, sig, amp) for cx, cy, sig, amp in blobs)
    img += 0.06 * r.standard_normal((N, N))
    return (img - img.min()) / (img.max() - img.min())


# ── Images — swap this list for your own (8 or 9 arrays of shape (H, W)) ─────

images_flat = [_make_image(seed) for seed in range(1, 9)]    # 8 images
# images_flat = [_make_image(seed) for seed in range(1, 10)] # uncomment for 9

# ── Grid geometry derived from the image list ─────────────────────────────────

n_images = len(images_flat)
if n_images not in (8, 9):
    raise ValueError(f"images_flat must contain 8 or 9 images, got {n_images}")

NROWS = math.ceil(n_images / NCOLS)   # 3 for both 8 and 9

# ── BFS flood-fill ────────────────────────────────────────────────────────────

def _bfs_region(img, row, col, tol):
    H, W = img.shape
    seed_val = float(img[row, col])
    visited  = np.zeros((H, W), dtype=bool)
    visited[row, col] = True
    stack = [(row, col)]
    while stack:
        r, c = stack.pop()
        for dr, dc in ((-1, 0), (1, 0), (0, -1), (0, 1)):
            nr, nc = r + dr, c + dc
            if 0 <= nr < H and 0 <= nc < W and not visited[nr, nc]:
                if abs(float(img[nr, nc]) - seed_val) <= tol:
                    visited[nr, nc] = True
                    stack.append((nr, nc))
    return visited


def _compute_mask(img, pos_seeds, neg_seeds, tol, clip_poly):
    """Flood-fill union, optionally restricted to a drawn polygon."""
    H, W = img.shape
    if not pos_seeds:
        return np.zeros((H, W), dtype=bool)
    combined = np.zeros((H, W), dtype=bool)
    for r, c in pos_seeds:
        combined |= _bfs_region(img, r, c, tol)
    for r, c in neg_seeds:
        combined &= ~_bfs_region(img, r, c, tol)

    if clip_poly and len(clip_poly) >= 3:
        # Pure-numpy even-odd ray-casting point-in-polygon
        # Polygon vertices are [x, y] = [col, row] in image-pixel space
        poly   = np.asarray(clip_poly, dtype=float)   # (K, 2) as [x, y]
        rows   = np.arange(H, dtype=float)
        cols   = np.arange(W, dtype=float)
        gc, gr = np.meshgrid(cols, rows)  # gc[r,c]=col, gr[r,c]=row
        xs     = gc.ravel()              # x = col index
        ys     = gr.ravel()              # y = row index
        inside = np.zeros(H * W, dtype=bool)
        n_v    = len(poly)
        xp, yp = poly[:, 0], poly[:, 1]
        for i in range(n_v):
            x1, y1 = xp[i], yp[i]
            x2, y2 = xp[(i + 1) % n_v], yp[(i + 1) % n_v]
            cond = ((y1 > ys) != (y2 > ys)) & (
                xs < (x2 - x1) * (ys - y1) / (y2 - y1 + 1e-12) + x1
            )
            inside ^= cond
        combined &= inside.reshape(H, W)

    return combined


# ── Per-panel state (flat) ────────────────────────────────────────────────────

TOL_STEP    = 0.01
TOL_MIN     = 0.005
TOL_MAX     = 0.40
SEED_RADIUS = 4
_HIDDEN     = [[-9999.0, -9999.0]]
_OFFSCREEN_TRI = [[-9990.0, -9990.0], [-9989.0, -9990.0], [-9989.0, -9989.0]]

_CMAPS = ["gray", "viridis", "plasma", "inferno", "magma",
          "cividis", "hot", "cool", "bone"]

panel_state = [
    {"pos_seeds": [], "neg_seeds": [], "tolerance": 0.08, "clip_poly": []}
    for _ in range(n_images)
]
masks_flat  = [np.zeros((N, N), dtype=bool) for _ in range(n_images)]
active_idx  = [0]

# ── Figure ────────────────────────────────────────────────────────────────────

fig, axes = apl.subplots(
    NROWS, NCOLS,
    figsize=(900, 900),
    help=(
        "Left-click              →  positive seed (grow)\n"
        "Shift + Left-click      →  negative seed (shrink)\n"
        "Ctrl  + Left-click      →  add clip-polygon vertex\n"
        "Drag polygon vertex     →  reposition (mask updates on release)\n"
        "Delete / Backspace      →  remove nearest vertex or seed\n"
        "+  /  -                 →  tolerance up / down\n"
        "c                       →  clear seeds\n"
        "p                       →  clear clip polygon"
    ),
)

# Flatten axes to a 1-D list (row-major, matches images_flat)
axes_flat = [axes[r][c] for r in range(NROWS) for c in range(NCOLS)]

# Build plot objects only for panels that have an image
plots_flat    = []
clip_wids     = []   # one PolygonWidget per panel

for idx in range(n_images):
    p = axes_flat[idx].imshow(images_flat[idx])
    p.set_colormap(_CMAPS[idx % len(_CMAPS)])

    # Seed marker groups
    p.add_circles(_HIDDEN, name="pos",
                  facecolors="#69f0ae", edgecolors="#ffffff",
                  radius=SEED_RADIUS)
    p.add_circles(_HIDDEN, name="neg",
                  facecolors="#ff5252", edgecolors="#ffffff",
                  radius=SEED_RADIUS)

    # Preview dots for partial polygon (< 3 vertices — before widget takes over)
    p.add_circles(_HIDDEN, name="clip_pts",
                  facecolors="#ffeb3b", edgecolors="#ffffff",
                  radius=3)

    # Draggable polygon widget — starts offscreen until ≥ 3 vertices are placed.
    # The widget provides per-vertex handles that can be dragged in the browser.
    wid = p.add_widget("polygon", color="#ffeb3b", vertices=_OFFSCREEN_TRI)
    clip_wids.append(wid)

    plots_flat.append(p)


# ── Refresh helper ────────────────────────────────────────────────────────────

def _refresh(idx):
    """Recompute mask and push all markers + overlay for panel ``idx``."""
    try:
        st  = panel_state[idx]
        p   = plots_flat[idx]
        img = images_flat[idx]

        masks_flat[idx] = _compute_mask(
            img, st["pos_seeds"], st["neg_seeds"],
            st["tolerance"], st["clip_poly"],
        )

        # Seed marker dots
        pos_off = [(c, r) for r, c in st["pos_seeds"]] or _HIDDEN
        neg_off = [(c, r) for r, c in st["neg_seeds"]] or _HIDDEN
        p.markers["circles"]["pos"].set(offsets=pos_off)
        p.markers["circles"]["neg"].set(offsets=neg_off)

        # Clip polygon widget — show real vertices once we have ≥ 3, else offscreen
        clip = st["clip_poly"]
        if len(clip) >= 3:
            clip_wids[idx].set(vertices=clip)
            # Hide the preview dots (widget handles are enough)
            p.markers["circles"]["clip_pts"].set(offsets=_HIDDEN)
        else:
            clip_wids[idx].set(vertices=_OFFSCREEN_TRI)
            # Show partial-polygon vertex dots during the building phase
            clip_off = [[v[0], v[1]] for v in clip] or _HIDDEN
            p.markers["circles"]["clip_pts"].set(offsets=clip_off)

        # Mask overlay
        p.set_overlay_mask(masks_flat[idx], color="#00e5ff", alpha=0.38)

    except Exception as exc:
        import traceback
        print(f"[panel {idx}] _refresh error: {exc}")
        traceback.print_exc()


# ── Click & key handlers (one closure per panel) ──────────────────────────────

def _make_handlers(idx):
    p   = plots_flat[idx]
    wid = clip_wids[idx]
    img = images_flat[idx]
    H, W = img.shape

    # ── Polygon widget: sync vertices → panel_state after any drag ────────────
    @wid.add_event_handler("pointer_up")
    def _poly_dragged(event):
        active_idx[0] = idx
        vs = wid.vertices    # widget data is synced from JS before callbacks
        if vs is None:
            return
        # Filter out any accidental off-screen dummy vertices
        real = [[float(v[0]), float(v[1])] for v in vs
                if abs(float(v[0])) < 9000 and abs(float(v[1])) < 9000]
        panel_state[idx]["clip_poly"] = real
        _refresh(idx)

    # ── Click: add seed or polygon vertex ─────────────────────────────────────
    @p.add_event_handler("pointer_down")
    def _on_click(event):
        if event.xdata is None or event.ydata is None:
            return
        active_idx[0] = idx
        st   = panel_state[idx]
        r_px = max(0, min(H - 1, int(round(float(event.ydata)))))
        c_px = max(0, min(W - 1, int(round(float(event.xdata)))))
        if "ctrl" in event.modifiers:
            st["clip_poly"].append([float(c_px), float(r_px)])
        elif "shift" in event.modifiers:
            st["neg_seeds"].append((r_px, c_px))
        else:
            st["pos_seeds"].append((r_px, c_px))
        _refresh(idx)

    # ── Keys: tolerance, clear, delete-nearest ─────────────────────────────────
    @p.add_event_handler("key_down")
    def _on_key(event):
        active_idx[0] = idx
        st = panel_state[idx]
        if event.key in ("+", "="):
            st["tolerance"] = min(TOL_MAX, round(st["tolerance"] + TOL_STEP, 4))
            _refresh(idx)
        elif event.key == "-":
            st["tolerance"] = max(TOL_MIN, round(st["tolerance"] - TOL_STEP, 4))
            _refresh(idx)
        elif event.key == "c":
            st["pos_seeds"].clear()
            st["neg_seeds"].clear()
            _refresh(idx)
        elif event.key == "p":
            st["clip_poly"].clear()
            _refresh(idx)
        elif event.key in ("Delete", "Backspace"):
            _delete_nearest(event)

    def _delete_nearest(event):
        st   = panel_state[idx]
        if event.xdata is None or event.ydata is None:
            return
        cx   = float(event.xdata)
        cy   = float(event.ydata)
        HIT2 = 15 ** 2   # hit radius squared (px)

        # Check clip-polygon vertices first (they're on top visually)
        best_dist  = float("inf")
        best_poly_i = -1
        for i, (vx, vy) in enumerate(st["clip_poly"]):
            d = (vx - cx) ** 2 + (vy - cy) ** 2
            if d < best_dist:
                best_dist   = d
                best_poly_i = i

        if best_poly_i >= 0 and best_dist <= HIT2:
            st["clip_poly"].pop(best_poly_i)
            _refresh(idx)
            return

        # Otherwise check seeds
        best_dist = float("inf")
        best_list = None
        best_i    = -1
        for lst in (st["pos_seeds"], st["neg_seeds"]):
            for i, (r, c) in enumerate(lst):
                d = (c - cx) ** 2 + (r - cy) ** 2
                if d < best_dist:
                    best_dist = d
                    best_list = lst
                    best_i    = i

        if best_list is not None and best_dist <= HIT2:
            best_list.pop(best_i)
            _refresh(idx)


_handlers = [_make_handlers(idx) for idx in range(n_images)]

fig

Total running time of the script: (0 minutes 0.830 seconds)

Gallery generated by Sphinx-Gallery