Source code for anyplotlib.figure._figure
"""
figure/_figure.py
=================
Top-level Figure widget (the single anywidget.AnyWidget subclass).
"""
from __future__ import annotations
import contextlib
import json
import pathlib
import time
import anywidget
import traitlets
from anyplotlib.axes import Axes, InsetAxes
from anyplotlib.axes._inset_axes import _plot_kind
from anyplotlib.figure._gridspec import SubplotSpec
from anyplotlib.callbacks import Event
from anyplotlib._repr_utils import repr_html_iframe
_HERE = pathlib.Path(__file__).parent.parent
_ESM_SOURCE = (_HERE / "figure_esm.js").read_text(encoding="utf-8")
[docs]
class Figure(anywidget.AnyWidget):
"""Multi-panel interactive figure widget.
The top-level container for all plots and the only ``anywidget.AnyWidget``
subclass in anyplotlib. It owns all traitlets and acts as the Python ↔
JavaScript bridge via the ``figure_esm.js`` canvas renderer.
Create via :func:`subplots` (recommended) or directly::
fig = Figure(2, 2, figsize=(800, 600))
ax = fig.add_subplot((0, 0))
v2d = ax.imshow(data)
Parameters
----------
nrows, ncols : int, optional
Grid dimensions. Default 1 row, 1 column.
figsize : (width, height), optional
Figure size in CSS pixels. Default ``(640, 480)``.
width_ratios : list of float, optional
Relative column widths. Length must equal *ncols*.
height_ratios : list of float, optional
Relative row heights. Length must equal *nrows*.
sharex, sharey : bool, optional
Link pan/zoom across all panels on the respective axis.
Default False (independent pan/zoom per panel).
See Also
--------
subplots : Recommended factory for creating Figure and Axes grid.
"""
layout_json = traitlets.Unicode("{}").tag(sync=True)
fig_width = traitlets.Int(640).tag(sync=True)
fig_height = traitlets.Int(480).tag(sync=True)
# Bidirectional JS event bus: JS writes interaction events here, Python reads them.
event_json = traitlets.Unicode("{}").tag(sync=True)
# When True the JS renderer shows a per-panel FPS / frame-time overlay.
display_stats = traitlets.Bool(False).tag(sync=True)
# Figure-level help text shown in a '?' badge overlay in JS.
# Empty string means no badge. Gated by apl.show_help at the Python level.
help_text = traitlets.Unicode("").tag(sync=True)
_esm = _ESM_SOURCE
# Static CSS injected by anywidget alongside _esm.
# .apl-scale-wrap — outer container; width:100% means it always fills
# the cell without any JS width updates.
# .apl-outer — the figure root; will-change:transform pre-promotes
# it to a GPU compositing layer so transform:scale()
# changes cost zero layout/paint passes.
_css = """\
.apl-scale-wrap {
display: block;
width: 100%;
overflow: visible;
position: relative;
line-height: 0;
}
.apl-outer {
display: inline-block;
position: relative;
user-select: none;
z-index: 1;
isolation: isolate;
will-change: transform;
transform-origin: top left;
vertical-align: top;
/* min-width: max-content prevents the inline-block from shrinking when
the parent container (scaleWrap, width:100%) narrows because the
Jupyter cell is narrower than the figure's native width. Without
this, outerDiv.offsetWidth collapses to cellW, causing _applyScale()
to compute s = cellW/cellW = 1.0 (no-op) instead of the correct
s = cellW/nativeW < 1. */
min-width: max-content;
}
"""
def __init__(self, nrows=1, ncols=1, figsize=(640, 480),
width_ratios=None, height_ratios=None,
sharex=False, sharey=False,
display_stats=False, help="", **kwargs):
super().__init__(**kwargs)
self._nrows = nrows
self._ncols = ncols
self._width_ratios = list(width_ratios) if width_ratios else [1] * ncols
self._height_ratios = list(height_ratios) if height_ratios else [1] * nrows
self._sharex = sharex
self._sharey = sharey
self._axes_map: dict = {}
self._plots_map: dict = {}
self._insets_map: dict = {}
self._hspace: float | None = None
self._wspace: float | None = None
self._batching: bool = False
self._batch_dirty: set = set()
# Geometry-channel bookkeeping (per panel id): a monotonic revision
# and the last geometry dict sent, so geometry is re-transmitted only
# when its values genuinely change.
self._geom_rev: dict = {}
self._geom_last: dict = {}
with self.hold_trait_notifications():
self.fig_width = figsize[0]
self.fig_height = figsize[1]
self.display_stats = display_stats
self.help_text = self._resolve_help(help)
self._push_layout()
@staticmethod
def _resolve_help(text: str) -> str:
"""Return *text* if ``apl.show_help`` is True (default), else ``""``."""
try:
import anyplotlib as _apl
if not getattr(_apl, "show_help", True):
return ""
except ImportError:
pass
return text or ""
[docs]
def set_help(self, text: str) -> None:
"""Set (or clear) the figure-level help text shown in the **?** badge.
Parameters
----------
text : str
Help string displayed when the user clicks the **?** badge.
Pass an empty string (or ``""`` ) to remove the badge entirely.
Newlines (``\\n``) are respected in the card.
Examples
--------
>>> fig.set_help("Drag peak: move μ/A\\nPress f: least-squares fit")
>>> fig.set_help("") # hide the badge
"""
self.help_text = self._resolve_help(text)
[docs]
def subplots_adjust(self, hspace: float | None = None,
wspace: float | None = None) -> None:
"""Set the spacing between subplot panels.
Only the arguments that are explicitly provided are updated; omitting
an argument leaves the current value unchanged.
Parameters
----------
hspace : float, optional
Fraction of the average row height to use as vertical gap between
panels. ``0.1`` adds a gap of 10 % of the mean row height.
``None`` (default) leaves the current hspace unchanged.
wspace : float, optional
Fraction of the average column width to use as horizontal gap.
``None`` (default) leaves the current wspace unchanged.
"""
if hspace is not None:
self._hspace = float(hspace)
if wspace is not None:
self._wspace = float(wspace)
self._push_layout()
# ── subplot creation ──────────────────────────────────────────────────────
[docs]
def add_subplot(self, spec) -> Axes:
"""Add a subplot cell and return its :class:`Axes`.
Parameters
----------
spec : SubplotSpec or int or tuple of (row, col)
Which grid cell(s) to occupy. A :class:`SubplotSpec` is used
directly (e.g. from ``GridSpec[r, c]``). An :class:`int` is
converted via ``divmod(spec, ncols)``, matching
``matplotlib.Figure.add_subplot`` numbering. A ``(row, col)``
tuple selects a single cell.
Returns
-------
Axes
The subplot axes object. Call plotting methods like ``.imshow()``,
``.plot()``, ``.bar()`` to attach data.
Raises
------
TypeError
If *spec* is not a SubplotSpec, int, or tuple.
Examples
--------
>>> fig = Figure(2, 2)
>>> ax1 = fig.add_subplot(0) # top-left (via numbering)
>>> ax2 = fig.add_subplot((0, 1)) # top-right (via tuple)
"""
if isinstance(spec, SubplotSpec):
# Auto-sync Figure grid to the parent GridSpec when the GridSpec is
# larger than the Figure's current dimensions. This allows the
# common workflow:
# gs = GridSpec(2, 2, height_ratios=[3, 1])
# fig = Figure(figsize=(...)) # defaults to nrows=1, ncols=1
# fig.add_subplot(gs[0, :]) # Figure adopts 2×2 from GridSpec
# without requiring the user to repeat nrows/ncols/ratios on Figure.
gs = spec._gs
if gs is not None:
if gs.nrows > self._nrows:
self._nrows = gs.nrows
self._height_ratios = list(gs.height_ratios)
if gs.ncols > self._ncols:
self._ncols = gs.ncols
self._width_ratios = list(gs.width_ratios)
elif isinstance(spec, int):
row, col = divmod(spec, self._ncols)
spec = SubplotSpec(None, row, row + 1, col, col + 1)
elif isinstance(spec, tuple):
row, col = spec
spec = SubplotSpec(None, row, row + 1, col, col + 1)
else:
raise TypeError(f"add_subplot: unsupported spec type {type(spec)!r}")
return Axes(self, spec)
# ── internal registration (called by Axes._attach) ────────────────────────
def _register_panel(self, ax: Axes, plot) -> None:
pid = plot._id
if not self.has_trait(f"panel_{pid}_json"):
self.add_traits(**{f"panel_{pid}_json": traitlets.Unicode("{}").tag(sync=True)})
# Plots that declare _GEOM_KEYS get a second trait carrying only the
# heavy geometry, re-sent only when that geometry changes. The light
# view trait then references it by revision so JS reuses the cached
# decode across view-only updates (highlight, camera, planes).
if getattr(plot, "_GEOM_KEYS", None) and not self.has_trait(f"panel_{pid}_geom"):
self.add_traits(**{f"panel_{pid}_geom": traitlets.Unicode("{}").tag(sync=True)})
self._geom_rev[pid] = 0
self._geom_last[pid] = None
self._plots_map[pid] = plot
self._axes_map[pid] = ax
self._push(pid)
self._push_layout()
def _push(self, panel_id: str) -> None:
"""Serialise one panel and write to its trait.
Inside a :meth:`batch` block, pushes are coalesced: each panel is
recorded as dirty and serialised + sent exactly once when the block
exits, no matter how many mutations touched it. This collapses the
many per-frame pushes of a linked-view update (set_data + set_title +
widget moves on the same panel) into one serialise/transfer per panel
— the dominant cost over a Pyodide comm boundary.
"""
plot = self._plots_map.get(panel_id)
if plot is None:
return
if self._batching:
self._batch_dirty.add(panel_id)
return
tname = f"panel_{panel_id}_json"
if not self.has_trait(tname):
return
state = plot.to_state_dict()
geom_keys = getattr(plot, "_GEOM_KEYS", None)
gname = f"panel_{panel_id}_geom"
if geom_keys and self.has_trait(gname):
# Split heavy geometry into its own channel. Detect change by
# comparing the geom values themselves (the b64 strings / LUT
# lists) against the last-sent snapshot — a reference/equality
# check that avoids re-serialising hundreds of KB on every
# view-only frame. Only on a real change do we serialise the
# geom blob, bump the revision, and write the geom trait.
geom = {k: state.pop(k) for k in geom_keys if k in state}
if geom != self._geom_last.get(panel_id):
self._geom_last[panel_id] = geom
self._geom_rev[panel_id] = self._geom_rev.get(panel_id, 0) + 1
setattr(self, gname, json.dumps(geom, sort_keys=True))
state["_geom_rev"] = self._geom_rev.get(panel_id, 0)
setattr(self, tname, json.dumps(state))
else:
setattr(self, tname, json.dumps(state))
[docs]
@contextlib.contextmanager
def batch(self):
"""Coalesce all panel pushes within the block into one push per panel.
Use around multi-panel updates (e.g. a linked-view crosshair handler)
so a single mouse event produces one serialise + transfer per panel
instead of one per mutation — a large win under Pyodide / remote
kernels where every push crosses a comm boundary.
::
with fig.batch():
v_xz.set_data(slice_xz)
v_yz.set_data(slice_yz)
cross.set(cx=..., cy=...)
"""
if self._batching: # already batching — nest transparently
yield
return
self._batching = True
self._batch_dirty = set()
try:
yield
finally:
self._batching = False
dirty, self._batch_dirty = self._batch_dirty, set()
# One push per dirty panel — no matter how many mutations touched
# it during the block. hold_trait_notifications coalesces the
# underlying comm traffic into a single sync.
with self.hold_trait_notifications():
for pid in dirty:
self._push(pid)
# ── layout ────────────────────────────────────────────────────────────────
def _compute_cell_sizes(self) -> dict:
fw, fh = self.fig_width, self.fig_height
wr, hr = self._width_ratios, self._height_ratios
wsum, hsum = sum(wr), sum(hr)
# Grid tracks are pure ratio math — no aspect-locking.
# Rule: col_px[i] = fw * width_ratios[i] / Σ width_ratios (and analogous
# for rows). Every panel gets exactly the canvas size its cell specifies;
# images are rendered "contain" (letterboxed) in JS if needed.
col_px = [fw * w / wsum for w in wr]
row_px = [fh * h / hsum for h in hr]
sizes: dict = {}
for pid, ax in self._axes_map.items():
s = ax._spec
pw = int(round(sum(col_px[s.col_start:s.col_stop])))
ph = int(round(sum(row_px[s.row_start:s.row_stop])))
sizes[pid] = (max(64, pw), max(64, ph))
return sizes
def _push_layout(self) -> None:
cell_sizes = self._compute_cell_sizes()
all_ids = list(self._axes_map.keys())
share_groups: dict = {}
def _mg(flag, key):
if flag is True and len(all_ids) > 1:
share_groups[key] = [list(all_ids)]
elif isinstance(flag, list):
share_groups[key] = flag
_mg(self._sharex, "x")
_mg(self._sharey, "y")
panel_specs = []
for pid, ax in self._axes_map.items():
s = ax._spec
pw, ph = cell_sizes.get(pid, (200, 200))
plot = self._plots_map.get(pid)
panel_specs.append({
"id": pid,
"kind": _plot_kind(plot) if plot else "1d",
"row_start": s.row_start,
"row_stop": s.row_stop,
"col_start": s.col_start,
"col_stop": s.col_stop,
"panel_width": pw,
"panel_height": ph,
})
inset_specs = []
for pid, inset_ax in self._insets_map.items():
plot = self._plots_map.get(pid)
pw = max(64, round(self.fig_width * inset_ax.w_frac))
ph = max(64, round(self.fig_height * inset_ax.h_frac))
inset_specs.append({
"id": pid,
"kind": _plot_kind(plot) if plot else "1d",
"w_frac": inset_ax.w_frac,
"h_frac": inset_ax.h_frac,
"corner": inset_ax.corner,
"title": inset_ax.title,
"panel_width": pw,
"panel_height": ph,
"inset_state": inset_ax._inset_state,
})
self.layout_json = json.dumps({
"nrows": self._nrows,
"ncols": self._ncols,
"width_ratios": self._width_ratios,
"height_ratios": self._height_ratios,
"fig_width": self.fig_width,
"fig_height": self.fig_height,
"panel_specs": panel_specs,
"share_groups": share_groups,
"inset_specs": inset_specs,
"hspace": self._hspace,
"wspace": self._wspace,
})
# ── inset creation ────────────────────────────────────────────────────────
[docs]
def add_inset(self, w_frac: float, h_frac: float, *,
corner: str = "top-right", title: str = "") -> "InsetAxes":
"""Create and return a floating inset axes.
The inset overlays the figure at the specified corner. Call
plot-factory methods on the returned :class:`InsetAxes` to attach
data::
inset = fig.add_inset(0.3, 0.25, corner="top-right", title="Zoom")
inset.imshow(data) # returns Plot2D
inset.plot(profile) # returns Plot1D
Parameters
----------
w_frac, h_frac : float
Width and height as fractions of the figure size (0–1).
corner : str, optional
Positioning corner: ``"top-right"`` (default), ``"top-left"``,
``"bottom-right"``, or ``"bottom-left"``.
title : str, optional
Text displayed in the inset title bar.
Returns
-------
InsetAxes
"""
return InsetAxes(self, w_frac, h_frac, corner=corner, title=title)
def _register_inset(self, inset_ax: "InsetAxes", plot) -> None:
"""Register an inset plot, allocating its trait and updating layout."""
pid = plot._id
if not self.has_trait(f"panel_{pid}_json"):
self.add_traits(**{f"panel_{pid}_json": traitlets.Unicode("{}").tag(sync=True)})
self._plots_map[pid] = plot
self._insets_map[pid] = inset_ax
self._push(pid)
self._push_layout()
@traitlets.observe("fig_width", "fig_height")
def _on_resize(self, change) -> None:
self._push_layout()
for pid in self._plots_map:
self._push(pid)
@traitlets.observe("event_json")
def _on_event(self, change) -> None:
"""Dispatch a JS interaction event to the relevant plot and widget callbacks."""
self._dispatch_event(change["new"])
def _dispatch_event(self, raw: str) -> None:
"""Process a raw JSON event string from the JS side.
Called by ``_on_event`` (traitlets observer) and also directly by the
Pyodide bridge (``anywidget_bridge.js``) when forwarding user interaction
events from the iframe back to Python callbacks.
Parameters
----------
raw : str
JSON-encoded event message. Expected keys: ``event_type``,
``panel_id``, and optionally ``source``, ``widget_id``, plus
any event-specific payload fields.
"""
if not raw or raw == "{}":
return
try:
msg = json.loads(raw)
except Exception:
return
if msg.get("source") == "python":
return
panel_id = msg.get("panel_id", "")
event_type = msg.get("event_type", "pointer_move")
widget_id = msg.get("widget_id")
# Inset state changes handled before regular plot dispatch
if event_type == "inset_state_change":
inset_ax = self._insets_map.get(panel_id)
if inset_ax is not None:
new_state = msg.get("new_state", "normal")
if new_state in ("normal", "minimized", "maximized"):
inset_ax._inset_state = new_state
self._push_layout()
return
plot = self._plots_map.get(panel_id)
if plot is None:
return
# GPU activation status echo (WebGPU path) — not a user event.
if event_type == "gpu_status":
if hasattr(plot, "_set_gpu_active"):
plot._set_gpu_active(bool(msg.get("gpu_active", False)))
return
source = None
if widget_id and hasattr(plot, "_widgets"):
widget = plot._widgets.get(widget_id)
if widget is not None:
widget._update_from_js(msg, event_type)
source = widget
if hasattr(plot, "callbacks"):
event = Event(
event_type=event_type,
source=source,
time_stamp=msg.get("time_stamp", time.perf_counter()),
modifiers=msg.get("modifiers", []),
x=msg.get("x"),
y=msg.get("y"),
button=msg.get("button"),
buttons=msg.get("buttons", 0),
xdata=msg.get("xdata"),
ydata=msg.get("ydata"),
ray=msg.get("ray"),
line_id=msg.get("line_id"),
dwell_ms=msg.get("dwell_ms"),
bar_index=msg.get("bar_index"),
value=msg.get("value"),
x_label=msg.get("x_label"),
group_index=msg.get("group_index"),
dx=msg.get("dx"),
dy=msg.get("dy"),
key=msg.get("key"),
last_widget_id=msg.get("last_widget_id"),
)
plot.callbacks.fire(event)
def _push_widget(self, panel_id: str, widget_id: str, fields: dict) -> None:
"""Send a targeted widget-position update to JS (no image data)."""
payload = {"source": "python", "panel_id": panel_id,
"widget_id": widget_id}
payload.update(fields)
self.event_json = json.dumps(payload)
def _push_panel_fields(self, panel_id: str, fields: dict) -> None:
"""Apply a small set of changed *fields* to a panel, then push once.
The fields are merged into the panel's ``_state`` and the panel is
pushed via the normal trait channel (authoritative for Jupyter,
snapshots, and the Pyodide bridge alike). Inside a :meth:`batch`
block the push is coalesced, so many such field updates across many
panels collapse to one serialise + transfer per panel per frame —
the dominant per-frame cost over a comm boundary.
"""
plot = self._plots_map.get(panel_id)
if plot is not None:
plot._state.update(fields)
self._push(panel_id)
# ── helpers ───────────────────────────────────────────────────────────────
[docs]
def get_axes(self) -> list:
"""Return a list of all Axes, sorted by grid position.
Returns
-------
list of Axes
Axes sorted by (row_start, col_start) to match typical left-to-right,
top-to-bottom iteration order.
"""
return sorted(self._axes_map.values(),
key=lambda a: (a._spec.row_start, a._spec.col_start))
def _repr_html_(self) -> str:
"""Return a self-contained iframe embedding the live widget.
Used by Sphinx Gallery (via :class:`~docs._sg_html_scraper.ViewerScraper`)
and by any HTML-capable notebook frontend that falls back to
``_repr_html_`` instead of the full ipywidgets protocol.
Returns
-------
str
HTML string containing an embedded iframe with srcdoc attribute.
"""
return repr_html_iframe(self)
[docs]
def to_html(self, *, resizable: bool = True) -> str:
"""Return a self-contained HTML page rendering this figure.
The page inlines the JS renderer and all data — no Jupyter kernel or
network needed at view time. Load it in any browser context, e.g.
an Electron ``BrowserWindow`` or ``<webview>``. See
:mod:`anyplotlib.embed` for the full embedding guide (including live
Python sync via ``FigureBridge``).
"""
from anyplotlib.embed import to_html
return to_html(self, resizable=resizable)
[docs]
def save_html(self, path, *, resizable: bool = True):
"""Write :meth:`to_html` output to *path*; returns the ``Path``."""
from anyplotlib.embed import save_html
return save_html(self, path, resizable=resizable)
[docs]
def close(self) -> None:
"""Close the figure.
Fires a ``"close"`` event on every panel's :attr:`callbacks`, then
hides the widget by setting its CSS ``display`` to ``"none"``.
Subsequent calls are no-ops.
"""
if getattr(self, "_closed", False):
return
self._closed = True
close_event = Event(event_type="close")
for plot in self._plots_map.values():
if hasattr(plot, "callbacks"):
plot.callbacks.fire(close_event)
try:
self.layout.display = "none"
except Exception:
pass
def __repr__(self) -> str:
return (f"Figure({self._nrows}x{self._ncols}, "
f"panels={len(self._plots_map)}, "
f"size={self.fig_width}x{self.fig_height})")