From bce98fc796e32f4a439307dd3b65ef28dc6a73ad Mon Sep 17 00:00:00 2001 From: Sam Scholten Date: Mon, 27 Oct 2025 18:24:01 +1000 Subject: refactor: Replace Loguru with warnings and print statements Co-authored-by: aider (openrouter/anthropic/claude-sonnet-4) --- src/scopekit/coordinate_manager.py | 207 ++- src/scopekit/data_manager.py | 903 +++++----- src/scopekit/decimation.py | 1273 +++++++------ src/scopekit/display_state.py | 580 +++--- src/scopekit/plot.py | 3450 +++++++++++++++++------------------- 5 files changed, 3060 insertions(+), 3353 deletions(-) diff --git a/src/scopekit/coordinate_manager.py b/src/scopekit/coordinate_manager.py index 6ee1709..47d4e47 100644 --- a/src/scopekit/coordinate_manager.py +++ b/src/scopekit/coordinate_manager.py @@ -1,107 +1,100 @@ -from typing import Tuple - -import numpy as np -from loguru import logger - - -class CoordinateManager: - """ - Handles coordinate transformations between raw time and display coordinates. - - Centralises all coordinate conversion logic to prevent inconsistencies. - """ - - def __init__(self, display_state): - """ - Initialise the coordinate manager. - - Parameters - ---------- - display_state : DisplayState - Reference to the display state object. - """ - self.state = display_state - - def get_current_view_raw(self, ax): - """Get current view in raw coordinates.""" - try: - xlim_display = ax.get_xlim() - logger.debug(f"Converting display xlim {xlim_display} to raw coordinates") - - # Validate display limits - if not np.isfinite(xlim_display[0]) or not np.isfinite(xlim_display[1]): - logger.warning(f"Invalid display limits: {xlim_display}") - # Try to get a valid view from the figure - if hasattr(ax, "figure") and hasattr(ax.figure, "canvas"): - ax.figure.canvas.draw() - xlim_display = ax.get_xlim() - if not np.isfinite(xlim_display[0]) or not np.isfinite( - xlim_display[1] - ): - # Still invalid, use a default range - logger.warning( - "Still invalid after redraw, using default range" - ) - xlim_display = (0, 1) - - raw_coords = self.xlim_display_to_raw(xlim_display) - logger.debug(f"Converted to raw coordinates: {raw_coords}") - return raw_coords - except Exception as e: - logger.exception(f"Error getting current view: {e}") - # Return a safe default - return (np.float32(0.0), np.float32(1.0)) - - def set_view_raw(self, ax, xlim_raw): - """Set view using raw coordinates.""" - xlim_display = self.xlim_raw_to_display(xlim_raw) - ax.set_xlim(xlim_display) - - def raw_to_display(self, t_raw: np.ndarray) -> np.ndarray: - """Convert raw time to display coordinates.""" - if self.state.offset_time_raw is not None: - return (t_raw - self.state.offset_time_raw) * self.state.current_time_scale - else: - return t_raw * self.state.current_time_scale - - def display_to_raw(self, t_display: np.ndarray) -> np.ndarray: - """Convert display coordinates to raw time.""" - t_raw = t_display / self.state.current_time_scale - if self.state.offset_time_raw is not None: - t_raw += self.state.offset_time_raw - - # Only log for scalar values to avoid excessive output - if isinstance(t_display, (int, float, np.number)): - logger.debug( - f"Converting display time {t_display:.6f} to raw time {t_raw:.6f} (scale={self.state.current_time_scale}, offset={self.state.offset_time_raw})" - ) - return t_raw - - def xlim_display_to_raw( - self, xlim_display: Tuple[float, float] - ) -> Tuple[np.float32, np.float32]: - """Convert display xlim tuple to raw time coordinates.""" - try: - # Ensure values are finite - if not np.isfinite(xlim_display[0]) or not np.isfinite(xlim_display[1]): - logger.warning( - f"Non-finite display limits: {xlim_display}, using defaults" - ) - return (np.float32(0.0), np.float32(1.0)) - - return ( - self.display_to_raw(np.float32(xlim_display[0])), - self.display_to_raw(np.float32(xlim_display[1])), - ) - except Exception as e: - logger.exception(f"Error converting display to raw coordinates: {e}") - return (np.float32(0.0), np.float32(1.0)) - - def xlim_raw_to_display( - self, xlim_raw: Tuple[np.float32, np.float32] - ) -> Tuple[np.float32, np.float32]: - """Convert raw time xlim tuple to display coordinates.""" - return ( - self.raw_to_display(xlim_raw[0]), - self.raw_to_display(xlim_raw[1]), - ) +from typing import Tuple +import warnings + +import numpy as np + + +class CoordinateManager: + """ + Handles coordinate transformations between raw time and display coordinates. + + Centralises all coordinate conversion logic to prevent inconsistencies. + """ + + def __init__(self, display_state): + """ + Initialise the coordinate manager. + + Parameters + ---------- + display_state : DisplayState + Reference to the display state object. + """ + self.state = display_state + + def get_current_view_raw(self, ax): + """Get current view in raw coordinates.""" + try: + xlim_display = ax.get_xlim() + + # Validate display limits + if not np.isfinite(xlim_display[0]) or not np.isfinite(xlim_display[1]): + warnings.warn(f"Invalid display limits: {xlim_display}", RuntimeWarning) + # Try to get a valid view from the figure + if hasattr(ax, "figure") and hasattr(ax.figure, "canvas"): + ax.figure.canvas.draw() + xlim_display = ax.get_xlim() + if not np.isfinite(xlim_display[0]) or not np.isfinite( + xlim_display[1] + ): + # Still invalid, use a default range + warnings.warn( + "Still invalid after redraw, using default range", RuntimeWarning + ) + xlim_display = (0, 1) + + raw_coords = self.xlim_display_to_raw(xlim_display) + return raw_coords + except Exception as e: + warnings.warn(f"Error getting current view: {e}", RuntimeWarning) + # Return a safe default + return (np.float32(0.0), np.float32(1.0)) + + def set_view_raw(self, ax, xlim_raw): + """Set view using raw coordinates.""" + xlim_display = self.xlim_raw_to_display(xlim_raw) + ax.set_xlim(xlim_display) + + def raw_to_display(self, t_raw: np.ndarray) -> np.ndarray: + """Convert raw time to display coordinates.""" + if self.state.offset_time_raw is not None: + return (t_raw - self.state.offset_time_raw) * self.state.current_time_scale + else: + return t_raw * self.state.current_time_scale + + def display_to_raw(self, t_display: np.ndarray) -> np.ndarray: + """Convert display coordinates to raw time.""" + t_raw = t_display / self.state.current_time_scale + if self.state.offset_time_raw is not None: + t_raw += self.state.offset_time_raw + + return t_raw + + def xlim_display_to_raw( + self, xlim_display: Tuple[float, float] + ) -> Tuple[np.float32, np.float32]: + """Convert display xlim tuple to raw time coordinates.""" + try: + # Ensure values are finite + if not np.isfinite(xlim_display[0]) or not np.isfinite(xlim_display[1]): + warnings.warn( + f"Non-finite display limits: {xlim_display}, using defaults", RuntimeWarning + ) + return (np.float32(0.0), np.float32(1.0)) + + return ( + self.display_to_raw(np.float32(xlim_display[0])), + self.display_to_raw(np.float32(xlim_display[1])), + ) + except Exception as e: + warnings.warn(f"Error converting display to raw coordinates: {e}", RuntimeWarning) + return (np.float32(0.0), np.float32(1.0)) + + def xlim_raw_to_display( + self, xlim_raw: Tuple[np.float32, np.float32] + ) -> Tuple[np.float32, np.float32]: + """Convert raw time xlim tuple to display coordinates.""" + return ( + self.raw_to_display(xlim_raw[0]), + self.raw_to_display(xlim_raw[1]), + ) diff --git a/src/scopekit/data_manager.py b/src/scopekit/data_manager.py index c2dd09c..78a8ea3 100644 --- a/src/scopekit/data_manager.py +++ b/src/scopekit/data_manager.py @@ -1,452 +1,451 @@ -from typing import Any, Dict, List, Optional, Tuple, Union - -import numpy as np -from loguru import logger - - -class TimeSeriesDataManager: - """ - Manages time series data storage and basic operations. - - Handles raw data storage, time scaling, and basic data access patterns. - It can also store optional associated data like background estimates, - global noise, and overlay lines. - - Supports multiple traces with shared time axis or individual time axes. - """ - - def __init__( - self, - t: Union[np.ndarray, List[np.ndarray]], - x: Union[np.ndarray, List[np.ndarray]], - name: Union[str, List[str]] = "Time Series", - trace_colors: Optional[List[str]] = None, - ): - """ - Initialise the data manager. - - Parameters - ---------- - t : Union[np.ndarray, List[np.ndarray]] - Time array(s) (raw time in seconds). Can be a single array shared by all traces - or a list of arrays, one per trace. - x : Union[np.ndarray, List[np.ndarray]] - Signal array(s). If t is a single array, x can be a 2D array (traces x samples) - or a list of 1D arrays. If t is a list, x must be a list of equal length. - name : Union[str, List[str]], default="Time Series" - Name(s) for identification. Can be a single string or a list of strings. - trace_colors : Optional[List[str]], default=None - Colors for each trace. If None, default colors will be used. - - Raises - ------ - ValueError - If input arrays have mismatched lengths or time array is not monotonic. - """ - # Convert inputs to standardized format: lists of arrays - self.t_arrays, self.x_arrays, self.names, self.colors = ( - self._standardize_inputs(t, x, name, trace_colors) - ) - - # Validate all data - for i, (t_arr, x_arr) in enumerate(zip(self.t_arrays, self.x_arrays)): - self._validate_core_data(t_arr, x_arr, trace_idx=i) - - # Optional associated data (per trace) - self._overlay_lines: List[List[Dict[str, Any]]] = [ - [] for _ in range(len(self.t_arrays)) - ] - - # For backward compatibility - if len(self.t_arrays) > 0: - self.t = self.t_arrays[0] # Primary time array - self.x = self.x_arrays[0] # Primary signal array - self.name = self.names[0] # Primary name - - def _standardize_inputs( - self, - t: Union[np.ndarray, List[np.ndarray]], - x: Union[np.ndarray, List[np.ndarray]], - name: Union[str, List[str]], - trace_colors: Optional[List[str]], - ) -> Tuple[List[np.ndarray], List[np.ndarray], List[str], List[str]]: - """ - Standardize inputs to lists of arrays. - - Parameters - ---------- - t : Union[np.ndarray, List[np.ndarray]] - Time array(s). - x : Union[np.ndarray, List[np.ndarray]] - Signal array(s). - name : Union[str, List[str]] - Name(s) for identification. - trace_colors : Optional[List[str]] - Colors for each trace. - - Returns - ------- - Tuple[List[np.ndarray], List[np.ndarray], List[str], List[str]] - Standardized lists of time arrays, signal arrays, names, and colors. - """ - # Default colors for traces - default_colors = [ - "black", - "blue", - "red", - "green", - "purple", - "orange", - "brown", - "pink", - "gray", - "olive", - ] - - # Handle time arrays - if isinstance(t, list): - t_arrays = [np.asarray(t_arr, dtype=np.float32) for t_arr in t] - n_traces = len(t_arrays) - else: - t_arr = np.asarray(t, dtype=np.float32) - - # Check if x is 2D array or list - if isinstance(x, list): - n_traces = len(x) - t_arrays = [t_arr.copy() for _ in range(n_traces)] - elif x.ndim == 2: - n_traces = x.shape[0] - t_arrays = [t_arr.copy() for _ in range(n_traces)] - else: - n_traces = 1 - t_arrays = [t_arr] - - # Handle signal arrays - if isinstance(x, list): - if len(x) != n_traces: - raise ValueError( - f"Number of signal arrays ({len(x)}) must match number of time arrays ({n_traces})" - ) - x_arrays = [np.asarray(x_arr, dtype=np.float32) for x_arr in x] - elif x.ndim == 2: - if x.shape[0] != n_traces: - raise ValueError( - f"First dimension of 2D signal array ({x.shape[0]}) must match number of time arrays ({n_traces})" - ) - x_arrays = [np.asarray(x[i], dtype=np.float32) for i in range(n_traces)] - else: - if n_traces != 1: - raise ValueError( - f"Single signal array provided but expected {n_traces} arrays" - ) - x_arrays = [np.asarray(x, dtype=np.float32)] - - # Handle names - if isinstance(name, list): - if len(name) != n_traces: - logger.warning( - f"Number of names ({len(name)}) doesn't match number of traces ({n_traces}). Using defaults." - ) - names = [f"Trace {i + 1}" for i in range(n_traces)] - else: - names = name - else: - if n_traces == 1: - names = [name] - else: - if ( - name == "Time Series" - ): # Only use default naming if the default name was used - names = [f"Trace {i + 1}" for i in range(n_traces)] - else: - names = [f"{name} {i + 1}" for i in range(n_traces)] - - # Handle colors - if trace_colors is not None: - if len(trace_colors) < n_traces: - logger.warning( - f"Not enough colors provided ({len(trace_colors)}). Using defaults for remaining traces." - ) - colors = trace_colors + [ - default_colors[i % len(default_colors)] - for i in range(len(trace_colors), n_traces) - ] - else: - colors = trace_colors[:n_traces] - else: - colors = [default_colors[i % len(default_colors)] for i in range(n_traces)] - - return t_arrays, x_arrays, names, colors - - def _validate_core_data( - self, t: np.ndarray, x: np.ndarray, trace_idx: int = 0 - ) -> None: - """ - Validate core input data arrays for consistency and correctness. - - Parameters - ---------- - t : np.ndarray - Time array. - x : np.ndarray - Signal array. - trace_idx : int, default=0 - Index of the trace being validated (for error messages). - - Raises - ------ - ValueError - If arrays have mismatched lengths or time array is not monotonic. - """ - if len(t) != len(x): - raise ValueError( - f"Time and signal arrays for trace {trace_idx} must have the same length. Got t={len(t)}, x={len(x)}" - ) - if len(t) == 0: - logger.warning(f"Initialising trace {trace_idx} with empty arrays.") - return - - # Check time array is monotonic - if len(t) > 1: - # Use a small epsilon for floating-point comparison - tolerance = 1e-9 - if not np.all(np.diff(t) > tolerance): - problematic_diffs = np.diff(t)[np.diff(t) <= tolerance] - logger.warning( - f"Time array for trace {trace_idx} is not strictly monotonic increasing within tolerance {tolerance}. " - f"Problematic diffs (first 10): {problematic_diffs[:10]}. " - f"This may affect analysis results." - ) - - # Check for non-uniform sampling - self._check_uniform_sampling(t, trace_idx) - - @property - def overlay_lines(self) -> List[Dict[str, Any]]: - """Get overlay lines data for the primary trace.""" - return self._overlay_lines[0] if self._overlay_lines else [] - - def get_overlay_lines(self, trace_idx: int = 0) -> List[Dict[str, Any]]: - """Get overlay lines data for a specific trace.""" - if trace_idx < 0 or trace_idx >= len(self.t_arrays): - raise ValueError( - f"Invalid trace index: {trace_idx}. Must be between 0 and {len(self.t_arrays) - 1}." - ) - return self._overlay_lines[trace_idx] - - @property - def num_traces(self) -> int: - """Get the number of traces.""" - return len(self.t_arrays) - - def get_trace_color(self, trace_idx: int = 0) -> str: - """Get the color for a specific trace.""" - if trace_idx < 0 or trace_idx >= len(self.t_arrays): - raise ValueError( - f"Invalid trace index: {trace_idx}. Must be between 0 and {len(self.t_arrays) - 1}." - ) - return self.colors[trace_idx] - - def get_trace_name(self, trace_idx: int = 0) -> str: - """Get the name for a specific trace.""" - if trace_idx < 0 or trace_idx >= len(self.t_arrays): - raise ValueError( - f"Invalid trace index: {trace_idx}. Must be between 0 and {len(self.t_arrays) - 1}." - ) - return self.names[trace_idx] - - def set_overlay_lines( - self, - overlay_lines: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]], - trace_idx: Optional[int] = None, - ) -> None: - """ - Set overlay lines data. - - Parameters - ---------- - overlay_lines : Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]] - List of dictionaries defining overlay lines, or list of lists for multiple traces. - trace_idx : Optional[int], default=None - If provided, set overlay lines only for the specified trace. - If None, set for all traces if a nested list is provided, or for the first trace if a flat list. - """ - if trace_idx is not None: - # Set for specific trace - if trace_idx < 0 or trace_idx >= len(self.t_arrays): - raise ValueError( - f"Invalid trace index: {trace_idx}. Must be between 0 and {len(self.t_arrays) - 1}." - ) - - # Ensure we have a list of dictionaries - if not isinstance(overlay_lines, list): - raise ValueError( - f"overlay_lines must be a list of dictionaries. Got {type(overlay_lines)}." - ) - - # Check if it's a list of dictionaries (not a nested list) - if len(overlay_lines) > 0 and isinstance(overlay_lines[0], dict): - self._overlay_lines[trace_idx] = overlay_lines - else: - raise ValueError( - "Expected a list of dictionaries for overlay_lines when trace_idx is specified." - ) - else: - # Set for all traces or first trace - if len(overlay_lines) > 0 and isinstance(overlay_lines[0], list): - # Nested list provided - set for multiple traces - if len(overlay_lines) != len(self.t_arrays): - raise ValueError( - f"Number of overlay line lists ({len(overlay_lines)}) must match number of traces ({len(self.t_arrays)})." - ) - - for i, lines in enumerate(overlay_lines): - self._overlay_lines[i] = lines - else: - # Flat list provided - set for first trace - self._overlay_lines[0] = overlay_lines - - def get_time_range(self, trace_idx: int = 0) -> Tuple[np.float32, np.float32]: - """ - Get the full time range of the data. - - Parameters - ---------- - trace_idx : int, default=0 - Index of the trace to get the time range for. - - Returns - ------- - Tuple[np.float32, np.float32] - Start and end time of the data. - """ - if trace_idx < 0 or trace_idx >= len(self.t_arrays): - raise ValueError( - f"Invalid trace index: {trace_idx}. Must be between 0 and {len(self.t_arrays) - 1}." - ) - - t_arr = self.t_arrays[trace_idx] - if t_arr.size == 0: - return np.float32(0.0), np.float32(0.0) - return np.float32(t_arr[0]), np.float32(t_arr[-1]) - - def get_global_time_range(self) -> Tuple[np.float32, np.float32]: - """ - Get the global time range across all traces. - - Returns - ------- - Tuple[np.float32, np.float32] - Global start and end time across all traces. - """ - if len(self.t_arrays) == 0: - return np.float32(0.0), np.float32(0.0) - - t_min = np.float32( - min(t_arr[0] if t_arr.size > 0 else np.inf for t_arr in self.t_arrays) - ) - t_max = np.float32( - max(t_arr[-1] if t_arr.size > 0 else -np.inf for t_arr in self.t_arrays) - ) - - if np.isinf(t_min) or np.isinf(t_max): - return np.float32(0.0), np.float32(0.0) - - return t_min, t_max - - def get_data_in_range( - self, t_start: np.float32, t_end: np.float32, trace_idx: int = 0 - ) -> Tuple[np.ndarray, np.ndarray]: - """ - Extract data within a time range. - - Parameters - ---------- - t_start : np.float32 - Start time in raw seconds. - t_end : np.float32 - End time in raw seconds. - trace_idx : int, default=0 - Index of the trace to get data for. - - Returns - ------- - Tuple[np.ndarray, np.ndarray] - Time and signal arrays. - """ - if trace_idx < 0 or trace_idx >= len(self.t_arrays): - raise ValueError( - f"Invalid trace index: {trace_idx}. Must be between 0 and {len(self.t_arrays) - 1}." - ) - - t_arr = self.t_arrays[trace_idx] - x_arr = self.x_arrays[trace_idx] - - mask = (t_arr >= t_start) & (t_arr <= t_end) - if not np.any(mask): - logger.debug(f"No data in range [{t_start}, {t_end}] for trace {trace_idx}") - return ( - np.array([], dtype=np.float32), - np.array([], dtype=np.float32), - ) - - t_masked = t_arr[mask] - x_masked = x_arr[mask] - - return t_masked, x_masked - - def _check_uniform_sampling(self, t: np.ndarray, trace_idx: int = 0) -> None: - """ - Check if time array is uniformly sampled and issue warnings if not. - - Parameters - ---------- - t : np.ndarray - Time array to check. - trace_idx : int, default=0 - Index of the trace being checked (for warning messages). - """ - if len(t) < 3: - return # Not enough points to check uniformity - - # Calculate time differences - dt = np.diff(t) - - # Calculate statistics - dt_mean = np.mean(dt) - dt_std = np.std(dt) - dt_cv = dt_std / dt_mean if dt_mean > 0 else 0 # Coefficient of variation - - # Check for significant non-uniformity - # CV > 0.01 (1%) indicates potentially problematic non-uniformity - if dt_cv > 0.01: - logger.warning( - f"Non-uniform sampling detected in trace {trace_idx}: " - f"mean dt={dt_mean:.3e}s, std={dt_std:.3e}s, CV={dt_cv:.2%}" - ) - - # More detailed warning for severe non-uniformity - if dt_cv > 0.05: # 5% variation - # Find the most extreme deviations - dt_median = np.median(dt) - rel_deviations = np.abs(dt - dt_median) / dt_median - worst_indices = np.argsort(rel_deviations)[-5:] # 5 worst points - - worst_deviations = [] - for idx in reversed(worst_indices): - if ( - rel_deviations[idx] > 0.1 - ): # Only report significant deviations (>10%) - worst_deviations.append( - f"at t={t[idx]:.3e}s: dt={dt[idx]:.3e}s ({rel_deviations[idx]:.1%} deviation)" - ) - - if worst_deviations: - logger.warning( - f"Severe sampling irregularities detected in trace {trace_idx}. " - f"Worst points: {'; '.join(worst_deviations)}" - ) - logger.warning( - "Non-uniform sampling may affect analysis results, especially for " - "frequency-domain analysis or event detection." - ) +from typing import Any, Dict, List, Optional, Tuple, Union +import warnings + +import numpy as np + + +class TimeSeriesDataManager: + """ + Manages time series data storage and basic operations. + + Handles raw data storage, time scaling, and basic data access patterns. + It can also store optional associated data like background estimates, + global noise, and overlay lines. + + Supports multiple traces with shared time axis or individual time axes. + """ + + def __init__( + self, + t: Union[np.ndarray, List[np.ndarray]], + x: Union[np.ndarray, List[np.ndarray]], + name: Union[str, List[str]] = "Time Series", + trace_colors: Optional[List[str]] = None, + ): + """ + Initialise the data manager. + + Parameters + ---------- + t : Union[np.ndarray, List[np.ndarray]] + Time array(s) (raw time in seconds). Can be a single array shared by all traces + or a list of arrays, one per trace. + x : Union[np.ndarray, List[np.ndarray]] + Signal array(s). If t is a single array, x can be a 2D array (traces x samples) + or a list of 1D arrays. If t is a list, x must be a list of equal length. + name : Union[str, List[str]], default="Time Series" + Name(s) for identification. Can be a single string or a list of strings. + trace_colors : Optional[List[str]], default=None + Colors for each trace. If None, default colors will be used. + + Raises + ------ + ValueError + If input arrays have mismatched lengths or time array is not monotonic. + """ + # Convert inputs to standardized format: lists of arrays + self.t_arrays, self.x_arrays, self.names, self.colors = ( + self._standardize_inputs(t, x, name, trace_colors) + ) + + # Validate all data + for i, (t_arr, x_arr) in enumerate(zip(self.t_arrays, self.x_arrays)): + self._validate_core_data(t_arr, x_arr, trace_idx=i) + + # Optional associated data (per trace) + self._overlay_lines: List[List[Dict[str, Any]]] = [ + [] for _ in range(len(self.t_arrays)) + ] + + # For backward compatibility + if len(self.t_arrays) > 0: + self.t = self.t_arrays[0] # Primary time array + self.x = self.x_arrays[0] # Primary signal array + self.name = self.names[0] # Primary name + + def _standardize_inputs( + self, + t: Union[np.ndarray, List[np.ndarray]], + x: Union[np.ndarray, List[np.ndarray]], + name: Union[str, List[str]], + trace_colors: Optional[List[str]], + ) -> Tuple[List[np.ndarray], List[np.ndarray], List[str], List[str]]: + """ + Standardize inputs to lists of arrays. + + Parameters + ---------- + t : Union[np.ndarray, List[np.ndarray]] + Time array(s). + x : Union[np.ndarray, List[np.ndarray]] + Signal array(s). + name : Union[str, List[str]] + Name(s) for identification. + trace_colors : Optional[List[str]] + Colors for each trace. + + Returns + ------- + Tuple[List[np.ndarray], List[np.ndarray], List[str], List[str]] + Standardized lists of time arrays, signal arrays, names, and colors. + """ + # Default colors for traces + default_colors = [ + "black", + "blue", + "red", + "green", + "purple", + "orange", + "brown", + "pink", + "gray", + "olive", + ] + + # Handle time arrays + if isinstance(t, list): + t_arrays = [np.asarray(t_arr, dtype=np.float32) for t_arr in t] + n_traces = len(t_arrays) + else: + t_arr = np.asarray(t, dtype=np.float32) + + # Check if x is 2D array or list + if isinstance(x, list): + n_traces = len(x) + t_arrays = [t_arr.copy() for _ in range(n_traces)] + elif x.ndim == 2: + n_traces = x.shape[0] + t_arrays = [t_arr.copy() for _ in range(n_traces)] + else: + n_traces = 1 + t_arrays = [t_arr] + + # Handle signal arrays + if isinstance(x, list): + if len(x) != n_traces: + raise ValueError( + f"Number of signal arrays ({len(x)}) must match number of time arrays ({n_traces})" + ) + x_arrays = [np.asarray(x_arr, dtype=np.float32) for x_arr in x] + elif x.ndim == 2: + if x.shape[0] != n_traces: + raise ValueError( + f"First dimension of 2D signal array ({x.shape[0]}) must match number of time arrays ({n_traces})" + ) + x_arrays = [np.asarray(x[i], dtype=np.float32) for i in range(n_traces)] + else: + if n_traces != 1: + raise ValueError( + f"Single signal array provided but expected {n_traces} arrays" + ) + x_arrays = [np.asarray(x, dtype=np.float32)] + + # Handle names + if isinstance(name, list): + if len(name) != n_traces: + warnings.warn( + f"Number of names ({len(name)}) doesn't match number of traces ({n_traces}). Using defaults.", UserWarning + ) + names = [f"Trace {i + 1}" for i in range(n_traces)] + else: + names = name + else: + if n_traces == 1: + names = [name] + else: + if ( + name == "Time Series" + ): # Only use default naming if the default name was used + names = [f"Trace {i + 1}" for i in range(n_traces)] + else: + names = [f"{name} {i + 1}" for i in range(n_traces)] + + # Handle colors + if trace_colors is not None: + if len(trace_colors) < n_traces: + warnings.warn( + f"Not enough colors provided ({len(trace_colors)}). Using defaults for remaining traces.", UserWarning + ) + colors = trace_colors + [ + default_colors[i % len(default_colors)] + for i in range(len(trace_colors), n_traces) + ] + else: + colors = trace_colors[:n_traces] + else: + colors = [default_colors[i % len(default_colors)] for i in range(n_traces)] + + return t_arrays, x_arrays, names, colors + + def _validate_core_data( + self, t: np.ndarray, x: np.ndarray, trace_idx: int = 0 + ) -> None: + """ + Validate core input data arrays for consistency and correctness. + + Parameters + ---------- + t : np.ndarray + Time array. + x : np.ndarray + Signal array. + trace_idx : int, default=0 + Index of the trace being validated (for error messages). + + Raises + ------ + ValueError + If arrays have mismatched lengths or time array is not monotonic. + """ + if len(t) != len(x): + raise ValueError( + f"Time and signal arrays for trace {trace_idx} must have the same length. Got t={len(t)}, x={len(x)}" + ) + if len(t) == 0: + warnings.warn(f"Initialising trace {trace_idx} with empty arrays.", UserWarning) + return + + # Check time array is monotonic + if len(t) > 1: + # Use a small epsilon for floating-point comparison + tolerance = 1e-9 + if not np.all(np.diff(t) > tolerance): + problematic_diffs = np.diff(t)[np.diff(t) <= tolerance] + warnings.warn( + f"Time array for trace {trace_idx} is not strictly monotonic increasing within tolerance {tolerance}. " + f"Problematic diffs (first 10): {problematic_diffs[:10]}. " + f"This may affect analysis results.", UserWarning + ) + + # Check for non-uniform sampling + self._check_uniform_sampling(t, trace_idx) + + @property + def overlay_lines(self) -> List[Dict[str, Any]]: + """Get overlay lines data for the primary trace.""" + return self._overlay_lines[0] if self._overlay_lines else [] + + def get_overlay_lines(self, trace_idx: int = 0) -> List[Dict[str, Any]]: + """Get overlay lines data for a specific trace.""" + if trace_idx < 0 or trace_idx >= len(self.t_arrays): + raise ValueError( + f"Invalid trace index: {trace_idx}. Must be between 0 and {len(self.t_arrays) - 1}." + ) + return self._overlay_lines[trace_idx] + + @property + def num_traces(self) -> int: + """Get the number of traces.""" + return len(self.t_arrays) + + def get_trace_color(self, trace_idx: int = 0) -> str: + """Get the color for a specific trace.""" + if trace_idx < 0 or trace_idx >= len(self.t_arrays): + raise ValueError( + f"Invalid trace index: {trace_idx}. Must be between 0 and {len(self.t_arrays) - 1}." + ) + return self.colors[trace_idx] + + def get_trace_name(self, trace_idx: int = 0) -> str: + """Get the name for a specific trace.""" + if trace_idx < 0 or trace_idx >= len(self.t_arrays): + raise ValueError( + f"Invalid trace index: {trace_idx}. Must be between 0 and {len(self.t_arrays) - 1}." + ) + return self.names[trace_idx] + + def set_overlay_lines( + self, + overlay_lines: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]], + trace_idx: Optional[int] = None, + ) -> None: + """ + Set overlay lines data. + + Parameters + ---------- + overlay_lines : Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]] + List of dictionaries defining overlay lines, or list of lists for multiple traces. + trace_idx : Optional[int], default=None + If provided, set overlay lines only for the specified trace. + If None, set for all traces if a nested list is provided, or for the first trace if a flat list. + """ + if trace_idx is not None: + # Set for specific trace + if trace_idx < 0 or trace_idx >= len(self.t_arrays): + raise ValueError( + f"Invalid trace index: {trace_idx}. Must be between 0 and {len(self.t_arrays) - 1}." + ) + + # Ensure we have a list of dictionaries + if not isinstance(overlay_lines, list): + raise ValueError( + f"overlay_lines must be a list of dictionaries. Got {type(overlay_lines)}." + ) + + # Check if it's a list of dictionaries (not a nested list) + if len(overlay_lines) > 0 and isinstance(overlay_lines[0], dict): + self._overlay_lines[trace_idx] = overlay_lines + else: + raise ValueError( + "Expected a list of dictionaries for overlay_lines when trace_idx is specified." + ) + else: + # Set for all traces or first trace + if len(overlay_lines) > 0 and isinstance(overlay_lines[0], list): + # Nested list provided - set for multiple traces + if len(overlay_lines) != len(self.t_arrays): + raise ValueError( + f"Number of overlay line lists ({len(overlay_lines)}) must match number of traces ({len(self.t_arrays)})." + ) + + for i, lines in enumerate(overlay_lines): + self._overlay_lines[i] = lines + else: + # Flat list provided - set for first trace + self._overlay_lines[0] = overlay_lines + + def get_time_range(self, trace_idx: int = 0) -> Tuple[np.float32, np.float32]: + """ + Get the full time range of the data. + + Parameters + ---------- + trace_idx : int, default=0 + Index of the trace to get the time range for. + + Returns + ------- + Tuple[np.float32, np.float32] + Start and end time of the data. + """ + if trace_idx < 0 or trace_idx >= len(self.t_arrays): + raise ValueError( + f"Invalid trace index: {trace_idx}. Must be between 0 and {len(self.t_arrays) - 1}." + ) + + t_arr = self.t_arrays[trace_idx] + if t_arr.size == 0: + return np.float32(0.0), np.float32(0.0) + return np.float32(t_arr[0]), np.float32(t_arr[-1]) + + def get_global_time_range(self) -> Tuple[np.float32, np.float32]: + """ + Get the global time range across all traces. + + Returns + ------- + Tuple[np.float32, np.float32] + Global start and end time across all traces. + """ + if len(self.t_arrays) == 0: + return np.float32(0.0), np.float32(0.0) + + t_min = np.float32( + min(t_arr[0] if t_arr.size > 0 else np.inf for t_arr in self.t_arrays) + ) + t_max = np.float32( + max(t_arr[-1] if t_arr.size > 0 else -np.inf for t_arr in self.t_arrays) + ) + + if np.isinf(t_min) or np.isinf(t_max): + return np.float32(0.0), np.float32(0.0) + + return t_min, t_max + + def get_data_in_range( + self, t_start: np.float32, t_end: np.float32, trace_idx: int = 0 + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Extract data within a time range. + + Parameters + ---------- + t_start : np.float32 + Start time in raw seconds. + t_end : np.float32 + End time in raw seconds. + trace_idx : int, default=0 + Index of the trace to get data for. + + Returns + ------- + Tuple[np.ndarray, np.ndarray] + Time and signal arrays. + """ + if trace_idx < 0 or trace_idx >= len(self.t_arrays): + raise ValueError( + f"Invalid trace index: {trace_idx}. Must be between 0 and {len(self.t_arrays) - 1}." + ) + + t_arr = self.t_arrays[trace_idx] + x_arr = self.x_arrays[trace_idx] + + mask = (t_arr >= t_start) & (t_arr <= t_end) + if not np.any(mask): + return ( + np.array([], dtype=np.float32), + np.array([], dtype=np.float32), + ) + + t_masked = t_arr[mask] + x_masked = x_arr[mask] + + return t_masked, x_masked + + def _check_uniform_sampling(self, t: np.ndarray, trace_idx: int = 0) -> None: + """ + Check if time array is uniformly sampled and issue warnings if not. + + Parameters + ---------- + t : np.ndarray + Time array to check. + trace_idx : int, default=0 + Index of the trace being checked (for warning messages). + """ + if len(t) < 3: + return # Not enough points to check uniformity + + # Calculate time differences + dt = np.diff(t) + + # Calculate statistics + dt_mean = np.mean(dt) + dt_std = np.std(dt) + dt_cv = dt_std / dt_mean if dt_mean > 0 else 0 # Coefficient of variation + + # Check for significant non-uniformity + # CV > 0.01 (1%) indicates potentially problematic non-uniformity + if dt_cv > 0.01: + warnings.warn( + f"Non-uniform sampling detected in trace {trace_idx}: " + f"mean dt={dt_mean:.3e}s, std={dt_std:.3e}s, CV={dt_cv:.2%}", UserWarning + ) + + # More detailed warning for severe non-uniformity + if dt_cv > 0.05: # 5% variation + # Find the most extreme deviations + dt_median = np.median(dt) + rel_deviations = np.abs(dt - dt_median) / dt_median + worst_indices = np.argsort(rel_deviations)[-5:] # 5 worst points + + worst_deviations = [] + for idx in reversed(worst_indices): + if ( + rel_deviations[idx] > 0.1 + ): # Only report significant deviations (>10%) + worst_deviations.append( + f"at t={t[idx]:.3e}s: dt={dt[idx]:.3e}s ({rel_deviations[idx]:.1%} deviation)" + ) + + if worst_deviations: + warnings.warn( + f"Severe sampling irregularities detected in trace {trace_idx}. " + f"Worst points: {'; '.join(worst_deviations)}", UserWarning + ) + warnings.warn( + "Non-uniform sampling may affect analysis results, especially for " + "frequency-domain analysis or event detection.", UserWarning + ) diff --git a/src/scopekit/decimation.py b/src/scopekit/decimation.py index 16543b1..d60ac3b 100644 --- a/src/scopekit/decimation.py +++ b/src/scopekit/decimation.py @@ -1,671 +1,602 @@ -from typing import Dict, Optional, Tuple - -import numpy as np -from loguru import logger -from numba import njit - - -@njit -def _decimate_time_numba(t: np.ndarray, step: int, n_bins: int) -> np.ndarray: - """ - Numba-optimized time decimation using bin centers. - - Parameters - ---------- - t : np.ndarray - Input time array. - step : int - Step size for binning. - n_bins : int - Number of bins to create. - - Returns - ------- - np.ndarray - Decimated time array with center time of each bin. - """ - t_decimated = np.zeros(n_bins, dtype=np.float32) - - for i in range(n_bins): - start_idx = i * step - end_idx = min((i + 1) * step, len(t)) - center_idx = start_idx + (end_idx - start_idx) // 2 - t_decimated[i] = t[center_idx] - - return t_decimated - - -@njit -def _decimate_mean_numba(x: np.ndarray, step: int, n_bins: int) -> np.ndarray: - """ - Numba-optimized mean decimation. - - Parameters - ---------- - x : np.ndarray - Input signal array. - step : int - Step size for binning. - n_bins : int - Number of bins to create. - - Returns - ------- - np.ndarray - Decimated signal array with mean values. - """ - x_decimated = np.zeros(n_bins, dtype=np.float32) - - for i in range(n_bins): - start_idx = i * step - end_idx = min((i + 1) * step, len(x)) - - if end_idx > start_idx: - # Calculate mean manually for Numba compatibility - bin_sum = 0.0 - bin_count = end_idx - start_idx - for j in range(start_idx, end_idx): - bin_sum += x[j] - x_decimated[i] = bin_sum / bin_count - else: - x_decimated[i] = x[start_idx] if start_idx < len(x) else 0.0 - - return x_decimated - - -@njit -def _decimate_envelope_standard_numba( - x: np.ndarray, step: int, n_bins: int -) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - Numba-optimized standard envelope decimation. - - Parameters - ---------- - x : np.ndarray - Input signal array. - step : int - Step size for binning. - n_bins : int - Number of bins to create. - - Returns - ------- - Tuple[np.ndarray, np.ndarray, np.ndarray] - Decimated signal (mean), min envelope, max envelope arrays. - """ - x_decimated = np.zeros(n_bins, dtype=np.float32) - x_min_envelope = np.zeros(n_bins, dtype=np.float32) - x_max_envelope = np.zeros(n_bins, dtype=np.float32) - - for i in range(n_bins): - start_idx = i * step - end_idx = min((i + 1) * step, len(x)) - - if end_idx > start_idx: - # Find min and max manually for Numba compatibility - bin_min = x[start_idx] - bin_max = x[start_idx] - bin_sum = 0.0 - - for j in range(start_idx, end_idx): - val = x[j] - if val < bin_min: - bin_min = val - if val > bin_max: - bin_max = val - bin_sum += val - - x_min_envelope[i] = bin_min - x_max_envelope[i] = bin_max - x_decimated[i] = bin_sum / (end_idx - start_idx) - else: - fallback_val = x[start_idx] if start_idx < len(x) else 0.0 - x_min_envelope[i] = fallback_val - x_max_envelope[i] = fallback_val - x_decimated[i] = fallback_val - - return x_decimated, x_min_envelope, x_max_envelope - - -@njit -def _decimate_envelope_highres_numba( - t: np.ndarray, x: np.ndarray, step: int, n_bins: int, envelope_window_samples: int -) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - Numba-optimized high-resolution envelope decimation. - - Parameters - ---------- - t : np.ndarray - Input time array. - x : np.ndarray - Input signal array. - step : int - Step size for binning. - n_bins : int - Number of bins to create. - envelope_window_samples : int - Window size in samples for high-resolution envelope calculation. - - Returns - ------- - Tuple[np.ndarray, np.ndarray, np.ndarray] - Decimated signal (mean), min envelope, max envelope arrays. - """ - x_decimated = np.zeros(n_bins, dtype=np.float32) - x_min_envelope = np.zeros(n_bins, dtype=np.float32) - x_max_envelope = np.zeros(n_bins, dtype=np.float32) - - half_window = envelope_window_samples // 2 - - for i in range(n_bins): - start_idx = i * step - end_idx = min((i + 1) * step, len(t)) - bin_center = start_idx + (end_idx - start_idx) // 2 - - # Define window around bin center - window_start = max(0, bin_center - half_window) - window_end = min(len(x), bin_center + half_window) - - if window_end > window_start: - # Find min and max in window manually for Numba compatibility - window_min = x[window_start] - window_max = x[window_start] - - for j in range(window_start, window_end): - val = x[j] - if val < window_min: - window_min = val - if val > window_max: - window_max = val - - x_min_envelope[i] = window_min - x_max_envelope[i] = window_max - x_decimated[i] = (window_min + window_max) / 2.0 - else: - fallback_val = x[bin_center] if bin_center < len(x) else 0.0 - x_min_envelope[i] = fallback_val - x_max_envelope[i] = fallback_val - x_decimated[i] = fallback_val - - return x_decimated, x_min_envelope, x_max_envelope - - -class DecimationManager: - """ - Handles data decimation and caching for efficient plotting. - - Manages different decimation strategies and caches results to improve performance. - Pre-calculates decimated data at load time for faster zooming. - """ - - # Cache and performance constants - CACHE_MAX_SIZE = 10 - MIN_VISIBLE_RANGE_DEFAULT = 1e-6 # Default if no global noise is provided - # Threshold for warning about too many points in detail mode - DETAIL_MODE_POINT_WARNING_THRESHOLD = 100000 - - def __init__(self, cache_max_size: int = CACHE_MAX_SIZE): - """ - Initialise the decimation manager. - - Parameters - ---------- - cache_max_size : int, default=PlotConstants.CACHE_MAX_SIZE - Maximum number of cached decimation results. - """ - self._cache: Dict[str, Tuple[np.ndarray, ...]] = {} - self._cache_max_size = cache_max_size - # Stores pre-decimated envelope data for the full dataset for each trace/line - # Structure: {trace_id: {'t': np.ndarray, 'x_min': np.ndarray, 'x_max': np.ndarray, ...}} - self._pre_decimated_envelopes: Dict[int, Dict[str, np.ndarray]] = {} - - def _get_cache_key( - self, - xlim_raw: Tuple[np.float32, np.float32], - max_points: int, - use_envelope: bool, - trace_id: Optional[int] = None, - ) -> str: - """Generate cache key for decimated data.""" - # Round to reasonable precision to improve cache hits - xlim_rounded = (round(float(xlim_raw[0]), 9), round(float(xlim_raw[1]), 9)) - - # Include trace_id in cache key for multi-trace support - trace_suffix = f"_t{trace_id}" if trace_id is not None else "" - - return f"{xlim_rounded}_{max_points}_{use_envelope}{trace_suffix}" - - def _manage_cache_size(self) -> None: - """Remove oldest cache entry if cache is full.""" - if len(self._cache) >= self._cache_max_size: - # Remove oldest entry (simple FIFO) - oldest_key = next(iter(self._cache)) - del self._cache[oldest_key] - - def clear_cache(self) -> None: - """Clear the decimation cache.""" - self._cache.clear() - # Do NOT clear _pre_decimated_envelopes here, as they are persistent for the full dataset - - def _decimate_data( - self, - t: np.ndarray, - x: np.ndarray, - max_points: int, - use_envelope: bool = False, - envelope_window_samples: Optional[int] = None, - return_envelope_min_max: bool = False, # New parameter - ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]: - """ - Unified decimation for time and multiple data arrays. - - Parameters - ---------- - t : np.ndarray - Time array. - x : np.ndarray - Signal array. - max_points : int, default=5000 - Maximum number of points to display. - use_envelope : bool, default=False - Whether to use envelope decimation for the signal array. - envelope_window_samples : Optional[int], default=None - Window size in samples for high-resolution envelope calculation. - return_envelope_min_max : bool, default=False - If True, returns x_min_envelope and x_max_envelope. Otherwise, returns None for them. - If None, uses simple binning approach. - - Returns - ------- - Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]] - Decimated time, signal, signal min envelope, signal max envelope arrays. - """ - # If input arrays are empty, return empty arrays immediately - if len(t) == 0: - return ( - np.array([], dtype=np.float32), - np.array([], dtype=np.float32), - None, - None, - ) - - # If not using envelope, always return raw data for the view - if ( - not use_envelope and not return_envelope_min_max - ): # If not using envelope and not explicitly asking for min/max - return t, x, None, None # No min/max envelope for raw data - - # If using envelope and data is small enough, return raw data as envelope - if use_envelope and len(t) <= max_points and return_envelope_min_max: - return t, x, x, x # x,x for min/max when no decimation - - # Calculate step size for decimation based on max_points - step = max(1, len(t) // max_points) - - # For envelope mode, calculate adaptive envelope window based on data density - adaptive_envelope_window = None - if use_envelope and len(t) > max_points: - # Calculate envelope window based on how much we're decimating - # This ensures envelope resolution matches display capability - adaptive_envelope_window = max( - 1, step // 2 - ) # Half the step size for smoother envelope - logger.debug( - f"Calculated adaptive envelope window: {adaptive_envelope_window} samples (step={step})" - ) - - # Ensure step is not zero, and calculate number of bins - if step == 0: # Should not happen with max(1, ...) but as a safeguard - step = 1 - n_bins = len(t) // step - if ( - n_bins == 0 - ): # If data is too short for the calculated step, take at least one bin - n_bins = 1 - step = len(t) # Take all points in one bin - - # Ensure arrays are contiguous and correct dtype for Numba - t_contiguous = np.ascontiguousarray(t, dtype=np.float32) - x_contiguous = np.ascontiguousarray(x, dtype=np.float32) - - # Decimate time array using Numba-optimized function - t_decimated = _decimate_time_numba(t_contiguous, step, n_bins) - - # Decimate signal (x) using appropriate Numba-optimized function - x_min_envelope: Optional[np.ndarray] = None - x_max_envelope: Optional[np.ndarray] = None - - if use_envelope: # This block handles the decimation logic (mean or envelope) - if adaptive_envelope_window is not None and adaptive_envelope_window > 1: - logger.debug( - f"Using adaptive high-resolution envelope with window size {adaptive_envelope_window} samples" - ) - - # Use Numba-optimized high-resolution envelope decimation with adaptive window - x_decimated, x_min_envelope, x_max_envelope = ( - _decimate_envelope_highres_numba( - t_contiguous, - x_contiguous, - step, - n_bins, - adaptive_envelope_window, - ) - ) - - envelope_thickness = np.mean(x_max_envelope - x_min_envelope) - logger.debug( - f"Adaptive envelope thickness: mean={envelope_thickness:.3g}, min={np.min(x_max_envelope - x_min_envelope):.3g}, max={np.max(x_max_envelope - x_min_envelope):.3g}" - ) - else: - logger.debug("Using standard bin-based envelope") - - # Use Numba-optimized standard envelope decimation - x_decimated, x_min_envelope, x_max_envelope = ( - _decimate_envelope_standard_numba(x_contiguous, step, n_bins) - ) - - # If we are not returning min/max, then x_decimated should be the mean - # Otherwise, x_decimated is just the mean of the envelope for internal use - if not return_envelope_min_max: - x_decimated = (x_min_envelope + x_max_envelope) / 2 - else: # This block is now reached if use_envelope is False AND len(t) > max_points - logger.debug("Using mean decimation for single line") - - # Use Numba-optimized mean decimation - x_decimated = _decimate_mean_numba(x_contiguous, step, n_bins) - - # If return_envelope_min_max is False, ensure min/max are None - if not return_envelope_min_max: - x_min_envelope = None - x_max_envelope = None - - return t_decimated, x_decimated, x_min_envelope, x_max_envelope - - def pre_decimate_data( - self, - data_id: int, # Changed from trace_id to data_id to be more generic for custom lines - t: np.ndarray, - x: np.ndarray, - max_points: int, - envelope_window_samples: Optional[int] = None, # This parameter is now ignored - ) -> None: - """ - Pre-calculate decimated envelope data for the full dataset. - This is used for fast rendering in zoomed-out (envelope) mode. - - Parameters - ---------- - data_id : int - Unique identifier for this data set (e.g., trace_id or custom line ID). - t : np.ndarray - Time array (raw time in seconds). - x : np.ndarray - Signal array. - max_points : int - Maximum number of points for the pre-decimated data. - envelope_window_samples : Optional[int], default=None - Window size in samples for high-resolution envelope calculation. - This will primarily determine the bin size for pre-decimation. - """ - if len(t) <= max_points: - # For small datasets, just store the original data as the "pre-decimated" envelope - # (min/max will be the same as x) - self._pre_decimated_envelopes[data_id] = { - "t": t, - "x": x, # Store mean/center for consistency - "x_min": x, - "x_max": x, - } - logger.debug( - f"Data ID {data_id} is small enough, storing raw as pre-decimated envelope." - ) - return - - logger.debug( - f"Pre-decimating data for ID {data_id} to {max_points} points for envelope view." - ) - # Perform the decimation using the _decimate_data method - # We force use_envelope=True here for pre-decimation to capture min/max - # envelope_window_samples is now calculated automatically based on max_points - t_decimated, x_decimated, x_min, x_max = self._decimate_data( - t, - x, - max_points=max_points, - use_envelope=True, # Always pre-decimate with envelope - envelope_window_samples=None, # Let _decimate_data calculate adaptive window - return_envelope_min_max=True, # Pre-decimation always stores min/max - ) - - # Store pre-decimated envelope data - self._pre_decimated_envelopes[data_id] = { - "t": t_decimated, - "x": x_decimated, # This is the mean/center of the envelope - "x_min": x_min, - "x_max": x_max, - } - - logger.debug( - f"Pre-decimated envelope calculated for ID {data_id}: {len(t_decimated)} points." - ) - - def decimate_for_view( - self, - t_raw_full: np.ndarray, # Full resolution time array - x_raw_full: np.ndarray, # Full resolution signal array - xlim_raw: Tuple[np.float32, np.float32], - max_points: int, - use_envelope: bool = False, - data_id: Optional[int] = None, # Changed from trace_id to data_id - envelope_window_samples: Optional[int] = None, # This parameter is now ignored - mode_switch_threshold: Optional[ - float - ] = None, # New parameter for mode switching - return_envelope_min_max: bool = False, # New parameter - ) -> Tuple[ - np.ndarray, - np.ndarray, - Optional[np.ndarray], - Optional[np.ndarray], - ]: - """ - Intelligently decimate data for current view with optional envelope mode. - - Parameters - ---------- - t_raw_full : np.ndarray - Full resolution time array (raw time in seconds). - x_raw_full : np.ndarray - Full resolution signal array. - xlim_raw : Tuple[np.float32, np.float32] - Current x-axis limits in raw time (seconds). - max_points : int - Maximum number of points to display. - use_envelope : bool, default=False - Whether the current display mode is envelope. - data_id : Optional[int], default=None - Unique identifier for this data set (e.g., trace_id or custom line ID). - Used to retrieve pre-decimated envelope data. - envelope_window_samples : Optional[int], default=None - Window size in samples for high-resolution envelope calculation. - return_envelope_min_max : bool, default=False - If True, returns x_min_envelope and x_max_envelope. Otherwise, returns None for them. - mode_switch_threshold : Optional[float], default=None - Time span threshold for switching between envelope and detail modes. - Used to decide whether to use pre-decimated envelope data. - - Returns - ------- - Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]] - Decimated time, signal, signal min envelope, signal max envelope arrays (all in raw time). - """ - logger.debug(f"=== DecimationManager.decimate_for_view data_id={data_id} ===") - logger.debug(f"xlim_raw: {xlim_raw}") - logger.debug(f"use_envelope (requested): {use_envelope}") - logger.debug(f"max_points: {max_points}") - logger.debug( - f"Input data range: t=[{np.min(t_raw_full):.6f}, {np.max(t_raw_full):.6f}], x=[{np.min(x_raw_full):.6f}, {np.max(x_raw_full):.6f}]" - ) - - # Ensure xlim_raw values are valid - if ( - not np.isfinite(xlim_raw[0]) - or not np.isfinite(xlim_raw[1]) - or xlim_raw[0] == xlim_raw[1] - ): - logger.warning( - f"Invalid xlim_raw values: {xlim_raw}. Using full data range." - ) - xlim_raw = (np.min(t_raw_full), np.max(t_raw_full)) - - # Ensure xlim_raw is in ascending order - if xlim_raw[0] > xlim_raw[1]: - logger.warning(f"xlim_raw values out of order: {xlim_raw}. Swapping.") - xlim_raw = (xlim_raw[1], xlim_raw[0]) - - # Calculate current view span - current_view_span = xlim_raw[1] - xlim_raw[0] - - # Check cache first - cache_key = self._get_cache_key( - xlim_raw, max_points, use_envelope, data_id - ) # Cache key doesn't need return_envelope_min_max - if cache_key in self._cache: - logger.debug(f"Using cached decimation for key: {cache_key}") - return self._cache[cache_key] - - # --- Strategy: Use pre-decimated envelope if in envelope mode and view is wide --- - if ( - use_envelope - and data_id is not None - and data_id in self._pre_decimated_envelopes - ): - pre_dec_data = self._pre_decimated_envelopes[data_id] - pre_dec_t = pre_dec_data["t"] - - if len(pre_dec_t) > 1: - pre_dec_span = pre_dec_t[-1] - pre_dec_t[0] - - # Calculate how much detail we would gain by re-decimating - # Find indices for current view in pre-decimated time - mask = (pre_dec_t >= xlim_raw[0]) & (pre_dec_t <= xlim_raw[1]) - pre_dec_points_in_view = np.sum(mask) - - # Estimate how many points we would get from dynamic decimation - t_view_mask = (t_raw_full >= xlim_raw[0]) & (t_raw_full <= xlim_raw[1]) - raw_points_in_view = np.sum(t_view_mask) - potential_decimated_points = min(raw_points_in_view, max_points) - - # Use pre-decimated data only if: - # 1. Current view span is very large (> 2x mode_switch_threshold), AND - # 2. Pre-decimated data provides reasonable detail (> max_points/4), AND - # 3. We wouldn't gain much detail from re-decimating (< 2x improvement) - use_pre_decimated = ( - mode_switch_threshold is not None - and current_view_span >= 2 * mode_switch_threshold - and pre_dec_points_in_view > max_points // 4 - and potential_decimated_points < 2 * pre_dec_points_in_view - ) - - if use_pre_decimated and np.any(mask): - logger.debug( - f"Using pre-decimated data for ID {data_id} (envelope mode, very wide view, {pre_dec_points_in_view} points, return_envelope_min_max={return_envelope_min_max})." - ) - - # If we need min/max, return them. Otherwise, return None. - x_min_ret = ( - pre_dec_data["x_min"][mask] if return_envelope_min_max else None - ) - x_max_ret = ( - pre_dec_data["x_max"][mask] if return_envelope_min_max else None - ) - - result = ( - pre_dec_t[mask], - pre_dec_data["x"][mask], # Center of envelope - x_min_ret, - x_max_ret, - ) - self._manage_cache_size() - self._cache[cache_key] = result - return result - else: - logger.debug( - f"Re-decimating for better detail: view_span={current_view_span:.3e}, pre_dec_points={pre_dec_points_in_view}, potential_points={potential_decimated_points}" - ) - else: - logger.debug( - f"Pre-decimated data for ID {data_id} has only one point, falling back to dynamic decimation." - ) - else: - logger.debug( - f"Not using pre-decimated envelope for ID {data_id} (use_envelope={use_envelope}, data_id={data_id in self._pre_decimated_envelopes})." - ) - - # --- Fallback: Dynamic decimation from raw data --- - logger.debug("Performing dynamic decimation from raw data.") - - # ADDED DEBUG LOGS - logger.debug( - f" t_raw_full min/max: {t_raw_full.min():.6f}, {t_raw_full.max():.6f}" - ) - logger.debug(f" xlim_raw: {xlim_raw[0]:.6f}, {xlim_raw[1]:.6f}") - - # Find indices for current view in raw time - mask = (t_raw_full >= xlim_raw[0]) & (t_raw_full <= xlim_raw[1]) - - # ADDED DEBUG LOG - logger.debug(f" Mask result: {np.sum(mask)} points selected.") - - if not np.any(mask): - logger.warning( - f"No data in view for xlim_raw: {xlim_raw}. Returning empty arrays." - ) - empty_result = ( - np.array([], dtype=np.float32), - np.array([], dtype=np.float32), - None, - None, - ) - # Cache empty result for this view - self._manage_cache_size() - self._cache[cache_key] = empty_result - return empty_result - - t_view = t_raw_full[mask] - x_view = x_raw_full[mask] - - # Add warning for large number of points in detail mode - if not use_envelope and len(t_view) > self.DETAIL_MODE_POINT_WARNING_THRESHOLD: - logger.warning( - f"Plotting {len(t_view)} points in detail mode. " - f"Performance may be affected. Consider zooming in further." - ) - - # Use unified decimation approach - # envelope_window_samples is now calculated automatically based on max_points and data density - result = self._decimate_data( - t_view, - x_view, - max_points=max_points, - use_envelope=use_envelope, # Use requested envelope mode for dynamic decimation - envelope_window_samples=None, # Let _decimate_data calculate adaptive window - return_envelope_min_max=return_envelope_min_max, # Pass through - ) - - # Cache the result (manage cache size) - self._manage_cache_size() - self._cache[cache_key] = result - - # Log the final result - t_result, x_result, x_min_result, x_max_result = result - logger.debug(f"Returning result: t len={len(t_result)}, x len={len(x_result)}") - logger.debug( - f"Result ranges: t=[{np.min(t_result) if len(t_result) > 0 else 'empty':.6f}, {np.max(t_result) if len(t_result) > 0 else 'empty':.6f}], x=[{np.min(x_result) if len(x_result) > 0 else 'empty':.6f}, {np.max(x_result) if len(x_result) > 0 else 'empty':.6f}]" - ) - logger.debug( - f"Envelope: x_min={'None' if x_min_result is None else f'len={len(x_min_result)}'}, x_max={'None' if x_max_result is None else f'len={len(x_max_result)}'}" - ) - - return result +from typing import Dict, Optional, Tuple +import warnings + +import numpy as np +from numba import njit + + +@njit +def _decimate_time_numba(t: np.ndarray, step: int, n_bins: int) -> np.ndarray: + """ + Numba-optimized time decimation using bin centers. + + Parameters + ---------- + t : np.ndarray + Input time array. + step : int + Step size for binning. + n_bins : int + Number of bins to create. + + Returns + ------- + np.ndarray + Decimated time array with center time of each bin. + """ + t_decimated = np.zeros(n_bins, dtype=np.float32) + + for i in range(n_bins): + start_idx = i * step + end_idx = min((i + 1) * step, len(t)) + center_idx = start_idx + (end_idx - start_idx) // 2 + t_decimated[i] = t[center_idx] + + return t_decimated + + +@njit +def _decimate_mean_numba(x: np.ndarray, step: int, n_bins: int) -> np.ndarray: + """ + Numba-optimized mean decimation. + + Parameters + ---------- + x : np.ndarray + Input signal array. + step : int + Step size for binning. + n_bins : int + Number of bins to create. + + Returns + ------- + np.ndarray + Decimated signal array with mean values. + """ + x_decimated = np.zeros(n_bins, dtype=np.float32) + + for i in range(n_bins): + start_idx = i * step + end_idx = min((i + 1) * step, len(x)) + + if end_idx > start_idx: + # Calculate mean manually for Numba compatibility + bin_sum = 0.0 + bin_count = end_idx - start_idx + for j in range(start_idx, end_idx): + bin_sum += x[j] + x_decimated[i] = bin_sum / bin_count + else: + x_decimated[i] = x[start_idx] if start_idx < len(x) else 0.0 + + return x_decimated + + +@njit +def _decimate_envelope_standard_numba( + x: np.ndarray, step: int, n_bins: int +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Numba-optimized standard envelope decimation. + + Parameters + ---------- + x : np.ndarray + Input signal array. + step : int + Step size for binning. + n_bins : int + Number of bins to create. + + Returns + ------- + Tuple[np.ndarray, np.ndarray, np.ndarray] + Decimated signal (mean), min envelope, max envelope arrays. + """ + x_decimated = np.zeros(n_bins, dtype=np.float32) + x_min_envelope = np.zeros(n_bins, dtype=np.float32) + x_max_envelope = np.zeros(n_bins, dtype=np.float32) + + for i in range(n_bins): + start_idx = i * step + end_idx = min((i + 1) * step, len(x)) + + if end_idx > start_idx: + # Find min and max manually for Numba compatibility + bin_min = x[start_idx] + bin_max = x[start_idx] + bin_sum = 0.0 + + for j in range(start_idx, end_idx): + val = x[j] + if val < bin_min: + bin_min = val + if val > bin_max: + bin_max = val + bin_sum += val + + x_min_envelope[i] = bin_min + x_max_envelope[i] = bin_max + x_decimated[i] = bin_sum / (end_idx - start_idx) + else: + fallback_val = x[start_idx] if start_idx < len(x) else 0.0 + x_min_envelope[i] = fallback_val + x_max_envelope[i] = fallback_val + x_decimated[i] = fallback_val + + return x_decimated, x_min_envelope, x_max_envelope + + +@njit +def _decimate_envelope_highres_numba( + t: np.ndarray, x: np.ndarray, step: int, n_bins: int, envelope_window_samples: int +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Numba-optimized high-resolution envelope decimation. + + Parameters + ---------- + t : np.ndarray + Input time array. + x : np.ndarray + Input signal array. + step : int + Step size for binning. + n_bins : int + Number of bins to create. + envelope_window_samples : int + Window size in samples for high-resolution envelope calculation. + + Returns + ------- + Tuple[np.ndarray, np.ndarray, np.ndarray] + Decimated signal (mean), min envelope, max envelope arrays. + """ + x_decimated = np.zeros(n_bins, dtype=np.float32) + x_min_envelope = np.zeros(n_bins, dtype=np.float32) + x_max_envelope = np.zeros(n_bins, dtype=np.float32) + + half_window = envelope_window_samples // 2 + + for i in range(n_bins): + start_idx = i * step + end_idx = min((i + 1) * step, len(t)) + bin_center = start_idx + (end_idx - start_idx) // 2 + + # Define window around bin center + window_start = max(0, bin_center - half_window) + window_end = min(len(x), bin_center + half_window) + + if window_end > window_start: + # Find min and max in window manually for Numba compatibility + window_min = x[window_start] + window_max = x[window_start] + + for j in range(window_start, window_end): + val = x[j] + if val < window_min: + window_min = val + if val > window_max: + window_max = val + + x_min_envelope[i] = window_min + x_max_envelope[i] = window_max + x_decimated[i] = (window_min + window_max) / 2.0 + else: + fallback_val = x[bin_center] if bin_center < len(x) else 0.0 + x_min_envelope[i] = fallback_val + x_max_envelope[i] = fallback_val + x_decimated[i] = fallback_val + + return x_decimated, x_min_envelope, x_max_envelope + + +class DecimationManager: + """ + Handles data decimation and caching for efficient plotting. + + Manages different decimation strategies and caches results to improve performance. + Pre-calculates decimated data at load time for faster zooming. + """ + + # Cache and performance constants + CACHE_MAX_SIZE = 10 + MIN_VISIBLE_RANGE_DEFAULT = 1e-6 # Default if no global noise is provided + # Threshold for warning about too many points in detail mode + DETAIL_MODE_POINT_WARNING_THRESHOLD = 100000 + + def __init__(self, cache_max_size: int = CACHE_MAX_SIZE): + """ + Initialise the decimation manager. + + Parameters + ---------- + cache_max_size : int, default=PlotConstants.CACHE_MAX_SIZE + Maximum number of cached decimation results. + """ + self._cache: Dict[str, Tuple[np.ndarray, ...]] = {} + self._cache_max_size = cache_max_size + # Stores pre-decimated envelope data for the full dataset for each trace/line + # Structure: {trace_id: {'t': np.ndarray, 'x_min': np.ndarray, 'x_max': np.ndarray, ...}} + self._pre_decimated_envelopes: Dict[int, Dict[str, np.ndarray]] = {} + + def _get_cache_key( + self, + xlim_raw: Tuple[np.float32, np.float32], + max_points: int, + use_envelope: bool, + trace_id: Optional[int] = None, + ) -> str: + """Generate cache key for decimated data.""" + # Round to reasonable precision to improve cache hits + xlim_rounded = (round(float(xlim_raw[0]), 9), round(float(xlim_raw[1]), 9)) + + # Include trace_id in cache key for multi-trace support + trace_suffix = f"_t{trace_id}" if trace_id is not None else "" + + return f"{xlim_rounded}_{max_points}_{use_envelope}{trace_suffix}" + + def _manage_cache_size(self) -> None: + """Remove oldest cache entry if cache is full.""" + if len(self._cache) >= self._cache_max_size: + # Remove oldest entry (simple FIFO) + oldest_key = next(iter(self._cache)) + del self._cache[oldest_key] + + def clear_cache(self) -> None: + """Clear the decimation cache.""" + self._cache.clear() + # Do NOT clear _pre_decimated_envelopes here, as they are persistent for the full dataset + + def _decimate_data( + self, + t: np.ndarray, + x: np.ndarray, + max_points: int, + use_envelope: bool = False, + envelope_window_samples: Optional[int] = None, + return_envelope_min_max: bool = False, # New parameter + ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]: + """ + Unified decimation for time and multiple data arrays. + + Parameters + ---------- + t : np.ndarray + Time array. + x : np.ndarray + Signal array. + max_points : int, default=5000 + Maximum number of points to display. + use_envelope : bool, default=False + Whether to use envelope decimation for the signal array. + envelope_window_samples : Optional[int], default=None + Window size in samples for high-resolution envelope calculation. + return_envelope_min_max : bool, default=False + If True, returns x_min_envelope and x_max_envelope. Otherwise, returns None for them. + If None, uses simple binning approach. + + Returns + ------- + Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]] + Decimated time, signal, signal min envelope, signal max envelope arrays. + """ + # If input arrays are empty, return empty arrays immediately + if len(t) == 0: + return ( + np.array([], dtype=np.float32), + np.array([], dtype=np.float32), + None, + None, + ) + + # If not using envelope, always return raw data for the view + if ( + not use_envelope and not return_envelope_min_max + ): # If not using envelope and not explicitly asking for min/max + return t, x, None, None # No min/max envelope for raw data + + # If using envelope and data is small enough, return raw data as envelope + if use_envelope and len(t) <= max_points and return_envelope_min_max: + return t, x, x, x # x,x for min/max when no decimation + + # Calculate step size for decimation based on max_points + step = max(1, len(t) // max_points) + + # For envelope mode, calculate adaptive envelope window based on data density + adaptive_envelope_window = None + if use_envelope and len(t) > max_points: + # Calculate envelope window based on how much we're decimating + # This ensures envelope resolution matches display capability + adaptive_envelope_window = max( + 1, step // 2 + ) # Half the step size for smoother envelope + logger.debug( + f"Calculated adaptive envelope window: {adaptive_envelope_window} samples (step={step})" + ) + + # Ensure step is not zero, and calculate number of bins + if step == 0: # Should not happen with max(1, ...) but as a safeguard + step = 1 + n_bins = len(t) // step + if ( + n_bins == 0 + ): # If data is too short for the calculated step, take at least one bin + n_bins = 1 + step = len(t) # Take all points in one bin + + # Ensure arrays are contiguous and correct dtype for Numba + t_contiguous = np.ascontiguousarray(t, dtype=np.float32) + x_contiguous = np.ascontiguousarray(x, dtype=np.float32) + + # Decimate time array using Numba-optimized function + t_decimated = _decimate_time_numba(t_contiguous, step, n_bins) + + # Decimate signal (x) using appropriate Numba-optimized function + x_min_envelope: Optional[np.ndarray] = None + x_max_envelope: Optional[np.ndarray] = None + + if use_envelope: # This block handles the decimation logic (mean or envelope) + if adaptive_envelope_window is not None and adaptive_envelope_window > 1: + # Use Numba-optimized high-resolution envelope decimation with adaptive window + x_decimated, x_min_envelope, x_max_envelope = ( + _decimate_envelope_highres_numba( + t_contiguous, + x_contiguous, + step, + n_bins, + adaptive_envelope_window, + ) + ) + else: + # Use Numba-optimized standard envelope decimation + x_decimated, x_min_envelope, x_max_envelope = ( + _decimate_envelope_standard_numba(x_contiguous, step, n_bins) + ) + + # If we are not returning min/max, then x_decimated should be the mean + # Otherwise, x_decimated is just the mean of the envelope for internal use + if not return_envelope_min_max: + x_decimated = (x_min_envelope + x_max_envelope) / 2 + else: # This block is now reached if use_envelope is False AND len(t) > max_points + # Use Numba-optimized mean decimation + x_decimated = _decimate_mean_numba(x_contiguous, step, n_bins) + + # If return_envelope_min_max is False, ensure min/max are None + if not return_envelope_min_max: + x_min_envelope = None + x_max_envelope = None + + return t_decimated, x_decimated, x_min_envelope, x_max_envelope + + def pre_decimate_data( + self, + data_id: int, # Changed from trace_id to data_id to be more generic for custom lines + t: np.ndarray, + x: np.ndarray, + max_points: int, + envelope_window_samples: Optional[int] = None, # This parameter is now ignored + ) -> None: + """ + Pre-calculate decimated envelope data for the full dataset. + This is used for fast rendering in zoomed-out (envelope) mode. + + Parameters + ---------- + data_id : int + Unique identifier for this data set (e.g., trace_id or custom line ID). + t : np.ndarray + Time array (raw time in seconds). + x : np.ndarray + Signal array. + max_points : int + Maximum number of points for the pre-decimated data. + envelope_window_samples : Optional[int], default=None + Window size in samples for high-resolution envelope calculation. + This will primarily determine the bin size for pre-decimation. + """ + if len(t) <= max_points: + # For small datasets, just store the original data as the "pre-decimated" envelope + # (min/max will be the same as x) + self._pre_decimated_envelopes[data_id] = { + "t": t, + "x": x, # Store mean/center for consistency + "x_min": x, + "x_max": x, + } + return + # Perform the decimation using the _decimate_data method + # We force use_envelope=True here for pre-decimation to capture min/max + # envelope_window_samples is now calculated automatically based on max_points + t_decimated, x_decimated, x_min, x_max = self._decimate_data( + t, + x, + max_points=max_points, + use_envelope=True, # Always pre-decimate with envelope + envelope_window_samples=None, # Let _decimate_data calculate adaptive window + return_envelope_min_max=True, # Pre-decimation always stores min/max + ) + + # Store pre-decimated envelope data + self._pre_decimated_envelopes[data_id] = { + "t": t_decimated, + "x": x_decimated, # This is the mean/center of the envelope + "x_min": x_min, + "x_max": x_max, + } + + def decimate_for_view( + self, + t_raw_full: np.ndarray, # Full resolution time array + x_raw_full: np.ndarray, # Full resolution signal array + xlim_raw: Tuple[np.float32, np.float32], + max_points: int, + use_envelope: bool = False, + data_id: Optional[int] = None, # Changed from trace_id to data_id + envelope_window_samples: Optional[int] = None, # This parameter is now ignored + mode_switch_threshold: Optional[ + float + ] = None, # New parameter for mode switching + return_envelope_min_max: bool = False, # New parameter + ) -> Tuple[ + np.ndarray, + np.ndarray, + Optional[np.ndarray], + Optional[np.ndarray], + ]: + """ + Intelligently decimate data for current view with optional envelope mode. + + Parameters + ---------- + t_raw_full : np.ndarray + Full resolution time array (raw time in seconds). + x_raw_full : np.ndarray + Full resolution signal array. + xlim_raw : Tuple[np.float32, np.float32] + Current x-axis limits in raw time (seconds). + max_points : int + Maximum number of points to display. + use_envelope : bool, default=False + Whether the current display mode is envelope. + data_id : Optional[int], default=None + Unique identifier for this data set (e.g., trace_id or custom line ID). + Used to retrieve pre-decimated envelope data. + envelope_window_samples : Optional[int], default=None + Window size in samples for high-resolution envelope calculation. + return_envelope_min_max : bool, default=False + If True, returns x_min_envelope and x_max_envelope. Otherwise, returns None for them. + mode_switch_threshold : Optional[float], default=None + Time span threshold for switching between envelope and detail modes. + Used to decide whether to use pre-decimated envelope data. + + Returns + ------- + Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]] + Decimated time, signal, signal min envelope, signal max envelope arrays (all in raw time). + """ + # Ensure xlim_raw values are valid + if ( + not np.isfinite(xlim_raw[0]) + or not np.isfinite(xlim_raw[1]) + or xlim_raw[0] == xlim_raw[1] + ): + warnings.warn( + f"Invalid xlim_raw values: {xlim_raw}. Using full data range.", RuntimeWarning + ) + xlim_raw = (np.min(t_raw_full), np.max(t_raw_full)) + + # Ensure xlim_raw is in ascending order + if xlim_raw[0] > xlim_raw[1]: + warnings.warn(f"xlim_raw values out of order: {xlim_raw}. Swapping.", RuntimeWarning) + xlim_raw = (xlim_raw[1], xlim_raw[0]) + + # Calculate current view span + current_view_span = xlim_raw[1] - xlim_raw[0] + + # Check cache first + cache_key = self._get_cache_key( + xlim_raw, max_points, use_envelope, data_id + ) # Cache key doesn't need return_envelope_min_max + if cache_key in self._cache: + logger.debug(f"Using cached decimation for key: {cache_key}") + return self._cache[cache_key] + + # --- Strategy: Use pre-decimated envelope if in envelope mode and view is wide --- + if ( + use_envelope + and data_id is not None + and data_id in self._pre_decimated_envelopes + ): + pre_dec_data = self._pre_decimated_envelopes[data_id] + pre_dec_t = pre_dec_data["t"] + + if len(pre_dec_t) > 1: + pre_dec_span = pre_dec_t[-1] - pre_dec_t[0] + + # Calculate how much detail we would gain by re-decimating + # Find indices for current view in pre-decimated time + mask = (pre_dec_t >= xlim_raw[0]) & (pre_dec_t <= xlim_raw[1]) + pre_dec_points_in_view = np.sum(mask) + + # Estimate how many points we would get from dynamic decimation + t_view_mask = (t_raw_full >= xlim_raw[0]) & (t_raw_full <= xlim_raw[1]) + raw_points_in_view = np.sum(t_view_mask) + potential_decimated_points = min(raw_points_in_view, max_points) + + # Use pre-decimated data only if: + # 1. Current view span is very large (> 2x mode_switch_threshold), AND + # 2. Pre-decimated data provides reasonable detail (> max_points/4), AND + # 3. We wouldn't gain much detail from re-decimating (< 2x improvement) + use_pre_decimated = ( + mode_switch_threshold is not None + and current_view_span >= 2 * mode_switch_threshold + and pre_dec_points_in_view > max_points // 4 + and potential_decimated_points < 2 * pre_dec_points_in_view + ) + + if use_pre_decimated and np.any(mask): + # If we need min/max, return them. Otherwise, return None. + x_min_ret = ( + pre_dec_data["x_min"][mask] if return_envelope_min_max else None + ) + x_max_ret = ( + pre_dec_data["x_max"][mask] if return_envelope_min_max else None + ) + + result = ( + pre_dec_t[mask], + pre_dec_data["x"][mask], # Center of envelope + x_min_ret, + x_max_ret, + ) + self._manage_cache_size() + self._cache[cache_key] = result + return result + + # --- Fallback: Dynamic decimation from raw data --- + # Find indices for current view in raw time + mask = (t_raw_full >= xlim_raw[0]) & (t_raw_full <= xlim_raw[1]) + + if not np.any(mask): + warnings.warn( + f"No data in view for xlim_raw: {xlim_raw}. Returning empty arrays.", RuntimeWarning + ) + empty_result = ( + np.array([], dtype=np.float32), + np.array([], dtype=np.float32), + None, + None, + ) + # Cache empty result for this view + self._manage_cache_size() + self._cache[cache_key] = empty_result + return empty_result + + t_view = t_raw_full[mask] + x_view = x_raw_full[mask] + + # Add warning for large number of points in detail mode + if not use_envelope and len(t_view) > self.DETAIL_MODE_POINT_WARNING_THRESHOLD: + warnings.warn( + f"Plotting {len(t_view)} points in detail mode. " + f"Performance may be affected. Consider zooming in further.", UserWarning + ) + + # Use unified decimation approach + # envelope_window_samples is now calculated automatically based on max_points and data density + result = self._decimate_data( + t_view, + x_view, + max_points=max_points, + use_envelope=use_envelope, # Use requested envelope mode for dynamic decimation + envelope_window_samples=None, # Let _decimate_data calculate adaptive window + return_envelope_min_max=return_envelope_min_max, # Pass through + ) + + # Cache the result (manage cache size) + self._manage_cache_size() + self._cache[cache_key] = result + + return result diff --git a/src/scopekit/display_state.py b/src/scopekit/display_state.py index b66b556..793bc4c 100644 --- a/src/scopekit/display_state.py +++ b/src/scopekit/display_state.py @@ -1,294 +1,286 @@ -from typing import Optional, Tuple - -import numpy as np -from loguru import logger -from matplotlib.ticker import FuncFormatter - -# Time unit boundaries (hysteresis) -PICOSECOND_BOUNDARY = 0.8e-9 -NANOSECOND_BOUNDARY = 0.8e-6 -MICROSECOND_BOUNDARY = 0.8e-3 -MILLISECOND_BOUNDARY = 0.8 - -# Offset thresholds -OFFSET_SPAN_MULTIPLIER = 10 -OFFSET_TIME_THRESHOLD = 1e-3 # 1ms - - -def _get_optimal_time_unit_and_scale( - time_array_or_span: np.ndarray | float, -) -> Tuple[str, np.float32]: - """ - Determines the optimal time unit and scaling factor for a given time array or span. - - Uses hysteresis boundaries to prevent oscillation near unit boundaries. - - Parameters - ---------- - time_array_or_span : np.ndarray | float - A NumPy array representing time in seconds, or a single float representing a time span in seconds. - - Returns - ------- - Tuple[str, np.float32] - A tuple containing the time unit string (e.g., "s", "ms", "us", "ns") - and the corresponding scaling factor (e.0, 1e3, 1e6, 1e9). - """ - if isinstance(time_array_or_span, np.ndarray): - # Handle empty array case to prevent errors - if time_array_or_span.size == 0: - return "s", np.float32(1.0) # Default to seconds if no data - max_val = np.max(time_array_or_span) - else: # Assume it's a float representing a span - max_val = time_array_or_span - - # Use hysteresis boundaries to prevent oscillation near unit boundaries - if max_val < PICOSECOND_BOUNDARY: - return "ps", np.float32(1e12) - elif max_val < NANOSECOND_BOUNDARY: - return "ns", np.float32(1e9) - elif max_val < MICROSECOND_BOUNDARY: - return "us", np.float32(1e6) - elif max_val < MILLISECOND_BOUNDARY: - return "ms", np.float32(1e3) - else: - return "s", np.float32(1.0) - - -def _determine_offset_display_params( - xlim_raw: Tuple[np.float32, np.float32], time_span_raw: np.float32 -) -> Tuple[str, np.float32, Optional[np.float32], Optional[str]]: - """ - Determine display parameters including offset for optimal readability. - - Parameters - ---------- - xlim_raw : Tuple[np.float32, np.float32] - Current x-axis limits in raw time (seconds). - time_span_raw : np.float32 - Time span of current view in seconds. - - Returns - ------- - Tuple[str, np.float32, Optional[np.float32], Optional[str]] - Display unit, display scale, offset time (raw seconds), offset unit string. - If no offset is needed, offset_time and offset_unit will be None. - """ - # Get optimal unit for the time span - display_unit, display_scale = _get_optimal_time_unit_and_scale(time_span_raw) - - # Determine if we need an offset - # Use offset if the start time is significantly larger than the span - xlim_start = xlim_raw[0] - - # Use offset if start time is more than threshold multiplier of the span, and span is small - use_offset = (abs(xlim_start) > OFFSET_SPAN_MULTIPLIER * time_span_raw) and ( - time_span_raw < np.float32(OFFSET_TIME_THRESHOLD) - ) - - if use_offset: - # Choose appropriate unit for the offset - if abs(xlim_start) >= np.float32(1.0): # >= 1 second - offset_unit = "s" - offset_scale = np.float32(1.0) - elif abs(xlim_start) >= np.float32(1e-3): # >= 1 millisecond - offset_unit = "ms" - offset_scale = np.float32(1e3) - elif abs(xlim_start) >= np.float32(1e-6): # >= 1 microsecond - offset_unit = "us" - offset_scale = np.float32(1e6) - else: - offset_unit = "ns" - offset_scale = np.float32(1e9) - - return display_unit, display_scale, xlim_start, offset_unit - else: - return display_unit, display_scale, None, None - - -def _create_time_formatter( - offset_time_raw: Optional[np.float32], display_scale: np.float32 -) -> FuncFormatter: - """ - Create a FuncFormatter for time axis tick labels. - - Parameters - ---------- - offset_time_raw : Optional[np.float32] - Offset time in raw seconds. If None, no offset is applied. - display_scale : np.float32 - Scale factor for display units. - - Returns - ------- - FuncFormatter - Matplotlib formatter for tick labels. - """ - - def formatter(x, pos): - # x is already in display units (relative to offset if applicable) - # Format with appropriate precision based on scale - if display_scale >= np.float32(1e9): # nanoseconds or smaller - return f"{x:.0f}" - elif display_scale >= np.float32(1e6): # microseconds - return f"{x:.0f}" - elif display_scale >= np.float32(1e3): # milliseconds - return f"{x:.1f}" - else: # seconds - return f"{x:.3f}" - - return FuncFormatter(formatter) - - -class DisplayState: - """ - Manages display state and mode switching logic. - - Centralises state management to reduce complexity and flag interactions. - """ - - def __init__( - self, - original_time_unit: str, - original_time_scale: np.float32, - envelope_limit: np.float32, - ): - """ - Initialise display state. - - Parameters - ---------- - original_time_unit : str - Original time unit string. - original_time_scale : np.float32 - Original time scaling factor. - envelope_limit : np.float32 - Time span threshold for envelope mode. - """ - # Time scaling - self.original_time_unit = original_time_unit - self.original_time_scale = original_time_scale - self.current_time_unit = original_time_unit - self.current_time_scale = original_time_scale - - # Display mode - self.current_mode: Optional[str] = None - self.envelope_limit = envelope_limit - - # Offset parameters - self.offset_time_raw: Optional[np.float32] = None - self.offset_unit: Optional[str] = None - - # Single state flag - simplified - self._updating = False - - def get_time_unit_and_scale(self, t: np.ndarray) -> Tuple[str, np.float32]: - """ - Automatically select appropriate time unit and scale for plotting. - - Parameters - ---------- - t : np.ndarray - Time array. - - Returns - ------- - Tuple[str, np.float32] - Time unit string and scaling factor. - """ - # Delegate to the new utility function - return _get_optimal_time_unit_and_scale(t) - - def update_display_params( - self, xlim_raw: Tuple[np.float32, np.float32], time_span_raw: np.float32 - ) -> bool: - """ - Update display parameters including offset based on current view. - - Parameters - ---------- - xlim_raw : Tuple[np.float32, np.float32] - Current x-axis limits in raw time (seconds). - time_span_raw : np.float32 - Time span of current view in seconds. - - Returns - ------- - bool - True if display parameters changed, False otherwise. - """ - display_unit, display_scale, offset_time, offset_unit = ( - _determine_offset_display_params(xlim_raw, time_span_raw) - ) - - # Check if anything changed - params_changed = ( - display_unit != self.current_time_unit - or display_scale != self.current_time_scale - or offset_time != self.offset_time_raw - or offset_unit != self.offset_unit - ) - - if params_changed: - logger.info( - f"Display params changed: unit={display_unit}, scale={display_scale:.1e}, offset={offset_time}, offset_unit={offset_unit}" - ) - self.current_time_unit = display_unit - self.current_time_scale = display_scale - self.offset_time_raw = offset_time - self.offset_unit = offset_unit - return True - - return False - - def should_use_envelope(self, time_span_raw: np.float32) -> bool: - """Determine if envelope mode should be used based on time span.""" - return time_span_raw > self.envelope_limit - - def should_show_thresholds(self, time_span_raw: np.float32) -> bool: - """Determine if threshold lines should be shown based on time span.""" - return time_span_raw < self.envelope_limit - - def update_time_scale(self, time_span_raw: np.float32) -> bool: - """ - Update time scale based on current view span. - - Returns True if scale changed, False otherwise. - """ - # Delegate to the new utility function for span - new_unit, new_scale = _get_optimal_time_unit_and_scale(time_span_raw) - - if new_scale != self.current_time_scale: - logger.info( - f"Time scale changed from {self.current_time_unit} ({self.current_time_scale:.1e}) to {new_unit} ({new_scale:.1e})" - ) - self.current_time_unit = new_unit - self.current_time_scale = new_scale - return True - - return False - - def reset_to_original_scale(self) -> None: - """Reset time scale to original values.""" - self.current_time_unit = self.original_time_unit - self.current_time_scale = self.original_time_scale - logger.info( - f"Reset to original scale: {self.current_time_unit} ({self.current_time_scale:.1e})" - ) - - def reset_to_initial_state(self) -> None: - """Reset all display parameters to initial values.""" - self.current_time_unit = self.original_time_unit - self.current_time_scale = self.original_time_scale - self.offset_time_raw = None - self.offset_unit = None - self.current_mode = None - self._updating = False - - def set_updating(self, value: bool = True) -> None: - """Set updating state to prevent recursion.""" - self._updating = value - - def is_updating(self) -> bool: - """Check if currently updating.""" - return self._updating +from typing import Optional, Tuple + +import warnings + +import numpy as np +from matplotlib.ticker import FuncFormatter + +# Time unit boundaries (hysteresis) +PICOSECOND_BOUNDARY = 0.8e-9 +NANOSECOND_BOUNDARY = 0.8e-6 +MICROSECOND_BOUNDARY = 0.8e-3 +MILLISECOND_BOUNDARY = 0.8 + +# Offset thresholds +OFFSET_SPAN_MULTIPLIER = 10 +OFFSET_TIME_THRESHOLD = 1e-3 # 1ms + + +def _get_optimal_time_unit_and_scale( + time_array_or_span: np.ndarray | float, +) -> Tuple[str, np.float32]: + """ + Determines the optimal time unit and scaling factor for a given time array or span. + + Uses hysteresis boundaries to prevent oscillation near unit boundaries. + + Parameters + ---------- + time_array_or_span : np.ndarray | float + A NumPy array representing time in seconds, or a single float representing a time span in seconds. + + Returns + ------- + Tuple[str, np.float32] + A tuple containing the time unit string (e.g., "s", "ms", "us", "ns") + and the corresponding scaling factor (e.0, 1e3, 1e6, 1e9). + """ + if isinstance(time_array_or_span, np.ndarray): + # Handle empty array case to prevent errors + if time_array_or_span.size == 0: + return "s", np.float32(1.0) # Default to seconds if no data + max_val = np.max(time_array_or_span) + else: # Assume it's a float representing a span + max_val = time_array_or_span + + # Use hysteresis boundaries to prevent oscillation near unit boundaries + if max_val < PICOSECOND_BOUNDARY: + return "ps", np.float32(1e12) + elif max_val < NANOSECOND_BOUNDARY: + return "ns", np.float32(1e9) + elif max_val < MICROSECOND_BOUNDARY: + return "us", np.float32(1e6) + elif max_val < MILLISECOND_BOUNDARY: + return "ms", np.float32(1e3) + else: + return "s", np.float32(1.0) + + +def _determine_offset_display_params( + xlim_raw: Tuple[np.float32, np.float32], time_span_raw: np.float32 +) -> Tuple[str, np.float32, Optional[np.float32], Optional[str]]: + """ + Determine display parameters including offset for optimal readability. + + Parameters + ---------- + xlim_raw : Tuple[np.float32, np.float32] + Current x-axis limits in raw time (seconds). + time_span_raw : np.float32 + Time span of current view in seconds. + + Returns + ------- + Tuple[str, np.float32, Optional[np.float32], Optional[str]] + Display unit, display scale, offset time (raw seconds), offset unit string. + If no offset is needed, offset_time and offset_unit will be None. + """ + # Get optimal unit for the time span + display_unit, display_scale = _get_optimal_time_unit_and_scale(time_span_raw) + + # Determine if we need an offset + # Use offset if the start time is significantly larger than the span + xlim_start = xlim_raw[0] + + # Use offset if start time is more than threshold multiplier of the span, and span is small + use_offset = (abs(xlim_start) > OFFSET_SPAN_MULTIPLIER * time_span_raw) and ( + time_span_raw < np.float32(OFFSET_TIME_THRESHOLD) + ) + + if use_offset: + # Choose appropriate unit for the offset + if abs(xlim_start) >= np.float32(1.0): # >= 1 second + offset_unit = "s" + offset_scale = np.float32(1.0) + elif abs(xlim_start) >= np.float32(1e-3): # >= 1 millisecond + offset_unit = "ms" + offset_scale = np.float32(1e3) + elif abs(xlim_start) >= np.float32(1e-6): # >= 1 microsecond + offset_unit = "us" + offset_scale = np.float32(1e6) + else: + offset_unit = "ns" + offset_scale = np.float32(1e9) + + return display_unit, display_scale, xlim_start, offset_unit + else: + return display_unit, display_scale, None, None + + +def _create_time_formatter( + offset_time_raw: Optional[np.float32], display_scale: np.float32 +) -> FuncFormatter: + """ + Create a FuncFormatter for time axis tick labels. + + Parameters + ---------- + offset_time_raw : Optional[np.float32] + Offset time in raw seconds. If None, no offset is applied. + display_scale : np.float32 + Scale factor for display units. + + Returns + ------- + FuncFormatter + Matplotlib formatter for tick labels. + """ + + def formatter(x, pos): + # x is already in display units (relative to offset if applicable) + # Format with appropriate precision based on scale + if display_scale >= np.float32(1e9): # nanoseconds or smaller + return f"{x:.0f}" + elif display_scale >= np.float32(1e6): # microseconds + return f"{x:.0f}" + elif display_scale >= np.float32(1e3): # milliseconds + return f"{x:.1f}" + else: # seconds + return f"{x:.3f}" + + return FuncFormatter(formatter) + + +class DisplayState: + """ + Manages display state and mode switching logic. + + Centralises state management to reduce complexity and flag interactions. + """ + + def __init__( + self, + original_time_unit: str, + original_time_scale: np.float32, + envelope_limit: np.float32, + ): + """ + Initialise display state. + + Parameters + ---------- + original_time_unit : str + Original time unit string. + original_time_scale : np.float32 + Original time scaling factor. + envelope_limit : np.float32 + Time span threshold for envelope mode. + """ + # Time scaling + self.original_time_unit = original_time_unit + self.original_time_scale = original_time_scale + self.current_time_unit = original_time_unit + self.current_time_scale = original_time_scale + + # Display mode + self.current_mode: Optional[str] = None + self.envelope_limit = envelope_limit + + # Offset parameters + self.offset_time_raw: Optional[np.float32] = None + self.offset_unit: Optional[str] = None + + # Single state flag - simplified + self._updating = False + + def get_time_unit_and_scale(self, t: np.ndarray) -> Tuple[str, np.float32]: + """ + Automatically select appropriate time unit and scale for plotting. + + Parameters + ---------- + t : np.ndarray + Time array. + + Returns + ------- + Tuple[str, np.float32] + Time unit string and scaling factor. + """ + # Delegate to the new utility function + return _get_optimal_time_unit_and_scale(t) + + def update_display_params( + self, xlim_raw: Tuple[np.float32, np.float32], time_span_raw: np.float32 + ) -> bool: + """ + Update display parameters including offset based on current view. + + Parameters + ---------- + xlim_raw : Tuple[np.float32, np.float32] + Current x-axis limits in raw time (seconds). + time_span_raw : np.float32 + Time span of current view in seconds. + + Returns + ------- + bool + True if display parameters changed, False otherwise. + """ + display_unit, display_scale, offset_time, offset_unit = ( + _determine_offset_display_params(xlim_raw, time_span_raw) + ) + + # Check if anything changed + params_changed = ( + display_unit != self.current_time_unit + or display_scale != self.current_time_scale + or offset_time != self.offset_time_raw + or offset_unit != self.offset_unit + ) + + if params_changed: + self.current_time_unit = display_unit + self.current_time_scale = display_scale + self.offset_time_raw = offset_time + self.offset_unit = offset_unit + return True + + return False + + def should_use_envelope(self, time_span_raw: np.float32) -> bool: + """Determine if envelope mode should be used based on time span.""" + return time_span_raw > self.envelope_limit + + def should_show_thresholds(self, time_span_raw: np.float32) -> bool: + """Determine if threshold lines should be shown based on time span.""" + return time_span_raw < self.envelope_limit + + def update_time_scale(self, time_span_raw: np.float32) -> bool: + """ + Update time scale based on current view span. + + Returns True if scale changed, False otherwise. + """ + # Delegate to the new utility function for span + new_unit, new_scale = _get_optimal_time_unit_and_scale(time_span_raw) + + if new_scale != self.current_time_scale: + self.current_time_unit = new_unit + self.current_time_scale = new_scale + return True + + return False + + def reset_to_original_scale(self) -> None: + """Reset time scale to original values.""" + self.current_time_unit = self.original_time_unit + self.current_time_scale = self.original_time_scale + + def reset_to_initial_state(self) -> None: + """Reset all display parameters to initial values.""" + self.current_time_unit = self.original_time_unit + self.current_time_scale = self.original_time_scale + self.offset_time_raw = None + self.offset_unit = None + self.current_mode = None + self._updating = False + + def set_updating(self, value: bool = True) -> None: + """Set updating state to prevent recursion.""" + self._updating = value + + def is_updating(self) -> bool: + """Check if currently updating.""" + return self._updating diff --git a/src/scopekit/plot.py b/src/scopekit/plot.py index 83c10bc..f405432 100644 --- a/src/scopekit/plot.py +++ b/src/scopekit/plot.py @@ -1,1829 +1,1621 @@ -from typing import Any, Dict, List, Optional, Tuple, Union - -import matplotlib as mpl -import matplotlib.pyplot as plt -import numpy as np -from loguru import logger -from matplotlib.ticker import MultipleLocator - -from .coordinate_manager import CoordinateManager -from .data_manager import TimeSeriesDataManager -from .decimation import DecimationManager -from .display_state import ( - DisplayState, - _create_time_formatter, - _get_optimal_time_unit_and_scale, -) - - -class OscilloscopePlot: - """ - General-purpose plotting class for time-series data with zoom and decimation. - - Uses separate managers for data, decimation, and state to reduce complexity. - Supports different visualization elements (lines, envelopes, ribbons, regions) - that can be displayed in different modes (envelope when zoomed out, detail when zoomed in). - """ - - # Mode constants - MODE_ENVELOPE = 1 # Zoomed out mode - MODE_DETAIL = 2 # Zoomed in mode - MODE_BOTH = 3 # Both modes - - # Default styling constants - DEFAULT_MAX_PLOT_POINTS = 10000 - DEFAULT_MODE_SWITCH_THRESHOLD = 10e-3 # 10 ms - DEFAULT_MIN_Y_RANGE_DEFAULT = 1e-9 # Default minimum Y-axis range (e.g., 1 nV) - DEFAULT_Y_MARGIN_FRACTION = 0.15 - DEFAULT_SIGNAL_LINE_WIDTH = 1.0 - DEFAULT_SIGNAL_ALPHA = 0.75 - DEFAULT_ENVELOPE_ALPHA = 0.75 - DEFAULT_REGION_ALPHA = 0.4 - DEFAULT_REGION_ZORDER = -5 - - def __init__( - self, - t: Union[np.ndarray, List[np.ndarray]], - x: Union[np.ndarray, List[np.ndarray]], - name: Union[str, List[str]] = "Waveform", - trace_colors: Optional[List[str]] = None, - # Core display parameters - max_plot_points: int = DEFAULT_MAX_PLOT_POINTS, - mode_switch_threshold: float = DEFAULT_MODE_SWITCH_THRESHOLD, - min_y_range: Optional[float] = None, # New parameter for minimum Y-axis range - y_margin_fraction: float = DEFAULT_Y_MARGIN_FRACTION, - signal_line_width: float = DEFAULT_SIGNAL_LINE_WIDTH, - signal_alpha: float = DEFAULT_SIGNAL_ALPHA, - envelope_alpha: float = DEFAULT_ENVELOPE_ALPHA, - region_alpha: float = DEFAULT_REGION_ALPHA, - region_zorder: int = DEFAULT_REGION_ZORDER, - envelope_window_samples: Optional[int] = None, - ): - """ - Initialize the OscilloscopePlot with time series data. - - Parameters - ---------- - t : Union[np.ndarray, List[np.ndarray]] - Time array(s) (raw time in seconds). Can be a single array shared by all traces - or a list of arrays, one per trace. - x : Union[np.ndarray, List[np.ndarray]] - Signal array(s). If t is a single array, x can be a 2D array (traces x samples) - or a list of 1D arrays. If t is a list, x must be a list of equal length. - name : Union[str, List[str]], default="Waveform" - Name(s) for plot title. Can be a single string or a list of strings. - trace_colors : Optional[List[str]], default=None - Colors for each trace. If None, default colors will be used. - max_plot_points : int, default=10000 - Maximum number of points to display on the plot. Data will be decimated if it exceeds this. - mode_switch_threshold : float, default=10e-3 - Time span (in seconds) above which the plot switches to envelope mode. - min_y_range : Optional[float], default=None - Minimum Y-axis range to enforce. If None, a default small value is used. - y_margin_fraction : float, default=0.05 - Fraction of data range to add as margin to Y-axis limits. - signal_line_width : float, default=1.0 - Line width for the raw signal plot. - signal_alpha : float, default=0.75 - Alpha (transparency) for the raw signal plot. - envelope_alpha : float, default=1.0 - Alpha (transparency) for the envelope fill. - region_alpha : float, default=0.4 - Alpha (transparency) for region highlight fills. - region_zorder : int, default=-5 - Z-order for region highlight fills (lower means further back). - envelope_window_samples : Optional[int], default=None - DEPRECATED: Window size in samples for envelope calculation. - Envelope window is now calculated automatically based on max_plot_points and zoom level. - This parameter is ignored but kept for backward compatibility. - """ - # Store styling parameters directly as instance attributes - self.max_plot_points = max_plot_points - self.mode_switch_threshold = np.float32(mode_switch_threshold) - self.min_y_range = ( - np.float32(min_y_range) - if min_y_range is not None - else self.DEFAULT_MIN_Y_RANGE_DEFAULT - ) - self.y_margin_fraction = np.float32(y_margin_fraction) - self.signal_line_width = signal_line_width - self.signal_alpha = signal_alpha - self.envelope_alpha = envelope_alpha - self.region_alpha = region_alpha - self.region_zorder = region_zorder - # envelope_window_samples is now deprecated - envelope window is calculated automatically - # Keep the parameter for backward compatibility but don't use it - if envelope_window_samples is not None: - logger.warning( - "envelope_window_samples parameter is deprecated. Envelope window is now calculated automatically based on zoom level." - ) - - # Initialize managers - self.data = TimeSeriesDataManager(t, x, name, trace_colors) - self.decimator = DecimationManager() - - # Pre-decimate main signal data for envelope view - for i in range(self.data.num_traces): - self.decimator.pre_decimate_data( - data_id=i, # Use trace_idx as data_id - t=self.data.t_arrays[i], - x=self.data.x_arrays[i], - max_points=self.max_plot_points, - envelope_window_samples=None, # Envelope window calculated automatically - ) - - # Initialize display state using the first trace's time array - initial_time_unit, initial_time_scale = _get_optimal_time_unit_and_scale( - self.data.t_arrays[0] - ) - self.state = DisplayState( - initial_time_unit, initial_time_scale, self.mode_switch_threshold - ) - - # Initialize matplotlib figure and axes to None - self.fig: Optional[mpl.figure.Figure] = None - self.ax: Optional[mpl.axes.Axes] = None - - # Store visualization elements for each trace - self._signal_lines: List[mpl.lines.Line2D] = [] - self._envelope_fills: List[Optional[mpl.collections.PolyCollection]] = [ - None - ] * self.data.num_traces - - # Visualization elements with mode control (definitions, not plot objects) - self._lines: List[List[Dict[str, Any]]] = [ - [] for _ in range(self.data.num_traces) - ] - self._ribbons: List[List[Dict[str, Any]]] = [ - [] for _ in range(self.data.num_traces) - ] - self._regions: List[List[Dict[str, Any]]] = [ - [] for _ in range(self.data.num_traces) - ] - self._envelopes: List[List[Dict[str, Any]]] = [ - [] for _ in range(self.data.num_traces) - ] - - # Line objects for each trace (will be populated as needed during rendering) - self._line_objects: List[List[mpl.artist.Artist]] = [ - [] for _ in range(self.data.num_traces) - ] # Changed type hint to Artist - self._ribbon_objects: List[List[mpl.collections.PolyCollection]] = [ - [] for _ in range(self.data.num_traces) - ] - self._region_objects: List[List[mpl.collections.PolyCollection]] = [ - [] for _ in range(self.data.num_traces) - ] - - # Store current plot data for access by other methods - self._current_plot_data = {} - - # Initialize coordinate manager - self.coord_manager = CoordinateManager(self.state) - - # Store initial view for home button (using global time range) - t_start, t_end = self.data.get_global_time_range() - self._initial_xlim_raw = (t_start, t_end) - - # Legend state for optimization - self._current_legend_handles: List[mpl.artist.Artist] = [] - self._current_legend_labels: List[str] = [] - self._legend: Optional[mpl.legend.Legend] = None - - # Track last mode for each trace to optimize element updates - self._last_mode: Dict[int, Optional[int]] = { - i: None for i in range(self.data.num_traces) - } - - # Store original toolbar methods for restoration - self._original_home = None - self._original_push_current = None - - def save(self, filepath: str) -> None: - """ - Save the current plot to a file. - - Parameters - ---------- - filepath : str - Path to save the plot image. - """ - if self.fig is None or self.ax is None: - raise RuntimeError("Plot has not been initialized yet.") - self.fig.savefig(filepath) - logger.info(f"Plot saved to {filepath}") - - def add_line( - self, - t: Union[np.ndarray, List[np.ndarray]], - data: Union[np.ndarray, List[np.ndarray]], - label: str = "Line", - color: Optional[str] = None, - alpha: float = 0.75, - linestyle: str = "-", - linewidth: float = 1.0, - display_mode: int = MODE_BOTH, - trace_idx: int = 0, - zorder: int = 5, - ) -> None: - """ - Add a line to the plot with mode control. - - Parameters - ---------- - t : Union[np.ndarray, List[np.ndarray]] - Time array(s) for the line data. Must match the length of data. - data : Union[np.ndarray, List[np.ndarray]] - Line data array(s). Can be a single array or a list of arrays. - label : str, default="Line" - Label for the legend. - color : Optional[str], default=None - Color for the line. If None, the trace color will be used. - alpha : float, default=0.75 - Alpha (transparency) for the line. - linestyle : str, default="-" - Line style. - linewidth : float, default=1.0 - Line width. - display_mode : int, default=MODE_BOTH - Which mode(s) to show this line in (MODE_ENVELOPE, MODE_DETAIL, or MODE_BOTH). - trace_idx : int, default=0 - Index of the trace to add the line to. - zorder : int, default=5 - Z-order for the line (higher values appear on top). - """ - if trace_idx < 0 or trace_idx >= self.data.num_traces: - raise ValueError( - f"Invalid trace index: {trace_idx}. Must be between 0 and {self.data.num_traces - 1}." - ) - - # Validate data length - if isinstance(data, list): - if len(data) != len(t): - raise ValueError( - f"Line data length ({len(data)}) must match time array length ({len(t)})." - ) - else: - if len(data) != len(t): - raise ValueError( - f"Line data length ({len(data)}) must match time array length ({len(t)})." - ) - - # Use trace color if none provided - if color is None: - color = self.data.get_trace_color(trace_idx) - - # Convert inputs to numpy arrays - t_array = np.asarray(t, dtype=np.float32) - data_array = np.asarray(data, dtype=np.float32) - - # Assign a unique ID for this custom line for pre-decimation caching - # We use a negative ID to distinguish from main traces (which use 0, 1, 2...) - # and ensure uniqueness across custom lines. - line_id = -(len(self._lines[trace_idx]) + 1) # Negative, unique per trace - - # Pre-decimate this custom line's data for envelope view - self.decimator.pre_decimate_data( - data_id=line_id, - t=t_array, - x=data_array, - max_points=self.max_plot_points, - envelope_window_samples=None, # Envelope window calculated automatically - ) - - # Store line definition with raw data and its assigned ID - line_def = { - "id": line_id, # Store the ID for retrieval from decimator - "t_raw": t_array, # Store raw time array - "data_raw": data_array, # Store raw data array - "label": label, - "color": color, - "alpha": alpha, - "linestyle": linestyle, - "linewidth": linewidth, - "display_mode": display_mode, - "zorder": zorder, - } - - logger.debug( - f"Adding line '{label}' with display_mode={display_mode} (MODE_ENVELOPE={self.MODE_ENVELOPE}, MODE_DETAIL={self.MODE_DETAIL}, MODE_BOTH={self.MODE_BOTH})" - ) - self._lines[trace_idx].append(line_def) - - def add_ribbon( - self, - t: Union[np.ndarray, List[np.ndarray]], - center_data: Union[np.ndarray, List[np.ndarray]], - width: Union[float, np.ndarray], - label: str = "Ribbon", - color: str = "gray", - alpha: float = 0.6, - display_mode: int = MODE_DETAIL, - trace_idx: int = 0, - zorder: int = 2, - ) -> None: - """ - Add a ribbon (center ± width) with mode control. - - Parameters - ---------- - t : Union[np.ndarray, List[np.ndarray]] - Time array(s) for the ribbon data. Must match the length of center_data. - center_data : Union[np.ndarray, List[np.ndarray]] - Center line data array(s). Can be a single array or a list of arrays. - width : Union[float, np.ndarray] - Width of the ribbon. Can be a single value or an array matching center_data. - label : str, default="Ribbon" - Label for the legend. - color : str, default="gray" - Color for the ribbon. - alpha : float, default=0.6 - Alpha (transparency) for the ribbon. - display_mode : int, default=MODE_DETAIL - Which mode(s) to show this ribbon in (MODE_ENVELOPE, MODE_DETAIL, or MODE_BOTH). - trace_idx : int, default=0 - Index of the trace to add the ribbon to. - """ - if trace_idx < 0 or trace_idx >= self.data.num_traces: - raise ValueError( - f"Invalid trace index: {trace_idx}. Must be between 0 and {self.data.num_traces - 1}." - ) - - # Validate data length - if isinstance(center_data, list): - if len(center_data) != len(t): - raise ValueError( - f"Ribbon center data length ({len(center_data)}) must match time array length ({len(t)})." - ) - else: - if len(center_data) != len(t): - raise ValueError( - f"Ribbon center data length ({len(center_data)}) must match time array length ({len(t)})." - ) - - # Convert center data to numpy array - center_data = np.asarray(center_data, dtype=np.float32) - - # Handle width as scalar or array - if isinstance(width, (int, float, np.number)): - width_array = np.ones_like(center_data) * width - else: - if len(width) != len(center_data): - raise ValueError( - f"Ribbon width array length ({len(width)}) must match center data length ({len(center_data)})." - ) - width_array = np.asarray(width, dtype=np.float32) - - # Assign a unique ID for this custom ribbon - ribbon_id = -( - len(self._ribbons[trace_idx]) + 1001 - ) # Negative, unique per trace, offset from lines - - # Pre-decimate this custom ribbon's center data for envelope view - # We only pre-decimate the center, as width is applied later - self.decimator.pre_decimate_data( - data_id=ribbon_id, - t=np.asarray(t, dtype=np.float32), - x=center_data, - max_points=self.max_plot_points, - envelope_window_samples=None, # Envelope window calculated automatically - ) - - # Store ribbon definition - ribbon_def = { - "id": ribbon_id, - "t_raw": np.asarray(t, dtype=np.float32), - "center_data_raw": center_data, - "width_raw": width_array, - "label": label, - "color": color, - "alpha": alpha, - "display_mode": display_mode, - "zorder": zorder, - } - - self._ribbons[trace_idx].append(ribbon_def) - - def add_envelope( - self, - min_data: Union[np.ndarray, List[np.ndarray]], - max_data: Union[np.ndarray, List[np.ndarray]], - label: str = "Envelope", - color: Optional[str] = None, - alpha: float = 0.4, - display_mode: int = MODE_ENVELOPE, - trace_idx: int = 0, - zorder: int = 1, - ) -> None: - """ - Add envelope data with mode control. - - Parameters - ---------- - min_data : Union[np.ndarray, List[np.ndarray]] - Minimum envelope data array(s). Can be a single array or a list of arrays. - max_data : Union[np.ndarray, List[np.ndarray]] - Maximum envelope data array(s). Can be a single array or a list of arrays. - label : str, default="Envelope" - Label for the legend. - color : Optional[str], default=None - Color for the envelope. If None, the trace color will be used. - alpha : float, default=0.4 - Alpha (transparency) for the envelope. - display_mode : int, default=MODE_ENVELOPE - Which mode(s) to show this envelope in (MODE_ENVELOPE, MODE_DETAIL, or MODE_BOTH). - trace_idx : int, default=0 - Index of the trace to add the envelope to. - """ - if trace_idx < 0 or trace_idx >= self.data.num_traces: - raise ValueError( - f"Invalid trace index: {trace_idx}. Must be between 0 and {self.data.num_traces - 1}." - ) - - # Validate data length - if isinstance(min_data, list): - if len(min_data) != len(self.data.t_arrays[trace_idx]): - raise ValueError( - f"Envelope min data length ({len(min_data)}) must match time array length ({len(self.data.t_arrays[trace_idx])})." - ) - else: - if len(min_data) != len(self.data.t_arrays[trace_idx]): - raise ValueError( - f"Envelope min data length ({len(min_data)}) must match time array length ({len(self.data.t_arrays[trace_idx])})." - ) - - if isinstance(max_data, list): - if len(max_data) != len(self.data.t_arrays[trace_idx]): - raise ValueError( - f"Envelope max data length ({len(max_data)}) must match time array length ({len(self.data.t_arrays[trace_idx])})." - ) - else: - if len(max_data) != len(self.data.t_arrays[trace_idx]): - raise ValueError( - f"Envelope max data length ({len(max_data)}) must match time array length ({len(self.data.t_arrays[trace_idx])})." - ) - - # Use trace color if none provided - if color is None: - color = self.data.get_trace_color(trace_idx) - - # Assign a unique ID for this custom envelope - envelope_id = -( - len(self._envelopes[trace_idx]) + 2001 - ) # Negative, unique per trace, offset from ribbons - - # Pre-decimate this custom envelope's data for envelope view - # We'll pre-decimate the average of min/max, and store min/max separately - t_raw = self.data.t_arrays[trace_idx] - avg_data = ( - np.asarray(min_data, dtype=np.float32) - + np.asarray(max_data, dtype=np.float32) - ) / 2 - - self.decimator.pre_decimate_data( - data_id=envelope_id, - t=t_raw, - x=avg_data, # Pass average for decimation - max_points=self.max_plot_points, - envelope_window_samples=None, # Envelope window calculated automatically - ) - - # Store envelope definition - envelope_def = { - "id": envelope_id, - "t_raw": t_raw, - "min_data_raw": np.asarray(min_data, dtype=np.float32), - "max_data_raw": np.asarray(max_data, dtype=np.float32), - "label": label, - "color": color, - "alpha": alpha, - "display_mode": display_mode, - "zorder": zorder, - } - - self._envelopes[trace_idx].append(envelope_def) - - def add_regions( - self, - regions: np.ndarray, - label: str = "Regions", - color: str = "crimson", - alpha: float = 0.4, - display_mode: int = MODE_BOTH, - trace_idx: int = 0, - zorder: int = -5, - ) -> None: - """ - Add region highlights with mode control. - - Parameters - ---------- - regions : np.ndarray - Region data array with shape (N, 2) where each row is [start_time, end_time]. - label : str, default="Regions" - Label for the legend. - color : str, default="crimson" - Color for the regions. - alpha : float, default=0.4 - Alpha (transparency) for the regions. - display_mode : int, default=MODE_BOTH - Which mode(s) to show these regions in (MODE_ENVELOPE, MODE_DETAIL, or MODE_BOTH). - trace_idx : int, default=0 - Index of the trace to add the regions to. - """ - if trace_idx < 0 or trace_idx >= self.data.num_traces: - raise ValueError( - f"Invalid trace index: {trace_idx}. Must be between 0 and {self.data.num_traces - 1}." - ) - - # Validate regions array - if regions.ndim != 2 or regions.shape[1] != 2: - raise ValueError( - f"Regions array must have shape (N, 2), got {regions.shape}." - ) - - # Store regions definition - region_def = { - "regions": np.asarray(regions, dtype=np.float32), - "label": label, - "color": color, - "alpha": alpha, - "display_mode": display_mode, - "zorder": zorder, - } - - logger.debug( - f"Adding regions '{label}' with {len(regions)} entries, display_mode={display_mode}" - ) - self._regions[trace_idx].append(region_def) - - def _update_signal_display( - self, - trace_idx: int, - t_display: np.ndarray, - x_data: np.ndarray, - envelope_data: Optional[Tuple[np.ndarray, np.ndarray]] = None, - ) -> None: - """ - Update signal display with envelope or raw data for a specific trace. - - Parameters - ---------- - trace_idx : int - Index of the trace to update. - t_display : np.ndarray - Display time array. - x_data : np.ndarray - Signal data array. - envelope_data : Optional[Tuple[np.ndarray, np.ndarray]], default=None - Tuple of (min, max) envelope data if in envelope mode. - """ - logger.debug(f"=== _update_signal_display trace {trace_idx} ===") - logger.debug( - f"t_display: len={len(t_display)}, range=[{np.min(t_display) if len(t_display) > 0 else 'empty':.6f}, {np.max(t_display) if len(t_display) > 0 else 'empty':.6f}]" - ) - logger.debug( - f"x_data: len={len(x_data)}, range=[{np.min(x_data) if len(x_data) > 0 else 'empty':.6f}, {np.max(x_data) if len(x_data) > 0 else 'empty':.6f}]" - ) - logger.debug(f"envelope_data: {envelope_data is not None}") - - if envelope_data is not None: - x_min, x_max = envelope_data - logger.debug( - f"envelope x_min: len={len(x_min)}, range=[{np.min(x_min) if len(x_min) > 0 else 'empty':.6f}, {np.max(x_min) if len(x_min) > 0 else 'empty':.6f}]" - ) - logger.debug( - f"envelope x_max: len={len(x_max)}, range=[{np.max(x_max) if len(x_max) > 0 else 'empty':.6f}, {np.max(x_max) if len(x_max) > 0 else 'empty':.6f}]" - ) - self._show_envelope_mode(trace_idx, t_display, envelope_data) - else: - logger.debug("Showing detail mode (raw signal)") - self._show_detail_mode(trace_idx, t_display, x_data) - - def _show_envelope_mode( - self, - trace_idx: int, - t_display: np.ndarray, - envelope_data: Tuple[np.ndarray, np.ndarray], - ) -> None: - """ - Show envelope display mode for a specific trace. - - Parameters - ---------- - trace_idx : int - Index of the trace to update. - t_display : np.ndarray - Display time array. - envelope_data : Tuple[np.ndarray, np.ndarray] - Tuple of (min, max) envelope data. - """ - logger.debug(f"=== _show_envelope_mode trace {trace_idx} ===") - x_min, x_max = envelope_data - color = self.data.get_trace_color(trace_idx) - name = self.data.get_trace_name(trace_idx) - - logger.debug(f"Envelope data: x_min len={len(x_min)}, x_max len={len(x_max)}") - logger.debug( - f"t_display range: [{np.min(t_display):.6f}, {np.max(t_display):.6f}]" - ) - logger.debug(f"y_range: [{np.min(x_min):.6f}, {np.max(x_max):.6f}]") - - # Clean up previous displays - if self._envelope_fills[trace_idx] is not None: - logger.debug("Removing previous envelope fill") - self._envelope_fills[trace_idx].remove() - - logger.debug("Hiding signal line") - self._signal_lines[trace_idx].set_data([], []) - self._signal_lines[trace_idx].set_visible(False) - - # Show built-in envelope - logger.debug( - f"Creating envelope fill with color={color}, alpha={self.envelope_alpha}" - ) - self._envelope_fills[trace_idx] = self.ax.fill_between( - t_display, - x_min, - x_max, - alpha=self.envelope_alpha, - color=color, - lw=0.1, - label=f"Raw envelope ({name})" - if self.data.num_traces > 1 - else "Raw envelope", - zorder=1, # Keep default envelope at zorder=1 - ) - - # Set current mode - self.state.current_mode = "envelope" - logger.debug("Set current_mode to 'envelope'") - - # Show any custom elements for this mode - self._show_custom_elements(trace_idx, t_display, self.MODE_ENVELOPE) - - def _show_detail_mode( - self, trace_idx: int, t_display: np.ndarray, x_data: np.ndarray - ) -> None: - """ - Show detail display mode for a specific trace. - - Parameters - ---------- - trace_idx : int - Index of the trace to update. - t_display : np.ndarray - Display time array. - x_data : np.ndarray - Signal data array. - """ - logger.debug(f"=== _show_detail_mode trace {trace_idx} ===") - logger.debug( - f"t_display: len={len(t_display)}, range=[{np.min(t_display) if len(t_display) > 0 else 'empty':.6f}, {np.max(t_display) if len(t_display) > 0 else 'empty':.6f}]" - ) - logger.debug( - f"x_data: len={len(x_data)}, range=[{np.min(x_data) if len(x_data) > 0 else 'empty':.6f}, {np.max(x_data) if len(x_data) > 0 else 'empty':.6f}]" - ) - - # Clean up envelope - if self._envelope_fills[trace_idx] is not None: - logger.debug("Removing envelope fill") - self._envelope_fills[trace_idx].remove() - self._envelope_fills[trace_idx] = None - - # Update signal line - line = self._signal_lines[trace_idx] - logger.debug( - f"Setting signal line data: linewidth={self.signal_line_width}, alpha={self.signal_alpha}" - ) - line.set_data(t_display, x_data) - line.set_linewidth(self.signal_line_width) - line.set_alpha(self.signal_alpha) - line.set_visible(True) - - # Set current mode - self.state.current_mode = "detail" - logger.debug("Set current_mode to 'detail'") - - # Show any custom elements for this mode - self._show_custom_elements(trace_idx, t_display, self.MODE_DETAIL) - - def _show_custom_elements( - self, trace_idx: int, t_display: np.ndarray, current_mode: int - ) -> None: - """ - Show custom visualization elements for the current mode. - - Parameters - ---------- - trace_idx : int - Index of the trace to update. - t_display : np.ndarray - Display time array. - current_mode : int - Current display mode (MODE_ENVELOPE or MODE_DETAIL). - """ - logger.debug( - f"=== _show_custom_elements trace {trace_idx}, current_mode={current_mode} ===" - ) - - last_mode = self._last_mode.get(trace_idx) - logger.debug(f"Last mode for trace {trace_idx}: {last_mode}") - - # Always clear and recreate elements when view changes, regardless of mode change - # This ensures custom lines/ribbons are redrawn correctly with current view data - logger.debug( - f"Clearing and recreating elements for trace {trace_idx} (mode: {last_mode} -> {current_mode})" - ) - self._clear_custom_elements(trace_idx) - - # Get current raw x-limits from the main plot data - # This is crucial for decimating custom lines to the current view - current_xlim_raw = self.coord_manager.get_current_view_raw(self.ax) - - # Show lines for current mode - line_objects = [] - for i, line_def in enumerate(self._lines[trace_idx]): - logger.debug( - f"Processing line {i} ('{line_def['label']}'): display_mode={line_def['display_mode']}, current_mode={current_mode}" - ) - if ( - line_def["display_mode"] & current_mode - ): # Bitwise check if mode is enabled - logger.debug( - f"Line {i} ('{line_def['label']}') should be visible in mode {current_mode}" - ) - - # Dynamically decimate the line data for the current view - # Use the same max_plot_points as the main signal for consistency - # For custom lines, we want mean decimation if in envelope mode, not min/max envelope - t_line_raw, line_data, _, _ = self.decimator.decimate_for_view( - line_def["t_raw"], - line_def["data_raw"], - current_xlim_raw, # Decimate to current view - self.max_plot_points, - use_envelope=(current_mode == self.MODE_ENVELOPE), - data_id=line_def[ - "id" - ], # Pass the custom line's ID for pre-decimated data lookup - envelope_window_samples=None, # Envelope window calculated automatically - mode_switch_threshold=self.mode_switch_threshold, # Pass mode switch threshold - return_envelope_min_max=False, # Custom lines never return min/max envelope - ) - - if len(t_line_raw) == 0 or len(line_data) == 0: - logger.warning( - f"Line {i} ('{line_def['label']}') has empty data after decimation for current view, skipping plot." - ) - continue - - # Make sure the time array is in display coordinates - t_line_display = self.coord_manager.raw_to_display(t_line_raw) - - # Always plot as a regular line - (line,) = self.ax.plot( - t_line_display, - line_data, - label=line_def["label"], - color=line_def["color"], - alpha=line_def["alpha"], - linestyle=line_def["linestyle"], - linewidth=line_def["linewidth"], - zorder=line_def["zorder"], - ) - line_objects.append( - (line, line_def) - ) # Store both the line and its definition - logger.debug(f"Added line {i} ('{line_def['label']}') to plot") - else: - logger.debug( - f"Line {i} ('{line_def['label']}') should NOT be visible in mode {current_mode}" - ) - - # Show ribbons for current mode - ribbon_objects = [] - for ribbon_def in self._ribbons[trace_idx]: - logger.debug( - f"Processing ribbon ('{ribbon_def['label']}'): display_mode={ribbon_def['display_mode']}, current_mode={current_mode}" - ) - if ribbon_def["display_mode"] & current_mode: - logger.debug( - f"Ribbon ('{ribbon_def['label']}') should be visible in mode {current_mode}" - ) - - # Ribbons are always plotted as fills, so we need to decimate their center and width - # We'll treat the center_data as the 'signal' for decimation purposes - ( - t_ribbon_raw, - center_data_decimated, - min_center_envelope, - max_center_envelope, - ) = self.decimator.decimate_for_view( - ribbon_def["t_raw"], - ribbon_def["center_data_raw"], - current_xlim_raw, - self.max_plot_points, - use_envelope=( - current_mode == self.MODE_ENVELOPE - ), # Use envelope for ribbons if in envelope mode - data_id=ribbon_def[ - "id" - ], # Pass the custom ribbon's ID for pre-decimated data lookup - return_envelope_min_max=True, # Ribbons always need min/max to draw fill - envelope_window_samples=None, # Envelope window calculated automatically - mode_switch_threshold=self.mode_switch_threshold, - ) - - # Decimate the width array as well, if it's an array - width_decimated = ribbon_def["width_raw"] - if len(ribbon_def["width_raw"]) > len( - t_ribbon_raw - ): # If raw width is longer than decimated time - # For simplicity, we'll just take the mean of the width in each bin - # A more robust solution might involve passing width as another data stream to decimate_for_view - # For now, we'll manually decimate it based on the t_ribbon_raw indices - # Find indices in raw data corresponding to decimated time points - # This is a simplified approach and assumes uniform sampling for width - indices = np.searchsorted(ribbon_def["t_raw"], t_ribbon_raw) - indices = np.clip(indices, 0, len(ribbon_def["width_raw"]) - 1) - width_decimated = ribbon_def["width_raw"][indices] - - # If the ribbon was decimated to an envelope, use that for min/max - if ( - current_mode == self.MODE_ENVELOPE - and min_center_envelope is not None - and max_center_envelope is not None - ): - lower_bound = min_center_envelope - width_decimated - upper_bound = max_center_envelope + width_decimated - else: - lower_bound = center_data_decimated - width_decimated - upper_bound = center_data_decimated + width_decimated - - if len(t_ribbon_raw) == 0 or len(lower_bound) == 0: - logger.warning( - f"Ribbon ('{ribbon_def['label']}') has empty data after decimation, skipping plot." - ) - continue - - # Make sure the time array is in display coordinates - t_ribbon_display = self.coord_manager.raw_to_display(t_ribbon_raw) - - ribbon = self.ax.fill_between( - t_ribbon_display, - lower_bound, - upper_bound, - color=ribbon_def["color"], - alpha=ribbon_def["alpha"], - label=ribbon_def["label"], - zorder=ribbon_def["zorder"], - ) - ribbon_objects.append( - (ribbon, ribbon_def) - ) # Store both the ribbon and its definition - logger.debug(f"Added ribbon ('{ribbon_def['label']}') to plot") - else: - logger.debug( - f"Ribbon ('{ribbon_def['label']}') should NOT be visible in mode {current_mode}" - ) - - # Show custom envelopes for current mode - for envelope_def in self._envelopes[trace_idx]: - logger.debug( - f"Processing custom envelope ('{envelope_def['label']}'): display_mode={envelope_def['display_mode']}, current_mode={current_mode}" - ) - if envelope_def["display_mode"] & current_mode: - logger.debug( - f"Custom envelope ('{envelope_def['label']}') should be visible in mode {current_mode}" - ) - - # For custom envelopes, we need to handle min/max data specially - # We'll decimate the min and max data separately using the envelope's stored data - # Since we stored min/max in the pre-decimated data, we can retrieve them - - # Get the pre-decimated envelope data for this custom envelope - if envelope_def["id"] in self.decimator._pre_decimated_envelopes: - pre_dec_data = self.decimator._pre_decimated_envelopes[ - envelope_def["id"] - ] - # The min/max data was stored in bg_initial/bg_clean during pre-decimation - t_envelope_raw, _, min_data_decimated, max_data_decimated = ( - self.decimator.decimate_for_view( - envelope_def["t_raw"], - ( - envelope_def["min_data_raw"] - + envelope_def["max_data_raw"] - ) - / 2, # Average for decimation - current_xlim_raw, - self.max_plot_points, - use_envelope=True, # Always treat custom envelopes as envelopes - data_id=envelope_def[ - "id" - ], # Pass the custom envelope's ID for pre-decimated data lookup - return_envelope_min_max=True, # Custom envelopes always need min/max to draw fill - envelope_window_samples=None, # Envelope window calculated automatically - mode_switch_threshold=self.mode_switch_threshold, - ) - ) - # For custom envelopes, the min/max are returned directly as the last two return values - else: - # Fallback if no pre-decimated data - logger.warning( - f"No pre-decimated data for custom envelope {envelope_def['id']}, using raw decimation" - ) - t_envelope_raw, _, min_data_decimated, max_data_decimated = ( - self.decimator.decimate_for_view( - envelope_def["t_raw"], - ( - envelope_def["min_data_raw"] - + envelope_def["max_data_raw"] - ) - / 2, - current_xlim_raw, - self.max_plot_points, - use_envelope=True, - data_id=None, # No pre-decimated data available - return_envelope_min_max=True, - envelope_window_samples=None, # Envelope window calculated automatically - mode_switch_threshold=self.mode_switch_threshold, - ) - ) - - if ( - len(t_envelope_raw) == 0 - or min_data_decimated is None - or max_data_decimated is None - or len(min_data_decimated) == 0 - ): - logger.warning( - f"Custom envelope ('{envelope_def['label']}') has empty data after decimation, skipping plot." - ) - continue - - t_envelope_display = self.coord_manager.raw_to_display(t_envelope_raw) - - envelope = self.ax.fill_between( - t_envelope_display, - min_data_decimated, - max_data_decimated, - color=envelope_def["color"], - alpha=envelope_def["alpha"], - label=envelope_def["label"], - zorder=envelope_def["zorder"], - ) - ribbon_objects.append( - (envelope, envelope_def) - ) # Store in ribbon objects - logger.debug( - f"Added custom envelope ('{envelope_def['label']}') to plot" - ) - else: - logger.debug( - f"Custom envelope ('{envelope_def['label']}') should NOT be visible in mode {current_mode}" - ) - - # Store objects with their definitions for future updates - self._line_objects[trace_idx] = line_objects - self._ribbon_objects[trace_idx] = ribbon_objects - - # Update last mode AFTER processing - self._last_mode[trace_idx] = current_mode - - def _update_element_visibility(self, trace_idx: int, current_mode: int) -> None: - """ - Update visibility of existing custom elements based on current mode. - - Parameters - ---------- - trace_idx : int - Index of the trace to update. - current_mode : int - Current display mode (MODE_ENVELOPE or MODE_DETAIL). - """ - logger.debug( - f"Updating element visibility for trace {trace_idx}, current_mode={current_mode}" - ) - # Update line visibility - for line_obj, line_def in self._line_objects[trace_idx]: - should_be_visible = bool(line_def["display_mode"] & current_mode) - if line_obj.get_visible() != should_be_visible: - line_obj.set_visible(should_be_visible) - logger.debug( - f"Set visibility of line '{line_def['label']}' to {should_be_visible}" - ) - - # Update ribbon visibility - for ribbon_obj, ribbon_def in self._ribbon_objects[trace_idx]: - should_be_visible = bool(ribbon_def["display_mode"] & current_mode) - if ribbon_obj.get_visible() != should_be_visible: - ribbon_obj.set_visible(should_be_visible) - logger.debug( - f"Set visibility of ribbon '{ribbon_def['label']}' to {should_be_visible}" - ) - - def _clear_custom_elements(self, trace_idx: int) -> None: - """ - Clear all custom visualization elements for a trace. - - Parameters - ---------- - trace_idx : int - Index of the trace to clear elements for. - """ - logger.debug(f"Clearing custom elements for trace {trace_idx}") - # Clear lines - for line_obj, _ in self._line_objects[trace_idx]: - line_obj.remove() - self._line_objects[trace_idx].clear() - - # Clear ribbons - for ribbon_obj, _ in self._ribbon_objects[trace_idx]: - ribbon_obj.remove() - self._ribbon_objects[trace_idx].clear() - - def _update_tick_locator(self, time_span_raw: np.float32) -> None: - """Update tick locator based on current time scale and span.""" - if self.state.current_time_scale >= np.float32(1e6): # microseconds or smaller - # For microsecond scale, use reasonable intervals - tick_interval = max( - 1, int(time_span_raw * self.state.current_time_scale / 10) - ) - self.ax.xaxis.set_major_locator(MultipleLocator(tick_interval)) - else: - # For larger scales, use matplotlib's default auto locator - self.ax.xaxis.set_major_locator(mpl.ticker.AutoLocator()) - - def _update_legend(self) -> None: - """Updates the plot legend, filtering out invisible elements and optimising rebuilds.""" - logger.debug("Updating legend...") - handles, labels = self.ax.get_legend_handles_labels() - - # Filter for unique and visible handles/labels - unique_labels = [] - unique_handles = [] - for h, l in zip(handles, labels): - # Check if the handle has a get_visible method and if it returns True - # For fill_between objects (ribbons, envelopes, regions), get_visible might not exist or behave differently - # For these, we assume they are visible if they are in the list of objects - is_visible = True - if hasattr(h, "get_visible"): - is_visible = h.get_visible() - elif isinstance( - h, mpl.collections.PolyCollection - ): # For fill_between objects - # PolyCollection doesn't have get_visible, but its patches might. - # Or we can assume it's visible if it's part of the current plot. - # For now, assume it's visible if it's a PolyCollection and has data. - is_visible = len(h.get_paths()) > 0 # Check if it has any paths to draw - - if l not in unique_labels and is_visible: - unique_labels.append(l) - unique_handles.append(h) - - logger.debug(f"Unique visible legend items found: {unique_labels}") - - # Create a hash of current handles/labels for efficient comparison - current_hash = hash(tuple(id(h) for h in unique_handles) + tuple(unique_labels)) - - # Check if legend content actually changed - if ( - not hasattr(self, "_last_legend_hash") - or self._last_legend_hash != current_hash - ): - logger.debug("Legend content changed, rebuilding legend.") - if self._legend is not None: - self._legend.remove() # Remove old legend to prevent duplicates - - if unique_handles: # Only create legend if there are handles to show - self._legend = self.ax.legend( - unique_handles, unique_labels, loc="lower right" - ) - logger.debug("New legend created.") - else: - self._legend = None # No legend to show - logger.debug("No legend to show.") - - self._current_legend_handles = unique_handles - self._current_legend_labels = unique_labels - self._last_legend_hash = current_hash - else: - logger.debug("Legend content unchanged, skipping rebuild.") - - def _clear_navigation_history(self): - """Clear matplotlib's navigation history when coordinate system changes.""" - if ( - self.fig - and self.fig.canvas - and hasattr(self.fig.canvas, "toolbar") - and self.fig.canvas.toolbar - ): - toolbar = self.fig.canvas.toolbar - if hasattr(toolbar, "_nav_stack"): - toolbar._nav_stack.clear() - - def _push_current_view(self): - """Push current view to navigation history as new base.""" - if ( - self.fig - and self.fig.canvas - and hasattr(self.fig.canvas, "toolbar") - and self.fig.canvas.toolbar - ): - toolbar = self.fig.canvas.toolbar - if hasattr(toolbar, "push_current"): - toolbar.push_current() - - def _update_axis_formatting(self) -> None: - """Update axis labels and formatters.""" - if self.state.offset_time_raw is not None: - offset_value = self.state.offset_time_raw * ( - 1e3 - if self.state.offset_unit == "ms" - else 1e6 - if self.state.offset_unit == "us" - else 1e9 - if self.state.offset_unit == "ns" - else 1.0 - ) - xlabel = f"Time ({self.state.current_time_unit}) + {offset_value:.3g} {self.state.offset_unit}" - else: - xlabel = f"Time ({self.state.current_time_unit})" - - self.ax.set_xlabel(xlabel) - - formatter = _create_time_formatter( - self.state.offset_time_raw, self.state.current_time_scale - ) - self.ax.xaxis.set_major_formatter(formatter) - - def _update_overlay_lines( - self, plot_data: Dict[str, Any], show_overlays: bool - ) -> None: - """Update overlay lines based on zoom level and data availability.""" - # Clear existing overlay lines from the plot - # This method is not currently used in the provided code, but if it were, - # it would need to be updated to use the new decimation strategy. - # For now, leaving it as is, assuming it's a placeholder or for future_use. - # If it were to be used, it would need to call decimate_for_view for each overlay line. - pass # No _overlay_lines attribute in this class, this method is unused. - - def _update_y_limits(self, plot_data: Dict[str, Any], use_envelope: bool) -> None: - """Update y-axis limits to fit current data.""" - y_min_data = float("inf") - y_max_data = float("-inf") - - # Process each trace - for trace_idx in range(self.data.num_traces): - x_new_key = f"x_new_{trace_idx}" - x_min_key = f"x_min_{trace_idx}" - x_max_key = f"x_max_{trace_idx}" - - if x_new_key not in plot_data: - continue - - # Include signal data - if len(plot_data[x_new_key]) > 0: - y_min_data = min(y_min_data, np.min(plot_data[x_new_key])) - y_max_data = max(y_max_data, np.max(plot_data[x_new_key])) - - # Include envelope data if available - if use_envelope and x_min_key in plot_data and x_max_key in plot_data: - if ( - plot_data[x_min_key] is not None - and plot_data[x_max_key] is not None - and len(plot_data[x_min_key]) > 0 - ): - y_min_data = min(y_min_data, np.min(plot_data[x_min_key])) - y_max_data = max(y_max_data, np.max(plot_data[x_max_key])) - - # Include custom lines - for line_obj, _ in self._line_objects[trace_idx]: - # Check if line_obj is a Line2D or PolyCollection - if isinstance(line_obj, mpl.lines.Line2D): - y_data = line_obj.get_ydata() - if len(y_data) > 0: - y_min_data = min(y_min_data, np.min(y_data)) - y_max_data = max(y_max_data, np.max(y_data)) - elif isinstance(line_obj, mpl.collections.PolyCollection): - # For fill_between objects, iterate through paths to get y-coordinates - for path in line_obj.get_paths(): - vertices = path.vertices - if len(vertices) > 0: - y_min_data = min(y_min_data, np.min(vertices[:, 1])) - y_max_data = max(y_max_data, np.max(vertices[:, 1])) - - # Include ribbon data - for ribbon_obj, _ in self._ribbon_objects[trace_idx]: - # For fill_between objects, we need to get the paths - if hasattr(ribbon_obj, "get_paths") and len(ribbon_obj.get_paths()) > 0: - for path in ribbon_obj.get_paths(): - vertices = path.vertices - if len(vertices) > 0: - y_min_data = min(y_min_data, np.min(vertices[:, 1])) - y_max_data = max(y_max_data, np.max(vertices[:, 1])) - - # Handle case where no data was found - if y_min_data == float("inf") or y_max_data == float("-inf"): - self.ax.set_ylim(0, 1) - return - - data_range = y_max_data - y_min_data - data_mean = (y_min_data + y_max_data) / 2 - - # Use min_y_range to ensure a minimum visible range - min_visible_range = self.min_y_range - - if data_range < min_visible_range: - y_min = data_mean - min_visible_range / 2 - y_max = data_mean + min_visible_range / 2 - else: - y_margin = self.y_margin_fraction * data_range - y_min = y_min_data - y_margin - y_max = y_max_data + y_margin - - logger.debug( - f"Y-limit calculation details: data_range={data_range:.3g}, min_visible_range={min_visible_range:.3g}, data_mean={data_mean:.3g}" - ) # ADDED THIS LINE - logger.debug( - f"Pre-set Y-limits: y_min={y_min:.9f}, y_max={y_max:.9f}" - ) # ADDED THIS LINE - self.ax.set_ylim(y_min, y_max) - - def _update_plot_data(self, ax_obj) -> None: - """Update plot based on current view.""" - if self.state.is_updating(): - return - - self.state.set_updating(True) - - try: - try: - # Add debug logging for current axis limits - display_xlim = ax_obj.get_xlim() - logger.debug(f"Current display xlim: {display_xlim}") - - view_params = self._calculate_view_parameters(ax_obj) - logger.debug( - f"Calculated view parameters: xlim_raw={view_params['xlim_raw']}, time_span_raw={view_params['time_span_raw']}, use_envelope={view_params['use_envelope']}" - ) - - plot_data = self._get_plot_data(view_params) - - # Debug data availability - data_summary = {} - for trace_idx in range(self.data.num_traces): - t_key = f"t_display_{trace_idx}" - if t_key in plot_data: - data_summary[t_key] = len(plot_data[t_key]) - logger.debug(f"Plot data summary: {data_summary}") - - self._render_plot_elements(plot_data, view_params) - self._update_regions_and_legend(view_params["xlim_display"]) - self.fig.canvas.draw_idle() - except Exception as e: - logger.exception(f"Error updating plot: {e}") - # Try to recover by resetting to home view - logger.info("Attempting to recover by resetting to home view") - self.home() - finally: - self.state.set_updating(False) - - def _calculate_view_parameters(self, ax_obj) -> Dict[str, Any]: - """Calculate view parameters from current axis state.""" - try: - xlim_raw = self.coord_manager.get_current_view_raw(ax_obj) - - # Validate xlim_raw values - if not np.isfinite(xlim_raw[0]) or not np.isfinite(xlim_raw[1]): - logger.warning( - f"Invalid xlim_raw from axis: {xlim_raw}. Using initial view." - ) - xlim_raw = self._initial_xlim_raw - - # Ensure xlim_raw is in ascending order - if xlim_raw[0] > xlim_raw[1]: - logger.warning(f"xlim_raw values out of order: {xlim_raw}. Swapping.") - xlim_raw = (xlim_raw[1], xlim_raw[0]) - - time_span_raw = xlim_raw[1] - xlim_raw[0] - use_envelope = self.state.should_use_envelope(time_span_raw) - current_mode = self.MODE_ENVELOPE if use_envelope else self.MODE_DETAIL - - logger.debug(f"=== _calculate_view_parameters ===") - logger.debug(f"xlim_raw: {xlim_raw}") - logger.debug(f"time_span_raw: {time_span_raw:.6e}s") - logger.debug( - f"envelope_limit: {self.mode_switch_threshold:.6e}s" - ) # Use mode_switch_threshold - logger.debug(f"use_envelope: {use_envelope}") - logger.debug( - f"current_mode: {current_mode} ({'ENVELOPE' if current_mode == self.MODE_ENVELOPE else 'DETAIL'})" - ) - - # Update coordinate system if needed - coordinate_system_changed = self.state.update_display_params( - xlim_raw, time_span_raw - ) - if coordinate_system_changed: - logger.debug("Coordinate system changed, updating") - self._update_coordinate_system(xlim_raw, time_span_raw) - - return { - "xlim_raw": xlim_raw, - "time_span_raw": time_span_raw, - "xlim_display": self.coord_manager.xlim_raw_to_display(xlim_raw), - "use_envelope": use_envelope, - "current_mode": current_mode, - } - except Exception as e: - logger.exception(f"Error calculating view parameters: {e}") - # Return safe default values - return { - "xlim_raw": self._initial_xlim_raw, - "time_span_raw": self._initial_xlim_raw[1] - self._initial_xlim_raw[0], - "xlim_display": self.coord_manager.xlim_raw_to_display( - self._initial_xlim_raw - ), - "use_envelope": True, - "current_mode": self.MODE_ENVELOPE, - } - - def _get_plot_data(self, view_params: Dict[str, Any]) -> Dict[str, Any]: - """Get decimated plot data for current view.""" - logger.debug(f"=== _get_plot_data ===") - logger.debug(f"view_params: {view_params}") - - plot_data = {} - - # Process each trace - for trace_idx in range(self.data.num_traces): - logger.debug(f"--- Processing trace {trace_idx} ---") - t_arr = self.data.t_arrays[trace_idx] - x_arr = self.data.x_arrays[trace_idx] - - logger.debug(f"Input data: t_arr len={len(t_arr)}, x_arr len={len(x_arr)}") - - try: - t_raw, x_new, x_min, x_max = self.decimator.decimate_for_view( - t_arr, - x_arr, - view_params["xlim_raw"], - self.max_plot_points, - view_params["use_envelope"], - trace_idx, # Pass trace_id to use pre-decimated data - envelope_window_samples=None, # Envelope window calculated automatically - mode_switch_threshold=self.mode_switch_threshold, # Pass mode switch threshold - return_envelope_min_max=True, # Main signal always returns envelope min/max if use_envelope is True - ) - - logger.debug( - f"Decimated data: t_raw len={len(t_raw)}, x_new len={len(x_new)}" - ) - logger.debug( - f"Envelope data: x_min={'None' if x_min is None else f'len={len(x_min)}'}, x_max={'None' if x_max is None else f'len={len(x_max)}'}" - ) - - if len(t_raw) == 0: - logger.warning( - f"No data in current view for trace {trace_idx}. View range: {view_params['xlim_raw']}" - ) - # Add empty arrays for this trace - plot_data[f"t_display_{trace_idx}"] = np.array([], dtype=np.float32) - plot_data[f"x_new_{trace_idx}"] = np.array([], dtype=np.float32) - plot_data[f"x_min_{trace_idx}"] = None - plot_data[f"x_max_{trace_idx}"] = None - continue - - t_display = self.coord_manager.raw_to_display(t_raw) - logger.debug( - f"Converted to display coordinates: t_display range=[{np.min(t_display):.6f}, {np.max(t_display):.6f}]" - ) - - # Store data for this trace - plot_data[f"t_display_{trace_idx}"] = t_display - plot_data[f"x_new_{trace_idx}"] = x_new - plot_data[f"x_min_{trace_idx}"] = x_min - plot_data[f"x_max_{trace_idx}"] = x_max - - logger.debug(f"Stored plot data for trace {trace_idx}") - except Exception as e: - logger.exception(f"Error getting plot data for trace {trace_idx}: {e}") - # Add empty arrays for this trace to prevent further errors - plot_data[f"t_display_{trace_idx}"] = np.array([], dtype=np.float32) - plot_data[f"x_new_{trace_idx}"] = np.array([], dtype=np.float32) - plot_data[f"x_min_{trace_idx}"] = None - plot_data[f"x_max_{trace_idx}"] = None - - logger.debug(f"Final plot_data keys: {list(plot_data.keys())}") - return plot_data - - def _render_plot_elements( - self, plot_data: Dict[str, Any], view_params: Dict[str, Any] - ) -> None: - """Render all plot elements with current data.""" - logger.debug(f"=== _render_plot_elements ===") - logger.debug(f"view_params use_envelope: {view_params['use_envelope']}") - - # Store the current plot data for use by other methods - self._current_plot_data = plot_data - - # Check if we have any data to plot - has_data = False - data_summary = {} - for trace_idx in range(self.data.num_traces): - key = f"t_display_{trace_idx}" - if key in plot_data and len(plot_data[key]) > 0: - has_data = True - data_summary[f"trace_{trace_idx}"] = len(plot_data[key]) - else: - data_summary[f"trace_{trace_idx}"] = 0 - - logger.debug(f"Data summary: {data_summary}, has_data: {has_data}") - - if not has_data: - logger.warning("No data to plot, clearing all elements") - # If no data, clear all lines and return - for i in range(self.data.num_traces): - self._signal_lines[i].set_data([], []) - if self._envelope_fills[i] is not None: - self._envelope_fills[i].remove() - self._envelope_fills[i] = None - - # Clear custom elements - self._clear_custom_elements(i) - - self.ax.set_ylim(0, 1) # Set a default y-limit - return - - # Process each trace - for trace_idx in range(self.data.num_traces): - logger.debug(f"--- Rendering trace {trace_idx} ---") - t_display_key = f"t_display_{trace_idx}" - x_new_key = f"x_new_{trace_idx}" - x_min_key = f"x_min_{trace_idx}" - x_max_key = f"x_max_{trace_idx}" - - if t_display_key not in plot_data or len(plot_data[t_display_key]) == 0: - logger.debug(f"No data for trace {trace_idx}, hiding elements") - # No data for this trace, hide its elements - self._signal_lines[trace_idx].set_data([], []) - if self._envelope_fills[trace_idx] is not None: - self._envelope_fills[trace_idx].remove() - self._envelope_fills[trace_idx] = None - - # Clear custom elements - self._clear_custom_elements(trace_idx) - continue - - # Update signal display - envelope_data = None - if ( - view_params["use_envelope"] - and x_min_key in plot_data - and x_max_key in plot_data - ): - if ( - plot_data[x_min_key] is not None - and plot_data[x_max_key] is not None - ): - envelope_data = (plot_data[x_min_key], plot_data[x_max_key]) - logger.debug(f"Using envelope data for trace {trace_idx}") - else: - logger.debug( - f"Envelope mode requested but no envelope data for trace {trace_idx}" - ) - else: - logger.debug(f"Detail mode for trace {trace_idx}") - - self._update_signal_display( - trace_idx, plot_data[t_display_key], plot_data[x_new_key], envelope_data - ) - - # Update y-limits - logger.debug("Updating y-limits") - self._update_y_limits(plot_data, view_params["use_envelope"]) - - def _update_coordinate_system( - self, xlim_raw: Tuple[np.float32, np.float32], time_span_raw: np.float32 - ) -> None: - """Update coordinate system and axis formatting.""" - self._clear_region_fills() - self._update_axis_formatting() - self._update_tick_locator(time_span_raw) - - xlim_display = self.coord_manager.xlim_raw_to_display(xlim_raw) - self.ax.set_xlim(xlim_display) - - self._clear_navigation_history() - self._push_current_view() - - def _update_regions_and_legend( - self, xlim_display: Tuple[np.float32, np.float32] - ) -> None: - """Update regions and legend.""" - self._refresh_region_display(xlim_display) - self._update_legend() - - def _refresh_region_display( - self, xlim_display: Tuple[np.float32, np.float32] - ) -> None: - """Refresh region display for current view.""" - logger.debug(f"=== _refresh_region_display ===") - self._clear_region_fills() - - # Get current mode - current_mode = ( - self.MODE_ENVELOPE - if self.state.current_mode == "envelope" - else self.MODE_DETAIL - ) - logger.debug(f"Current display mode for regions: {current_mode}") - - for trace_idx in range(self.data.num_traces): - logger.debug(f"Processing regions for trace {trace_idx}") - # Process each region definition - for region_def in self._regions[trace_idx]: - logger.debug( - f"Region '{region_def['label']}': display_mode={region_def['display_mode']}, current_mode={current_mode}" - ) - # Skip if not visible in current mode - if not (region_def["display_mode"] & current_mode): - logger.debug( - f"Region '{region_def['label']}' not visible in current mode {current_mode}, skipping." - ) - continue - - regions = region_def["regions"] - if regions is None or len(regions) == 0: - logger.debug( - f"No regions data for '{region_def['label']}', skipping." - ) - continue - - logger.debug( - f"Displaying {len(regions)} regions for '{region_def['label']}' in mode {current_mode}" - ) - - color = region_def["color"] - label = region_def["label"] - alpha = region_def["alpha"] - first_visible_region = True - - for t_start, t_end in regions: - t_start_display = self.coord_manager.raw_to_display(t_start) - t_end_display = self.coord_manager.raw_to_display(t_end) - - # Check if region overlaps with current view - if not ( - t_end_display <= xlim_display[0] - or t_start_display >= xlim_display[1] - ): - # Only show label for first visible region - current_label = label if first_visible_region else "" - if first_visible_region and len(regions) > 1: - current_label = f"{label} ({len(regions)})" - - logger.debug( - f"Adding region span from {t_start_display:.6f} to {t_end_display:.6f} (raw: {t_start:.6f} to {t_end:.6f}) for '{label}'" - ) - fill = self.ax.axvspan( - t_start_display, - t_end_display, - alpha=alpha, - color=color, - linewidth=0.5, - label=current_label, - zorder=region_def["zorder"], - ) - self._region_objects[trace_idx].append((fill, region_def)) - first_visible_region = False - else: - logger.debug( - f"Region span from {t_start_display:.6f} to {t_end_display:.6f} (raw: {t_start:.6f} to {t_end:.6f}) for '{label}' is outside current view {xlim_display}, skipping." - ) - - def _clear_region_fills(self) -> None: - """Clear all region fills.""" - logger.debug("Clearing region fills.") - for trace_fills in self._region_objects: - for fill_item in trace_fills: - # Handle both old format (just fill object) and new format (tuple) - if isinstance(fill_item, tuple): - fill, _ = fill_item # Extract the fill object from the tuple - fill.remove() - else: - fill_item.remove() # Old format - direct fill object - trace_fills.clear() - logger.debug("Region fills cleared.") - - def _setup_plot_elements(self) -> None: - """ - Initialise matplotlib plot elements (lines, fills) for each trace. - This is called once during render(). - """ - if self.fig is None or self.ax is None: - raise RuntimeError( - "Figure and Axes must be created before setting up plot elements." - ) - - # Create initial signal line objects for each trace - for i in range(self.data.num_traces): - color = self.data.get_trace_color(i) - name = self.data.get_trace_name(i) - - # Signal line - (line_signal,) = self.ax.plot( - [], - [], - label="Raw data" if self.data.num_traces == 1 else f"Raw data ({name})", - color=color, - alpha=self.signal_alpha, - ) - self._signal_lines.append(line_signal) - - def _connect_callbacks(self) -> None: - """Connect matplotlib callbacks.""" - if self.ax is None: - raise RuntimeError("Axes must be created before connecting callbacks.") - self.ax.callbacks.connect("xlim_changed", self._update_plot_data) - - def _setup_toolbar_overrides(self) -> None: - """Override matplotlib toolbar methods (e.g., home button).""" - if ( - self.fig - and self.fig.canvas - and hasattr(self.fig.canvas, "toolbar") - and self.fig.canvas.toolbar - ): - toolbar = self.fig.canvas.toolbar - - # Store original methods - self._original_home = getattr(toolbar, "home", None) - self._original_push_current = getattr(toolbar, "push_current", None) - - # Create our custom home method - def custom_home(*args, **kwargs): - logger.debug("Toolbar home button pressed - calling custom home") - self.home() - - # Override both the method and try to find the actual button - toolbar.home = custom_home - - # For Qt backend, also override the action - if hasattr(toolbar, "actions"): - for action in toolbar.actions(): - if hasattr(action, "text") and hasattr(action, "objectName"): - action_text = ( - action.text() if callable(action.text) else str(action.text) - ) - action_name = ( - action.objectName() - if callable(action.objectName) - else str(action.objectName) - ) - if action_text == "Home" or "home" in action_name.lower(): - if hasattr(action, "triggered"): - action.triggered.disconnect() - action.triggered.connect(custom_home) - logger.debug("Connected custom home to Qt action") - break - - # For other backends, try to override the button callback - if hasattr(toolbar, "_buttons") and "Home" in toolbar._buttons: - home_button = toolbar._buttons["Home"] - if hasattr(home_button, "configure"): - home_button.configure(command=custom_home) - logger.debug("Connected custom home to Tkinter button") - - def _set_initial_view_and_labels(self) -> None: - """Set initial axis limits, title, and labels.""" - if self.ax is None: - raise RuntimeError( - "Axes must be created before setting initial view and labels." - ) - - # Create title based on number of traces - if self.data.num_traces == 1: - self.ax.set_title(f"{self.data.names[0]}") - else: - # Multiple traces - just show "Multiple Traces" - self.ax.set_title(f"Multiple Traces ({self.data.num_traces})") - self.ax.set_xlabel(f"Time ({self.state.current_time_unit})") - self.ax.set_ylabel("Signal") - - # Set initial xlim - initial_xlim_display = self.coord_manager.xlim_raw_to_display( - self._initial_xlim_raw - ) - self.ax.set_xlim(initial_xlim_display) - - def render(self) -> None: - """ - Renders the oscilloscope plot. This method must be called after all - data and visualization elements have been added. - """ - if self.fig is not None or self.ax is not None: - logger.warning( - "Plot already rendered. Call `home()` to reset or create a new instance." - ) - return - - logger.info("Rendering plot...") - self.fig, self.ax = plt.subplots(figsize=(10, 5)) - - self._setup_plot_elements() - self._connect_callbacks() - self._setup_toolbar_overrides() - self._set_initial_view_and_labels() - - # Calculate initial parameters for the full view - t_start, t_end = self.data.get_global_time_range() - full_time_span = t_end - t_start - - logger.info( - f"Initial render: full time span={full_time_span:.3e}s, envelope_limit={self.mode_switch_threshold:.3e}s" - ) - - # Set initial display state based on full view - self.state.current_time_unit, self.state.current_time_scale = ( - _get_optimal_time_unit_and_scale(full_time_span) - ) - self.state.current_mode = ( - "envelope" if self.state.should_use_envelope(full_time_span) else "detail" - ) - - # Force initial draw of all elements by calling _update_plot_data - # This will also update the legend and regions - self.state.set_updating(False) # Ensure not in updating state for first call - self._update_plot_data(self.ax) - self.fig.canvas.draw_idle() - logger.info("Plot rendering complete.") - - def home(self) -> None: - """Return to initial full view with complete state reset.""" - if self.ax is None: # Fix: Changed '===' to 'is' - logger.warning("Plot not rendered yet. Cannot go home.") - return - - # Disconnect callback temporarily - callback_id = None - for cid, callback in self.ax.callbacks.callbacks["xlim_changed"].items(): - if getattr(callback, "__func__", callback) == self._update_plot_data: - callback_id = cid - break - - if callback_id is not None: - self.ax.callbacks.disconnect(callback_id) - - try: - self.state.set_updating(True) - self.state.reset_to_initial_state() - self.decimator.clear_cache() - self._clear_region_fills() - - # Clear all custom elements and reset _last_mode for each trace to force redraw - for trace_idx in range(self.data.num_traces): - self._clear_custom_elements(trace_idx) - self._last_mode[trace_idx] = None - - # Reset axis formatting - self.ax.set_xlabel(f"Time ({self.state.original_time_unit})") - self.ax.xaxis.set_major_formatter(mpl.ticker.ScalarFormatter()) - self.ax.xaxis.set_major_locator(mpl.ticker.AutoLocator()) - - # Reset view - self.coord_manager.set_view_raw(self.ax, self._initial_xlim_raw) - - # Manually trigger update for the home view - # This will re-evaluate use_envelope, current_mode, and redraw everything - self._update_plot_data(self.ax) - - self.state.set_updating(False) - - finally: - self.ax.callbacks.connect("xlim_changed", self._update_plot_data) - - self.fig.canvas.draw() - logger.info(f"Home view restored: {self.state.original_time_unit} scale") - - def refresh(self) -> None: - """Force a complete refresh of the plot without changing the current view.""" - if self.ax is None: - logger.warning("Plot not rendered yet. Cannot refresh.") - return - - # Temporarily bypass the updating state for forced refresh - was_updating = self.state.is_updating() - self.state.set_updating(False) - try: - self._update_plot_data(self.ax) - finally: - self.state.set_updating(was_updating) - self.fig.canvas.draw_idle() - - def show(self) -> None: - """Display the plot.""" - if self.fig is None: - self.render() # Render if not already rendered - plt.show() +from typing import Any, Dict, List, Optional, Tuple, Union + +import warnings + +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.ticker import MultipleLocator + +from .coordinate_manager import CoordinateManager +from .data_manager import TimeSeriesDataManager +from .decimation import DecimationManager +from .display_state import ( + DisplayState, + _create_time_formatter, + _get_optimal_time_unit_and_scale, +) + + +class OscilloscopePlot: + """ + General-purpose plotting class for time-series data with zoom and decimation. + + Uses separate managers for data, decimation, and state to reduce complexity. + Supports different visualization elements (lines, envelopes, ribbons, regions) + that can be displayed in different modes (envelope when zoomed out, detail when zoomed in). + """ + + # Mode constants + MODE_ENVELOPE = 1 # Zoomed out mode + MODE_DETAIL = 2 # Zoomed in mode + MODE_BOTH = 3 # Both modes + + # Default styling constants + DEFAULT_MAX_PLOT_POINTS = 10000 + DEFAULT_MODE_SWITCH_THRESHOLD = 10e-3 # 10 ms + DEFAULT_MIN_Y_RANGE_DEFAULT = 1e-9 # Default minimum Y-axis range (e.g., 1 nV) + DEFAULT_Y_MARGIN_FRACTION = 0.15 + DEFAULT_SIGNAL_LINE_WIDTH = 1.0 + DEFAULT_SIGNAL_ALPHA = 0.75 + DEFAULT_ENVELOPE_ALPHA = 0.75 + DEFAULT_REGION_ALPHA = 0.4 + DEFAULT_REGION_ZORDER = -5 + + def __init__( + self, + t: Union[np.ndarray, List[np.ndarray]], + x: Union[np.ndarray, List[np.ndarray]], + name: Union[str, List[str]] = "Waveform", + trace_colors: Optional[List[str]] = None, + # Core display parameters + max_plot_points: int = DEFAULT_MAX_PLOT_POINTS, + mode_switch_threshold: float = DEFAULT_MODE_SWITCH_THRESHOLD, + min_y_range: Optional[float] = None, # New parameter for minimum Y-axis range + y_margin_fraction: float = DEFAULT_Y_MARGIN_FRACTION, + signal_line_width: float = DEFAULT_SIGNAL_LINE_WIDTH, + signal_alpha: float = DEFAULT_SIGNAL_ALPHA, + envelope_alpha: float = DEFAULT_ENVELOPE_ALPHA, + region_alpha: float = DEFAULT_REGION_ALPHA, + region_zorder: int = DEFAULT_REGION_ZORDER, + envelope_window_samples: Optional[int] = None, + ): + """ + Initialize the OscilloscopePlot with time series data. + + Parameters + ---------- + t : Union[np.ndarray, List[np.ndarray]] + Time array(s) (raw time in seconds). Can be a single array shared by all traces + or a list of arrays, one per trace. + x : Union[np.ndarray, List[np.ndarray]] + Signal array(s). If t is a single array, x can be a 2D array (traces x samples) + or a list of 1D arrays. If t is a list, x must be a list of equal length. + name : Union[str, List[str]], default="Waveform" + Name(s) for plot title. Can be a single string or a list of strings. + trace_colors : Optional[List[str]], default=None + Colors for each trace. If None, default colors will be used. + max_plot_points : int, default=10000 + Maximum number of points to display on the plot. Data will be decimated if it exceeds this. + mode_switch_threshold : float, default=10e-3 + Time span (in seconds) above which the plot switches to envelope mode. + min_y_range : Optional[float], default=None + Minimum Y-axis range to enforce. If None, a default small value is used. + y_margin_fraction : float, default=0.05 + Fraction of data range to add as margin to Y-axis limits. + signal_line_width : float, default=1.0 + Line width for the raw signal plot. + signal_alpha : float, default=0.75 + Alpha (transparency) for the raw signal plot. + envelope_alpha : float, default=1.0 + Alpha (transparency) for the envelope fill. + region_alpha : float, default=0.4 + Alpha (transparency) for region highlight fills. + region_zorder : int, default=-5 + Z-order for region highlight fills (lower means further back). + envelope_window_samples : Optional[int], default=None + DEPRECATED: Window size in samples for envelope calculation. + Envelope window is now calculated automatically based on max_plot_points and zoom level. + This parameter is ignored but kept for backward compatibility. + """ + # Store styling parameters directly as instance attributes + self.max_plot_points = max_plot_points + self.mode_switch_threshold = np.float32(mode_switch_threshold) + self.min_y_range = ( + np.float32(min_y_range) + if min_y_range is not None + else self.DEFAULT_MIN_Y_RANGE_DEFAULT + ) + self.y_margin_fraction = np.float32(y_margin_fraction) + self.signal_line_width = signal_line_width + self.signal_alpha = signal_alpha + self.envelope_alpha = envelope_alpha + self.region_alpha = region_alpha + self.region_zorder = region_zorder + # envelope_window_samples is now deprecated - envelope window is calculated automatically + # Keep the parameter for backward compatibility but don't use it + if envelope_window_samples is not None: + warnings.warn( + "envelope_window_samples parameter is deprecated. Envelope window is now calculated automatically based on zoom level.", DeprecationWarning + ) + + # Initialize managers + self.data = TimeSeriesDataManager(t, x, name, trace_colors) + self.decimator = DecimationManager() + + # Pre-decimate main signal data for envelope view + for i in range(self.data.num_traces): + self.decimator.pre_decimate_data( + data_id=i, # Use trace_idx as data_id + t=self.data.t_arrays[i], + x=self.data.x_arrays[i], + max_points=self.max_plot_points, + envelope_window_samples=None, # Envelope window calculated automatically + ) + + # Initialize display state using the first trace's time array + initial_time_unit, initial_time_scale = _get_optimal_time_unit_and_scale( + self.data.t_arrays[0] + ) + self.state = DisplayState( + initial_time_unit, initial_time_scale, self.mode_switch_threshold + ) + + # Initialize matplotlib figure and axes to None + self.fig: Optional[mpl.figure.Figure] = None + self.ax: Optional[mpl.axes.Axes] = None + + # Store visualization elements for each trace + self._signal_lines: List[mpl.lines.Line2D] = [] + self._envelope_fills: List[Optional[mpl.collections.PolyCollection]] = [ + None + ] * self.data.num_traces + + # Visualization elements with mode control (definitions, not plot objects) + self._lines: List[List[Dict[str, Any]]] = [ + [] for _ in range(self.data.num_traces) + ] + self._ribbons: List[List[Dict[str, Any]]] = [ + [] for _ in range(self.data.num_traces) + ] + self._regions: List[List[Dict[str, Any]]] = [ + [] for _ in range(self.data.num_traces) + ] + self._envelopes: List[List[Dict[str, Any]]] = [ + [] for _ in range(self.data.num_traces) + ] + + # Line objects for each trace (will be populated as needed during rendering) + self._line_objects: List[List[mpl.artist.Artist]] = [ + [] for _ in range(self.data.num_traces) + ] # Changed type hint to Artist + self._ribbon_objects: List[List[mpl.collections.PolyCollection]] = [ + [] for _ in range(self.data.num_traces) + ] + self._region_objects: List[List[mpl.collections.PolyCollection]] = [ + [] for _ in range(self.data.num_traces) + ] + + # Store current plot data for access by other methods + self._current_plot_data = {} + + # Initialize coordinate manager + self.coord_manager = CoordinateManager(self.state) + + # Store initial view for home button (using global time range) + t_start, t_end = self.data.get_global_time_range() + self._initial_xlim_raw = (t_start, t_end) + + # Legend state for optimization + self._current_legend_handles: List[mpl.artist.Artist] = [] + self._current_legend_labels: List[str] = [] + self._legend: Optional[mpl.legend.Legend] = None + + # Track last mode for each trace to optimize element updates + self._last_mode: Dict[int, Optional[int]] = { + i: None for i in range(self.data.num_traces) + } + + # Store original toolbar methods for restoration + self._original_home = None + self._original_push_current = None + + def save(self, filepath: str) -> None: + """ + Save the current plot to a file. + + Parameters + ---------- + filepath : str + Path to save the plot image. + """ + if self.fig is None or self.ax is None: + raise RuntimeError("Plot has not been initialized yet.") + self.fig.savefig(filepath) + print(f"Plot saved to {filepath}") + + def add_line( + self, + t: Union[np.ndarray, List[np.ndarray]], + data: Union[np.ndarray, List[np.ndarray]], + label: str = "Line", + color: Optional[str] = None, + alpha: float = 0.75, + linestyle: str = "-", + linewidth: float = 1.0, + display_mode: int = MODE_BOTH, + trace_idx: int = 0, + zorder: int = 5, + ) -> None: + """ + Add a line to the plot with mode control. + + Parameters + ---------- + t : Union[np.ndarray, List[np.ndarray]] + Time array(s) for the line data. Must match the length of data. + data : Union[np.ndarray, List[np.ndarray]] + Line data array(s). Can be a single array or a list of arrays. + label : str, default="Line" + Label for the legend. + color : Optional[str], default=None + Color for the line. If None, the trace color will be used. + alpha : float, default=0.75 + Alpha (transparency) for the line. + linestyle : str, default="-" + Line style. + linewidth : float, default=1.0 + Line width. + display_mode : int, default=MODE_BOTH + Which mode(s) to show this line in (MODE_ENVELOPE, MODE_DETAIL, or MODE_BOTH). + trace_idx : int, default=0 + Index of the trace to add the line to. + zorder : int, default=5 + Z-order for the line (higher values appear on top). + """ + if trace_idx < 0 or trace_idx >= self.data.num_traces: + raise ValueError( + f"Invalid trace index: {trace_idx}. Must be between 0 and {self.data.num_traces - 1}." + ) + + # Validate data length + if isinstance(data, list): + if len(data) != len(t): + raise ValueError( + f"Line data length ({len(data)}) must match time array length ({len(t)})." + ) + else: + if len(data) != len(t): + raise ValueError( + f"Line data length ({len(data)}) must match time array length ({len(t)})." + ) + + # Use trace color if none provided + if color is None: + color = self.data.get_trace_color(trace_idx) + + # Convert inputs to numpy arrays + t_array = np.asarray(t, dtype=np.float32) + data_array = np.asarray(data, dtype=np.float32) + + # Assign a unique ID for this custom line for pre-decimation caching + # We use a negative ID to distinguish from main traces (which use 0, 1, 2...) + # and ensure uniqueness across custom lines. + line_id = -(len(self._lines[trace_idx]) + 1) # Negative, unique per trace + + # Pre-decimate this custom line's data for envelope view + self.decimator.pre_decimate_data( + data_id=line_id, + t=t_array, + x=data_array, + max_points=self.max_plot_points, + envelope_window_samples=None, # Envelope window calculated automatically + ) + + # Store line definition with raw data and its assigned ID + line_def = { + "id": line_id, # Store the ID for retrieval from decimator + "t_raw": t_array, # Store raw time array + "data_raw": data_array, # Store raw data array + "label": label, + "color": color, + "alpha": alpha, + "linestyle": linestyle, + "linewidth": linewidth, + "display_mode": display_mode, + "zorder": zorder, + } + + self._lines[trace_idx].append(line_def) + + def add_ribbon( + self, + t: Union[np.ndarray, List[np.ndarray]], + center_data: Union[np.ndarray, List[np.ndarray]], + width: Union[float, np.ndarray], + label: str = "Ribbon", + color: str = "gray", + alpha: float = 0.6, + display_mode: int = MODE_DETAIL, + trace_idx: int = 0, + zorder: int = 2, + ) -> None: + """ + Add a ribbon (center ± width) with mode control. + + Parameters + ---------- + t : Union[np.ndarray, List[np.ndarray]] + Time array(s) for the ribbon data. Must match the length of center_data. + center_data : Union[np.ndarray, List[np.ndarray]] + Center line data array(s). Can be a single array or a list of arrays. + width : Union[float, np.ndarray] + Width of the ribbon. Can be a single value or an array matching center_data. + label : str, default="Ribbon" + Label for the legend. + color : str, default="gray" + Color for the ribbon. + alpha : float, default=0.6 + Alpha (transparency) for the ribbon. + display_mode : int, default=MODE_DETAIL + Which mode(s) to show this ribbon in (MODE_ENVELOPE, MODE_DETAIL, or MODE_BOTH). + trace_idx : int, default=0 + Index of the trace to add the ribbon to. + """ + if trace_idx < 0 or trace_idx >= self.data.num_traces: + raise ValueError( + f"Invalid trace index: {trace_idx}. Must be between 0 and {self.data.num_traces - 1}." + ) + + # Validate data length + if isinstance(center_data, list): + if len(center_data) != len(t): + raise ValueError( + f"Ribbon center data length ({len(center_data)}) must match time array length ({len(t)})." + ) + else: + if len(center_data) != len(t): + raise ValueError( + f"Ribbon center data length ({len(center_data)}) must match time array length ({len(t)})." + ) + + # Convert center data to numpy array + center_data = np.asarray(center_data, dtype=np.float32) + + # Handle width as scalar or array + if isinstance(width, (int, float, np.number)): + width_array = np.ones_like(center_data) * width + else: + if len(width) != len(center_data): + raise ValueError( + f"Ribbon width array length ({len(width)}) must match center data length ({len(center_data)})." + ) + width_array = np.asarray(width, dtype=np.float32) + + # Assign a unique ID for this custom ribbon + ribbon_id = -( + len(self._ribbons[trace_idx]) + 1001 + ) # Negative, unique per trace, offset from lines + + # Pre-decimate this custom ribbon's center data for envelope view + # We only pre-decimate the center, as width is applied later + self.decimator.pre_decimate_data( + data_id=ribbon_id, + t=np.asarray(t, dtype=np.float32), + x=center_data, + max_points=self.max_plot_points, + envelope_window_samples=None, # Envelope window calculated automatically + ) + + # Store ribbon definition + ribbon_def = { + "id": ribbon_id, + "t_raw": np.asarray(t, dtype=np.float32), + "center_data_raw": center_data, + "width_raw": width_array, + "label": label, + "color": color, + "alpha": alpha, + "display_mode": display_mode, + "zorder": zorder, + } + + self._ribbons[trace_idx].append(ribbon_def) + + def add_envelope( + self, + min_data: Union[np.ndarray, List[np.ndarray]], + max_data: Union[np.ndarray, List[np.ndarray]], + label: str = "Envelope", + color: Optional[str] = None, + alpha: float = 0.4, + display_mode: int = MODE_ENVELOPE, + trace_idx: int = 0, + zorder: int = 1, + ) -> None: + """ + Add envelope data with mode control. + + Parameters + ---------- + min_data : Union[np.ndarray, List[np.ndarray]] + Minimum envelope data array(s). Can be a single array or a list of arrays. + max_data : Union[np.ndarray, List[np.ndarray]] + Maximum envelope data array(s). Can be a single array or a list of arrays. + label : str, default="Envelope" + Label for the legend. + color : Optional[str], default=None + Color for the envelope. If None, the trace color will be used. + alpha : float, default=0.4 + Alpha (transparency) for the envelope. + display_mode : int, default=MODE_ENVELOPE + Which mode(s) to show this envelope in (MODE_ENVELOPE, MODE_DETAIL, or MODE_BOTH). + trace_idx : int, default=0 + Index of the trace to add the envelope to. + """ + if trace_idx < 0 or trace_idx >= self.data.num_traces: + raise ValueError( + f"Invalid trace index: {trace_idx}. Must be between 0 and {self.data.num_traces - 1}." + ) + + # Validate data length + if isinstance(min_data, list): + if len(min_data) != len(self.data.t_arrays[trace_idx]): + raise ValueError( + f"Envelope min data length ({len(min_data)}) must match time array length ({len(self.data.t_arrays[trace_idx])})." + ) + else: + if len(min_data) != len(self.data.t_arrays[trace_idx]): + raise ValueError( + f"Envelope min data length ({len(min_data)}) must match time array length ({len(self.data.t_arrays[trace_idx])})." + ) + + if isinstance(max_data, list): + if len(max_data) != len(self.data.t_arrays[trace_idx]): + raise ValueError( + f"Envelope max data length ({len(max_data)}) must match time array length ({len(self.data.t_arrays[trace_idx])})." + ) + else: + if len(max_data) != len(self.data.t_arrays[trace_idx]): + raise ValueError( + f"Envelope max data length ({len(max_data)}) must match time array length ({len(self.data.t_arrays[trace_idx])})." + ) + + # Use trace color if none provided + if color is None: + color = self.data.get_trace_color(trace_idx) + + # Assign a unique ID for this custom envelope + envelope_id = -( + len(self._envelopes[trace_idx]) + 2001 + ) # Negative, unique per trace, offset from ribbons + + # Pre-decimate this custom envelope's data for envelope view + # We'll pre-decimate the average of min/max, and store min/max separately + t_raw = self.data.t_arrays[trace_idx] + avg_data = ( + np.asarray(min_data, dtype=np.float32) + + np.asarray(max_data, dtype=np.float32) + ) / 2 + + self.decimator.pre_decimate_data( + data_id=envelope_id, + t=t_raw, + x=avg_data, # Pass average for decimation + max_points=self.max_plot_points, + envelope_window_samples=None, # Envelope window calculated automatically + ) + + # Store envelope definition + envelope_def = { + "id": envelope_id, + "t_raw": t_raw, + "min_data_raw": np.asarray(min_data, dtype=np.float32), + "max_data_raw": np.asarray(max_data, dtype=np.float32), + "label": label, + "color": color, + "alpha": alpha, + "display_mode": display_mode, + "zorder": zorder, + } + + self._envelopes[trace_idx].append(envelope_def) + + def add_regions( + self, + regions: np.ndarray, + label: str = "Regions", + color: str = "crimson", + alpha: float = 0.4, + display_mode: int = MODE_BOTH, + trace_idx: int = 0, + zorder: int = -5, + ) -> None: + """ + Add region highlights with mode control. + + Parameters + ---------- + regions : np.ndarray + Region data array with shape (N, 2) where each row is [start_time, end_time]. + label : str, default="Regions" + Label for the legend. + color : str, default="crimson" + Color for the regions. + alpha : float, default=0.4 + Alpha (transparency) for the regions. + display_mode : int, default=MODE_BOTH + Which mode(s) to show these regions in (MODE_ENVELOPE, MODE_DETAIL, or MODE_BOTH). + trace_idx : int, default=0 + Index of the trace to add the regions to. + """ + if trace_idx < 0 or trace_idx >= self.data.num_traces: + raise ValueError( + f"Invalid trace index: {trace_idx}. Must be between 0 and {self.data.num_traces - 1}." + ) + + # Validate regions array + if regions.ndim != 2 or regions.shape[1] != 2: + raise ValueError( + f"Regions array must have shape (N, 2), got {regions.shape}." + ) + + # Store regions definition + region_def = { + "regions": np.asarray(regions, dtype=np.float32), + "label": label, + "color": color, + "alpha": alpha, + "display_mode": display_mode, + "zorder": zorder, + } + + self._regions[trace_idx].append(region_def) + + def _update_signal_display( + self, + trace_idx: int, + t_display: np.ndarray, + x_data: np.ndarray, + envelope_data: Optional[Tuple[np.ndarray, np.ndarray]] = None, + ) -> None: + """ + Update signal display with envelope or raw data for a specific trace. + + Parameters + ---------- + trace_idx : int + Index of the trace to update. + t_display : np.ndarray + Display time array. + x_data : np.ndarray + Signal data array. + envelope_data : Optional[Tuple[np.ndarray, np.ndarray]], default=None + Tuple of (min, max) envelope data if in envelope mode. + """ + if envelope_data is not None: + self._show_envelope_mode(trace_idx, t_display, envelope_data) + else: + self._show_detail_mode(trace_idx, t_display, x_data) + + def _show_envelope_mode( + self, + trace_idx: int, + t_display: np.ndarray, + envelope_data: Tuple[np.ndarray, np.ndarray], + ) -> None: + """ + Show envelope display mode for a specific trace. + + Parameters + ---------- + trace_idx : int + Index of the trace to update. + t_display : np.ndarray + Display time array. + envelope_data : Tuple[np.ndarray, np.ndarray] + Tuple of (min, max) envelope data. + """ + x_min, x_max = envelope_data + color = self.data.get_trace_color(trace_idx) + name = self.data.get_trace_name(trace_idx) + + # Clean up previous displays + if self._envelope_fills[trace_idx] is not None: + self._envelope_fills[trace_idx].remove() + + self._signal_lines[trace_idx].set_data([], []) + self._signal_lines[trace_idx].set_visible(False) + + # Show built-in envelope + self._envelope_fills[trace_idx] = self.ax.fill_between( + t_display, + x_min, + x_max, + alpha=self.envelope_alpha, + color=color, + lw=0.1, + label=f"Raw envelope ({name})" + if self.data.num_traces > 1 + else "Raw envelope", + zorder=1, # Keep default envelope at zorder=1 + ) + + # Set current mode + self.state.current_mode = "envelope" + + # Show any custom elements for this mode + self._show_custom_elements(trace_idx, t_display, self.MODE_ENVELOPE) + + def _show_detail_mode( + self, trace_idx: int, t_display: np.ndarray, x_data: np.ndarray + ) -> None: + """ + Show detail display mode for a specific trace. + + Parameters + ---------- + trace_idx : int + Index of the trace to update. + t_display : np.ndarray + Display time array. + x_data : np.ndarray + Signal data array. + """ + # Clean up envelope + if self._envelope_fills[trace_idx] is not None: + self._envelope_fills[trace_idx].remove() + self._envelope_fills[trace_idx] = None + + # Update signal line + line = self._signal_lines[trace_idx] + line.set_data(t_display, x_data) + line.set_linewidth(self.signal_line_width) + line.set_alpha(self.signal_alpha) + line.set_visible(True) + + # Set current mode + self.state.current_mode = "detail" + + # Show any custom elements for this mode + self._show_custom_elements(trace_idx, t_display, self.MODE_DETAIL) + + def _show_custom_elements( + self, trace_idx: int, t_display: np.ndarray, current_mode: int + ) -> None: + """ + Show custom visualization elements for the current mode. + + Parameters + ---------- + trace_idx : int + Index of the trace to update. + t_display : np.ndarray + Display time array. + current_mode : int + Current display mode (MODE_ENVELOPE or MODE_DETAIL). + """ + last_mode = self._last_mode.get(trace_idx) + + # Always clear and recreate elements when view changes, regardless of mode change + # This ensures custom lines/ribbons are redrawn correctly with current view data + self._clear_custom_elements(trace_idx) + + # Get current raw x-limits from the main plot data + # This is crucial for decimating custom lines to the current view + current_xlim_raw = self.coord_manager.get_current_view_raw(self.ax) + + # Show lines for current mode + line_objects = [] + for i, line_def in enumerate(self._lines[trace_idx]): + if ( + line_def["display_mode"] & current_mode + ): # Bitwise check if mode is enabled + + # Dynamically decimate the line data for the current view + # Use the same max_plot_points as the main signal for consistency + # For custom lines, we want mean decimation if in envelope mode, not min/max envelope + t_line_raw, line_data, _, _ = self.decimator.decimate_for_view( + line_def["t_raw"], + line_def["data_raw"], + current_xlim_raw, # Decimate to current view + self.max_plot_points, + use_envelope=(current_mode == self.MODE_ENVELOPE), + data_id=line_def[ + "id" + ], # Pass the custom line's ID for pre-decimated data lookup + envelope_window_samples=None, # Envelope window calculated automatically + mode_switch_threshold=self.mode_switch_threshold, # Pass mode switch threshold + return_envelope_min_max=False, # Custom lines never return min/max envelope + ) + + if len(t_line_raw) == 0 or len(line_data) == 0: + warnings.warn( + f"Line {i} ('{line_def['label']}') has empty data after decimation for current view, skipping plot.", UserWarning + ) + continue + + # Make sure the time array is in display coordinates + t_line_display = self.coord_manager.raw_to_display(t_line_raw) + + # Always plot as a regular line + (line,) = self.ax.plot( + t_line_display, + line_data, + label=line_def["label"], + color=line_def["color"], + alpha=line_def["alpha"], + linestyle=line_def["linestyle"], + linewidth=line_def["linewidth"], + zorder=line_def["zorder"], + ) + line_objects.append( + (line, line_def) + ) # Store both the line and its definition + + # Show ribbons for current mode + ribbon_objects = [] + for ribbon_def in self._ribbons[trace_idx]: + if ribbon_def["display_mode"] & current_mode: + + # Ribbons are always plotted as fills, so we need to decimate their center and width + # We'll treat the center_data as the 'signal' for decimation purposes + ( + t_ribbon_raw, + center_data_decimated, + min_center_envelope, + max_center_envelope, + ) = self.decimator.decimate_for_view( + ribbon_def["t_raw"], + ribbon_def["center_data_raw"], + current_xlim_raw, + self.max_plot_points, + use_envelope=( + current_mode == self.MODE_ENVELOPE + ), # Use envelope for ribbons if in envelope mode + data_id=ribbon_def[ + "id" + ], # Pass the custom ribbon's ID for pre-decimated data lookup + return_envelope_min_max=True, # Ribbons always need min/max to draw fill + envelope_window_samples=None, # Envelope window calculated automatically + mode_switch_threshold=self.mode_switch_threshold, + ) + + # Decimate the width array as well, if it's an array + width_decimated = ribbon_def["width_raw"] + if len(ribbon_def["width_raw"]) > len( + t_ribbon_raw + ): # If raw width is longer than decimated time + # For simplicity, we'll just take the mean of the width in each bin + # A more robust solution might involve passing width as another data stream to decimate_for_view + # For now, we'll manually decimate it based on the t_ribbon_raw indices + # Find indices in raw data corresponding to decimated time points + # This is a simplified approach and assumes uniform sampling for width + indices = np.searchsorted(ribbon_def["t_raw"], t_ribbon_raw) + indices = np.clip(indices, 0, len(ribbon_def["width_raw"]) - 1) + width_decimated = ribbon_def["width_raw"][indices] + + # If the ribbon was decimated to an envelope, use that for min/max + if ( + current_mode == self.MODE_ENVELOPE + and min_center_envelope is not None + and max_center_envelope is not None + ): + lower_bound = min_center_envelope - width_decimated + upper_bound = max_center_envelope + width_decimated + else: + lower_bound = center_data_decimated - width_decimated + upper_bound = center_data_decimated + width_decimated + + if len(t_ribbon_raw) == 0 or len(lower_bound) == 0: + warnings.warn( + f"Ribbon ('{ribbon_def['label']}') has empty data after decimation, skipping plot.", UserWarning + ) + continue + + # Make sure the time array is in display coordinates + t_ribbon_display = self.coord_manager.raw_to_display(t_ribbon_raw) + + ribbon = self.ax.fill_between( + t_ribbon_display, + lower_bound, + upper_bound, + color=ribbon_def["color"], + alpha=ribbon_def["alpha"], + label=ribbon_def["label"], + zorder=ribbon_def["zorder"], + ) + ribbon_objects.append( + (ribbon, ribbon_def) + ) # Store both the ribbon and its definition + + # Show custom envelopes for current mode + for envelope_def in self._envelopes[trace_idx]: + if envelope_def["display_mode"] & current_mode: + + # For custom envelopes, we need to handle min/max data specially + # We'll decimate the min and max data separately using the envelope's stored data + # Since we stored min/max in the pre-decimated data, we can retrieve them + + # Get the pre-decimated envelope data for this custom envelope + if envelope_def["id"] in self.decimator._pre_decimated_envelopes: + pre_dec_data = self.decimator._pre_decimated_envelopes[ + envelope_def["id"] + ] + # The min/max data was stored in bg_initial/bg_clean during pre-decimation + t_envelope_raw, _, min_data_decimated, max_data_decimated = ( + self.decimator.decimate_for_view( + envelope_def["t_raw"], + ( + envelope_def["min_data_raw"] + + envelope_def["max_data_raw"] + ) + / 2, # Average for decimation + current_xlim_raw, + self.max_plot_points, + use_envelope=True, # Always treat custom envelopes as envelopes + data_id=envelope_def[ + "id" + ], # Pass the custom envelope's ID for pre-decimated data lookup + return_envelope_min_max=True, # Custom envelopes always need min/max to draw fill + envelope_window_samples=None, # Envelope window calculated automatically + mode_switch_threshold=self.mode_switch_threshold, + ) + ) + # For custom envelopes, the min/max are returned directly as the last two return values + else: + # Fallback if no pre-decimated data + warnings.warn( + f"No pre-decimated data for custom envelope {envelope_def['id']}, using raw decimation", UserWarning + ) + t_envelope_raw, _, min_data_decimated, max_data_decimated = ( + self.decimator.decimate_for_view( + envelope_def["t_raw"], + ( + envelope_def["min_data_raw"] + + envelope_def["max_data_raw"] + ) + / 2, + current_xlim_raw, + self.max_plot_points, + use_envelope=True, + data_id=None, # No pre-decimated data available + return_envelope_min_max=True, + envelope_window_samples=None, # Envelope window calculated automatically + mode_switch_threshold=self.mode_switch_threshold, + ) + ) + + if ( + len(t_envelope_raw) == 0 + or min_data_decimated is None + or max_data_decimated is None + or len(min_data_decimated) == 0 + ): + warnings.warn( + f"Custom envelope ('{envelope_def['label']}') has empty data after decimation, skipping plot.", UserWarning + ) + continue + + t_envelope_display = self.coord_manager.raw_to_display(t_envelope_raw) + + envelope = self.ax.fill_between( + t_envelope_display, + min_data_decimated, + max_data_decimated, + color=envelope_def["color"], + alpha=envelope_def["alpha"], + label=envelope_def["label"], + zorder=envelope_def["zorder"], + ) + ribbon_objects.append( + (envelope, envelope_def) + ) # Store in ribbon objects + + # Store objects with their definitions for future updates + self._line_objects[trace_idx] = line_objects + self._ribbon_objects[trace_idx] = ribbon_objects + + # Update last mode AFTER processing + self._last_mode[trace_idx] = current_mode + + def _update_element_visibility(self, trace_idx: int, current_mode: int) -> None: + """ + Update visibility of existing custom elements based on current mode. + + Parameters + ---------- + trace_idx : int + Index of the trace to update. + current_mode : int + Current display mode (MODE_ENVELOPE or MODE_DETAIL). + """ + # Update line visibility + for line_obj, line_def in self._line_objects[trace_idx]: + should_be_visible = bool(line_def["display_mode"] & current_mode) + if line_obj.get_visible() != should_be_visible: + line_obj.set_visible(should_be_visible) + + # Update ribbon visibility + for ribbon_obj, ribbon_def in self._ribbon_objects[trace_idx]: + should_be_visible = bool(ribbon_def["display_mode"] & current_mode) + if ribbon_obj.get_visible() != should_be_visible: + ribbon_obj.set_visible(should_be_visible) + + def _clear_custom_elements(self, trace_idx: int) -> None: + """ + Clear all custom visualization elements for a trace. + + Parameters + ---------- + trace_idx : int + Index of the trace to clear elements for. + """ + # Clear lines + for line_obj, _ in self._line_objects[trace_idx]: + line_obj.remove() + self._line_objects[trace_idx].clear() + + # Clear ribbons + for ribbon_obj, _ in self._ribbon_objects[trace_idx]: + ribbon_obj.remove() + self._ribbon_objects[trace_idx].clear() + + def _update_tick_locator(self, time_span_raw: np.float32) -> None: + """Update tick locator based on current time scale and span.""" + if self.state.current_time_scale >= np.float32(1e6): # microseconds or smaller + # For microsecond scale, use reasonable intervals + tick_interval = max( + 1, int(time_span_raw * self.state.current_time_scale / 10) + ) + self.ax.xaxis.set_major_locator(MultipleLocator(tick_interval)) + else: + # For larger scales, use matplotlib's default auto locator + self.ax.xaxis.set_major_locator(mpl.ticker.AutoLocator()) + + def _update_legend(self) -> None: + """Updates the plot legend, filtering out invisible elements and optimising rebuilds.""" + handles, labels = self.ax.get_legend_handles_labels() + + # Filter for unique and visible handles/labels + unique_labels = [] + unique_handles = [] + for h, l in zip(handles, labels): + # Check if the handle has a get_visible method and if it returns True + # For fill_between objects (ribbons, envelopes, regions), get_visible might not exist or behave differently + # For these, we assume they are visible if they are in the list of objects + is_visible = True + if hasattr(h, "get_visible"): + is_visible = h.get_visible() + elif isinstance( + h, mpl.collections.PolyCollection + ): # For fill_between objects + # PolyCollection doesn't have get_visible, but its patches might. + # Or we can assume it's visible if it's part of the current plot. + # For now, assume it's visible if it's a PolyCollection and has data. + is_visible = len(h.get_paths()) > 0 # Check if it has any paths to draw + + if l not in unique_labels and is_visible: + unique_labels.append(l) + unique_handles.append(h) + + # Create a hash of current handles/labels for efficient comparison + current_hash = hash(tuple(id(h) for h in unique_handles) + tuple(unique_labels)) + + # Check if legend content actually changed + if ( + not hasattr(self, "_last_legend_hash") + or self._last_legend_hash != current_hash + ): + if self._legend is not None: + self._legend.remove() # Remove old legend to prevent duplicates + + if unique_handles: # Only create legend if there are handles to show + self._legend = self.ax.legend( + unique_handles, unique_labels, loc="lower right" + ) + else: + self._legend = None # No legend to show + + self._current_legend_handles = unique_handles + self._current_legend_labels = unique_labels + self._last_legend_hash = current_hash + + def _clear_navigation_history(self): + """Clear matplotlib's navigation history when coordinate system changes.""" + if ( + self.fig + and self.fig.canvas + and hasattr(self.fig.canvas, "toolbar") + and self.fig.canvas.toolbar + ): + toolbar = self.fig.canvas.toolbar + if hasattr(toolbar, "_nav_stack"): + toolbar._nav_stack.clear() + + def _push_current_view(self): + """Push current view to navigation history as new base.""" + if ( + self.fig + and self.fig.canvas + and hasattr(self.fig.canvas, "toolbar") + and self.fig.canvas.toolbar + ): + toolbar = self.fig.canvas.toolbar + if hasattr(toolbar, "push_current"): + toolbar.push_current() + + def _update_axis_formatting(self) -> None: + """Update axis labels and formatters.""" + if self.state.offset_time_raw is not None: + offset_value = self.state.offset_time_raw * ( + 1e3 + if self.state.offset_unit == "ms" + else 1e6 + if self.state.offset_unit == "us" + else 1e9 + if self.state.offset_unit == "ns" + else 1.0 + ) + xlabel = f"Time ({self.state.current_time_unit}) + {offset_value:.3g} {self.state.offset_unit}" + else: + xlabel = f"Time ({self.state.current_time_unit})" + + self.ax.set_xlabel(xlabel) + + formatter = _create_time_formatter( + self.state.offset_time_raw, self.state.current_time_scale + ) + self.ax.xaxis.set_major_formatter(formatter) + + def _update_overlay_lines( + self, plot_data: Dict[str, Any], show_overlays: bool + ) -> None: + """Update overlay lines based on zoom level and data availability.""" + # Clear existing overlay lines from the plot + # This method is not currently used in the provided code, but if it were, + # it would need to be updated to use the new decimation strategy. + # For now, leaving it as is, assuming it's a placeholder or for future_use. + # If it were to be used, it would need to call decimate_for_view for each overlay line. + pass # No _overlay_lines attribute in this class, this method is unused. + + def _update_y_limits(self, plot_data: Dict[str, Any], use_envelope: bool) -> None: + """Update y-axis limits to fit current data.""" + y_min_data = float("inf") + y_max_data = float("-inf") + + # Process each trace + for trace_idx in range(self.data.num_traces): + x_new_key = f"x_new_{trace_idx}" + x_min_key = f"x_min_{trace_idx}" + x_max_key = f"x_max_{trace_idx}" + + if x_new_key not in plot_data: + continue + + # Include signal data + if len(plot_data[x_new_key]) > 0: + y_min_data = min(y_min_data, np.min(plot_data[x_new_key])) + y_max_data = max(y_max_data, np.max(plot_data[x_new_key])) + + # Include envelope data if available + if use_envelope and x_min_key in plot_data and x_max_key in plot_data: + if ( + plot_data[x_min_key] is not None + and plot_data[x_max_key] is not None + and len(plot_data[x_min_key]) > 0 + ): + y_min_data = min(y_min_data, np.min(plot_data[x_min_key])) + y_max_data = max(y_max_data, np.max(plot_data[x_max_key])) + + # Include custom lines + for line_obj, _ in self._line_objects[trace_idx]: + # Check if line_obj is a Line2D or PolyCollection + if isinstance(line_obj, mpl.lines.Line2D): + y_data = line_obj.get_ydata() + if len(y_data) > 0: + y_min_data = min(y_min_data, np.min(y_data)) + y_max_data = max(y_max_data, np.max(y_data)) + elif isinstance(line_obj, mpl.collections.PolyCollection): + # For fill_between objects, iterate through paths to get y-coordinates + for path in line_obj.get_paths(): + vertices = path.vertices + if len(vertices) > 0: + y_min_data = min(y_min_data, np.min(vertices[:, 1])) + y_max_data = max(y_max_data, np.max(vertices[:, 1])) + + # Include ribbon data + for ribbon_obj, _ in self._ribbon_objects[trace_idx]: + # For fill_between objects, we need to get the paths + if hasattr(ribbon_obj, "get_paths") and len(ribbon_obj.get_paths()) > 0: + for path in ribbon_obj.get_paths(): + vertices = path.vertices + if len(vertices) > 0: + y_min_data = min(y_min_data, np.min(vertices[:, 1])) + y_max_data = max(y_max_data, np.max(vertices[:, 1])) + + # Handle case where no data was found + if y_min_data == float("inf") or y_max_data == float("-inf"): + self.ax.set_ylim(0, 1) + return + + data_range = y_max_data - y_min_data + data_mean = (y_min_data + y_max_data) / 2 + + # Use min_y_range to ensure a minimum visible range + min_visible_range = self.min_y_range + + if data_range < min_visible_range: + y_min = data_mean - min_visible_range / 2 + y_max = data_mean + min_visible_range / 2 + else: + y_margin = self.y_margin_fraction * data_range + y_min = y_min_data - y_margin + y_max = y_max_data + y_margin + + self.ax.set_ylim(y_min, y_max) + + def _update_plot_data(self, ax_obj) -> None: + """Update plot based on current view.""" + if self.state.is_updating(): + return + + self.state.set_updating(True) + + try: + try: + view_params = self._calculate_view_parameters(ax_obj) + plot_data = self._get_plot_data(view_params) + self._render_plot_elements(plot_data, view_params) + self._update_regions_and_legend(view_params["xlim_display"]) + self.fig.canvas.draw_idle() + except Exception as e: + warnings.warn(f"Error updating plot: {e}", RuntimeWarning) + # Try to recover by resetting to home view + print("Attempting to recover by resetting to home view") + self.home() + finally: + self.state.set_updating(False) + + def _calculate_view_parameters(self, ax_obj) -> Dict[str, Any]: + """Calculate view parameters from current axis state.""" + try: + xlim_raw = self.coord_manager.get_current_view_raw(ax_obj) + + # Validate xlim_raw values + if not np.isfinite(xlim_raw[0]) or not np.isfinite(xlim_raw[1]): + warnings.warn( + f"Invalid xlim_raw from axis: {xlim_raw}. Using initial view.", RuntimeWarning + ) + xlim_raw = self._initial_xlim_raw + + # Ensure xlim_raw is in ascending order + if xlim_raw[0] > xlim_raw[1]: + warnings.warn(f"xlim_raw values out of order: {xlim_raw}. Swapping.", RuntimeWarning) + xlim_raw = (xlim_raw[1], xlim_raw[0]) + + time_span_raw = xlim_raw[1] - xlim_raw[0] + use_envelope = self.state.should_use_envelope(time_span_raw) + current_mode = self.MODE_ENVELOPE if use_envelope else self.MODE_DETAIL + + # Update coordinate system if needed + coordinate_system_changed = self.state.update_display_params( + xlim_raw, time_span_raw + ) + if coordinate_system_changed: + self._update_coordinate_system(xlim_raw, time_span_raw) + + return { + "xlim_raw": xlim_raw, + "time_span_raw": time_span_raw, + "xlim_display": self.coord_manager.xlim_raw_to_display(xlim_raw), + "use_envelope": use_envelope, + "current_mode": current_mode, + } + except Exception as e: + warnings.warn(f"Error calculating view parameters: {e}", RuntimeWarning) + # Return safe default values + return { + "xlim_raw": self._initial_xlim_raw, + "time_span_raw": self._initial_xlim_raw[1] - self._initial_xlim_raw[0], + "xlim_display": self.coord_manager.xlim_raw_to_display( + self._initial_xlim_raw + ), + "use_envelope": True, + "current_mode": self.MODE_ENVELOPE, + } + + def _get_plot_data(self, view_params: Dict[str, Any]) -> Dict[str, Any]: + """Get decimated plot data for current view.""" + plot_data = {} + + # Process each trace + for trace_idx in range(self.data.num_traces): + t_arr = self.data.t_arrays[trace_idx] + x_arr = self.data.x_arrays[trace_idx] + + try: + t_raw, x_new, x_min, x_max = self.decimator.decimate_for_view( + t_arr, + x_arr, + view_params["xlim_raw"], + self.max_plot_points, + view_params["use_envelope"], + trace_idx, # Pass trace_id to use pre-decimated data + envelope_window_samples=None, # Envelope window calculated automatically + mode_switch_threshold=self.mode_switch_threshold, # Pass mode switch threshold + return_envelope_min_max=True, # Main signal always returns envelope min/max if use_envelope is True + ) + + if len(t_raw) == 0: + warnings.warn( + f"No data in current view for trace {trace_idx}. View range: {view_params['xlim_raw']}", UserWarning + ) + # Add empty arrays for this trace + plot_data[f"t_display_{trace_idx}"] = np.array([], dtype=np.float32) + plot_data[f"x_new_{trace_idx}"] = np.array([], dtype=np.float32) + plot_data[f"x_min_{trace_idx}"] = None + plot_data[f"x_max_{trace_idx}"] = None + continue + + t_display = self.coord_manager.raw_to_display(t_raw) + + # Store data for this trace + plot_data[f"t_display_{trace_idx}"] = t_display + plot_data[f"x_new_{trace_idx}"] = x_new + plot_data[f"x_min_{trace_idx}"] = x_min + plot_data[f"x_max_{trace_idx}"] = x_max + + except Exception as e: + warnings.warn(f"Error getting plot data for trace {trace_idx}: {e}", RuntimeWarning) + # Add empty arrays for this trace to prevent further errors + plot_data[f"t_display_{trace_idx}"] = np.array([], dtype=np.float32) + plot_data[f"x_new_{trace_idx}"] = np.array([], dtype=np.float32) + plot_data[f"x_min_{trace_idx}"] = None + plot_data[f"x_max_{trace_idx}"] = None + + return plot_data + + def _render_plot_elements( + self, plot_data: Dict[str, Any], view_params: Dict[str, Any] + ) -> None: + """Render all plot elements with current data.""" + # Store the current plot data for use by other methods + self._current_plot_data = plot_data + + # Check if we have any data to plot + has_data = False + for trace_idx in range(self.data.num_traces): + key = f"t_display_{trace_idx}" + if key in plot_data and len(plot_data[key]) > 0: + has_data = True + break + + if not has_data: + warnings.warn("No data to plot, clearing all elements", UserWarning) + # If no data, clear all lines and return + for i in range(self.data.num_traces): + self._signal_lines[i].set_data([], []) + if self._envelope_fills[i] is not None: + self._envelope_fills[i].remove() + self._envelope_fills[i] = None + + # Clear custom elements + self._clear_custom_elements(i) + + self.ax.set_ylim(0, 1) # Set a default y-limit + return + + # Process each trace + for trace_idx in range(self.data.num_traces): + t_display_key = f"t_display_{trace_idx}" + x_new_key = f"x_new_{trace_idx}" + x_min_key = f"x_min_{trace_idx}" + x_max_key = f"x_max_{trace_idx}" + + if t_display_key not in plot_data or len(plot_data[t_display_key]) == 0: + # No data for this trace, hide its elements + self._signal_lines[trace_idx].set_data([], []) + if self._envelope_fills[trace_idx] is not None: + self._envelope_fills[trace_idx].remove() + self._envelope_fills[trace_idx] = None + + # Clear custom elements + self._clear_custom_elements(trace_idx) + continue + + # Update signal display + envelope_data = None + if ( + view_params["use_envelope"] + and x_min_key in plot_data + and x_max_key in plot_data + ): + if ( + plot_data[x_min_key] is not None + and plot_data[x_max_key] is not None + ): + envelope_data = (plot_data[x_min_key], plot_data[x_max_key]) + + self._update_signal_display( + trace_idx, plot_data[t_display_key], plot_data[x_new_key], envelope_data + ) + + # Update y-limits + self._update_y_limits(plot_data, view_params["use_envelope"]) + + def _update_coordinate_system( + self, xlim_raw: Tuple[np.float32, np.float32], time_span_raw: np.float32 + ) -> None: + """Update coordinate system and axis formatting.""" + self._clear_region_fills() + self._update_axis_formatting() + self._update_tick_locator(time_span_raw) + + xlim_display = self.coord_manager.xlim_raw_to_display(xlim_raw) + self.ax.set_xlim(xlim_display) + + self._clear_navigation_history() + self._push_current_view() + + def _update_regions_and_legend( + self, xlim_display: Tuple[np.float32, np.float32] + ) -> None: + """Update regions and legend.""" + self._refresh_region_display(xlim_display) + self._update_legend() + + def _refresh_region_display( + self, xlim_display: Tuple[np.float32, np.float32] + ) -> None: + """Refresh region display for current view.""" + self._clear_region_fills() + + # Get current mode + current_mode = ( + self.MODE_ENVELOPE + if self.state.current_mode == "envelope" + else self.MODE_DETAIL + ) + + for trace_idx in range(self.data.num_traces): + # Process each region definition + for region_def in self._regions[trace_idx]: + # Skip if not visible in current mode + if not (region_def["display_mode"] & current_mode): + continue + + regions = region_def["regions"] + if regions is None or len(regions) == 0: + continue + + color = region_def["color"] + label = region_def["label"] + alpha = region_def["alpha"] + first_visible_region = True + + for t_start, t_end in regions: + t_start_display = self.coord_manager.raw_to_display(t_start) + t_end_display = self.coord_manager.raw_to_display(t_end) + + # Check if region overlaps with current view + if not ( + t_end_display <= xlim_display[0] + or t_start_display >= xlim_display[1] + ): + # Only show label for first visible region + current_label = label if first_visible_region else "" + if first_visible_region and len(regions) > 1: + current_label = f"{label} ({len(regions)})" + + fill = self.ax.axvspan( + t_start_display, + t_end_display, + alpha=alpha, + color=color, + linewidth=0.5, + label=current_label, + zorder=region_def["zorder"], + ) + self._region_objects[trace_idx].append((fill, region_def)) + first_visible_region = False + + def _clear_region_fills(self) -> None: + """Clear all region fills.""" + for trace_fills in self._region_objects: + for fill_item in trace_fills: + # Handle both old format (just fill object) and new format (tuple) + if isinstance(fill_item, tuple): + fill, _ = fill_item # Extract the fill object from the tuple + fill.remove() + else: + fill_item.remove() # Old format - direct fill object + trace_fills.clear() + + def _setup_plot_elements(self) -> None: + """ + Initialise matplotlib plot elements (lines, fills) for each trace. + This is called once during render(). + """ + if self.fig is None or self.ax is None: + raise RuntimeError( + "Figure and Axes must be created before setting up plot elements." + ) + + # Create initial signal line objects for each trace + for i in range(self.data.num_traces): + color = self.data.get_trace_color(i) + name = self.data.get_trace_name(i) + + # Signal line + (line_signal,) = self.ax.plot( + [], + [], + label="Raw data" if self.data.num_traces == 1 else f"Raw data ({name})", + color=color, + alpha=self.signal_alpha, + ) + self._signal_lines.append(line_signal) + + def _connect_callbacks(self) -> None: + """Connect matplotlib callbacks.""" + if self.ax is None: + raise RuntimeError("Axes must be created before connecting callbacks.") + self.ax.callbacks.connect("xlim_changed", self._update_plot_data) + + def _setup_toolbar_overrides(self) -> None: + """Override matplotlib toolbar methods (e.g., home button).""" + if ( + self.fig + and self.fig.canvas + and hasattr(self.fig.canvas, "toolbar") + and self.fig.canvas.toolbar + ): + toolbar = self.fig.canvas.toolbar + + # Store original methods + self._original_home = getattr(toolbar, "home", None) + self._original_push_current = getattr(toolbar, "push_current", None) + + # Create our custom home method + def custom_home(*args, **kwargs): + self.home() + + # Override both the method and try to find the actual button + toolbar.home = custom_home + + # For Qt backend, also override the action + if hasattr(toolbar, "actions"): + for action in toolbar.actions(): + if hasattr(action, "text") and hasattr(action, "objectName"): + action_text = ( + action.text() if callable(action.text) else str(action.text) + ) + action_name = ( + action.objectName() + if callable(action.objectName) + else str(action.objectName) + ) + if action_text == "Home" or "home" in action_name.lower(): + if hasattr(action, "triggered"): + action.triggered.disconnect() + action.triggered.connect(custom_home) + break + + # For other backends, try to override the button callback + if hasattr(toolbar, "_buttons") and "Home" in toolbar._buttons: + home_button = toolbar._buttons["Home"] + if hasattr(home_button, "configure"): + home_button.configure(command=custom_home) + + def _set_initial_view_and_labels(self) -> None: + """Set initial axis limits, title, and labels.""" + if self.ax is None: + raise RuntimeError( + "Axes must be created before setting initial view and labels." + ) + + # Create title based on number of traces + if self.data.num_traces == 1: + self.ax.set_title(f"{self.data.names[0]}") + else: + # Multiple traces - just show "Multiple Traces" + self.ax.set_title(f"Multiple Traces ({self.data.num_traces})") + self.ax.set_xlabel(f"Time ({self.state.current_time_unit})") + self.ax.set_ylabel("Signal") + + # Set initial xlim + initial_xlim_display = self.coord_manager.xlim_raw_to_display( + self._initial_xlim_raw + ) + self.ax.set_xlim(initial_xlim_display) + + def render(self) -> None: + """ + Renders the oscilloscope plot. This method must be called after all + data and visualization elements have been added. + """ + if self.fig is not None or self.ax is not None: + warnings.warn( + "Plot already rendered. Call `home()` to reset or create a new instance.", UserWarning + ) + return + + print("Rendering plot...") + self.fig, self.ax = plt.subplots(figsize=(10, 5)) + + self._setup_plot_elements() + self._connect_callbacks() + self._setup_toolbar_overrides() + self._set_initial_view_and_labels() + + # Calculate initial parameters for the full view + t_start, t_end = self.data.get_global_time_range() + full_time_span = t_end - t_start + + print( + f"Initial render: full time span={full_time_span:.3e}s, envelope_limit={self.mode_switch_threshold:.3e}s" + ) + + # Set initial display state based on full view + self.state.current_time_unit, self.state.current_time_scale = ( + _get_optimal_time_unit_and_scale(full_time_span) + ) + self.state.current_mode = ( + "envelope" if self.state.should_use_envelope(full_time_span) else "detail" + ) + + # Force initial draw of all elements by calling _update_plot_data + # This will also update the legend and regions + self.state.set_updating(False) # Ensure not in updating state for first call + self._update_plot_data(self.ax) + self.fig.canvas.draw_idle() + print("Plot rendering complete.") + + def home(self) -> None: + """Return to initial full view with complete state reset.""" + if self.ax is None: # Fix: Changed '===' to 'is' + warnings.warn("Plot not rendered yet. Cannot go home.", UserWarning) + return + + # Disconnect callback temporarily + callback_id = None + for cid, callback in self.ax.callbacks.callbacks["xlim_changed"].items(): + if getattr(callback, "__func__", callback) == self._update_plot_data: + callback_id = cid + break + + if callback_id is not None: + self.ax.callbacks.disconnect(callback_id) + + try: + self.state.set_updating(True) + self.state.reset_to_initial_state() + self.decimator.clear_cache() + self._clear_region_fills() + + # Clear all custom elements and reset _last_mode for each trace to force redraw + for trace_idx in range(self.data.num_traces): + self._clear_custom_elements(trace_idx) + self._last_mode[trace_idx] = None + + # Reset axis formatting + self.ax.set_xlabel(f"Time ({self.state.original_time_unit})") + self.ax.xaxis.set_major_formatter(mpl.ticker.ScalarFormatter()) + self.ax.xaxis.set_major_locator(mpl.ticker.AutoLocator()) + + # Reset view + self.coord_manager.set_view_raw(self.ax, self._initial_xlim_raw) + + # Manually trigger update for the home view + # This will re-evaluate use_envelope, current_mode, and redraw everything + self._update_plot_data(self.ax) + + self.state.set_updating(False) + + finally: + self.ax.callbacks.connect("xlim_changed", self._update_plot_data) + + self.fig.canvas.draw() + print(f"Home view restored: {self.state.original_time_unit} scale") + + def refresh(self) -> None: + """Force a complete refresh of the plot without changing the current view.""" + if self.ax is None: + warnings.warn("Plot not rendered yet. Cannot refresh.", UserWarning) + return + + # Temporarily bypass the updating state for forced refresh + was_updating = self.state.is_updating() + self.state.set_updating(False) + try: + self._update_plot_data(self.ax) + finally: + self.state.set_updating(was_updating) + self.fig.canvas.draw_idle() + + def show(self) -> None: + """Display the plot.""" + if self.fig is None: + self.render() # Render if not already rendered + plt.show() -- cgit v1.2.3