aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSam Scholten2025-10-27 18:24:01 +1000
committerSam Scholten2025-10-27 18:24:01 +1000
commitbce98fc796e32f4a439307dd3b65ef28dc6a73ad (patch)
tree22c42b1786e219d35dd0ab559e3530f6e9676d84
parenta873318fedb0ab6caf65cf42d7df3f7cf67a2325 (diff)
downloadscopekit-bce98fc796e32f4a439307dd3b65ef28dc6a73ad.tar.gz
scopekit-bce98fc796e32f4a439307dd3b65ef28dc6a73ad.zip
refactor: Replace Loguru with warnings and print statements
Co-authored-by: aider (openrouter/anthropic/claude-sonnet-4) <aider@aider.chat>
-rw-r--r--src/scopekit/coordinate_manager.py207
-rw-r--r--src/scopekit/data_manager.py903
-rw-r--r--src/scopekit/decimation.py1273
-rw-r--r--src/scopekit/display_state.py580
-rw-r--r--src/scopekit/plot.py3450
5 files changed, 3060 insertions, 3353 deletions
diff --git a/src/scopekit/coordinate_manager.py b/src/scopekit/coordinate_manager.py
index 6ee1709..47d4e47 100644
--- a/src/scopekit/coordinate_manager.py
+++ b/src/scopekit/coordinate_manager.py
@@ -1,107 +1,100 @@
-from typing import Tuple
-
-import numpy as np
-from loguru import logger
-
-
-class CoordinateManager:
- """
- Handles coordinate transformations between raw time and display coordinates.
-
- Centralises all coordinate conversion logic to prevent inconsistencies.
- """
-
- def __init__(self, display_state):
- """
- Initialise the coordinate manager.
-
- Parameters
- ----------
- display_state : DisplayState
- Reference to the display state object.
- """
- self.state = display_state
-
- def get_current_view_raw(self, ax):
- """Get current view in raw coordinates."""
- try:
- xlim_display = ax.get_xlim()
- logger.debug(f"Converting display xlim {xlim_display} to raw coordinates")
-
- # Validate display limits
- if not np.isfinite(xlim_display[0]) or not np.isfinite(xlim_display[1]):
- logger.warning(f"Invalid display limits: {xlim_display}")
- # Try to get a valid view from the figure
- if hasattr(ax, "figure") and hasattr(ax.figure, "canvas"):
- ax.figure.canvas.draw()
- xlim_display = ax.get_xlim()
- if not np.isfinite(xlim_display[0]) or not np.isfinite(
- xlim_display[1]
- ):
- # Still invalid, use a default range
- logger.warning(
- "Still invalid after redraw, using default range"
- )
- xlim_display = (0, 1)
-
- raw_coords = self.xlim_display_to_raw(xlim_display)
- logger.debug(f"Converted to raw coordinates: {raw_coords}")
- return raw_coords
- except Exception as e:
- logger.exception(f"Error getting current view: {e}")
- # Return a safe default
- return (np.float32(0.0), np.float32(1.0))
-
- def set_view_raw(self, ax, xlim_raw):
- """Set view using raw coordinates."""
- xlim_display = self.xlim_raw_to_display(xlim_raw)
- ax.set_xlim(xlim_display)
-
- def raw_to_display(self, t_raw: np.ndarray) -> np.ndarray:
- """Convert raw time to display coordinates."""
- if self.state.offset_time_raw is not None:
- return (t_raw - self.state.offset_time_raw) * self.state.current_time_scale
- else:
- return t_raw * self.state.current_time_scale
-
- def display_to_raw(self, t_display: np.ndarray) -> np.ndarray:
- """Convert display coordinates to raw time."""
- t_raw = t_display / self.state.current_time_scale
- if self.state.offset_time_raw is not None:
- t_raw += self.state.offset_time_raw
-
- # Only log for scalar values to avoid excessive output
- if isinstance(t_display, (int, float, np.number)):
- logger.debug(
- f"Converting display time {t_display:.6f} to raw time {t_raw:.6f} (scale={self.state.current_time_scale}, offset={self.state.offset_time_raw})"
- )
- return t_raw
-
- def xlim_display_to_raw(
- self, xlim_display: Tuple[float, float]
- ) -> Tuple[np.float32, np.float32]:
- """Convert display xlim tuple to raw time coordinates."""
- try:
- # Ensure values are finite
- if not np.isfinite(xlim_display[0]) or not np.isfinite(xlim_display[1]):
- logger.warning(
- f"Non-finite display limits: {xlim_display}, using defaults"
- )
- return (np.float32(0.0), np.float32(1.0))
-
- return (
- self.display_to_raw(np.float32(xlim_display[0])),
- self.display_to_raw(np.float32(xlim_display[1])),
- )
- except Exception as e:
- logger.exception(f"Error converting display to raw coordinates: {e}")
- return (np.float32(0.0), np.float32(1.0))
-
- def xlim_raw_to_display(
- self, xlim_raw: Tuple[np.float32, np.float32]
- ) -> Tuple[np.float32, np.float32]:
- """Convert raw time xlim tuple to display coordinates."""
- return (
- self.raw_to_display(xlim_raw[0]),
- self.raw_to_display(xlim_raw[1]),
- )
+from typing import Tuple
+import warnings
+
+import numpy as np
+
+
+class CoordinateManager:
+ """
+ Handles coordinate transformations between raw time and display coordinates.
+
+ Centralises all coordinate conversion logic to prevent inconsistencies.
+ """
+
+ def __init__(self, display_state):
+ """
+ Initialise the coordinate manager.
+
+ Parameters
+ ----------
+ display_state : DisplayState
+ Reference to the display state object.
+ """
+ self.state = display_state
+
+ def get_current_view_raw(self, ax):
+ """Get current view in raw coordinates."""
+ try:
+ xlim_display = ax.get_xlim()
+
+ # Validate display limits
+ if not np.isfinite(xlim_display[0]) or not np.isfinite(xlim_display[1]):
+ warnings.warn(f"Invalid display limits: {xlim_display}", RuntimeWarning)
+ # Try to get a valid view from the figure
+ if hasattr(ax, "figure") and hasattr(ax.figure, "canvas"):
+ ax.figure.canvas.draw()
+ xlim_display = ax.get_xlim()
+ if not np.isfinite(xlim_display[0]) or not np.isfinite(
+ xlim_display[1]
+ ):
+ # Still invalid, use a default range
+ warnings.warn(
+ "Still invalid after redraw, using default range", RuntimeWarning
+ )
+ xlim_display = (0, 1)
+
+ raw_coords = self.xlim_display_to_raw(xlim_display)
+ return raw_coords
+ except Exception as e:
+ warnings.warn(f"Error getting current view: {e}", RuntimeWarning)
+ # Return a safe default
+ return (np.float32(0.0), np.float32(1.0))
+
+ def set_view_raw(self, ax, xlim_raw):
+ """Set view using raw coordinates."""
+ xlim_display = self.xlim_raw_to_display(xlim_raw)
+ ax.set_xlim(xlim_display)
+
+ def raw_to_display(self, t_raw: np.ndarray) -> np.ndarray:
+ """Convert raw time to display coordinates."""
+ if self.state.offset_time_raw is not None:
+ return (t_raw - self.state.offset_time_raw) * self.state.current_time_scale
+ else:
+ return t_raw * self.state.current_time_scale
+
+ def display_to_raw(self, t_display: np.ndarray) -> np.ndarray:
+ """Convert display coordinates to raw time."""
+ t_raw = t_display / self.state.current_time_scale
+ if self.state.offset_time_raw is not None:
+ t_raw += self.state.offset_time_raw
+
+ return t_raw
+
+ def xlim_display_to_raw(
+ self, xlim_display: Tuple[float, float]
+ ) -> Tuple[np.float32, np.float32]:
+ """Convert display xlim tuple to raw time coordinates."""
+ try:
+ # Ensure values are finite
+ if not np.isfinite(xlim_display[0]) or not np.isfinite(xlim_display[1]):
+ warnings.warn(
+ f"Non-finite display limits: {xlim_display}, using defaults", RuntimeWarning
+ )
+ return (np.float32(0.0), np.float32(1.0))
+
+ return (
+ self.display_to_raw(np.float32(xlim_display[0])),
+ self.display_to_raw(np.float32(xlim_display[1])),
+ )
+ except Exception as e:
+ warnings.warn(f"Error converting display to raw coordinates: {e}", RuntimeWarning)
+ return (np.float32(0.0), np.float32(1.0))
+
+ def xlim_raw_to_display(
+ self, xlim_raw: Tuple[np.float32, np.float32]
+ ) -> Tuple[np.float32, np.float32]:
+ """Convert raw time xlim tuple to display coordinates."""
+ return (
+ self.raw_to_display(xlim_raw[0]),
+ self.raw_to_display(xlim_raw[1]),
+ )
diff --git a/src/scopekit/data_manager.py b/src/scopekit/data_manager.py
index c2dd09c..78a8ea3 100644
--- a/src/scopekit/data_manager.py
+++ b/src/scopekit/data_manager.py
@@ -1,452 +1,451 @@
-from typing import Any, Dict, List, Optional, Tuple, Union
-
-import numpy as np
-from loguru import logger
-
-
-class TimeSeriesDataManager:
- """
- Manages time series data storage and basic operations.
-
- Handles raw data storage, time scaling, and basic data access patterns.
- It can also store optional associated data like background estimates,
- global noise, and overlay lines.
-
- Supports multiple traces with shared time axis or individual time axes.
- """
-
- def __init__(
- self,
- t: Union[np.ndarray, List[np.ndarray]],
- x: Union[np.ndarray, List[np.ndarray]],
- name: Union[str, List[str]] = "Time Series",
- trace_colors: Optional[List[str]] = None,
- ):
- """
- Initialise the data manager.
-
- Parameters
- ----------
- t : Union[np.ndarray, List[np.ndarray]]
- Time array(s) (raw time in seconds). Can be a single array shared by all traces
- or a list of arrays, one per trace.
- x : Union[np.ndarray, List[np.ndarray]]
- Signal array(s). If t is a single array, x can be a 2D array (traces x samples)
- or a list of 1D arrays. If t is a list, x must be a list of equal length.
- name : Union[str, List[str]], default="Time Series"
- Name(s) for identification. Can be a single string or a list of strings.
- trace_colors : Optional[List[str]], default=None
- Colors for each trace. If None, default colors will be used.
-
- Raises
- ------
- ValueError
- If input arrays have mismatched lengths or time array is not monotonic.
- """
- # Convert inputs to standardized format: lists of arrays
- self.t_arrays, self.x_arrays, self.names, self.colors = (
- self._standardize_inputs(t, x, name, trace_colors)
- )
-
- # Validate all data
- for i, (t_arr, x_arr) in enumerate(zip(self.t_arrays, self.x_arrays)):
- self._validate_core_data(t_arr, x_arr, trace_idx=i)
-
- # Optional associated data (per trace)
- self._overlay_lines: List[List[Dict[str, Any]]] = [
- [] for _ in range(len(self.t_arrays))
- ]
-
- # For backward compatibility
- if len(self.t_arrays) > 0:
- self.t = self.t_arrays[0] # Primary time array
- self.x = self.x_arrays[0] # Primary signal array
- self.name = self.names[0] # Primary name
-
- def _standardize_inputs(
- self,
- t: Union[np.ndarray, List[np.ndarray]],
- x: Union[np.ndarray, List[np.ndarray]],
- name: Union[str, List[str]],
- trace_colors: Optional[List[str]],
- ) -> Tuple[List[np.ndarray], List[np.ndarray], List[str], List[str]]:
- """
- Standardize inputs to lists of arrays.
-
- Parameters
- ----------
- t : Union[np.ndarray, List[np.ndarray]]
- Time array(s).
- x : Union[np.ndarray, List[np.ndarray]]
- Signal array(s).
- name : Union[str, List[str]]
- Name(s) for identification.
- trace_colors : Optional[List[str]]
- Colors for each trace.
-
- Returns
- -------
- Tuple[List[np.ndarray], List[np.ndarray], List[str], List[str]]
- Standardized lists of time arrays, signal arrays, names, and colors.
- """
- # Default colors for traces
- default_colors = [
- "black",
- "blue",
- "red",
- "green",
- "purple",
- "orange",
- "brown",
- "pink",
- "gray",
- "olive",
- ]
-
- # Handle time arrays
- if isinstance(t, list):
- t_arrays = [np.asarray(t_arr, dtype=np.float32) for t_arr in t]
- n_traces = len(t_arrays)
- else:
- t_arr = np.asarray(t, dtype=np.float32)
-
- # Check if x is 2D array or list
- if isinstance(x, list):
- n_traces = len(x)
- t_arrays = [t_arr.copy() for _ in range(n_traces)]
- elif x.ndim == 2:
- n_traces = x.shape[0]
- t_arrays = [t_arr.copy() for _ in range(n_traces)]
- else:
- n_traces = 1
- t_arrays = [t_arr]
-
- # Handle signal arrays
- if isinstance(x, list):
- if len(x) != n_traces:
- raise ValueError(
- f"Number of signal arrays ({len(x)}) must match number of time arrays ({n_traces})"
- )
- x_arrays = [np.asarray(x_arr, dtype=np.float32) for x_arr in x]
- elif x.ndim == 2:
- if x.shape[0] != n_traces:
- raise ValueError(
- f"First dimension of 2D signal array ({x.shape[0]}) must match number of time arrays ({n_traces})"
- )
- x_arrays = [np.asarray(x[i], dtype=np.float32) for i in range(n_traces)]
- else:
- if n_traces != 1:
- raise ValueError(
- f"Single signal array provided but expected {n_traces} arrays"
- )
- x_arrays = [np.asarray(x, dtype=np.float32)]
-
- # Handle names
- if isinstance(name, list):
- if len(name) != n_traces:
- logger.warning(
- f"Number of names ({len(name)}) doesn't match number of traces ({n_traces}). Using defaults."
- )
- names = [f"Trace {i + 1}" for i in range(n_traces)]
- else:
- names = name
- else:
- if n_traces == 1:
- names = [name]
- else:
- if (
- name == "Time Series"
- ): # Only use default naming if the default name was used
- names = [f"Trace {i + 1}" for i in range(n_traces)]
- else:
- names = [f"{name} {i + 1}" for i in range(n_traces)]
-
- # Handle colors
- if trace_colors is not None:
- if len(trace_colors) < n_traces:
- logger.warning(
- f"Not enough colors provided ({len(trace_colors)}). Using defaults for remaining traces."
- )
- colors = trace_colors + [
- default_colors[i % len(default_colors)]
- for i in range(len(trace_colors), n_traces)
- ]
- else:
- colors = trace_colors[:n_traces]
- else:
- colors = [default_colors[i % len(default_colors)] for i in range(n_traces)]
-
- return t_arrays, x_arrays, names, colors
-
- def _validate_core_data(
- self, t: np.ndarray, x: np.ndarray, trace_idx: int = 0
- ) -> None:
- """
- Validate core input data arrays for consistency and correctness.
-
- Parameters
- ----------
- t : np.ndarray
- Time array.
- x : np.ndarray
- Signal array.
- trace_idx : int, default=0
- Index of the trace being validated (for error messages).
-
- Raises
- ------
- ValueError
- If arrays have mismatched lengths or time array is not monotonic.
- """
- if len(t) != len(x):
- raise ValueError(
- f"Time and signal arrays for trace {trace_idx} must have the same length. Got t={len(t)}, x={len(x)}"
- )
- if len(t) == 0:
- logger.warning(f"Initialising trace {trace_idx} with empty arrays.")
- return
-
- # Check time array is monotonic
- if len(t) > 1:
- # Use a small epsilon for floating-point comparison
- tolerance = 1e-9
- if not np.all(np.diff(t) > tolerance):
- problematic_diffs = np.diff(t)[np.diff(t) <= tolerance]
- logger.warning(
- f"Time array for trace {trace_idx} is not strictly monotonic increasing within tolerance {tolerance}. "
- f"Problematic diffs (first 10): {problematic_diffs[:10]}. "
- f"This may affect analysis results."
- )
-
- # Check for non-uniform sampling
- self._check_uniform_sampling(t, trace_idx)
-
- @property
- def overlay_lines(self) -> List[Dict[str, Any]]:
- """Get overlay lines data for the primary trace."""
- return self._overlay_lines[0] if self._overlay_lines else []
-
- def get_overlay_lines(self, trace_idx: int = 0) -> List[Dict[str, Any]]:
- """Get overlay lines data for a specific trace."""
- if trace_idx < 0 or trace_idx >= len(self.t_arrays):
- raise ValueError(
- f"Invalid trace index: {trace_idx}. Must be between 0 and {len(self.t_arrays) - 1}."
- )
- return self._overlay_lines[trace_idx]
-
- @property
- def num_traces(self) -> int:
- """Get the number of traces."""
- return len(self.t_arrays)
-
- def get_trace_color(self, trace_idx: int = 0) -> str:
- """Get the color for a specific trace."""
- if trace_idx < 0 or trace_idx >= len(self.t_arrays):
- raise ValueError(
- f"Invalid trace index: {trace_idx}. Must be between 0 and {len(self.t_arrays) - 1}."
- )
- return self.colors[trace_idx]
-
- def get_trace_name(self, trace_idx: int = 0) -> str:
- """Get the name for a specific trace."""
- if trace_idx < 0 or trace_idx >= len(self.t_arrays):
- raise ValueError(
- f"Invalid trace index: {trace_idx}. Must be between 0 and {len(self.t_arrays) - 1}."
- )
- return self.names[trace_idx]
-
- def set_overlay_lines(
- self,
- overlay_lines: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]],
- trace_idx: Optional[int] = None,
- ) -> None:
- """
- Set overlay lines data.
-
- Parameters
- ----------
- overlay_lines : Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]
- List of dictionaries defining overlay lines, or list of lists for multiple traces.
- trace_idx : Optional[int], default=None
- If provided, set overlay lines only for the specified trace.
- If None, set for all traces if a nested list is provided, or for the first trace if a flat list.
- """
- if trace_idx is not None:
- # Set for specific trace
- if trace_idx < 0 or trace_idx >= len(self.t_arrays):
- raise ValueError(
- f"Invalid trace index: {trace_idx}. Must be between 0 and {len(self.t_arrays) - 1}."
- )
-
- # Ensure we have a list of dictionaries
- if not isinstance(overlay_lines, list):
- raise ValueError(
- f"overlay_lines must be a list of dictionaries. Got {type(overlay_lines)}."
- )
-
- # Check if it's a list of dictionaries (not a nested list)
- if len(overlay_lines) > 0 and isinstance(overlay_lines[0], dict):
- self._overlay_lines[trace_idx] = overlay_lines
- else:
- raise ValueError(
- "Expected a list of dictionaries for overlay_lines when trace_idx is specified."
- )
- else:
- # Set for all traces or first trace
- if len(overlay_lines) > 0 and isinstance(overlay_lines[0], list):
- # Nested list provided - set for multiple traces
- if len(overlay_lines) != len(self.t_arrays):
- raise ValueError(
- f"Number of overlay line lists ({len(overlay_lines)}) must match number of traces ({len(self.t_arrays)})."
- )
-
- for i, lines in enumerate(overlay_lines):
- self._overlay_lines[i] = lines
- else:
- # Flat list provided - set for first trace
- self._overlay_lines[0] = overlay_lines
-
- def get_time_range(self, trace_idx: int = 0) -> Tuple[np.float32, np.float32]:
- """
- Get the full time range of the data.
-
- Parameters
- ----------
- trace_idx : int, default=0
- Index of the trace to get the time range for.
-
- Returns
- -------
- Tuple[np.float32, np.float32]
- Start and end time of the data.
- """
- if trace_idx < 0 or trace_idx >= len(self.t_arrays):
- raise ValueError(
- f"Invalid trace index: {trace_idx}. Must be between 0 and {len(self.t_arrays) - 1}."
- )
-
- t_arr = self.t_arrays[trace_idx]
- if t_arr.size == 0:
- return np.float32(0.0), np.float32(0.0)
- return np.float32(t_arr[0]), np.float32(t_arr[-1])
-
- def get_global_time_range(self) -> Tuple[np.float32, np.float32]:
- """
- Get the global time range across all traces.
-
- Returns
- -------
- Tuple[np.float32, np.float32]
- Global start and end time across all traces.
- """
- if len(self.t_arrays) == 0:
- return np.float32(0.0), np.float32(0.0)
-
- t_min = np.float32(
- min(t_arr[0] if t_arr.size > 0 else np.inf for t_arr in self.t_arrays)
- )
- t_max = np.float32(
- max(t_arr[-1] if t_arr.size > 0 else -np.inf for t_arr in self.t_arrays)
- )
-
- if np.isinf(t_min) or np.isinf(t_max):
- return np.float32(0.0), np.float32(0.0)
-
- return t_min, t_max
-
- def get_data_in_range(
- self, t_start: np.float32, t_end: np.float32, trace_idx: int = 0
- ) -> Tuple[np.ndarray, np.ndarray]:
- """
- Extract data within a time range.
-
- Parameters
- ----------
- t_start : np.float32
- Start time in raw seconds.
- t_end : np.float32
- End time in raw seconds.
- trace_idx : int, default=0
- Index of the trace to get data for.
-
- Returns
- -------
- Tuple[np.ndarray, np.ndarray]
- Time and signal arrays.
- """
- if trace_idx < 0 or trace_idx >= len(self.t_arrays):
- raise ValueError(
- f"Invalid trace index: {trace_idx}. Must be between 0 and {len(self.t_arrays) - 1}."
- )
-
- t_arr = self.t_arrays[trace_idx]
- x_arr = self.x_arrays[trace_idx]
-
- mask = (t_arr >= t_start) & (t_arr <= t_end)
- if not np.any(mask):
- logger.debug(f"No data in range [{t_start}, {t_end}] for trace {trace_idx}")
- return (
- np.array([], dtype=np.float32),
- np.array([], dtype=np.float32),
- )
-
- t_masked = t_arr[mask]
- x_masked = x_arr[mask]
-
- return t_masked, x_masked
-
- def _check_uniform_sampling(self, t: np.ndarray, trace_idx: int = 0) -> None:
- """
- Check if time array is uniformly sampled and issue warnings if not.
-
- Parameters
- ----------
- t : np.ndarray
- Time array to check.
- trace_idx : int, default=0
- Index of the trace being checked (for warning messages).
- """
- if len(t) < 3:
- return # Not enough points to check uniformity
-
- # Calculate time differences
- dt = np.diff(t)
-
- # Calculate statistics
- dt_mean = np.mean(dt)
- dt_std = np.std(dt)
- dt_cv = dt_std / dt_mean if dt_mean > 0 else 0 # Coefficient of variation
-
- # Check for significant non-uniformity
- # CV > 0.01 (1%) indicates potentially problematic non-uniformity
- if dt_cv > 0.01:
- logger.warning(
- f"Non-uniform sampling detected in trace {trace_idx}: "
- f"mean dt={dt_mean:.3e}s, std={dt_std:.3e}s, CV={dt_cv:.2%}"
- )
-
- # More detailed warning for severe non-uniformity
- if dt_cv > 0.05: # 5% variation
- # Find the most extreme deviations
- dt_median = np.median(dt)
- rel_deviations = np.abs(dt - dt_median) / dt_median
- worst_indices = np.argsort(rel_deviations)[-5:] # 5 worst points
-
- worst_deviations = []
- for idx in reversed(worst_indices):
- if (
- rel_deviations[idx] > 0.1
- ): # Only report significant deviations (>10%)
- worst_deviations.append(
- f"at t={t[idx]:.3e}s: dt={dt[idx]:.3e}s ({rel_deviations[idx]:.1%} deviation)"
- )
-
- if worst_deviations:
- logger.warning(
- f"Severe sampling irregularities detected in trace {trace_idx}. "
- f"Worst points: {'; '.join(worst_deviations)}"
- )
- logger.warning(
- "Non-uniform sampling may affect analysis results, especially for "
- "frequency-domain analysis or event detection."
- )
+from typing import Any, Dict, List, Optional, Tuple, Union
+import warnings
+
+import numpy as np
+
+
+class TimeSeriesDataManager:
+ """
+ Manages time series data storage and basic operations.
+
+ Handles raw data storage, time scaling, and basic data access patterns.
+ It can also store optional associated data like background estimates,
+ global noise, and overlay lines.
+
+ Supports multiple traces with shared time axis or individual time axes.
+ """
+
+ def __init__(
+ self,
+ t: Union[np.ndarray, List[np.ndarray]],
+ x: Union[np.ndarray, List[np.ndarray]],
+ name: Union[str, List[str]] = "Time Series",
+ trace_colors: Optional[List[str]] = None,
+ ):
+ """
+ Initialise the data manager.
+
+ Parameters
+ ----------
+ t : Union[np.ndarray, List[np.ndarray]]
+ Time array(s) (raw time in seconds). Can be a single array shared by all traces
+ or a list of arrays, one per trace.
+ x : Union[np.ndarray, List[np.ndarray]]
+ Signal array(s). If t is a single array, x can be a 2D array (traces x samples)
+ or a list of 1D arrays. If t is a list, x must be a list of equal length.
+ name : Union[str, List[str]], default="Time Series"
+ Name(s) for identification. Can be a single string or a list of strings.
+ trace_colors : Optional[List[str]], default=None
+ Colors for each trace. If None, default colors will be used.
+
+ Raises
+ ------
+ ValueError
+ If input arrays have mismatched lengths or time array is not monotonic.
+ """
+ # Convert inputs to standardized format: lists of arrays
+ self.t_arrays, self.x_arrays, self.names, self.colors = (
+ self._standardize_inputs(t, x, name, trace_colors)
+ )
+
+ # Validate all data
+ for i, (t_arr, x_arr) in enumerate(zip(self.t_arrays, self.x_arrays)):
+ self._validate_core_data(t_arr, x_arr, trace_idx=i)
+
+ # Optional associated data (per trace)
+ self._overlay_lines: List[List[Dict[str, Any]]] = [
+ [] for _ in range(len(self.t_arrays))
+ ]
+
+ # For backward compatibility
+ if len(self.t_arrays) > 0:
+ self.t = self.t_arrays[0] # Primary time array
+ self.x = self.x_arrays[0] # Primary signal array
+ self.name = self.names[0] # Primary name
+
+ def _standardize_inputs(
+ self,
+ t: Union[np.ndarray, List[np.ndarray]],
+ x: Union[np.ndarray, List[np.ndarray]],
+ name: Union[str, List[str]],
+ trace_colors: Optional[List[str]],
+ ) -> Tuple[List[np.ndarray], List[np.ndarray], List[str], List[str]]:
+ """
+ Standardize inputs to lists of arrays.
+
+ Parameters
+ ----------
+ t : Union[np.ndarray, List[np.ndarray]]
+ Time array(s).
+ x : Union[np.ndarray, List[np.ndarray]]
+ Signal array(s).
+ name : Union[str, List[str]]
+ Name(s) for identification.
+ trace_colors : Optional[List[str]]
+ Colors for each trace.
+
+ Returns
+ -------
+ Tuple[List[np.ndarray], List[np.ndarray], List[str], List[str]]
+ Standardized lists of time arrays, signal arrays, names, and colors.
+ """
+ # Default colors for traces
+ default_colors = [
+ "black",
+ "blue",
+ "red",
+ "green",
+ "purple",
+ "orange",
+ "brown",
+ "pink",
+ "gray",
+ "olive",
+ ]
+
+ # Handle time arrays
+ if isinstance(t, list):
+ t_arrays = [np.asarray(t_arr, dtype=np.float32) for t_arr in t]
+ n_traces = len(t_arrays)
+ else:
+ t_arr = np.asarray(t, dtype=np.float32)
+
+ # Check if x is 2D array or list
+ if isinstance(x, list):
+ n_traces = len(x)
+ t_arrays = [t_arr.copy() for _ in range(n_traces)]
+ elif x.ndim == 2:
+ n_traces = x.shape[0]
+ t_arrays = [t_arr.copy() for _ in range(n_traces)]
+ else:
+ n_traces = 1
+ t_arrays = [t_arr]
+
+ # Handle signal arrays
+ if isinstance(x, list):
+ if len(x) != n_traces:
+ raise ValueError(
+ f"Number of signal arrays ({len(x)}) must match number of time arrays ({n_traces})"
+ )
+ x_arrays = [np.asarray(x_arr, dtype=np.float32) for x_arr in x]
+ elif x.ndim == 2:
+ if x.shape[0] != n_traces:
+ raise ValueError(
+ f"First dimension of 2D signal array ({x.shape[0]}) must match number of time arrays ({n_traces})"
+ )
+ x_arrays = [np.asarray(x[i], dtype=np.float32) for i in range(n_traces)]
+ else:
+ if n_traces != 1:
+ raise ValueError(
+ f"Single signal array provided but expected {n_traces} arrays"
+ )
+ x_arrays = [np.asarray(x, dtype=np.float32)]
+
+ # Handle names
+ if isinstance(name, list):
+ if len(name) != n_traces:
+ warnings.warn(
+ f"Number of names ({len(name)}) doesn't match number of traces ({n_traces}). Using defaults.", UserWarning
+ )
+ names = [f"Trace {i + 1}" for i in range(n_traces)]
+ else:
+ names = name
+ else:
+ if n_traces == 1:
+ names = [name]
+ else:
+ if (
+ name == "Time Series"
+ ): # Only use default naming if the default name was used
+ names = [f"Trace {i + 1}" for i in range(n_traces)]
+ else:
+ names = [f"{name} {i + 1}" for i in range(n_traces)]
+
+ # Handle colors
+ if trace_colors is not None:
+ if len(trace_colors) < n_traces:
+ warnings.warn(
+ f"Not enough colors provided ({len(trace_colors)}). Using defaults for remaining traces.", UserWarning
+ )
+ colors = trace_colors + [
+ default_colors[i % len(default_colors)]
+ for i in range(len(trace_colors), n_traces)
+ ]
+ else:
+ colors = trace_colors[:n_traces]
+ else:
+ colors = [default_colors[i % len(default_colors)] for i in range(n_traces)]
+
+ return t_arrays, x_arrays, names, colors
+
+ def _validate_core_data(
+ self, t: np.ndarray, x: np.ndarray, trace_idx: int = 0
+ ) -> None:
+ """
+ Validate core input data arrays for consistency and correctness.
+
+ Parameters
+ ----------
+ t : np.ndarray
+ Time array.
+ x : np.ndarray
+ Signal array.
+ trace_idx : int, default=0
+ Index of the trace being validated (for error messages).
+
+ Raises
+ ------
+ ValueError
+ If arrays have mismatched lengths or time array is not monotonic.
+ """
+ if len(t) != len(x):
+ raise ValueError(
+ f"Time and signal arrays for trace {trace_idx} must have the same length. Got t={len(t)}, x={len(x)}"
+ )
+ if len(t) == 0:
+ warnings.warn(f"Initialising trace {trace_idx} with empty arrays.", UserWarning)
+ return
+
+ # Check time array is monotonic
+ if len(t) > 1:
+ # Use a small epsilon for floating-point comparison
+ tolerance = 1e-9
+ if not np.all(np.diff(t) > tolerance):
+ problematic_diffs = np.diff(t)[np.diff(t) <= tolerance]
+ warnings.warn(
+ f"Time array for trace {trace_idx} is not strictly monotonic increasing within tolerance {tolerance}. "
+ f"Problematic diffs (first 10): {problematic_diffs[:10]}. "
+ f"This may affect analysis results.", UserWarning
+ )
+
+ # Check for non-uniform sampling
+ self._check_uniform_sampling(t, trace_idx)
+
+ @property
+ def overlay_lines(self) -> List[Dict[str, Any]]:
+ """Get overlay lines data for the primary trace."""
+ return self._overlay_lines[0] if self._overlay_lines else []
+
+ def get_overlay_lines(self, trace_idx: int = 0) -> List[Dict[str, Any]]:
+ """Get overlay lines data for a specific trace."""
+ if trace_idx < 0 or trace_idx >= len(self.t_arrays):
+ raise ValueError(
+ f"Invalid trace index: {trace_idx}. Must be between 0 and {len(self.t_arrays) - 1}."
+ )
+ return self._overlay_lines[trace_idx]
+
+ @property
+ def num_traces(self) -> int:
+ """Get the number of traces."""
+ return len(self.t_arrays)
+
+ def get_trace_color(self, trace_idx: int = 0) -> str:
+ """Get the color for a specific trace."""
+ if trace_idx < 0 or trace_idx >= len(self.t_arrays):
+ raise ValueError(
+ f"Invalid trace index: {trace_idx}. Must be between 0 and {len(self.t_arrays) - 1}."
+ )
+ return self.colors[trace_idx]
+
+ def get_trace_name(self, trace_idx: int = 0) -> str:
+ """Get the name for a specific trace."""
+ if trace_idx < 0 or trace_idx >= len(self.t_arrays):
+ raise ValueError(
+ f"Invalid trace index: {trace_idx}. Must be between 0 and {len(self.t_arrays) - 1}."
+ )
+ return self.names[trace_idx]
+
+ def set_overlay_lines(
+ self,
+ overlay_lines: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]],
+ trace_idx: Optional[int] = None,
+ ) -> None:
+ """
+ Set overlay lines data.
+
+ Parameters
+ ----------
+ overlay_lines : Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]
+ List of dictionaries defining overlay lines, or list of lists for multiple traces.
+ trace_idx : Optional[int], default=None
+ If provided, set overlay lines only for the specified trace.
+ If None, set for all traces if a nested list is provided, or for the first trace if a flat list.
+ """
+ if trace_idx is not None:
+ # Set for specific trace
+ if trace_idx < 0 or trace_idx >= len(self.t_arrays):
+ raise ValueError(
+ f"Invalid trace index: {trace_idx}. Must be between 0 and {len(self.t_arrays) - 1}."
+ )
+
+ # Ensure we have a list of dictionaries
+ if not isinstance(overlay_lines, list):
+ raise ValueError(
+ f"overlay_lines must be a list of dictionaries. Got {type(overlay_lines)}."
+ )
+
+ # Check if it's a list of dictionaries (not a nested list)
+ if len(overlay_lines) > 0 and isinstance(overlay_lines[0], dict):
+ self._overlay_lines[trace_idx] = overlay_lines
+ else:
+ raise ValueError(
+ "Expected a list of dictionaries for overlay_lines when trace_idx is specified."
+ )
+ else:
+ # Set for all traces or first trace
+ if len(overlay_lines) > 0 and isinstance(overlay_lines[0], list):
+ # Nested list provided - set for multiple traces
+ if len(overlay_lines) != len(self.t_arrays):
+ raise ValueError(
+ f"Number of overlay line lists ({len(overlay_lines)}) must match number of traces ({len(self.t_arrays)})."
+ )
+
+ for i, lines in enumerate(overlay_lines):
+ self._overlay_lines[i] = lines
+ else:
+ # Flat list provided - set for first trace
+ self._overlay_lines[0] = overlay_lines
+
+ def get_time_range(self, trace_idx: int = 0) -> Tuple[np.float32, np.float32]:
+ """
+ Get the full time range of the data.
+
+ Parameters
+ ----------
+ trace_idx : int, default=0
+ Index of the trace to get the time range for.
+
+ Returns
+ -------
+ Tuple[np.float32, np.float32]
+ Start and end time of the data.
+ """
+ if trace_idx < 0 or trace_idx >= len(self.t_arrays):
+ raise ValueError(
+ f"Invalid trace index: {trace_idx}. Must be between 0 and {len(self.t_arrays) - 1}."
+ )
+
+ t_arr = self.t_arrays[trace_idx]
+ if t_arr.size == 0:
+ return np.float32(0.0), np.float32(0.0)
+ return np.float32(t_arr[0]), np.float32(t_arr[-1])
+
+ def get_global_time_range(self) -> Tuple[np.float32, np.float32]:
+ """
+ Get the global time range across all traces.
+
+ Returns
+ -------
+ Tuple[np.float32, np.float32]
+ Global start and end time across all traces.
+ """
+ if len(self.t_arrays) == 0:
+ return np.float32(0.0), np.float32(0.0)
+
+ t_min = np.float32(
+ min(t_arr[0] if t_arr.size > 0 else np.inf for t_arr in self.t_arrays)
+ )
+ t_max = np.float32(
+ max(t_arr[-1] if t_arr.size > 0 else -np.inf for t_arr in self.t_arrays)
+ )
+
+ if np.isinf(t_min) or np.isinf(t_max):
+ return np.float32(0.0), np.float32(0.0)
+
+ return t_min, t_max
+
+ def get_data_in_range(
+ self, t_start: np.float32, t_end: np.float32, trace_idx: int = 0
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Extract data within a time range.
+
+ Parameters
+ ----------
+ t_start : np.float32
+ Start time in raw seconds.
+ t_end : np.float32
+ End time in raw seconds.
+ trace_idx : int, default=0
+ Index of the trace to get data for.
+
+ Returns
+ -------
+ Tuple[np.ndarray, np.ndarray]
+ Time and signal arrays.
+ """
+ if trace_idx < 0 or trace_idx >= len(self.t_arrays):
+ raise ValueError(
+ f"Invalid trace index: {trace_idx}. Must be between 0 and {len(self.t_arrays) - 1}."
+ )
+
+ t_arr = self.t_arrays[trace_idx]
+ x_arr = self.x_arrays[trace_idx]
+
+ mask = (t_arr >= t_start) & (t_arr <= t_end)
+ if not np.any(mask):
+ return (
+ np.array([], dtype=np.float32),
+ np.array([], dtype=np.float32),
+ )
+
+ t_masked = t_arr[mask]
+ x_masked = x_arr[mask]
+
+ return t_masked, x_masked
+
+ def _check_uniform_sampling(self, t: np.ndarray, trace_idx: int = 0) -> None:
+ """
+ Check if time array is uniformly sampled and issue warnings if not.
+
+ Parameters
+ ----------
+ t : np.ndarray
+ Time array to check.
+ trace_idx : int, default=0
+ Index of the trace being checked (for warning messages).
+ """
+ if len(t) < 3:
+ return # Not enough points to check uniformity
+
+ # Calculate time differences
+ dt = np.diff(t)
+
+ # Calculate statistics
+ dt_mean = np.mean(dt)
+ dt_std = np.std(dt)
+ dt_cv = dt_std / dt_mean if dt_mean > 0 else 0 # Coefficient of variation
+
+ # Check for significant non-uniformity
+ # CV > 0.01 (1%) indicates potentially problematic non-uniformity
+ if dt_cv > 0.01:
+ warnings.warn(
+ f"Non-uniform sampling detected in trace {trace_idx}: "
+ f"mean dt={dt_mean:.3e}s, std={dt_std:.3e}s, CV={dt_cv:.2%}", UserWarning
+ )
+
+ # More detailed warning for severe non-uniformity
+ if dt_cv > 0.05: # 5% variation
+ # Find the most extreme deviations
+ dt_median = np.median(dt)
+ rel_deviations = np.abs(dt - dt_median) / dt_median
+ worst_indices = np.argsort(rel_deviations)[-5:] # 5 worst points
+
+ worst_deviations = []
+ for idx in reversed(worst_indices):
+ if (
+ rel_deviations[idx] > 0.1
+ ): # Only report significant deviations (>10%)
+ worst_deviations.append(
+ f"at t={t[idx]:.3e}s: dt={dt[idx]:.3e}s ({rel_deviations[idx]:.1%} deviation)"
+ )
+
+ if worst_deviations:
+ warnings.warn(
+ f"Severe sampling irregularities detected in trace {trace_idx}. "
+ f"Worst points: {'; '.join(worst_deviations)}", UserWarning
+ )
+ warnings.warn(
+ "Non-uniform sampling may affect analysis results, especially for "
+ "frequency-domain analysis or event detection.", UserWarning
+ )
diff --git a/src/scopekit/decimation.py b/src/scopekit/decimation.py
index 16543b1..d60ac3b 100644
--- a/src/scopekit/decimation.py
+++ b/src/scopekit/decimation.py
@@ -1,671 +1,602 @@
-from typing import Dict, Optional, Tuple
-
-import numpy as np
-from loguru import logger
-from numba import njit
-
-
-@njit
-def _decimate_time_numba(t: np.ndarray, step: int, n_bins: int) -> np.ndarray:
- """
- Numba-optimized time decimation using bin centers.
-
- Parameters
- ----------
- t : np.ndarray
- Input time array.
- step : int
- Step size for binning.
- n_bins : int
- Number of bins to create.
-
- Returns
- -------
- np.ndarray
- Decimated time array with center time of each bin.
- """
- t_decimated = np.zeros(n_bins, dtype=np.float32)
-
- for i in range(n_bins):
- start_idx = i * step
- end_idx = min((i + 1) * step, len(t))
- center_idx = start_idx + (end_idx - start_idx) // 2
- t_decimated[i] = t[center_idx]
-
- return t_decimated
-
-
-@njit
-def _decimate_mean_numba(x: np.ndarray, step: int, n_bins: int) -> np.ndarray:
- """
- Numba-optimized mean decimation.
-
- Parameters
- ----------
- x : np.ndarray
- Input signal array.
- step : int
- Step size for binning.
- n_bins : int
- Number of bins to create.
-
- Returns
- -------
- np.ndarray
- Decimated signal array with mean values.
- """
- x_decimated = np.zeros(n_bins, dtype=np.float32)
-
- for i in range(n_bins):
- start_idx = i * step
- end_idx = min((i + 1) * step, len(x))
-
- if end_idx > start_idx:
- # Calculate mean manually for Numba compatibility
- bin_sum = 0.0
- bin_count = end_idx - start_idx
- for j in range(start_idx, end_idx):
- bin_sum += x[j]
- x_decimated[i] = bin_sum / bin_count
- else:
- x_decimated[i] = x[start_idx] if start_idx < len(x) else 0.0
-
- return x_decimated
-
-
-@njit
-def _decimate_envelope_standard_numba(
- x: np.ndarray, step: int, n_bins: int
-) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
- """
- Numba-optimized standard envelope decimation.
-
- Parameters
- ----------
- x : np.ndarray
- Input signal array.
- step : int
- Step size for binning.
- n_bins : int
- Number of bins to create.
-
- Returns
- -------
- Tuple[np.ndarray, np.ndarray, np.ndarray]
- Decimated signal (mean), min envelope, max envelope arrays.
- """
- x_decimated = np.zeros(n_bins, dtype=np.float32)
- x_min_envelope = np.zeros(n_bins, dtype=np.float32)
- x_max_envelope = np.zeros(n_bins, dtype=np.float32)
-
- for i in range(n_bins):
- start_idx = i * step
- end_idx = min((i + 1) * step, len(x))
-
- if end_idx > start_idx:
- # Find min and max manually for Numba compatibility
- bin_min = x[start_idx]
- bin_max = x[start_idx]
- bin_sum = 0.0
-
- for j in range(start_idx, end_idx):
- val = x[j]
- if val < bin_min:
- bin_min = val
- if val > bin_max:
- bin_max = val
- bin_sum += val
-
- x_min_envelope[i] = bin_min
- x_max_envelope[i] = bin_max
- x_decimated[i] = bin_sum / (end_idx - start_idx)
- else:
- fallback_val = x[start_idx] if start_idx < len(x) else 0.0
- x_min_envelope[i] = fallback_val
- x_max_envelope[i] = fallback_val
- x_decimated[i] = fallback_val
-
- return x_decimated, x_min_envelope, x_max_envelope
-
-
-@njit
-def _decimate_envelope_highres_numba(
- t: np.ndarray, x: np.ndarray, step: int, n_bins: int, envelope_window_samples: int
-) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
- """
- Numba-optimized high-resolution envelope decimation.
-
- Parameters
- ----------
- t : np.ndarray
- Input time array.
- x : np.ndarray
- Input signal array.
- step : int
- Step size for binning.
- n_bins : int
- Number of bins to create.
- envelope_window_samples : int
- Window size in samples for high-resolution envelope calculation.
-
- Returns
- -------
- Tuple[np.ndarray, np.ndarray, np.ndarray]
- Decimated signal (mean), min envelope, max envelope arrays.
- """
- x_decimated = np.zeros(n_bins, dtype=np.float32)
- x_min_envelope = np.zeros(n_bins, dtype=np.float32)
- x_max_envelope = np.zeros(n_bins, dtype=np.float32)
-
- half_window = envelope_window_samples // 2
-
- for i in range(n_bins):
- start_idx = i * step
- end_idx = min((i + 1) * step, len(t))
- bin_center = start_idx + (end_idx - start_idx) // 2
-
- # Define window around bin center
- window_start = max(0, bin_center - half_window)
- window_end = min(len(x), bin_center + half_window)
-
- if window_end > window_start:
- # Find min and max in window manually for Numba compatibility
- window_min = x[window_start]
- window_max = x[window_start]
-
- for j in range(window_start, window_end):
- val = x[j]
- if val < window_min:
- window_min = val
- if val > window_max:
- window_max = val
-
- x_min_envelope[i] = window_min
- x_max_envelope[i] = window_max
- x_decimated[i] = (window_min + window_max) / 2.0
- else:
- fallback_val = x[bin_center] if bin_center < len(x) else 0.0
- x_min_envelope[i] = fallback_val
- x_max_envelope[i] = fallback_val
- x_decimated[i] = fallback_val
-
- return x_decimated, x_min_envelope, x_max_envelope
-
-
-class DecimationManager:
- """
- Handles data decimation and caching for efficient plotting.
-
- Manages different decimation strategies and caches results to improve performance.
- Pre-calculates decimated data at load time for faster zooming.
- """
-
- # Cache and performance constants
- CACHE_MAX_SIZE = 10
- MIN_VISIBLE_RANGE_DEFAULT = 1e-6 # Default if no global noise is provided
- # Threshold for warning about too many points in detail mode
- DETAIL_MODE_POINT_WARNING_THRESHOLD = 100000
-
- def __init__(self, cache_max_size: int = CACHE_MAX_SIZE):
- """
- Initialise the decimation manager.
-
- Parameters
- ----------
- cache_max_size : int, default=PlotConstants.CACHE_MAX_SIZE
- Maximum number of cached decimation results.
- """
- self._cache: Dict[str, Tuple[np.ndarray, ...]] = {}
- self._cache_max_size = cache_max_size
- # Stores pre-decimated envelope data for the full dataset for each trace/line
- # Structure: {trace_id: {'t': np.ndarray, 'x_min': np.ndarray, 'x_max': np.ndarray, ...}}
- self._pre_decimated_envelopes: Dict[int, Dict[str, np.ndarray]] = {}
-
- def _get_cache_key(
- self,
- xlim_raw: Tuple[np.float32, np.float32],
- max_points: int,
- use_envelope: bool,
- trace_id: Optional[int] = None,
- ) -> str:
- """Generate cache key for decimated data."""
- # Round to reasonable precision to improve cache hits
- xlim_rounded = (round(float(xlim_raw[0]), 9), round(float(xlim_raw[1]), 9))
-
- # Include trace_id in cache key for multi-trace support
- trace_suffix = f"_t{trace_id}" if trace_id is not None else ""
-
- return f"{xlim_rounded}_{max_points}_{use_envelope}{trace_suffix}"
-
- def _manage_cache_size(self) -> None:
- """Remove oldest cache entry if cache is full."""
- if len(self._cache) >= self._cache_max_size:
- # Remove oldest entry (simple FIFO)
- oldest_key = next(iter(self._cache))
- del self._cache[oldest_key]
-
- def clear_cache(self) -> None:
- """Clear the decimation cache."""
- self._cache.clear()
- # Do NOT clear _pre_decimated_envelopes here, as they are persistent for the full dataset
-
- def _decimate_data(
- self,
- t: np.ndarray,
- x: np.ndarray,
- max_points: int,
- use_envelope: bool = False,
- envelope_window_samples: Optional[int] = None,
- return_envelope_min_max: bool = False, # New parameter
- ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]:
- """
- Unified decimation for time and multiple data arrays.
-
- Parameters
- ----------
- t : np.ndarray
- Time array.
- x : np.ndarray
- Signal array.
- max_points : int, default=5000
- Maximum number of points to display.
- use_envelope : bool, default=False
- Whether to use envelope decimation for the signal array.
- envelope_window_samples : Optional[int], default=None
- Window size in samples for high-resolution envelope calculation.
- return_envelope_min_max : bool, default=False
- If True, returns x_min_envelope and x_max_envelope. Otherwise, returns None for them.
- If None, uses simple binning approach.
-
- Returns
- -------
- Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]
- Decimated time, signal, signal min envelope, signal max envelope arrays.
- """
- # If input arrays are empty, return empty arrays immediately
- if len(t) == 0:
- return (
- np.array([], dtype=np.float32),
- np.array([], dtype=np.float32),
- None,
- None,
- )
-
- # If not using envelope, always return raw data for the view
- if (
- not use_envelope and not return_envelope_min_max
- ): # If not using envelope and not explicitly asking for min/max
- return t, x, None, None # No min/max envelope for raw data
-
- # If using envelope and data is small enough, return raw data as envelope
- if use_envelope and len(t) <= max_points and return_envelope_min_max:
- return t, x, x, x # x,x for min/max when no decimation
-
- # Calculate step size for decimation based on max_points
- step = max(1, len(t) // max_points)
-
- # For envelope mode, calculate adaptive envelope window based on data density
- adaptive_envelope_window = None
- if use_envelope and len(t) > max_points:
- # Calculate envelope window based on how much we're decimating
- # This ensures envelope resolution matches display capability
- adaptive_envelope_window = max(
- 1, step // 2
- ) # Half the step size for smoother envelope
- logger.debug(
- f"Calculated adaptive envelope window: {adaptive_envelope_window} samples (step={step})"
- )
-
- # Ensure step is not zero, and calculate number of bins
- if step == 0: # Should not happen with max(1, ...) but as a safeguard
- step = 1
- n_bins = len(t) // step
- if (
- n_bins == 0
- ): # If data is too short for the calculated step, take at least one bin
- n_bins = 1
- step = len(t) # Take all points in one bin
-
- # Ensure arrays are contiguous and correct dtype for Numba
- t_contiguous = np.ascontiguousarray(t, dtype=np.float32)
- x_contiguous = np.ascontiguousarray(x, dtype=np.float32)
-
- # Decimate time array using Numba-optimized function
- t_decimated = _decimate_time_numba(t_contiguous, step, n_bins)
-
- # Decimate signal (x) using appropriate Numba-optimized function
- x_min_envelope: Optional[np.ndarray] = None
- x_max_envelope: Optional[np.ndarray] = None
-
- if use_envelope: # This block handles the decimation logic (mean or envelope)
- if adaptive_envelope_window is not None and adaptive_envelope_window > 1:
- logger.debug(
- f"Using adaptive high-resolution envelope with window size {adaptive_envelope_window} samples"
- )
-
- # Use Numba-optimized high-resolution envelope decimation with adaptive window
- x_decimated, x_min_envelope, x_max_envelope = (
- _decimate_envelope_highres_numba(
- t_contiguous,
- x_contiguous,
- step,
- n_bins,
- adaptive_envelope_window,
- )
- )
-
- envelope_thickness = np.mean(x_max_envelope - x_min_envelope)
- logger.debug(
- f"Adaptive envelope thickness: mean={envelope_thickness:.3g}, min={np.min(x_max_envelope - x_min_envelope):.3g}, max={np.max(x_max_envelope - x_min_envelope):.3g}"
- )
- else:
- logger.debug("Using standard bin-based envelope")
-
- # Use Numba-optimized standard envelope decimation
- x_decimated, x_min_envelope, x_max_envelope = (
- _decimate_envelope_standard_numba(x_contiguous, step, n_bins)
- )
-
- # If we are not returning min/max, then x_decimated should be the mean
- # Otherwise, x_decimated is just the mean of the envelope for internal use
- if not return_envelope_min_max:
- x_decimated = (x_min_envelope + x_max_envelope) / 2
- else: # This block is now reached if use_envelope is False AND len(t) > max_points
- logger.debug("Using mean decimation for single line")
-
- # Use Numba-optimized mean decimation
- x_decimated = _decimate_mean_numba(x_contiguous, step, n_bins)
-
- # If return_envelope_min_max is False, ensure min/max are None
- if not return_envelope_min_max:
- x_min_envelope = None
- x_max_envelope = None
-
- return t_decimated, x_decimated, x_min_envelope, x_max_envelope
-
- def pre_decimate_data(
- self,
- data_id: int, # Changed from trace_id to data_id to be more generic for custom lines
- t: np.ndarray,
- x: np.ndarray,
- max_points: int,
- envelope_window_samples: Optional[int] = None, # This parameter is now ignored
- ) -> None:
- """
- Pre-calculate decimated envelope data for the full dataset.
- This is used for fast rendering in zoomed-out (envelope) mode.
-
- Parameters
- ----------
- data_id : int
- Unique identifier for this data set (e.g., trace_id or custom line ID).
- t : np.ndarray
- Time array (raw time in seconds).
- x : np.ndarray
- Signal array.
- max_points : int
- Maximum number of points for the pre-decimated data.
- envelope_window_samples : Optional[int], default=None
- Window size in samples for high-resolution envelope calculation.
- This will primarily determine the bin size for pre-decimation.
- """
- if len(t) <= max_points:
- # For small datasets, just store the original data as the "pre-decimated" envelope
- # (min/max will be the same as x)
- self._pre_decimated_envelopes[data_id] = {
- "t": t,
- "x": x, # Store mean/center for consistency
- "x_min": x,
- "x_max": x,
- }
- logger.debug(
- f"Data ID {data_id} is small enough, storing raw as pre-decimated envelope."
- )
- return
-
- logger.debug(
- f"Pre-decimating data for ID {data_id} to {max_points} points for envelope view."
- )
- # Perform the decimation using the _decimate_data method
- # We force use_envelope=True here for pre-decimation to capture min/max
- # envelope_window_samples is now calculated automatically based on max_points
- t_decimated, x_decimated, x_min, x_max = self._decimate_data(
- t,
- x,
- max_points=max_points,
- use_envelope=True, # Always pre-decimate with envelope
- envelope_window_samples=None, # Let _decimate_data calculate adaptive window
- return_envelope_min_max=True, # Pre-decimation always stores min/max
- )
-
- # Store pre-decimated envelope data
- self._pre_decimated_envelopes[data_id] = {
- "t": t_decimated,
- "x": x_decimated, # This is the mean/center of the envelope
- "x_min": x_min,
- "x_max": x_max,
- }
-
- logger.debug(
- f"Pre-decimated envelope calculated for ID {data_id}: {len(t_decimated)} points."
- )
-
- def decimate_for_view(
- self,
- t_raw_full: np.ndarray, # Full resolution time array
- x_raw_full: np.ndarray, # Full resolution signal array
- xlim_raw: Tuple[np.float32, np.float32],
- max_points: int,
- use_envelope: bool = False,
- data_id: Optional[int] = None, # Changed from trace_id to data_id
- envelope_window_samples: Optional[int] = None, # This parameter is now ignored
- mode_switch_threshold: Optional[
- float
- ] = None, # New parameter for mode switching
- return_envelope_min_max: bool = False, # New parameter
- ) -> Tuple[
- np.ndarray,
- np.ndarray,
- Optional[np.ndarray],
- Optional[np.ndarray],
- ]:
- """
- Intelligently decimate data for current view with optional envelope mode.
-
- Parameters
- ----------
- t_raw_full : np.ndarray
- Full resolution time array (raw time in seconds).
- x_raw_full : np.ndarray
- Full resolution signal array.
- xlim_raw : Tuple[np.float32, np.float32]
- Current x-axis limits in raw time (seconds).
- max_points : int
- Maximum number of points to display.
- use_envelope : bool, default=False
- Whether the current display mode is envelope.
- data_id : Optional[int], default=None
- Unique identifier for this data set (e.g., trace_id or custom line ID).
- Used to retrieve pre-decimated envelope data.
- envelope_window_samples : Optional[int], default=None
- Window size in samples for high-resolution envelope calculation.
- return_envelope_min_max : bool, default=False
- If True, returns x_min_envelope and x_max_envelope. Otherwise, returns None for them.
- mode_switch_threshold : Optional[float], default=None
- Time span threshold for switching between envelope and detail modes.
- Used to decide whether to use pre-decimated envelope data.
-
- Returns
- -------
- Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]
- Decimated time, signal, signal min envelope, signal max envelope arrays (all in raw time).
- """
- logger.debug(f"=== DecimationManager.decimate_for_view data_id={data_id} ===")
- logger.debug(f"xlim_raw: {xlim_raw}")
- logger.debug(f"use_envelope (requested): {use_envelope}")
- logger.debug(f"max_points: {max_points}")
- logger.debug(
- f"Input data range: t=[{np.min(t_raw_full):.6f}, {np.max(t_raw_full):.6f}], x=[{np.min(x_raw_full):.6f}, {np.max(x_raw_full):.6f}]"
- )
-
- # Ensure xlim_raw values are valid
- if (
- not np.isfinite(xlim_raw[0])
- or not np.isfinite(xlim_raw[1])
- or xlim_raw[0] == xlim_raw[1]
- ):
- logger.warning(
- f"Invalid xlim_raw values: {xlim_raw}. Using full data range."
- )
- xlim_raw = (np.min(t_raw_full), np.max(t_raw_full))
-
- # Ensure xlim_raw is in ascending order
- if xlim_raw[0] > xlim_raw[1]:
- logger.warning(f"xlim_raw values out of order: {xlim_raw}. Swapping.")
- xlim_raw = (xlim_raw[1], xlim_raw[0])
-
- # Calculate current view span
- current_view_span = xlim_raw[1] - xlim_raw[0]
-
- # Check cache first
- cache_key = self._get_cache_key(
- xlim_raw, max_points, use_envelope, data_id
- ) # Cache key doesn't need return_envelope_min_max
- if cache_key in self._cache:
- logger.debug(f"Using cached decimation for key: {cache_key}")
- return self._cache[cache_key]
-
- # --- Strategy: Use pre-decimated envelope if in envelope mode and view is wide ---
- if (
- use_envelope
- and data_id is not None
- and data_id in self._pre_decimated_envelopes
- ):
- pre_dec_data = self._pre_decimated_envelopes[data_id]
- pre_dec_t = pre_dec_data["t"]
-
- if len(pre_dec_t) > 1:
- pre_dec_span = pre_dec_t[-1] - pre_dec_t[0]
-
- # Calculate how much detail we would gain by re-decimating
- # Find indices for current view in pre-decimated time
- mask = (pre_dec_t >= xlim_raw[0]) & (pre_dec_t <= xlim_raw[1])
- pre_dec_points_in_view = np.sum(mask)
-
- # Estimate how many points we would get from dynamic decimation
- t_view_mask = (t_raw_full >= xlim_raw[0]) & (t_raw_full <= xlim_raw[1])
- raw_points_in_view = np.sum(t_view_mask)
- potential_decimated_points = min(raw_points_in_view, max_points)
-
- # Use pre-decimated data only if:
- # 1. Current view span is very large (> 2x mode_switch_threshold), AND
- # 2. Pre-decimated data provides reasonable detail (> max_points/4), AND
- # 3. We wouldn't gain much detail from re-decimating (< 2x improvement)
- use_pre_decimated = (
- mode_switch_threshold is not None
- and current_view_span >= 2 * mode_switch_threshold
- and pre_dec_points_in_view > max_points // 4
- and potential_decimated_points < 2 * pre_dec_points_in_view
- )
-
- if use_pre_decimated and np.any(mask):
- logger.debug(
- f"Using pre-decimated data for ID {data_id} (envelope mode, very wide view, {pre_dec_points_in_view} points, return_envelope_min_max={return_envelope_min_max})."
- )
-
- # If we need min/max, return them. Otherwise, return None.
- x_min_ret = (
- pre_dec_data["x_min"][mask] if return_envelope_min_max else None
- )
- x_max_ret = (
- pre_dec_data["x_max"][mask] if return_envelope_min_max else None
- )
-
- result = (
- pre_dec_t[mask],
- pre_dec_data["x"][mask], # Center of envelope
- x_min_ret,
- x_max_ret,
- )
- self._manage_cache_size()
- self._cache[cache_key] = result
- return result
- else:
- logger.debug(
- f"Re-decimating for better detail: view_span={current_view_span:.3e}, pre_dec_points={pre_dec_points_in_view}, potential_points={potential_decimated_points}"
- )
- else:
- logger.debug(
- f"Pre-decimated data for ID {data_id} has only one point, falling back to dynamic decimation."
- )
- else:
- logger.debug(
- f"Not using pre-decimated envelope for ID {data_id} (use_envelope={use_envelope}, data_id={data_id in self._pre_decimated_envelopes})."
- )
-
- # --- Fallback: Dynamic decimation from raw data ---
- logger.debug("Performing dynamic decimation from raw data.")
-
- # ADDED DEBUG LOGS
- logger.debug(
- f" t_raw_full min/max: {t_raw_full.min():.6f}, {t_raw_full.max():.6f}"
- )
- logger.debug(f" xlim_raw: {xlim_raw[0]:.6f}, {xlim_raw[1]:.6f}")
-
- # Find indices for current view in raw time
- mask = (t_raw_full >= xlim_raw[0]) & (t_raw_full <= xlim_raw[1])
-
- # ADDED DEBUG LOG
- logger.debug(f" Mask result: {np.sum(mask)} points selected.")
-
- if not np.any(mask):
- logger.warning(
- f"No data in view for xlim_raw: {xlim_raw}. Returning empty arrays."
- )
- empty_result = (
- np.array([], dtype=np.float32),
- np.array([], dtype=np.float32),
- None,
- None,
- )
- # Cache empty result for this view
- self._manage_cache_size()
- self._cache[cache_key] = empty_result
- return empty_result
-
- t_view = t_raw_full[mask]
- x_view = x_raw_full[mask]
-
- # Add warning for large number of points in detail mode
- if not use_envelope and len(t_view) > self.DETAIL_MODE_POINT_WARNING_THRESHOLD:
- logger.warning(
- f"Plotting {len(t_view)} points in detail mode. "
- f"Performance may be affected. Consider zooming in further."
- )
-
- # Use unified decimation approach
- # envelope_window_samples is now calculated automatically based on max_points and data density
- result = self._decimate_data(
- t_view,
- x_view,
- max_points=max_points,
- use_envelope=use_envelope, # Use requested envelope mode for dynamic decimation
- envelope_window_samples=None, # Let _decimate_data calculate adaptive window
- return_envelope_min_max=return_envelope_min_max, # Pass through
- )
-
- # Cache the result (manage cache size)
- self._manage_cache_size()
- self._cache[cache_key] = result
-
- # Log the final result
- t_result, x_result, x_min_result, x_max_result = result
- logger.debug(f"Returning result: t len={len(t_result)}, x len={len(x_result)}")
- logger.debug(
- f"Result ranges: t=[{np.min(t_result) if len(t_result) > 0 else 'empty':.6f}, {np.max(t_result) if len(t_result) > 0 else 'empty':.6f}], x=[{np.min(x_result) if len(x_result) > 0 else 'empty':.6f}, {np.max(x_result) if len(x_result) > 0 else 'empty':.6f}]"
- )
- logger.debug(
- f"Envelope: x_min={'None' if x_min_result is None else f'len={len(x_min_result)}'}, x_max={'None' if x_max_result is None else f'len={len(x_max_result)}'}"
- )
-
- return result
+from typing import Dict, Optional, Tuple
+import warnings
+
+import numpy as np
+from numba import njit
+
+
+@njit
+def _decimate_time_numba(t: np.ndarray, step: int, n_bins: int) -> np.ndarray:
+ """
+ Numba-optimized time decimation using bin centers.
+
+ Parameters
+ ----------
+ t : np.ndarray
+ Input time array.
+ step : int
+ Step size for binning.
+ n_bins : int
+ Number of bins to create.
+
+ Returns
+ -------
+ np.ndarray
+ Decimated time array with center time of each bin.
+ """
+ t_decimated = np.zeros(n_bins, dtype=np.float32)
+
+ for i in range(n_bins):
+ start_idx = i * step
+ end_idx = min((i + 1) * step, len(t))
+ center_idx = start_idx + (end_idx - start_idx) // 2
+ t_decimated[i] = t[center_idx]
+
+ return t_decimated
+
+
+@njit
+def _decimate_mean_numba(x: np.ndarray, step: int, n_bins: int) -> np.ndarray:
+ """
+ Numba-optimized mean decimation.
+
+ Parameters
+ ----------
+ x : np.ndarray
+ Input signal array.
+ step : int
+ Step size for binning.
+ n_bins : int
+ Number of bins to create.
+
+ Returns
+ -------
+ np.ndarray
+ Decimated signal array with mean values.
+ """
+ x_decimated = np.zeros(n_bins, dtype=np.float32)
+
+ for i in range(n_bins):
+ start_idx = i * step
+ end_idx = min((i + 1) * step, len(x))
+
+ if end_idx > start_idx:
+ # Calculate mean manually for Numba compatibility
+ bin_sum = 0.0
+ bin_count = end_idx - start_idx
+ for j in range(start_idx, end_idx):
+ bin_sum += x[j]
+ x_decimated[i] = bin_sum / bin_count
+ else:
+ x_decimated[i] = x[start_idx] if start_idx < len(x) else 0.0
+
+ return x_decimated
+
+
+@njit
+def _decimate_envelope_standard_numba(
+ x: np.ndarray, step: int, n_bins: int
+) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Numba-optimized standard envelope decimation.
+
+ Parameters
+ ----------
+ x : np.ndarray
+ Input signal array.
+ step : int
+ Step size for binning.
+ n_bins : int
+ Number of bins to create.
+
+ Returns
+ -------
+ Tuple[np.ndarray, np.ndarray, np.ndarray]
+ Decimated signal (mean), min envelope, max envelope arrays.
+ """
+ x_decimated = np.zeros(n_bins, dtype=np.float32)
+ x_min_envelope = np.zeros(n_bins, dtype=np.float32)
+ x_max_envelope = np.zeros(n_bins, dtype=np.float32)
+
+ for i in range(n_bins):
+ start_idx = i * step
+ end_idx = min((i + 1) * step, len(x))
+
+ if end_idx > start_idx:
+ # Find min and max manually for Numba compatibility
+ bin_min = x[start_idx]
+ bin_max = x[start_idx]
+ bin_sum = 0.0
+
+ for j in range(start_idx, end_idx):
+ val = x[j]
+ if val < bin_min:
+ bin_min = val
+ if val > bin_max:
+ bin_max = val
+ bin_sum += val
+
+ x_min_envelope[i] = bin_min
+ x_max_envelope[i] = bin_max
+ x_decimated[i] = bin_sum / (end_idx - start_idx)
+ else:
+ fallback_val = x[start_idx] if start_idx < len(x) else 0.0
+ x_min_envelope[i] = fallback_val
+ x_max_envelope[i] = fallback_val
+ x_decimated[i] = fallback_val
+
+ return x_decimated, x_min_envelope, x_max_envelope
+
+
+@njit
+def _decimate_envelope_highres_numba(
+ t: np.ndarray, x: np.ndarray, step: int, n_bins: int, envelope_window_samples: int
+) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Numba-optimized high-resolution envelope decimation.
+
+ Parameters
+ ----------
+ t : np.ndarray
+ Input time array.
+ x : np.ndarray
+ Input signal array.
+ step : int
+ Step size for binning.
+ n_bins : int
+ Number of bins to create.
+ envelope_window_samples : int
+ Window size in samples for high-resolution envelope calculation.
+
+ Returns
+ -------
+ Tuple[np.ndarray, np.ndarray, np.ndarray]
+ Decimated signal (mean), min envelope, max envelope arrays.
+ """
+ x_decimated = np.zeros(n_bins, dtype=np.float32)
+ x_min_envelope = np.zeros(n_bins, dtype=np.float32)
+ x_max_envelope = np.zeros(n_bins, dtype=np.float32)
+
+ half_window = envelope_window_samples // 2
+
+ for i in range(n_bins):
+ start_idx = i * step
+ end_idx = min((i + 1) * step, len(t))
+ bin_center = start_idx + (end_idx - start_idx) // 2
+
+ # Define window around bin center
+ window_start = max(0, bin_center - half_window)
+ window_end = min(len(x), bin_center + half_window)
+
+ if window_end > window_start:
+ # Find min and max in window manually for Numba compatibility
+ window_min = x[window_start]
+ window_max = x[window_start]
+
+ for j in range(window_start, window_end):
+ val = x[j]
+ if val < window_min:
+ window_min = val
+ if val > window_max:
+ window_max = val
+
+ x_min_envelope[i] = window_min
+ x_max_envelope[i] = window_max
+ x_decimated[i] = (window_min + window_max) / 2.0
+ else:
+ fallback_val = x[bin_center] if bin_center < len(x) else 0.0
+ x_min_envelope[i] = fallback_val
+ x_max_envelope[i] = fallback_val
+ x_decimated[i] = fallback_val
+
+ return x_decimated, x_min_envelope, x_max_envelope
+
+
+class DecimationManager:
+ """
+ Handles data decimation and caching for efficient plotting.
+
+ Manages different decimation strategies and caches results to improve performance.
+ Pre-calculates decimated data at load time for faster zooming.
+ """
+
+ # Cache and performance constants
+ CACHE_MAX_SIZE = 10
+ MIN_VISIBLE_RANGE_DEFAULT = 1e-6 # Default if no global noise is provided
+ # Threshold for warning about too many points in detail mode
+ DETAIL_MODE_POINT_WARNING_THRESHOLD = 100000
+
+ def __init__(self, cache_max_size: int = CACHE_MAX_SIZE):
+ """
+ Initialise the decimation manager.
+
+ Parameters
+ ----------
+ cache_max_size : int, default=PlotConstants.CACHE_MAX_SIZE
+ Maximum number of cached decimation results.
+ """
+ self._cache: Dict[str, Tuple[np.ndarray, ...]] = {}
+ self._cache_max_size = cache_max_size
+ # Stores pre-decimated envelope data for the full dataset for each trace/line
+ # Structure: {trace_id: {'t': np.ndarray, 'x_min': np.ndarray, 'x_max': np.ndarray, ...}}
+ self._pre_decimated_envelopes: Dict[int, Dict[str, np.ndarray]] = {}
+
+ def _get_cache_key(
+ self,
+ xlim_raw: Tuple[np.float32, np.float32],
+ max_points: int,
+ use_envelope: bool,
+ trace_id: Optional[int] = None,
+ ) -> str:
+ """Generate cache key for decimated data."""
+ # Round to reasonable precision to improve cache hits
+ xlim_rounded = (round(float(xlim_raw[0]), 9), round(float(xlim_raw[1]), 9))
+
+ # Include trace_id in cache key for multi-trace support
+ trace_suffix = f"_t{trace_id}" if trace_id is not None else ""
+
+ return f"{xlim_rounded}_{max_points}_{use_envelope}{trace_suffix}"
+
+ def _manage_cache_size(self) -> None:
+ """Remove oldest cache entry if cache is full."""
+ if len(self._cache) >= self._cache_max_size:
+ # Remove oldest entry (simple FIFO)
+ oldest_key = next(iter(self._cache))
+ del self._cache[oldest_key]
+
+ def clear_cache(self) -> None:
+ """Clear the decimation cache."""
+ self._cache.clear()
+ # Do NOT clear _pre_decimated_envelopes here, as they are persistent for the full dataset
+
+ def _decimate_data(
+ self,
+ t: np.ndarray,
+ x: np.ndarray,
+ max_points: int,
+ use_envelope: bool = False,
+ envelope_window_samples: Optional[int] = None,
+ return_envelope_min_max: bool = False, # New parameter
+ ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]:
+ """
+ Unified decimation for time and multiple data arrays.
+
+ Parameters
+ ----------
+ t : np.ndarray
+ Time array.
+ x : np.ndarray
+ Signal array.
+ max_points : int, default=5000
+ Maximum number of points to display.
+ use_envelope : bool, default=False
+ Whether to use envelope decimation for the signal array.
+ envelope_window_samples : Optional[int], default=None
+ Window size in samples for high-resolution envelope calculation.
+ return_envelope_min_max : bool, default=False
+ If True, returns x_min_envelope and x_max_envelope. Otherwise, returns None for them.
+ If None, uses simple binning approach.
+
+ Returns
+ -------
+ Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]
+ Decimated time, signal, signal min envelope, signal max envelope arrays.
+ """
+ # If input arrays are empty, return empty arrays immediately
+ if len(t) == 0:
+ return (
+ np.array([], dtype=np.float32),
+ np.array([], dtype=np.float32),
+ None,
+ None,
+ )
+
+ # If not using envelope, always return raw data for the view
+ if (
+ not use_envelope and not return_envelope_min_max
+ ): # If not using envelope and not explicitly asking for min/max
+ return t, x, None, None # No min/max envelope for raw data
+
+ # If using envelope and data is small enough, return raw data as envelope
+ if use_envelope and len(t) <= max_points and return_envelope_min_max:
+ return t, x, x, x # x,x for min/max when no decimation
+
+ # Calculate step size for decimation based on max_points
+ step = max(1, len(t) // max_points)
+
+ # For envelope mode, calculate adaptive envelope window based on data density
+ adaptive_envelope_window = None
+ if use_envelope and len(t) > max_points:
+ # Calculate envelope window based on how much we're decimating
+ # This ensures envelope resolution matches display capability
+ adaptive_envelope_window = max(
+ 1, step // 2
+ ) # Half the step size for smoother envelope
+ logger.debug(
+ f"Calculated adaptive envelope window: {adaptive_envelope_window} samples (step={step})"
+ )
+
+ # Ensure step is not zero, and calculate number of bins
+ if step == 0: # Should not happen with max(1, ...) but as a safeguard
+ step = 1
+ n_bins = len(t) // step
+ if (
+ n_bins == 0
+ ): # If data is too short for the calculated step, take at least one bin
+ n_bins = 1
+ step = len(t) # Take all points in one bin
+
+ # Ensure arrays are contiguous and correct dtype for Numba
+ t_contiguous = np.ascontiguousarray(t, dtype=np.float32)
+ x_contiguous = np.ascontiguousarray(x, dtype=np.float32)
+
+ # Decimate time array using Numba-optimized function
+ t_decimated = _decimate_time_numba(t_contiguous, step, n_bins)
+
+ # Decimate signal (x) using appropriate Numba-optimized function
+ x_min_envelope: Optional[np.ndarray] = None
+ x_max_envelope: Optional[np.ndarray] = None
+
+ if use_envelope: # This block handles the decimation logic (mean or envelope)
+ if adaptive_envelope_window is not None and adaptive_envelope_window > 1:
+ # Use Numba-optimized high-resolution envelope decimation with adaptive window
+ x_decimated, x_min_envelope, x_max_envelope = (
+ _decimate_envelope_highres_numba(
+ t_contiguous,
+ x_contiguous,
+ step,
+ n_bins,
+ adaptive_envelope_window,
+ )
+ )
+ else:
+ # Use Numba-optimized standard envelope decimation
+ x_decimated, x_min_envelope, x_max_envelope = (
+ _decimate_envelope_standard_numba(x_contiguous, step, n_bins)
+ )
+
+ # If we are not returning min/max, then x_decimated should be the mean
+ # Otherwise, x_decimated is just the mean of the envelope for internal use
+ if not return_envelope_min_max:
+ x_decimated = (x_min_envelope + x_max_envelope) / 2
+ else: # This block is now reached if use_envelope is False AND len(t) > max_points
+ # Use Numba-optimized mean decimation
+ x_decimated = _decimate_mean_numba(x_contiguous, step, n_bins)
+
+ # If return_envelope_min_max is False, ensure min/max are None
+ if not return_envelope_min_max:
+ x_min_envelope = None
+ x_max_envelope = None
+
+ return t_decimated, x_decimated, x_min_envelope, x_max_envelope
+
+ def pre_decimate_data(
+ self,
+ data_id: int, # Changed from trace_id to data_id to be more generic for custom lines
+ t: np.ndarray,
+ x: np.ndarray,
+ max_points: int,
+ envelope_window_samples: Optional[int] = None, # This parameter is now ignored
+ ) -> None:
+ """
+ Pre-calculate decimated envelope data for the full dataset.
+ This is used for fast rendering in zoomed-out (envelope) mode.
+
+ Parameters
+ ----------
+ data_id : int
+ Unique identifier for this data set (e.g., trace_id or custom line ID).
+ t : np.ndarray
+ Time array (raw time in seconds).
+ x : np.ndarray
+ Signal array.
+ max_points : int
+ Maximum number of points for the pre-decimated data.
+ envelope_window_samples : Optional[int], default=None
+ Window size in samples for high-resolution envelope calculation.
+ This will primarily determine the bin size for pre-decimation.
+ """
+ if len(t) <= max_points:
+ # For small datasets, just store the original data as the "pre-decimated" envelope
+ # (min/max will be the same as x)
+ self._pre_decimated_envelopes[data_id] = {
+ "t": t,
+ "x": x, # Store mean/center for consistency
+ "x_min": x,
+ "x_max": x,
+ }
+ return
+ # Perform the decimation using the _decimate_data method
+ # We force use_envelope=True here for pre-decimation to capture min/max
+ # envelope_window_samples is now calculated automatically based on max_points
+ t_decimated, x_decimated, x_min, x_max = self._decimate_data(
+ t,
+ x,
+ max_points=max_points,
+ use_envelope=True, # Always pre-decimate with envelope
+ envelope_window_samples=None, # Let _decimate_data calculate adaptive window
+ return_envelope_min_max=True, # Pre-decimation always stores min/max
+ )
+
+ # Store pre-decimated envelope data
+ self._pre_decimated_envelopes[data_id] = {
+ "t": t_decimated,
+ "x": x_decimated, # This is the mean/center of the envelope
+ "x_min": x_min,
+ "x_max": x_max,
+ }
+
+ def decimate_for_view(
+ self,
+ t_raw_full: np.ndarray, # Full resolution time array
+ x_raw_full: np.ndarray, # Full resolution signal array
+ xlim_raw: Tuple[np.float32, np.float32],
+ max_points: int,
+ use_envelope: bool = False,
+ data_id: Optional[int] = None, # Changed from trace_id to data_id
+ envelope_window_samples: Optional[int] = None, # This parameter is now ignored
+ mode_switch_threshold: Optional[
+ float
+ ] = None, # New parameter for mode switching
+ return_envelope_min_max: bool = False, # New parameter
+ ) -> Tuple[
+ np.ndarray,
+ np.ndarray,
+ Optional[np.ndarray],
+ Optional[np.ndarray],
+ ]:
+ """
+ Intelligently decimate data for current view with optional envelope mode.
+
+ Parameters
+ ----------
+ t_raw_full : np.ndarray
+ Full resolution time array (raw time in seconds).
+ x_raw_full : np.ndarray
+ Full resolution signal array.
+ xlim_raw : Tuple[np.float32, np.float32]
+ Current x-axis limits in raw time (seconds).
+ max_points : int
+ Maximum number of points to display.
+ use_envelope : bool, default=False
+ Whether the current display mode is envelope.
+ data_id : Optional[int], default=None
+ Unique identifier for this data set (e.g., trace_id or custom line ID).
+ Used to retrieve pre-decimated envelope data.
+ envelope_window_samples : Optional[int], default=None
+ Window size in samples for high-resolution envelope calculation.
+ return_envelope_min_max : bool, default=False
+ If True, returns x_min_envelope and x_max_envelope. Otherwise, returns None for them.
+ mode_switch_threshold : Optional[float], default=None
+ Time span threshold for switching between envelope and detail modes.
+ Used to decide whether to use pre-decimated envelope data.
+
+ Returns
+ -------
+ Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]
+ Decimated time, signal, signal min envelope, signal max envelope arrays (all in raw time).
+ """
+ # Ensure xlim_raw values are valid
+ if (
+ not np.isfinite(xlim_raw[0])
+ or not np.isfinite(xlim_raw[1])
+ or xlim_raw[0] == xlim_raw[1]
+ ):
+ warnings.warn(
+ f"Invalid xlim_raw values: {xlim_raw}. Using full data range.", RuntimeWarning
+ )
+ xlim_raw = (np.min(t_raw_full), np.max(t_raw_full))
+
+ # Ensure xlim_raw is in ascending order
+ if xlim_raw[0] > xlim_raw[1]:
+ warnings.warn(f"xlim_raw values out of order: {xlim_raw}. Swapping.", RuntimeWarning)
+ xlim_raw = (xlim_raw[1], xlim_raw[0])
+
+ # Calculate current view span
+ current_view_span = xlim_raw[1] - xlim_raw[0]
+
+ # Check cache first
+ cache_key = self._get_cache_key(
+ xlim_raw, max_points, use_envelope, data_id
+ ) # Cache key doesn't need return_envelope_min_max
+ if cache_key in self._cache:
+ logger.debug(f"Using cached decimation for key: {cache_key}")
+ return self._cache[cache_key]
+
+ # --- Strategy: Use pre-decimated envelope if in envelope mode and view is wide ---
+ if (
+ use_envelope
+ and data_id is not None
+ and data_id in self._pre_decimated_envelopes
+ ):
+ pre_dec_data = self._pre_decimated_envelopes[data_id]
+ pre_dec_t = pre_dec_data["t"]
+
+ if len(pre_dec_t) > 1:
+ pre_dec_span = pre_dec_t[-1] - pre_dec_t[0]
+
+ # Calculate how much detail we would gain by re-decimating
+ # Find indices for current view in pre-decimated time
+ mask = (pre_dec_t >= xlim_raw[0]) & (pre_dec_t <= xlim_raw[1])
+ pre_dec_points_in_view = np.sum(mask)
+
+ # Estimate how many points we would get from dynamic decimation
+ t_view_mask = (t_raw_full >= xlim_raw[0]) & (t_raw_full <= xlim_raw[1])
+ raw_points_in_view = np.sum(t_view_mask)
+ potential_decimated_points = min(raw_points_in_view, max_points)
+
+ # Use pre-decimated data only if:
+ # 1. Current view span is very large (> 2x mode_switch_threshold), AND
+ # 2. Pre-decimated data provides reasonable detail (> max_points/4), AND
+ # 3. We wouldn't gain much detail from re-decimating (< 2x improvement)
+ use_pre_decimated = (
+ mode_switch_threshold is not None
+ and current_view_span >= 2 * mode_switch_threshold
+ and pre_dec_points_in_view > max_points // 4
+ and potential_decimated_points < 2 * pre_dec_points_in_view
+ )
+
+ if use_pre_decimated and np.any(mask):
+ # If we need min/max, return them. Otherwise, return None.
+ x_min_ret = (
+ pre_dec_data["x_min"][mask] if return_envelope_min_max else None
+ )
+ x_max_ret = (
+ pre_dec_data["x_max"][mask] if return_envelope_min_max else None
+ )
+
+ result = (
+ pre_dec_t[mask],
+ pre_dec_data["x"][mask], # Center of envelope
+ x_min_ret,
+ x_max_ret,
+ )
+ self._manage_cache_size()
+ self._cache[cache_key] = result
+ return result
+
+ # --- Fallback: Dynamic decimation from raw data ---
+ # Find indices for current view in raw time
+ mask = (t_raw_full >= xlim_raw[0]) & (t_raw_full <= xlim_raw[1])
+
+ if not np.any(mask):
+ warnings.warn(
+ f"No data in view for xlim_raw: {xlim_raw}. Returning empty arrays.", RuntimeWarning
+ )
+ empty_result = (
+ np.array([], dtype=np.float32),
+ np.array([], dtype=np.float32),
+ None,
+ None,
+ )
+ # Cache empty result for this view
+ self._manage_cache_size()
+ self._cache[cache_key] = empty_result
+ return empty_result
+
+ t_view = t_raw_full[mask]
+ x_view = x_raw_full[mask]
+
+ # Add warning for large number of points in detail mode
+ if not use_envelope and len(t_view) > self.DETAIL_MODE_POINT_WARNING_THRESHOLD:
+ warnings.warn(
+ f"Plotting {len(t_view)} points in detail mode. "
+ f"Performance may be affected. Consider zooming in further.", UserWarning
+ )
+
+ # Use unified decimation approach
+ # envelope_window_samples is now calculated automatically based on max_points and data density
+ result = self._decimate_data(
+ t_view,
+ x_view,
+ max_points=max_points,
+ use_envelope=use_envelope, # Use requested envelope mode for dynamic decimation
+ envelope_window_samples=None, # Let _decimate_data calculate adaptive window
+ return_envelope_min_max=return_envelope_min_max, # Pass through
+ )
+
+ # Cache the result (manage cache size)
+ self._manage_cache_size()
+ self._cache[cache_key] = result
+
+ return result
diff --git a/src/scopekit/display_state.py b/src/scopekit/display_state.py
index b66b556..793bc4c 100644
--- a/src/scopekit/display_state.py
+++ b/src/scopekit/display_state.py
@@ -1,294 +1,286 @@
-from typing import Optional, Tuple
-
-import numpy as np
-from loguru import logger
-from matplotlib.ticker import FuncFormatter
-
-# Time unit boundaries (hysteresis)
-PICOSECOND_BOUNDARY = 0.8e-9
-NANOSECOND_BOUNDARY = 0.8e-6
-MICROSECOND_BOUNDARY = 0.8e-3
-MILLISECOND_BOUNDARY = 0.8
-
-# Offset thresholds
-OFFSET_SPAN_MULTIPLIER = 10
-OFFSET_TIME_THRESHOLD = 1e-3 # 1ms
-
-
-def _get_optimal_time_unit_and_scale(
- time_array_or_span: np.ndarray | float,
-) -> Tuple[str, np.float32]:
- """
- Determines the optimal time unit and scaling factor for a given time array or span.
-
- Uses hysteresis boundaries to prevent oscillation near unit boundaries.
-
- Parameters
- ----------
- time_array_or_span : np.ndarray | float
- A NumPy array representing time in seconds, or a single float representing a time span in seconds.
-
- Returns
- -------
- Tuple[str, np.float32]
- A tuple containing the time unit string (e.g., "s", "ms", "us", "ns")
- and the corresponding scaling factor (e.0, 1e3, 1e6, 1e9).
- """
- if isinstance(time_array_or_span, np.ndarray):
- # Handle empty array case to prevent errors
- if time_array_or_span.size == 0:
- return "s", np.float32(1.0) # Default to seconds if no data
- max_val = np.max(time_array_or_span)
- else: # Assume it's a float representing a span
- max_val = time_array_or_span
-
- # Use hysteresis boundaries to prevent oscillation near unit boundaries
- if max_val < PICOSECOND_BOUNDARY:
- return "ps", np.float32(1e12)
- elif max_val < NANOSECOND_BOUNDARY:
- return "ns", np.float32(1e9)
- elif max_val < MICROSECOND_BOUNDARY:
- return "us", np.float32(1e6)
- elif max_val < MILLISECOND_BOUNDARY:
- return "ms", np.float32(1e3)
- else:
- return "s", np.float32(1.0)
-
-
-def _determine_offset_display_params(
- xlim_raw: Tuple[np.float32, np.float32], time_span_raw: np.float32
-) -> Tuple[str, np.float32, Optional[np.float32], Optional[str]]:
- """
- Determine display parameters including offset for optimal readability.
-
- Parameters
- ----------
- xlim_raw : Tuple[np.float32, np.float32]
- Current x-axis limits in raw time (seconds).
- time_span_raw : np.float32
- Time span of current view in seconds.
-
- Returns
- -------
- Tuple[str, np.float32, Optional[np.float32], Optional[str]]
- Display unit, display scale, offset time (raw seconds), offset unit string.
- If no offset is needed, offset_time and offset_unit will be None.
- """
- # Get optimal unit for the time span
- display_unit, display_scale = _get_optimal_time_unit_and_scale(time_span_raw)
-
- # Determine if we need an offset
- # Use offset if the start time is significantly larger than the span
- xlim_start = xlim_raw[0]
-
- # Use offset if start time is more than threshold multiplier of the span, and span is small
- use_offset = (abs(xlim_start) > OFFSET_SPAN_MULTIPLIER * time_span_raw) and (
- time_span_raw < np.float32(OFFSET_TIME_THRESHOLD)
- )
-
- if use_offset:
- # Choose appropriate unit for the offset
- if abs(xlim_start) >= np.float32(1.0): # >= 1 second
- offset_unit = "s"
- offset_scale = np.float32(1.0)
- elif abs(xlim_start) >= np.float32(1e-3): # >= 1 millisecond
- offset_unit = "ms"
- offset_scale = np.float32(1e3)
- elif abs(xlim_start) >= np.float32(1e-6): # >= 1 microsecond
- offset_unit = "us"
- offset_scale = np.float32(1e6)
- else:
- offset_unit = "ns"
- offset_scale = np.float32(1e9)
-
- return display_unit, display_scale, xlim_start, offset_unit
- else:
- return display_unit, display_scale, None, None
-
-
-def _create_time_formatter(
- offset_time_raw: Optional[np.float32], display_scale: np.float32
-) -> FuncFormatter:
- """
- Create a FuncFormatter for time axis tick labels.
-
- Parameters
- ----------
- offset_time_raw : Optional[np.float32]
- Offset time in raw seconds. If None, no offset is applied.
- display_scale : np.float32
- Scale factor for display units.
-
- Returns
- -------
- FuncFormatter
- Matplotlib formatter for tick labels.
- """
-
- def formatter(x, pos):
- # x is already in display units (relative to offset if applicable)
- # Format with appropriate precision based on scale
- if display_scale >= np.float32(1e9): # nanoseconds or smaller
- return f"{x:.0f}"
- elif display_scale >= np.float32(1e6): # microseconds
- return f"{x:.0f}"
- elif display_scale >= np.float32(1e3): # milliseconds
- return f"{x:.1f}"
- else: # seconds
- return f"{x:.3f}"
-
- return FuncFormatter(formatter)
-
-
-class DisplayState:
- """
- Manages display state and mode switching logic.
-
- Centralises state management to reduce complexity and flag interactions.
- """
-
- def __init__(
- self,
- original_time_unit: str,
- original_time_scale: np.float32,
- envelope_limit: np.float32,
- ):
- """
- Initialise display state.
-
- Parameters
- ----------
- original_time_unit : str
- Original time unit string.
- original_time_scale : np.float32
- Original time scaling factor.
- envelope_limit : np.float32
- Time span threshold for envelope mode.
- """
- # Time scaling
- self.original_time_unit = original_time_unit
- self.original_time_scale = original_time_scale
- self.current_time_unit = original_time_unit
- self.current_time_scale = original_time_scale
-
- # Display mode
- self.current_mode: Optional[str] = None
- self.envelope_limit = envelope_limit
-
- # Offset parameters
- self.offset_time_raw: Optional[np.float32] = None
- self.offset_unit: Optional[str] = None
-
- # Single state flag - simplified
- self._updating = False
-
- def get_time_unit_and_scale(self, t: np.ndarray) -> Tuple[str, np.float32]:
- """
- Automatically select appropriate time unit and scale for plotting.
-
- Parameters
- ----------
- t : np.ndarray
- Time array.
-
- Returns
- -------
- Tuple[str, np.float32]
- Time unit string and scaling factor.
- """
- # Delegate to the new utility function
- return _get_optimal_time_unit_and_scale(t)
-
- def update_display_params(
- self, xlim_raw: Tuple[np.float32, np.float32], time_span_raw: np.float32
- ) -> bool:
- """
- Update display parameters including offset based on current view.
-
- Parameters
- ----------
- xlim_raw : Tuple[np.float32, np.float32]
- Current x-axis limits in raw time (seconds).
- time_span_raw : np.float32
- Time span of current view in seconds.
-
- Returns
- -------
- bool
- True if display parameters changed, False otherwise.
- """
- display_unit, display_scale, offset_time, offset_unit = (
- _determine_offset_display_params(xlim_raw, time_span_raw)
- )
-
- # Check if anything changed
- params_changed = (
- display_unit != self.current_time_unit
- or display_scale != self.current_time_scale
- or offset_time != self.offset_time_raw
- or offset_unit != self.offset_unit
- )
-
- if params_changed:
- logger.info(
- f"Display params changed: unit={display_unit}, scale={display_scale:.1e}, offset={offset_time}, offset_unit={offset_unit}"
- )
- self.current_time_unit = display_unit
- self.current_time_scale = display_scale
- self.offset_time_raw = offset_time
- self.offset_unit = offset_unit
- return True
-
- return False
-
- def should_use_envelope(self, time_span_raw: np.float32) -> bool:
- """Determine if envelope mode should be used based on time span."""
- return time_span_raw > self.envelope_limit
-
- def should_show_thresholds(self, time_span_raw: np.float32) -> bool:
- """Determine if threshold lines should be shown based on time span."""
- return time_span_raw < self.envelope_limit
-
- def update_time_scale(self, time_span_raw: np.float32) -> bool:
- """
- Update time scale based on current view span.
-
- Returns True if scale changed, False otherwise.
- """
- # Delegate to the new utility function for span
- new_unit, new_scale = _get_optimal_time_unit_and_scale(time_span_raw)
-
- if new_scale != self.current_time_scale:
- logger.info(
- f"Time scale changed from {self.current_time_unit} ({self.current_time_scale:.1e}) to {new_unit} ({new_scale:.1e})"
- )
- self.current_time_unit = new_unit
- self.current_time_scale = new_scale
- return True
-
- return False
-
- def reset_to_original_scale(self) -> None:
- """Reset time scale to original values."""
- self.current_time_unit = self.original_time_unit
- self.current_time_scale = self.original_time_scale
- logger.info(
- f"Reset to original scale: {self.current_time_unit} ({self.current_time_scale:.1e})"
- )
-
- def reset_to_initial_state(self) -> None:
- """Reset all display parameters to initial values."""
- self.current_time_unit = self.original_time_unit
- self.current_time_scale = self.original_time_scale
- self.offset_time_raw = None
- self.offset_unit = None
- self.current_mode = None
- self._updating = False
-
- def set_updating(self, value: bool = True) -> None:
- """Set updating state to prevent recursion."""
- self._updating = value
-
- def is_updating(self) -> bool:
- """Check if currently updating."""
- return self._updating
+from typing import Optional, Tuple
+
+import warnings
+
+import numpy as np
+from matplotlib.ticker import FuncFormatter
+
+# Time unit boundaries (hysteresis)
+PICOSECOND_BOUNDARY = 0.8e-9
+NANOSECOND_BOUNDARY = 0.8e-6
+MICROSECOND_BOUNDARY = 0.8e-3
+MILLISECOND_BOUNDARY = 0.8
+
+# Offset thresholds
+OFFSET_SPAN_MULTIPLIER = 10
+OFFSET_TIME_THRESHOLD = 1e-3 # 1ms
+
+
+def _get_optimal_time_unit_and_scale(
+ time_array_or_span: np.ndarray | float,
+) -> Tuple[str, np.float32]:
+ """
+ Determines the optimal time unit and scaling factor for a given time array or span.
+
+ Uses hysteresis boundaries to prevent oscillation near unit boundaries.
+
+ Parameters
+ ----------
+ time_array_or_span : np.ndarray | float
+ A NumPy array representing time in seconds, or a single float representing a time span in seconds.
+
+ Returns
+ -------
+ Tuple[str, np.float32]
+ A tuple containing the time unit string (e.g., "s", "ms", "us", "ns")
+ and the corresponding scaling factor (e.0, 1e3, 1e6, 1e9).
+ """
+ if isinstance(time_array_or_span, np.ndarray):
+ # Handle empty array case to prevent errors
+ if time_array_or_span.size == 0:
+ return "s", np.float32(1.0) # Default to seconds if no data
+ max_val = np.max(time_array_or_span)
+ else: # Assume it's a float representing a span
+ max_val = time_array_or_span
+
+ # Use hysteresis boundaries to prevent oscillation near unit boundaries
+ if max_val < PICOSECOND_BOUNDARY:
+ return "ps", np.float32(1e12)
+ elif max_val < NANOSECOND_BOUNDARY:
+ return "ns", np.float32(1e9)
+ elif max_val < MICROSECOND_BOUNDARY:
+ return "us", np.float32(1e6)
+ elif max_val < MILLISECOND_BOUNDARY:
+ return "ms", np.float32(1e3)
+ else:
+ return "s", np.float32(1.0)
+
+
+def _determine_offset_display_params(
+ xlim_raw: Tuple[np.float32, np.float32], time_span_raw: np.float32
+) -> Tuple[str, np.float32, Optional[np.float32], Optional[str]]:
+ """
+ Determine display parameters including offset for optimal readability.
+
+ Parameters
+ ----------
+ xlim_raw : Tuple[np.float32, np.float32]
+ Current x-axis limits in raw time (seconds).
+ time_span_raw : np.float32
+ Time span of current view in seconds.
+
+ Returns
+ -------
+ Tuple[str, np.float32, Optional[np.float32], Optional[str]]
+ Display unit, display scale, offset time (raw seconds), offset unit string.
+ If no offset is needed, offset_time and offset_unit will be None.
+ """
+ # Get optimal unit for the time span
+ display_unit, display_scale = _get_optimal_time_unit_and_scale(time_span_raw)
+
+ # Determine if we need an offset
+ # Use offset if the start time is significantly larger than the span
+ xlim_start = xlim_raw[0]
+
+ # Use offset if start time is more than threshold multiplier of the span, and span is small
+ use_offset = (abs(xlim_start) > OFFSET_SPAN_MULTIPLIER * time_span_raw) and (
+ time_span_raw < np.float32(OFFSET_TIME_THRESHOLD)
+ )
+
+ if use_offset:
+ # Choose appropriate unit for the offset
+ if abs(xlim_start) >= np.float32(1.0): # >= 1 second
+ offset_unit = "s"
+ offset_scale = np.float32(1.0)
+ elif abs(xlim_start) >= np.float32(1e-3): # >= 1 millisecond
+ offset_unit = "ms"
+ offset_scale = np.float32(1e3)
+ elif abs(xlim_start) >= np.float32(1e-6): # >= 1 microsecond
+ offset_unit = "us"
+ offset_scale = np.float32(1e6)
+ else:
+ offset_unit = "ns"
+ offset_scale = np.float32(1e9)
+
+ return display_unit, display_scale, xlim_start, offset_unit
+ else:
+ return display_unit, display_scale, None, None
+
+
+def _create_time_formatter(
+ offset_time_raw: Optional[np.float32], display_scale: np.float32
+) -> FuncFormatter:
+ """
+ Create a FuncFormatter for time axis tick labels.
+
+ Parameters
+ ----------
+ offset_time_raw : Optional[np.float32]
+ Offset time in raw seconds. If None, no offset is applied.
+ display_scale : np.float32
+ Scale factor for display units.
+
+ Returns
+ -------
+ FuncFormatter
+ Matplotlib formatter for tick labels.
+ """
+
+ def formatter(x, pos):
+ # x is already in display units (relative to offset if applicable)
+ # Format with appropriate precision based on scale
+ if display_scale >= np.float32(1e9): # nanoseconds or smaller
+ return f"{x:.0f}"
+ elif display_scale >= np.float32(1e6): # microseconds
+ return f"{x:.0f}"
+ elif display_scale >= np.float32(1e3): # milliseconds
+ return f"{x:.1f}"
+ else: # seconds
+ return f"{x:.3f}"
+
+ return FuncFormatter(formatter)
+
+
+class DisplayState:
+ """
+ Manages display state and mode switching logic.
+
+ Centralises state management to reduce complexity and flag interactions.
+ """
+
+ def __init__(
+ self,
+ original_time_unit: str,
+ original_time_scale: np.float32,
+ envelope_limit: np.float32,
+ ):
+ """
+ Initialise display state.
+
+ Parameters
+ ----------
+ original_time_unit : str
+ Original time unit string.
+ original_time_scale : np.float32
+ Original time scaling factor.
+ envelope_limit : np.float32
+ Time span threshold for envelope mode.
+ """
+ # Time scaling
+ self.original_time_unit = original_time_unit
+ self.original_time_scale = original_time_scale
+ self.current_time_unit = original_time_unit
+ self.current_time_scale = original_time_scale
+
+ # Display mode
+ self.current_mode: Optional[str] = None
+ self.envelope_limit = envelope_limit
+
+ # Offset parameters
+ self.offset_time_raw: Optional[np.float32] = None
+ self.offset_unit: Optional[str] = None
+
+ # Single state flag - simplified
+ self._updating = False
+
+ def get_time_unit_and_scale(self, t: np.ndarray) -> Tuple[str, np.float32]:
+ """
+ Automatically select appropriate time unit and scale for plotting.
+
+ Parameters
+ ----------
+ t : np.ndarray
+ Time array.
+
+ Returns
+ -------
+ Tuple[str, np.float32]
+ Time unit string and scaling factor.
+ """
+ # Delegate to the new utility function
+ return _get_optimal_time_unit_and_scale(t)
+
+ def update_display_params(
+ self, xlim_raw: Tuple[np.float32, np.float32], time_span_raw: np.float32
+ ) -> bool:
+ """
+ Update display parameters including offset based on current view.
+
+ Parameters
+ ----------
+ xlim_raw : Tuple[np.float32, np.float32]
+ Current x-axis limits in raw time (seconds).
+ time_span_raw : np.float32
+ Time span of current view in seconds.
+
+ Returns
+ -------
+ bool
+ True if display parameters changed, False otherwise.
+ """
+ display_unit, display_scale, offset_time, offset_unit = (
+ _determine_offset_display_params(xlim_raw, time_span_raw)
+ )
+
+ # Check if anything changed
+ params_changed = (
+ display_unit != self.current_time_unit
+ or display_scale != self.current_time_scale
+ or offset_time != self.offset_time_raw
+ or offset_unit != self.offset_unit
+ )
+
+ if params_changed:
+ self.current_time_unit = display_unit
+ self.current_time_scale = display_scale
+ self.offset_time_raw = offset_time
+ self.offset_unit = offset_unit
+ return True
+
+ return False
+
+ def should_use_envelope(self, time_span_raw: np.float32) -> bool:
+ """Determine if envelope mode should be used based on time span."""
+ return time_span_raw > self.envelope_limit
+
+ def should_show_thresholds(self, time_span_raw: np.float32) -> bool:
+ """Determine if threshold lines should be shown based on time span."""
+ return time_span_raw < self.envelope_limit
+
+ def update_time_scale(self, time_span_raw: np.float32) -> bool:
+ """
+ Update time scale based on current view span.
+
+ Returns True if scale changed, False otherwise.
+ """
+ # Delegate to the new utility function for span
+ new_unit, new_scale = _get_optimal_time_unit_and_scale(time_span_raw)
+
+ if new_scale != self.current_time_scale:
+ self.current_time_unit = new_unit
+ self.current_time_scale = new_scale
+ return True
+
+ return False
+
+ def reset_to_original_scale(self) -> None:
+ """Reset time scale to original values."""
+ self.current_time_unit = self.original_time_unit
+ self.current_time_scale = self.original_time_scale
+
+ def reset_to_initial_state(self) -> None:
+ """Reset all display parameters to initial values."""
+ self.current_time_unit = self.original_time_unit
+ self.current_time_scale = self.original_time_scale
+ self.offset_time_raw = None
+ self.offset_unit = None
+ self.current_mode = None
+ self._updating = False
+
+ def set_updating(self, value: bool = True) -> None:
+ """Set updating state to prevent recursion."""
+ self._updating = value
+
+ def is_updating(self) -> bool:
+ """Check if currently updating."""
+ return self._updating
diff --git a/src/scopekit/plot.py b/src/scopekit/plot.py
index 83c10bc..f405432 100644
--- a/src/scopekit/plot.py
+++ b/src/scopekit/plot.py
@@ -1,1829 +1,1621 @@
-from typing import Any, Dict, List, Optional, Tuple, Union
-
-import matplotlib as mpl
-import matplotlib.pyplot as plt
-import numpy as np
-from loguru import logger
-from matplotlib.ticker import MultipleLocator
-
-from .coordinate_manager import CoordinateManager
-from .data_manager import TimeSeriesDataManager
-from .decimation import DecimationManager
-from .display_state import (
- DisplayState,
- _create_time_formatter,
- _get_optimal_time_unit_and_scale,
-)
-
-
-class OscilloscopePlot:
- """
- General-purpose plotting class for time-series data with zoom and decimation.
-
- Uses separate managers for data, decimation, and state to reduce complexity.
- Supports different visualization elements (lines, envelopes, ribbons, regions)
- that can be displayed in different modes (envelope when zoomed out, detail when zoomed in).
- """
-
- # Mode constants
- MODE_ENVELOPE = 1 # Zoomed out mode
- MODE_DETAIL = 2 # Zoomed in mode
- MODE_BOTH = 3 # Both modes
-
- # Default styling constants
- DEFAULT_MAX_PLOT_POINTS = 10000
- DEFAULT_MODE_SWITCH_THRESHOLD = 10e-3 # 10 ms
- DEFAULT_MIN_Y_RANGE_DEFAULT = 1e-9 # Default minimum Y-axis range (e.g., 1 nV)
- DEFAULT_Y_MARGIN_FRACTION = 0.15
- DEFAULT_SIGNAL_LINE_WIDTH = 1.0
- DEFAULT_SIGNAL_ALPHA = 0.75
- DEFAULT_ENVELOPE_ALPHA = 0.75
- DEFAULT_REGION_ALPHA = 0.4
- DEFAULT_REGION_ZORDER = -5
-
- def __init__(
- self,
- t: Union[np.ndarray, List[np.ndarray]],
- x: Union[np.ndarray, List[np.ndarray]],
- name: Union[str, List[str]] = "Waveform",
- trace_colors: Optional[List[str]] = None,
- # Core display parameters
- max_plot_points: int = DEFAULT_MAX_PLOT_POINTS,
- mode_switch_threshold: float = DEFAULT_MODE_SWITCH_THRESHOLD,
- min_y_range: Optional[float] = None, # New parameter for minimum Y-axis range
- y_margin_fraction: float = DEFAULT_Y_MARGIN_FRACTION,
- signal_line_width: float = DEFAULT_SIGNAL_LINE_WIDTH,
- signal_alpha: float = DEFAULT_SIGNAL_ALPHA,
- envelope_alpha: float = DEFAULT_ENVELOPE_ALPHA,
- region_alpha: float = DEFAULT_REGION_ALPHA,
- region_zorder: int = DEFAULT_REGION_ZORDER,
- envelope_window_samples: Optional[int] = None,
- ):
- """
- Initialize the OscilloscopePlot with time series data.
-
- Parameters
- ----------
- t : Union[np.ndarray, List[np.ndarray]]
- Time array(s) (raw time in seconds). Can be a single array shared by all traces
- or a list of arrays, one per trace.
- x : Union[np.ndarray, List[np.ndarray]]
- Signal array(s). If t is a single array, x can be a 2D array (traces x samples)
- or a list of 1D arrays. If t is a list, x must be a list of equal length.
- name : Union[str, List[str]], default="Waveform"
- Name(s) for plot title. Can be a single string or a list of strings.
- trace_colors : Optional[List[str]], default=None
- Colors for each trace. If None, default colors will be used.
- max_plot_points : int, default=10000
- Maximum number of points to display on the plot. Data will be decimated if it exceeds this.
- mode_switch_threshold : float, default=10e-3
- Time span (in seconds) above which the plot switches to envelope mode.
- min_y_range : Optional[float], default=None
- Minimum Y-axis range to enforce. If None, a default small value is used.
- y_margin_fraction : float, default=0.05
- Fraction of data range to add as margin to Y-axis limits.
- signal_line_width : float, default=1.0
- Line width for the raw signal plot.
- signal_alpha : float, default=0.75
- Alpha (transparency) for the raw signal plot.
- envelope_alpha : float, default=1.0
- Alpha (transparency) for the envelope fill.
- region_alpha : float, default=0.4
- Alpha (transparency) for region highlight fills.
- region_zorder : int, default=-5
- Z-order for region highlight fills (lower means further back).
- envelope_window_samples : Optional[int], default=None
- DEPRECATED: Window size in samples for envelope calculation.
- Envelope window is now calculated automatically based on max_plot_points and zoom level.
- This parameter is ignored but kept for backward compatibility.
- """
- # Store styling parameters directly as instance attributes
- self.max_plot_points = max_plot_points
- self.mode_switch_threshold = np.float32(mode_switch_threshold)
- self.min_y_range = (
- np.float32(min_y_range)
- if min_y_range is not None
- else self.DEFAULT_MIN_Y_RANGE_DEFAULT
- )
- self.y_margin_fraction = np.float32(y_margin_fraction)
- self.signal_line_width = signal_line_width
- self.signal_alpha = signal_alpha
- self.envelope_alpha = envelope_alpha
- self.region_alpha = region_alpha
- self.region_zorder = region_zorder
- # envelope_window_samples is now deprecated - envelope window is calculated automatically
- # Keep the parameter for backward compatibility but don't use it
- if envelope_window_samples is not None:
- logger.warning(
- "envelope_window_samples parameter is deprecated. Envelope window is now calculated automatically based on zoom level."
- )
-
- # Initialize managers
- self.data = TimeSeriesDataManager(t, x, name, trace_colors)
- self.decimator = DecimationManager()
-
- # Pre-decimate main signal data for envelope view
- for i in range(self.data.num_traces):
- self.decimator.pre_decimate_data(
- data_id=i, # Use trace_idx as data_id
- t=self.data.t_arrays[i],
- x=self.data.x_arrays[i],
- max_points=self.max_plot_points,
- envelope_window_samples=None, # Envelope window calculated automatically
- )
-
- # Initialize display state using the first trace's time array
- initial_time_unit, initial_time_scale = _get_optimal_time_unit_and_scale(
- self.data.t_arrays[0]
- )
- self.state = DisplayState(
- initial_time_unit, initial_time_scale, self.mode_switch_threshold
- )
-
- # Initialize matplotlib figure and axes to None
- self.fig: Optional[mpl.figure.Figure] = None
- self.ax: Optional[mpl.axes.Axes] = None
-
- # Store visualization elements for each trace
- self._signal_lines: List[mpl.lines.Line2D] = []
- self._envelope_fills: List[Optional[mpl.collections.PolyCollection]] = [
- None
- ] * self.data.num_traces
-
- # Visualization elements with mode control (definitions, not plot objects)
- self._lines: List[List[Dict[str, Any]]] = [
- [] for _ in range(self.data.num_traces)
- ]
- self._ribbons: List[List[Dict[str, Any]]] = [
- [] for _ in range(self.data.num_traces)
- ]
- self._regions: List[List[Dict[str, Any]]] = [
- [] for _ in range(self.data.num_traces)
- ]
- self._envelopes: List[List[Dict[str, Any]]] = [
- [] for _ in range(self.data.num_traces)
- ]
-
- # Line objects for each trace (will be populated as needed during rendering)
- self._line_objects: List[List[mpl.artist.Artist]] = [
- [] for _ in range(self.data.num_traces)
- ] # Changed type hint to Artist
- self._ribbon_objects: List[List[mpl.collections.PolyCollection]] = [
- [] for _ in range(self.data.num_traces)
- ]
- self._region_objects: List[List[mpl.collections.PolyCollection]] = [
- [] for _ in range(self.data.num_traces)
- ]
-
- # Store current plot data for access by other methods
- self._current_plot_data = {}
-
- # Initialize coordinate manager
- self.coord_manager = CoordinateManager(self.state)
-
- # Store initial view for home button (using global time range)
- t_start, t_end = self.data.get_global_time_range()
- self._initial_xlim_raw = (t_start, t_end)
-
- # Legend state for optimization
- self._current_legend_handles: List[mpl.artist.Artist] = []
- self._current_legend_labels: List[str] = []
- self._legend: Optional[mpl.legend.Legend] = None
-
- # Track last mode for each trace to optimize element updates
- self._last_mode: Dict[int, Optional[int]] = {
- i: None for i in range(self.data.num_traces)
- }
-
- # Store original toolbar methods for restoration
- self._original_home = None
- self._original_push_current = None
-
- def save(self, filepath: str) -> None:
- """
- Save the current plot to a file.
-
- Parameters
- ----------
- filepath : str
- Path to save the plot image.
- """
- if self.fig is None or self.ax is None:
- raise RuntimeError("Plot has not been initialized yet.")
- self.fig.savefig(filepath)
- logger.info(f"Plot saved to {filepath}")
-
- def add_line(
- self,
- t: Union[np.ndarray, List[np.ndarray]],
- data: Union[np.ndarray, List[np.ndarray]],
- label: str = "Line",
- color: Optional[str] = None,
- alpha: float = 0.75,
- linestyle: str = "-",
- linewidth: float = 1.0,
- display_mode: int = MODE_BOTH,
- trace_idx: int = 0,
- zorder: int = 5,
- ) -> None:
- """
- Add a line to the plot with mode control.
-
- Parameters
- ----------
- t : Union[np.ndarray, List[np.ndarray]]
- Time array(s) for the line data. Must match the length of data.
- data : Union[np.ndarray, List[np.ndarray]]
- Line data array(s). Can be a single array or a list of arrays.
- label : str, default="Line"
- Label for the legend.
- color : Optional[str], default=None
- Color for the line. If None, the trace color will be used.
- alpha : float, default=0.75
- Alpha (transparency) for the line.
- linestyle : str, default="-"
- Line style.
- linewidth : float, default=1.0
- Line width.
- display_mode : int, default=MODE_BOTH
- Which mode(s) to show this line in (MODE_ENVELOPE, MODE_DETAIL, or MODE_BOTH).
- trace_idx : int, default=0
- Index of the trace to add the line to.
- zorder : int, default=5
- Z-order for the line (higher values appear on top).
- """
- if trace_idx < 0 or trace_idx >= self.data.num_traces:
- raise ValueError(
- f"Invalid trace index: {trace_idx}. Must be between 0 and {self.data.num_traces - 1}."
- )
-
- # Validate data length
- if isinstance(data, list):
- if len(data) != len(t):
- raise ValueError(
- f"Line data length ({len(data)}) must match time array length ({len(t)})."
- )
- else:
- if len(data) != len(t):
- raise ValueError(
- f"Line data length ({len(data)}) must match time array length ({len(t)})."
- )
-
- # Use trace color if none provided
- if color is None:
- color = self.data.get_trace_color(trace_idx)
-
- # Convert inputs to numpy arrays
- t_array = np.asarray(t, dtype=np.float32)
- data_array = np.asarray(data, dtype=np.float32)
-
- # Assign a unique ID for this custom line for pre-decimation caching
- # We use a negative ID to distinguish from main traces (which use 0, 1, 2...)
- # and ensure uniqueness across custom lines.
- line_id = -(len(self._lines[trace_idx]) + 1) # Negative, unique per trace
-
- # Pre-decimate this custom line's data for envelope view
- self.decimator.pre_decimate_data(
- data_id=line_id,
- t=t_array,
- x=data_array,
- max_points=self.max_plot_points,
- envelope_window_samples=None, # Envelope window calculated automatically
- )
-
- # Store line definition with raw data and its assigned ID
- line_def = {
- "id": line_id, # Store the ID for retrieval from decimator
- "t_raw": t_array, # Store raw time array
- "data_raw": data_array, # Store raw data array
- "label": label,
- "color": color,
- "alpha": alpha,
- "linestyle": linestyle,
- "linewidth": linewidth,
- "display_mode": display_mode,
- "zorder": zorder,
- }
-
- logger.debug(
- f"Adding line '{label}' with display_mode={display_mode} (MODE_ENVELOPE={self.MODE_ENVELOPE}, MODE_DETAIL={self.MODE_DETAIL}, MODE_BOTH={self.MODE_BOTH})"
- )
- self._lines[trace_idx].append(line_def)
-
- def add_ribbon(
- self,
- t: Union[np.ndarray, List[np.ndarray]],
- center_data: Union[np.ndarray, List[np.ndarray]],
- width: Union[float, np.ndarray],
- label: str = "Ribbon",
- color: str = "gray",
- alpha: float = 0.6,
- display_mode: int = MODE_DETAIL,
- trace_idx: int = 0,
- zorder: int = 2,
- ) -> None:
- """
- Add a ribbon (center ± width) with mode control.
-
- Parameters
- ----------
- t : Union[np.ndarray, List[np.ndarray]]
- Time array(s) for the ribbon data. Must match the length of center_data.
- center_data : Union[np.ndarray, List[np.ndarray]]
- Center line data array(s). Can be a single array or a list of arrays.
- width : Union[float, np.ndarray]
- Width of the ribbon. Can be a single value or an array matching center_data.
- label : str, default="Ribbon"
- Label for the legend.
- color : str, default="gray"
- Color for the ribbon.
- alpha : float, default=0.6
- Alpha (transparency) for the ribbon.
- display_mode : int, default=MODE_DETAIL
- Which mode(s) to show this ribbon in (MODE_ENVELOPE, MODE_DETAIL, or MODE_BOTH).
- trace_idx : int, default=0
- Index of the trace to add the ribbon to.
- """
- if trace_idx < 0 or trace_idx >= self.data.num_traces:
- raise ValueError(
- f"Invalid trace index: {trace_idx}. Must be between 0 and {self.data.num_traces - 1}."
- )
-
- # Validate data length
- if isinstance(center_data, list):
- if len(center_data) != len(t):
- raise ValueError(
- f"Ribbon center data length ({len(center_data)}) must match time array length ({len(t)})."
- )
- else:
- if len(center_data) != len(t):
- raise ValueError(
- f"Ribbon center data length ({len(center_data)}) must match time array length ({len(t)})."
- )
-
- # Convert center data to numpy array
- center_data = np.asarray(center_data, dtype=np.float32)
-
- # Handle width as scalar or array
- if isinstance(width, (int, float, np.number)):
- width_array = np.ones_like(center_data) * width
- else:
- if len(width) != len(center_data):
- raise ValueError(
- f"Ribbon width array length ({len(width)}) must match center data length ({len(center_data)})."
- )
- width_array = np.asarray(width, dtype=np.float32)
-
- # Assign a unique ID for this custom ribbon
- ribbon_id = -(
- len(self._ribbons[trace_idx]) + 1001
- ) # Negative, unique per trace, offset from lines
-
- # Pre-decimate this custom ribbon's center data for envelope view
- # We only pre-decimate the center, as width is applied later
- self.decimator.pre_decimate_data(
- data_id=ribbon_id,
- t=np.asarray(t, dtype=np.float32),
- x=center_data,
- max_points=self.max_plot_points,
- envelope_window_samples=None, # Envelope window calculated automatically
- )
-
- # Store ribbon definition
- ribbon_def = {
- "id": ribbon_id,
- "t_raw": np.asarray(t, dtype=np.float32),
- "center_data_raw": center_data,
- "width_raw": width_array,
- "label": label,
- "color": color,
- "alpha": alpha,
- "display_mode": display_mode,
- "zorder": zorder,
- }
-
- self._ribbons[trace_idx].append(ribbon_def)
-
- def add_envelope(
- self,
- min_data: Union[np.ndarray, List[np.ndarray]],
- max_data: Union[np.ndarray, List[np.ndarray]],
- label: str = "Envelope",
- color: Optional[str] = None,
- alpha: float = 0.4,
- display_mode: int = MODE_ENVELOPE,
- trace_idx: int = 0,
- zorder: int = 1,
- ) -> None:
- """
- Add envelope data with mode control.
-
- Parameters
- ----------
- min_data : Union[np.ndarray, List[np.ndarray]]
- Minimum envelope data array(s). Can be a single array or a list of arrays.
- max_data : Union[np.ndarray, List[np.ndarray]]
- Maximum envelope data array(s). Can be a single array or a list of arrays.
- label : str, default="Envelope"
- Label for the legend.
- color : Optional[str], default=None
- Color for the envelope. If None, the trace color will be used.
- alpha : float, default=0.4
- Alpha (transparency) for the envelope.
- display_mode : int, default=MODE_ENVELOPE
- Which mode(s) to show this envelope in (MODE_ENVELOPE, MODE_DETAIL, or MODE_BOTH).
- trace_idx : int, default=0
- Index of the trace to add the envelope to.
- """
- if trace_idx < 0 or trace_idx >= self.data.num_traces:
- raise ValueError(
- f"Invalid trace index: {trace_idx}. Must be between 0 and {self.data.num_traces - 1}."
- )
-
- # Validate data length
- if isinstance(min_data, list):
- if len(min_data) != len(self.data.t_arrays[trace_idx]):
- raise ValueError(
- f"Envelope min data length ({len(min_data)}) must match time array length ({len(self.data.t_arrays[trace_idx])})."
- )
- else:
- if len(min_data) != len(self.data.t_arrays[trace_idx]):
- raise ValueError(
- f"Envelope min data length ({len(min_data)}) must match time array length ({len(self.data.t_arrays[trace_idx])})."
- )
-
- if isinstance(max_data, list):
- if len(max_data) != len(self.data.t_arrays[trace_idx]):
- raise ValueError(
- f"Envelope max data length ({len(max_data)}) must match time array length ({len(self.data.t_arrays[trace_idx])})."
- )
- else:
- if len(max_data) != len(self.data.t_arrays[trace_idx]):
- raise ValueError(
- f"Envelope max data length ({len(max_data)}) must match time array length ({len(self.data.t_arrays[trace_idx])})."
- )
-
- # Use trace color if none provided
- if color is None:
- color = self.data.get_trace_color(trace_idx)
-
- # Assign a unique ID for this custom envelope
- envelope_id = -(
- len(self._envelopes[trace_idx]) + 2001
- ) # Negative, unique per trace, offset from ribbons
-
- # Pre-decimate this custom envelope's data for envelope view
- # We'll pre-decimate the average of min/max, and store min/max separately
- t_raw = self.data.t_arrays[trace_idx]
- avg_data = (
- np.asarray(min_data, dtype=np.float32)
- + np.asarray(max_data, dtype=np.float32)
- ) / 2
-
- self.decimator.pre_decimate_data(
- data_id=envelope_id,
- t=t_raw,
- x=avg_data, # Pass average for decimation
- max_points=self.max_plot_points,
- envelope_window_samples=None, # Envelope window calculated automatically
- )
-
- # Store envelope definition
- envelope_def = {
- "id": envelope_id,
- "t_raw": t_raw,
- "min_data_raw": np.asarray(min_data, dtype=np.float32),
- "max_data_raw": np.asarray(max_data, dtype=np.float32),
- "label": label,
- "color": color,
- "alpha": alpha,
- "display_mode": display_mode,
- "zorder": zorder,
- }
-
- self._envelopes[trace_idx].append(envelope_def)
-
- def add_regions(
- self,
- regions: np.ndarray,
- label: str = "Regions",
- color: str = "crimson",
- alpha: float = 0.4,
- display_mode: int = MODE_BOTH,
- trace_idx: int = 0,
- zorder: int = -5,
- ) -> None:
- """
- Add region highlights with mode control.
-
- Parameters
- ----------
- regions : np.ndarray
- Region data array with shape (N, 2) where each row is [start_time, end_time].
- label : str, default="Regions"
- Label for the legend.
- color : str, default="crimson"
- Color for the regions.
- alpha : float, default=0.4
- Alpha (transparency) for the regions.
- display_mode : int, default=MODE_BOTH
- Which mode(s) to show these regions in (MODE_ENVELOPE, MODE_DETAIL, or MODE_BOTH).
- trace_idx : int, default=0
- Index of the trace to add the regions to.
- """
- if trace_idx < 0 or trace_idx >= self.data.num_traces:
- raise ValueError(
- f"Invalid trace index: {trace_idx}. Must be between 0 and {self.data.num_traces - 1}."
- )
-
- # Validate regions array
- if regions.ndim != 2 or regions.shape[1] != 2:
- raise ValueError(
- f"Regions array must have shape (N, 2), got {regions.shape}."
- )
-
- # Store regions definition
- region_def = {
- "regions": np.asarray(regions, dtype=np.float32),
- "label": label,
- "color": color,
- "alpha": alpha,
- "display_mode": display_mode,
- "zorder": zorder,
- }
-
- logger.debug(
- f"Adding regions '{label}' with {len(regions)} entries, display_mode={display_mode}"
- )
- self._regions[trace_idx].append(region_def)
-
- def _update_signal_display(
- self,
- trace_idx: int,
- t_display: np.ndarray,
- x_data: np.ndarray,
- envelope_data: Optional[Tuple[np.ndarray, np.ndarray]] = None,
- ) -> None:
- """
- Update signal display with envelope or raw data for a specific trace.
-
- Parameters
- ----------
- trace_idx : int
- Index of the trace to update.
- t_display : np.ndarray
- Display time array.
- x_data : np.ndarray
- Signal data array.
- envelope_data : Optional[Tuple[np.ndarray, np.ndarray]], default=None
- Tuple of (min, max) envelope data if in envelope mode.
- """
- logger.debug(f"=== _update_signal_display trace {trace_idx} ===")
- logger.debug(
- f"t_display: len={len(t_display)}, range=[{np.min(t_display) if len(t_display) > 0 else 'empty':.6f}, {np.max(t_display) if len(t_display) > 0 else 'empty':.6f}]"
- )
- logger.debug(
- f"x_data: len={len(x_data)}, range=[{np.min(x_data) if len(x_data) > 0 else 'empty':.6f}, {np.max(x_data) if len(x_data) > 0 else 'empty':.6f}]"
- )
- logger.debug(f"envelope_data: {envelope_data is not None}")
-
- if envelope_data is not None:
- x_min, x_max = envelope_data
- logger.debug(
- f"envelope x_min: len={len(x_min)}, range=[{np.min(x_min) if len(x_min) > 0 else 'empty':.6f}, {np.max(x_min) if len(x_min) > 0 else 'empty':.6f}]"
- )
- logger.debug(
- f"envelope x_max: len={len(x_max)}, range=[{np.max(x_max) if len(x_max) > 0 else 'empty':.6f}, {np.max(x_max) if len(x_max) > 0 else 'empty':.6f}]"
- )
- self._show_envelope_mode(trace_idx, t_display, envelope_data)
- else:
- logger.debug("Showing detail mode (raw signal)")
- self._show_detail_mode(trace_idx, t_display, x_data)
-
- def _show_envelope_mode(
- self,
- trace_idx: int,
- t_display: np.ndarray,
- envelope_data: Tuple[np.ndarray, np.ndarray],
- ) -> None:
- """
- Show envelope display mode for a specific trace.
-
- Parameters
- ----------
- trace_idx : int
- Index of the trace to update.
- t_display : np.ndarray
- Display time array.
- envelope_data : Tuple[np.ndarray, np.ndarray]
- Tuple of (min, max) envelope data.
- """
- logger.debug(f"=== _show_envelope_mode trace {trace_idx} ===")
- x_min, x_max = envelope_data
- color = self.data.get_trace_color(trace_idx)
- name = self.data.get_trace_name(trace_idx)
-
- logger.debug(f"Envelope data: x_min len={len(x_min)}, x_max len={len(x_max)}")
- logger.debug(
- f"t_display range: [{np.min(t_display):.6f}, {np.max(t_display):.6f}]"
- )
- logger.debug(f"y_range: [{np.min(x_min):.6f}, {np.max(x_max):.6f}]")
-
- # Clean up previous displays
- if self._envelope_fills[trace_idx] is not None:
- logger.debug("Removing previous envelope fill")
- self._envelope_fills[trace_idx].remove()
-
- logger.debug("Hiding signal line")
- self._signal_lines[trace_idx].set_data([], [])
- self._signal_lines[trace_idx].set_visible(False)
-
- # Show built-in envelope
- logger.debug(
- f"Creating envelope fill with color={color}, alpha={self.envelope_alpha}"
- )
- self._envelope_fills[trace_idx] = self.ax.fill_between(
- t_display,
- x_min,
- x_max,
- alpha=self.envelope_alpha,
- color=color,
- lw=0.1,
- label=f"Raw envelope ({name})"
- if self.data.num_traces > 1
- else "Raw envelope",
- zorder=1, # Keep default envelope at zorder=1
- )
-
- # Set current mode
- self.state.current_mode = "envelope"
- logger.debug("Set current_mode to 'envelope'")
-
- # Show any custom elements for this mode
- self._show_custom_elements(trace_idx, t_display, self.MODE_ENVELOPE)
-
- def _show_detail_mode(
- self, trace_idx: int, t_display: np.ndarray, x_data: np.ndarray
- ) -> None:
- """
- Show detail display mode for a specific trace.
-
- Parameters
- ----------
- trace_idx : int
- Index of the trace to update.
- t_display : np.ndarray
- Display time array.
- x_data : np.ndarray
- Signal data array.
- """
- logger.debug(f"=== _show_detail_mode trace {trace_idx} ===")
- logger.debug(
- f"t_display: len={len(t_display)}, range=[{np.min(t_display) if len(t_display) > 0 else 'empty':.6f}, {np.max(t_display) if len(t_display) > 0 else 'empty':.6f}]"
- )
- logger.debug(
- f"x_data: len={len(x_data)}, range=[{np.min(x_data) if len(x_data) > 0 else 'empty':.6f}, {np.max(x_data) if len(x_data) > 0 else 'empty':.6f}]"
- )
-
- # Clean up envelope
- if self._envelope_fills[trace_idx] is not None:
- logger.debug("Removing envelope fill")
- self._envelope_fills[trace_idx].remove()
- self._envelope_fills[trace_idx] = None
-
- # Update signal line
- line = self._signal_lines[trace_idx]
- logger.debug(
- f"Setting signal line data: linewidth={self.signal_line_width}, alpha={self.signal_alpha}"
- )
- line.set_data(t_display, x_data)
- line.set_linewidth(self.signal_line_width)
- line.set_alpha(self.signal_alpha)
- line.set_visible(True)
-
- # Set current mode
- self.state.current_mode = "detail"
- logger.debug("Set current_mode to 'detail'")
-
- # Show any custom elements for this mode
- self._show_custom_elements(trace_idx, t_display, self.MODE_DETAIL)
-
- def _show_custom_elements(
- self, trace_idx: int, t_display: np.ndarray, current_mode: int
- ) -> None:
- """
- Show custom visualization elements for the current mode.
-
- Parameters
- ----------
- trace_idx : int
- Index of the trace to update.
- t_display : np.ndarray
- Display time array.
- current_mode : int
- Current display mode (MODE_ENVELOPE or MODE_DETAIL).
- """
- logger.debug(
- f"=== _show_custom_elements trace {trace_idx}, current_mode={current_mode} ==="
- )
-
- last_mode = self._last_mode.get(trace_idx)
- logger.debug(f"Last mode for trace {trace_idx}: {last_mode}")
-
- # Always clear and recreate elements when view changes, regardless of mode change
- # This ensures custom lines/ribbons are redrawn correctly with current view data
- logger.debug(
- f"Clearing and recreating elements for trace {trace_idx} (mode: {last_mode} -> {current_mode})"
- )
- self._clear_custom_elements(trace_idx)
-
- # Get current raw x-limits from the main plot data
- # This is crucial for decimating custom lines to the current view
- current_xlim_raw = self.coord_manager.get_current_view_raw(self.ax)
-
- # Show lines for current mode
- line_objects = []
- for i, line_def in enumerate(self._lines[trace_idx]):
- logger.debug(
- f"Processing line {i} ('{line_def['label']}'): display_mode={line_def['display_mode']}, current_mode={current_mode}"
- )
- if (
- line_def["display_mode"] & current_mode
- ): # Bitwise check if mode is enabled
- logger.debug(
- f"Line {i} ('{line_def['label']}') should be visible in mode {current_mode}"
- )
-
- # Dynamically decimate the line data for the current view
- # Use the same max_plot_points as the main signal for consistency
- # For custom lines, we want mean decimation if in envelope mode, not min/max envelope
- t_line_raw, line_data, _, _ = self.decimator.decimate_for_view(
- line_def["t_raw"],
- line_def["data_raw"],
- current_xlim_raw, # Decimate to current view
- self.max_plot_points,
- use_envelope=(current_mode == self.MODE_ENVELOPE),
- data_id=line_def[
- "id"
- ], # Pass the custom line's ID for pre-decimated data lookup
- envelope_window_samples=None, # Envelope window calculated automatically
- mode_switch_threshold=self.mode_switch_threshold, # Pass mode switch threshold
- return_envelope_min_max=False, # Custom lines never return min/max envelope
- )
-
- if len(t_line_raw) == 0 or len(line_data) == 0:
- logger.warning(
- f"Line {i} ('{line_def['label']}') has empty data after decimation for current view, skipping plot."
- )
- continue
-
- # Make sure the time array is in display coordinates
- t_line_display = self.coord_manager.raw_to_display(t_line_raw)
-
- # Always plot as a regular line
- (line,) = self.ax.plot(
- t_line_display,
- line_data,
- label=line_def["label"],
- color=line_def["color"],
- alpha=line_def["alpha"],
- linestyle=line_def["linestyle"],
- linewidth=line_def["linewidth"],
- zorder=line_def["zorder"],
- )
- line_objects.append(
- (line, line_def)
- ) # Store both the line and its definition
- logger.debug(f"Added line {i} ('{line_def['label']}') to plot")
- else:
- logger.debug(
- f"Line {i} ('{line_def['label']}') should NOT be visible in mode {current_mode}"
- )
-
- # Show ribbons for current mode
- ribbon_objects = []
- for ribbon_def in self._ribbons[trace_idx]:
- logger.debug(
- f"Processing ribbon ('{ribbon_def['label']}'): display_mode={ribbon_def['display_mode']}, current_mode={current_mode}"
- )
- if ribbon_def["display_mode"] & current_mode:
- logger.debug(
- f"Ribbon ('{ribbon_def['label']}') should be visible in mode {current_mode}"
- )
-
- # Ribbons are always plotted as fills, so we need to decimate their center and width
- # We'll treat the center_data as the 'signal' for decimation purposes
- (
- t_ribbon_raw,
- center_data_decimated,
- min_center_envelope,
- max_center_envelope,
- ) = self.decimator.decimate_for_view(
- ribbon_def["t_raw"],
- ribbon_def["center_data_raw"],
- current_xlim_raw,
- self.max_plot_points,
- use_envelope=(
- current_mode == self.MODE_ENVELOPE
- ), # Use envelope for ribbons if in envelope mode
- data_id=ribbon_def[
- "id"
- ], # Pass the custom ribbon's ID for pre-decimated data lookup
- return_envelope_min_max=True, # Ribbons always need min/max to draw fill
- envelope_window_samples=None, # Envelope window calculated automatically
- mode_switch_threshold=self.mode_switch_threshold,
- )
-
- # Decimate the width array as well, if it's an array
- width_decimated = ribbon_def["width_raw"]
- if len(ribbon_def["width_raw"]) > len(
- t_ribbon_raw
- ): # If raw width is longer than decimated time
- # For simplicity, we'll just take the mean of the width in each bin
- # A more robust solution might involve passing width as another data stream to decimate_for_view
- # For now, we'll manually decimate it based on the t_ribbon_raw indices
- # Find indices in raw data corresponding to decimated time points
- # This is a simplified approach and assumes uniform sampling for width
- indices = np.searchsorted(ribbon_def["t_raw"], t_ribbon_raw)
- indices = np.clip(indices, 0, len(ribbon_def["width_raw"]) - 1)
- width_decimated = ribbon_def["width_raw"][indices]
-
- # If the ribbon was decimated to an envelope, use that for min/max
- if (
- current_mode == self.MODE_ENVELOPE
- and min_center_envelope is not None
- and max_center_envelope is not None
- ):
- lower_bound = min_center_envelope - width_decimated
- upper_bound = max_center_envelope + width_decimated
- else:
- lower_bound = center_data_decimated - width_decimated
- upper_bound = center_data_decimated + width_decimated
-
- if len(t_ribbon_raw) == 0 or len(lower_bound) == 0:
- logger.warning(
- f"Ribbon ('{ribbon_def['label']}') has empty data after decimation, skipping plot."
- )
- continue
-
- # Make sure the time array is in display coordinates
- t_ribbon_display = self.coord_manager.raw_to_display(t_ribbon_raw)
-
- ribbon = self.ax.fill_between(
- t_ribbon_display,
- lower_bound,
- upper_bound,
- color=ribbon_def["color"],
- alpha=ribbon_def["alpha"],
- label=ribbon_def["label"],
- zorder=ribbon_def["zorder"],
- )
- ribbon_objects.append(
- (ribbon, ribbon_def)
- ) # Store both the ribbon and its definition
- logger.debug(f"Added ribbon ('{ribbon_def['label']}') to plot")
- else:
- logger.debug(
- f"Ribbon ('{ribbon_def['label']}') should NOT be visible in mode {current_mode}"
- )
-
- # Show custom envelopes for current mode
- for envelope_def in self._envelopes[trace_idx]:
- logger.debug(
- f"Processing custom envelope ('{envelope_def['label']}'): display_mode={envelope_def['display_mode']}, current_mode={current_mode}"
- )
- if envelope_def["display_mode"] & current_mode:
- logger.debug(
- f"Custom envelope ('{envelope_def['label']}') should be visible in mode {current_mode}"
- )
-
- # For custom envelopes, we need to handle min/max data specially
- # We'll decimate the min and max data separately using the envelope's stored data
- # Since we stored min/max in the pre-decimated data, we can retrieve them
-
- # Get the pre-decimated envelope data for this custom envelope
- if envelope_def["id"] in self.decimator._pre_decimated_envelopes:
- pre_dec_data = self.decimator._pre_decimated_envelopes[
- envelope_def["id"]
- ]
- # The min/max data was stored in bg_initial/bg_clean during pre-decimation
- t_envelope_raw, _, min_data_decimated, max_data_decimated = (
- self.decimator.decimate_for_view(
- envelope_def["t_raw"],
- (
- envelope_def["min_data_raw"]
- + envelope_def["max_data_raw"]
- )
- / 2, # Average for decimation
- current_xlim_raw,
- self.max_plot_points,
- use_envelope=True, # Always treat custom envelopes as envelopes
- data_id=envelope_def[
- "id"
- ], # Pass the custom envelope's ID for pre-decimated data lookup
- return_envelope_min_max=True, # Custom envelopes always need min/max to draw fill
- envelope_window_samples=None, # Envelope window calculated automatically
- mode_switch_threshold=self.mode_switch_threshold,
- )
- )
- # For custom envelopes, the min/max are returned directly as the last two return values
- else:
- # Fallback if no pre-decimated data
- logger.warning(
- f"No pre-decimated data for custom envelope {envelope_def['id']}, using raw decimation"
- )
- t_envelope_raw, _, min_data_decimated, max_data_decimated = (
- self.decimator.decimate_for_view(
- envelope_def["t_raw"],
- (
- envelope_def["min_data_raw"]
- + envelope_def["max_data_raw"]
- )
- / 2,
- current_xlim_raw,
- self.max_plot_points,
- use_envelope=True,
- data_id=None, # No pre-decimated data available
- return_envelope_min_max=True,
- envelope_window_samples=None, # Envelope window calculated automatically
- mode_switch_threshold=self.mode_switch_threshold,
- )
- )
-
- if (
- len(t_envelope_raw) == 0
- or min_data_decimated is None
- or max_data_decimated is None
- or len(min_data_decimated) == 0
- ):
- logger.warning(
- f"Custom envelope ('{envelope_def['label']}') has empty data after decimation, skipping plot."
- )
- continue
-
- t_envelope_display = self.coord_manager.raw_to_display(t_envelope_raw)
-
- envelope = self.ax.fill_between(
- t_envelope_display,
- min_data_decimated,
- max_data_decimated,
- color=envelope_def["color"],
- alpha=envelope_def["alpha"],
- label=envelope_def["label"],
- zorder=envelope_def["zorder"],
- )
- ribbon_objects.append(
- (envelope, envelope_def)
- ) # Store in ribbon objects
- logger.debug(
- f"Added custom envelope ('{envelope_def['label']}') to plot"
- )
- else:
- logger.debug(
- f"Custom envelope ('{envelope_def['label']}') should NOT be visible in mode {current_mode}"
- )
-
- # Store objects with their definitions for future updates
- self._line_objects[trace_idx] = line_objects
- self._ribbon_objects[trace_idx] = ribbon_objects
-
- # Update last mode AFTER processing
- self._last_mode[trace_idx] = current_mode
-
- def _update_element_visibility(self, trace_idx: int, current_mode: int) -> None:
- """
- Update visibility of existing custom elements based on current mode.
-
- Parameters
- ----------
- trace_idx : int
- Index of the trace to update.
- current_mode : int
- Current display mode (MODE_ENVELOPE or MODE_DETAIL).
- """
- logger.debug(
- f"Updating element visibility for trace {trace_idx}, current_mode={current_mode}"
- )
- # Update line visibility
- for line_obj, line_def in self._line_objects[trace_idx]:
- should_be_visible = bool(line_def["display_mode"] & current_mode)
- if line_obj.get_visible() != should_be_visible:
- line_obj.set_visible(should_be_visible)
- logger.debug(
- f"Set visibility of line '{line_def['label']}' to {should_be_visible}"
- )
-
- # Update ribbon visibility
- for ribbon_obj, ribbon_def in self._ribbon_objects[trace_idx]:
- should_be_visible = bool(ribbon_def["display_mode"] & current_mode)
- if ribbon_obj.get_visible() != should_be_visible:
- ribbon_obj.set_visible(should_be_visible)
- logger.debug(
- f"Set visibility of ribbon '{ribbon_def['label']}' to {should_be_visible}"
- )
-
- def _clear_custom_elements(self, trace_idx: int) -> None:
- """
- Clear all custom visualization elements for a trace.
-
- Parameters
- ----------
- trace_idx : int
- Index of the trace to clear elements for.
- """
- logger.debug(f"Clearing custom elements for trace {trace_idx}")
- # Clear lines
- for line_obj, _ in self._line_objects[trace_idx]:
- line_obj.remove()
- self._line_objects[trace_idx].clear()
-
- # Clear ribbons
- for ribbon_obj, _ in self._ribbon_objects[trace_idx]:
- ribbon_obj.remove()
- self._ribbon_objects[trace_idx].clear()
-
- def _update_tick_locator(self, time_span_raw: np.float32) -> None:
- """Update tick locator based on current time scale and span."""
- if self.state.current_time_scale >= np.float32(1e6): # microseconds or smaller
- # For microsecond scale, use reasonable intervals
- tick_interval = max(
- 1, int(time_span_raw * self.state.current_time_scale / 10)
- )
- self.ax.xaxis.set_major_locator(MultipleLocator(tick_interval))
- else:
- # For larger scales, use matplotlib's default auto locator
- self.ax.xaxis.set_major_locator(mpl.ticker.AutoLocator())
-
- def _update_legend(self) -> None:
- """Updates the plot legend, filtering out invisible elements and optimising rebuilds."""
- logger.debug("Updating legend...")
- handles, labels = self.ax.get_legend_handles_labels()
-
- # Filter for unique and visible handles/labels
- unique_labels = []
- unique_handles = []
- for h, l in zip(handles, labels):
- # Check if the handle has a get_visible method and if it returns True
- # For fill_between objects (ribbons, envelopes, regions), get_visible might not exist or behave differently
- # For these, we assume they are visible if they are in the list of objects
- is_visible = True
- if hasattr(h, "get_visible"):
- is_visible = h.get_visible()
- elif isinstance(
- h, mpl.collections.PolyCollection
- ): # For fill_between objects
- # PolyCollection doesn't have get_visible, but its patches might.
- # Or we can assume it's visible if it's part of the current plot.
- # For now, assume it's visible if it's a PolyCollection and has data.
- is_visible = len(h.get_paths()) > 0 # Check if it has any paths to draw
-
- if l not in unique_labels and is_visible:
- unique_labels.append(l)
- unique_handles.append(h)
-
- logger.debug(f"Unique visible legend items found: {unique_labels}")
-
- # Create a hash of current handles/labels for efficient comparison
- current_hash = hash(tuple(id(h) for h in unique_handles) + tuple(unique_labels))
-
- # Check if legend content actually changed
- if (
- not hasattr(self, "_last_legend_hash")
- or self._last_legend_hash != current_hash
- ):
- logger.debug("Legend content changed, rebuilding legend.")
- if self._legend is not None:
- self._legend.remove() # Remove old legend to prevent duplicates
-
- if unique_handles: # Only create legend if there are handles to show
- self._legend = self.ax.legend(
- unique_handles, unique_labels, loc="lower right"
- )
- logger.debug("New legend created.")
- else:
- self._legend = None # No legend to show
- logger.debug("No legend to show.")
-
- self._current_legend_handles = unique_handles
- self._current_legend_labels = unique_labels
- self._last_legend_hash = current_hash
- else:
- logger.debug("Legend content unchanged, skipping rebuild.")
-
- def _clear_navigation_history(self):
- """Clear matplotlib's navigation history when coordinate system changes."""
- if (
- self.fig
- and self.fig.canvas
- and hasattr(self.fig.canvas, "toolbar")
- and self.fig.canvas.toolbar
- ):
- toolbar = self.fig.canvas.toolbar
- if hasattr(toolbar, "_nav_stack"):
- toolbar._nav_stack.clear()
-
- def _push_current_view(self):
- """Push current view to navigation history as new base."""
- if (
- self.fig
- and self.fig.canvas
- and hasattr(self.fig.canvas, "toolbar")
- and self.fig.canvas.toolbar
- ):
- toolbar = self.fig.canvas.toolbar
- if hasattr(toolbar, "push_current"):
- toolbar.push_current()
-
- def _update_axis_formatting(self) -> None:
- """Update axis labels and formatters."""
- if self.state.offset_time_raw is not None:
- offset_value = self.state.offset_time_raw * (
- 1e3
- if self.state.offset_unit == "ms"
- else 1e6
- if self.state.offset_unit == "us"
- else 1e9
- if self.state.offset_unit == "ns"
- else 1.0
- )
- xlabel = f"Time ({self.state.current_time_unit}) + {offset_value:.3g} {self.state.offset_unit}"
- else:
- xlabel = f"Time ({self.state.current_time_unit})"
-
- self.ax.set_xlabel(xlabel)
-
- formatter = _create_time_formatter(
- self.state.offset_time_raw, self.state.current_time_scale
- )
- self.ax.xaxis.set_major_formatter(formatter)
-
- def _update_overlay_lines(
- self, plot_data: Dict[str, Any], show_overlays: bool
- ) -> None:
- """Update overlay lines based on zoom level and data availability."""
- # Clear existing overlay lines from the plot
- # This method is not currently used in the provided code, but if it were,
- # it would need to be updated to use the new decimation strategy.
- # For now, leaving it as is, assuming it's a placeholder or for future_use.
- # If it were to be used, it would need to call decimate_for_view for each overlay line.
- pass # No _overlay_lines attribute in this class, this method is unused.
-
- def _update_y_limits(self, plot_data: Dict[str, Any], use_envelope: bool) -> None:
- """Update y-axis limits to fit current data."""
- y_min_data = float("inf")
- y_max_data = float("-inf")
-
- # Process each trace
- for trace_idx in range(self.data.num_traces):
- x_new_key = f"x_new_{trace_idx}"
- x_min_key = f"x_min_{trace_idx}"
- x_max_key = f"x_max_{trace_idx}"
-
- if x_new_key not in plot_data:
- continue
-
- # Include signal data
- if len(plot_data[x_new_key]) > 0:
- y_min_data = min(y_min_data, np.min(plot_data[x_new_key]))
- y_max_data = max(y_max_data, np.max(plot_data[x_new_key]))
-
- # Include envelope data if available
- if use_envelope and x_min_key in plot_data and x_max_key in plot_data:
- if (
- plot_data[x_min_key] is not None
- and plot_data[x_max_key] is not None
- and len(plot_data[x_min_key]) > 0
- ):
- y_min_data = min(y_min_data, np.min(plot_data[x_min_key]))
- y_max_data = max(y_max_data, np.max(plot_data[x_max_key]))
-
- # Include custom lines
- for line_obj, _ in self._line_objects[trace_idx]:
- # Check if line_obj is a Line2D or PolyCollection
- if isinstance(line_obj, mpl.lines.Line2D):
- y_data = line_obj.get_ydata()
- if len(y_data) > 0:
- y_min_data = min(y_min_data, np.min(y_data))
- y_max_data = max(y_max_data, np.max(y_data))
- elif isinstance(line_obj, mpl.collections.PolyCollection):
- # For fill_between objects, iterate through paths to get y-coordinates
- for path in line_obj.get_paths():
- vertices = path.vertices
- if len(vertices) > 0:
- y_min_data = min(y_min_data, np.min(vertices[:, 1]))
- y_max_data = max(y_max_data, np.max(vertices[:, 1]))
-
- # Include ribbon data
- for ribbon_obj, _ in self._ribbon_objects[trace_idx]:
- # For fill_between objects, we need to get the paths
- if hasattr(ribbon_obj, "get_paths") and len(ribbon_obj.get_paths()) > 0:
- for path in ribbon_obj.get_paths():
- vertices = path.vertices
- if len(vertices) > 0:
- y_min_data = min(y_min_data, np.min(vertices[:, 1]))
- y_max_data = max(y_max_data, np.max(vertices[:, 1]))
-
- # Handle case where no data was found
- if y_min_data == float("inf") or y_max_data == float("-inf"):
- self.ax.set_ylim(0, 1)
- return
-
- data_range = y_max_data - y_min_data
- data_mean = (y_min_data + y_max_data) / 2
-
- # Use min_y_range to ensure a minimum visible range
- min_visible_range = self.min_y_range
-
- if data_range < min_visible_range:
- y_min = data_mean - min_visible_range / 2
- y_max = data_mean + min_visible_range / 2
- else:
- y_margin = self.y_margin_fraction * data_range
- y_min = y_min_data - y_margin
- y_max = y_max_data + y_margin
-
- logger.debug(
- f"Y-limit calculation details: data_range={data_range:.3g}, min_visible_range={min_visible_range:.3g}, data_mean={data_mean:.3g}"
- ) # ADDED THIS LINE
- logger.debug(
- f"Pre-set Y-limits: y_min={y_min:.9f}, y_max={y_max:.9f}"
- ) # ADDED THIS LINE
- self.ax.set_ylim(y_min, y_max)
-
- def _update_plot_data(self, ax_obj) -> None:
- """Update plot based on current view."""
- if self.state.is_updating():
- return
-
- self.state.set_updating(True)
-
- try:
- try:
- # Add debug logging for current axis limits
- display_xlim = ax_obj.get_xlim()
- logger.debug(f"Current display xlim: {display_xlim}")
-
- view_params = self._calculate_view_parameters(ax_obj)
- logger.debug(
- f"Calculated view parameters: xlim_raw={view_params['xlim_raw']}, time_span_raw={view_params['time_span_raw']}, use_envelope={view_params['use_envelope']}"
- )
-
- plot_data = self._get_plot_data(view_params)
-
- # Debug data availability
- data_summary = {}
- for trace_idx in range(self.data.num_traces):
- t_key = f"t_display_{trace_idx}"
- if t_key in plot_data:
- data_summary[t_key] = len(plot_data[t_key])
- logger.debug(f"Plot data summary: {data_summary}")
-
- self._render_plot_elements(plot_data, view_params)
- self._update_regions_and_legend(view_params["xlim_display"])
- self.fig.canvas.draw_idle()
- except Exception as e:
- logger.exception(f"Error updating plot: {e}")
- # Try to recover by resetting to home view
- logger.info("Attempting to recover by resetting to home view")
- self.home()
- finally:
- self.state.set_updating(False)
-
- def _calculate_view_parameters(self, ax_obj) -> Dict[str, Any]:
- """Calculate view parameters from current axis state."""
- try:
- xlim_raw = self.coord_manager.get_current_view_raw(ax_obj)
-
- # Validate xlim_raw values
- if not np.isfinite(xlim_raw[0]) or not np.isfinite(xlim_raw[1]):
- logger.warning(
- f"Invalid xlim_raw from axis: {xlim_raw}. Using initial view."
- )
- xlim_raw = self._initial_xlim_raw
-
- # Ensure xlim_raw is in ascending order
- if xlim_raw[0] > xlim_raw[1]:
- logger.warning(f"xlim_raw values out of order: {xlim_raw}. Swapping.")
- xlim_raw = (xlim_raw[1], xlim_raw[0])
-
- time_span_raw = xlim_raw[1] - xlim_raw[0]
- use_envelope = self.state.should_use_envelope(time_span_raw)
- current_mode = self.MODE_ENVELOPE if use_envelope else self.MODE_DETAIL
-
- logger.debug(f"=== _calculate_view_parameters ===")
- logger.debug(f"xlim_raw: {xlim_raw}")
- logger.debug(f"time_span_raw: {time_span_raw:.6e}s")
- logger.debug(
- f"envelope_limit: {self.mode_switch_threshold:.6e}s"
- ) # Use mode_switch_threshold
- logger.debug(f"use_envelope: {use_envelope}")
- logger.debug(
- f"current_mode: {current_mode} ({'ENVELOPE' if current_mode == self.MODE_ENVELOPE else 'DETAIL'})"
- )
-
- # Update coordinate system if needed
- coordinate_system_changed = self.state.update_display_params(
- xlim_raw, time_span_raw
- )
- if coordinate_system_changed:
- logger.debug("Coordinate system changed, updating")
- self._update_coordinate_system(xlim_raw, time_span_raw)
-
- return {
- "xlim_raw": xlim_raw,
- "time_span_raw": time_span_raw,
- "xlim_display": self.coord_manager.xlim_raw_to_display(xlim_raw),
- "use_envelope": use_envelope,
- "current_mode": current_mode,
- }
- except Exception as e:
- logger.exception(f"Error calculating view parameters: {e}")
- # Return safe default values
- return {
- "xlim_raw": self._initial_xlim_raw,
- "time_span_raw": self._initial_xlim_raw[1] - self._initial_xlim_raw[0],
- "xlim_display": self.coord_manager.xlim_raw_to_display(
- self._initial_xlim_raw
- ),
- "use_envelope": True,
- "current_mode": self.MODE_ENVELOPE,
- }
-
- def _get_plot_data(self, view_params: Dict[str, Any]) -> Dict[str, Any]:
- """Get decimated plot data for current view."""
- logger.debug(f"=== _get_plot_data ===")
- logger.debug(f"view_params: {view_params}")
-
- plot_data = {}
-
- # Process each trace
- for trace_idx in range(self.data.num_traces):
- logger.debug(f"--- Processing trace {trace_idx} ---")
- t_arr = self.data.t_arrays[trace_idx]
- x_arr = self.data.x_arrays[trace_idx]
-
- logger.debug(f"Input data: t_arr len={len(t_arr)}, x_arr len={len(x_arr)}")
-
- try:
- t_raw, x_new, x_min, x_max = self.decimator.decimate_for_view(
- t_arr,
- x_arr,
- view_params["xlim_raw"],
- self.max_plot_points,
- view_params["use_envelope"],
- trace_idx, # Pass trace_id to use pre-decimated data
- envelope_window_samples=None, # Envelope window calculated automatically
- mode_switch_threshold=self.mode_switch_threshold, # Pass mode switch threshold
- return_envelope_min_max=True, # Main signal always returns envelope min/max if use_envelope is True
- )
-
- logger.debug(
- f"Decimated data: t_raw len={len(t_raw)}, x_new len={len(x_new)}"
- )
- logger.debug(
- f"Envelope data: x_min={'None' if x_min is None else f'len={len(x_min)}'}, x_max={'None' if x_max is None else f'len={len(x_max)}'}"
- )
-
- if len(t_raw) == 0:
- logger.warning(
- f"No data in current view for trace {trace_idx}. View range: {view_params['xlim_raw']}"
- )
- # Add empty arrays for this trace
- plot_data[f"t_display_{trace_idx}"] = np.array([], dtype=np.float32)
- plot_data[f"x_new_{trace_idx}"] = np.array([], dtype=np.float32)
- plot_data[f"x_min_{trace_idx}"] = None
- plot_data[f"x_max_{trace_idx}"] = None
- continue
-
- t_display = self.coord_manager.raw_to_display(t_raw)
- logger.debug(
- f"Converted to display coordinates: t_display range=[{np.min(t_display):.6f}, {np.max(t_display):.6f}]"
- )
-
- # Store data for this trace
- plot_data[f"t_display_{trace_idx}"] = t_display
- plot_data[f"x_new_{trace_idx}"] = x_new
- plot_data[f"x_min_{trace_idx}"] = x_min
- plot_data[f"x_max_{trace_idx}"] = x_max
-
- logger.debug(f"Stored plot data for trace {trace_idx}")
- except Exception as e:
- logger.exception(f"Error getting plot data for trace {trace_idx}: {e}")
- # Add empty arrays for this trace to prevent further errors
- plot_data[f"t_display_{trace_idx}"] = np.array([], dtype=np.float32)
- plot_data[f"x_new_{trace_idx}"] = np.array([], dtype=np.float32)
- plot_data[f"x_min_{trace_idx}"] = None
- plot_data[f"x_max_{trace_idx}"] = None
-
- logger.debug(f"Final plot_data keys: {list(plot_data.keys())}")
- return plot_data
-
- def _render_plot_elements(
- self, plot_data: Dict[str, Any], view_params: Dict[str, Any]
- ) -> None:
- """Render all plot elements with current data."""
- logger.debug(f"=== _render_plot_elements ===")
- logger.debug(f"view_params use_envelope: {view_params['use_envelope']}")
-
- # Store the current plot data for use by other methods
- self._current_plot_data = plot_data
-
- # Check if we have any data to plot
- has_data = False
- data_summary = {}
- for trace_idx in range(self.data.num_traces):
- key = f"t_display_{trace_idx}"
- if key in plot_data and len(plot_data[key]) > 0:
- has_data = True
- data_summary[f"trace_{trace_idx}"] = len(plot_data[key])
- else:
- data_summary[f"trace_{trace_idx}"] = 0
-
- logger.debug(f"Data summary: {data_summary}, has_data: {has_data}")
-
- if not has_data:
- logger.warning("No data to plot, clearing all elements")
- # If no data, clear all lines and return
- for i in range(self.data.num_traces):
- self._signal_lines[i].set_data([], [])
- if self._envelope_fills[i] is not None:
- self._envelope_fills[i].remove()
- self._envelope_fills[i] = None
-
- # Clear custom elements
- self._clear_custom_elements(i)
-
- self.ax.set_ylim(0, 1) # Set a default y-limit
- return
-
- # Process each trace
- for trace_idx in range(self.data.num_traces):
- logger.debug(f"--- Rendering trace {trace_idx} ---")
- t_display_key = f"t_display_{trace_idx}"
- x_new_key = f"x_new_{trace_idx}"
- x_min_key = f"x_min_{trace_idx}"
- x_max_key = f"x_max_{trace_idx}"
-
- if t_display_key not in plot_data or len(plot_data[t_display_key]) == 0:
- logger.debug(f"No data for trace {trace_idx}, hiding elements")
- # No data for this trace, hide its elements
- self._signal_lines[trace_idx].set_data([], [])
- if self._envelope_fills[trace_idx] is not None:
- self._envelope_fills[trace_idx].remove()
- self._envelope_fills[trace_idx] = None
-
- # Clear custom elements
- self._clear_custom_elements(trace_idx)
- continue
-
- # Update signal display
- envelope_data = None
- if (
- view_params["use_envelope"]
- and x_min_key in plot_data
- and x_max_key in plot_data
- ):
- if (
- plot_data[x_min_key] is not None
- and plot_data[x_max_key] is not None
- ):
- envelope_data = (plot_data[x_min_key], plot_data[x_max_key])
- logger.debug(f"Using envelope data for trace {trace_idx}")
- else:
- logger.debug(
- f"Envelope mode requested but no envelope data for trace {trace_idx}"
- )
- else:
- logger.debug(f"Detail mode for trace {trace_idx}")
-
- self._update_signal_display(
- trace_idx, plot_data[t_display_key], plot_data[x_new_key], envelope_data
- )
-
- # Update y-limits
- logger.debug("Updating y-limits")
- self._update_y_limits(plot_data, view_params["use_envelope"])
-
- def _update_coordinate_system(
- self, xlim_raw: Tuple[np.float32, np.float32], time_span_raw: np.float32
- ) -> None:
- """Update coordinate system and axis formatting."""
- self._clear_region_fills()
- self._update_axis_formatting()
- self._update_tick_locator(time_span_raw)
-
- xlim_display = self.coord_manager.xlim_raw_to_display(xlim_raw)
- self.ax.set_xlim(xlim_display)
-
- self._clear_navigation_history()
- self._push_current_view()
-
- def _update_regions_and_legend(
- self, xlim_display: Tuple[np.float32, np.float32]
- ) -> None:
- """Update regions and legend."""
- self._refresh_region_display(xlim_display)
- self._update_legend()
-
- def _refresh_region_display(
- self, xlim_display: Tuple[np.float32, np.float32]
- ) -> None:
- """Refresh region display for current view."""
- logger.debug(f"=== _refresh_region_display ===")
- self._clear_region_fills()
-
- # Get current mode
- current_mode = (
- self.MODE_ENVELOPE
- if self.state.current_mode == "envelope"
- else self.MODE_DETAIL
- )
- logger.debug(f"Current display mode for regions: {current_mode}")
-
- for trace_idx in range(self.data.num_traces):
- logger.debug(f"Processing regions for trace {trace_idx}")
- # Process each region definition
- for region_def in self._regions[trace_idx]:
- logger.debug(
- f"Region '{region_def['label']}': display_mode={region_def['display_mode']}, current_mode={current_mode}"
- )
- # Skip if not visible in current mode
- if not (region_def["display_mode"] & current_mode):
- logger.debug(
- f"Region '{region_def['label']}' not visible in current mode {current_mode}, skipping."
- )
- continue
-
- regions = region_def["regions"]
- if regions is None or len(regions) == 0:
- logger.debug(
- f"No regions data for '{region_def['label']}', skipping."
- )
- continue
-
- logger.debug(
- f"Displaying {len(regions)} regions for '{region_def['label']}' in mode {current_mode}"
- )
-
- color = region_def["color"]
- label = region_def["label"]
- alpha = region_def["alpha"]
- first_visible_region = True
-
- for t_start, t_end in regions:
- t_start_display = self.coord_manager.raw_to_display(t_start)
- t_end_display = self.coord_manager.raw_to_display(t_end)
-
- # Check if region overlaps with current view
- if not (
- t_end_display <= xlim_display[0]
- or t_start_display >= xlim_display[1]
- ):
- # Only show label for first visible region
- current_label = label if first_visible_region else ""
- if first_visible_region and len(regions) > 1:
- current_label = f"{label} ({len(regions)})"
-
- logger.debug(
- f"Adding region span from {t_start_display:.6f} to {t_end_display:.6f} (raw: {t_start:.6f} to {t_end:.6f}) for '{label}'"
- )
- fill = self.ax.axvspan(
- t_start_display,
- t_end_display,
- alpha=alpha,
- color=color,
- linewidth=0.5,
- label=current_label,
- zorder=region_def["zorder"],
- )
- self._region_objects[trace_idx].append((fill, region_def))
- first_visible_region = False
- else:
- logger.debug(
- f"Region span from {t_start_display:.6f} to {t_end_display:.6f} (raw: {t_start:.6f} to {t_end:.6f}) for '{label}' is outside current view {xlim_display}, skipping."
- )
-
- def _clear_region_fills(self) -> None:
- """Clear all region fills."""
- logger.debug("Clearing region fills.")
- for trace_fills in self._region_objects:
- for fill_item in trace_fills:
- # Handle both old format (just fill object) and new format (tuple)
- if isinstance(fill_item, tuple):
- fill, _ = fill_item # Extract the fill object from the tuple
- fill.remove()
- else:
- fill_item.remove() # Old format - direct fill object
- trace_fills.clear()
- logger.debug("Region fills cleared.")
-
- def _setup_plot_elements(self) -> None:
- """
- Initialise matplotlib plot elements (lines, fills) for each trace.
- This is called once during render().
- """
- if self.fig is None or self.ax is None:
- raise RuntimeError(
- "Figure and Axes must be created before setting up plot elements."
- )
-
- # Create initial signal line objects for each trace
- for i in range(self.data.num_traces):
- color = self.data.get_trace_color(i)
- name = self.data.get_trace_name(i)
-
- # Signal line
- (line_signal,) = self.ax.plot(
- [],
- [],
- label="Raw data" if self.data.num_traces == 1 else f"Raw data ({name})",
- color=color,
- alpha=self.signal_alpha,
- )
- self._signal_lines.append(line_signal)
-
- def _connect_callbacks(self) -> None:
- """Connect matplotlib callbacks."""
- if self.ax is None:
- raise RuntimeError("Axes must be created before connecting callbacks.")
- self.ax.callbacks.connect("xlim_changed", self._update_plot_data)
-
- def _setup_toolbar_overrides(self) -> None:
- """Override matplotlib toolbar methods (e.g., home button)."""
- if (
- self.fig
- and self.fig.canvas
- and hasattr(self.fig.canvas, "toolbar")
- and self.fig.canvas.toolbar
- ):
- toolbar = self.fig.canvas.toolbar
-
- # Store original methods
- self._original_home = getattr(toolbar, "home", None)
- self._original_push_current = getattr(toolbar, "push_current", None)
-
- # Create our custom home method
- def custom_home(*args, **kwargs):
- logger.debug("Toolbar home button pressed - calling custom home")
- self.home()
-
- # Override both the method and try to find the actual button
- toolbar.home = custom_home
-
- # For Qt backend, also override the action
- if hasattr(toolbar, "actions"):
- for action in toolbar.actions():
- if hasattr(action, "text") and hasattr(action, "objectName"):
- action_text = (
- action.text() if callable(action.text) else str(action.text)
- )
- action_name = (
- action.objectName()
- if callable(action.objectName)
- else str(action.objectName)
- )
- if action_text == "Home" or "home" in action_name.lower():
- if hasattr(action, "triggered"):
- action.triggered.disconnect()
- action.triggered.connect(custom_home)
- logger.debug("Connected custom home to Qt action")
- break
-
- # For other backends, try to override the button callback
- if hasattr(toolbar, "_buttons") and "Home" in toolbar._buttons:
- home_button = toolbar._buttons["Home"]
- if hasattr(home_button, "configure"):
- home_button.configure(command=custom_home)
- logger.debug("Connected custom home to Tkinter button")
-
- def _set_initial_view_and_labels(self) -> None:
- """Set initial axis limits, title, and labels."""
- if self.ax is None:
- raise RuntimeError(
- "Axes must be created before setting initial view and labels."
- )
-
- # Create title based on number of traces
- if self.data.num_traces == 1:
- self.ax.set_title(f"{self.data.names[0]}")
- else:
- # Multiple traces - just show "Multiple Traces"
- self.ax.set_title(f"Multiple Traces ({self.data.num_traces})")
- self.ax.set_xlabel(f"Time ({self.state.current_time_unit})")
- self.ax.set_ylabel("Signal")
-
- # Set initial xlim
- initial_xlim_display = self.coord_manager.xlim_raw_to_display(
- self._initial_xlim_raw
- )
- self.ax.set_xlim(initial_xlim_display)
-
- def render(self) -> None:
- """
- Renders the oscilloscope plot. This method must be called after all
- data and visualization elements have been added.
- """
- if self.fig is not None or self.ax is not None:
- logger.warning(
- "Plot already rendered. Call `home()` to reset or create a new instance."
- )
- return
-
- logger.info("Rendering plot...")
- self.fig, self.ax = plt.subplots(figsize=(10, 5))
-
- self._setup_plot_elements()
- self._connect_callbacks()
- self._setup_toolbar_overrides()
- self._set_initial_view_and_labels()
-
- # Calculate initial parameters for the full view
- t_start, t_end = self.data.get_global_time_range()
- full_time_span = t_end - t_start
-
- logger.info(
- f"Initial render: full time span={full_time_span:.3e}s, envelope_limit={self.mode_switch_threshold:.3e}s"
- )
-
- # Set initial display state based on full view
- self.state.current_time_unit, self.state.current_time_scale = (
- _get_optimal_time_unit_and_scale(full_time_span)
- )
- self.state.current_mode = (
- "envelope" if self.state.should_use_envelope(full_time_span) else "detail"
- )
-
- # Force initial draw of all elements by calling _update_plot_data
- # This will also update the legend and regions
- self.state.set_updating(False) # Ensure not in updating state for first call
- self._update_plot_data(self.ax)
- self.fig.canvas.draw_idle()
- logger.info("Plot rendering complete.")
-
- def home(self) -> None:
- """Return to initial full view with complete state reset."""
- if self.ax is None: # Fix: Changed '===' to 'is'
- logger.warning("Plot not rendered yet. Cannot go home.")
- return
-
- # Disconnect callback temporarily
- callback_id = None
- for cid, callback in self.ax.callbacks.callbacks["xlim_changed"].items():
- if getattr(callback, "__func__", callback) == self._update_plot_data:
- callback_id = cid
- break
-
- if callback_id is not None:
- self.ax.callbacks.disconnect(callback_id)
-
- try:
- self.state.set_updating(True)
- self.state.reset_to_initial_state()
- self.decimator.clear_cache()
- self._clear_region_fills()
-
- # Clear all custom elements and reset _last_mode for each trace to force redraw
- for trace_idx in range(self.data.num_traces):
- self._clear_custom_elements(trace_idx)
- self._last_mode[trace_idx] = None
-
- # Reset axis formatting
- self.ax.set_xlabel(f"Time ({self.state.original_time_unit})")
- self.ax.xaxis.set_major_formatter(mpl.ticker.ScalarFormatter())
- self.ax.xaxis.set_major_locator(mpl.ticker.AutoLocator())
-
- # Reset view
- self.coord_manager.set_view_raw(self.ax, self._initial_xlim_raw)
-
- # Manually trigger update for the home view
- # This will re-evaluate use_envelope, current_mode, and redraw everything
- self._update_plot_data(self.ax)
-
- self.state.set_updating(False)
-
- finally:
- self.ax.callbacks.connect("xlim_changed", self._update_plot_data)
-
- self.fig.canvas.draw()
- logger.info(f"Home view restored: {self.state.original_time_unit} scale")
-
- def refresh(self) -> None:
- """Force a complete refresh of the plot without changing the current view."""
- if self.ax is None:
- logger.warning("Plot not rendered yet. Cannot refresh.")
- return
-
- # Temporarily bypass the updating state for forced refresh
- was_updating = self.state.is_updating()
- self.state.set_updating(False)
- try:
- self._update_plot_data(self.ax)
- finally:
- self.state.set_updating(was_updating)
- self.fig.canvas.draw_idle()
-
- def show(self) -> None:
- """Display the plot."""
- if self.fig is None:
- self.render() # Render if not already rendered
- plt.show()
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import warnings
+
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+import numpy as np
+from matplotlib.ticker import MultipleLocator
+
+from .coordinate_manager import CoordinateManager
+from .data_manager import TimeSeriesDataManager
+from .decimation import DecimationManager
+from .display_state import (
+ DisplayState,
+ _create_time_formatter,
+ _get_optimal_time_unit_and_scale,
+)
+
+
+class OscilloscopePlot:
+ """
+ General-purpose plotting class for time-series data with zoom and decimation.
+
+ Uses separate managers for data, decimation, and state to reduce complexity.
+ Supports different visualization elements (lines, envelopes, ribbons, regions)
+ that can be displayed in different modes (envelope when zoomed out, detail when zoomed in).
+ """
+
+ # Mode constants
+ MODE_ENVELOPE = 1 # Zoomed out mode
+ MODE_DETAIL = 2 # Zoomed in mode
+ MODE_BOTH = 3 # Both modes
+
+ # Default styling constants
+ DEFAULT_MAX_PLOT_POINTS = 10000
+ DEFAULT_MODE_SWITCH_THRESHOLD = 10e-3 # 10 ms
+ DEFAULT_MIN_Y_RANGE_DEFAULT = 1e-9 # Default minimum Y-axis range (e.g., 1 nV)
+ DEFAULT_Y_MARGIN_FRACTION = 0.15
+ DEFAULT_SIGNAL_LINE_WIDTH = 1.0
+ DEFAULT_SIGNAL_ALPHA = 0.75
+ DEFAULT_ENVELOPE_ALPHA = 0.75
+ DEFAULT_REGION_ALPHA = 0.4
+ DEFAULT_REGION_ZORDER = -5
+
+ def __init__(
+ self,
+ t: Union[np.ndarray, List[np.ndarray]],
+ x: Union[np.ndarray, List[np.ndarray]],
+ name: Union[str, List[str]] = "Waveform",
+ trace_colors: Optional[List[str]] = None,
+ # Core display parameters
+ max_plot_points: int = DEFAULT_MAX_PLOT_POINTS,
+ mode_switch_threshold: float = DEFAULT_MODE_SWITCH_THRESHOLD,
+ min_y_range: Optional[float] = None, # New parameter for minimum Y-axis range
+ y_margin_fraction: float = DEFAULT_Y_MARGIN_FRACTION,
+ signal_line_width: float = DEFAULT_SIGNAL_LINE_WIDTH,
+ signal_alpha: float = DEFAULT_SIGNAL_ALPHA,
+ envelope_alpha: float = DEFAULT_ENVELOPE_ALPHA,
+ region_alpha: float = DEFAULT_REGION_ALPHA,
+ region_zorder: int = DEFAULT_REGION_ZORDER,
+ envelope_window_samples: Optional[int] = None,
+ ):
+ """
+ Initialize the OscilloscopePlot with time series data.
+
+ Parameters
+ ----------
+ t : Union[np.ndarray, List[np.ndarray]]
+ Time array(s) (raw time in seconds). Can be a single array shared by all traces
+ or a list of arrays, one per trace.
+ x : Union[np.ndarray, List[np.ndarray]]
+ Signal array(s). If t is a single array, x can be a 2D array (traces x samples)
+ or a list of 1D arrays. If t is a list, x must be a list of equal length.
+ name : Union[str, List[str]], default="Waveform"
+ Name(s) for plot title. Can be a single string or a list of strings.
+ trace_colors : Optional[List[str]], default=None
+ Colors for each trace. If None, default colors will be used.
+ max_plot_points : int, default=10000
+ Maximum number of points to display on the plot. Data will be decimated if it exceeds this.
+ mode_switch_threshold : float, default=10e-3
+ Time span (in seconds) above which the plot switches to envelope mode.
+ min_y_range : Optional[float], default=None
+ Minimum Y-axis range to enforce. If None, a default small value is used.
+ y_margin_fraction : float, default=0.05
+ Fraction of data range to add as margin to Y-axis limits.
+ signal_line_width : float, default=1.0
+ Line width for the raw signal plot.
+ signal_alpha : float, default=0.75
+ Alpha (transparency) for the raw signal plot.
+ envelope_alpha : float, default=1.0
+ Alpha (transparency) for the envelope fill.
+ region_alpha : float, default=0.4
+ Alpha (transparency) for region highlight fills.
+ region_zorder : int, default=-5
+ Z-order for region highlight fills (lower means further back).
+ envelope_window_samples : Optional[int], default=None
+ DEPRECATED: Window size in samples for envelope calculation.
+ Envelope window is now calculated automatically based on max_plot_points and zoom level.
+ This parameter is ignored but kept for backward compatibility.
+ """
+ # Store styling parameters directly as instance attributes
+ self.max_plot_points = max_plot_points
+ self.mode_switch_threshold = np.float32(mode_switch_threshold)
+ self.min_y_range = (
+ np.float32(min_y_range)
+ if min_y_range is not None
+ else self.DEFAULT_MIN_Y_RANGE_DEFAULT
+ )
+ self.y_margin_fraction = np.float32(y_margin_fraction)
+ self.signal_line_width = signal_line_width
+ self.signal_alpha = signal_alpha
+ self.envelope_alpha = envelope_alpha
+ self.region_alpha = region_alpha
+ self.region_zorder = region_zorder
+ # envelope_window_samples is now deprecated - envelope window is calculated automatically
+ # Keep the parameter for backward compatibility but don't use it
+ if envelope_window_samples is not None:
+ warnings.warn(
+ "envelope_window_samples parameter is deprecated. Envelope window is now calculated automatically based on zoom level.", DeprecationWarning
+ )
+
+ # Initialize managers
+ self.data = TimeSeriesDataManager(t, x, name, trace_colors)
+ self.decimator = DecimationManager()
+
+ # Pre-decimate main signal data for envelope view
+ for i in range(self.data.num_traces):
+ self.decimator.pre_decimate_data(
+ data_id=i, # Use trace_idx as data_id
+ t=self.data.t_arrays[i],
+ x=self.data.x_arrays[i],
+ max_points=self.max_plot_points,
+ envelope_window_samples=None, # Envelope window calculated automatically
+ )
+
+ # Initialize display state using the first trace's time array
+ initial_time_unit, initial_time_scale = _get_optimal_time_unit_and_scale(
+ self.data.t_arrays[0]
+ )
+ self.state = DisplayState(
+ initial_time_unit, initial_time_scale, self.mode_switch_threshold
+ )
+
+ # Initialize matplotlib figure and axes to None
+ self.fig: Optional[mpl.figure.Figure] = None
+ self.ax: Optional[mpl.axes.Axes] = None
+
+ # Store visualization elements for each trace
+ self._signal_lines: List[mpl.lines.Line2D] = []
+ self._envelope_fills: List[Optional[mpl.collections.PolyCollection]] = [
+ None
+ ] * self.data.num_traces
+
+ # Visualization elements with mode control (definitions, not plot objects)
+ self._lines: List[List[Dict[str, Any]]] = [
+ [] for _ in range(self.data.num_traces)
+ ]
+ self._ribbons: List[List[Dict[str, Any]]] = [
+ [] for _ in range(self.data.num_traces)
+ ]
+ self._regions: List[List[Dict[str, Any]]] = [
+ [] for _ in range(self.data.num_traces)
+ ]
+ self._envelopes: List[List[Dict[str, Any]]] = [
+ [] for _ in range(self.data.num_traces)
+ ]
+
+ # Line objects for each trace (will be populated as needed during rendering)
+ self._line_objects: List[List[mpl.artist.Artist]] = [
+ [] for _ in range(self.data.num_traces)
+ ] # Changed type hint to Artist
+ self._ribbon_objects: List[List[mpl.collections.PolyCollection]] = [
+ [] for _ in range(self.data.num_traces)
+ ]
+ self._region_objects: List[List[mpl.collections.PolyCollection]] = [
+ [] for _ in range(self.data.num_traces)
+ ]
+
+ # Store current plot data for access by other methods
+ self._current_plot_data = {}
+
+ # Initialize coordinate manager
+ self.coord_manager = CoordinateManager(self.state)
+
+ # Store initial view for home button (using global time range)
+ t_start, t_end = self.data.get_global_time_range()
+ self._initial_xlim_raw = (t_start, t_end)
+
+ # Legend state for optimization
+ self._current_legend_handles: List[mpl.artist.Artist] = []
+ self._current_legend_labels: List[str] = []
+ self._legend: Optional[mpl.legend.Legend] = None
+
+ # Track last mode for each trace to optimize element updates
+ self._last_mode: Dict[int, Optional[int]] = {
+ i: None for i in range(self.data.num_traces)
+ }
+
+ # Store original toolbar methods for restoration
+ self._original_home = None
+ self._original_push_current = None
+
+ def save(self, filepath: str) -> None:
+ """
+ Save the current plot to a file.
+
+ Parameters
+ ----------
+ filepath : str
+ Path to save the plot image.
+ """
+ if self.fig is None or self.ax is None:
+ raise RuntimeError("Plot has not been initialized yet.")
+ self.fig.savefig(filepath)
+ print(f"Plot saved to {filepath}")
+
+ def add_line(
+ self,
+ t: Union[np.ndarray, List[np.ndarray]],
+ data: Union[np.ndarray, List[np.ndarray]],
+ label: str = "Line",
+ color: Optional[str] = None,
+ alpha: float = 0.75,
+ linestyle: str = "-",
+ linewidth: float = 1.0,
+ display_mode: int = MODE_BOTH,
+ trace_idx: int = 0,
+ zorder: int = 5,
+ ) -> None:
+ """
+ Add a line to the plot with mode control.
+
+ Parameters
+ ----------
+ t : Union[np.ndarray, List[np.ndarray]]
+ Time array(s) for the line data. Must match the length of data.
+ data : Union[np.ndarray, List[np.ndarray]]
+ Line data array(s). Can be a single array or a list of arrays.
+ label : str, default="Line"
+ Label for the legend.
+ color : Optional[str], default=None
+ Color for the line. If None, the trace color will be used.
+ alpha : float, default=0.75
+ Alpha (transparency) for the line.
+ linestyle : str, default="-"
+ Line style.
+ linewidth : float, default=1.0
+ Line width.
+ display_mode : int, default=MODE_BOTH
+ Which mode(s) to show this line in (MODE_ENVELOPE, MODE_DETAIL, or MODE_BOTH).
+ trace_idx : int, default=0
+ Index of the trace to add the line to.
+ zorder : int, default=5
+ Z-order for the line (higher values appear on top).
+ """
+ if trace_idx < 0 or trace_idx >= self.data.num_traces:
+ raise ValueError(
+ f"Invalid trace index: {trace_idx}. Must be between 0 and {self.data.num_traces - 1}."
+ )
+
+ # Validate data length
+ if isinstance(data, list):
+ if len(data) != len(t):
+ raise ValueError(
+ f"Line data length ({len(data)}) must match time array length ({len(t)})."
+ )
+ else:
+ if len(data) != len(t):
+ raise ValueError(
+ f"Line data length ({len(data)}) must match time array length ({len(t)})."
+ )
+
+ # Use trace color if none provided
+ if color is None:
+ color = self.data.get_trace_color(trace_idx)
+
+ # Convert inputs to numpy arrays
+ t_array = np.asarray(t, dtype=np.float32)
+ data_array = np.asarray(data, dtype=np.float32)
+
+ # Assign a unique ID for this custom line for pre-decimation caching
+ # We use a negative ID to distinguish from main traces (which use 0, 1, 2...)
+ # and ensure uniqueness across custom lines.
+ line_id = -(len(self._lines[trace_idx]) + 1) # Negative, unique per trace
+
+ # Pre-decimate this custom line's data for envelope view
+ self.decimator.pre_decimate_data(
+ data_id=line_id,
+ t=t_array,
+ x=data_array,
+ max_points=self.max_plot_points,
+ envelope_window_samples=None, # Envelope window calculated automatically
+ )
+
+ # Store line definition with raw data and its assigned ID
+ line_def = {
+ "id": line_id, # Store the ID for retrieval from decimator
+ "t_raw": t_array, # Store raw time array
+ "data_raw": data_array, # Store raw data array
+ "label": label,
+ "color": color,
+ "alpha": alpha,
+ "linestyle": linestyle,
+ "linewidth": linewidth,
+ "display_mode": display_mode,
+ "zorder": zorder,
+ }
+
+ self._lines[trace_idx].append(line_def)
+
+ def add_ribbon(
+ self,
+ t: Union[np.ndarray, List[np.ndarray]],
+ center_data: Union[np.ndarray, List[np.ndarray]],
+ width: Union[float, np.ndarray],
+ label: str = "Ribbon",
+ color: str = "gray",
+ alpha: float = 0.6,
+ display_mode: int = MODE_DETAIL,
+ trace_idx: int = 0,
+ zorder: int = 2,
+ ) -> None:
+ """
+ Add a ribbon (center ± width) with mode control.
+
+ Parameters
+ ----------
+ t : Union[np.ndarray, List[np.ndarray]]
+ Time array(s) for the ribbon data. Must match the length of center_data.
+ center_data : Union[np.ndarray, List[np.ndarray]]
+ Center line data array(s). Can be a single array or a list of arrays.
+ width : Union[float, np.ndarray]
+ Width of the ribbon. Can be a single value or an array matching center_data.
+ label : str, default="Ribbon"
+ Label for the legend.
+ color : str, default="gray"
+ Color for the ribbon.
+ alpha : float, default=0.6
+ Alpha (transparency) for the ribbon.
+ display_mode : int, default=MODE_DETAIL
+ Which mode(s) to show this ribbon in (MODE_ENVELOPE, MODE_DETAIL, or MODE_BOTH).
+ trace_idx : int, default=0
+ Index of the trace to add the ribbon to.
+ """
+ if trace_idx < 0 or trace_idx >= self.data.num_traces:
+ raise ValueError(
+ f"Invalid trace index: {trace_idx}. Must be between 0 and {self.data.num_traces - 1}."
+ )
+
+ # Validate data length
+ if isinstance(center_data, list):
+ if len(center_data) != len(t):
+ raise ValueError(
+ f"Ribbon center data length ({len(center_data)}) must match time array length ({len(t)})."
+ )
+ else:
+ if len(center_data) != len(t):
+ raise ValueError(
+ f"Ribbon center data length ({len(center_data)}) must match time array length ({len(t)})."
+ )
+
+ # Convert center data to numpy array
+ center_data = np.asarray(center_data, dtype=np.float32)
+
+ # Handle width as scalar or array
+ if isinstance(width, (int, float, np.number)):
+ width_array = np.ones_like(center_data) * width
+ else:
+ if len(width) != len(center_data):
+ raise ValueError(
+ f"Ribbon width array length ({len(width)}) must match center data length ({len(center_data)})."
+ )
+ width_array = np.asarray(width, dtype=np.float32)
+
+ # Assign a unique ID for this custom ribbon
+ ribbon_id = -(
+ len(self._ribbons[trace_idx]) + 1001
+ ) # Negative, unique per trace, offset from lines
+
+ # Pre-decimate this custom ribbon's center data for envelope view
+ # We only pre-decimate the center, as width is applied later
+ self.decimator.pre_decimate_data(
+ data_id=ribbon_id,
+ t=np.asarray(t, dtype=np.float32),
+ x=center_data,
+ max_points=self.max_plot_points,
+ envelope_window_samples=None, # Envelope window calculated automatically
+ )
+
+ # Store ribbon definition
+ ribbon_def = {
+ "id": ribbon_id,
+ "t_raw": np.asarray(t, dtype=np.float32),
+ "center_data_raw": center_data,
+ "width_raw": width_array,
+ "label": label,
+ "color": color,
+ "alpha": alpha,
+ "display_mode": display_mode,
+ "zorder": zorder,
+ }
+
+ self._ribbons[trace_idx].append(ribbon_def)
+
+ def add_envelope(
+ self,
+ min_data: Union[np.ndarray, List[np.ndarray]],
+ max_data: Union[np.ndarray, List[np.ndarray]],
+ label: str = "Envelope",
+ color: Optional[str] = None,
+ alpha: float = 0.4,
+ display_mode: int = MODE_ENVELOPE,
+ trace_idx: int = 0,
+ zorder: int = 1,
+ ) -> None:
+ """
+ Add envelope data with mode control.
+
+ Parameters
+ ----------
+ min_data : Union[np.ndarray, List[np.ndarray]]
+ Minimum envelope data array(s). Can be a single array or a list of arrays.
+ max_data : Union[np.ndarray, List[np.ndarray]]
+ Maximum envelope data array(s). Can be a single array or a list of arrays.
+ label : str, default="Envelope"
+ Label for the legend.
+ color : Optional[str], default=None
+ Color for the envelope. If None, the trace color will be used.
+ alpha : float, default=0.4
+ Alpha (transparency) for the envelope.
+ display_mode : int, default=MODE_ENVELOPE
+ Which mode(s) to show this envelope in (MODE_ENVELOPE, MODE_DETAIL, or MODE_BOTH).
+ trace_idx : int, default=0
+ Index of the trace to add the envelope to.
+ """
+ if trace_idx < 0 or trace_idx >= self.data.num_traces:
+ raise ValueError(
+ f"Invalid trace index: {trace_idx}. Must be between 0 and {self.data.num_traces - 1}."
+ )
+
+ # Validate data length
+ if isinstance(min_data, list):
+ if len(min_data) != len(self.data.t_arrays[trace_idx]):
+ raise ValueError(
+ f"Envelope min data length ({len(min_data)}) must match time array length ({len(self.data.t_arrays[trace_idx])})."
+ )
+ else:
+ if len(min_data) != len(self.data.t_arrays[trace_idx]):
+ raise ValueError(
+ f"Envelope min data length ({len(min_data)}) must match time array length ({len(self.data.t_arrays[trace_idx])})."
+ )
+
+ if isinstance(max_data, list):
+ if len(max_data) != len(self.data.t_arrays[trace_idx]):
+ raise ValueError(
+ f"Envelope max data length ({len(max_data)}) must match time array length ({len(self.data.t_arrays[trace_idx])})."
+ )
+ else:
+ if len(max_data) != len(self.data.t_arrays[trace_idx]):
+ raise ValueError(
+ f"Envelope max data length ({len(max_data)}) must match time array length ({len(self.data.t_arrays[trace_idx])})."
+ )
+
+ # Use trace color if none provided
+ if color is None:
+ color = self.data.get_trace_color(trace_idx)
+
+ # Assign a unique ID for this custom envelope
+ envelope_id = -(
+ len(self._envelopes[trace_idx]) + 2001
+ ) # Negative, unique per trace, offset from ribbons
+
+ # Pre-decimate this custom envelope's data for envelope view
+ # We'll pre-decimate the average of min/max, and store min/max separately
+ t_raw = self.data.t_arrays[trace_idx]
+ avg_data = (
+ np.asarray(min_data, dtype=np.float32)
+ + np.asarray(max_data, dtype=np.float32)
+ ) / 2
+
+ self.decimator.pre_decimate_data(
+ data_id=envelope_id,
+ t=t_raw,
+ x=avg_data, # Pass average for decimation
+ max_points=self.max_plot_points,
+ envelope_window_samples=None, # Envelope window calculated automatically
+ )
+
+ # Store envelope definition
+ envelope_def = {
+ "id": envelope_id,
+ "t_raw": t_raw,
+ "min_data_raw": np.asarray(min_data, dtype=np.float32),
+ "max_data_raw": np.asarray(max_data, dtype=np.float32),
+ "label": label,
+ "color": color,
+ "alpha": alpha,
+ "display_mode": display_mode,
+ "zorder": zorder,
+ }
+
+ self._envelopes[trace_idx].append(envelope_def)
+
+ def add_regions(
+ self,
+ regions: np.ndarray,
+ label: str = "Regions",
+ color: str = "crimson",
+ alpha: float = 0.4,
+ display_mode: int = MODE_BOTH,
+ trace_idx: int = 0,
+ zorder: int = -5,
+ ) -> None:
+ """
+ Add region highlights with mode control.
+
+ Parameters
+ ----------
+ regions : np.ndarray
+ Region data array with shape (N, 2) where each row is [start_time, end_time].
+ label : str, default="Regions"
+ Label for the legend.
+ color : str, default="crimson"
+ Color for the regions.
+ alpha : float, default=0.4
+ Alpha (transparency) for the regions.
+ display_mode : int, default=MODE_BOTH
+ Which mode(s) to show these regions in (MODE_ENVELOPE, MODE_DETAIL, or MODE_BOTH).
+ trace_idx : int, default=0
+ Index of the trace to add the regions to.
+ """
+ if trace_idx < 0 or trace_idx >= self.data.num_traces:
+ raise ValueError(
+ f"Invalid trace index: {trace_idx}. Must be between 0 and {self.data.num_traces - 1}."
+ )
+
+ # Validate regions array
+ if regions.ndim != 2 or regions.shape[1] != 2:
+ raise ValueError(
+ f"Regions array must have shape (N, 2), got {regions.shape}."
+ )
+
+ # Store regions definition
+ region_def = {
+ "regions": np.asarray(regions, dtype=np.float32),
+ "label": label,
+ "color": color,
+ "alpha": alpha,
+ "display_mode": display_mode,
+ "zorder": zorder,
+ }
+
+ self._regions[trace_idx].append(region_def)
+
+ def _update_signal_display(
+ self,
+ trace_idx: int,
+ t_display: np.ndarray,
+ x_data: np.ndarray,
+ envelope_data: Optional[Tuple[np.ndarray, np.ndarray]] = None,
+ ) -> None:
+ """
+ Update signal display with envelope or raw data for a specific trace.
+
+ Parameters
+ ----------
+ trace_idx : int
+ Index of the trace to update.
+ t_display : np.ndarray
+ Display time array.
+ x_data : np.ndarray
+ Signal data array.
+ envelope_data : Optional[Tuple[np.ndarray, np.ndarray]], default=None
+ Tuple of (min, max) envelope data if in envelope mode.
+ """
+ if envelope_data is not None:
+ self._show_envelope_mode(trace_idx, t_display, envelope_data)
+ else:
+ self._show_detail_mode(trace_idx, t_display, x_data)
+
+ def _show_envelope_mode(
+ self,
+ trace_idx: int,
+ t_display: np.ndarray,
+ envelope_data: Tuple[np.ndarray, np.ndarray],
+ ) -> None:
+ """
+ Show envelope display mode for a specific trace.
+
+ Parameters
+ ----------
+ trace_idx : int
+ Index of the trace to update.
+ t_display : np.ndarray
+ Display time array.
+ envelope_data : Tuple[np.ndarray, np.ndarray]
+ Tuple of (min, max) envelope data.
+ """
+ x_min, x_max = envelope_data
+ color = self.data.get_trace_color(trace_idx)
+ name = self.data.get_trace_name(trace_idx)
+
+ # Clean up previous displays
+ if self._envelope_fills[trace_idx] is not None:
+ self._envelope_fills[trace_idx].remove()
+
+ self._signal_lines[trace_idx].set_data([], [])
+ self._signal_lines[trace_idx].set_visible(False)
+
+ # Show built-in envelope
+ self._envelope_fills[trace_idx] = self.ax.fill_between(
+ t_display,
+ x_min,
+ x_max,
+ alpha=self.envelope_alpha,
+ color=color,
+ lw=0.1,
+ label=f"Raw envelope ({name})"
+ if self.data.num_traces > 1
+ else "Raw envelope",
+ zorder=1, # Keep default envelope at zorder=1
+ )
+
+ # Set current mode
+ self.state.current_mode = "envelope"
+
+ # Show any custom elements for this mode
+ self._show_custom_elements(trace_idx, t_display, self.MODE_ENVELOPE)
+
+ def _show_detail_mode(
+ self, trace_idx: int, t_display: np.ndarray, x_data: np.ndarray
+ ) -> None:
+ """
+ Show detail display mode for a specific trace.
+
+ Parameters
+ ----------
+ trace_idx : int
+ Index of the trace to update.
+ t_display : np.ndarray
+ Display time array.
+ x_data : np.ndarray
+ Signal data array.
+ """
+ # Clean up envelope
+ if self._envelope_fills[trace_idx] is not None:
+ self._envelope_fills[trace_idx].remove()
+ self._envelope_fills[trace_idx] = None
+
+ # Update signal line
+ line = self._signal_lines[trace_idx]
+ line.set_data(t_display, x_data)
+ line.set_linewidth(self.signal_line_width)
+ line.set_alpha(self.signal_alpha)
+ line.set_visible(True)
+
+ # Set current mode
+ self.state.current_mode = "detail"
+
+ # Show any custom elements for this mode
+ self._show_custom_elements(trace_idx, t_display, self.MODE_DETAIL)
+
+ def _show_custom_elements(
+ self, trace_idx: int, t_display: np.ndarray, current_mode: int
+ ) -> None:
+ """
+ Show custom visualization elements for the current mode.
+
+ Parameters
+ ----------
+ trace_idx : int
+ Index of the trace to update.
+ t_display : np.ndarray
+ Display time array.
+ current_mode : int
+ Current display mode (MODE_ENVELOPE or MODE_DETAIL).
+ """
+ last_mode = self._last_mode.get(trace_idx)
+
+ # Always clear and recreate elements when view changes, regardless of mode change
+ # This ensures custom lines/ribbons are redrawn correctly with current view data
+ self._clear_custom_elements(trace_idx)
+
+ # Get current raw x-limits from the main plot data
+ # This is crucial for decimating custom lines to the current view
+ current_xlim_raw = self.coord_manager.get_current_view_raw(self.ax)
+
+ # Show lines for current mode
+ line_objects = []
+ for i, line_def in enumerate(self._lines[trace_idx]):
+ if (
+ line_def["display_mode"] & current_mode
+ ): # Bitwise check if mode is enabled
+
+ # Dynamically decimate the line data for the current view
+ # Use the same max_plot_points as the main signal for consistency
+ # For custom lines, we want mean decimation if in envelope mode, not min/max envelope
+ t_line_raw, line_data, _, _ = self.decimator.decimate_for_view(
+ line_def["t_raw"],
+ line_def["data_raw"],
+ current_xlim_raw, # Decimate to current view
+ self.max_plot_points,
+ use_envelope=(current_mode == self.MODE_ENVELOPE),
+ data_id=line_def[
+ "id"
+ ], # Pass the custom line's ID for pre-decimated data lookup
+ envelope_window_samples=None, # Envelope window calculated automatically
+ mode_switch_threshold=self.mode_switch_threshold, # Pass mode switch threshold
+ return_envelope_min_max=False, # Custom lines never return min/max envelope
+ )
+
+ if len(t_line_raw) == 0 or len(line_data) == 0:
+ warnings.warn(
+ f"Line {i} ('{line_def['label']}') has empty data after decimation for current view, skipping plot.", UserWarning
+ )
+ continue
+
+ # Make sure the time array is in display coordinates
+ t_line_display = self.coord_manager.raw_to_display(t_line_raw)
+
+ # Always plot as a regular line
+ (line,) = self.ax.plot(
+ t_line_display,
+ line_data,
+ label=line_def["label"],
+ color=line_def["color"],
+ alpha=line_def["alpha"],
+ linestyle=line_def["linestyle"],
+ linewidth=line_def["linewidth"],
+ zorder=line_def["zorder"],
+ )
+ line_objects.append(
+ (line, line_def)
+ ) # Store both the line and its definition
+
+ # Show ribbons for current mode
+ ribbon_objects = []
+ for ribbon_def in self._ribbons[trace_idx]:
+ if ribbon_def["display_mode"] & current_mode:
+
+ # Ribbons are always plotted as fills, so we need to decimate their center and width
+ # We'll treat the center_data as the 'signal' for decimation purposes
+ (
+ t_ribbon_raw,
+ center_data_decimated,
+ min_center_envelope,
+ max_center_envelope,
+ ) = self.decimator.decimate_for_view(
+ ribbon_def["t_raw"],
+ ribbon_def["center_data_raw"],
+ current_xlim_raw,
+ self.max_plot_points,
+ use_envelope=(
+ current_mode == self.MODE_ENVELOPE
+ ), # Use envelope for ribbons if in envelope mode
+ data_id=ribbon_def[
+ "id"
+ ], # Pass the custom ribbon's ID for pre-decimated data lookup
+ return_envelope_min_max=True, # Ribbons always need min/max to draw fill
+ envelope_window_samples=None, # Envelope window calculated automatically
+ mode_switch_threshold=self.mode_switch_threshold,
+ )
+
+ # Decimate the width array as well, if it's an array
+ width_decimated = ribbon_def["width_raw"]
+ if len(ribbon_def["width_raw"]) > len(
+ t_ribbon_raw
+ ): # If raw width is longer than decimated time
+ # For simplicity, we'll just take the mean of the width in each bin
+ # A more robust solution might involve passing width as another data stream to decimate_for_view
+ # For now, we'll manually decimate it based on the t_ribbon_raw indices
+ # Find indices in raw data corresponding to decimated time points
+ # This is a simplified approach and assumes uniform sampling for width
+ indices = np.searchsorted(ribbon_def["t_raw"], t_ribbon_raw)
+ indices = np.clip(indices, 0, len(ribbon_def["width_raw"]) - 1)
+ width_decimated = ribbon_def["width_raw"][indices]
+
+ # If the ribbon was decimated to an envelope, use that for min/max
+ if (
+ current_mode == self.MODE_ENVELOPE
+ and min_center_envelope is not None
+ and max_center_envelope is not None
+ ):
+ lower_bound = min_center_envelope - width_decimated
+ upper_bound = max_center_envelope + width_decimated
+ else:
+ lower_bound = center_data_decimated - width_decimated
+ upper_bound = center_data_decimated + width_decimated
+
+ if len(t_ribbon_raw) == 0 or len(lower_bound) == 0:
+ warnings.warn(
+ f"Ribbon ('{ribbon_def['label']}') has empty data after decimation, skipping plot.", UserWarning
+ )
+ continue
+
+ # Make sure the time array is in display coordinates
+ t_ribbon_display = self.coord_manager.raw_to_display(t_ribbon_raw)
+
+ ribbon = self.ax.fill_between(
+ t_ribbon_display,
+ lower_bound,
+ upper_bound,
+ color=ribbon_def["color"],
+ alpha=ribbon_def["alpha"],
+ label=ribbon_def["label"],
+ zorder=ribbon_def["zorder"],
+ )
+ ribbon_objects.append(
+ (ribbon, ribbon_def)
+ ) # Store both the ribbon and its definition
+
+ # Show custom envelopes for current mode
+ for envelope_def in self._envelopes[trace_idx]:
+ if envelope_def["display_mode"] & current_mode:
+
+ # For custom envelopes, we need to handle min/max data specially
+ # We'll decimate the min and max data separately using the envelope's stored data
+ # Since we stored min/max in the pre-decimated data, we can retrieve them
+
+ # Get the pre-decimated envelope data for this custom envelope
+ if envelope_def["id"] in self.decimator._pre_decimated_envelopes:
+ pre_dec_data = self.decimator._pre_decimated_envelopes[
+ envelope_def["id"]
+ ]
+ # The min/max data was stored in bg_initial/bg_clean during pre-decimation
+ t_envelope_raw, _, min_data_decimated, max_data_decimated = (
+ self.decimator.decimate_for_view(
+ envelope_def["t_raw"],
+ (
+ envelope_def["min_data_raw"]
+ + envelope_def["max_data_raw"]
+ )
+ / 2, # Average for decimation
+ current_xlim_raw,
+ self.max_plot_points,
+ use_envelope=True, # Always treat custom envelopes as envelopes
+ data_id=envelope_def[
+ "id"
+ ], # Pass the custom envelope's ID for pre-decimated data lookup
+ return_envelope_min_max=True, # Custom envelopes always need min/max to draw fill
+ envelope_window_samples=None, # Envelope window calculated automatically
+ mode_switch_threshold=self.mode_switch_threshold,
+ )
+ )
+ # For custom envelopes, the min/max are returned directly as the last two return values
+ else:
+ # Fallback if no pre-decimated data
+ warnings.warn(
+ f"No pre-decimated data for custom envelope {envelope_def['id']}, using raw decimation", UserWarning
+ )
+ t_envelope_raw, _, min_data_decimated, max_data_decimated = (
+ self.decimator.decimate_for_view(
+ envelope_def["t_raw"],
+ (
+ envelope_def["min_data_raw"]
+ + envelope_def["max_data_raw"]
+ )
+ / 2,
+ current_xlim_raw,
+ self.max_plot_points,
+ use_envelope=True,
+ data_id=None, # No pre-decimated data available
+ return_envelope_min_max=True,
+ envelope_window_samples=None, # Envelope window calculated automatically
+ mode_switch_threshold=self.mode_switch_threshold,
+ )
+ )
+
+ if (
+ len(t_envelope_raw) == 0
+ or min_data_decimated is None
+ or max_data_decimated is None
+ or len(min_data_decimated) == 0
+ ):
+ warnings.warn(
+ f"Custom envelope ('{envelope_def['label']}') has empty data after decimation, skipping plot.", UserWarning
+ )
+ continue
+
+ t_envelope_display = self.coord_manager.raw_to_display(t_envelope_raw)
+
+ envelope = self.ax.fill_between(
+ t_envelope_display,
+ min_data_decimated,
+ max_data_decimated,
+ color=envelope_def["color"],
+ alpha=envelope_def["alpha"],
+ label=envelope_def["label"],
+ zorder=envelope_def["zorder"],
+ )
+ ribbon_objects.append(
+ (envelope, envelope_def)
+ ) # Store in ribbon objects
+
+ # Store objects with their definitions for future updates
+ self._line_objects[trace_idx] = line_objects
+ self._ribbon_objects[trace_idx] = ribbon_objects
+
+ # Update last mode AFTER processing
+ self._last_mode[trace_idx] = current_mode
+
+ def _update_element_visibility(self, trace_idx: int, current_mode: int) -> None:
+ """
+ Update visibility of existing custom elements based on current mode.
+
+ Parameters
+ ----------
+ trace_idx : int
+ Index of the trace to update.
+ current_mode : int
+ Current display mode (MODE_ENVELOPE or MODE_DETAIL).
+ """
+ # Update line visibility
+ for line_obj, line_def in self._line_objects[trace_idx]:
+ should_be_visible = bool(line_def["display_mode"] & current_mode)
+ if line_obj.get_visible() != should_be_visible:
+ line_obj.set_visible(should_be_visible)
+
+ # Update ribbon visibility
+ for ribbon_obj, ribbon_def in self._ribbon_objects[trace_idx]:
+ should_be_visible = bool(ribbon_def["display_mode"] & current_mode)
+ if ribbon_obj.get_visible() != should_be_visible:
+ ribbon_obj.set_visible(should_be_visible)
+
+ def _clear_custom_elements(self, trace_idx: int) -> None:
+ """
+ Clear all custom visualization elements for a trace.
+
+ Parameters
+ ----------
+ trace_idx : int
+ Index of the trace to clear elements for.
+ """
+ # Clear lines
+ for line_obj, _ in self._line_objects[trace_idx]:
+ line_obj.remove()
+ self._line_objects[trace_idx].clear()
+
+ # Clear ribbons
+ for ribbon_obj, _ in self._ribbon_objects[trace_idx]:
+ ribbon_obj.remove()
+ self._ribbon_objects[trace_idx].clear()
+
+ def _update_tick_locator(self, time_span_raw: np.float32) -> None:
+ """Update tick locator based on current time scale and span."""
+ if self.state.current_time_scale >= np.float32(1e6): # microseconds or smaller
+ # For microsecond scale, use reasonable intervals
+ tick_interval = max(
+ 1, int(time_span_raw * self.state.current_time_scale / 10)
+ )
+ self.ax.xaxis.set_major_locator(MultipleLocator(tick_interval))
+ else:
+ # For larger scales, use matplotlib's default auto locator
+ self.ax.xaxis.set_major_locator(mpl.ticker.AutoLocator())
+
+ def _update_legend(self) -> None:
+ """Updates the plot legend, filtering out invisible elements and optimising rebuilds."""
+ handles, labels = self.ax.get_legend_handles_labels()
+
+ # Filter for unique and visible handles/labels
+ unique_labels = []
+ unique_handles = []
+ for h, l in zip(handles, labels):
+ # Check if the handle has a get_visible method and if it returns True
+ # For fill_between objects (ribbons, envelopes, regions), get_visible might not exist or behave differently
+ # For these, we assume they are visible if they are in the list of objects
+ is_visible = True
+ if hasattr(h, "get_visible"):
+ is_visible = h.get_visible()
+ elif isinstance(
+ h, mpl.collections.PolyCollection
+ ): # For fill_between objects
+ # PolyCollection doesn't have get_visible, but its patches might.
+ # Or we can assume it's visible if it's part of the current plot.
+ # For now, assume it's visible if it's a PolyCollection and has data.
+ is_visible = len(h.get_paths()) > 0 # Check if it has any paths to draw
+
+ if l not in unique_labels and is_visible:
+ unique_labels.append(l)
+ unique_handles.append(h)
+
+ # Create a hash of current handles/labels for efficient comparison
+ current_hash = hash(tuple(id(h) for h in unique_handles) + tuple(unique_labels))
+
+ # Check if legend content actually changed
+ if (
+ not hasattr(self, "_last_legend_hash")
+ or self._last_legend_hash != current_hash
+ ):
+ if self._legend is not None:
+ self._legend.remove() # Remove old legend to prevent duplicates
+
+ if unique_handles: # Only create legend if there are handles to show
+ self._legend = self.ax.legend(
+ unique_handles, unique_labels, loc="lower right"
+ )
+ else:
+ self._legend = None # No legend to show
+
+ self._current_legend_handles = unique_handles
+ self._current_legend_labels = unique_labels
+ self._last_legend_hash = current_hash
+
+ def _clear_navigation_history(self):
+ """Clear matplotlib's navigation history when coordinate system changes."""
+ if (
+ self.fig
+ and self.fig.canvas
+ and hasattr(self.fig.canvas, "toolbar")
+ and self.fig.canvas.toolbar
+ ):
+ toolbar = self.fig.canvas.toolbar
+ if hasattr(toolbar, "_nav_stack"):
+ toolbar._nav_stack.clear()
+
+ def _push_current_view(self):
+ """Push current view to navigation history as new base."""
+ if (
+ self.fig
+ and self.fig.canvas
+ and hasattr(self.fig.canvas, "toolbar")
+ and self.fig.canvas.toolbar
+ ):
+ toolbar = self.fig.canvas.toolbar
+ if hasattr(toolbar, "push_current"):
+ toolbar.push_current()
+
+ def _update_axis_formatting(self) -> None:
+ """Update axis labels and formatters."""
+ if self.state.offset_time_raw is not None:
+ offset_value = self.state.offset_time_raw * (
+ 1e3
+ if self.state.offset_unit == "ms"
+ else 1e6
+ if self.state.offset_unit == "us"
+ else 1e9
+ if self.state.offset_unit == "ns"
+ else 1.0
+ )
+ xlabel = f"Time ({self.state.current_time_unit}) + {offset_value:.3g} {self.state.offset_unit}"
+ else:
+ xlabel = f"Time ({self.state.current_time_unit})"
+
+ self.ax.set_xlabel(xlabel)
+
+ formatter = _create_time_formatter(
+ self.state.offset_time_raw, self.state.current_time_scale
+ )
+ self.ax.xaxis.set_major_formatter(formatter)
+
+ def _update_overlay_lines(
+ self, plot_data: Dict[str, Any], show_overlays: bool
+ ) -> None:
+ """Update overlay lines based on zoom level and data availability."""
+ # Clear existing overlay lines from the plot
+ # This method is not currently used in the provided code, but if it were,
+ # it would need to be updated to use the new decimation strategy.
+ # For now, leaving it as is, assuming it's a placeholder or for future_use.
+ # If it were to be used, it would need to call decimate_for_view for each overlay line.
+ pass # No _overlay_lines attribute in this class, this method is unused.
+
+ def _update_y_limits(self, plot_data: Dict[str, Any], use_envelope: bool) -> None:
+ """Update y-axis limits to fit current data."""
+ y_min_data = float("inf")
+ y_max_data = float("-inf")
+
+ # Process each trace
+ for trace_idx in range(self.data.num_traces):
+ x_new_key = f"x_new_{trace_idx}"
+ x_min_key = f"x_min_{trace_idx}"
+ x_max_key = f"x_max_{trace_idx}"
+
+ if x_new_key not in plot_data:
+ continue
+
+ # Include signal data
+ if len(plot_data[x_new_key]) > 0:
+ y_min_data = min(y_min_data, np.min(plot_data[x_new_key]))
+ y_max_data = max(y_max_data, np.max(plot_data[x_new_key]))
+
+ # Include envelope data if available
+ if use_envelope and x_min_key in plot_data and x_max_key in plot_data:
+ if (
+ plot_data[x_min_key] is not None
+ and plot_data[x_max_key] is not None
+ and len(plot_data[x_min_key]) > 0
+ ):
+ y_min_data = min(y_min_data, np.min(plot_data[x_min_key]))
+ y_max_data = max(y_max_data, np.max(plot_data[x_max_key]))
+
+ # Include custom lines
+ for line_obj, _ in self._line_objects[trace_idx]:
+ # Check if line_obj is a Line2D or PolyCollection
+ if isinstance(line_obj, mpl.lines.Line2D):
+ y_data = line_obj.get_ydata()
+ if len(y_data) > 0:
+ y_min_data = min(y_min_data, np.min(y_data))
+ y_max_data = max(y_max_data, np.max(y_data))
+ elif isinstance(line_obj, mpl.collections.PolyCollection):
+ # For fill_between objects, iterate through paths to get y-coordinates
+ for path in line_obj.get_paths():
+ vertices = path.vertices
+ if len(vertices) > 0:
+ y_min_data = min(y_min_data, np.min(vertices[:, 1]))
+ y_max_data = max(y_max_data, np.max(vertices[:, 1]))
+
+ # Include ribbon data
+ for ribbon_obj, _ in self._ribbon_objects[trace_idx]:
+ # For fill_between objects, we need to get the paths
+ if hasattr(ribbon_obj, "get_paths") and len(ribbon_obj.get_paths()) > 0:
+ for path in ribbon_obj.get_paths():
+ vertices = path.vertices
+ if len(vertices) > 0:
+ y_min_data = min(y_min_data, np.min(vertices[:, 1]))
+ y_max_data = max(y_max_data, np.max(vertices[:, 1]))
+
+ # Handle case where no data was found
+ if y_min_data == float("inf") or y_max_data == float("-inf"):
+ self.ax.set_ylim(0, 1)
+ return
+
+ data_range = y_max_data - y_min_data
+ data_mean = (y_min_data + y_max_data) / 2
+
+ # Use min_y_range to ensure a minimum visible range
+ min_visible_range = self.min_y_range
+
+ if data_range < min_visible_range:
+ y_min = data_mean - min_visible_range / 2
+ y_max = data_mean + min_visible_range / 2
+ else:
+ y_margin = self.y_margin_fraction * data_range
+ y_min = y_min_data - y_margin
+ y_max = y_max_data + y_margin
+
+ self.ax.set_ylim(y_min, y_max)
+
+ def _update_plot_data(self, ax_obj) -> None:
+ """Update plot based on current view."""
+ if self.state.is_updating():
+ return
+
+ self.state.set_updating(True)
+
+ try:
+ try:
+ view_params = self._calculate_view_parameters(ax_obj)
+ plot_data = self._get_plot_data(view_params)
+ self._render_plot_elements(plot_data, view_params)
+ self._update_regions_and_legend(view_params["xlim_display"])
+ self.fig.canvas.draw_idle()
+ except Exception as e:
+ warnings.warn(f"Error updating plot: {e}", RuntimeWarning)
+ # Try to recover by resetting to home view
+ print("Attempting to recover by resetting to home view")
+ self.home()
+ finally:
+ self.state.set_updating(False)
+
+ def _calculate_view_parameters(self, ax_obj) -> Dict[str, Any]:
+ """Calculate view parameters from current axis state."""
+ try:
+ xlim_raw = self.coord_manager.get_current_view_raw(ax_obj)
+
+ # Validate xlim_raw values
+ if not np.isfinite(xlim_raw[0]) or not np.isfinite(xlim_raw[1]):
+ warnings.warn(
+ f"Invalid xlim_raw from axis: {xlim_raw}. Using initial view.", RuntimeWarning
+ )
+ xlim_raw = self._initial_xlim_raw
+
+ # Ensure xlim_raw is in ascending order
+ if xlim_raw[0] > xlim_raw[1]:
+ warnings.warn(f"xlim_raw values out of order: {xlim_raw}. Swapping.", RuntimeWarning)
+ xlim_raw = (xlim_raw[1], xlim_raw[0])
+
+ time_span_raw = xlim_raw[1] - xlim_raw[0]
+ use_envelope = self.state.should_use_envelope(time_span_raw)
+ current_mode = self.MODE_ENVELOPE if use_envelope else self.MODE_DETAIL
+
+ # Update coordinate system if needed
+ coordinate_system_changed = self.state.update_display_params(
+ xlim_raw, time_span_raw
+ )
+ if coordinate_system_changed:
+ self._update_coordinate_system(xlim_raw, time_span_raw)
+
+ return {
+ "xlim_raw": xlim_raw,
+ "time_span_raw": time_span_raw,
+ "xlim_display": self.coord_manager.xlim_raw_to_display(xlim_raw),
+ "use_envelope": use_envelope,
+ "current_mode": current_mode,
+ }
+ except Exception as e:
+ warnings.warn(f"Error calculating view parameters: {e}", RuntimeWarning)
+ # Return safe default values
+ return {
+ "xlim_raw": self._initial_xlim_raw,
+ "time_span_raw": self._initial_xlim_raw[1] - self._initial_xlim_raw[0],
+ "xlim_display": self.coord_manager.xlim_raw_to_display(
+ self._initial_xlim_raw
+ ),
+ "use_envelope": True,
+ "current_mode": self.MODE_ENVELOPE,
+ }
+
+ def _get_plot_data(self, view_params: Dict[str, Any]) -> Dict[str, Any]:
+ """Get decimated plot data for current view."""
+ plot_data = {}
+
+ # Process each trace
+ for trace_idx in range(self.data.num_traces):
+ t_arr = self.data.t_arrays[trace_idx]
+ x_arr = self.data.x_arrays[trace_idx]
+
+ try:
+ t_raw, x_new, x_min, x_max = self.decimator.decimate_for_view(
+ t_arr,
+ x_arr,
+ view_params["xlim_raw"],
+ self.max_plot_points,
+ view_params["use_envelope"],
+ trace_idx, # Pass trace_id to use pre-decimated data
+ envelope_window_samples=None, # Envelope window calculated automatically
+ mode_switch_threshold=self.mode_switch_threshold, # Pass mode switch threshold
+ return_envelope_min_max=True, # Main signal always returns envelope min/max if use_envelope is True
+ )
+
+ if len(t_raw) == 0:
+ warnings.warn(
+ f"No data in current view for trace {trace_idx}. View range: {view_params['xlim_raw']}", UserWarning
+ )
+ # Add empty arrays for this trace
+ plot_data[f"t_display_{trace_idx}"] = np.array([], dtype=np.float32)
+ plot_data[f"x_new_{trace_idx}"] = np.array([], dtype=np.float32)
+ plot_data[f"x_min_{trace_idx}"] = None
+ plot_data[f"x_max_{trace_idx}"] = None
+ continue
+
+ t_display = self.coord_manager.raw_to_display(t_raw)
+
+ # Store data for this trace
+ plot_data[f"t_display_{trace_idx}"] = t_display
+ plot_data[f"x_new_{trace_idx}"] = x_new
+ plot_data[f"x_min_{trace_idx}"] = x_min
+ plot_data[f"x_max_{trace_idx}"] = x_max
+
+ except Exception as e:
+ warnings.warn(f"Error getting plot data for trace {trace_idx}: {e}", RuntimeWarning)
+ # Add empty arrays for this trace to prevent further errors
+ plot_data[f"t_display_{trace_idx}"] = np.array([], dtype=np.float32)
+ plot_data[f"x_new_{trace_idx}"] = np.array([], dtype=np.float32)
+ plot_data[f"x_min_{trace_idx}"] = None
+ plot_data[f"x_max_{trace_idx}"] = None
+
+ return plot_data
+
+ def _render_plot_elements(
+ self, plot_data: Dict[str, Any], view_params: Dict[str, Any]
+ ) -> None:
+ """Render all plot elements with current data."""
+ # Store the current plot data for use by other methods
+ self._current_plot_data = plot_data
+
+ # Check if we have any data to plot
+ has_data = False
+ for trace_idx in range(self.data.num_traces):
+ key = f"t_display_{trace_idx}"
+ if key in plot_data and len(plot_data[key]) > 0:
+ has_data = True
+ break
+
+ if not has_data:
+ warnings.warn("No data to plot, clearing all elements", UserWarning)
+ # If no data, clear all lines and return
+ for i in range(self.data.num_traces):
+ self._signal_lines[i].set_data([], [])
+ if self._envelope_fills[i] is not None:
+ self._envelope_fills[i].remove()
+ self._envelope_fills[i] = None
+
+ # Clear custom elements
+ self._clear_custom_elements(i)
+
+ self.ax.set_ylim(0, 1) # Set a default y-limit
+ return
+
+ # Process each trace
+ for trace_idx in range(self.data.num_traces):
+ t_display_key = f"t_display_{trace_idx}"
+ x_new_key = f"x_new_{trace_idx}"
+ x_min_key = f"x_min_{trace_idx}"
+ x_max_key = f"x_max_{trace_idx}"
+
+ if t_display_key not in plot_data or len(plot_data[t_display_key]) == 0:
+ # No data for this trace, hide its elements
+ self._signal_lines[trace_idx].set_data([], [])
+ if self._envelope_fills[trace_idx] is not None:
+ self._envelope_fills[trace_idx].remove()
+ self._envelope_fills[trace_idx] = None
+
+ # Clear custom elements
+ self._clear_custom_elements(trace_idx)
+ continue
+
+ # Update signal display
+ envelope_data = None
+ if (
+ view_params["use_envelope"]
+ and x_min_key in plot_data
+ and x_max_key in plot_data
+ ):
+ if (
+ plot_data[x_min_key] is not None
+ and plot_data[x_max_key] is not None
+ ):
+ envelope_data = (plot_data[x_min_key], plot_data[x_max_key])
+
+ self._update_signal_display(
+ trace_idx, plot_data[t_display_key], plot_data[x_new_key], envelope_data
+ )
+
+ # Update y-limits
+ self._update_y_limits(plot_data, view_params["use_envelope"])
+
+ def _update_coordinate_system(
+ self, xlim_raw: Tuple[np.float32, np.float32], time_span_raw: np.float32
+ ) -> None:
+ """Update coordinate system and axis formatting."""
+ self._clear_region_fills()
+ self._update_axis_formatting()
+ self._update_tick_locator(time_span_raw)
+
+ xlim_display = self.coord_manager.xlim_raw_to_display(xlim_raw)
+ self.ax.set_xlim(xlim_display)
+
+ self._clear_navigation_history()
+ self._push_current_view()
+
+ def _update_regions_and_legend(
+ self, xlim_display: Tuple[np.float32, np.float32]
+ ) -> None:
+ """Update regions and legend."""
+ self._refresh_region_display(xlim_display)
+ self._update_legend()
+
+ def _refresh_region_display(
+ self, xlim_display: Tuple[np.float32, np.float32]
+ ) -> None:
+ """Refresh region display for current view."""
+ self._clear_region_fills()
+
+ # Get current mode
+ current_mode = (
+ self.MODE_ENVELOPE
+ if self.state.current_mode == "envelope"
+ else self.MODE_DETAIL
+ )
+
+ for trace_idx in range(self.data.num_traces):
+ # Process each region definition
+ for region_def in self._regions[trace_idx]:
+ # Skip if not visible in current mode
+ if not (region_def["display_mode"] & current_mode):
+ continue
+
+ regions = region_def["regions"]
+ if regions is None or len(regions) == 0:
+ continue
+
+ color = region_def["color"]
+ label = region_def["label"]
+ alpha = region_def["alpha"]
+ first_visible_region = True
+
+ for t_start, t_end in regions:
+ t_start_display = self.coord_manager.raw_to_display(t_start)
+ t_end_display = self.coord_manager.raw_to_display(t_end)
+
+ # Check if region overlaps with current view
+ if not (
+ t_end_display <= xlim_display[0]
+ or t_start_display >= xlim_display[1]
+ ):
+ # Only show label for first visible region
+ current_label = label if first_visible_region else ""
+ if first_visible_region and len(regions) > 1:
+ current_label = f"{label} ({len(regions)})"
+
+ fill = self.ax.axvspan(
+ t_start_display,
+ t_end_display,
+ alpha=alpha,
+ color=color,
+ linewidth=0.5,
+ label=current_label,
+ zorder=region_def["zorder"],
+ )
+ self._region_objects[trace_idx].append((fill, region_def))
+ first_visible_region = False
+
+ def _clear_region_fills(self) -> None:
+ """Clear all region fills."""
+ for trace_fills in self._region_objects:
+ for fill_item in trace_fills:
+ # Handle both old format (just fill object) and new format (tuple)
+ if isinstance(fill_item, tuple):
+ fill, _ = fill_item # Extract the fill object from the tuple
+ fill.remove()
+ else:
+ fill_item.remove() # Old format - direct fill object
+ trace_fills.clear()
+
+ def _setup_plot_elements(self) -> None:
+ """
+ Initialise matplotlib plot elements (lines, fills) for each trace.
+ This is called once during render().
+ """
+ if self.fig is None or self.ax is None:
+ raise RuntimeError(
+ "Figure and Axes must be created before setting up plot elements."
+ )
+
+ # Create initial signal line objects for each trace
+ for i in range(self.data.num_traces):
+ color = self.data.get_trace_color(i)
+ name = self.data.get_trace_name(i)
+
+ # Signal line
+ (line_signal,) = self.ax.plot(
+ [],
+ [],
+ label="Raw data" if self.data.num_traces == 1 else f"Raw data ({name})",
+ color=color,
+ alpha=self.signal_alpha,
+ )
+ self._signal_lines.append(line_signal)
+
+ def _connect_callbacks(self) -> None:
+ """Connect matplotlib callbacks."""
+ if self.ax is None:
+ raise RuntimeError("Axes must be created before connecting callbacks.")
+ self.ax.callbacks.connect("xlim_changed", self._update_plot_data)
+
+ def _setup_toolbar_overrides(self) -> None:
+ """Override matplotlib toolbar methods (e.g., home button)."""
+ if (
+ self.fig
+ and self.fig.canvas
+ and hasattr(self.fig.canvas, "toolbar")
+ and self.fig.canvas.toolbar
+ ):
+ toolbar = self.fig.canvas.toolbar
+
+ # Store original methods
+ self._original_home = getattr(toolbar, "home", None)
+ self._original_push_current = getattr(toolbar, "push_current", None)
+
+ # Create our custom home method
+ def custom_home(*args, **kwargs):
+ self.home()
+
+ # Override both the method and try to find the actual button
+ toolbar.home = custom_home
+
+ # For Qt backend, also override the action
+ if hasattr(toolbar, "actions"):
+ for action in toolbar.actions():
+ if hasattr(action, "text") and hasattr(action, "objectName"):
+ action_text = (
+ action.text() if callable(action.text) else str(action.text)
+ )
+ action_name = (
+ action.objectName()
+ if callable(action.objectName)
+ else str(action.objectName)
+ )
+ if action_text == "Home" or "home" in action_name.lower():
+ if hasattr(action, "triggered"):
+ action.triggered.disconnect()
+ action.triggered.connect(custom_home)
+ break
+
+ # For other backends, try to override the button callback
+ if hasattr(toolbar, "_buttons") and "Home" in toolbar._buttons:
+ home_button = toolbar._buttons["Home"]
+ if hasattr(home_button, "configure"):
+ home_button.configure(command=custom_home)
+
+ def _set_initial_view_and_labels(self) -> None:
+ """Set initial axis limits, title, and labels."""
+ if self.ax is None:
+ raise RuntimeError(
+ "Axes must be created before setting initial view and labels."
+ )
+
+ # Create title based on number of traces
+ if self.data.num_traces == 1:
+ self.ax.set_title(f"{self.data.names[0]}")
+ else:
+ # Multiple traces - just show "Multiple Traces"
+ self.ax.set_title(f"Multiple Traces ({self.data.num_traces})")
+ self.ax.set_xlabel(f"Time ({self.state.current_time_unit})")
+ self.ax.set_ylabel("Signal")
+
+ # Set initial xlim
+ initial_xlim_display = self.coord_manager.xlim_raw_to_display(
+ self._initial_xlim_raw
+ )
+ self.ax.set_xlim(initial_xlim_display)
+
+ def render(self) -> None:
+ """
+ Renders the oscilloscope plot. This method must be called after all
+ data and visualization elements have been added.
+ """
+ if self.fig is not None or self.ax is not None:
+ warnings.warn(
+ "Plot already rendered. Call `home()` to reset or create a new instance.", UserWarning
+ )
+ return
+
+ print("Rendering plot...")
+ self.fig, self.ax = plt.subplots(figsize=(10, 5))
+
+ self._setup_plot_elements()
+ self._connect_callbacks()
+ self._setup_toolbar_overrides()
+ self._set_initial_view_and_labels()
+
+ # Calculate initial parameters for the full view
+ t_start, t_end = self.data.get_global_time_range()
+ full_time_span = t_end - t_start
+
+ print(
+ f"Initial render: full time span={full_time_span:.3e}s, envelope_limit={self.mode_switch_threshold:.3e}s"
+ )
+
+ # Set initial display state based on full view
+ self.state.current_time_unit, self.state.current_time_scale = (
+ _get_optimal_time_unit_and_scale(full_time_span)
+ )
+ self.state.current_mode = (
+ "envelope" if self.state.should_use_envelope(full_time_span) else "detail"
+ )
+
+ # Force initial draw of all elements by calling _update_plot_data
+ # This will also update the legend and regions
+ self.state.set_updating(False) # Ensure not in updating state for first call
+ self._update_plot_data(self.ax)
+ self.fig.canvas.draw_idle()
+ print("Plot rendering complete.")
+
+ def home(self) -> None:
+ """Return to initial full view with complete state reset."""
+ if self.ax is None: # Fix: Changed '===' to 'is'
+ warnings.warn("Plot not rendered yet. Cannot go home.", UserWarning)
+ return
+
+ # Disconnect callback temporarily
+ callback_id = None
+ for cid, callback in self.ax.callbacks.callbacks["xlim_changed"].items():
+ if getattr(callback, "__func__", callback) == self._update_plot_data:
+ callback_id = cid
+ break
+
+ if callback_id is not None:
+ self.ax.callbacks.disconnect(callback_id)
+
+ try:
+ self.state.set_updating(True)
+ self.state.reset_to_initial_state()
+ self.decimator.clear_cache()
+ self._clear_region_fills()
+
+ # Clear all custom elements and reset _last_mode for each trace to force redraw
+ for trace_idx in range(self.data.num_traces):
+ self._clear_custom_elements(trace_idx)
+ self._last_mode[trace_idx] = None
+
+ # Reset axis formatting
+ self.ax.set_xlabel(f"Time ({self.state.original_time_unit})")
+ self.ax.xaxis.set_major_formatter(mpl.ticker.ScalarFormatter())
+ self.ax.xaxis.set_major_locator(mpl.ticker.AutoLocator())
+
+ # Reset view
+ self.coord_manager.set_view_raw(self.ax, self._initial_xlim_raw)
+
+ # Manually trigger update for the home view
+ # This will re-evaluate use_envelope, current_mode, and redraw everything
+ self._update_plot_data(self.ax)
+
+ self.state.set_updating(False)
+
+ finally:
+ self.ax.callbacks.connect("xlim_changed", self._update_plot_data)
+
+ self.fig.canvas.draw()
+ print(f"Home view restored: {self.state.original_time_unit} scale")
+
+ def refresh(self) -> None:
+ """Force a complete refresh of the plot without changing the current view."""
+ if self.ax is None:
+ warnings.warn("Plot not rendered yet. Cannot refresh.", UserWarning)
+ return
+
+ # Temporarily bypass the updating state for forced refresh
+ was_updating = self.state.is_updating()
+ self.state.set_updating(False)
+ try:
+ self._update_plot_data(self.ax)
+ finally:
+ self.state.set_updating(was_updating)
+ self.fig.canvas.draw_idle()
+
+ def show(self) -> None:
+ """Display the plot."""
+ if self.fig is None:
+ self.render() # Render if not already rendered
+ plt.show()