summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorSam Scholten2025-10-23 15:01:40 +1000
committerSam Scholten2025-10-23 15:01:40 +1000
commit4a7026759e099e5c81cc9c77f19182a23d2f0275 (patch)
treedcdff1dc81401b4a56248c05f99da47121056d55 /src
downloadtransivent-4a7026759e099e5c81cc9c77f19182a23d2f0275.tar.gz
transivent-4a7026759e099e5c81cc9c77f19182a23d2f0275.zip
Initial release v1.0.0v1.0.0
Event detection and analysis pipeline for transient events in time-series data. - Event detection based on SNR thresholds - Configurable background estimation and noise analysis - Visualization with scopekit integration - Chunked processing for large files
Diffstat (limited to 'src')
-rw-r--r--src/transivent/__init__.py36
-rw-r--r--src/transivent/analysis.py1238
-rw-r--r--src/transivent/event_detector.py404
-rw-r--r--src/transivent/event_plotter.py524
-rw-r--r--src/transivent/io.py456
5 files changed, 2658 insertions, 0 deletions
diff --git a/src/transivent/__init__.py b/src/transivent/__init__.py
new file mode 100644
index 0000000..db3e824
--- /dev/null
+++ b/src/transivent/__init__.py
@@ -0,0 +1,36 @@
+"""
+High-level analysis and plotting for transient events.
+"""
+
+from .analysis import (
+ analyze_thresholds,
+ calculate_initial_background,
+ calculate_smoothing_parameters,
+ configure_logging,
+ create_oscilloscope_plot,
+ get_final_events,
+ initialize_state,
+ process_chunk,
+ process_file,
+)
+from .event_detector import detect_events, merge_overlapping_events
+from .event_plotter import EventPlotter
+from .io import get_waveform_params, rd, rd_chunked
+
+__all__ = [
+ "analyze_thresholds",
+ "calculate_initial_background",
+ "calculate_smoothing_parameters",
+ "configure_logging",
+ "create_oscilloscope_plot",
+ "detect_events",
+ "EventPlotter",
+ "get_final_events",
+ "get_waveform_params",
+ "initialize_state",
+ "merge_overlapping_events",
+ "process_chunk",
+ "process_file",
+ "rd",
+ "rd_chunked",
+]
diff --git a/src/transivent/analysis.py b/src/transivent/analysis.py
new file mode 100644
index 0000000..1b5277b
--- /dev/null
+++ b/src/transivent/analysis.py
@@ -0,0 +1,1238 @@
+import base64
+import io
+import os
+import sys
+import time
+from typing import Any, Dict, List, Optional, Tuple
+
+import matplotlib.pyplot as plt
+import numpy as np
+from loguru import logger
+from numba import njit
+from PIL import Image
+from scipy.ndimage import gaussian_filter1d, median_filter, uniform_filter1d
+from scipy.signal import savgol_filter
+
+from scopekit.plot import OscilloscopePlot
+from .io import _get_xml_sidecar_path, rd, rd_chunked
+
+from .event_detector import (
+ MEDIAN_TO_STD_FACTOR,
+ detect_events,
+ merge_overlapping_events,
+)
+from .event_plotter import EventPlotter
+import xml.etree.ElementTree as ET
+
+
+@njit
+def _create_event_mask_numba(t: np.ndarray, events: np.ndarray) -> np.ndarray:
+ """
+ Create a boolean mask to exclude event regions using numba for speed.
+
+ Parameters
+ ----------
+ t : np.ndarray
+ Time array.
+ events : np.ndarray
+ Events array with shape (n_events, 2) where each row is [t_start, t_end].
+
+ Returns
+ -------
+ np.ndarray
+ Boolean mask where True means keep the sample, False means exclude.
+ """
+ mask = np.ones(len(t), dtype=np.bool_)
+ if len(events) == 0:
+ return mask
+
+ for i in range(len(events)):
+ t_start = events[i, 0]
+ t_end = events[i, 1]
+
+ start_idx = np.searchsorted(t, t_start, side="left")
+ end_idx = np.searchsorted(t, t_end, side="left")
+
+ if start_idx < end_idx:
+ mask[start_idx:end_idx] = False
+
+ return mask
+
+
+def extract_preview_image(sidecar_path: str, output_path: str) -> Optional[str]:
+ """
+ Extract preview image from XML sidecar and save as PNG.
+
+ Parameters
+ ----------
+ sidecar_path : str
+ Path to the XML sidecar file.
+ output_path : str
+ Path where to save the PNG file.
+
+ Returns
+ -------
+ Optional[str]
+ Path to saved PNG file, or None if no image found.
+ """
+ try:
+ tree = ET.parse(sidecar_path)
+ root = tree.getroot()
+
+ # Find PreviewImage element
+ preview_elem = root.find(".//PreviewImage")
+ if preview_elem is None:
+ logger.warning(f"No PreviewImage found in {sidecar_path}")
+ return None
+
+ image_data = preview_elem.get("ImageData")
+ if not image_data:
+ logger.warning(f"Empty ImageData in PreviewImage from {sidecar_path}")
+ return None
+
+ # Decode base64 image data
+ image_bytes = base64.b64decode(image_data)
+
+ # Open with PIL and save as PNG
+ image = Image.open(io.BytesIO(image_bytes))
+ image.save(output_path, "PNG")
+
+ logger.info(f"Saved preview image: {output_path}")
+ return output_path
+
+ except Exception as e:
+ logger.warning(f"Failed to extract preview image from {sidecar_path}: {e}")
+ return None
+
+
+def plot_preview_image(image_path: str, title: str = "Preview Image") -> None:
+ """
+ Display preview image using matplotlib.
+
+ Parameters
+ ----------
+ image_path : str
+ Path to the image file.
+ title : str
+ Title for the plot.
+ """
+ try:
+ image = Image.open(image_path)
+
+ fig, ax = plt.subplots(figsize=(10, 6))
+ ax.imshow(image)
+ ax.set_title(title)
+ ax.axis('off') # Hide axes for cleaner display
+
+ except Exception as e:
+ logger.warning(f"Failed to display preview image {image_path}: {e}")
+
+
+def configure_logging(log_level: str = "INFO") -> None:
+ """
+ Configure loguru logging with specified level.
+
+ Parameters
+ ----------
+ log_level : str, default="INFO"
+ Logging level: DEBUG, INFO, WARNING, ERROR, CRITICAL.
+ """
+ logger.remove()
+ logger.add(
+ sys.stderr,
+ level=log_level.upper(),
+ format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>",
+ colorize=True,
+ )
+
+
+def load_data(
+ name: str,
+ sampling_interval: float,
+ data_path: str,
+ sidecar: Optional[str] = None,
+ crop: Optional[List[int]] = None,
+) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Stage 1: Load waveform data from file.
+
+ Parameters
+ ----------
+ name : str
+ Filename of the waveform data.
+ sampling_interval : float
+ Sampling interval in seconds.
+ data_path : str
+ Path to data directory.
+ sidecar : str, optional
+ XML sidecar filename.
+ crop : List[int], optional
+ Crop indices [start, end].
+
+ Returns
+ -------
+ Tuple[np.ndarray, np.ndarray]
+ Time and signal arrays.
+ """
+ logger.success(f"Loading data from {name}")
+ t, x = rd(
+ name,
+ sampling_interval,
+ data_path=data_path,
+ sidecar=sidecar,
+ crop=crop,
+ )
+
+ logger.debug(
+ f"Signal statistics: min={np.min(x):.3g}, max={np.max(x):.3g}, mean={np.mean(x):.3g}, std={np.std(x):.3g}"
+ )
+
+ return t, x
+
+
+def calculate_smoothing_parameters(
+ sampling_interval: float,
+ smooth_win_t: Optional[float],
+ smooth_win_f: Optional[float],
+ min_event_t: float,
+ detection_snr: float,
+ min_event_keep_snr: float,
+ widen_frac: float,
+ signal_polarity: int,
+) -> Tuple[int, int]:
+ """
+ Calculate smoothing window size and minimum event length in samples.
+
+ Parameters
+ ----------
+ sampling_interval : float
+ Sampling interval in seconds.
+ smooth_win_t : Optional[float]
+ Smoothing window in seconds.
+ smooth_win_f : Optional[float]
+ Smoothing window in Hz.
+ min_event_t : float
+ Minimum event duration in seconds.
+ detection_snr : float
+ Detection SNR threshold.
+ min_event_keep_snr : float
+ Minimum event keep SNR threshold.
+ widen_frac : float
+ Fraction to widen detected events.
+ signal_polarity : int
+ Signal polarity (-1 for negative, +1 for positive).
+
+ Returns
+ -------
+ Tuple[int, int]
+ Smoothing window size and minimum event length in samples.
+ """
+ if smooth_win_t is not None:
+ smooth_n = int(smooth_win_t / sampling_interval)
+ elif smooth_win_f is not None:
+ smooth_n = int(1 / (smooth_win_f * sampling_interval))
+ else:
+ raise ValueError("Set either smooth_win_t or smooth_win_f")
+
+ if smooth_n % 2 == 0:
+ smooth_n += 1
+
+ min_event_n = max(1, int(min_event_t / sampling_interval))
+
+ smooth_freq_hz = 1 / (smooth_n * sampling_interval)
+ logger.info(
+ f"--Smooth window: {smooth_n} samples ({smooth_win_t * 1e6:.1f} µs, {smooth_freq_hz:.1f} Hz)"
+ )
+ logger.info(
+ f"--Min event length: {min_event_n} samples ({min_event_t * 1e6:.1f} µs)"
+ )
+ logger.info(f"--Detection SNR: {detection_snr}")
+ logger.info(f"--Min keep SNR: {min_event_keep_snr}")
+ logger.info(f"--Widen fraction: {widen_frac}")
+ logger.info(
+ f"--Signal polarity: {signal_polarity} ({'negative' if signal_polarity < 0 else 'positive'} events)"
+ )
+
+ return smooth_n, min_event_n
+
+
+def calculate_initial_background(
+ t: np.ndarray, x: np.ndarray, smooth_n: int, filter_type: str = "gaussian"
+) -> np.ndarray:
+ """
+ Stage 2: Calculate initial background estimate.
+
+ Parameters
+ ----------
+ t : np.ndarray
+ Time array.
+ x : np.ndarray
+ Signal array.
+ smooth_n : int
+ Smoothing window size in samples.
+ filter_type : str, default="gaussian"
+ Filter type: "savgol", "gaussian", "moving_average", "median".
+
+ Returns
+ -------
+ np.ndarray
+ Initial background estimate.
+
+ Notes
+ -----
+ 1 Start with Gaussian - Best balance of speed, noise rejection, and event preservation
+ 2 Try Median if you see frequent spikes/glitches in your data
+ 3 Use Moving Average for maximum speed if events are well above noise
+ 4 Reserve Savitzky-Golay for final high-quality analysis of interesting datasets
+ """
+ logger.info(f"Calculating initial background using {filter_type} filter")
+
+ if filter_type == "savgol":
+ bg_initial = savgol_filter(x, smooth_n, 3).astype(np.float32)
+ elif filter_type == "gaussian":
+ sigma = smooth_n / 6.0 # Convert window to sigma (6-sigma rule)
+ bg_initial = gaussian_filter1d(x.astype(np.float64), sigma).astype(np.float32)
+ elif filter_type == "moving_average":
+ # Use scipy's uniform_filter1d for proper edge handling
+ bg_initial = uniform_filter1d(
+ x.astype(np.float64), size=smooth_n, mode="nearest"
+ ).astype(np.float32)
+ elif filter_type == "median":
+ bg_initial = median_filter(x.astype(np.float64), size=smooth_n).astype(
+ np.float32
+ )
+ else:
+ raise ValueError(
+ f"Unknown filter_type: {filter_type}. Choose from 'savgol', 'gaussian', 'moving_average', 'median'"
+ )
+
+ logger.debug(
+ f"Initial background: mean={np.mean(bg_initial):.3g}, std={np.std(bg_initial):.3g}"
+ )
+ return bg_initial
+
+
+def estimate_noise(x: np.ndarray, bg_initial: np.ndarray) -> np.float32:
+ """
+ Stage 2: Estimate noise level.
+
+ Parameters
+ ----------
+ x : np.ndarray
+ Signal array.
+ bg_initial : np.ndarray
+ Initial background estimate.
+
+ Returns
+ -------
+ np.float32
+ Estimated noise level.
+ """
+ global_noise = np.float32(np.median(np.abs(x - bg_initial)) * MEDIAN_TO_STD_FACTOR)
+
+ signal_rms = np.sqrt(np.mean(x**2))
+ signal_range = np.max(x) - np.min(x)
+ noise_pct_rms = 100 * global_noise / signal_rms if signal_rms > 0 else 0
+ noise_pct_range = 100 * global_noise / signal_range if signal_range > 0 else 0
+
+ logger.info(
+ f"Global noise level: {global_noise:.3g} ({noise_pct_rms:.1f}% of RMS, {noise_pct_range:.1f}% of range)"
+ )
+
+ snr_estimate = np.std(x) / global_noise
+ logger.info(f"Estimated signal SNR: {snr_estimate:.2f}")
+
+ return global_noise
+
+
+def detect_initial_events(
+ t: np.ndarray,
+ x: np.ndarray,
+ bg_initial: np.ndarray,
+ global_noise: np.float32,
+ detection_snr: float,
+ min_event_keep_snr: float,
+ widen_frac: float,
+ signal_polarity: int,
+ min_event_n: int,
+) -> np.ndarray:
+ """
+ Stage 3: Detect initial events.
+
+ Parameters
+ ----------
+ t : np.ndarray
+ Time array.
+ x : np.ndarray
+ Signal array.
+ bg_initial : np.ndarray
+ Initial background estimate.
+ global_noise : np.float32
+ Estimated noise level.
+ detection_snr : float
+ Detection SNR threshold.
+ min_event_keep_snr : float
+ Minimum event keep SNR threshold.
+ widen_frac : float
+ Fraction to widen detected events.
+ signal_polarity : int
+ Signal polarity (-1 for negative, +1 for positive).
+ min_event_n : int
+ Minimum event length in samples.
+
+ Returns
+ -------
+ np.ndarray
+ Array of initial events.
+ """
+ logger.info("Detecting initial events")
+ min_event_amp = np.float32(min_event_keep_snr) * global_noise
+
+ logger.info(f"Detection threshold: {detection_snr}σ below background")
+ logger.info(f"Keep threshold: {min_event_keep_snr}σ below background")
+ logger.info(f"Min event amplitude threshold: {min_event_amp:.3g}")
+
+ events_initial, _ = detect_events(
+ t,
+ x,
+ bg_initial,
+ snr_threshold=np.float32(detection_snr),
+ min_event_len=min_event_n,
+ min_event_amp=min_event_amp,
+ widen_frac=np.float32(widen_frac),
+ global_noise=global_noise,
+ signal_polarity=signal_polarity,
+ )
+
+ logger.info(f"Found {len(events_initial)} initial events after filtering")
+
+ events_initial = merge_overlapping_events(events_initial)
+ logger.info(f"After merging: {len(events_initial)} events")
+
+ return events_initial
+
+
+def calculate_clean_background(
+ t: np.ndarray,
+ x: np.ndarray,
+ events_initial: np.ndarray,
+ smooth_n: int,
+ bg_initial: np.ndarray,
+ filter_type: str = "gaussian",
+ filter_order: int = 2,
+) -> np.ndarray:
+ """
+ Stage 4: Calculate clean background by masking events.
+
+ Parameters
+ ----------
+ t : np.ndarray
+ Time array.
+ x : np.ndarray
+ Signal array.
+ events_initial : np.ndarray
+ Initial events array.
+ smooth_n : int
+ Smoothing window size in samples.
+ bg_initial : np.ndarray
+ Initial background estimate.
+ filter_type : str, default="gaussian"
+ Filter type: "savgol", "gaussian", "moving_average", "median".
+ filter_order : int, default=2
+ Order of the Savitzky-Golay filter (only used for filter_type="savgol").
+
+ Returns
+ -------
+ np.ndarray
+ Clean background estimate.
+
+
+ Notes
+ -----
+ 1 Start with Gaussian - Best balance of speed, noise rejection, and event preservation
+ 2 Try Median if you see frequent spikes/glitches in your data
+ 3 Use Moving Average for maximum speed if events are well above noise
+ 4 Reserve Savitzky-Golay for final high-quality analysis of interesting datasets
+ """
+ logger.info(f"Calculating clean background using {filter_type} filter")
+ start_time = time.time()
+
+ # Fast masking with numba
+ mask = _create_event_mask_numba(t, events_initial)
+ mask_time = time.time()
+
+ logger.debug(
+ f"Masked {np.sum(~mask)} samples ({100 * np.sum(~mask) / len(mask):.1f}%) for clean background"
+ )
+
+ t_masked = t[mask]
+ x_masked = x[mask]
+
+ if np.sum(mask) > 2 * smooth_n:
+ # Check if we need interpolation (events detected and masking applied)
+ if len(events_initial) == 0 or np.all(mask):
+ # No events detected or no masking needed - skip interpolation
+ logger.debug("No events to mask - using direct filtering")
+ interp_start = time.time()
+ x_interp = x
+ interp_end = time.time()
+ else:
+ # Events detected - need interpolation
+ interp_start = time.time()
+ x_interp = np.interp(t, t_masked, x_masked)
+ interp_end = time.time()
+
+ filter_start = time.time()
+ if filter_type == "savgol":
+ bg_clean = savgol_filter(x_interp, smooth_n, filter_order).astype(
+ np.float32
+ )
+ elif filter_type == "gaussian":
+ sigma = smooth_n / 6.0 # Convert window to sigma (6-sigma rule)
+ bg_clean = gaussian_filter1d(x_interp.astype(np.float64), sigma).astype(
+ np.float32
+ )
+ elif filter_type == "moving_average":
+ bg_clean = uniform_filter1d(
+ x_interp.astype(np.float64), size=smooth_n, mode="nearest"
+ ).astype(np.float32)
+ elif filter_type == "median":
+ bg_clean = median_filter(x_interp.astype(np.float64), size=smooth_n).astype(
+ np.float32
+ )
+ else:
+ raise ValueError(
+ f"Unknown filter_type: {filter_type}. Choose from 'savgol', 'gaussian', 'moving_average', 'median'"
+ )
+ filter_end = time.time()
+
+ logger.success(
+ f"Timing: mask={mask_time - start_time:.3f}s, interp={interp_end - interp_start:.3f}s, filter={filter_end - filter_start:.3f}s"
+ )
+ logger.debug(
+ f"Clean background: mean={np.mean(bg_clean):.3g}, std={np.std(bg_clean):.3g}"
+ )
+ else:
+ logger.debug(
+ "Insufficient unmasked samples for clean background - using initial"
+ )
+ bg_clean = bg_initial
+
+ return bg_clean
+
+
+def analyze_thresholds(
+ x: np.ndarray,
+ bg_clean: np.ndarray,
+ global_noise: np.float32,
+ detection_snr: float,
+ min_event_keep_snr: float,
+ signal_polarity: int,
+) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Analyze threshold statistics and create threshold arrays.
+
+ Parameters
+ ----------
+ x : np.ndarray
+ Signal array.
+ bg_clean : np.ndarray
+ Clean background estimate.
+ global_noise : np.float32
+ Estimated noise level.
+ detection_snr : float
+ Detection SNR threshold.
+ min_event_keep_snr : float
+ Minimum event keep SNR threshold.
+ signal_polarity : int
+ Signal polarity (-1 for negative, +1 for positive).
+
+ Returns
+ -------
+ Tuple[np.ndarray, np.ndarray]
+ Detection and keep threshold arrays.
+ """
+ logger.info("Analyzing thresholds")
+
+ if signal_polarity < 0:
+ detection_threshold = bg_clean - detection_snr * global_noise
+ keep_threshold = bg_clean - min_event_keep_snr * global_noise
+ below_detection_pct = 100 * np.sum(x < detection_threshold) / len(x)
+ below_keep_pct = 100 * np.sum(x < keep_threshold) / len(x)
+ logger.info(f"Samples below detection threshold: {below_detection_pct:.2f}%")
+ logger.info(f"Samples below keep threshold: {below_keep_pct:.2f}%")
+ else:
+ detection_threshold = bg_clean + detection_snr * global_noise
+ keep_threshold = bg_clean + min_event_keep_snr * global_noise
+ above_detection_pct = 100 * np.sum(x > detection_threshold) / len(x)
+ above_keep_pct = 100 * np.sum(x > keep_threshold) / len(x)
+ logger.info(f"Samples above detection threshold: {above_detection_pct:.2f}%")
+ logger.info(f"Samples above keep threshold: {above_keep_pct:.2f}%")
+
+ return detection_threshold, keep_threshold
+
+
+def detect_final_events(
+ t: np.ndarray,
+ x: np.ndarray,
+ bg_clean: np.ndarray,
+ global_noise: np.float32,
+ detection_snr: float,
+ min_event_keep_snr: float,
+ widen_frac: float,
+ signal_polarity: int,
+ min_event_n: int,
+) -> np.ndarray:
+ """
+ Stage 5: Detect final events using clean background.
+
+ Parameters
+ ----------
+ t : np.ndarray
+ Time array.
+ x : np.ndarray
+ Signal array.
+ bg_clean : np.ndarray
+ Clean background estimate.
+ global_noise : np.float32
+ Estimated noise level.
+ detection_snr : float
+ Detection SNR threshold.
+ min_event_keep_snr : float
+ Minimum event keep SNR threshold.
+ widen_frac : float
+ Fraction to widen detected events.
+ signal_polarity : int
+ Signal polarity (-1 for negative, +1 for positive).
+ min_event_n : int
+ Minimum event length in samples.
+
+ Returns
+ -------
+ np.ndarray
+ Array of final events.
+ """
+ logger.info("Detecting final events")
+ min_event_amp = np.float32(min_event_keep_snr) * global_noise
+
+ events, noise = detect_events(
+ t,
+ x,
+ bg_clean,
+ snr_threshold=np.float32(detection_snr),
+ min_event_len=min_event_n,
+ min_event_amp=min_event_amp,
+ widen_frac=np.float32(widen_frac),
+ global_noise=global_noise,
+ signal_polarity=signal_polarity,
+ )
+
+ events = merge_overlapping_events(events)
+ logger.info(f"Detected {len(events)} final events")
+
+ return events
+
+
+def analyze_events(
+ t: np.ndarray,
+ x: np.ndarray,
+ bg_clean: np.ndarray,
+ events: np.ndarray,
+ global_noise: np.float32,
+ signal_polarity: int,
+) -> None:
+ """
+ Stage 7: Analyze event characteristics.
+
+ Parameters
+ ----------
+ t : np.ndarray
+ Time array.
+ x : np.ndarray
+ Signal array.
+ bg_clean : np.ndarray
+ Clean background estimate.
+ events : np.ndarray
+ Events array.
+ global_noise : np.float32
+ Estimated noise level.
+ signal_polarity : int
+ Signal polarity (-1 for negative, +1 for positive).
+ """
+ if len(events) == 0:
+ logger.info("No events to analyze")
+ return
+ if len(events) > 1000:
+ logger.warning(
+ f"Detected {len(events)} events, which is more than 1000. Skipping analysis."
+ )
+ return
+
+ event_durations = (events[:, 1] - events[:, 0]) * 1000000 # Convert to µs
+ event_amplitudes = []
+
+ for t_start, t_end in events:
+ event_mask = (t >= t_start) & (t < t_end)
+ if np.any(event_mask):
+ if signal_polarity < 0:
+ amp = np.min(x[event_mask] - bg_clean[event_mask])
+ else:
+ amp = np.max(x[event_mask] - bg_clean[event_mask])
+ event_amplitudes.append(abs(amp))
+
+ if event_amplitudes:
+ logger.info(
+ f"Event durations (µs): min={np.min(event_durations):.2f}, max={np.max(event_durations):.2f}, mean={np.mean(event_durations):.2f}"
+ )
+ logger.info(
+ f"Event amplitudes: min={np.min(event_amplitudes):.3g}, max={np.max(event_amplitudes):.3g}, mean={np.mean(event_amplitudes):.3g}"
+ )
+ logger.info(
+ f"Event amplitude SNRs: min={np.min(event_amplitudes) / global_noise:.2f}, max={np.max(event_amplitudes) / global_noise:.2f}"
+ )
+
+ final_signal_rms = np.sqrt(np.mean(x**2))
+ final_noise_pct_rms = (
+ 100 * global_noise / final_signal_rms if final_signal_rms > 0 else 0
+ )
+ final_signal_range = np.max(x) - np.min(x)
+ final_noise_pct_range = (
+ 100 * global_noise / final_signal_range if final_signal_range > 0 else 0
+ )
+
+ logger.info(
+ f"Noise summary: {global_noise:.3g} ({final_noise_pct_rms:.1f}% of RMS, {final_noise_pct_range:.1f}% of range)"
+ )
+
+
+def create_oscilloscope_plot(
+ t: np.ndarray,
+ x: np.ndarray,
+ bg_initial: np.ndarray,
+ bg_clean: np.ndarray,
+ events: np.ndarray,
+ detection_threshold: np.ndarray,
+ keep_threshold: np.ndarray,
+ name: str,
+ detection_snr: float,
+ min_event_keep_snr: float,
+ max_plot_points: int,
+ envelope_mode_limit: float,
+ smooth_n: int,
+ global_noise: Optional[np.float32] = None,
+) -> OscilloscopePlot:
+ """
+ Stage 6: Create oscilloscope plot with all visualization elements.
+
+ Parameters
+ ----------
+ t : np.ndarray
+ Time array.
+ x : np.ndarray
+ Signal array.
+ bg_initial : np.ndarray
+ Initial background estimate.
+ bg_clean : np.ndarray
+ Clean background estimate.
+ events : np.ndarray
+ Events array.
+ detection_threshold : np.ndarray
+ Detection threshold array.
+ keep_threshold : np.ndarray
+ Keep threshold array.
+ name : str
+ Name for the plot.
+ detection_snr : float
+ Detection SNR threshold.
+ min_event_keep_snr : float
+ Minimum event keep SNR threshold.
+ max_plot_points : int
+ Maximum plot points for decimation.
+ envelope_mode_limit : float
+ Envelope mode limit.
+ smooth_n : int
+ Smoothing window size in samples.
+ global_noise : Optional[np.float32], default=None
+ Estimated noise level. If provided, will be plotted as a ribbon.
+
+ Returns
+ -------
+ OscilloscopePlot
+ Configured oscilloscope plot.
+ """
+ logger.info("Creating visualization")
+
+ plot_name = name
+ if global_noise is not None:
+ plot_signal_rms = np.sqrt(np.mean(x**2))
+ plot_noise_pct_rms = (
+ 100 * global_noise / plot_signal_rms if plot_signal_rms > 0 else 0
+ )
+ plot_name = f"{name} | Global noise: {global_noise:.3g} ({plot_noise_pct_rms:.1f}% of RMS)"
+
+ plot = OscilloscopePlot(
+ t,
+ x,
+ name=plot_name,
+ max_plot_points=max_plot_points,
+ mode_switch_threshold=envelope_mode_limit,
+ envelope_window_samples=None, # Envelope window now calculated automatically based on zoom
+ )
+
+ plot.add_line(
+ t,
+ bg_clean,
+ label="Background",
+ color="orange",
+ alpha=0.6,
+ linewidth=1.5,
+ display_mode=OscilloscopePlot.MODE_BOTH,
+ )
+
+ if global_noise is not None:
+ plot.add_ribbon(
+ t,
+ bg_clean,
+ global_noise,
+ label="Noise (±1σ)",
+ color="gray",
+ alpha=0.3,
+ display_mode=OscilloscopePlot.MODE_DETAIL,
+ )
+
+ plot.add_line(
+ t,
+ detection_threshold,
+ label=f"Detection ({detection_snr}σ)",
+ color="red",
+ alpha=0.7,
+ linestyle=":",
+ linewidth=1.5,
+ display_mode=OscilloscopePlot.MODE_DETAIL,
+ )
+
+ plot.add_line(
+ t,
+ keep_threshold,
+ label=f"Keep ({min_event_keep_snr}σ)",
+ color="darkred",
+ alpha=0.7,
+ linestyle="--",
+ linewidth=1.5,
+ display_mode=OscilloscopePlot.MODE_DETAIL,
+ )
+
+ if len(events) > 0:
+ plot.add_regions(
+ events,
+ label="Events",
+ color="crimson",
+ alpha=0.4,
+ display_mode=OscilloscopePlot.MODE_BOTH,
+ )
+
+ plot.render()
+ return plot
+
+
+def initialize_state(config: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Initialise the state dictionary for processing.
+
+ Parameters
+ ----------
+ config : Dict[str, Any]
+ Configuration dictionary containing analysis parameters.
+
+ Returns
+ -------
+ Dict[str, Any]
+ The initial state dictionary.
+ """
+ # Normalise keys to lowercase to allow for flexible config from user scripts
+ normalized_config = {k.lower(): v for k, v in config.items()}
+ return {
+ "config": normalized_config,
+ "events": [], # To store lists of events from each chunk
+ "overlap_buffer": {"t": np.array([]), "x": np.array([])}, # For seamless filtering
+ "incomplete_event": None, # To handle events spanning chunks
+ }
+
+
+def get_final_events(state: Dict[str, Any]) -> np.ndarray:
+ """
+ Extract and finalise the list of detected events from the state.
+
+ Parameters
+ ----------
+ state : Dict[str, Any]
+ The final state dictionary after processing all chunks.
+
+ Returns
+ -------
+ np.ndarray
+ The final, merged list of all detected events.
+ """
+ if not state["events"]:
+ return np.empty((0, 2), dtype=np.float32)
+
+ all_events = np.vstack(state["events"])
+ return merge_overlapping_events(all_events)
+
+
+def process_chunk(
+ data: Tuple[np.ndarray, np.ndarray], state: Dict[str, Any]
+) -> Dict[str, Any]:
+ """
+ Process a single data chunk to find events.
+
+ This function contains the core analysis pipeline. It takes a data chunk
+ and the current state, performs event detection, and returns the updated
+ state along with intermediate results for plotting.
+
+ Parameters
+ ----------
+ data : Tuple[np.ndarray, np.ndarray]
+ A tuple containing the time (t) and signal (x) arrays for the chunk.
+ state : Dict[str, Any]
+ The current state dictionary.
+
+ Returns
+ -------
+ Dict[str, Any]
+ A dictionary containing the updated state and intermediate results:
+ - "state": The updated state dictionary.
+ - "bg_initial": The initial background estimate.
+ - "global_noise": The estimated global noise for the chunk.
+ - "events_initial": The initially detected events.
+ - "bg_clean": The cleaned background estimate.
+ - "events": The final detected events for the chunk.
+ """
+ t_chunk, x_chunk = data
+ config = state["config"]
+ overlap_buffer = state["overlap_buffer"]
+ smooth_n = config["smooth_n"]
+
+ # Prepend overlap buffer from previous chunk
+ t = np.concatenate((overlap_buffer["t"], t_chunk))
+ x = np.concatenate((overlap_buffer["x"], x_chunk))
+
+ # If this is the first chunk, there's nothing to process if data is too short
+ if len(x) < smooth_n:
+ # Not enough data to process, just buffer it for the next chunk
+ state["overlap_buffer"] = {"t": t, "x": x}
+ return {
+ "state": state,
+ "bg_initial": np.array([], dtype=np.float32),
+ "global_noise": np.float32(0),
+ "events_initial": np.array([], dtype=np.float32),
+ "bg_clean": np.array([], dtype=np.float32),
+ "events": np.array([], dtype=np.float32),
+ }
+
+ # Update overlap buffer for next iteration
+ # The buffer should be `smooth_n` points long for filtering
+ overlap_size = min(len(t), smooth_n)
+ state["overlap_buffer"] = {
+ "t": t[-overlap_size:],
+ "x": x[-overlap_size:],
+ }
+
+ # Extract parameters from config
+ min_event_n = config["min_event_n"]
+ filter_type = config["filter_type"]
+ detection_snr = config["detection_snr"]
+ min_event_keep_snr = config["min_event_keep_snr"]
+ widen_frac = config["widen_frac"]
+ signal_polarity = config["signal_polarity"]
+ filter_order = config["filter_order"]
+
+ # Stage 2: Background Calculation
+ bg_initial = calculate_initial_background(t, x, smooth_n, filter_type)
+ global_noise = estimate_noise(x, bg_initial)
+
+ # Stage 3: Initial Event Detection
+ events_initial = detect_initial_events(
+ t,
+ x,
+ bg_initial,
+ global_noise,
+ detection_snr,
+ min_event_keep_snr,
+ widen_frac,
+ signal_polarity,
+ min_event_n,
+ )
+
+ # Stage 4: Clean Background Calculation
+ bg_clean = calculate_clean_background(
+ t, x, events_initial, smooth_n, bg_initial, filter_type, filter_order
+ )
+
+ # Stage 5: Final Event Detection
+ events = detect_final_events(
+ t,
+ x,
+ bg_clean,
+ global_noise,
+ detection_snr,
+ min_event_keep_snr,
+ widen_frac,
+ signal_polarity,
+ min_event_n,
+ )
+
+ # --- Handle events spanning chunk boundaries ---
+ # Check for and merge an incomplete event from the previous chunk
+ if state["incomplete_event"] is not None:
+ if len(events) > 0 and events[0, 0] <= t[0]:
+ # The first event in this chunk is a continuation of the previous one.
+ # Merge by updating the end time of the stored incomplete event.
+ state["incomplete_event"][1] = events[0, 1]
+ # Remove the partial event from this chunk's list.
+ events = events[1:]
+ else:
+ # The incomplete event was not continued. It's now complete.
+ # Add it to the final list and clear the state.
+ state["events"].append(np.array([state["incomplete_event"]]))
+ state["incomplete_event"] = None
+
+ # Check if the last event in this chunk is incomplete
+ if len(events) > 0:
+ # An event is incomplete if its end time is at or beyond the end of the
+ # current processing window `t`. `detect_events` extrapolates the end
+ # time, so a check for >= is sufficient.
+ if events[-1, 1] >= t[-1]:
+ # Store the incomplete event for the next chunk.
+ state["incomplete_event"] = events[-1]
+ # Remove it from this chunk's list.
+ events = events[:-1]
+
+ # Update state with the completed events from this chunk
+ if len(events) > 0:
+ state["events"].append(events)
+
+ return {
+ "state": state,
+ "bg_initial": bg_initial,
+ "global_noise": global_noise,
+ "events_initial": events_initial,
+ "bg_clean": bg_clean,
+ "events": events,
+ }
+
+
+def process_file(
+ name: str,
+ sampling_interval: float,
+ data_path: str,
+ smooth_win_t: Optional[float] = None,
+ smooth_win_f: Optional[float] = None,
+ detection_snr: float = 3.0,
+ min_event_keep_snr: float = 6.0,
+ min_event_t: float = 0.75e-6,
+ widen_frac: float = 10.0,
+ signal_polarity: int = -1,
+ max_plot_points: int = 10000,
+ envelope_mode_limit: float = 10e-3,
+ sidecar: Optional[str] = None,
+ crop: Optional[List[int]] = None,
+ yscale_mode: str = "snr",
+ show_plots: bool = True,
+ filter_type: str = "gaussian",
+ filter_order: int = 2,
+ chunk_size: Optional[int] = None,
+) -> None:
+ """
+ Process a single waveform file for event detection.
+
+ Parameters
+ ----------
+ name : str
+ Filename of the waveform data.
+ sampling_interval : float
+ Sampling interval in seconds.
+ data_path : str
+ Path to data directory.
+ smooth_win_t : Optional[float], default=None
+ Smoothing window in seconds.
+ smooth_win_f : Optional[float], default=None
+ Smoothing window in Hz.
+ detection_snr : float, default=3.0
+ Detection SNR threshold.
+ min_event_keep_snr : float, default=6.0
+ Minimum event keep SNR threshold.
+ min_event_t : float, default=0.75e-6
+ Minimum event duration in seconds.
+ widen_frac : float, default=10.0
+ Fraction to widen detected events.
+ signal_polarity : int, default=-1
+ Signal polarity (-1 for negative, +1 for positive).
+ max_plot_points : int, default=10000
+ Maximum plot points for decimation.
+ envelope_mode_limit : float, default=10e-3
+ Envelope mode limit.
+ sidecar : str, optional
+ XML sidecar filename.
+ crop : List[int], optional
+ Crop indices [start, end].
+ yscale_mode : str, default="snr"
+ Y-axis scaling mode for event plotter.
+ show_plots : bool, default=True
+ Whether to show plots interactively.
+ filter_type : str, default="gaussian"
+ Filter type for background smoothing: "savgol", "gaussian", "moving_average", "median".
+ filter_order : int, default=2
+ Order of the Savitzky-Golay filter (only used for filter_type="savgol").
+ """
+ start_time = time.time()
+ logger.info(f"Processing {name} with parameters:")
+
+ analysis_dir = data_path[:-1] if data_path.endswith("/") else data_path
+ analysis_dir += "_analysis/"
+ if not os.path.exists(analysis_dir):
+ os.makedirs(analysis_dir)
+
+ # Extract and save preview image
+ sidecar_path = _get_xml_sidecar_path(name, data_path, sidecar)
+ logger.info(f"Attempting to extract preview from: {sidecar_path}")
+ preview_path = os.path.join(analysis_dir, f"{name}_preview.png")
+ saved_preview = extract_preview_image(sidecar_path, preview_path)
+
+ if saved_preview and show_plots:
+ plot_preview_image(saved_preview, f"Preview: {name}")
+
+ # Calculate parameters
+ smooth_n, min_event_n = calculate_smoothing_parameters(
+ sampling_interval,
+ smooth_win_t,
+ smooth_win_f,
+ min_event_t,
+ detection_snr,
+ min_event_keep_snr,
+ widen_frac,
+ signal_polarity,
+ )
+
+ # --- Refactored analysis pipeline ---
+ config = {
+ "sampling_interval": sampling_interval,
+ "smooth_win_t": smooth_win_t,
+ "smooth_win_f": smooth_win_f,
+ "detection_snr": detection_snr,
+ "min_event_keep_snr": min_event_keep_snr,
+ "min_event_t": min_event_t,
+ "widen_frac": widen_frac,
+ "signal_polarity": signal_polarity,
+ "filter_type": filter_type,
+ "filter_order": filter_order,
+ "smooth_n": smooth_n,
+ "min_event_n": min_event_n,
+ }
+
+ state = initialize_state(config)
+
+ if chunk_size is None:
+ # --- Original full-file processing ---
+ t, x = load_data(name, sampling_interval, data_path, sidecar, crop)
+
+ # For now, process the entire file as a single chunk
+ process_start_time = time.time()
+ results = process_chunk((t, x), state)
+ final_events = get_final_events(results["state"])
+ logger.debug(f"Core processing took {time.time() - process_start_time:.3f}s")
+
+ # Extract intermediate results for plotting and analysis
+ bg_initial = results["bg_initial"]
+ global_noise = results["global_noise"]
+ bg_clean = results["bg_clean"]
+
+ # Analyze thresholds
+ detection_threshold, keep_threshold = analyze_thresholds(
+ x,
+ bg_clean,
+ global_noise,
+ detection_snr,
+ min_event_keep_snr,
+ signal_polarity,
+ )
+
+ # Stage 7: Event Analysis
+ analyze_events(t, x, bg_clean, final_events, global_noise, signal_polarity)
+
+ logger.debug(f"Total processing time: {time.time() - start_time:.3f}s")
+
+ # Stage 6: Visualization
+ plot = create_oscilloscope_plot(
+ t,
+ x,
+ bg_initial,
+ bg_clean,
+ final_events,
+ detection_threshold,
+ keep_threshold,
+ name,
+ detection_snr,
+ min_event_keep_snr,
+ max_plot_points,
+ envelope_mode_limit,
+ smooth_n,
+ global_noise=global_noise,
+ )
+
+ # Save plots
+
+ plot.save(analysis_dir + f"{name}_trace.png")
+
+ # Create event plotter
+ event_plotter = EventPlotter(
+ plot,
+ final_events,
+ bg_clean=bg_clean,
+ global_noise=global_noise,
+ y_scale_mode=yscale_mode,
+ )
+ event_plotter.plot_events_grid(max_events=16)
+ event_plotter.save(analysis_dir + f"{name}_events.png")
+
+ if show_plots:
+ plt.show(block=True)
+ else:
+ # --- Chunked processing ---
+ logger.info(
+ f"--- Starting chunked processing with chunk size: {chunk_size} ---"
+ )
+
+ chunk_generator = rd_chunked(
+ name,
+ chunk_size=chunk_size,
+ sampling_interval=sampling_interval,
+ data_path=data_path,
+ sidecar=sidecar,
+ )
+
+ process_start_time = time.time()
+
+ for t_chunk, x_chunk in chunk_generator:
+ results = process_chunk((t_chunk, x_chunk), state)
+ state = results["state"]
+
+ # After processing all chunks, add any remaining incomplete event
+ if state.get("incomplete_event") is not None:
+ state["events"].append(np.array([state["incomplete_event"]]))
+ state["incomplete_event"] = None
+
+ final_events = get_final_events(state)
+ logger.debug(f"Core processing took {time.time() - process_start_time:.3f}s")
+
+ logger.success(
+ f"Chunked processing complete. Found {len(final_events)} events."
+ )
+ if len(final_events) > 0:
+ logger.info("Final events (first 10):")
+ for i, event in enumerate(final_events[:10]):
+ logger.info(
+ f" Event {i+1}: start={event[0]:.6f}s, end={event[1]:.6f}s"
+ )
+
+ logger.warning("Plotting is disabled in chunked processing mode.")
diff --git a/src/transivent/event_detector.py b/src/transivent/event_detector.py
new file mode 100644
index 0000000..22d76d4
--- /dev/null
+++ b/src/transivent/event_detector.py
@@ -0,0 +1,404 @@
+from typing import Optional, Tuple
+
+import numpy as np
+from loguru import logger
+from numba import njit
+
+# --- Constants ---
+MEDIAN_TO_STD_FACTOR = (
+ 1.4826 # Factor to convert median absolute deviation to standard deviation
+)
+
+
+@njit
+def detect_events_numba(
+ time: np.ndarray,
+ signal: np.ndarray,
+ bg: np.ndarray,
+ snr_threshold: float,
+ min_event_len: int,
+ min_event_amp: float,
+ widen_frac: float,
+ global_noise: float,
+ signal_polarity: int,
+) -> np.ndarray:
+ """
+ Detect events in signal using Numba for performance.
+
+ Uses pre-allocated NumPy arrays instead of dynamic lists for better performance.
+
+ Parameters
+ ----------
+ time : np.ndarray
+ Time array (float32).
+ signal : np.ndarray
+ Input signal array (float32).
+ bg : np.ndarray
+ Background/baseline array (float32).
+ snr_threshold : float
+ Signal-to-noise ratio threshold for detection.
+ min_event_len : int
+ Minimum event length in samples.
+ min_event_amp : float
+ Minimum event amplitude threshold.
+ widen_frac : float
+ Fraction to widen detected events.
+ global_noise : float
+ Global noise level.
+ signal_polarity : int
+ Signal polarity: -1 for negative events, +1 for positive events.
+
+ Returns
+ -------
+ np.ndarray
+ Array of shape (n_events, 2) with start and end indices of events.
+ """
+ # Cast scalar parameters to float32 for consistency
+ snr_threshold = np.float32(snr_threshold)
+ min_event_amp = np.float32(min_event_amp)
+ widen_frac = np.float32(widen_frac)
+ global_noise = np.float32(global_noise)
+
+ if signal_polarity < 0:
+ threshold = bg - snr_threshold * global_noise
+ above = signal < threshold
+ else:
+ threshold = bg + snr_threshold * global_noise
+ above = signal > threshold
+
+ # Pre-allocate maximum possible events (worst case: every other sample is an event)
+ max_events = len(signal) // 2
+ events = np.empty((max_events, 2), dtype=np.int64)
+ event_count = 0
+
+ in_event = False
+ start = 0
+
+ for i in range(len(above)):
+ val = above[i]
+ if val and not in_event:
+ start = i
+ in_event = True
+ elif not val and in_event:
+ end = i
+ event_len = end - start
+ if event_len < min_event_len:
+ in_event = False
+ continue
+
+ # Amplitude filter
+ if min_event_amp > 0.0:
+ if signal_polarity < 0:
+ if np.min(signal[start:end] - bg[start:end]) > -min_event_amp:
+ in_event = False
+ continue
+ else:
+ if np.max(signal[start:end] - bg[start:end]) < min_event_amp:
+ in_event = False
+ continue
+
+ # Widen event
+ widen = int(widen_frac * (end - start))
+ new_start = max(0, start - widen)
+ new_end = min(len(signal), end + widen)
+
+ # Store indices for now, convert to time outside numba
+ events[event_count, 0] = new_start
+ events[event_count, 1] = new_end
+ event_count += 1
+ in_event = False
+
+ # Handle event at end of signal
+ if in_event:
+ end = len(signal)
+ event_len = end - start
+ if event_len >= min_event_len:
+ if min_event_amp > 0.0:
+ if signal_polarity < 0:
+ if np.min(signal[start:end] - bg[start:end]) <= -min_event_amp:
+ widen = int(widen_frac * (end - start))
+ new_start = max(0, start - widen)
+ new_end = min(len(signal), end + widen)
+ # Store indices for now, convert to time outside numba
+ events[event_count, 0] = new_start
+ events[event_count, 1] = new_end
+ event_count += 1
+ else:
+ if np.max(signal[start:end] - bg[start:end]) >= min_event_amp:
+ widen = int(widen_frac * (end - start))
+ new_start = max(0, start - widen)
+ new_end = min(len(signal), end + widen)
+ # Store indices for now, convert to time outside numba
+ events[event_count, 0] = new_start
+ events[event_count, 1] = new_end
+ event_count += 1
+ else:
+ widen = int(widen_frac * (end - start))
+ new_start = max(0, start - widen)
+ new_end = min(len(signal), end + widen)
+ # Store indices for now, convert to time outside numba
+ events[event_count, 0] = new_start
+ events[event_count, 1] = new_end
+ event_count += 1
+
+ # Return only the filled portion
+ return events[:event_count]
+
+
+@njit
+def merge_overlapping_events_numba(events: np.ndarray) -> np.ndarray:
+ """
+ Merge overlapping events using Numba for performance.
+
+ Parameters
+ ----------
+ events : np.ndarray
+ Array of shape (n_events, 2) with start and end times.
+
+ Returns
+ -------
+ np.ndarray
+ Array of merged events with shape (n_merged, 2).
+ """
+ n = len(events)
+ if n == 0:
+ return np.empty((0, 2), dtype=np.float32)
+ arr = events # type transfer
+ arr = arr[np.argsort(arr[:, 0])]
+ merged = np.empty((n, 2), dtype=np.float32)
+ count = 0
+ merged[count] = arr[0]
+ count += 1
+ for i in range(1, arr.shape[0]):
+ start, end = arr[i]
+ last_start, last_end = merged[count - 1]
+ if start <= last_end:
+ merged[count - 1, 1] = max(last_end, end)
+ else:
+ merged[count] = arr[i]
+ count += 1
+ return merged[:count]
+
+
+def detect_events(
+ time: np.ndarray,
+ signal: np.ndarray,
+ bg: np.ndarray,
+ snr_threshold: np.float32 = np.float32(2.0),
+ min_event_len: int = 20,
+ min_event_amp: np.float32 = np.float32(0.0),
+ widen_frac: np.float32 = np.float32(0.5),
+ global_noise: Optional[np.float32] = None,
+ signal_polarity: int = -1,
+) -> Tuple[np.ndarray, np.float32]:
+ """
+ Detect events in signal above background with specified thresholds.
+
+ Parameters
+ ----------
+ time : np.ndarray
+ Time array.
+ signal : np.ndarray
+ Input signal array.
+ bg : np.ndarray
+ Background/baseline array.
+ snr_threshold : np.float32, default=2.0
+ Signal-to-noise ratio threshold for detection.
+ min_event_len : int, default=20
+ Minimum event length in samples.
+ min_event_amp : np.float32, default=0.0
+ Minimum event amplitude threshold.
+ widen_frac : np.float32, default=0.5
+ Fraction to widen detected events.
+ global_noise : np.float32, optional
+ Global noise level. Must be provided.
+ signal_polarity : int, default=-1
+ Signal polarity: -1 for negative events (below background), +1 for positive events (above background).
+
+ Returns
+ -------
+ Tuple[np.ndarray, np.float32]
+ Array of detected events (time ranges) and global noise value.
+
+ Raises
+ ------
+ ValueError
+ If global_noise is not provided or input arrays are invalid.
+ """
+ if global_noise is None:
+ logger.error("global_noise was not provided to detect_events.")
+ raise ValueError("global_noise must be provided")
+
+ # Validate and convert input arrays
+ time = np.asarray(time, dtype=np.float32)
+ signal = np.asarray(signal, dtype=np.float32)
+ bg = np.asarray(bg, dtype=np.float32)
+
+ # Validate input data
+ _validate_detection_inputs(
+ time, signal, bg, snr_threshold, min_event_len, global_noise
+ )
+
+ events_indices = detect_events_numba(
+ time,
+ signal,
+ bg,
+ np.float32(snr_threshold),
+ int(min_event_len),
+ np.float32(min_event_amp),
+ np.float32(widen_frac),
+ np.float32(global_noise),
+ int(signal_polarity),
+ )
+
+ # Convert indices to time values outside of numba
+ events_array = np.empty_like(events_indices, dtype=np.float32)
+ for i in range(len(events_indices)):
+ start_idx = int(events_indices[i, 0])
+ end_idx = int(events_indices[i, 1])
+ events_array[i, 0] = time[start_idx]
+ if end_idx < len(time):
+ events_array[i, 1] = time[end_idx]
+ else:
+ # Event extends to the end of the signal. Extrapolate end time.
+ sampling_interval = time[1] - time[0] if len(time) > 1 else 0.0
+ events_array[i, 1] = time[-1] + sampling_interval
+
+ logger.info(f"Raw detection found {len(events_array)} events")
+
+ return events_array, np.float32(global_noise)
+
+
+def _validate_detection_inputs(
+ time: np.ndarray,
+ signal: np.ndarray,
+ bg: np.ndarray,
+ snr_threshold: np.float32,
+ min_event_len: int,
+ global_noise: np.float32,
+) -> None:
+ """
+ Validate inputs for event detection.
+
+ Parameters
+ ----------
+ time : np.ndarray
+ Time array.
+ signal : np.ndarray
+ Signal array.
+ bg : np.ndarray
+ Background array.
+ snr_threshold : np.float32
+ SNR threshold.
+ min_event_len : int
+ Minimum event length.
+ global_noise : np.float32
+ Global noise level.
+
+ Raises
+ ------
+ ValueError
+ If inputs are invalid.
+ """
+ # Check array lengths
+ if not (len(time) == len(signal) == len(bg)):
+ logger.warning(
+ f"Validation Warning: Array length mismatch: time={len(time)}, signal={len(signal)}, bg={len(bg)}. "
+ "This may lead to unexpected behaviour."
+ )
+
+ # Check for empty arrays
+ if len(time) == 0:
+ logger.warning(
+ "Validation Warning: Input arrays are empty. This may lead to unexpected behaviour."
+ )
+
+ # Check time monotonicity with a small tolerance for floating-point comparisons
+ if len(time) > 1:
+ # Use a small epsilon for floating-point comparison
+ # np.finfo(time.dtype).eps is the smallest representable positive number such that 1.0 + eps != 1.0
+ # Multiplying by a small factor (e.g., 10) provides a reasonable tolerance.
+ tolerance = np.finfo(time.dtype).eps * 10
+ if not np.all(np.diff(time) > tolerance):
+ # Log the problematic differences for debugging
+ problematic_diffs = np.diff(time)[np.diff(time) <= tolerance]
+ logger.warning(
+ f"Validation Warning: Time array is not strictly monotonic increasing within tolerance {tolerance}. "
+ f"Problematic diffs (first 10): {problematic_diffs[:10]}. This may lead to unexpected behaviour."
+ )
+
+ # Check parameter validity
+ if snr_threshold <= 0:
+ logger.warning(
+ f"Validation Warning: SNR threshold must be positive, got {snr_threshold}. This may lead to unexpected behaviour."
+ )
+
+ if min_event_len <= 0:
+ logger.warning(
+ f"Validation Warning: Minimum event length must be positive, got {min_event_len}. This may lead to unexpected behaviour."
+ )
+
+ if global_noise <= 0:
+ logger.warning(
+ f"Validation Warning: Global noise must be positive, got {global_noise}. This may lead to unexpected behaviour."
+ )
+
+ # Check for NaN/inf values
+ for name, arr in [("time", time), ("signal", signal), ("bg", bg)]:
+ if not np.all(np.isfinite(arr)):
+ logger.warning(
+ f"Validation Warning: {name} array contains NaN or infinite values. This may lead to unexpected behaviour."
+ )
+
+
+def merge_overlapping_events(events: np.ndarray) -> np.ndarray:
+ """
+ Merge overlapping events.
+
+ Parameters
+ ----------
+ events : np.ndarray
+ Array of events with shape (n_events, 2).
+
+ Returns
+ -------
+ np.ndarray
+ Array of merged events.
+
+ Raises
+ ------
+ ValueError
+ If events array has invalid format.
+ """
+ if len(events) == 0:
+ return np.empty((0, 2), dtype=np.float32)
+
+ # Validate events array format
+ events_array = np.asarray(events, dtype=np.float32)
+ if events_array.ndim != 2 or events_array.shape[1] != 2:
+ logger.warning(
+ f"Validation Warning: Events array must have shape (n_events, 2), got {events_array.shape}. This may lead to unexpected behaviour."
+ )
+ # This specific check is critical for the Numba function's array indexing,
+ # so it's safer to keep it as a ValueError if the shape is fundamentally wrong.
+ # However, for "very permissive", I'll change it to a warning and let Numba potentially fail later.
+ # If Numba fails, we can revert this specific one to ValueError.
+ # For now, let's make it a warning.
+ pass # Continue execution after warning
+
+ # Check for invalid events (start >= end)
+ invalid_mask = events_array[:, 0] >= events_array[:, 1]
+ if np.any(invalid_mask):
+ invalid_indices = np.where(invalid_mask)[0]
+ logger.warning(
+ f"Validation Warning: Invalid events found (start >= end) at indices: {invalid_indices}. This may lead to unexpected behaviour."
+ )
+
+ merged = merge_overlapping_events_numba(events_array)
+
+ if len(merged) != len(events):
+ logger.info(
+ f"Merged {len(events)} → {len(merged)} events ({len(events) - len(merged)} overlaps resolved)"
+ )
+
+ return merged
diff --git a/src/transivent/event_plotter.py b/src/transivent/event_plotter.py
new file mode 100644
index 0000000..c86cb14
--- /dev/null
+++ b/src/transivent/event_plotter.py
@@ -0,0 +1,524 @@
+import warnings
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import matplotlib.figure
+import matplotlib.pyplot as plt
+import numpy as np
+from loguru import logger
+
+from scopekit.display_state import (
+ _create_time_formatter,
+ _determine_offset_display_params,
+ _get_optimal_time_unit_and_scale,
+)
+from scopekit.plot import OscilloscopePlot
+
+
+class EventPlotter:
+ """
+ Provides utility functions for plotting individual events or event grids.
+ """
+
+ def __init__(
+ self,
+ osc_plot: OscilloscopePlot,
+ events: Optional[np.ndarray] = None,
+ trace_idx: int = 0,
+ bg_clean: Optional[np.ndarray] = None,
+ global_noise: Optional[np.float32] = None,
+ y_scale_mode: str = "raw",
+ ):
+ """
+ Initialize the EventPlotter with an OscilloscopePlot instance.
+
+ Parameters
+ ----------
+ osc_plot : OscilloscopePlot
+ An instance of OscilloscopePlot containing the waveform data.
+ events : Optional[np.ndarray], default=None
+ Events array with shape (n_events, 2) where each row is [start_time, end_time].
+ If None, will try to extract events from regions in the OscilloscopePlot.
+ trace_idx : int, default=0
+ Index of the trace to extract events from.
+ bg_clean : Optional[np.ndarray], default=None
+ The clean background signal array. This is needed for plotting background in event views.
+ global_noise : Optional[np.float32], default=None
+ The estimated global noise level. If provided, a noise ribbon will be plotted around bg_clean.
+ y_scale_mode : str, default="raw"
+ Y-axis scaling mode. Options:
+ - "raw": Raw signal values
+ - "percent": Percentage contrast relative to background ((signal - bg) / bg * 100)
+ - "snr": Signal-to-noise ratio ((signal - bg) / noise)
+ """
+ self.osc_plot = osc_plot
+ self.trace_idx = trace_idx
+ self.bg_clean = bg_clean
+ self.global_noise = global_noise # Store global_noise here
+ self.y_scale_mode = y_scale_mode
+
+ # Extract events from regions if not provided
+ if events is None:
+ self.events = self._extract_events_from_regions()
+ else:
+ self.events = events
+
+ if self.events is None or len(self.events) == 0:
+ logger.warning("EventPlotter initialized but no events are available.")
+ self.events = np.array([]) # Ensure it's an empty array if no events
+
+ # Validate y_scale_mode
+ valid_modes = ["raw", "percent", "snr"]
+ if self.y_scale_mode not in valid_modes:
+ logger.warning(
+ f"Invalid y_scale_mode '{self.y_scale_mode}'. Using 'raw'. Valid options: {valid_modes}"
+ )
+ self.y_scale_mode = "raw"
+
+ # Warn if scaling mode requires data that's not available
+ if self.y_scale_mode == "percent" and self.bg_clean is None:
+ logger.warning(
+ "y_scale_mode='percent' requires bg_clean data. Falling back to 'raw' mode."
+ )
+ self.y_scale_mode = "raw"
+ elif self.y_scale_mode == "snr" and self.global_noise is None:
+ logger.warning(
+ "y_scale_mode='snr' requires global_noise data. Falling back to 'raw' mode."
+ )
+ self.y_scale_mode = "raw"
+
+ self.fig: Optional[matplotlib.figure.Figure] = None
+
+ def save(self, filepath: str):
+ """
+ Save the current state of the EventPlotter to a file.
+
+ Parameters
+ ----------
+ filepath : str
+ Path to save the EventPlotter state.
+ """
+ if self.fig is not None:
+ self.fig.savefig(filepath)
+ logger.info(f"EventPlotter figure saved to {filepath}")
+
+ def _extract_events_from_regions(self) -> Optional[np.ndarray]:
+ """
+ Extract events from regions in the OscilloscopePlot.
+
+ Returns
+ -------
+ Optional[np.ndarray]
+ Events array with shape (n_events, 2) where each row is [start_time, end_time].
+ """
+ # First check if events are stored in the data manager
+ if hasattr(self.osc_plot.data, "get_events"):
+ events = self.osc_plot.data.get_events(self.trace_idx)
+ if events is not None:
+ return events
+
+ # If not, try to extract from regions (backward compatibility)
+ if not hasattr(self.osc_plot, "_regions") or not self.osc_plot._regions:
+ return None
+
+ # Extract regions from the specified trace
+ trace_regions = self.osc_plot._regions[self.trace_idx]
+ if not trace_regions:
+ return None
+
+ # Combine all regions into a single array
+ all_events = []
+ for region_def in trace_regions:
+ if "regions" in region_def and region_def["regions"] is not None:
+ all_events.append(region_def["regions"])
+
+ if not all_events:
+ return None
+
+ # Concatenate all region arrays
+ return np.vstack(all_events)
+
+ def _scale_y_data(
+ self, y_data: np.ndarray, bg_data: Optional[np.ndarray], mask: np.ndarray
+ ) -> Tuple[np.ndarray, str]:
+ """
+ Scale y-data according to the current scaling mode.
+
+ Parameters
+ ----------
+ y_data : np.ndarray
+ Raw signal data.
+ bg_data : Optional[np.ndarray]
+ Background data array (same length as full signal).
+ mask : np.ndarray
+ Boolean mask for extracting the relevant portion of bg_data.
+
+ Returns
+ -------
+ Tuple[np.ndarray, str]
+ Scaled y-data and appropriate y-axis label.
+ """
+ if self.y_scale_mode == "percent" and bg_data is not None:
+ bg_event = bg_data[mask]
+ # Avoid division by zero - use small value for near-zero background
+ bg_safe = np.where(np.abs(bg_event) < 1e-12, 1e-12, bg_event)
+ scaled_data = 100 * (y_data - bg_event) / bg_safe
+ return scaled_data, "Contrast (%)"
+ elif self.y_scale_mode == "snr" and self.global_noise is not None:
+ bg_event = bg_data[mask] if bg_data is not None else 0
+ scaled_data = (y_data - bg_event) / self.global_noise
+ return scaled_data, "Signal (σ)"
+ else:
+ return y_data, "Signal"
+
+ def _scale_background_data(self, bg_data: np.ndarray) -> np.ndarray:
+ """
+ Scale background data according to the current scaling mode.
+
+ Parameters
+ ----------
+ bg_data : np.ndarray
+ Background data.
+
+ Returns
+ -------
+ np.ndarray
+ Scaled background data.
+ """
+ if self.y_scale_mode == "percent":
+ # In percentage mode, background becomes 0% contrast
+ return np.zeros_like(bg_data)
+ elif self.y_scale_mode == "snr":
+ # In SNR mode, background becomes 0 sigma
+ return np.zeros_like(bg_data)
+ else:
+ return bg_data
+
+ def _scale_noise_ribbon(
+ self, bg_data: np.ndarray, noise_level: np.float32
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Scale noise ribbon bounds according to the current scaling mode.
+
+ Parameters
+ ----------
+ bg_data : np.ndarray
+ Background data.
+ noise_level : np.float32
+ Noise level (±1σ).
+
+ Returns
+ -------
+ Tuple[np.ndarray, np.ndarray]
+ Lower and upper bounds for noise ribbon.
+ """
+ if self.y_scale_mode == "percent":
+ # In percentage mode, noise becomes ±(noise/bg * 100)%
+ bg_safe = np.where(np.abs(bg_data) < 1e-12, 1e-12, bg_data)
+ noise_percent = 100 * noise_level / np.abs(bg_safe)
+ return -noise_percent, noise_percent
+ elif self.y_scale_mode == "snr":
+ # In SNR mode, noise becomes ±1σ
+ return np.full_like(bg_data, -1.0), np.full_like(bg_data, 1.0)
+ else:
+ return bg_data - noise_level, bg_data + noise_level
+
+ def plot_single_event(self, event_index: int) -> None:
+ """
+ Plot an individual event.
+
+ Parameters
+ ----------
+ event_index : int
+ Index of the event to plot.
+ """
+ if self.events is None or len(self.events) == 0:
+ logger.warning("No events available to plot.")
+ return
+
+ if not (0 <= event_index < len(self.events)):
+ logger.warning(
+ f"Event index {event_index} out of bounds. Total events: {len(self.events)}."
+ )
+ return
+
+ t_start, t_end = self.events[event_index]
+
+ # Get the raw data for the specific trace
+ t_raw = self.osc_plot.data.t_arrays[self.trace_idx]
+ x_raw = self.osc_plot.data.x_arrays[self.trace_idx]
+
+ # Define a window around the event
+ event_duration = t_end - t_start
+ plot_start_time = t_start - event_duration * 0.5
+ plot_end_time = t_end + event_duration * 0.5
+
+ # Get data within the plot window
+ mask = (t_raw >= plot_start_time) & (t_raw <= plot_end_time)
+ t_event_raw = t_raw[mask]
+ x_event = x_raw[mask]
+
+ if not np.any(mask):
+ logger.warning(
+ f"No data found for event {event_index} in time range [{t_start:.6f}, {t_end:.6f}]"
+ )
+ return
+
+ # Extract event data and make relative to plot window start
+ t_event_raw_relative = t_event_raw - plot_start_time
+
+ # Determine offset display parameters
+ time_span_raw = plot_end_time - plot_start_time
+ event_time_unit, display_scale, offset_time_raw, offset_unit = (
+ _determine_offset_display_params(
+ (plot_start_time, plot_end_time), time_span_raw
+ )
+ )
+
+ # Scale for display using the display_scale from offset params
+ t_event_display = t_event_raw_relative * display_scale
+
+ # Create time formatter for axis
+ time_formatter = _create_time_formatter(offset_time_raw, display_scale)
+
+ # Scale the data according to the selected mode
+ x_event_scaled, ylabel = self._scale_y_data(x_event, self.bg_clean, mask)
+
+ self.fig, ax_ev = plt.subplots(figsize=(6, 3))
+ ax_ev.plot(
+ t_event_display,
+ x_event_scaled,
+ label="Event",
+ color="black",
+ # marker="o",
+ mfc="none",
+ )
+
+ if self.bg_clean is not None:
+ bg_event_scaled = self._scale_background_data(self.bg_clean[mask])
+ ax_ev.plot(
+ t_event_display,
+ bg_event_scaled,
+ label="BG",
+ color="orange",
+ ls="--",
+ )
+ if self.global_noise is not None:
+ # Plot noise ribbon around the background
+ noise_lower, noise_upper = self._scale_noise_ribbon(
+ self.bg_clean[mask], self.global_noise
+ )
+ ax_ev.fill_between(
+ t_event_display,
+ noise_lower,
+ noise_upper,
+ color="gray",
+ alpha=0.3,
+ label="Noise (±1σ)",
+ )
+ else:
+ logger.warning(
+ "Clean background (bg_clean) not provided to EventPlotter, cannot plot."
+ )
+
+ # Set xlabel with offset if applicable
+ if offset_time_raw is not None:
+ # Use the offset unit for display, not the event time unit
+ offset_scale = 1.0
+ if offset_unit == "ms":
+ offset_scale = 1e3
+ elif offset_unit == "us":
+ offset_scale = 1e6
+ elif offset_unit == "ns":
+ offset_scale = 1e9
+
+ offset_display = offset_time_raw * offset_scale
+ ax_ev.set_xlabel(
+ f"Time ({event_time_unit}) + {offset_display:.3g} {offset_unit}"
+ )
+ else:
+ ax_ev.set_xlabel(f"Time ({event_time_unit})")
+
+ # Apply the time formatter to x-axis
+ ax_ev.xaxis.set_major_formatter(time_formatter)
+ # Use a shorter title - just the base trace name without noise info
+ trace_name = self.osc_plot.data.get_trace_name(self.trace_idx)
+ # Remove noise information from title if present
+ clean_name = trace_name.split(" | ")[0] if " | " in trace_name else trace_name
+ ax_ev.set_title(f"{clean_name} - Event {event_index + 1}")
+ ax_ev.set_ylabel(ylabel)
+ ax_ev.legend(loc="lower right")
+
+ def plot_events_grid(self, max_events: int = 16) -> None:
+ """
+ Plot multiple events in a subplot grid.
+
+ Parameters
+ ----------
+ max_events : int, default=16
+ Maximum number of events to plot in the grid.
+ """
+ if self.events is None or len(self.events) == 0:
+ logger.warning("No events available to plot.")
+ return
+
+ # Limit number of events
+ n_events = min(len(self.events), max_events)
+ events_to_plot = self.events[:n_events]
+
+ # Determine grid size
+ if n_events <= 4:
+ rows, cols = 2, 2
+ elif n_events <= 9:
+ rows, cols = 3, 3
+ elif n_events <= 16:
+ rows, cols = 4, 4
+ elif n_events <= 25:
+ rows, cols = 5, 5
+ else:
+ rows, cols = 6, 6 # Maximum 36 events
+
+ self.fig, axes = plt.subplots(
+ rows, cols, figsize=(cols * 4, rows * 3), sharey=True
+ ) # Sharey for consistent amplitude scale
+ # Get trace name safely and clean it
+ trace_name = self.osc_plot.data.get_trace_name(self.trace_idx)
+ # Remove noise information from title if present
+ clean_name = trace_name.split(" | ")[0] if " | " in trace_name else trace_name
+
+ self.fig.suptitle(
+ f"{clean_name} - Events 1-{n_events} (of {len(self.events)} total)",
+ fontsize=12,
+ )
+
+ # Flatten axes for easier indexing
+ if rows == 1 and cols == 1:
+ axes = [axes]
+ elif rows == 1 or cols == 1:
+ axes = axes.flatten()
+ else:
+ axes = axes.flatten()
+
+ # Get the raw data for the specific trace once
+ t_raw_full = self.osc_plot.data.t_arrays[self.trace_idx]
+ x_raw_full = self.osc_plot.data.x_arrays[self.trace_idx]
+
+ for i, (t_start, t_end) in enumerate(events_to_plot):
+ ax = axes[i]
+
+ # Define a window around the event
+ event_duration = t_end - t_start
+ plot_start_time = t_start - event_duration * 0.5
+ plot_end_time = t_end + event_duration * 0.5
+
+ # Extract event data
+ mask = (t_raw_full >= plot_start_time) & (t_raw_full <= plot_end_time)
+ t_event_raw = t_raw_full[mask]
+ x_event = x_raw_full[mask]
+
+ if not np.any(mask):
+ ax.text(
+ 0.5,
+ 0.5,
+ f"Event {i + 1}\nNo data",
+ ha="center",
+ va="center",
+ transform=ax.transAxes,
+ )
+ ax.set_xticks([])
+ ax.set_yticks([])
+ continue
+
+ # Make time relative to plot window start
+ t_event_raw_relative = t_event_raw - plot_start_time
+
+ # Determine offset display parameters for this event
+ time_span_raw = plot_end_time - plot_start_time
+ event_time_unit, display_scale, offset_time_raw, offset_unit = (
+ _determine_offset_display_params(
+ (plot_start_time, plot_end_time), time_span_raw
+ )
+ )
+
+ # Scale for display using the display_scale from offset params
+ t_event_display = t_event_raw_relative * display_scale
+
+ # Create time formatter for axis
+ time_formatter = _create_time_formatter(offset_time_raw, display_scale)
+
+ # Scale the data according to the selected mode
+ x_event_scaled, ylabel = self._scale_y_data(x_event, self.bg_clean, mask)
+
+ # Plot event
+ ax.plot(
+ t_event_display,
+ x_event_scaled,
+ "-ok",
+ mfc="none",
+ linewidth=1,
+ label="Signal",
+ ms=4,
+ )
+
+ if self.bg_clean is not None:
+ bg_event_scaled = self._scale_background_data(self.bg_clean[mask])
+ ax.plot(
+ t_event_display,
+ bg_event_scaled,
+ "orange",
+ linestyle="--",
+ alpha=0.7,
+ label="BG",
+ )
+ if self.global_noise is not None:
+ # Plot noise ribbon around the background
+ noise_lower, noise_upper = self._scale_noise_ribbon(
+ self.bg_clean[mask], self.global_noise
+ )
+ ax.fill_between(
+ t_event_display,
+ noise_lower,
+ noise_upper,
+ color="gray",
+ alpha=0.3,
+ label="Noise (±1σ)",
+ )
+ else:
+ logger.warning(
+ f"Background data not available for event {i + 1}. Ensure bg_clean is passed to EventPlotter."
+ )
+
+ # Formatting
+ ax.set_title(f"Event {i + 1}", fontsize=10)
+
+ # Set xlabel with offset if applicable
+ if offset_time_raw is not None:
+ # Use the offset unit for display, not the event time unit
+ offset_scale = 1.0
+ if offset_unit == "ms":
+ offset_scale = 1e3
+ elif offset_unit == "us":
+ offset_scale = 1e6
+ elif offset_unit == "ns":
+ offset_scale = 1e9
+
+ offset_display = offset_time_raw * offset_scale
+ ax.set_xlabel(
+ f"Time ({event_time_unit}) + {offset_display:.3g} {offset_unit}",
+ fontsize=8,
+ )
+ else:
+ ax.set_xlabel(f"Time ({event_time_unit})", fontsize=8)
+
+ ax.set_ylabel(ylabel, fontsize=8)
+ ax.tick_params(labelsize=7)
+
+ # Apply the time formatter to x-axis
+ ax.xaxis.set_major_formatter(time_formatter)
+
+ # Only show legend on first subplot
+ if i == 0:
+ ax.legend(fontsize=7, loc="best")
+
+ # Hide unused subplots
+ for i in range(n_events, len(axes)):
+ axes[i].set_visible(False)
diff --git a/src/transivent/io.py b/src/transivent/io.py
new file mode 100644
index 0000000..b7ad1c5
--- /dev/null
+++ b/src/transivent/io.py
@@ -0,0 +1,456 @@
+import os
+import xml.etree.ElementTree as ET
+from typing import Any, Dict, List, Optional, Tuple, Generator
+from warnings import warn
+
+import numpy as np
+from loguru import logger
+
+
+def _get_xml_sidecar_path(
+ bin_filename: str,
+ data_path: Optional[str] = None,
+ sidecar: Optional[str] = None
+) -> str:
+ """
+ Determine the XML sidecar file path using consistent logic.
+
+ Parameters
+ ----------
+ bin_filename : str
+ Name of the binary waveform file.
+ data_path : str, optional
+ Path to the data directory.
+ sidecar : str, optional
+ Name of the XML sidecar file. If None, auto-detects from bin_filename.
+
+ Returns
+ -------
+ str
+ Full path to the XML sidecar file.
+ """
+ if sidecar is not None:
+ sidecar_path = (
+ os.path.join(data_path, sidecar)
+ if data_path is not None and not os.path.isabs(sidecar)
+ else sidecar
+ )
+ else:
+ base = os.path.splitext(bin_filename)[0]
+ if base.endswith(".Wfm"):
+ sidecar_guess = base[:-4] + ".bin"
+ else:
+ sidecar_guess = base + ".bin"
+ sidecar_path = (
+ os.path.join(data_path, sidecar_guess)
+ if data_path is not None
+ else sidecar_guess
+ )
+ return sidecar_path
+
+
+def get_waveform_params(
+ bin_filename: str,
+ data_path: Optional[str] = None,
+ sidecar: Optional[str] = None,
+) -> Dict[str, Any]:
+ """
+ Parse XML sidecar file to extract waveform parameters.
+
+ Given a binary waveform filename, find and parse the corresponding XML sidecar file.
+ If sidecar is provided, use it directly. Otherwise, guess from bin_filename.
+
+ Parameters
+ ----------
+ bin_filename : str
+ Name of the binary waveform file.
+ data_path : str, optional
+ Path to the data directory. If None, uses current directory.
+ sidecar : str, optional
+ Name of the XML sidecar file. If None, guesses from bin_filename.
+
+ Returns
+ -------
+ Dict[str, Any]
+ Dictionary with keys: sampling_interval, vertical_scale, vertical_offset,
+ byte_order, signal_format.
+
+ Raises
+ ------
+ FileNotFoundError
+ If the XML sidecar file is not found.
+ RuntimeError
+ If the XML file cannot be parsed.
+
+ Warns
+ -----
+ UserWarning
+ If sampling resolution is not found in XML.
+ """
+ sidecar_path = _get_xml_sidecar_path(bin_filename, data_path, sidecar)
+ params = {
+ "sampling_interval": None,
+ "vertical_scale": None,
+ "vertical_offset": None,
+ "byte_order": "LSB", # default
+ "signal_format": "float32", # default
+ "signal_hardware_record_length": None,
+ }
+ found_resolution = False
+ if not os.path.exists(sidecar_path):
+ msg = (
+ f"XML sidecar file not found: {sidecar_path}\n"
+ f" bin_filename: {bin_filename}\n"
+ f" sidecar: {sidecar}\n"
+ f" data_path: {data_path}\n"
+ f" Tried path: {sidecar_path}\n"
+ f"Please check that the XML sidecar exists and the path is correct."
+ )
+ raise FileNotFoundError(msg)
+ try:
+ tree = ET.parse(sidecar_path)
+ root = tree.getroot()
+
+ # Validate XML structure
+ if root is None:
+ raise RuntimeError(f"XML file has no root element: {sidecar_path}")
+
+ # Track which parameters we found for validation
+ found_params = set()
+ signal_resolution = None
+ resolution = None
+
+ for prop in root.iter("Prop"):
+ if prop.attrib is None:
+ logger.warning(f"Found Prop element with no attributes in {sidecar_path}")
+ continue
+
+ name = prop.attrib.get("Name", "")
+ value = prop.attrib.get("Value", "")
+
+ if not name:
+ logger.warning(
+ f"Found Prop element with empty Name attribute in {sidecar_path}"
+ )
+ continue
+
+ try:
+ if name == "Resolution":
+ params["sampling_interval"] = float(value)
+ found_resolution = True
+ found_params.add("SignalResolution")
+ resolution = float(value)
+ elif name == "SignalResolution" and params["sampling_interval"] is None:
+ params["sampling_interval"] = float(value)
+ found_resolution = True
+ found_params.add("Resolution")
+ signal_resolution = float(value)
+ elif name == "SignalResolution":
+ signal_resolution = float(value) # store val even if Resolution is found
+ elif name == "ByteOrder":
+ if not value:
+ logger.warning(
+ f"Empty ByteOrder value in {sidecar_path}, using default LSB"
+ )
+ continue
+ params["byte_order"] = "LSB" if "LSB" in value else "MSB"
+ found_params.add("ByteOrder")
+ elif name == "SignalFormat":
+ if not value:
+ logger.warning(
+ f"Empty SignalFormat value in {sidecar_path}, using default float32"
+ )
+ continue
+ if "FLOAT" in value:
+ params["signal_format"] = "float32"
+ elif "INT16" in value:
+ params["signal_format"] = "int16"
+ elif "INT32" in value:
+ params["signal_format"] = "int32"
+ else:
+ logger.warning(
+ f"Unknown SignalFormat '{value}' in {sidecar_path}, using default float32"
+ )
+ found_params.add("SignalFormat")
+ elif name == "SignalHardwareRecordLength":
+ params["signal_hardware_record_length"] = int(value)
+ found_params.add("SignalHardwareRecordLength")
+ except ValueError as e:
+ logger.warning(
+ f"Failed to parse {name} value '{value}' in {sidecar_path}: {e}"
+ )
+ continue
+
+ # Validate critical parameters
+ if not found_resolution:
+ warn(
+ "Neither 'Resolution' nor 'SignalResolution' found in XML. "
+ + "Using default sampling_interval=None. "
+ + "Please provide a value or check your XML."
+ )
+ if (
+ "SignalResolution" in found_params
+ and "Resolution" in found_params
+ and not np.isclose(signal_resolution, resolution, rtol=1e-2, atol=1e-9)
+ ):
+ logger.warning(
+ f"FYI: 'Resolution' ({resolution}) != SignalResolution' ({signal_resolution}) found in {sidecar_path}. "
+ f"Using 'Resolution' ({signal_resolution}). Diff: {abs(signal_resolution - resolution)}"
+ )
+
+ # Log what we found for debugging
+ logger.debug(f"XML parsing found parameters: {found_params}")
+
+ # Validate sampling interval if found
+ if params["sampling_interval"] is not None and params["sampling_interval"] <= 0:
+ logger.warning(
+ f"Invalid sampling interval {params['sampling_interval']} in {sidecar_path}. "
+ "This may lead to issues with time array generation."
+ )
+
+ except ET.ParseError as e:
+ raise RuntimeError(f"XML parsing error in {sidecar_path}: {e}")
+ except Exception as e:
+ raise RuntimeError(f"Failed to parse XML sidecar: {sidecar_path}: {e}")
+ return params
+
+
+def rd(
+ filename: str,
+ sampling_interval: Optional[float] = None,
+ data_path: Optional[str] = None,
+ sidecar: Optional[str] = None,
+ crop: Optional[List[int]] = None,
+) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Read waveform binary file using sidecar XML for parameters.
+
+ Parameters
+ ----------
+ filename : str
+ Name of the binary waveform file.
+ sampling_interval : float, optional
+ Sampling interval in seconds. If None, reads from XML sidecar.
+ data_path : str, optional
+ Path to the data directory.
+ sidecar : str, optional
+ Name of the XML sidecar file.
+ crop : List[int], optional
+ Crop indices [start, end]. If None, uses entire signal.
+
+ Returns
+ -------
+ Tuple[np.ndarray, np.ndarray]
+ Time array (float32) and scaled signal array (float32).
+
+ Raises
+ ------
+ RuntimeError
+ If sampling interval cannot be determined.
+ FileNotFoundError
+ If the binary file is not found.
+ """
+ # Always join data_path and filename if filename is not absolute
+ if data_path is not None and not os.path.isabs(filename):
+ fp = os.path.join(data_path, filename)
+ else:
+ fp = filename
+ params = get_waveform_params(
+ os.path.basename(fp), data_path, sidecar=sidecar
+ )
+ # Use sampling_interval from XML if available, else argument, else raise error
+ si = params["sampling_interval"]
+ if si is None:
+ if sampling_interval is not None:
+ si = sampling_interval
+ else:
+ raise RuntimeError(
+ f"Sampling interval could not be determined for file: {fp}. "
+ + "Please provide it or ensure the XML sidecar is present."
+ )
+ # log info about what we're reading and the parameters
+ rel_fp = os.path.relpath(fp, os.getcwd()) if os.path.isabs(fp) else fp
+ logger.info(f"Reading binary file: {rel_fp}")
+ if sidecar:
+ sidecar_path = _get_xml_sidecar_path(os.path.basename(fp), data_path, sidecar)
+ rel_sidecar = (
+ os.path.relpath(sidecar_path, os.getcwd())
+ if os.path.isabs(sidecar_path)
+ else sidecar_path
+ )
+ logger.info(f"--Using sidecar XML: {rel_sidecar}")
+ logger.info(f"--Sampling interval: {si}")
+ logger.info(f"--Byte order: {params['byte_order']}")
+ logger.info(f"--Signal format: {params['signal_format']}")
+ # Determine dtype
+ dtype = np.float32
+ if params["signal_format"] == "int16":
+ dtype = np.int16
+ elif params["signal_format"] == "int32":
+ dtype = np.int32
+ # Determine byte order
+ byteorder = "<" if params["byte_order"] == "LSB" else ">"
+ try:
+ with open(fp, "rb") as f:
+ import struct
+ # Read first two bytes into two 32-bit unsigned integers,
+ header_bytes = f.read(8)
+ elsize, record_length_from_header = struct.unpack('<II', header_bytes)
+ logger.success(f"Bin header: data el. size: {elsize} (bytes)")
+ logger.success(f"Bin header: length: {record_length_from_header} ({elsize}-byte nums)")
+ params["record_length_from_header"] = record_length_from_header
+ if params["signal_hardware_record_length"] != record_length_from_header:
+ logger.warning(
+ f"SignalHardwareRecordLength ({params['signal_hardware_record_length']}) "
+ f"does not match header record length ({record_length_from_header}) in {rel_fp}. "
+ "This may indicate a mismatch in expected data length."
+ )
+
+ # first 8 bytes are the header (equiv to 2 float32s)
+ arr = np.fromfile(fp, dtype=byteorder + dtype().dtype.char, offset=8)
+
+ # Validate expected data length if available
+ expected_length = params["signal_hardware_record_length"]
+ if expected_length is not None:
+ if len(arr) != expected_length:
+ # raise RuntimeError(
+ logger.warning(
+ f"Data length mismatch in {rel_fp}: "
+ f"expected {expected_length} points from SignalHardwareRecordLength, "
+ f"but read {len(arr)} points from binary file"
+ )
+
+ if crop is not None:
+ x = arr[crop[0] : crop[1]]
+ else:
+ x = arr
+ except FileNotFoundError:
+ raise FileNotFoundError(
+ f"The file '{fp}' was not found. "
+ + "Please ensure the file is in the correct directory."
+ )
+ # x = x.astype(np.float32) # NB: data is already in physical units (V)
+
+ # Use np.linspace for more robust time array generation
+ num_points = len(x)
+ if num_points > 0:
+ t = np.linspace(0, (num_points - 1) * si, num_points, dtype=np.float32)
+ else:
+ t = np.array([], dtype=np.float32)
+ logger.warning(
+ f"Generated an empty time array for file {rel_fp}. "
+ f"Length of signal: {len(x)}, sampling interval: {si}. "
+ "This might indicate an issue with input data or sampling interval."
+ )
+
+ return t, x
+
+
+def rd_chunked(
+ filename: str,
+ chunk_size: int,
+ sampling_interval: Optional[float] = None,
+ data_path: Optional[str] = None,
+ sidecar: Optional[str] = None,
+) -> Generator[Tuple[np.ndarray, np.ndarray], None, None]:
+ """
+ Read waveform binary file in chunks using sidecar XML for parameters.
+
+ This is a generator function that yields chunks of data.
+
+ Parameters
+ ----------
+ filename : str
+ Name of the binary waveform file.
+ chunk_size : int
+ Number of points per chunk.
+ sampling_interval : float, optional
+ Sampling interval in seconds. If None, reads from XML sidecar.
+ data_path : str, optional
+ Path to the data directory.
+ sidecar : str, optional
+ Name of the XML sidecar file.
+
+ Yields
+ ------
+ Tuple[np.ndarray, np.ndarray]
+ Time array (float32) and signal array (float32) for each chunk.
+ """
+ if data_path is not None and not os.path.isabs(filename):
+ fp = os.path.join(data_path, filename)
+ else:
+ fp = filename
+
+ params = get_waveform_params(os.path.basename(fp), data_path, sidecar=sidecar)
+ si = params["sampling_interval"]
+ if si is None:
+ if sampling_interval is not None:
+ si = sampling_interval
+ else:
+ raise RuntimeError(f"Sampling interval could not be determined for file: {fp}.")
+
+ dtype = np.float32
+ if params["signal_format"] == "int16":
+ dtype = np.int16
+ elif params["signal_format"] == "int32":
+ dtype = np.int32
+
+ byteorder = "<" if params["byte_order"] == "LSB" else ">"
+ full_dtype_str = byteorder + dtype().dtype.char
+
+ header_size_bytes = 8
+
+ try:
+ with open(fp, "rb") as f:
+ # Read header
+ header_bytes = f.read(header_size_bytes)
+ if len(header_bytes) < header_size_bytes:
+ logger.warning("Could not read full header from binary file.")
+ return
+
+ import struct
+
+ elsize, record_length_from_header = struct.unpack("<II", header_bytes)
+ logger.success(f"Bin header: data el. size: {elsize} (bytes)")
+ logger.success(
+ f"Bin header: length: {record_length_from_header} ({elsize}-byte nums)"
+ )
+
+ total_points = params.get("signal_hardware_record_length")
+ if total_points is None:
+ total_points = record_length_from_header
+ logger.warning(
+ f"SignalHardwareRecordLength not found. Using length from header: {total_points} points."
+ )
+ elif total_points != record_length_from_header:
+ logger.warning(
+ f"SignalHardwareRecordLength ({total_points}) "
+ f"does not match header record length ({record_length_from_header}) in {fp}. "
+ "Using header length."
+ )
+ total_points = record_length_from_header
+
+ current_pos = 0
+ while current_pos < total_points:
+ points_to_read = min(chunk_size, total_points - current_pos)
+
+ x_chunk = np.fromfile(f, dtype=full_dtype_str, count=points_to_read)
+
+ if len(x_chunk) == 0:
+ break
+
+ start_time = current_pos * si
+ num_points = len(x_chunk)
+ t_chunk = np.linspace(
+ start_time,
+ start_time + (num_points - 1) * si,
+ num_points,
+ dtype=np.float32,
+ )
+
+ yield t_chunk, x_chunk.astype(np.float32)
+
+ current_pos += points_to_read
+
+ except FileNotFoundError:
+ raise FileNotFoundError(f"The file '{fp}' was not found.")