From 637ddc52f4dc23ba3aa7cccef014aa85cab36b49 Mon Sep 17 00:00:00 2001 From: Sam Scholten Date: Mon, 30 Mar 2026 11:42:22 +1000 Subject: Release v1.0.0 --- .gitignore | 73 +- Makefile | 18 - PicoStream.spec | 102 -- README.md | 240 +--- assets/icons/app.ico | Bin 0 -> 36439 bytes assets/images/screenshot.png | Bin 0 -> 106092 bytes justfile | 92 +- new_plan.md | 65 - picostream/__init__.py | 8 +- picostream/acquisition_rate.py | 287 ++++ picostream/cli.py | 699 --------- picostream/consumer.py | 160 --- picostream/data_pipeline.py | 622 ++++++++ picostream/device.py | 2794 ++++++++++++++++++++++++++++++++++++ picostream/dfplot.py | 1438 ++++++++++--------- picostream/main.py | 2671 ++++++++++++++++++++++++++++------ picostream/mock_device.py | 1477 +++++++++++++++++++ picostream/pico.py | 535 ------- picostream/reader.py | 259 ---- picostream/ring_buffer.py | 331 +++++ picostream/test_buffered_stream.py | 344 +++++ picostream/test_data_pipeline.py | 496 +++++++ picostream/test_live_plotter.py | 437 ++++++ picostream/test_max_adc_fix.py | 73 + picostream/test_rate_contract.py | 263 ++++ picostream/test_rate_invariants.py | 329 +++++ picostream/test_ring_buffer.py | 508 +++++++ picostream/test_zarr_reader.py | 331 +++++ picostream/test_zarr_viewer.py | 223 +++ picostream/test_zarr_writer.py | 310 ++++ picostream/zarr_reader.py | 405 ++++++ picostream/zarr_viewer.py | 553 +++++++ picostream/zarr_writer.py | 250 ++++ pyproject.toml | 96 +- 34 files changed, 13389 insertions(+), 3100 deletions(-) delete mode 100644 Makefile delete mode 100644 PicoStream.spec create mode 100644 assets/icons/app.ico create mode 100644 assets/images/screenshot.png delete mode 100644 new_plan.md create mode 100644 picostream/acquisition_rate.py delete mode 100644 picostream/cli.py delete mode 100644 picostream/consumer.py create mode 100644 picostream/data_pipeline.py create mode 100644 picostream/device.py create mode 100644 picostream/mock_device.py delete mode 100644 picostream/pico.py delete mode 100644 picostream/reader.py create mode 100644 picostream/ring_buffer.py create mode 100644 picostream/test_buffered_stream.py create mode 100644 picostream/test_data_pipeline.py create mode 100644 picostream/test_live_plotter.py create mode 100644 picostream/test_max_adc_fix.py create mode 100644 picostream/test_rate_contract.py create mode 100644 picostream/test_rate_invariants.py create mode 100644 picostream/test_ring_buffer.py create mode 100644 picostream/test_zarr_reader.py create mode 100644 picostream/test_zarr_viewer.py create mode 100644 picostream/test_zarr_writer.py create mode 100644 picostream/zarr_reader.py create mode 100644 picostream/zarr_viewer.py create mode 100644 picostream/zarr_writer.py diff --git a/.gitignore b/.gitignore index d66a8e1..5c2e902 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,65 @@ -conda_env* -output.hdf5 -*__pycache__* -complexipy.json -build/* -*egg-info* -uv.lock -dist/ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual environments +venv/ +env/ +ENV/ +.venv + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# Testing +.pytest_cache/ +.coverage +htmlcov/ + +# Type checking / linting +.ruff_cache/ +complexipy.json + +# OS +.DS_Store +Thumbs.db + +# Data files +*.hdf5 +*.h5 +*.zarr +output.* + +# Logs +*.log + +# uv lock file (we use pyproject.toml for deps) +uv.lock + +# Old backups (keep for reference but not committed) +v0_backup/ + +# Conda envs (legacy) +conda_env* diff --git a/Makefile b/Makefile deleted file mode 100644 index c9a7a95..0000000 --- a/Makefile +++ /dev/null @@ -1,18 +0,0 @@ - - - -format: - ruff format - ruff check --fix --select F,B,I . - ruff format - -lint: - prospector --with-tool mypy - -complexity: - complexipy --sort asc . - -complexity-json: - complexipy --sort desc --output-json . - - diff --git a/PicoStream.spec b/PicoStream.spec deleted file mode 100644 index fdadedf..0000000 --- a/PicoStream.spec +++ /dev/null @@ -1,102 +0,0 @@ -# -*- mode: python ; coding: utf-8 -*- - -import glob -import os -import sys - - -def find_libffi_dll(): - """ - Find the libffi-*.dll file required for _ctypes on Windows. - Searches in common locations for standard, venv, and Conda Python. - """ - if sys.platform != "win32": - return [] - - print("--- PyInstaller Build Environment ---") - print(f" - Python Executable: {sys.executable}") - print(f" - sys.prefix: {sys.prefix}") - if hasattr(sys, "base_prefix"): - print(f" - sys.base_prefix: {sys.base_prefix}") - else: - print(" - sys.base_prefix: Not available") - conda_prefix = os.environ.get("CONDA_PREFIX") - print(f" - CONDA_PREFIX env var: {conda_prefix}") - - search_paths = [] - # Active environment's DLLs directory - search_paths.append(os.path.join(sys.prefix, "DLLs")) - - # Base Python installation's directories (if in a venv) - if hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix: - search_paths.append(os.path.join(sys.base_prefix, "DLLs")) - search_paths.append(os.path.join(sys.base_prefix, "Library", "bin")) - - # Conda environment's directory (if CONDA_PREFIX is set) - if conda_prefix: - search_paths.append(os.path.join(conda_prefix, "Library", "bin")) - - print("\n--- Potential Search Paths for libffi ---") - for p in search_paths: - print(f" - Path: {p}, Exists? {os.path.isdir(p)}") - - print("\n--- Searching for libffi DLL on Windows ---") - unique_paths = sorted(list(set(p for p in search_paths if os.path.isdir(p)))) - - for path in unique_paths: - print(f" - Checking: {path}") - dll_pattern = os.path.join(path, "libffi-*.dll") - found_dlls = glob.glob(dll_pattern) - if found_dlls: - dll_path = found_dlls[0] - print(f" - Found: {dll_path}") - return [(dll_path, ".")] # (source, destination_in_bundle) - - print("\nERROR: Could not find libffi-*.dll on Windows.") - sys.exit(1) - - -# --- PyInstaller Spec --- -app_name = 'PicoStream' -binaries = find_libffi_dll() if sys.platform == "win32" else [] - -a = Analysis( - ['picostream/main.py'], - pathex=[], - binaries=binaries, - datas=[], - hiddenimports=[], - hookspath=[], - hooksconfig={}, - runtime_hooks=[], - excludes=[], - win_no_prefer_redirects=False, - win_private_assemblies=False, - cipher=None, - noarchive=False, -) -pyz = PYZ(a.pure, a.zipped_data, cipher=None) - -exe = EXE( - pyz, - a.scripts, - a.binaries, - a.zipfiles, - a.datas, - [], - name=app_name, - debug=False, - bootloader_ignore_signals=False, - strip=False, - upx=True, - upx_exclude=[], - runtime_tmpdir=None, - console=True, # useful for debugging for now - disable_windowed_traceback=False, - argv_emulation=False, - target_arch=None, - codesign_identity=None, - entitlements_file=None, - icon=None, -) -# NOTE: The absence of a COLLECT block is what defines a one-file build. diff --git a/README.md b/README.md index d332ae0..e130832 100644 --- a/README.md +++ b/README.md @@ -1,216 +1,118 @@ # PicoStream -High-performance data acquisition from PicoScope 5000a series to HDF5, with decoupled live visualisation. Designed for robust, high-speed logging where data integrity is critical. +A fast, simple GUI for streaming from PicoScope 5000a series scopes. Built on [labdaemon](https://github.com/qnslab/labdaemon). -## Quick Start +## What's new in v1.0 -Install with `uv`: +Complete rewrite with: -```bash -uv pip install -e . -``` +- **Dual channels** — acquire A and B simultaneously (16-bit = single channel only) +- **Ring buffer** — configurable lookback (5–120s) with on-demand recording +- **Pre-trigger capture** — include N seconds of data before you hit record +- **Keep or discard** — stop recording and either save or trash the file +- **VisPy plotting** — hardware-accelerated OpenGL, much faster than before +- **Zarr format** — chunked storage, faster than HDF5 for large files +- **PyQt6** — bumped from PyQt5 -Install the [PicoSDK](https://www.picotech.com/downloads) (Linux users: see [Arch AUR wiki](https://aur.archlinux.org/packages/picoscope) for configuration). +![PicoStream GUI](assets/images/screenshot.png) -Launch the GUI: +## Install -```bash -just gui -# or: uv run python -m picostream.main -``` +Requires Python ≥3.10, PicoSDK installed. -Or use the CLI: +Install PicoSDK from [picotech.com/downloads](https://www.picotech.com/downloads). -```bash -uv run picostream -s 62.5 -o data.hdf5 --plot +Linux users: create `/etc/udev/rules.d/99-picoscope.rules`: ``` - -## Key Features - -- **Producer-consumer architecture** with large shared memory buffer pool prevents data loss -- **Decoupled live plotting** reads from HDF5 file, cannot interfere with acquisition -- **Live-only mode** for continuous monitoring without filling disk -- **Numba-accelerated decimation** for efficient visualisation of large datasets -- **GUI and CLI interfaces** for interactive and scripted workflows - -## Usage - -### GUI Application - -```bash -just gui +SUBSYSTEM=="usb", ATTR{idVendor}=="0ce9", MODE="0666" ``` +Then run `sudo udevadm control --reload-rules && sudo udevadm trigger`. -Configure acquisition parameters, start/stop capture, and view live data. Settings persist between sessions. - -### Building Standalone Executable +From PyPI: ```bash -just build-gui +pip install picostream ``` -Executable appears in `dist/PicoStream` (or `dist/PicoStream.exe` on Windows). - -### Command-Line Interface - -Standard acquisition (saves all data): +For development: ```bash -uv run picostream -s 62.5 -o my_data.hdf5 --plot +just sync ``` -Live-only mode (limits file size): +## Run + +After pip install: ```bash -uv run picostream -s 62.5 --plot --max-buff-sec 60 +picostream ``` -View existing file: +During development: ```bash -uv run python -m picostream.dfplot /path/to/data.hdf5 +just gui ``` -Run `uv run picostream --help` for all options. - -## Documentation - -### Architecture - -PicoStream uses a producer-consumer pattern to ensure data integrity: - -- **Producer (`PicoDevice`)**: Interfaces with PicoScope hardware, streams ADC data into shared memory buffers in a dedicated thread -- **Consumer (`Consumer`)**: Retrieves filled buffers from queue, writes to HDF5, returns empty buffers to pool in separate thread -- **Buffer Pool**: Large shared memory buffers (100+ MB) prevent data loss if disk I/O slows -- **Live Plotter (`HDF5LivePlotter`)**: Reads from HDF5 file on disk, completely decoupled from acquisition +## Quickstart -This ensures the critical acquisition path is never blocked by disk I/O or GUI rendering. +1. Enter scope serial (or `MOCK` to simulate), click **Connect** +2. Set channels, voltage ranges, sample rate, resolution +3. **Start Acquisition** (Space bar) +4. **Start Recording** (R) when you want to save +5. **Stop & Keep** (K) or **Stop & Discard** (Del) -### Data Analysis with `PicoStreamReader` +Recordings save to `~/picostream/` by default. -The `PicoStreamReader` class provides efficient access to HDF5 files with on-the-fly decimation. +## Crash Cleanup -```python -import numpy as np -from picostream.reader import PicoStreamReader +If the app crashes during recording, incomplete Zarr files remain in `~/picostream/`. +These are useless. Delete them: -with PicoStreamReader('my_data.hdf5') as reader: - sample_rate_sps = 1e9 / reader.sample_interval_ns - print(f"File contains {reader.num_samples:,} samples at {sample_rate_sps / 1e6:.2f} MS/s") - - # Iterate through file with 10x decimation - for times, voltages_mv in reader.get_block_iter( - chunk_size=10_000_000, decimation_factor=10, decimation_mode='min_max' - ): - print(f"Processed {voltages_mv.size} decimated points") +```bash +rm -rf ~/picostream/*.zarr ``` -### API Reference: `PicoStreamReader` - -#### Initialisation - -**`__init__(self, hdf5_path: str)`** - -Initialises the reader. File is opened when used as context manager. - -#### Metadata Attributes - -Available after opening: +Change save location in the GUI (Record Controls → Change...). -- `num_samples: int` - Total raw samples in dataset -- `sample_interval_ns: float` - Time interval between samples (nanoseconds) -- `voltage_range_v: float` - Configured voltage range (e.g., `20.0` for ±20V) -- `max_adc_val: int` - Maximum ADC count value (e.g., 32767) -- `analog_offset_v: float` - Configured analogue offset (Volts) -- `downsample_mode: str` - Hardware downsampling mode (`'average'` or `'aggregate'`) -- `hardware_downsample_ratio: int` - Hardware downsampling ratio +## Shortcuts -#### Methods +| Key | Action | +|-----|--------| +| Space | Start/stop acquisition | +| 1, 2 | Toggle channels | +| R | Start recording | +| K | Stop + keep | +| Del | Stop + discard | +| Ctrl+O | Open existing Zarr file | -**`get_block_iter(self, chunk_size: int = 1_000_000, decimation_factor: int = 1, decimation_mode: str = "mean") -> Generator`** +## Development -Generator yielding `(times, voltages)` tuples for entire dataset. Recommended for large files. - -- `chunk_size`: Number of raw samples per chunk -- `decimation_factor`: Decimation factor -- `decimation_mode`: `'mean'` or `'min_max'` - -**`get_next_block(self, chunk_size: int, decimation_factor: int = 1, decimation_mode: str = "mean") -> Tuple | None`** - -Retrieves next sequential block. Returns `None` at end of file. Use `reset()` to restart. - -**`get_block(self, size: int, start: int = 0, decimation_factor: int = 1, decimation_mode: str = "mean") -> Tuple`** - -Retrieves specific block from file. - -- `size`: Number of raw samples -- `start`: Starting sample index - -**`reset(self) -> None`** - -Resets internal counter for `get_next_block()`. - -### API Reference: `HDF5LivePlotter` - -PyQt5 widget for real-time visualisation of HDF5 files. Can be used standalone or embedded in custom applications. - -#### Initialisation - -**`__init__(self, hdf5_path: str = "/tmp/data.hdf5", update_interval_ms: int = 50, display_window_seconds: float = 0.5, decimation_factor: int = 150, max_display_points: int = 4000)`** - -- `hdf5_path`: Path to HDF5 file -- `update_interval_ms`: Update frequency (milliseconds) -- `display_window_seconds`: Duration of data to display (seconds) -- `decimation_factor`: Decimation factor for display -- `max_display_points`: Maximum points to display (prevents GUI slowdown) - -#### Methods - -**`set_hdf5_path(self, hdf5_path: str) -> None`** - -Updates monitored HDF5 file path. - -**`start_updates(self) -> None`** - -Begins periodic plot updates. - -**`stop_updates(self) -> None`** - -Stops periodic updates. - -**`save_screenshot(self) -> None`** - -Saves PNG screenshot. Called automatically when 'S' key is pressed. +```bash +just gui # run from source +just test # run tests +just test-cov # run tests with coverage +just format # format all code +just lint # lint code +just typecheck # type check +just build # build PyInstaller executable +just clean # remove build artifacts +just sync # install dependencies +just update # update dependencies +just help # see all commands +``` -#### Usage Example +## Acknowledgements -```python -from PyQt5.QtWidgets import QApplication -from picostream.dfplot import HDF5LivePlotter - -app = QApplication([]) -plotter = HDF5LivePlotter( - hdf5_path="my_data.hdf5", - display_window_seconds=1.0, - decimation_factor=100 -) -plotter.show() -plotter.start_updates() -app.exec_() -``` +This began as a fork of [JoshHarris2108/pico_streaming](https://github.com/JoshHarris2108/pico_streaming) (unlicensed). The original producer-consumer architecture and PicoSDK integration came from Josh's work. ## Changelog -### Version 0.2.0 -- Added live-only mode with `--max-buff-sec` option -- Added GUI application for interactive control -- Improved plotter to handle buffer resets gracefully -- Added total sample count tracking across buffer resets -- Skip verification step in live-only mode for better performance - -### Version 0.1.0 -- Initial release with core streaming and plotting functionality +### v1.0.0 +Complete rewrite — new architecture, new file format, dual channels, GUI-only. -## Acknowledgements +### v0.2.0 +CLI + HDF5 + single channel + PyQt5/PyQtGraph. -This package began as a fork of [JoshHarris2108/pico_streaming](https://github.com/JoshHarris2108/pico_streaming) (unlicensed). Thanks to Josh for the original architecture. +### v0.1.0 +Initial release. diff --git a/assets/icons/app.ico b/assets/icons/app.ico new file mode 100644 index 0000000..e0f4c0f Binary files /dev/null and b/assets/icons/app.ico differ diff --git a/assets/images/screenshot.png b/assets/images/screenshot.png new file mode 100644 index 0000000..3fb0d26 Binary files /dev/null and b/assets/images/screenshot.png differ diff --git a/justfile b/justfile index 237121f..c1e1b1f 100644 --- a/justfile +++ b/justfile @@ -1,12 +1,96 @@ -set windows-shell := ["C:\\Program Files\\Git\\bin\\sh.exe","-c"] +# PicoStream - Justfile +set windows-shell := ["C:\\Program Files\\Git\\bin\\sh.exe", "-c"] + +# Default: show available commands +default: + @just --list # Run the picostream GUI application gui: PYTHONPATH=. uv run picostream/main.py -# Build the picostream GUI executable -build-gui: +# Run with MOCK device (no hardware needed) +mock: + PYTHONPATH=. uv run python -c "import os; os.environ['MOCK_PICO'] = '1'; from picostream.main import main; main()" + +# Format all code +fmt: + uv run ruff format picostream/ + +# Check code formatting without fixing +fmt-check: + uv run ruff format --check picostream/ + +# Lint code +lint: + uv run ruff check picostream/ + +# Fix linting issues +lint-fix: + uv run ruff check --fix picostream/. + +# Type check code +type-check: + uv run pyright picostream/ + +# Run all checks (format check, lint, type-check) +check: fmt-check lint type-check + @echo "✓ All checks passed" + +# Run all checks and fix what can be fixed +fix: lint-fix fmt + @echo "✓ Code fixed" + +# Run all dev tools (tests run even if type-check fails) +devtools: fix fmt-check lint test type-check + +# Run all tests +test: + PYTHONPATH=. uv run pytest picostream -v + +# Run tests with coverage +test-cov: + PYTHONPATH=. uv run pytest picostream --cov=picostream --cov-report=term-missing -v + +# Run tests and generate HTML coverage report +test-cov-html: + PYTHONPATH=. uv run pytest picostream --cov=picostream --cov-report=html -v + @echo "Coverage report: file://$(pwd)/htmlcov/index.html" + +# Profile the PicoStream GUI (outputs to profile.stats) +profile: + PYTHONPATH=. uv run python -m cProfile -o profile.stats picostream/main.py + +# View profile results in browser (snakeviz) +profile-view: + uv run snakeviz profile.stats + +# Print top 20 functions by cumulative time (quick CLI view) +profile-text: + uv run python -c "import pstats; p = pstats.Stats('profile.stats'); p.sort_stats('cumulative').print_stats(20)" + +# Build the picostream GUI executable with PyInstaller +build: uv pip install pyinstaller pyinstaller-hooks-contrib - uv run pyinstaller --clean PicoStream.spec --noconfirm + uv run pyinstaller \ + --name PicoStream \ + --onefile \ + --windowed \ + --collect-all vispy \ + --icon assets/icons/app.ico \ + picostream/main.py + +# Sync dependencies (install/update) +sync: + uv sync + +# Update all dependencies +update: + uv lock --upgrade +# Clean build artifacts and caches +clean: + rm -rf build/ dist/ picostream/*.spec __pycache__ .pytest_cache .ruff_cache + find picostream -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true + find picostream -type f -name "*.pyc" -delete 2>/dev/null || true diff --git a/new_plan.md b/new_plan.md deleted file mode 100644 index 6c7186e..0000000 --- a/new_plan.md +++ /dev/null @@ -1,65 +0,0 @@ - Design - -The application will be a single-window GUI built with PyQt5. It will be composed of three main components: a main -window, a refactored plotter widget, and a background worker for data acquisition. - - 1 PicoStreamMainWindow (QMainWindow): This will be the application's central component, serving as the main entry point. - • Layout: It will feature a two-panel layout. The left panel will contain all user-configurable settings for the - acquisition (e.g., sample rate, voltage range, output file). The right panel will contain the embedded live plot. - • Control: It will have "Start" and "Stop" buttons to manage the acquisition lifecycle. It will manage the - application's state (e.g., idle, acquiring, error). - • Persistence: It will use QSettings to automatically save user-entered settings on exit and load them on startup. - • Lifecycle: It will be responsible for creating and managing the background worker thread and ensuring a graceful - shutdown. - 2 HDF5LivePlotter (QWidget): The existing plotter will be refactored from a QMainWindow into a QWidget. - • Responsibility: Its sole responsibility will be to monitor the HDF5 file and display the live data. It will no - longer be a top-level window or control the application's lifecycle. - • Integration: An instance of this widget will be created and embedded directly into the right-hand panel of the - PicoStreamMainWindow. - 3 StreamerWorker (QObject): This class will manage the acquisition task in a background thread to keep the GUI - responsive. - • Execution: It will be moved to a QThread. Its primary method will instantiate the Streamer class with parameters - from the GUI and call the blocking Streamer.run() method. - • Communication: It will use Qt signals to report its status (e.g., finished, error) back to the PicoStreamMainWindow - in a thread-safe manner. The main window will connect to these signals to update the UI, for example, by - re-enabling the "Start" button upon completion. - - Phased Implementation Plan - -This plan breaks the work into five distinct, sequential phases. - -Phase 1: Project Restructuring and GUI Shell The goal is to set up the new file structure and a basic, non-functional GUI -window. - - 1 Rename picostream/main.py to picostream/cli.py. - 2 Create a new, empty picostream/main.py to serve as the GUI entry point. - 3 In the new main.py, create a PicoStreamMainWindow class with a simple layout containing placeholders for the settings - panel and the plot. - 4 Update the justfile with a new target to run the GUI application. - -Phase 2: Background Worker Implementation The goal is to run the data acquisition in a background thread, controlled by -the GUI. - - 1 In picostream/main.py, create the StreamerWorker class inheriting from QObject. - 2 Implement the QThread worker pattern in PicoStreamMainWindow to start the acquisition when a "Start" button is clicked - and to signal a stop using the existing shutdown_event. - 3 Connect the worker's finished and error signals to GUI methods that update the UI state (e.g., re-enable buttons). - -Phase 3: GUI Controls and Settings Persistence The goal is to make the acquisition configurable through the GUI and to -remember settings. - - 1 Populate the settings panel in PicoStreamMainWindow with input widgets for all acquisition parameters. - 2 Pass the values from these widgets to the StreamerWorker when starting an acquisition. - 3 Implement load_settings and save_settings methods using QSettings. - -Phase 4: Plotter Integration The goal is to embed the live plot directly into the main window. - - 1 In picostream/dfplot.py, refactor the HDF5LivePlotter class to inherit from QWidget instead of QMainWindow. Remove its - window-management logic. - 2 In PicoStreamMainWindow, replace the plot placeholder with an instance of the refactored HDF5LivePlotter widget. - -Phase 5: Packaging The goal is to create a standalone, distributable executable. - - 1 Add a new build target to the justfile that uses PyInstaller to bundle the application. - 2 Configure the build to handle dependencies, particularly creating a hook for Numba if necessary. - 3 Test the final executable. diff --git a/picostream/__init__.py b/picostream/__init__.py index 1fb1a8e..ea077ae 100644 --- a/picostream/__init__.py +++ b/picostream/__init__.py @@ -1,7 +1,3 @@ -"""PicoStream: High-performance PicoScope data acquisition and visualization.""" +"""PicoStream - fast GUI streaming for PicoScope 5000a series.""" -from .reader import PicoStreamReader -from .dfplot import HDF5LivePlotter - -__version__ = "0.2.0" -__all__ = ["PicoStreamReader", "HDF5LivePlotter"] +__version__ = "1.0.0" diff --git a/picostream/acquisition_rate.py b/picostream/acquisition_rate.py new file mode 100644 index 0000000..f9c3e18 --- /dev/null +++ b/picostream/acquisition_rate.py @@ -0,0 +1,287 @@ +"""Acquisition rate abstraction for consistent rate semantics. + +This module provides a single source of truth for all rate calculations +throughout the picostream pipeline. Rates are represented as immutable +objects that know the relationships between hardware rate, downsampling, +channels, and effective rates at different stages. + +All time/sample conversions should go through the AcquisitionRate object +to ensure consistency and avoid double-counting of downsampling ratios. +""" + +from dataclasses import dataclass +from enum import Enum +from typing import TypeAlias + + +class DownsampleMode(Enum): + """Downsampling modes supported by Picoscope hardware.""" + + NONE = "NONE" + AVERAGE = "AVERAGE" + AGGREGATE = "AGGREGATE" + DECIMATE = "DECIMATE" + + +StorageSampleCount: TypeAlias = int +"""Number of samples in ring buffer (post-hardware-downsampling). + +In AGGREGATE mode, these are interleaved min/max pairs. +""" + +DisplayPointCount: TypeAlias = int +"""Number of points to display on screen (post-software-decimation). + +Always includes min/max pairs for envelope preservation. +""" + + +@dataclass(frozen=True) +class AcquisitionRate: + """Encapsulates all rate information for an acquisition. + + This is the single source of truth for rate semantics throughout + the pipeline. All rate calculations derive from these four core facts. + The object is frozen to prevent accidental misuse or reinterpretation + of rates at different stages. + + Attributes + ---------- + hardware_rate_hz : float + Raw ADC rate from hardware. For multi-channel systems, this is the + total rate across all channels. + Example: 125 MHz for 2 channels at 62.5 MS/s per channel. + + num_channels : int + Number of active acquisition channels. + + downsample_ratio : int + Hardware downsampling factor. 1 means no downsampling. + Example: 10 means every 10th sample is kept. + + downsample_mode : DownsampleMode + Type of downsampling applied: NONE, AVERAGE, AGGREGATE, or DECIMATE. + AGGREGATE produces min/max pairs (2 samples per time point). + """ + + hardware_rate_hz: float + num_channels: int + downsample_ratio: int + downsample_mode: DownsampleMode + + def __post_init__(self) -> None: + """Validate rate parameters.""" + if self.hardware_rate_hz <= 0: + raise ValueError( + f"hardware_rate_hz must be positive, got {self.hardware_rate_hz}" + ) + if self.num_channels <= 0: + raise ValueError(f"num_channels must be positive, got {self.num_channels}") + if self.downsample_ratio < 1: + raise ValueError( + f"downsample_ratio must be >= 1, got {self.downsample_ratio}" + ) + if not isinstance(self.downsample_mode, DownsampleMode): + raise ValueError( + f"downsample_mode must be DownsampleMode, got {type(self.downsample_mode)}" + ) + + @property + def per_channel_rate_hz(self) -> float: + """Per-channel rate after hardware downsampling. + + This is the rate at which individual channels produce time points + after downsampling is applied. + + Returns + ------- + float + Samples per second per channel. + + Example + ------- + hardware_rate_hz=125e6, num_channels=2, downsample_ratio=10 + → per_channel_rate_hz = 125e6 / (2 * 10) = 6.25e6 MS/s + """ + return self.hardware_rate_hz / (self.num_channels * self.downsample_ratio) + + @property + def storage_rate_hz(self) -> float: + """Total rate at which samples arrive at ring buffer. + + This is the rate to use for ring buffer capacity calculations and + writing to storage. It accounts for AGGREGATE mode producing 2 samples + per downsampled point (min/max pairs interleaved). + + For non-AGGREGATE modes: storage_rate = hardware_rate / downsample_ratio + For AGGREGATE mode: storage_rate = (hardware_rate / downsample_ratio) * 2 + + Returns + ------- + float + Total samples per second arriving at the buffer. + + Example + ------- + hardware_rate_hz=125e6, downsample_ratio=10, mode=DECIMATE + → storage_rate_hz = 125e6 / 10 = 12.5 MS/s + + hardware_rate_hz=125e6, downsample_ratio=10, mode=AGGREGATE + → storage_rate_hz = (125e6 / 10) * 2 = 25 MS/s + """ + base_rate = self.hardware_rate_hz / self.downsample_ratio + if ( + self.downsample_mode == DownsampleMode.AGGREGATE + and self.downsample_ratio > 1 + ): + return base_rate * 2 + return base_rate + + @property + def display_rate_per_channel_hz(self) -> float: + """Per-channel rate for time axis generation in display. + + Used by the plotter for converting sample counts to time durations. + This is the "semantic" time rate for the user - how fast time flows + in the plot regardless of how samples are packed in the buffer. + + Returns + ------- + float + Samples per second per channel as perceived in display time. + """ + return self.per_channel_rate_hz + + def samples_to_seconds(self, n_samples: StorageSampleCount) -> float: + """Convert sample count to time duration. + + Operates on per-channel sample counts (time-points), properly handling + AGGREGATE mode where storage includes interleaved min/max pairs. + + The conversion uses per_channel_rate_hz, not storage_rate_hz, because + this method is designed for time-point counts (one value per channel + per time step), not total storage sample counts. + + Parameters + ---------- + n_samples : StorageSampleCount + Number of time-points (samples per channel). For AGGREGATE mode, + this counts each min/max pair as 2 samples but represents 1 time-point. + + Returns + ------- + float + Duration in seconds. + + Example + ------- + If per_channel_rate = 6.25 MS/s and n_samples = 6.25M: + - NONE mode: 6.25e6 / 6.25e6 = 1.0 second + - AGGREGATE mode: (6.25e6 // 2) / 6.25e6 = 0.5 second + (because 6.25M samples = 3.125M pairs = 0.5s of data) + """ + if ( + self.downsample_mode == DownsampleMode.AGGREGATE + and self.downsample_ratio > 1 + ): + # Each min/max pair represents one time point + n_time_points = n_samples // 2 + return n_time_points / self.display_rate_per_channel_hz + return n_samples / self.display_rate_per_channel_hz + + def seconds_to_samples(self, duration_s: float) -> StorageSampleCount: + """Convert time duration to sample count. + + Returns the number of storage samples (including min/max pairs for + AGGREGATE mode) needed to represent the given duration. + + Note: This operates on per-channel rates. The returned count is for + storage capacity calculations. For time-point counts (one per channel), + divide the result by num_channels. + + Parameters + ---------- + duration_s : float + Duration in seconds. + + Returns + ------- + StorageSampleCount + Number of storage samples needed. For AGGREGATE mode with + downsampling, this is 2× the number of time-points. + + Example + ------- + If per_channel_rate = 6.25 MS/s and duration = 1.0s: + - NONE mode: 1.0 * 6.25e6 = 6.25M samples + - AGGREGATE mode: (1.0 * 6.25e6) * 2 = 12.5M samples + (because we store both min and max for each time point) + """ + n_time_points = int(duration_s * self.display_rate_per_channel_hz) + if ( + self.downsample_mode == DownsampleMode.AGGREGATE + and self.downsample_ratio > 1 + ): + return n_time_points * 2 + return n_time_points + + def get_display_duration( + self, n_display_samples: DisplayPointCount, decimation: int + ) -> float: + """Calculate time duration represented by display samples. + + DEPRECATED: This method mixes concerns by trying to reverse-engineer + duration from display points. Use samples_to_seconds() with storage + samples instead. + + Accounts for both software min-max decimation (which produces pairs) + and the hardware downsampling that's already baked into the rate. + + Parameters + ---------- + n_display_samples : DisplayPointCount + Number of display points (after software decimation). + decimation : int + Software decimation factor used for display (minimum 1). + + Returns + ------- + float + Duration in seconds that the display data represents. + + Notes + ----- + This is used by the plotter to know how much time is represented + by the currently displayed points, independent of how many points + are actually being drawn on screen. + """ + if decimation <= 0 or n_display_samples <= 0: + return 0.0 + + # Software min-max decimation produces pairs + n_time_points = n_display_samples // 2 + original_samples = n_time_points * decimation + + return original_samples / self.display_rate_per_channel_hz + + def __str__(self) -> str: + """Human-readable description of the acquisition rate.""" + return ( + f"AcquisitionRate(" + f"hw={self.hardware_rate_hz / 1e6:.1f}MS/s, " + f"ch={self.num_channels}, " + f"ds={self.downsample_ratio}x {self.downsample_mode.value}, " + f"→ {self.per_channel_rate_hz / 1e6:.2f}MS/s/ch, " + f"{self.storage_rate_hz / 1e6:.2f}MS/s storage" + f")" + ) + + def __repr__(self) -> str: + """Development-friendly representation.""" + return ( + f"AcquisitionRate(" + f"hardware_rate_hz={self.hardware_rate_hz}, " + f"num_channels={self.num_channels}, " + f"downsample_ratio={self.downsample_ratio}, " + f"downsample_mode={self.downsample_mode})" + ) diff --git a/picostream/cli.py b/picostream/cli.py deleted file mode 100644 index d8e7ffb..0000000 --- a/picostream/cli.py +++ /dev/null @@ -1,699 +0,0 @@ -from __future__ import annotations - -import queue -import signal -import sys -import threading -import time -from datetime import datetime -from typing import TYPE_CHECKING, List, Optional - -if TYPE_CHECKING: - from PyQt5.QtWidgets import QApplication - -import click -import h5py -import numpy as np -from loguru import logger - -from . import __version__ -from .consumer import Consumer -from .pico import PicoDevice - - -class Streamer: - """Orchestrates the Picoscope data acquisition process. - - This class initializes the Picoscope device (producer), the HDF5 writer - (consumer), and the live plotter. It manages the threads, queues, and - graceful shutdown of the entire application. - - Supports both standard acquisition mode (saves all data) and live-only mode - (limits buffer size using max_buffer_seconds parameter). - """ - - def __init__( - self, - sample_rate_msps: float = 62.5, - resolution_bits: int = 12, - channel_range_str: str = "PS5000A_20V", - enable_live_plot: bool = False, - output_file: str = "./output.hdf5", - debug: bool = False, - plot_window_s: float = 0.5, - plot_points: int = 4000, - hardware_downsample: int = 1, - downsample_mode: str = "average", - offset_v: float = 0.0, - max_buffer_seconds: Optional[float] = None, - is_gui_mode: bool = False, - y_min: Optional[float] = None, - y_max: Optional[float] = None, - bandwidth_limiter: str = "full", - ) -> None: - # --- Configuration --- - self.output_file = output_file - self.debug = debug - self.enable_live_plot = enable_live_plot - self.is_gui_mode = is_gui_mode - self.plot_window_s = plot_window_s - self.max_buffer_seconds = max_buffer_seconds - self.y_min = y_min - self.y_max = y_max - - ( - sample_rate_msps, - pico_downsample_ratio, - pico_ratio_mode, - offset_v, - ) = self._validate_config( - resolution_bits, - sample_rate_msps, - channel_range_str, - hardware_downsample, - downsample_mode, - offset_v, - ) - # Dynamically size buffers to hold a specific duration of data. This makes - # memory usage proportional to the data rate, providing a consistent - # time-based buffer to handle processing latencies. - effective_rate_sps = (sample_rate_msps * 1e6) / pico_downsample_ratio - - # Consumer buffers (for writing to HDF5) are sized to hold 1 second of data. - # This is a good balance, as larger buffers lead to more efficient disk writes - # but use more RAM. - consumer_buffer_duration_s = 1.0 - self.consumer_buffer_size = int( - effective_rate_sps * consumer_buffer_duration_s - ) - if downsample_mode == "aggregate": - self.consumer_buffer_size *= 2 - self.consumer_num_buffers = 5 # A pool of 5 buffers - - # The Picoscope driver buffer is sized to hold 0.5 seconds of data. This - # buffer receives data directly from the hardware. A smaller size ensures - # that the application receives data in timely chunks, reducing latency. - driver_buffer_duration_s = 0.5 - self.pico_driver_buffer_size = int( - effective_rate_sps * driver_buffer_duration_s - ) - self.pico_driver_num_buffers = ( - 1 # A single large buffer is efficient for the driver - ) - - logger.info( - f"Consumer buffer sized to {self.consumer_buffer_size:,} samples " - f"({consumer_buffer_duration_s}s at effective rate)" - ) - logger.info( - f"Pico driver buffer sized to {self.pico_driver_buffer_size:,} samples " - f"({driver_buffer_duration_s}s at effective rate)" - ) - - # --- Plotting Decimation --- - # Calculate the decimation factor needed to achieve the target number of plot points. - points_per_timestep = 2 if downsample_mode == "aggregate" else 1 - samples_in_window = effective_rate_sps * plot_window_s * points_per_timestep - self.decimation_factor = max(1, int(samples_in_window / plot_points)) - logger.info( - f"Plotting with target of {plot_points} points. " - f"Calculated decimation factor: {self.decimation_factor}" - ) - - # Picoscope hardware settings - self.pico_resolution = f"PS5000A_DR_{resolution_bits}BIT" - self.pico_channel_range = channel_range_str - self.pico_sample_interval_ns = int(1000 / sample_rate_msps) - self.pico_sample_unit = "PS5000A_NS" - - # Streaming settings - self.pico_auto_stop = 0 # Don't auto stop - self.pico_auto_stop_stream = False - # --- End Configuration --- - - # --- System Components --- - self.shutdown_event: threading.Event = threading.Event() - data_queue: queue.Queue[int] = queue.Queue() - empty_queue: queue.Queue[int] = queue.Queue() - data_buffers: List[np.ndarray] = [] - - # Pre-allocate a pool of numpy arrays for data transfer and populate the - # empty_queue with their indices. - for idx in range(self.consumer_num_buffers): - data_buffers.append(np.empty((self.consumer_buffer_size,), dtype="int16")) - empty_queue.put(idx) - - # --- Producer --- - self.pico_device: PicoDevice = PicoDevice( - 0, # handle - self.pico_resolution, - self.pico_driver_buffer_size, - self.pico_driver_num_buffers, - self.consumer_buffer_size, - data_queue, - empty_queue, - data_buffers, - self.shutdown_event, - downsample_mode=downsample_mode, - ) - - # Store bandwidth limiter setting - self.bandwidth_limiter = bandwidth_limiter - - self.pico_device.set_channel( - "PS5000A_CHANNEL_A", 1, "PS5000A_DC", self.pico_channel_range, offset_v - ) - self.pico_device.set_channel( - "PS5000A_CHANNEL_B", 0, "PS5000A_DC", self.pico_channel_range, 0.0 - ) - - # Set bandwidth filter for Channel A if specified - if self.bandwidth_limiter == "20MHz": - self.pico_device.set_bandwidth_filter("PS5000A_CHANNEL_A", "PS5000A_BW_20MHZ") - else: - self.pico_device.set_bandwidth_filter("PS5000A_CHANNEL_A", "PS5000A_BW_FULL") - self.pico_device.set_data_buffer("PS5000A_CHANNEL_A", 0, pico_ratio_mode) - self.pico_device.configure_streaming_var( - self.pico_sample_interval_ns, - self.pico_sample_unit, - 0, # pre-trigger samples - pico_downsample_ratio, - pico_ratio_mode, - self.pico_auto_stop, - self.pico_auto_stop_stream, - ) - - # Run streaming once to get the actual sample interval from the driver - self.pico_device.run_streaming() - - # --- Consumer --- - # Prepare metadata for the consumer - acquisition_start_time_utc = datetime.utcnow().isoformat() + "Z" - was_live_mode = self.max_buffer_seconds is not None - - # Get metadata from configured device and pass to consumer - metadata = self.pico_device.get_metadata( - acquisition_start_time_utc=acquisition_start_time_utc, - picostream_version=__version__, - acquisition_command="", # Will be set later in main() - was_live_mode=was_live_mode - ) - - # Calculate max samples for live-only mode - max_samples = None - if self.max_buffer_seconds: - max_samples = int(effective_rate_sps * self.max_buffer_seconds) - if downsample_mode == "aggregate": - max_samples *= 2 - logger.info(f"Live-only mode: limiting buffer to {self.max_buffer_seconds}s ({max_samples:,} samples)") - - self.consumer: Consumer = Consumer( - self.consumer_buffer_size, - data_queue, - empty_queue, - data_buffers, - output_file, - self.shutdown_event, - metadata=metadata, - max_samples=max_samples, - ) - - # --- Threads --- - self.consumer_thread: threading.Thread = threading.Thread( - target=self.consumer.consume - ) - self.pico_thread: threading.Thread = threading.Thread( - target=self.pico_device.run_capture - ) - - # --- Signal Handling --- - # Only set the signal handler if not in GUI mode. - # In GUI mode, the main window handles shutdown signals. - # In CLI plot mode, the main() function handles SIGINT to quit the Qt app. - if not self.is_gui_mode and not self.enable_live_plot: - signal.signal(signal.SIGINT, self.signal_handler) - - # --- Live Plotting (optional) --- - self.start_time: Optional[float] = None - - def update_acquisition_command(self, command: str) -> None: - """Update the acquisition command in the consumer's metadata.""" - self.consumer.metadata["acquisition_command"] = command - - def _validate_config( - self, - resolution_bits: int, - sample_rate_msps: float, - channel_range_str: str, - hardware_downsample: int, - downsample_mode: str, - offset_v: float, - ) -> tuple[float, int, str, float]: - """Validates user-provided settings and returns derived configuration.""" - if resolution_bits == 8: - max_rate_msps = 125.0 - elif resolution_bits in [12, 14, 15, 16]: - max_rate_msps = 62.5 - else: - raise ValueError( - f"Unsupported resolution: {resolution_bits} bits. Must be one of 8, 12, 14, 15, 16." - ) - - if sample_rate_msps <= 0: - sample_rate_msps = max_rate_msps - logger.info(f"Max sample rate requested. Setting to {max_rate_msps} MS/s.") - - if sample_rate_msps > max_rate_msps: - raise ValueError( - f"Sample rate {sample_rate_msps} MS/s exceeds maximum of {max_rate_msps} MS/s for {resolution_bits}-bit resolution." - ) - - # Check if sample rate is excessive for the analog bandwidth. - # Bandwidth is dependent on both resolution and voltage range. - # (Based on PicoScope 5000A/B Series datasheet) - inv_voltage_map = {v: k for k, v in VOLTAGE_RANGE_MAP.items()} - voltage_v = inv_voltage_map.get(channel_range_str, 0) - - if resolution_bits == 16: - bandwidth_mhz = 20 # 20 MHz for all ranges - elif resolution_bits == 15: - # Bandwidth is 70MHz for < ±5V, 60MHz for >= ±5V - bandwidth_mhz = 70 if voltage_v < 5.0 else 60 - else: # 8-14 bits - # Bandwidth is 100MHz for < ±5V, 60MHz for >= ±5V - bandwidth_mhz = 100 if voltage_v < 5.0 else 60 - - # Nyquist rate is 2x bandwidth. A common rule of thumb is 3-5x. - # Warn if sampling faster than 5x the analog bandwidth. - if sample_rate_msps > 5 * bandwidth_mhz: - logger.warning( - f"Sample rate ({sample_rate_msps} MS/s) may be unnecessarily high " - f"for the selected voltage range ({channel_range_str}), which has an " - f"analog bandwidth of {bandwidth_mhz} MHz." - ) - - if downsample_mode == "aggregate" and hardware_downsample <= 1: - raise ValueError( - "Hardware downsample ratio must be > 1 for 'aggregate' mode." - ) - - if hardware_downsample > 1: - pico_downsample_ratio = hardware_downsample - pico_ratio_mode = f"PS5000A_RATIO_MODE_{downsample_mode.upper()}" - logger.info( - f"Hardware down-sampling ({downsample_mode}) enabled " - + f"with ratio {pico_downsample_ratio}." - ) - else: - pico_downsample_ratio = 1 - pico_ratio_mode = "PS5000A_RATIO_MODE_NONE" - - # Validate analog offset - if offset_v != 0.0: - if voltage_v >= 5.0: - raise ValueError( - f"Analog offset is not supported for voltage ranges >= 5V (selected: {channel_range_str})." - ) - if abs(offset_v) > voltage_v: - raise ValueError( - f"Analog offset ({offset_v}V) exceeds the selected voltage range (±{voltage_v}V)." - ) - logger.info(f"Analog offset set to {offset_v:.3f}V.") - - return sample_rate_msps, pico_downsample_ratio, pico_ratio_mode, offset_v - - def signal_handler(self, _sig: int, frame: Optional[object]) -> None: - """Handles Ctrl+C interrupts to initiate a graceful shutdown.""" - logger.warning("Ctrl+C detected. Shutting down.") - self.shutdown() - - def shutdown(self) -> None: - """Performs a graceful shutdown of all components. - - This method calculates final statistics, stops all threads, closes the - plotter, and ensures the Picoscope device is properly closed. - """ - if self.shutdown_event.is_set(): - return - - self._log_acquisition_summary() - self.shutdown_event.set() - - logger.info("Stopping data acquisition and saving...") - - self.pico_device.close_device() - self._join_threads() - - logger.success("Shutdown complete.") - - def _log_acquisition_summary(self) -> None: - """Calculates and logs final acquisition statistics.""" - if not self.start_time: - return - - end_time = time.time() - duration = end_time - self.start_time - total_samples = self.consumer.values_written - effective_rate_msps = (total_samples / duration) / 1e6 if duration > 0 else 0 - configured_rate_msps = 1e3 / self.pico_device.sample_int.value - - logger.info("--- Acquisition Summary ---") - logger.info(f"Total acquisition time: {duration:.2f} s") - logger.info( - "Total samples written: " - + f"{self.consumer.format_sample_count(total_samples)}" - ) - logger.info(f"Configured sample rate: {configured_rate_msps:.2f} MS/s") - logger.info(f"Effective average rate: {effective_rate_msps:.2f} MS/s") - - rate_ratio = ( - effective_rate_msps / configured_rate_msps - if configured_rate_msps > 0 - else 0 - ) - if rate_ratio < 0.95: - logger.warning( - f"Effective rate was only {rate_ratio:.1%} " + "of the configured rate." - ) - else: - logger.success("Effective rate matches configured rate.") - logger.info("--------------------------") - - def _join_threads(self) -> None: - """Waits for the producer and consumer threads to terminate.""" - for thread_name in ["pico_thread", "consumer_thread"]: - thread = getattr(self, thread_name, None) - if thread and thread.is_alive(): - logger.info(f"Waiting for {thread_name} to terminate...") - thread.join(timeout=2.0) - if thread.is_alive(): - logger.critical(f"{thread_name} failed to terminate.") - - def run(self, app: Optional[QApplication] = None) -> None: - """Starts the acquisition threads and optionally the Qt event loop.""" - # Start acquisition threads - self.start_time = time.time() - self.consumer_thread.start() - self.pico_thread.start() - - # Handle Qt event loop if plotting is enabled - if self.enable_live_plot and app: - from .dfplot import HDF5LivePlotter - - plotter = HDF5LivePlotter( - hdf5_path=self.output_file, - display_window_seconds=self.plot_window_s, - decimation_factor=self.decimation_factor, - y_min=self.y_min, - y_max=self.y_max, - ) - plotter.show() - - # Run the Qt event loop. This will block until the plot window is closed. - app.exec_() - - # Once the window is closed, the shutdown event should have been set. - # We call shutdown() to ensure threads are joined and cleanup happens. - self.shutdown() - else: - # In GUI mode, run() returns immediately, allowing the worker's event - # loop to process signals. In CLI mode, we block until completion. - if not self.is_gui_mode: - self.consumer_thread.join() - self.pico_thread.join() - logger.success("Acquisition complete!") - - -# --- Argument Parsing --- -VOLTAGE_RANGE_MAP = { - 0.01: "PS5000A_10MV", - 0.02: "PS5000A_20MV", - 0.05: "PS5000A_50MV", - 0.1: "PS5000A_100MV", - 0.2: "PS5000A_200MV", - 0.5: "PS5000A_500MV", - 1: "PS5000A_1V", - 2: "PS5000A_2V", - 5: "PS5000A_5V", - 10: "PS5000A_10V", - 20: "PS5000A_20V", -} - - -def generate_unique_filename(base_path: str) -> str: - """Generate a unique filename by appending timestamp if file exists.""" - import os - from pathlib import Path - - if not os.path.exists(base_path): - return base_path - - # Split the path into parts - path = Path(base_path) - stem = path.stem - suffix = path.suffix - parent = path.parent - - # Generate timestamp - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - - # Create new filename with timestamp - new_filename = f"{stem}_{timestamp}{suffix}" - new_path = parent / new_filename - - return str(new_path) - - -@click.command() -@click.option( - "--sample-rate", - "-s", - type=float, - default=20, - help="Sample rate in MS/s (e.g., 62.5). Use 0 for max rate. [default: 20]", -) -@click.option( - "--resolution", - "-b", - type=click.Choice(["8", "12", "16"]), - default="12", - help="Resolution in bits. [default: 12]", -) -@click.option( - "--rangev", - "-r", - type=click.Choice([str(k) for k in sorted(VOLTAGE_RANGE_MAP.keys())]), - default="20", - help=f"Voltage range in Volts. [default: 20]", -) -@click.option( - "--plot/--no-plot", - "-p", - is_flag=True, - default=True, - help="Enable/disable live plotting. [default: --plot]", -) -@click.option( - "--output", - "-o", - type=click.Path(dir_okay=False, writable=True), - help="Output HDF5 file (default: auto-timestamped).", -) -@click.option( - "--plot-window", - "-w", - type=float, - default=0.5, - help="Live plot display window duration in seconds. [default: 0.5]", -) -@click.option( - "--verbose", "-v", is_flag=True, default=False, help="Enable debug logging." -) -@click.option( - "--plot-npts", - "-n", - type=int, - default=4000, - help="Target number of points for the plot window. [default: 4000]", -) -@click.option( - "--hardware-downsample", - "-h", - type=int, - default=1, - help="Hardware down-sampling ratio (power of 2 for 'average' mode). [default: 1]", -) -@click.option( - "--downsample-mode", - "-m", - type=click.Choice(["average", "aggregate"]), - default="average", - help="Hardware down-sampling mode. [default: average]", -) -@click.option( - "--offset", - type=float, - default=0.0, - help="Analog offset in Volts (only for ranges < 5V). [default: 0.0]", -) -@click.option( - "--bandwidth", - type=click.Choice(["full", "20MHz"]), - default="full", - help="Bandwidth limiter to reduce noise. [default: full]", -) -@click.option( - "--max-buff-sec", - type=float, - help="Maximum buffer duration in seconds for live-only mode (limits file size).", -) -@click.option( - "--force", - "-f", - is_flag=True, - default=False, - help="Overwrite existing output file.", -) -@click.option( - "--y-min", - type=float, - help="Minimum Y-axis limit in mV for live plot.", -) -@click.option( - "--y-max", - type=float, - help="Maximum Y-axis limit in mV for live plot.", -) -def main( - sample_rate: float, - resolution: str, - rangev: str, - plot: bool, - output: Optional[str], - plot_window: float, - verbose: bool, - plot_npts: int, - hardware_downsample: int, - downsample_mode: str, - offset: float, - bandwidth: str, - max_buff_sec: Optional[float], - force: bool, - y_min: Optional[float], - y_max: Optional[float], -) -> None: - """High-speed data acquisition tool for Picoscope 5000a series.""" - # --- Argument Validation and Processing --- - - # Validate Y-axis limits - if (y_min is None) != (y_max is None): - logger.error("Both --y-min and --y-max must be provided together, or neither.") - sys.exit(1) - - if y_min is not None and y_max is not None and y_min >= y_max: - logger.error(f"Invalid Y-axis range: y_min ({y_min}) must be less than y_max ({y_max}).") - sys.exit(1) - - channel_range_str = VOLTAGE_RANGE_MAP[float(rangev)] - resolution_bits = int(resolution) - - app: Optional[QApplication] = None - if plot: - from PyQt5.QtWidgets import QApplication - - app = QApplication(sys.argv) - - # When plotting, SIGINT should gracefully close the Qt application. - # The main loop will then handle the shutdown. - def sigint_handler(_sig: int, _frame: Optional[object]) -> None: - logger.warning("Ctrl+C detected. Closing application.") - QApplication.quit() - - signal.signal(signal.SIGINT, sigint_handler) - - # Configure logging - logger.remove() - log_level = "DEBUG" if verbose else "INFO" - logger.add(sys.stderr, level=log_level) - logger.info(f"Logging configured at level: {log_level}") - - # Auto-generate filename if not specified - if not output: - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - output = f"./output_{timestamp}.hdf5" - else: - # Check if file exists and handle accordingly - if not force: - original_output = output - output = generate_unique_filename(output) - if output != original_output: - logger.info(f"File '{original_output}' exists. Using '{output}' instead.") - logger.info("Use --force/-f to overwrite existing files.") - - logger.info(f"Output file: {output}") - logger.info(f"Selected voltage range: {rangev}V -> {channel_range_str}") - - try: - # Create and run the streamer - streamer = Streamer( - sample_rate_msps=sample_rate, - resolution_bits=resolution_bits, - channel_range_str=channel_range_str, - enable_live_plot=plot, - output_file=output, - debug=verbose, - plot_window_s=plot_window, - plot_points=plot_npts, - hardware_downsample=hardware_downsample, - downsample_mode=downsample_mode, - offset_v=offset, - max_buffer_seconds=max_buff_sec, - y_min=y_min, - y_max=y_max, - bandwidth_limiter=bandwidth, - ) - - # Update the acquisition command in metadata - acquisition_command = " ".join(sys.argv) - streamer.update_acquisition_command(acquisition_command) - - streamer.run(app) - except RuntimeError as e: - if "PICO_NOT_FOUND" in str(e): - logger.critical( - "Picoscope device not found. Please check connection and ensure no other software is using it." - ) - else: - logger.critical(f"Failed to initialize Picoscope: {e}") - sys.exit(1) - - # --- Verification Step --- - # Skip verification in live-only mode since file size is limited - if not streamer.max_buffer_seconds: - logger.info(f"Verifying output file: {output}") - try: - expected_samples = streamer.consumer.values_written - if expected_samples == 0: - logger.warning("Consumer processed no samples. Nothing to verify.") - else: - with h5py.File(output, "r") as f: - if "adc_counts" not in f: - raise ValueError("Dataset 'adc_counts' not found in HDF5 file.") - - actual_samples = len(f["adc_counts"]) - if actual_samples == expected_samples: - logger.success( - f"Verification PASSED: File contains {actual_samples} samples, as expected." - ) - else: - logger.error( - f"Verification FAILED: Expected {expected_samples} samples, but file has {actual_samples}." - ) - except Exception as e: - logger.error(f"HDF5 file verification failed: {e}") - else: - logger.info("Skipping verification in live-only mode.") - - -if __name__ == "__main__": - main() diff --git a/picostream/consumer.py b/picostream/consumer.py deleted file mode 100644 index f48ff30..0000000 --- a/picostream/consumer.py +++ /dev/null @@ -1,160 +0,0 @@ -from __future__ import annotations - -import os -import queue -import threading -from typing import Any, Dict, List, Optional - -import h5py -import numpy as np -from loguru import logger - - -class Consumer: - """A data consumer that runs in a separate thread. - - This class retrieves data buffers from a queue, writes them to an HDF5 file, - and then returns the buffer index to an "empty" queue for reuse by the - producer. It handles file creation, data writing, and metadata storage. - - Supports live-only mode when max_samples is specified, which limits the - HDF5 file size by resetting and overwriting data when the limit is reached. - """ - - def __init__( - self, - buffer_size: int, - data_queue: queue.Queue[int], - empty_queue: queue.Queue[int], - data_buffers: List[np.ndarray], - file_name: str, - shutdown_event: threading.Event, - metadata: Dict[str, Any], - max_samples: Optional[int] = None, - ): - """Initializes the Consumer. - - Args: - buffer_size: The size of each individual data buffer. - data_queue: A queue for receiving indices of data-filled buffers. - empty_queue: A queue for returning indices of processed (empty) buffers. - data_buffers: A list of pre-allocated NumPy arrays for data. - file_name: The path to the output HDF5 file. - shutdown_event: A threading.Event to signal termination. - metadata: A dictionary of metadata to be saved as HDF5 attributes. - max_samples: Maximum number of samples to keep in file (for live-only mode). - """ - self.buffer_size = buffer_size - self.data_queue = data_queue - self.empty_queue = empty_queue - self.data_buffers = data_buffers - self.file_name = file_name - self.shutdown_event = shutdown_event - self.metadata = metadata - self.max_samples = max_samples - - self.values_written: int = 0 - self.empty_con_queue_count: int = 0 - self.buffer_resets: int = 0 - - def format_sample_count(self, count: int) -> str: - """Format a large integer count into a human-readable string. - - Uses metric prefixes (K, M, G) for thousands, millions, and billions. - - Args: - count: The integer number to format. - - Returns: - A formatted string representation of the count. - """ - if count >= 1_000_000_000: - return f"{count / 1_000_000_000:.2f}G" - if count >= 1_000_000: - return f"{count / 1_000_000:.2f}M" - if count >= 1_000: - return f"{count / 1_000:.2f}K" - else: - return str(count) - - def _processing_loop(self, dset: h5py.Dataset) -> None: - """ - Continuously processes data from the queue and writes to the HDF5 dataset. - - Args: - dset: The HDF5 dataset to write to. - """ - while not self.shutdown_event.is_set(): - try: - # Wait for a buffer index from the producer. - # A timeout allows the loop to periodically check the shutdown event. - idx = self.data_queue.get(timeout=0.1) - - # Append the new data to the HDF5 dataset. - buffer_len = len(self.data_buffers[idx]) - - # Check if we need to reset the buffer (live-only mode) - if self.max_samples and self.values_written + buffer_len > self.max_samples: - # Reset the dataset to start overwriting from the beginning - dset.resize((buffer_len,)) - dset[:] = self.data_buffers[idx] - self.values_written = buffer_len - self.buffer_resets += 1 - logger.debug(f"Buffer reset #{self.buffer_resets} - file size limited to {self.max_samples:,} samples") - else: - # Normal append - dset.resize((self.values_written + buffer_len,)) - dset[self.values_written :] = self.data_buffers[idx] - self.values_written += buffer_len - - # Return the buffer index to the empty queue for reuse. - self.empty_queue.put(idx) - - except queue.Empty: - # This occurs if the producer hasn't provided data within the timeout. - self.empty_con_queue_count += 1 - # This is expected when acquisition stops, so no need to log as a warning - if not self.shutdown_event.is_set(): - logger.debug("Consumer queue was empty.") - - def consume(self) -> None: - """The main loop for the consumer thread. - - This method continuously checks for data from the producer, writes it to - the HDF5 file, and returns the buffer for reuse. It handles file setup, - the main processing loop, and graceful shutdown. - """ - try: - # Ensure a clean slate by removing any pre-existing file. - if os.path.exists(self.file_name): - os.remove(self.file_name) - logger.info(f"Removed existing file: {self.file_name}") - - with h5py.File(self.file_name, "w") as f: - # Write the collected metadata to the HDF5 file's attributes. - for key, value in self.metadata.items(): - if value is not None: - f.attrs[key] = value - - # Create a resizable dataset for the ADC data. - # Chunking is aligned with the buffer size for efficient writes. - dset = f.create_dataset( - "adc_counts", - (0,), - maxshape=(None,), - dtype="int16", - chunks=(self.buffer_size,), - ) - - self._processing_loop(dset) - except (IOError, OSError) as e: - # A critical file error means we cannot continue. - logger.critical(f"Failed to create or write to HDF5 file: {e}") - self.shutdown_event.set() # Signal other threads to shut down. - return - - logger.info( - f"Consumer couldn't obtain data from queue {self.empty_con_queue_count} times." - ) - if self.buffer_resets > 0: - logger.info(f"Buffer was reset {self.buffer_resets} times (live-only mode).") diff --git a/picostream/data_pipeline.py b/picostream/data_pipeline.py new file mode 100644 index 0000000..c5587dd --- /dev/null +++ b/picostream/data_pipeline.py @@ -0,0 +1,622 @@ +"""Data pipeline for processing PicoScope streaming data. + +This module provides a unified data processing pipeline that handles: +- Sample rate calculations for different downsample modes +- Data decimation for display optimization +- ADC to millivolt conversion +- Time axis generation with proper accounting for decimation and modes +- Lag metrics between producer and consumer +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, NamedTuple, Optional, Tuple + +import numpy as np +from loguru import logger + +from picostream.acquisition_rate import ( + AcquisitionRate, + DisplayPointCount, + DownsampleMode, + StorageSampleCount, +) +from picostream.conversion_utils import adc_to_mV, min_max_decimate_numba + +if TYPE_CHECKING: + pass # Already imported above + + +class LagMetrics(NamedTuple): + """Metrics for producer-consumer lag. + + Attributes + ---------- + samples : int + Number of samples of lag. + milliseconds : float + Lag duration in milliseconds. + """ + + samples: int + milliseconds: float + + +@dataclass(frozen=True) +class DisplayGeometry: + """Complete display geometry for a window, computed in one place. + + This is the single source of truth for how many points to cache, + how long the time axis is, and whether min/max pairs are used. + Both the cache sizing and the time axis generation must use values + from the same DisplayGeometry instance to stay in sync. + + Attributes + ---------- + time_axis : np.ndarray + Time values in seconds, right-aligned within the window. + Length equals ``max_display_values``. + max_display_values : int + Total number of display values to keep in the cache per channel. + Equals ``n_time_points * values_per_point``. + n_time_points : int + Number of unique time points (before any min/max doubling). + values_per_point : int + 1 for raw samples, 2 for min/max pairs. + decimation : int + The software decimation factor used to compute this geometry. + window_samples : int + Number of storage samples that fit in the window duration. + """ + + time_axis: np.ndarray + max_display_values: int + n_time_points: int + values_per_point: int + decimation: int + window_samples: int + + +@dataclass(frozen=True) +class PipelineConfig: + """Immutable configuration set at acquisition start. + + This configuration is frozen at acquisition start to avoid race + conditions and ensure consistent processing throughout acquisition. + + Attributes + ---------- + resolution : int + ADC resolution in bits. + voltage_ranges : Dict[int, float] + Mapping of channel index to voltage range in Volts. + offsets_v : Dict[int, float] + Mapping of channel index to DC offset in Volts. + max_adc_value : int + Maximum ADC count value for the configured resolution. + target_plot_points : Optional[int] + Target number of points to display on screen, or None for lossless mode. + """ + + resolution: int + voltage_ranges: Dict[int, float] + offsets_v: Dict[int, float] + max_adc_value: int + target_plot_points: Optional[int] + + +@dataclass +class ProcessedChunk: + """Result of processing a chunk of data. + + Attributes + ---------- + time_axis : np.ndarray + Time axis in seconds (relative to window end). + voltage_data : Dict[int, np.ndarray] + Processed voltage data per channel in mV. + decimation : int + Decimation factor used for this chunk. + """ + + time_axis: np.ndarray + voltage_data: Dict[int, np.ndarray] + decimation: int + + +class DataPipeline: + """Consumer-side data processing: decimation, conversion, time axis. + + This class owns the complete consumer-side data processing pipeline, + consolidating previously scattered logic for decimation, ADC conversion, + and time axis generation into a single source of truth. + + All rate calculations are delegated to the AcquisitionRate object, + ensuring consistent semantics throughout the system. + + Parameters + ---------- + config : PipelineConfig + Immutable configuration for conversion and decimation. + acquisition_rate : AcquisitionRate + Rate information with all downsampling and channel details. + + Attributes + ---------- + config : PipelineConfig + The pipeline configuration. + acquisition_rate : AcquisitionRate + The acquisition rate for consistent rate calculations. + _remainder : Dict[int, np.ndarray] + Leftover samples from previous decimation (per channel). + _last_decimation : int + Last decimation factor used (for detecting changes). + """ + + def __init__(self, config: PipelineConfig, acquisition_rate: AcquisitionRate): + self.config = config + self.acquisition_rate = acquisition_rate + + # Processing state (owned here, not scattered) + self._remainder: Dict[int, np.ndarray] = {} + self._last_decimation: int = 1 + + logger.debug( + "DataPipeline initialised: resolution={} bits, " + "channels={}, mode={}, ratio={} → {:.2f} MS/s per channel", + config.resolution, + acquisition_rate.num_channels, + acquisition_rate.downsample_mode.value, + acquisition_rate.downsample_ratio, + acquisition_rate.per_channel_rate_hz / 1e6, + ) + + def calculate_decimation(self, window_samples: int) -> int: + """Calculate decimation factor for display optimization. + + Parameters + ---------- + window_samples : int + Number of samples in the display window. + + Returns + ------- + int + Decimation factor (minimum 1). + """ + # Lossless mode: no decimation + if self.config.target_plot_points is None: + return 1 + return max(1, window_samples // self.config.target_plot_points) + + def should_invalidate_cache(self, decimation: int) -> bool: + """Check if cache should be invalidated due to decimation change. + + Parameters + ---------- + decimation : int + New decimation factor. + + Returns + ------- + bool + True if decimation has changed and cache should be reset. + """ + return decimation != self._last_decimation + + def invalidate_cache(self) -> None: + """Invalidate processing caches. + + Called when decimation factor changes to reset remainder buffers + and signal that full reprocessing is needed. + """ + self._remainder.clear() + self._last_decimation = 1 + logger.debug("Pipeline cache invalidated") + + def process_channel_data( + self, + raw_data: np.ndarray, + channel: int, + decimation: int, + ) -> Tuple[np.ndarray, bool]: + """Process raw ADC data for a single channel. + + This method decimates the data, converts ADC to mV, and applies offsets. + It handles remainder samples from previous calls to ensure no data loss. + + Parameters + ---------- + raw_data : np.ndarray + Raw int16 ADC data for this channel. + channel : int + Channel index. + decimation : int + Decimation factor for display optimization. + + Returns + ------- + Tuple[np.ndarray, bool] + (voltage_data, has_data) where voltage_data is the processed + voltage in mV and has_data indicates if any data was produced. + """ + if len(raw_data) == 0: + return np.array([], dtype=np.float64), False + + # Combine with remainder from previous chunk + remainder = self._remainder.get(channel) + if remainder is not None and len(remainder) > 0: + combined = np.concatenate([remainder, raw_data]) + else: + combined = raw_data + + n_combined = len(combined) + n_decimated = (n_combined // decimation) * decimation + + # Perform decimation if we have enough samples + if n_decimated > 0: + to_decimate = combined[:n_decimated] + decimated = min_max_decimate_numba(to_decimate, decimation) + else: + decimated = np.array([], dtype=np.int16) + + # Store remainder for next call + remainder_len = n_combined - n_decimated + if remainder_len > 0: + self._remainder[channel] = combined[-remainder_len:].copy() + else: + self._remainder[channel] = np.array([], dtype=np.int16) + + if len(decimated) == 0: + return np.array([], dtype=np.float64), False + + # Convert ADC to mV + voltage_range = self.config.voltage_ranges.get(channel) + if voltage_range is not None and self.config.max_adc_value > 0: + try: + voltage_data = adc_to_mV( + decimated, voltage_range, self.config.max_adc_value + ) + # Apply offset + offset = self.config.offsets_v.get(channel, 0.0) + voltage_data = voltage_data + (offset * 1000.0) + except Exception: + logger.exception("Voltage conversion failed for channel {}", channel) + voltage_data = decimated.astype(np.float64) + else: + voltage_data = decimated.astype(np.float64) + + return voltage_data, True + + def create_time_axis_for_window( + self, + window_duration: float, + decimation: int, + data_duration: Optional[float] = None, + ) -> np.ndarray: + """Create a time axis for a display window. + + Delegates to ``get_window_display_geometry()`` to ensure the time + axis length is always consistent with cache sizing. + + Parameters + ---------- + window_duration : float + Full duration of the display window in seconds. + decimation : int + Current software decimation factor (1 for lossless). + data_duration : Optional[float] + Actual duration that the data represents. If provided and less + than window_duration, the time axis is right-aligned within + the window. Default is None (span full window). + + Returns + ------- + np.ndarray + Time axis array in seconds, right-aligned within the window. + """ + geometry = self.get_window_display_geometry( + window_duration, decimation, data_duration + ) + return geometry.time_axis + + def get_window_display_geometry( + self, + window_duration: float, + decimation: int, + data_duration: Optional[float] = None, + ) -> DisplayGeometry: + """Compute the complete display geometry for a window. + + This is the single authoritative method for determining cache size, + time axis, and values-per-point. Both cache sizing and time axis + generation derive from the same call, so they cannot diverge. + + Parameters + ---------- + window_duration : float + Full duration of the display window in seconds. + decimation : int + Current software decimation factor (1 for lossless). + data_duration : Optional[float] + Actual data duration if less than window. When provided, + the time axis is right-aligned within the window. + + Returns + ------- + DisplayGeometry + Complete geometry with time axis, cache sizing, and metadata. + """ + if window_duration <= 0: + empty = np.array([], dtype=np.float64) + return DisplayGeometry( + time_axis=empty, + max_display_values=0, + n_time_points=0, + values_per_point=1, + decimation=decimation, + window_samples=0, + ) + + window_samples = self.time_to_samples(window_duration) + + is_aggregate = ( + self.acquisition_rate.downsample_mode == DownsampleMode.AGGREGATE + and self.acquisition_rate.downsample_ratio > 1 + ) + # Software decimation ALWAYS produces min/max pairs when decimation > 1, + # regardless of target_plot_points (lossless vs quality mode). + # This is because min_max_decimate_numba always outputs pairs. + has_software_minmax = not is_aggregate and decimation > 1 + + if is_aggregate: + # AGGREGATE mode: window_samples already includes min/max pairs (storage_rate doubles). + # Each "time point" is a min/max pair (2 storage samples). + # So n_time_points = window_samples // 2 (pairs) // decimation. + values_per_point = 2 + n_time_points = int(window_samples // 2 // decimation) + elif has_software_minmax: + # Software decimation produces min/max pairs. + values_per_point = 2 + n_time_points = int(window_samples // decimation) + else: + # Normal mode: one value per sample. + values_per_point = 1 + n_time_points = int(window_samples // decimation) + + max_display_values = n_time_points * values_per_point + + if data_duration is not None and data_duration < window_duration: + start_time = window_duration - data_duration + else: + start_time = 0.0 + + if n_time_points > 0: + time_values = np.linspace(start_time, window_duration, n_time_points) + else: + time_values = np.array([], dtype=np.float64) + + if values_per_point == 2 and len(time_values) > 0: + time_axis = np.repeat(time_values, 2) + else: + time_axis = time_values + + return DisplayGeometry( + time_axis=time_axis, + max_display_values=max_display_values, + n_time_points=n_time_points, + values_per_point=values_per_point, + decimation=decimation, + window_samples=window_samples, + ) + + def get_lag_metrics(self, producer_idx: int, consumer_idx: int) -> LagMetrics: + """Calculate lag between producer and consumer. + + Parameters + ---------- + producer_idx : int + Current write index from producer. + consumer_idx : int + Last read index from consumer. + + Returns + ------- + LagMetrics + Lag in samples and milliseconds. + """ + lag_samples = producer_idx - consumer_idx + if lag_samples < 0: + lag_samples = 0 + + storage_rate = self.acquisition_rate.storage_rate_hz + if storage_rate > 0: + lag_seconds = lag_samples / storage_rate + lag_ms = lag_seconds * 1000.0 + else: + lag_ms = 0.0 + + return LagMetrics(lag_samples, lag_ms) + + def format_lag(self, producer_idx: int, consumer_idx: int) -> str: + """Format lag for UI display. + + Parameters + ---------- + producer_idx : int + Current write index from producer. + consumer_idx : int + Last read index from consumer. + + Returns + ------- + str + Formatted lag string with two lines (e.g., "12ms
2.45MS"). + """ + metrics = self.get_lag_metrics(producer_idx, consumer_idx) + lag_samples_formatted = self._format_sample_count_with_unit(metrics.samples) + return f"{metrics.milliseconds:.0f}ms
{lag_samples_formatted}" + + def _format_sample_count_with_unit(self, count: int) -> str: + """Format sample count with appropriate SI unit prefix. + + Always uses MSamples (mega) as the primary unit since typical + lag values are in the megasample range. + + Parameters + ---------- + count : int + Sample count. + + Returns + ------- + str + Formatted count with unit (e.g., "2.45 MS", "850 kS"). + """ + if count >= 1_000_000: + val = count / 1_000_000 + return f"{val:.2f} MS".replace(".00 ", " ").replace(".0 ", " ") + if count >= 1_000: + return f"{count / 1_000:.1f} kS" + return f"{count} S" + + def samples_to_time(self, n_samples: StorageSampleCount) -> float: + """Convert sample count to duration. + + Parameters + ---------- + n_samples : StorageSampleCount + Number of storage samples. + + Returns + ------- + float + Duration in seconds. + """ + return self.acquisition_rate.samples_to_seconds(n_samples) + + def time_to_samples(self, duration_s: float) -> StorageSampleCount: + """Convert duration to sample count. + + Parameters + ---------- + duration_s : float + Duration in seconds. + + Returns + ------- + StorageSampleCount + Number of samples. + """ + return self.acquisition_rate.seconds_to_samples(duration_s) + + def get_display_duration( + self, n_display_samples: DisplayPointCount, decimation: int + ) -> float: + """Calculate the time duration represented by display samples. + + DEPRECATED: This method mixes concerns by trying to reverse-engineer + duration from display points. Use samples_to_seconds() with storage + samples instead. + + This accounts for both software min-max decimation (which produces + pairs of min/max values) and hardware AGGREGATE mode (which stores + min/max pairs in the ring buffer). + + Parameters + ---------- + n_display_samples : DisplayPointCount + Number of display points (after software decimation). + decimation : int + Software decimation factor used. + + Returns + ------- + float + Duration in seconds that the data represents. + """ + return self.acquisition_rate.get_display_duration(n_display_samples, decimation) + + def get_max_cache_points(self, window_samples: int, decimation: int) -> int: + """Calculate maximum display values for cache. + + Delegates to ``get_window_display_geometry()`` so cache sizing + always matches time axis length exactly. + + Parameters + ---------- + window_samples : int + Number of samples in the display window. + decimation : int + Decimation factor. + + Returns + ------- + int + Maximum number of display values to keep in cache per channel. + """ + window_duration = self.samples_to_time(window_samples) + geometry = self.get_window_display_geometry(window_duration, decimation) + return geometry.max_display_values + + def get_max_time_points(self, window_samples: int, decimation: int) -> int: + """Calculate maximum time points for time axis creation. + + Parameters + ---------- + window_samples : int + Number of samples in the display window. + decimation : int + Decimation factor. + + Returns + ------- + int + Maximum number of time points (not display values). + """ + return window_samples // decimation + + def is_aggregate_mode(self) -> bool: + """Check if current mode is AGGREGATE. + + Returns + ------- + bool + True if mode is AGGREGATE. + """ + return self.acquisition_rate.downsample_mode == DownsampleMode.AGGREGATE + + def calculate_display_points( + self, n_input_samples: int, channel: int, decimation: int + ) -> int: + """Calculate the number of display points that will be produced. + + Accounts for current remainder state and min-max pair expansion + when software decimation is active. + + Parameters + ---------- + n_input_samples : int + Number of input samples to process. + channel : int + Channel index (remainder is per-channel). + decimation : int + Decimation factor. + + Returns + ------- + int + Number of display points that will be produced. + """ + if n_input_samples <= 0 or decimation <= 0: + return 0 + + remainder = self._remainder.get(channel) + remainder_len = len(remainder) if remainder is not None else 0 + total_samples = remainder_len + n_input_samples + + n_groups = total_samples // decimation + n_points = n_groups * 2 # min-max pairs + + return n_points diff --git a/picostream/device.py b/picostream/device.py new file mode 100644 index 0000000..3681f79 --- /dev/null +++ b/picostream/device.py @@ -0,0 +1,2794 @@ +import ctypes +import os +import threading +import time +from datetime import datetime, timezone +from typing import Any, Callable, Dict, List, Optional, Tuple + +import numpy as np +from loguru import logger + +from picostream.acquisition_rate import AcquisitionRate, DownsampleMode + +# Lazy import of PicoSDK +try: + from picosdk.functions import PICO_STATUS, adc2mV, assert_pico_ok + from picosdk.ps5000a import ps5000a as ps # type: ignore +except ImportError: + ps = None + adc2mV = None + assert_pico_ok = None + PICO_STATUS = None + logger.warning("PicoSDK not available - Picoscope functionality will be limited") + +try: + import h5py # type: ignore +except ImportError: + h5py = None + logger.warning("h5py not available - PicoscopeBufferedStream will be limited") + +from labdaemon.exceptions import ( + DeviceConfigurationError, + DeviceConnectionError, + DeviceNotConnectedError, + DeviceOperationError, + TaskCancelledError, +) + + +class Picoscope: + """ + Device plugin for Picoscope 5000a series digitizers. + + This is a synchronous implementation for the LabDaemon framework, providing + blocking data acquisition functionality. + """ + + # Valid voltage ranges for 5000a series + VOLTAGE_RANGES = { + 0.01: "PS5000A_10MV", + 0.02: "PS5000A_20MV", + 0.05: "PS5000A_50MV", + 0.1: "PS5000A_100MV", + 0.2: "PS5000A_200MV", + 0.5: "PS5000A_500MV", + 1.0: "PS5000A_1V", + 2.0: "PS5000A_2V", + 5.0: "PS5000A_5V", + 10.0: "PS5000A_10V", + 20.0: "PS5000A_20V", + } + + # Hardware limits for 5000a series + MAX_SAMPLE_RATE_HZ = 1_000_000_000 # 1 GS/s for single channel + MAX_DOWNSAMPLE_RATIO = 2**32 - 1 # uint32 limit in PicoSDK + SUPPORTED_DOWNSAMPLE_MODES = {"NONE", "AVERAGE", "DECIMATE", "AGGREGATE"} + + def __init__( + self, + device_id: str, + serial_code: Optional[str] = None, + resolution: int = 12, + **kwargs: Any, + ) -> None: + """Initialise the Picoscope device.""" + self.device_id = device_id + self.serial_code = serial_code + self.resolution = resolution + + if ps is None: + raise ImportError( + "PicoSDK not found. Please install it to use the Picoscope plugin." + ) + self._ps = ps + + if resolution not in {8, 12, 14, 15, 16}: + raise DeviceConfigurationError( + f"Invalid resolution: {resolution}. Must be one of: 8, 12, 14, 15, 16" + ) + + self._chandle: Optional[ctypes.c_int16] = None + self._is_connected: bool = False + self._enabled_channels: List[int] = [] + self._channel_ranges: Dict[int, str] = {} + self._channel_voltage_ranges: Dict[int, float] = {} + self._buffers: Dict[int, np.ndarray] = {} + self._sample_interval_s: Optional[float] = None + self._max_adc_value: Optional[ctypes.c_int16] = None + self._current_acquisition_params: Dict[str, Any] = {} + self._downsample_ratio: int = 1 + self._downsample_mode: str = "NONE" + + logger.info("Picoscope plugin initialised for {}", self.device_id) + + def get_connection_id(self) -> Optional[str]: + """ + Get the Picoscope serial code. + + Returns + ------- + Optional[str] + The serial code, or None if auto-detect is used. + """ + return self.serial_code + + def set_connection_id(self, connection_id: Optional[str]) -> None: + """ + Set the Picoscope serial code for connection. + + Pass None or empty string to use auto-detect. + + Parameters + ---------- + connection_id : Optional[str] + Serial code (e.g., "AW123/456"), or None/empty for auto-detect. + + Raises + ------ + ValueError + If connection_id is not a string or None. + """ + if connection_id is None: + self.serial_code = None + elif isinstance(connection_id, str): + self.serial_code = connection_id if connection_id else None + else: + raise ValueError("connection_id must be a string or None") + logger.info( + "Picoscope {} connection_id set to: {}", + self.device_id, + self.serial_code or "auto-detect", + ) + + def connect(self) -> None: + """Establishes a connection to the Picoscope device.""" + if self._is_connected: + logger.info("Device {} is already connected.", self.device_id) + return + + self._chandle = ctypes.c_int16() + res_enum = self._ps.PS5000A_DEVICE_RESOLUTION[ # pyright: ignore[reportAttributeAccessIssue] + f"PS5000A_DR_{self.resolution}BIT" + ] + serial_number = self.serial_code.encode("utf-8") if self.serial_code else None + + try: + status = self._ps.ps5000aOpenUnit( + ctypes.byref(self._chandle), serial_number, res_enum + ) + self._assert_pico_ok(status) + self._is_connected = True + + # Get max ADC value from driver + # Note: PicoScope 5000a returns the SAME max_adc for all resolutions + # because it uses a 16-bit ADC internally. Higher resolutions just use + # more effective bits. We store both resolution and max_adc for clarity. + self._max_adc_value = ctypes.c_int16() + status = self._ps.ps5000aMaximumValue( + self._chandle, ctypes.byref(self._max_adc_value) + ) + self._assert_pico_ok(status) + + # Validate: max_adc should match what the hardware actually uses + # For PicoScope 5000a, this is typically 32767 (15-bit signed) regardless of resolution + expected_max_for_res = (1 << (self.resolution - 1)) - 1 + logger.info( + "Connected to Picoscope {}: {}-bit resolution, " + "driver max_adc={}, theoretical max for resolution={}", + self.device_id, + self.resolution, + self._max_adc_value.value, + expected_max_for_res, + ) + + if self._max_adc_value.value != expected_max_for_res: + logger.warning( + "max_adc mismatch: driver returns {} but {}-bit resolution " + "theoretical max is {}. Using driver value for conversion.", + self._max_adc_value.value, + self.resolution, + expected_max_for_res, + ) + except Exception as e: + self._is_connected = False + raise DeviceConnectionError( + f"Failed to connect to Picoscope device {self.device_id}: {e}" + ) from e + + def disconnect(self) -> None: + """Closes the connection to the Picoscope device. + + Cleanup order is important: this method should only be called after + any streaming or acquisition has been stopped. The typical sequence is: + 1. Stop streaming/acquisition (streaming thread) + 2. Stop hardware (ps5000aStop) + 3. Close unit (ps5000aCloseUnit) - this method + + Reversing this order can cause deadlocks or handle leaks. + """ + if not self._is_connected or self._chandle is None: + logger.info("Device {} is already disconnected.", self.device_id) + return + + try: + status = self._ps.ps5000aCloseUnit(self._chandle) + self._assert_pico_ok(status) + except Exception as e: + logger.exception( + "Error disconnecting from Picoscope device {}: {}", self.device_id, e + ) + finally: + self._chandle = None + self._is_connected = False + logger.info("Disconnected from Picoscope device {}", self.device_id) + + def acquire_block( + self, + sample_rate: float, + num_samples: int, + channels: List[int], + voltage_ranges: List[float], + trigger_source: Optional[str] = None, + trigger_threshold_v: float = 0.5, + trigger_direction: str = "RISING", + trigger_delay_s: float = 0.0, + timeout_s: float = 10.0, + downsample_ratio: int = 1, + downsample_mode: str = "NONE", + **kwargs: Any, + ) -> Dict[str, Any]: + """ + Acquires a single block of data from the Picoscope. + + This is a convenience method that arms the acquisition and waits for it + to complete. For more control, use `arm_acquisition` and + `wait_for_acquisition` separately. + + Can be configured to wait for a hardware trigger before starting. + """ + if not self._is_connected or self._chandle is None: + raise DeviceNotConnectedError(f"Device {self.device_id} is not connected.") + + self.arm_acquisition( + sample_rate=sample_rate, + num_samples=num_samples, + channels=channels, + voltage_ranges=voltage_ranges, + trigger_source=trigger_source, + trigger_threshold_v=trigger_threshold_v, + trigger_direction=trigger_direction, + trigger_delay_s=trigger_delay_s, + downsample_ratio=downsample_ratio, + downsample_mode=downsample_mode, + **kwargs, + ) + return self.wait_for_acquisition(timeout_s=timeout_s) + + def arm_acquisition( + self, + sample_rate: float, + num_samples: int, + channels: List[int], + voltage_ranges: List[float], + trigger_source: Optional[str] = None, + trigger_threshold_v: float = 0.5, + trigger_direction: str = "RISING", + trigger_delay_s: float = 0.0, + downsample_ratio: int = 1, + downsample_mode: str = "NONE", + **kwargs: Any, + ) -> None: + """ + Arms the Picoscope for a block mode acquisition. + + This configures the channels, timebase, and trigger, then sets up the + data buffers and starts the acquisition. The device will then wait for + a trigger (if configured) or start capturing immediately. This method + returns immediately without waiting for the capture to finish. + + Note + ---- + The num_samples parameter is interpreted as the desired acquisition + duration, relative to the requested sample rate. If the hardware + cannot achieve the requested sample rate, num_samples will be adjusted + to keep the same acquisition duration at the actual achievable rate. + """ + if not self._is_connected or self._chandle is None: + raise DeviceNotConnectedError(f"Device {self.device_id} is not connected.") + if not channels: + raise DeviceConfigurationError("No channels specified for acquisition.") + if len(channels) != len(voltage_ranges): + raise DeviceConfigurationError( + "Length of 'channels' and 'voltage_ranges' must match." + ) + + logger.debug( + "Arming block acquisition on device {} at requested {} Hz for {} samples on channels {}.{}", + self.device_id, + sample_rate, + num_samples, + channels, + f" Waiting for {trigger_source} trigger." if trigger_source else "", + ) + + # Reset device to a clean state before configuring block mode + self._reset_device_state() + + try: + # 1. Configure channels and store downsampling parameters + for ch, v_range in zip(channels, voltage_ranges, strict=True): + self._configure_channel(ch, True, "DC", v_range) + self._store_downsampling_params(downsample_ratio, downsample_mode) + + # 2. Probe actual achievable sample rate + desired_interval = 1.0 / sample_rate + timebase, actual_interval = self._get_timebase( + desired_interval, num_samples + ) + actual_sample_rate = 1.0 / actual_interval + self._sample_interval_s = actual_interval + + # 3. Adjust num_samples to keep the same acquisition duration at the actual rate + # Calculate implied duration from requested parameters + requested_duration = num_samples * desired_interval + adjusted_num_samples = int(round(requested_duration / actual_interval)) + + rate_mismatch_pct = ( + abs(actual_sample_rate - sample_rate) / sample_rate * 100 + ) + if rate_mismatch_pct > 1.0: + logger.warning( + "Sample rate mismatch on device {}: requested {:.0f} Hz, achieved {:.0f} Hz ({:.1f}% difference). Adjusted num_samples: {} -> {} to maintain {:.2f}s duration.", + self.device_id, + sample_rate, + actual_sample_rate, + rate_mismatch_pct, + num_samples, + adjusted_num_samples, + requested_duration, + ) + + num_samples = adjusted_num_samples + logger.debug( + "[PICO] Timebase result: timebase_idx={}, actual_interval={:.3e} s, actual_rate={:.0f} Hz", + timebase, + actual_interval, + actual_sample_rate, + ) + + # 4. Configure trigger + if trigger_source: + self._configure_trigger( + source=trigger_source, + threshold_v=trigger_threshold_v, + direction=trigger_direction, + delay_s=trigger_delay_s, + ) + else: + # Explicitly disable trigger if not requested + self._configure_trigger(source=None) + + # 5. Set up buffers + output_samples = num_samples // downsample_ratio + if output_samples < 1: + raise DeviceConfigurationError( + f"Calculated output samples is {output_samples}. This must be at least 1. " + "This can happen if `num_samples` is smaller than `downsample_ratio`." + ) + self._setup_block_buffers(channels, output_samples) + + # 6. Store acquisition parameters for later use + self._current_acquisition_params = { + "num_samples": output_samples, + "channels": channels, + "requested_sample_rate": sample_rate, + "actual_sample_rate": actual_sample_rate, + "downsample_ratio": downsample_ratio, + "downsample_mode": downsample_mode, + } + + # 7. Run acquisition + self._run_block(num_samples, timebase) + time.sleep(0.5) # ensure ready for trigger? + + except Exception as e: + logger.exception("Error arming acquisition on device {}", self.device_id) + raise DeviceOperationError( + f"Error arming acquisition on device {self.device_id}: {e}" + ) from e + + def wait_for_acquisition( + self, timeout_s: float = 10.0, cancel_event: Optional[threading.Event] = None + ) -> Dict[str, Any]: + """ + Waits for an armed acquisition to complete and retrieves the data. + + Parameters + ---------- + timeout_s : float, default 10.0 + Maximum time to wait for acquisition completion. + cancel_event : Optional[threading.Event], default None + If provided, the acquisition will be cancelled if this event is set. + + Returns + ------- + Dict[str, Any] + Dictionary containing acquired data and metadata. + + Raises + ------ + TimeoutError + If the acquisition does not complete within the timeout. + TaskCancelledError + If the cancel_event is set during acquisition. + + Notes + ----- + When cancel_event is set, this method will call cancel_acquisition() + and raise TaskCancelledError. This allows tasks to pass their + context.cancel_event to enable cooperative cancellation during + long acquisitions. + """ + if not self._is_connected or self._chandle is None: + raise DeviceNotConnectedError(f"Device {self.device_id} is not connected.") + if not self._current_acquisition_params: + raise DeviceOperationError("No acquisition is currently armed.") + + try: + # 1. Wait for completion + self._wait_for_capture(timeout_s=timeout_s, cancel_event=cancel_event) + + # 2. Retrieve and convert data + time_data, voltage_data, raw_adc_data = self._get_block_data() + + # 2b. Ensure block acquisition is fully stopped and buffers are de-registered + try: + status = self._ps.ps5000aStop(self._chandle) + if status != PICO_STATUS["PICO_OK"]: + logger.debug( + "ps5000aStop returned status {} after block retrieval on device {}.", + status, + self.device_id, + ) + except Exception: + logger.exception("Error stopping block acquisition after retrieval") + + # De-register block-mode buffers using the same pointers and size=0 + try: + for ch in list(self._enabled_channels): + buf = self._buffers.get(ch) + if buf is None: + continue + status = self._ps.ps5000aSetDataBuffer( + self._chandle, + self._get_channel_enum(ch), + buf.ctypes.data_as(ctypes.POINTER(ctypes.c_int16)), + 0, # size=0 de-registers this buffer for block mode + 0, # segmentIndex + self._ps.PS5000A_RATIO_MODE["PS5000A_RATIO_MODE_NONE"], + ) + if status != PICO_STATUS["PICO_OK"]: + logger.debug( + "ps5000aSetDataBuffer (deregister) status {} for channel {} on device {}.", + status, + ch, + self.device_id, + ) + except Exception: + logger.exception("Error de-registering block buffers after retrieval") + + # 3. Format and return results + params = self._current_acquisition_params + if self._max_adc_value is None: + raise DeviceOperationError("max_adc_value not available for metadata.") + + downsample_ratio = params.get("downsample_ratio", 1) + result = { + "data_v": voltage_data, + "data_adc": raw_adc_data, + "metadata": { + "sample_rate": params["actual_sample_rate"] / downsample_ratio, + "num_samples": len(time_data), + "channels": params["channels"], + "channel_voltage_ranges": [ + self._channel_voltage_ranges[ch] for ch in params["channels"] + ], + "max_adc": self._max_adc_value.value, + "time_data": time_data, + "requested_sample_rate": params["requested_sample_rate"], + "device_id": self.device_id, + "downsample_ratio": downsample_ratio, + "downsample_mode": params.get("downsample_mode", "NONE"), + }, + } + # Clear params after successful retrieval + self._current_acquisition_params = {} + return result + except TimeoutError: + # This is an expected outcome on timeout. Log and re-raise. + logger.warning("Acquisition timed out on device {}", self.device_id) + self._current_acquisition_params = {} + raise # Re-raise TimeoutError for the task to handle + except TaskCancelledError: + # This is an expected outcome, not an error. Log and re-raise. + logger.warning("Acquisition cancelled on device {}", self.device_id) + self._current_acquisition_params = {} + raise + except Exception as e: + logger.exception( + f"Error waiting for acquisition on device {self.device_id}" + ) + # Clear params on failure + self._current_acquisition_params = {} + raise DeviceOperationError( + f"Error waiting for acquisition on device {self.device_id}: {e}" + ) from e + + def _cancel_acquisition_unlocked(self) -> None: + """Stops a running acquisition. Assumes no lock is held.""" + if not self._is_connected or self._chandle is None: + logger.warning( + "Cannot cancel acquisition, device {} is not connected.", self.device_id + ) + return + + try: + status = self._ps.ps5000aStop(self._chandle) + # Don't assert OK on stop, as it can fail if no acquisition is running. + if status != PICO_STATUS["PICO_OK"]: + logger.debug( + f"ps5000aStop returned status {status} on device {self.device_id}. This is often normal." + ) + logger.info("Acquisition stop command sent to device {}", self.device_id) + except Exception as e: + logger.exception( + f"Error sending stop command on device {self.device_id}: {e}" + ) + + def cancel_acquisition(self) -> None: + """ + Stops a running acquisition on the Picoscope. This method is thread-safe. + """ + # This public method is wrapped by the device lock. + self._cancel_acquisition_unlocked() + + def disable_trigger(self) -> None: + """ + Disables the trigger on the Picoscope. + + This is useful to ensure the device is in a known state after a + triggered acquisition. + """ + if not self._is_connected or self._chandle is None: + logger.warning( + "Cannot disable trigger, device {} is not connected.", self.device_id + ) + return + try: + self._configure_trigger(source=None) + except Exception as e: + logger.exception( + f"Failed to disable trigger on device {self.device_id}: {e}" + ) + # Do not re-raise, as this is often for cleanup. + + def reset_device_state(self) -> None: + """Public method to reset the device to a clean, idle state.""" + self._reset_device_state() + + def _reset_device_state(self) -> None: + """ + Resets the device to a clean, idle state. + + This method is intended to be called before any new acquisition + (block or streaming) to prevent state pollution from previous + operations. It clears all buffers, resets channel configurations, + and disables the hardware trigger. + """ + logger.debug("Resetting device state for {}", self.device_id) + # 0. Stop any running acquisition. This is critical to return the hardware + # to a known idle state before re-configuration. + if self._is_connected and self._chandle is not None: + try: + status = self._ps.ps5000aStop(self._chandle) + if status != PICO_STATUS["PICO_OK"]: + # This is not an error. It's normal for ps5000aStop to return + # non-OK status if the device is already idle. + logger.debug( + f"ps5000aStop returned status {status} during reset. This is normal." + ) + except Exception: + # Log but do not re-raise. We want to continue the reset process. + logger.exception( + f"Error calling ps5000aStop during reset on device {self.device_id}" + ) + + # 1. Disable trigger to ensure device is not waiting for an event + self.disable_trigger() + + # 2. Explicitly disable all channels in the SDK to clear any previous configuration + if self._is_connected and self._chandle is not None: + for ch in range(4): # Channels A, B, C, D (0-3) + try: + channel_enum = self._get_channel_enum(ch) + status = self._ps.ps5000aSetChannel( + self._chandle, + channel_enum, + 0, + 0, + 0, + 0.0, # disabled + ) + if status != PICO_STATUS["PICO_OK"]: + logger.debug( + f"Failed to disable channel {ch} during reset: status {status}" + ) + except Exception: + logger.exception( + f"Error disabling channel {ch} during reset on device {self.device_id}" + ) + + # 3. De-register block-mode buffers using the original pointers (size=0) + if self._is_connected and self._chandle is not None: + for ch in range(4): # Channels A, B, C, D (0-3) + try: + buf = self._buffers.get(ch) + if buf is None: + continue # Only de-register buffers we actually set + channel_enum = self._get_channel_enum(ch) + status = self._ps.ps5000aSetDataBuffer( + self._chandle, + channel_enum, + buf.ctypes.data_as(ctypes.POINTER(ctypes.c_int16)), + 0, # size=0 de-registers this buffer for block mode + 0, # segmentIndex + self._ps.PS5000A_RATIO_MODE["PS5000A_RATIO_MODE_NONE"], + ) + if status != PICO_STATUS["PICO_OK"]: + logger.debug( + f"Failed to de-register block buffer for channel {ch} during reset: status {status}" + ) + except Exception: + logger.exception( + f"Error de-registering block buffer for channel {ch} during reset on device {self.device_id}" + ) + + # 4. Normalise segmentation/rapid-capture state to avoid lingering block configuration + if self._is_connected and self._chandle is not None: + try: + max_samples = ctypes.c_int32() + status = self._ps.ps5000aMemorySegments( + self._chandle, 1, ctypes.byref(max_samples) + ) + self._assert_pico_ok(status) + status = self._ps.ps5000aSetNoOfCaptures(self._chandle, 1) + self._assert_pico_ok(status) + except Exception: + logger.exception( + f"Error normalising segmentation state during reset on device {self.device_id}" + ) + + # 5. Clear software buffers and channel configurations + self._buffers.clear() + self._enabled_channels.clear() + self._channel_ranges.clear() + self._channel_voltage_ranges.clear() + + # 6. Reset acquisition parameters + self._sample_interval_s = None + self._downsample_ratio = 1 + self._downsample_mode = "NONE" + self._current_acquisition_params.clear() + + def _store_downsampling_params( + self, + downsample_ratio: int, + downsample_mode: str, + ) -> None: + """Validates and stores downsampling parameters for later use.""" + # Validate downsample ratio + if downsample_ratio < 1: + raise DeviceConfigurationError( + f"Downsample ratio must be >= 1, got {downsample_ratio}" + ) + if downsample_ratio > self.MAX_DOWNSAMPLE_RATIO: + raise DeviceConfigurationError( + f"Downsample ratio {downsample_ratio} exceeds hardware limit of {self.MAX_DOWNSAMPLE_RATIO}" + ) + + # Validate downsample mode + mode_upper = downsample_mode.upper() + if mode_upper not in self.SUPPORTED_DOWNSAMPLE_MODES: + raise DeviceConfigurationError( + f"Unsupported downsample mode: {downsample_mode}. " + f"Supported modes: {', '.join(self.SUPPORTED_DOWNSAMPLE_MODES)}" + ) + + # Store parameters for use in SDK calls + self._downsample_ratio = downsample_ratio + self._downsample_mode = mode_upper + + logger.debug( + "Stored downsampling parameters for device {}: ratio={}, mode={}", + self.device_id, + downsample_ratio, + downsample_mode, + ) + + def _configure_trigger( + self, + source: Optional[str], + threshold_v: float = 0.5, + direction: str = "RISING", + delay_s: float = 0.0, + ) -> None: + """Configures the trigger for the acquisition.""" + if source is None: + # Disable trigger + status = self._ps.ps5000aSetSimpleTrigger(self._chandle, 0, 0, 0, 0, 0, 0) + self._assert_pico_ok(status) + logger.debug("Trigger disabled for device {}", self.device_id) + return + + source_upper = source.upper() + if source_upper != "EXT": + raise DeviceConfigurationError( + f"Unsupported trigger source: {source}. Only 'EXT' is supported." + ) + + source_enum = self._ps.PS5000A_CHANNEL["PS5000A_EXTERNAL"] + threshold_adc = self._v_to_adc_trigger(threshold_v) + + valid_directions = {"RISING", "FALLING"} + direction_upper = direction.upper() + if direction_upper not in valid_directions: + raise DeviceConfigurationError( + f"Invalid trigger direction: {direction}. Supported: RISING, FALLING" + ) + direction_enum = self._ps.PS5000A_THRESHOLD_DIRECTION[ + f"PS5000A_{direction_upper}" + ] + + if self._sample_interval_s is None: + raise DeviceOperationError( + "Sample interval must be set before configuring trigger delay." + ) + delay_samples = int(delay_s / self._sample_interval_s) + + status = self._ps.ps5000aSetSimpleTrigger( + self._chandle, + 1, # enabled + source_enum, + threshold_adc, + direction_enum, + delay_samples, + 0, # autoTrigger_ms=0 -> wait indefinitely + ) + self._assert_pico_ok(status) + logger.debug( + "Configured trigger on {} at {}V ({}) for device {}", + source, + threshold_v, + direction, + self.device_id, + ) + + def _v_to_adc_trigger(self, voltage_v: float) -> int: + """Converts a voltage to ADC counts for the external trigger input.""" + if not -5.0 <= voltage_v <= 5.0: + raise DeviceConfigurationError( + "External trigger threshold must be between -5V and +5V." + ) + if self._max_adc_value is None: + raise DeviceOperationError( + "Cannot convert voltage to ADC counts, max_adc_value not set." + ) + + # External trigger range is fixed at +/- 5V + trigger_range_mv = 5000.0 + + # Convert voltage to millivolts + voltage_mv = voltage_v * 1000.0 + + # Calculate ADC value + adc_value = int((voltage_mv / trigger_range_mv) * self._max_adc_value.value) + return adc_value + + def _configure_channel( + self, + channel_id: int, + enabled: bool, + coupling: str, + voltage_range: float, + offset: float = 0.0, + ) -> None: + """Configures a single input channel.""" + if not 0 <= channel_id <= 3: + raise DeviceConfigurationError( + f"Invalid channel ID: {channel_id}. Must be between 0 and 3." + ) + if voltage_range not in self.VOLTAGE_RANGES: + raise DeviceConfigurationError( + f"Invalid voltage range: {voltage_range}V. " + f"Valid ranges: {sorted(self.VOLTAGE_RANGES.keys())}V" + ) + range_str = self.VOLTAGE_RANGES[voltage_range] + + coupling_enum = self._ps.PS5000A_COUPLING[f"PS5000A_{coupling.upper()}"] + channel_enum = self._get_channel_enum(channel_id) + range_enum = self._ps.PS5000A_RANGE[range_str] + + status = self._ps.ps5000aSetChannel( + self._chandle, channel_enum, int(enabled), coupling_enum, range_enum, offset + ) + self._assert_pico_ok(status) + + if enabled: + if channel_id not in self._enabled_channels: + self._enabled_channels.append(channel_id) + self._channel_ranges[channel_id] = range_str + self._channel_voltage_ranges[channel_id] = voltage_range + else: + if channel_id in self._enabled_channels: + self._enabled_channels.remove(channel_id) + if channel_id in self._channel_ranges: + del self._channel_ranges[channel_id] + if channel_id in self._channel_voltage_ranges: + del self._channel_voltage_ranges[channel_id] + + def _get_timebase( + self, desired_interval: float, num_samples: int + ) -> Tuple[int, float]: + """ + Calculate the correct timebase for the desired sample interval. + + Uses the appropriate formula based on the device's resolution to calculate + the timebase index, then validates it with the PicoSDK. + """ + import math + + # Determine max sample rate and timebase formula based on resolution + if self.resolution == 8: + max_sample_rate = 1_000_000_000.0 + # For 8-bit: interval = (timebase + 1) * 1ns for timebase 0-4, then (timebase - 4) * 2ns + if desired_interval <= 5e-9: + timebase_idx = math.ceil(desired_interval / 1e-9) - 1 + else: + timebase_idx = math.ceil(desired_interval / 2e-9) + 4 + elif self.resolution == 12: + max_sample_rate = 500_000_000.0 + # For 12-bit: interval = (timebase + 1) * 2ns for timebase 0-2, then (timebase - 2) * 4ns + if desired_interval <= 6e-9: + timebase_idx = math.ceil(desired_interval / 2e-9) - 1 + else: + timebase_idx = math.ceil(desired_interval / 4e-9) + 2 + elif self.resolution == 14: + max_sample_rate = 125_000_000.0 + # For 14-bit: interval = (timebase + 1) * 8ns + timebase_idx = math.ceil(desired_interval / 8e-9) - 1 + else: # 15 and 16 bit + max_sample_rate = 62_500_000.0 + # For 15/16-bit: interval = (timebase + 1) * 16ns + timebase_idx = math.ceil(desired_interval / 16e-9) - 1 + + # Validate that the desired sample rate is within hardware limits for the resolution + desired_sample_rate = 1.0 / desired_interval + if desired_sample_rate > max_sample_rate: + raise DeviceConfigurationError( + f"Requested sample rate {desired_sample_rate / 1e6:.1f} MS/s exceeds " + f"hardware limit of {max_sample_rate / 1e6:.1f} MS/s for {self.resolution}-bit resolution." + ) + + # Ensure timebase is not negative + if timebase_idx < 0: + timebase_idx = 0 + + # Verify the calculated timebase with the SDK and get the actual interval. + # The SDK will error out if this timebase is invalid for the current mode. + timebase = ctypes.c_uint32(int(timebase_idx)) + interval = ctypes.c_float() + max_samples = ctypes.c_int32() + + status = self._ps.ps5000aGetTimebase2( + self._chandle, + timebase, + ctypes.c_int32(num_samples), + ctypes.byref(interval), + ctypes.byref(max_samples), + ctypes.c_uint32(0), + ) + self._assert_pico_ok(status) + + actual_interval = interval.value * 1e-9 + logger.debug( + "Calculated timebase {} for {}-bit resolution: desired interval {:.3e}s -> actual {:.3e}s", + timebase_idx, + self.resolution, + desired_interval, + actual_interval, + ) + return int(timebase_idx), actual_interval + + def _setup_block_buffers(self, channels: List[int], num_samples: int) -> None: + """Set up data buffers for block mode capture.""" + for ch in channels: + self._buffers[ch] = np.zeros(shape=num_samples, dtype=np.int16) + status = self._ps.ps5000aSetDataBuffer( + self._chandle, + self._get_channel_enum(ch), + self._buffers[ch].ctypes.data_as(ctypes.POINTER(ctypes.c_int16)), + num_samples, + 0, + self._ps.PS5000A_RATIO_MODE["PS5000A_RATIO_MODE_NONE"], + ) + self._assert_pico_ok(status) + + def _run_block(self, num_samples: int, timebase: int) -> None: + """Initiates a single block acquisition.""" + time_indisposed_ms = ctypes.c_int32() + status = self._ps.ps5000aRunBlock( + self._chandle, + 0, # preTriggerSamples + num_samples, + ctypes.c_uint32(timebase), + ctypes.byref(time_indisposed_ms), + 0, + None, + None, + ) + self._assert_pico_ok(status) + + def _wait_for_capture( + self, timeout_s: float = 10.0, cancel_event: Optional[threading.Event] = None + ) -> None: + """Waits for the block capture to complete.""" + ready = ctypes.c_int16(0) + start_time = time.time() + while ready.value == 0: + if cancel_event and cancel_event.is_set(): + self.cancel_acquisition() + raise TaskCancelledError("Acquisition cancelled by user.") + if time.time() - start_time > timeout_s: + self.cancel_acquisition() + raise TimeoutError( + f"Block capture timed out after {timeout_s} seconds." + ) + status = self._ps.ps5000aIsReady(self._chandle, ctypes.byref(ready)) + self._assert_pico_ok(status) + time.sleep(0.01) + + def _get_block_data(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Retrieves the captured block data and converts it to voltage.""" + overflow = ctypes.c_int16() + buffer_len = len(self._buffers[self._enabled_channels[0]]) + samples_retrieved = ctypes.c_uint32(buffer_len) + + # Apply downsampling in the GetValues call + downsample_ratio_val = ctypes.c_uint32(self._downsample_ratio) + downsample_mode = self._ps.PS5000A_RATIO_MODE[ + f"PS5000A_RATIO_MODE_{self._downsample_mode}" + ] + + status = self._ps.ps5000aGetValues( + self._chandle, + 0, # startIndex + ctypes.byref(samples_retrieved), # noOfSamples (in/out) + downsample_ratio_val, + downsample_mode, + 0, # segmentIndex + ctypes.byref(overflow), + ) + self._assert_pico_ok(status) + + actual_num_samples = samples_retrieved.value + logger.debug( + "Retrieved {} samples from a buffer of size {}.", + actual_num_samples, + buffer_len, + ) + + raw_adc_data = np.array( + [self._buffers[ch][:actual_num_samples] for ch in self._enabled_channels], + dtype=np.int16, + ) + + # Convert to voltage and check for clipping + voltages_list = [] + if self._max_adc_value is None: + raise DeviceOperationError( + "Cannot convert ADC, max_adc_value not available." + ) + max_adc = self._max_adc_value.value + + for i, ch in enumerate(self._enabled_channels): + adc_buffer = raw_adc_data[i] + voltage_range = self._channel_voltage_ranges[ch] + + # Check for clipping (analogue saturation) + clipped_samples = np.sum((adc_buffer >= max_adc) | (adc_buffer <= -max_adc)) + if clipped_samples > 0: + total_samples = len(adc_buffer) + clip_percent = (clipped_samples / total_samples) * 100 + logger.warning( + "Picoscope channel {} may be clipping. {} of {} samples ({:.2f}%) are at the ADC limit. Consider using a larger voltage range than {}V.", + ch, + clipped_samples, + total_samples, + clip_percent, + voltage_range, + ) + + # Perform vectorized scaling to volts + voltages = (adc_buffer.astype(np.float32) / max_adc) * voltage_range + voltages_list.append(voltages) + + logger.debug( + "Converted ADC to voltage for channel {}: range={}V, ADC min/max={}/{}, max_adc={}, V min/max={:.6f}/{:.6f}", + ch, + voltage_range, + adc_buffer.min(), + adc_buffer.max(), + max_adc, + voltages.min(), + voltages.max(), + ) + + voltage_data = np.array(voltages_list) + + if self._sample_interval_s is None: + raise DeviceOperationError( + "Sample interval not set for block data retrieval." + ) + + downsample_ratio = self._current_acquisition_params.get("downsample_ratio", 1) + effective_interval = self._sample_interval_s * downsample_ratio + + time_data = np.linspace( + 0, (actual_num_samples - 1) * effective_interval, actual_num_samples + ) + return time_data, voltage_data, raw_adc_data + + def _get_channel_enum(self, ch: int) -> int: + """Converts a channel index (0-3) to the corresponding PicoSDK channel enum.""" + return self._ps.PS5000A_CHANNEL[f"PS5000A_CHANNEL_{chr(65 + ch)}"] + + def _assert_pico_ok(self, status: int) -> None: + """Check if a PicoScope status code indicates success.""" + try: + if assert_pico_ok is not None: + assert_pico_ok(status) # type: ignore + except Exception as e: + # Provide more context for common errors + if status == 2: # PICO_NOT_FOUND + raise DeviceConnectionError( + f"Picoscope device not found. Status: {status}" + ) from e + elif status in {286, 282}: # PICO_POWER_SUPPLY_... + logger.warning( + f"Picoscope device {self.device_id} has a power issue (status: {status}). " + "Consider connecting external power or using a USB3.0 port." + ) + return + else: + raise DeviceOperationError( + f"PicoSDK operation failed with status {status}: {e}" + ) from e + + +class PicoscopeSimpleStream(Picoscope): + """ + Picoscope device with simple callback-based streaming. + + This class inherits all block acquisition functionality from Picoscope + and provides direct callback streaming suitable for live monitoring. + """ + + def __init__( + self, + device_id: str, + serial_code: Optional[str] = None, + resolution: int = 12, + **kwargs: Any, + ) -> None: + super().__init__(device_id, serial_code, resolution, **kwargs) + + self._streaming_thread: Optional[threading.Thread] = None + self._stop_streaming_event = threading.Event() + self._streaming_callback_ref: Optional[Callable] = None + self._streaming_voltage_range: Optional[float] = None + self._streaming_sample_rate: Optional[float] = None + self._streaming_channel: int = 0 + self._streaming_num_samples_per_block: int = 16384 + self._sdk_buffer: Optional[np.ndarray] = None + self._callbackFuncPtr: Optional[Any] = ( + None # Will be recreated per streaming session + ) + + # --- Streaming Performance Tracking --- + self._stream_stats_lock = threading.Lock() + self._stream_total_samples = 0 + self._stream_callback_count = 0 + self._stream_overflow_count = 0 + self._stream_last_summary_time = 0.0 + self._stream_start_count = 0 # Track how many times streaming has been started + + # --- Diagnostic Timing for Chunking Issues --- + self._last_callback_time: Optional[float] = None + self._callback_intervals: List[float] = [] + + def configure_streaming( + self, + sample_rate: float, + voltage_range: float, + channel: int = 0, + num_samples_per_block: int = 4096, + downsample_ratio: int = 1, + downsample_mode: str = "NONE", + ) -> None: + """ + Configures the parameters for hardware-timed streaming. + + This must be called before `start_streaming`. + + Args: + sample_rate: The desired sample rate in Hz. + voltage_range: The voltage range for the acquisition channel. + channel: The channel to stream from (default 0). + num_samples_per_block: The number of samples in each acquired block. + downsample_ratio: The hardware downsampling factor. + downsample_mode: The downsampling mode (e.g., 'AVERAGE'). + """ + if sample_rate <= 0: + raise DeviceConfigurationError("Sample rate must be positive.") + if num_samples_per_block <= 0: + raise DeviceConfigurationError( + "Number of samples per block must be positive." + ) + + # Reset device to a clean state before configuring streaming + self._reset_device_state() + + # When mode is NONE, force ratio to 1 for SDK compatibility + if downsample_mode.upper() == "NONE": + downsample_ratio = 1 + + # Validate and store downsampling parameters + self._store_downsampling_params(downsample_ratio, downsample_mode) + + self._streaming_sample_rate = sample_rate + self._streaming_voltage_range = voltage_range + self._streaming_channel = channel + self._streaming_num_samples_per_block = num_samples_per_block + logger.debug( + "Streaming configured for device {}: rate={}Hz, range={}V, downsample={}x ({})", + self.device_id, + sample_rate, + voltage_range, + downsample_ratio, + downsample_mode, + ) + + def start_streaming(self, callback: Callable) -> None: + """ + Starts a true hardware-timed stream using the PicoSDK's streaming mode. + + `configure_streaming` must be called before this method. + + Parameters + ---------- + callback : Callable + A function to call with new data. It will receive two arguments: + a (currently None) time array and a NumPy array of voltage data. + """ + if not self._is_connected or self._chandle is None: + raise DeviceNotConnectedError(f"Device {self.device_id} is not connected.") + + if self._streaming_thread and self._streaming_thread.is_alive(): + logger.warning( + "Streaming is already active on device {}.", + self.device_id, + ) + return + + if self._streaming_sample_rate is None or self._streaming_voltage_range is None: + raise DeviceConfigurationError( + "Streaming parameters not configured. Call `configure_streaming` before `start_streaming`." + ) + + # Store callback and reset stop event + self._streaming_callback_ref = callback + self._stop_streaming_event.clear() + + # Reset performance counters and increment start count + with self._stream_stats_lock: + self._stream_total_samples = 0 + self._stream_callback_count = 0 + self._stream_overflow_count = 0 + self._stream_last_summary_time = time.time() + self._stream_start_count += 1 + + # Reset diagnostic timing + self._last_callback_time = None + self._callback_intervals.clear() + + # 1) Configure channel for streaming + channel = self._streaming_channel or 0 + self._configure_channel(channel, True, "DC", self._streaming_voltage_range) + + # 2) Recreate callback function pointer and allocate SDK buffer + self._callbackFuncPtr = self._ps.StreamingReadyType( + self._streaming_sdk_callback + ) + buffer_size = int(self._streaming_num_samples_per_block or 50000) + self._sdk_buffer = np.zeros(shape=buffer_size, dtype=np.int16) + downsample_mode_enum = self._ps.PS5000A_RATIO_MODE[ + f"PS5000A_RATIO_MODE_{self._downsample_mode.upper()}" + ] + status = self._ps.ps5000aSetDataBuffers( + self._chandle, + self._get_channel_enum(channel), + self._sdk_buffer.ctypes.data_as(ctypes.POINTER(ctypes.c_int16)), + None, # bufferMin + buffer_size, + 0, # segmentIndex + downsample_mode_enum, + ) + self._assert_pico_ok(status) + + # 3) Determine streaming interval and unit + interval, unit_enum, actual_rate = self._get_streaming_interval( + self._streaming_sample_rate + ) + sample_interval = ctypes.c_int32(interval) + + # 4) Start hardware streaming + downsample_mode_enum = self._ps.PS5000A_RATIO_MODE[ + f"PS5000A_RATIO_MODE_{self._downsample_mode}" + ] + status = self._ps.ps5000aRunStreaming( + self._chandle, + ctypes.byref(sample_interval), + unit_enum, + 0, # preTriggerSamples + 10_000_000, # maxPostTriggerSamples (large number for continuous) + 0, # autoStop = False + self._downsample_ratio, # downSampleRatio + downsample_mode_enum, + buffer_size, # overviewBufferSize + ) + self._assert_pico_ok(status) + + logger.info( + "Streaming started on device {} (session #{}). " + "Requested rate: {} Hz, Actual rate: {} Hz. " + "Buffer size: {} samples.", + self.device_id, + self._stream_start_count, + self._streaming_sample_rate, + actual_rate, + buffer_size, + ) + + # 5) Start a lightweight polling thread to trigger SDK callbacks + self._streaming_thread = threading.Thread( + target=self._streaming_polling_loop, + daemon=True, + name=f"picoscope-stream-{self.device_id}", + ) + self._streaming_thread.start() + + def stop_streaming(self, timeout: float = 2.0) -> None: + """Stops the background streaming thread and hardware stream.""" + if not self._streaming_thread or not self._streaming_thread.is_alive(): + logger.info("Streaming is not active on device {}.", self.device_id) + return + + logger.info("Stopping stream on device {}...", self.device_id) + self._stop_streaming_event.set() + + # Avoid self-join deadlock if called from the polling thread itself + if threading.current_thread() is self._streaming_thread: + logger.warning( + "stop_streaming called from within the streaming polling thread on device {}. " + "The thread cannot be joined, but the stop event has been set.", + self.device_id, + ) + return + + # Join the polling thread with a bound + self._streaming_thread.join(timeout) + if self._streaming_thread.is_alive(): + logger.warning( + "Streaming thread on device {} did not stop within {}s.", + self.device_id, + timeout, + ) + + # Stop the hardware stream regardless, swallowing errors + try: + if self._is_connected and self._chandle is not None: + status = self._ps.ps5000aStop(self._chandle) + if status != PICO_STATUS["PICO_OK"]: + logger.debug( + "ps5000aStop returned status {} on device {}.", + status, + self.device_id, + ) + except Exception: + logger.exception("Error stopping hardware stream during stop_streaming") + + # De-register streaming buffers using size 0 before clearing references + try: + if ( + self._is_connected + and self._chandle is not None + and self._sdk_buffer is not None + ): + channel_enum = self._get_channel_enum(self._streaming_channel or 0) + status = self._ps.ps5000aSetDataBuffers( + self._chandle, + channel_enum, + self._sdk_buffer.ctypes.data_as(ctypes.POINTER(ctypes.c_int16)), + None, # bufferMin unused in non-aggregate mode + 0, # size=0 de-registers streaming buffer + 0, # segmentIndex + self._ps.PS5000A_RATIO_MODE["PS5000A_RATIO_MODE_NONE"], + ) + if status != PICO_STATUS["PICO_OK"]: + logger.debug( + "ps5000aSetDataBuffers (deregister) returned status {} on device {}.", + status, + self.device_id, + ) + except Exception: + logger.exception( + "Error de-registering streaming buffers during stop_streaming" + ) + + # Clear state + self._streaming_thread = None + self._streaming_callback_ref = None + self._sdk_buffer = None + self._callbackFuncPtr = None + logger.info("Streaming stopped and cleaned up for device {}.", self.device_id) + + def is_streaming(self) -> bool: + """Returns True if the streaming thread is currently active.""" + return self._streaming_thread is not None and self._streaming_thread.is_alive() + + def disconnect(self) -> None: + """ + Closes the connection to the Picoscope device. + + Stops any active streaming before disconnecting. + """ + if self.is_streaming(): + self.stop_streaming() + super().disconnect() + + def _get_streaming_interval(self, sample_rate_hz: float) -> Tuple[int, int, float]: + """ + Calculate the integer interval and unit for streaming mode. + + Parameters + ---------- + sample_rate_hz : float + Desired sample rate in Hz. + + Returns + ------- + Tuple[int, int, float] + (interval, unit_enum, actual_rate_hz) + """ + if sample_rate_hz <= 0: + raise DeviceConfigurationError("Sample rate must be positive.") + + # Choose unit to maximise resolution of integer interval + if sample_rate_hz >= 1_000_000: + unit_str = "PS5000A_NS" + scale = 1_000_000_000 + elif sample_rate_hz >= 1_000: + unit_str = "PS5000A_US" + scale = 1_000_000 + else: + unit_str = "PS5000A_MS" + scale = 1_000 + + interval = max(1, int(round(scale / sample_rate_hz))) + unit_enum = self._ps.PS5000A_TIME_UNITS[unit_str] + actual_rate = scale / interval + return interval, unit_enum, actual_rate + + def _streaming_polling_loop(self) -> None: + """Poll the SDK to trigger data callbacks until stopped.""" + logger.debug( + "Streaming polling loop started for device {}.", + self.device_id, + ) + try: + while not self._stop_streaming_event.is_set(): + # --- Periodic Summary Logging --- + now = time.time() + with self._stream_stats_lock: + last_summary = self._stream_last_summary_time + + if now - last_summary > 5.0: + with self._stream_stats_lock: + elapsed = now - self._stream_last_summary_time + if elapsed > 0: + sample_rate = self._stream_total_samples / elapsed + callback_rate = self._stream_callback_count / elapsed + + # Add diagnostic info for chunking analysis + avg_interval = ( + np.mean(self._callback_intervals) + if self._callback_intervals + else 0 + ) + max_interval = ( + np.max(self._callback_intervals) + if self._callback_intervals + else 0 + ) + + logger.info( + "Streaming summary for {} (session #{}): " + "{} Samples/s, {} Callbacks/s, Overflows: {}, " + "Avg callback interval: {}ms, Max callback interval: {}ms", + self.device_id, + self._stream_start_count, + sample_rate, + callback_rate, + self._stream_overflow_count, + avg_interval * 1000, + max_interval * 1000, + ) + # Reset for next interval + self._stream_total_samples = 0 + self._stream_callback_count = 0 + self._stream_last_summary_time = now + self._callback_intervals.clear() + + # --- Poll for Data --- + status = self._ps.ps5000aGetStreamingLatestValues( + self._chandle, self._callbackFuncPtr, None + ) + # Expected statuses during normal operation + if status not in [ + PICO_STATUS["PICO_OK"], + PICO_STATUS["PICO_BUSY"], + PICO_STATUS["PICO_NO_SAMPLES_AVAILABLE"], + ]: + logger.warning( + "ps5000aGetStreamingLatestValues returned status {} on device {}.", + status, + self.device_id, + ) + + time.sleep(0.005) # Yield to other threads + except Exception: + logger.exception("Error in streaming polling loop") + finally: + # Best-effort stop of hardware when loop exits + try: + if self._is_connected and self._chandle is not None: + status = self._ps.ps5000aStop(self._chandle) + if status != PICO_STATUS["PICO_OK"]: + logger.debug( + f"ps5000aStop returned status {status} on device {self.device_id}." + ) + except Exception: + logger.exception("Error stopping hardware stream on polling loop exit") + logger.debug( + "Streaming polling loop finished for device {}.", + self.device_id, + ) + + def _streaming_sdk_callback( + self, + _handle: int, + noOfSamples: int, + startIndex: int, + overflow: int, + _triggerAt: int, + _triggered: int, + _autoStop: int, + _param: int, + ) -> None: + """ + SDK callback executed when new streaming data is available. + + Notes + ----- + - Called from a PicoSDK-managed thread. + - Must be fast and must not acquire the device lock. + - Fail-fast on user-callback exceptions: stop event is set. + """ + try: + # Track timing for diagnostic purposes + current_time = time.time() + interval = None + if self._last_callback_time is not None: + interval = current_time - self._last_callback_time + self._callback_intervals.append(interval) + # Keep only recent intervals to avoid memory growth + if len(self._callback_intervals) > 100: + self._callback_intervals = self._callback_intervals[-50:] + self._last_callback_time = current_time + + with self._stream_stats_lock: + if overflow: + self._stream_overflow_count += 1 + logger.warning( + "Picoscope hardware buffer overflow on device {}. Data loss may have occurred.", + self.device_id, + ) + if noOfSamples > 0: + self._stream_total_samples += noOfSamples + self._stream_callback_count += 1 + + # Log large chunks for diagnostic purposes (only for sessions 2+) + if self._stream_start_count >= 2 and noOfSamples > 40: + interval_ms = interval * 1000 if interval is not None else 0 + logger.debug( + "Large chunk detected on {} session #{}: {} samples (interval: {}ms)", + self.device_id, + self._stream_start_count, + noOfSamples, + interval_ms, + ) + + if noOfSamples <= 0 or self._sdk_buffer is None: + return + + # Slice ADC data from the SDK-managed buffer + adc_view = self._sdk_buffer[startIndex : startIndex + noOfSamples] + + if self._max_adc_value is None or self._streaming_voltage_range is None: + logger.warning( + "Cannot process streaming data: max_adc or voltage_range not set." + ) + return + + # Convert to volts + max_adc = self._max_adc_value.value + volts = ( + adc_view.astype(np.float32) / max_adc + ) * self._streaming_voltage_range + + # Invoke user callback; if it raises, fail fast + if self._streaming_callback_ref is not None: + try: + self._streaming_callback_ref(None, volts) + except Exception: + logger.exception( + "Error in user streaming callback on device {}", self.device_id + ) + self._stop_streaming_event.set() + except Exception: + logger.exception("Unexpected error in SDK streaming callback") + + +class PicoscopeBufferedStream(Picoscope): + """ + Picoscope device with buffered streaming using a ring buffer and Zarr save. + + This class inherits all block acquisition functionality from Picoscope + and implements a producer-consumer pattern with an in-memory ring buffer + for live display and on-demand Zarr saving. + + Parameters + ---------- + device_id : str + Unique identifier for this device instance. + serial_code : Optional[str], default None + Serial code of the Picoscope hardware. If None, auto-detect is used. + resolution : int, default 12 + ADC resolution in bits (8, 12, 14, 15, or 16). + """ + + def __init__( + self, + device_id: str, + serial_code: Optional[str] = None, + resolution: int = 12, + **kwargs: Any, + ) -> None: + super().__init__(device_id, serial_code, resolution, **kwargs) + + # --- Streaming Configuration --- + self._streaming_configured: bool = False + self._streaming_channels: List[int] = [] + self._enabled_channels_set: set[int] = set() + self._streaming_voltage_ranges: Dict[int, float] = {} + self._streaming_offsets_v: Dict[int, float] = {} + self._streaming_sample_rate: Optional[float] = None + self._streaming_sample_interval_ns: Optional[int] = None + self._streaming_sample_unit: Optional[int] = None + self._streaming_downsample_ratio: int = 1 + self._streaming_downsample_mode: str = "NONE" + self._streaming_bandwidth_limiter: str = "FULL" + self._streaming_buffer_duration_s: float = 30.0 + + # --- Ring Buffer --- + self._ring_buffer: Optional[Any] = None + + # --- SDK Buffers --- + self._sdk_buffers: Dict[int, np.ndarray] = {} + self._sdk_buffer_min: Optional[np.ndarray] = None + self._interleaved_buffer: Optional[np.ndarray] = None + + # --- Threads and Control --- + self._producer_thread: Optional[threading.Thread] = None + self._stop_streaming_event = threading.Event() + self._user_callback: Optional[Callable] = None + self._callbackFuncPtr: Optional[Any] = None + + # --- Save Thread State --- + self._save_thread: Optional[threading.Thread] = None + self._save_stop_event: Optional[threading.Event] = None + self._save_output_path: Optional[str] = None + self._save_samples_written: int = 0 + self._save_pre_trigger_samples: int = 0 + self._save_error: Optional[Exception] = None + self._is_saving: bool = False + self._save_start_time_iso: Optional[str] = None + self._save_stop_time_iso: Optional[str] = None + + # --- Performance Tracking --- + self._stream_stats_lock = threading.Lock() + self._stream_total_samples: int = 0 + self._stream_callback_count: int = 0 + self._stream_overflow_count: int = 0 + self._hardware_overflow_count: int = 0 + self._overvoltage_count: int = 0 + + # --- Streaming Error State --- + self._streaming_error: Optional[str] = None + + # --- Pre-allocated callback buffer for zero-copy operation --- + self._callback_buffer: Optional[np.ndarray] = None + + # --- Plotter position tracking for overflow detection --- + self._last_plotter_read_idx: int = 0 + self._plotter_position_lock = threading.Lock() + self._streaming_start_time: float = 0.0 + + def configure_streaming( + self, + sample_rate: float, + channels: List[int], + voltage_ranges: List[float], + buffer_duration_s: float = 30.0, + downsample_ratio: int = 1, + downsample_mode: str = "NONE", + offsets_v: Optional[List[float]] = None, + bandwidth_limiter: str = "FULL", + ) -> None: + """Configure buffered streaming acquisition with a ring buffer. + + This sets up the ring buffer for high-speed multi-channel acquisition + and prepares channel configuration. Data is stored in memory and can + be saved on-demand using start_save(). + + The ring buffer always has 2 columns (Channel A and Channel B). + Disabled channels will have zeros written to their column. + + Parameters + ---------- + sample_rate : float + Desired sample rate in Hz per enabled channel. + channels : List[int] + List of channel indices to enable for acquisition (e.g., [0, 1] for A and B). + Empty list is allowed (no acquisition, just configuration). + voltage_ranges : List[float] + Voltage range for each channel in volts. Must match length of channels. + buffer_duration_s : float, default 30.0 + Duration of the ring buffer in seconds. Determines memory usage. + At 62.5 MS/s per channel with 2 channels, 30s requires ~7.5 GB. + downsample_ratio : int, default 1 + Hardware downsampling ratio. + downsample_mode : str, default "NONE" + Downsampling mode ("NONE", "AVERAGE", "DECIMATE", "AGGREGATE"). + offsets_v : Optional[List[float]], default None + Analogue offset in volts for each channel. If None, defaults to 0.0 for all. + + Raises + ------ + DeviceConfigurationError + If parameters are invalid or inconsistent. + ImportError + If ring buffer module is not available. + """ + if sample_rate <= 0: + raise DeviceConfigurationError("Sample rate must be positive.") + if not channels: + raise DeviceConfigurationError( + "At least one channel must be enabled for acquisition. " + "Specify channels=[0] for Channel A or channels=[1] for Channel B." + ) + if len(channels) != len(voltage_ranges): + raise DeviceConfigurationError( + "Length of channels and voltage_ranges must match." + ) + if len(set(channels)) != len(channels): + raise DeviceConfigurationError("Duplicate channel indices not allowed.") + for ch in channels: + if not 0 <= ch <= 3: + raise DeviceConfigurationError( + f"Invalid channel index: {ch}. Must be 0-3." + ) + for v_range in voltage_ranges: + if v_range not in self.VOLTAGE_RANGES: + raise DeviceConfigurationError( + f"Invalid voltage range: {v_range}V. " + f"Valid ranges: {sorted(self.VOLTAGE_RANGES.keys())}V" + ) + + if buffer_duration_s <= 0: + raise DeviceConfigurationError("Buffer duration must be positive.") + + try: + from picostream.ring_buffer import RingBuffer + except ImportError as e: + raise ImportError( + "Ring buffer module not found. " + "Please ensure the picostream app is available." + ) from e + + if downsample_mode.upper() == "NONE": + downsample_ratio = 1 + + self._store_downsampling_params(downsample_ratio, downsample_mode) + self._reset_device_state() + + # Always use 2 columns in ring buffer (A and B) + num_buffer_channels = 2 + + # Store enabled channels and their configurations + self._streaming_channels = list(channels) + self._enabled_channels_set = set(channels) + self._streaming_voltage_ranges = { + ch: v_range for ch, v_range in zip(channels, voltage_ranges, strict=True) + } + + # Hardware samples at the requested rate regardless of downsampling mode + # AGGREGATE mode stores 2 values per downsampled point (min, max) + # but the hardware timing is still based on the original sample rate + self._streaming_sample_rate = sample_rate + + self._streaming_buffer_duration_s = buffer_duration_s + + if offsets_v is None: + offsets_v = [0.0] * len(channels) + elif len(offsets_v) != len(channels): + raise DeviceConfigurationError( + "Length of offsets_v must match length of channels." + ) + + self._streaming_offsets_v = { + ch: offset for ch, offset in zip(channels, offsets_v, strict=True) + } + self._streaming_bandwidth_limiter = bandwidth_limiter.upper() + + # Configure all physical channels - enabled ones for acquisition, + # disabled ones are turned off in hardware + all_channels = [0, 1] # Only support A and B for now + for ch in all_channels: + if ch in self._enabled_channels_set: + idx = channels.index(ch) + v_range = voltage_ranges[idx] + offset = offsets_v[idx] + self._configure_channel(ch, True, "DC", v_range, offset) + else: + # Disable channel in hardware - use dummy values + self._configure_channel(ch, False, "DC", 20.0, 0.0) + + # Create AcquisitionRate to calculate effective storage rate + # hardware_rate_hz is total across all channels (per-channel × num_channels) + total_hardware_rate_hz = sample_rate * len(channels) + self._acquisition_rate = AcquisitionRate( + hardware_rate_hz=total_hardware_rate_hz, + num_channels=len(channels), + downsample_ratio=self._downsample_ratio, + downsample_mode=DownsampleMode(self._downsample_mode), + ) + + # Use storage_rate_hz which accounts for hardware downsampling + # This is the actual rate at which samples arrive at the buffer + self._ring_buffer = RingBuffer( + duration_s=buffer_duration_s, + sample_rate=self._acquisition_rate.storage_rate_hz, + num_channels=num_buffer_channels, + ) + self._ring_buffer.set_acquisition_rate(self._acquisition_rate) + + buffer_memory_mb = self._ring_buffer.buffer.nbytes / (1024 * 1024) + + try: + import psutil # type: ignore + + available_memory_mb = psutil.virtual_memory().available / (1024 * 1024) + if buffer_memory_mb > available_memory_mb * 0.5: + logger.warning( + "Ring buffer (%.1f MB) exceeds 50%% of available RAM (%.1f MB). " + "Consider reducing buffer_duration_s.", + buffer_memory_mb, + available_memory_mb, + ) + except ImportError: + logger.debug("psutil not available, skipping memory validation") + + # SDK buffers only for enabled channels + # Use a single large buffer (500ms) matching reference implementation pattern + samples_per_buffer = int(sample_rate * 0.5) + if samples_per_buffer < 1: + samples_per_buffer = 5000 + + self._sdk_buffers = {} + for ch in channels: + self._sdk_buffers[ch] = np.zeros(shape=samples_per_buffer, dtype=np.int16) + + # Pre-allocate callback output buffer (Fix #1: avoid allocation in hot path) + # Max size is samples_per_buffer * 2 for AGGREGATE mode (min/max pairs) + max_callback_samples = ( + samples_per_buffer * 2 + if downsample_mode.upper() == "AGGREGATE" + else samples_per_buffer + ) + self._callback_buffer = np.zeros((max_callback_samples, 2), dtype=np.int16) + + if self._downsample_mode == "AGGREGATE": + self._sdk_buffer_min = np.zeros(shape=samples_per_buffer, dtype=np.int16) + self._interleaved_buffer = np.zeros( + shape=samples_per_buffer * 2, dtype=np.int16 + ) + + interval, unit_enum, actual_rate = self._get_streaming_interval_ns(sample_rate) + self._streaming_sample_interval_ns = interval + self._streaming_sample_unit = unit_enum + + self._streaming_configured = True + + enabled_names = ["A" if ch == 0 else "B" for ch in sorted(channels)] + logger.info( + "Buffered streaming configured for {}: channels=[{}], " + "rate={} Hz per channel, {:.1f}s buffer ({:.1f} MB), " + "downsample={}x ({})", + self.device_id, + ", ".join(enabled_names), + actual_rate, + buffer_duration_s, + buffer_memory_mb, + downsample_ratio, + downsample_mode, + ) + + def start_streaming(self, callback: Optional[Callable] = None) -> None: + """Start buffered streaming acquisition to the ring buffer. + + Parameters + ---------- + callback : Optional[Callable], default None + Optional callback function called with metadata when data arrives. + Receives a dict: {'type': 'data_ready', 'samples': int, + 'channels': List[int], 'total_samples': int}. + + Raises + ------ + DeviceNotConnectedError + If device is not connected. + DeviceConfigurationError + If streaming is not configured. + """ + if not self._is_connected or self._chandle is None: + raise DeviceNotConnectedError(f"Device {self.device_id} is not connected.") + + if not self._streaming_configured: + raise DeviceConfigurationError( + "Streaming not configured. Call configure_streaming() first." + ) + + if self.is_streaming(): + logger.warning("Streaming already active on device {}.", self.device_id) + return + + self._user_callback = callback + self._stop_streaming_event.clear() + + with self._stream_stats_lock: + self._stream_total_samples = 0 + self._stream_callback_count = 0 + self._stream_overflow_count = 0 + self._hardware_overflow_count = 0 + self._overvoltage_count = 0 + self._streaming_error = None # Clear any previous error + + # Reset plotter tracking for overflow detection + with self._plotter_position_lock: + self._last_plotter_read_idx = 0 + self._streaming_start_time = time.time() + + self._setup_streaming_buffers() + + self._callbackFuncPtr = self._ps.StreamingReadyType(self._streaming_callback) + + if self._streaming_sample_interval_ns is None: + raise DeviceConfigurationError( + "Streaming interval not set. Call configure_streaming first." + ) + sample_interval = ctypes.c_int32(self._streaming_sample_interval_ns) + downsample_mode_enum = self._ps.PS5000A_RATIO_MODE[ + f"PS5000A_RATIO_MODE_{self._downsample_mode}" + ] + + # Use single buffer size for maxPostTriggerSamples (matches reference pattern) + sdk_buffer_size = self._sdk_buffers[self._streaming_channels[0]].shape[0] + max_samples = ctypes.c_int32(sdk_buffer_size) + + status = self._ps.ps5000aRunStreaming( + self._chandle, + ctypes.byref(sample_interval), + self._streaming_sample_unit, + 0, # preTriggerSamples + max_samples.value, + 0, # autoStop = False (continuous streaming) + self._downsample_ratio, + downsample_mode_enum, + sdk_buffer_size, # overviewBufferSize + ) + self._assert_pico_ok(status) + + # Get actual sample interval from SDK (Fix #2: use actual rate, not requested) + # IMPORTANT: The SDK's actual_rate_hz is the PRE-DOWNSAMPLED rate. + # It's what we requested (or the closest achievable). + # The effective per-channel rate after downsampling is actual_rate_hz / downsample_ratio. + self._streaming_sample_interval_ns = sample_interval.value + actual_rate_ns = self._streaming_sample_interval_ns + if self._streaming_sample_unit == self._ps.PS5000A_TIME_UNITS["PS5000A_US"]: + actual_rate_ns *= 1000 + elif self._streaming_sample_unit == self._ps.PS5000A_TIME_UNITS["PS5000A_MS"]: + actual_rate_ns *= 1_000_000 + actual_rate_hz = 1e9 / actual_rate_ns if actual_rate_ns > 0 else 0 + + logger.info( + "Hardware streaming started on {}. " + "Actual interval: {} (unit={}), rate={} Hz", + self.device_id, + self._streaming_sample_interval_ns, + self._streaming_sample_unit, + actual_rate_hz, + ) + + self._producer_thread = threading.Thread( + target=self._producer_loop, + daemon=True, + name=f"picoscope-producer-{self.device_id}", + ) + + self._producer_thread.start() + + logger.info("Buffered streaming started on {}.", self.device_id) + + def stop_streaming(self, timeout: float = 5.0) -> None: + """Stop buffered streaming acquisition. + + Parameters + ---------- + timeout : float, default 5.0 + Maximum time to wait for threads to finish. + """ + if not self.is_streaming(): + logger.info("Streaming not active on device {}.", self.device_id) + return + + logger.info("Stopping buffered stream on {}...", self.device_id) + self._stop_streaming_event.set() + + if self._is_saving: + self.stop_save(keep=True) + + try: + if self._is_connected and self._chandle is not None: + status = self._ps.ps5000aStop(self._chandle) + if status != PICO_STATUS["PICO_OK"]: + logger.debug("ps5000aStop returned status {}", status) + except Exception: + logger.exception("Error stopping hardware stream") + + if self._producer_thread and self._producer_thread.is_alive(): + self._producer_thread.join(timeout) + if self._producer_thread.is_alive(): + logger.warning("Producer thread did not stop in time") + + self._cleanup_streaming_buffers() + + self._producer_thread = None + self._user_callback = None + self._callbackFuncPtr = None + + with self._stream_stats_lock: + logger.info( + "Streaming stopped on {}. Total samples: {}, callbacks: {}, " + "hardware overflows: {}", + self.device_id, + self._stream_total_samples, + self._stream_callback_count, + self._hardware_overflow_count, + ) + + def is_streaming(self) -> bool: + """Return True if streaming thread is active.""" + return self._producer_thread is not None and self._producer_thread.is_alive() + + def get_streaming_error(self) -> Optional[str]: + """Get the last streaming error, if any. + + Returns + ------- + Optional[str] + Error description (e.g., "PICO_NOT_RESPONDING") if streaming + stopped due to an error, None otherwise. + Error is cleared when streaming is restarted. + """ + with self._stream_stats_lock: + return self._streaming_error + + @property + def is_saving(self) -> bool: + """Return True if currently saving to file.""" + return self._is_saving + + def disconnect(self) -> None: + """Disconnect from device, stopping any active streaming first. + + Cleanup must happen in this order to avoid deadlocks and handle leaks: + 1. stop_streaming() - stops the producer thread first + 2. ps5000aStop() - stops hardware acquisition (called by stop_streaming) + 3. ps5000aCloseUnit() - closes device handle last (called by parent) + + Reversing this order can cause deadlocks or resource leaks. + """ + try: + if self.is_streaming(): + self.stop_streaming(timeout=1.0) + except Exception: + logger.exception("Error stopping streaming during disconnect") + finally: + super().disconnect() + + @property + def ring_buffer(self) -> Optional[Any]: + """Return the ring buffer for plotter access. + + Returns + ------- + Optional[RingBuffer] + The ring buffer if configured, None otherwise. + """ + return self._ring_buffer + + def update_plotter_position(self, read_idx: int) -> None: + """Update the plotter's last read position for overflow detection. + + This should be called by the plotter whenever it reads data from + the ring buffer. It allows the producer to detect if the plotter + is falling behind. + + Parameters + ---------- + read_idx : int + The sample index that the plotter last read up to. + """ + with self._plotter_position_lock: + self._last_plotter_read_idx = read_idx + self._last_plotter_update_time = time.time() + + # Log overflow check outside lock to avoid holding lock during computation + if self._ring_buffer is not None: + write_idx = self._ring_buffer.write_idx + buffer_capacity = self._ring_buffer.capacity + lag_samples = write_idx - read_idx + if lag_samples > 0.9 * buffer_capacity: + logger.warning( + "Ring buffer nearly full: {} samples lag ({:.1%} utilised). " + "write_idx={}, read_idx={}, capacity={}. " + "Plotter may not be keeping up - data loss imminent!", + lag_samples, + lag_samples / buffer_capacity if buffer_capacity > 0 else 0, + write_idx, + read_idx, + buffer_capacity, + ) + + def start_save(self, lookback_seconds: float, output_path: str) -> bool: + """Start saving data to a Zarr file. + + Launches a save thread that writes pre-trigger data from the ring + buffer, then continuously drains new data to a Zarr file. + + Parameters + ---------- + lookback_seconds : float + How many seconds of pre-trigger data to include. + output_path : str + Path for the output Zarr directory. + + Returns + ------- + bool + True if save started successfully, False otherwise. + """ + if self._is_saving: + logger.warning("Save already in progress") + return False + + if self._ring_buffer is None: + logger.error("Cannot start save: ring buffer not configured") + return False + + self._save_output_path = output_path + + if self._acquisition_rate is None: + logger.error("Cannot start save: acquisition_rate not configured") + return False + + lookback_samples = self._acquisition_rate.seconds_to_samples(lookback_seconds) + available_samples = min(lookback_samples, self._ring_buffer.write_idx) + + # Capture start/stop timestamps for metadata + self._save_start_time_iso = datetime.now(timezone.utc).isoformat() + self._save_stop_time_iso = None + + self._save_stop_event = threading.Event() + self._save_samples_written = 0 + self._save_pre_trigger_samples = 0 + self._save_error = None + + # Use existing acquisition_rate, should already be set from start_streaming + acquisition_rate = self._acquisition_rate + + # Use the driver's max_adc value for voltage conversion + # The PicoScope 5000a driver always returns 32767 (16-bit signed max) + # regardless of resolution setting. Higher resolutions use more effective + # bits, but the ADC value range stays the same. + driver_max_adc = self._max_adc_value.value if self._max_adc_value else 32767 + + # Build metadata with AcquisitionRate fields for consistent rate semantics + # Use the actual acquisition_rate stored on the ring buffer + total_hardware_rate_hz = self._acquisition_rate.hardware_rate_hz + metadata: Dict[str, Any] = { + # AcquisitionRate fields - stored explicitly for reconstruction + "hardware_rate_hz": total_hardware_rate_hz, + "num_channels": len(self._streaming_channels), + "downsample_ratio": self._downsample_ratio, + "downsample_mode": self._downsample_mode, + # Legacy field for backward compatibility (deprecated) + "sample_rate_hz": acquisition_rate.storage_rate_hz, + "channels": [0, 1], # Always save both columns (A and B) + "enabled_channels": self._streaming_channels, # Which channels were actually acquiring + "voltage_ranges": [ + self._streaming_voltage_ranges.get(ch, 20.0) for ch in [0, 1] + ], + "voltage_ranges_units": "V", + "offsets_v": [self._streaming_offsets_v.get(ch, 0.0) for ch in [0, 1]], + "resolution": self.resolution, + "resolution_units": "bits", + # Use driver's max_adc (always 32767 for PicoScope 5000a) + # for correct voltage conversion in external tools + "max_adc": driver_max_adc, + "pre_trigger_seconds": lookback_seconds, + "buffer_duration_s": self._streaming_buffer_duration_s, + "device_id": self.device_id, + "serial_code": self.serial_code, + "coupling": "DC", # Currently hardcoded to DC in configure_streaming + "bandwidth_limiter": self._streaming_bandwidth_limiter, + } + + # Capture trigger position in main thread BEFORE spawning worker thread. + # This eliminates a race condition where the producer could overwrite + # pre-trigger data while the save worker is starting up. + trigger_pos = self._ring_buffer.get_snapshot() + + self._save_thread = threading.Thread( + target=self._save_worker, + args=( + output_path, + trigger_pos, + available_samples, + metadata, + acquisition_rate, + ), + daemon=True, + name=f"picoscope-save-{self.device_id}", + ) + self._save_thread.start() + self._is_saving = True + + logger.info( + "Started save on {}: path={}, lookback={:.1f}s ({} samples)", + self.device_id, + output_path, + lookback_seconds, + available_samples, + ) + return True + + def _save_worker( + self, + path: str, + trigger_pos: int, + lookback_samples: int, + metadata: Dict[str, Any], + acquisition_rate: AcquisitionRate, + ) -> None: + """Worker thread for saving data to Zarr. + + This runs in a separate thread and continuously drains data + from the ring buffer to a Zarr file until stop_save is called. + Always saves 2 columns (A and B), with zeros for disabled channels. + + Parameters + ---------- + path : str + Output path for the Zarr file. + trigger_pos : int + The write index position captured when start_save() was called. + Used as the trigger point for pre-trigger data capture. + lookback_samples : int + Number of samples to include before the trigger point. + metadata : Dict[str, Any] + Metadata dictionary to store in the Zarr file. + acquisition_rate : AcquisitionRate + Acquisition rate object with all rate information. + + Raises + ------ + Exception + Any exception during saving is stored in _save_error and re-raised + to ensure the thread terminates with an error state. + """ + try: + from picostream.zarr_writer import ZarrStreamWriter + + writer = ZarrStreamWriter( + path=path, + acquisition_rate=acquisition_rate, + num_channels=2, # Always 2 columns + compression=None, + ) + + pre_trigger_start = trigger_pos - lookback_samples + if pre_trigger_start < 0: + pre_trigger_start = 0 + + # Validate that trigger_pos is still valid before reading. + # This catches edge cases where save was significantly delayed + # after start_save() was called. + if not self._ring_buffer.is_valid_range(pre_trigger_start): + msg = ( + f"Pre-trigger data expired before save worker started: " + f"trigger_pos={trigger_pos}, lookback_samples={lookback_samples}, " + f"buffer write_idx={self._ring_buffer.write_idx}, " + f"capacity={self._ring_buffer.capacity}" + ) + logger.error(msg) + self._save_error = RuntimeError(msg) + raise self._save_error + + pre_trigger_data = self._ring_buffer.read_range( + pre_trigger_start, trigger_pos + ) + + if pre_trigger_data.shape[0] > 0: + writer.append(pre_trigger_data) + self._save_pre_trigger_samples = pre_trigger_data.shape[0] + + last_pos = trigger_pos + + while not self._save_stop_event.is_set(): + data = self._ring_buffer.read_since(last_pos) + + if data.shape[0] > 0: + if not self._ring_buffer.is_valid_range(last_pos): + valid_start = ( + self._ring_buffer.write_idx - self._ring_buffer.capacity + ) + if valid_start < 0: + valid_start = 0 + samples_lost = valid_start - last_pos + if samples_lost > 0: + logger.warning( + "Save worker lost {} samples: producer outran consumer. " + "Consider reducing sample rate or increasing buffer size.", + samples_lost, + ) + data = self._ring_buffer.read_since(valid_start) + last_pos = valid_start + + writer.append(data) + last_pos = self._ring_buffer.get_snapshot() + self._save_samples_written = writer.total_samples + else: + time.sleep(0.001) + + if self._save_start_time_iso is not None: + metadata["record_start_time_iso"] = self._save_start_time_iso + if self._save_stop_time_iso is not None: + metadata["record_stop_time_iso"] = self._save_stop_time_iso + + writer.close(metadata) + logger.info( + "Save worker completed: {} samples written to {}", + writer.total_samples, + path, + ) + + except Exception as e: + logger.exception("Save worker error: {}", e) + self._save_error = e + # Re-raise to ensure thread terminates with error state + raise + + def stop_save(self, keep: bool = True) -> Optional[str]: + """Stop the current save operation (non-blocking). + + Signals the save thread to stop and returns immediately. + Poll is_save_finished() to check when the operation completes. + + Parameters + ---------- + keep : bool, default True + If True, keep the saved file. If False, delete it. + + Returns + ------- + Optional[str] + Path where the file is being saved (if keep=True), or None. + The file may still be incomplete - check is_save_finished() before using. + + Raises + ------ + DeviceOperationError + If the save thread already encountered an error during recording. + """ + if not self._is_saving: + logger.info("No save in progress") + return None + + # Capture stop timestamp and store keep flag for cleanup + self._save_stop_time_iso = datetime.now(timezone.utc).isoformat() + self._keep_file_on_stop = keep + + logger.info("Stopping save on {} (keep={})", self.device_id, keep) + self._save_stop_event.set() + + # Check for errors from the save worker (already occurred) + if self._save_error is not None: + save_error = self._save_error + logger.error( + "Save operation failed on {}: {}", + self.device_id, + save_error, + ) + # Clear the error so we don't re-raise it again later + self._save_error = None + raise DeviceOperationError( + f"Save operation failed: {save_error}" + ) from save_error + + return self._save_output_path if keep else None + + def is_save_finished(self) -> bool: + """Check if the save operation has finished. + + After this returns True, you can safely use the saved file + (if keep=True was passed to stop_save()). + + Returns + ------- + bool + True if save is complete and the file is ready, False otherwise. + + Notes + ----- + This method is thread-safe and will not raise exceptions. Any errors + during cleanup are logged but not propagated. + + There is a benign race condition: the save thread may terminate after + the is_alive() check but before cleanup completes. This is handled + gracefully by catching exceptions during cleanup. + """ + if not self._is_saving: + return True + + # Check if thread is still running + if self._save_thread and self._save_thread.is_alive(): + return False + + # Thread has finished, do cleanup + try: + self._cleanup_save() + except Exception: + logger.exception("Error during save cleanup in is_save_finished()") + # Still mark as finished even if cleanup fails, to avoid infinite polling + self._is_saving = False + + return True + + def _cleanup_save(self) -> None: + """Clean up after save thread completes (idempotent).""" + if not self._is_saving: + return + + path = self._save_output_path + keep = getattr(self, "_keep_file_on_stop", True) + + if not keep and path and os.path.exists(path): + try: + import shutil + + shutil.rmtree(path) + logger.info("Discarded save file: {}", path) + except Exception as e: + logger.exception("Error discarding save file: {}", e) + + self._is_saving = False + self._save_thread = None + self._save_stop_event = None + self._save_output_path = None + if hasattr(self, "_keep_file_on_stop"): + delattr(self, "_keep_file_on_stop") + + def get_save_status(self) -> Dict[str, Any]: + """Get the current save status. + + Returns + ------- + Dict[str, Any] + Dictionary with keys: 'state', 'total_seconds', 'pre_trigger_seconds', + 'post_trigger_seconds', 'error' (if applicable). + """ + if not self._is_saving and self._save_error is None: + return {"state": "idle"} + + if self._save_error is not None: + return {"state": "error", "error": str(self._save_error)} + + if self._acquisition_rate is None: + return {"state": "saving", "total_seconds": 0} + + total_seconds = self._acquisition_rate.samples_to_seconds( + self._save_samples_written + ) + pre_trigger_seconds = self._acquisition_rate.samples_to_seconds( + self._save_pre_trigger_samples + ) + post_trigger_seconds = total_seconds - pre_trigger_seconds + + return { + "state": "saving", + "total_seconds": total_seconds, + "pre_trigger_seconds": pre_trigger_seconds, + "post_trigger_seconds": post_trigger_seconds, + } + + def _setup_streaming_buffers(self) -> None: + """Register SDK streaming buffers for enabled channels only.""" + downsample_mode_enum = self._ps.PS5000A_RATIO_MODE[ + f"PS5000A_RATIO_MODE_{self._downsample_mode}" + ] + + for ch in self._streaming_channels: + buf = self._sdk_buffers[ch] + buffer_min_ptr = None + + if ( + self._downsample_mode == "AGGREGATE" + and self._sdk_buffer_min is not None + ): + buffer_min_ptr = self._sdk_buffer_min.ctypes.data_as( + ctypes.POINTER(ctypes.c_int16) + ) + + status = self._ps.ps5000aSetDataBuffers( + self._chandle, + self._get_channel_enum(ch), + buf.ctypes.data_as(ctypes.POINTER(ctypes.c_int16)), + buffer_min_ptr, + buf.shape[0], + 0, + downsample_mode_enum, + ) + self._assert_pico_ok(status) + + logger.debug( + "Set up streaming buffers for enabled channels %s", self._streaming_channels + ) + + def _cleanup_streaming_buffers(self) -> None: + """De-register streaming buffers for enabled channels.""" + if not self._is_connected or self._chandle is None: + return + + for ch in self._streaming_channels: + try: + buf = self._sdk_buffers.get(ch) + if buf is None: + continue + + buffer_min_ptr = None + if ( + self._downsample_mode == "AGGREGATE" + and self._sdk_buffer_min is not None + ): + buffer_min_ptr = self._sdk_buffer_min.ctypes.data_as( + ctypes.POINTER(ctypes.c_int16) + ) + + status = self._ps.ps5000aSetDataBuffers( + self._chandle, + self._get_channel_enum(ch), + buf.ctypes.data_as(ctypes.POINTER(ctypes.c_int16)), + buffer_min_ptr, + 0, # size=0 de-registers + 0, + self._ps.PS5000A_RATIO_MODE["PS5000A_RATIO_MODE_NONE"], + ) + if status != PICO_STATUS["PICO_OK"]: + logger.debug( + "Buffer de-register status {} for channel {}.", status, ch + ) + except Exception: + logger.exception("Error de-registering buffer for channel {}", ch) + + def _get_streaming_interval_ns( + self, sample_rate_hz: float + ) -> Tuple[int, int, float]: + """ + Calculate streaming interval in nanoseconds. + + Returns + ------- + Tuple[int, int, float] + (interval, unit_enum, actual_rate_hz) + """ + if sample_rate_hz <= 0: + raise DeviceConfigurationError("Sample rate must be positive.") + + # Use nanoseconds for highest resolution + unit_str = "PS5000A_NS" + interval = max(1, int(round(1e9 / sample_rate_hz))) + unit_enum = self._ps.PS5000A_TIME_UNITS[unit_str] + actual_rate = 1e9 / interval + + return interval, unit_enum, actual_rate + + def _producer_loop(self) -> None: + """Producer thread: poll SDK for data and write to ring buffer.""" + logger.debug("Producer loop started for {}.", self.device_id) + not_responding_count = 0 + + # Main polling loop - start immediately like the reference implementation + # The device may return PICO_NOT_RESPONDING initially; we handle this + # gracefully by counting consecutive failures and only stopping after 3 + main_loop_start = time.perf_counter() + last_status_log_time = main_loop_start + last_callback_seen = 0 + + try: + while not self._stop_streaming_event.is_set(): + status = self._ps.ps5000aGetStreamingLatestValues( + self._chandle, self._callbackFuncPtr, None + ) + + now = time.perf_counter() + elapsed_since_start = now - main_loop_start + + # Handle statuses according to PicoSDK documentation + # These are all normal during streaming - only PICO_NOT_RESPONDING + # and unexpected errors are treated as abnormal + if status == PICO_STATUS["PICO_BUFFER_STALL"]: + with self._stream_stats_lock: + self._hardware_overflow_count += 1 + logger.warning("Hardware buffer overflow - data lost") + elif status == PICO_STATUS["PICO_NOT_RESPONDING"]: + not_responding_count += 1 + if not_responding_count == 1: + # First occurrence - log diagnostic info + with self._stream_stats_lock: + cb_count = self._stream_callback_count + sample_count = self._stream_total_samples + logger.error( + "PICO_NOT_RESPONDING (status 7) first detected after {:.3f}s. " + "Callbacks so far: {}, Samples: {}. " + "This may indicate hardware initialization timing issue.", + elapsed_since_start, + cb_count, + sample_count, + ) + + if not_responding_count >= 3: + # Device has stopped responding persistently - treat as fatal + with self._stream_stats_lock: + self._streaming_error = "PICO_NOT_RESPONDING" + final_cb_count = self._stream_callback_count + final_sample_count = self._stream_total_samples + logger.error( + "Picoscope device {} stopped responding (status 7 - PICO_NOT_RESPONDING) " + "after {} consecutive failures. Total time: {:.3f}s, " + "Total callbacks: {}, Total samples: {}. Streaming will stop.", + self.device_id, + not_responding_count, + elapsed_since_start, + final_cb_count, + final_sample_count, + ) + break # Exit the producer loop + else: + logger.debug( + "Picoscope device {} returned PICO_NOT_RESPONDING (attempt {}, {:.3f}s elapsed)", + self.device_id, + not_responding_count, + elapsed_since_start, + ) + elif status in [ + PICO_STATUS["PICO_OK"], + PICO_STATUS["PICO_NO_SAMPLES_AVAILABLE"], + PICO_STATUS["PICO_BUSY"], + PICO_STATUS["PICO_DATA_NOT_AVAILABLE"], + PICO_STATUS["PICO_DRIVER_FUNCTION"], + ]: + # All normal statuses during streaming - reset not_responding_count + if not_responding_count > 0: + logger.info( + "Device {} responding again after {} NOT_RESPONDING errors " + "(at {:.3f}s elapsed)", + self.device_id, + not_responding_count, + elapsed_since_start, + ) + not_responding_count = 0 + else: + # Unexpected status - log warning (throttled) + if ( + not hasattr(self, "_last_status_warning_time") + or now - self._last_status_warning_time > 5.0 + ): + logger.warning( + "GetStreamingLatestValues returned unexpected status {} after {:.3f}s", + status, + elapsed_since_start, + ) + self._last_status_warning_time = now + + # Log callback stats periodically + with self._stream_stats_lock: + current_callbacks = self._stream_callback_count + current_samples = self._stream_total_samples + + if current_callbacks > last_callback_seen: + last_callback_seen = current_callbacks + # Log every 100 callbacks or every 5 seconds + if current_callbacks % 100 == 1 or now - last_status_log_time > 5.0: + logger.debug( + "Producer stats @ {:.3f}s: callbacks={}, samples={}, " + "rate={:.1f} samples/s", + elapsed_since_start, + current_callbacks, + current_samples, + current_samples / elapsed_since_start + if elapsed_since_start > 0 + else 0, + ) + last_status_log_time = now + + time.sleep(0.001) + + except Exception: + logger.exception( + "Error in producer loop after {:.3f}s", + time.perf_counter() - main_loop_start, + ) + finally: + total_elapsed = time.perf_counter() - main_loop_start + logger.debug( + "Producer loop finished for %s. Total time: %.3fs", + self.device_id, + total_elapsed, + ) + + def _streaming_callback( + self, + _handle: int, + noOfSamples: int, + startIndex: int, + overvoltage_flags: int, + _triggerAt: int, + _triggered: int, + _autoStop: int, + _param: int, + ) -> None: + """SDK callback: copy data from SDK buffers to the ring buffer. + + This is called from the SDK's internal thread when new data is available. + Data is written to fixed column positions (0=Channel A, 1=Channel B). + Disabled channels have zeros written to their column. + """ + # Debug: log first few callbacks + with self._stream_stats_lock: + self._stream_callback_count += 1 + if self._stream_callback_count <= 5: + logger.debug( + "SDK callback #%s: %s samples", + self._stream_callback_count, + noOfSamples, + ) + + if self._stop_streaming_event.is_set(): + return + + if overvoltage_flags: + with self._stream_stats_lock: + self._overvoltage_count += 1 + logger.warning("ADC overvoltage detected (signal clipping)") + + if noOfSamples <= 0: + return + + with self._stream_stats_lock: + self._stream_total_samples += noOfSamples + self._stream_callback_count += 1 + + if self._ring_buffer is None: + logger.error("Ring buffer not initialised") + return + + if ( + self._downsample_mode == "AGGREGATE" + and self._interleaved_buffer is not None + ): + # AGGREGATE mode: SDK provides separate min/max buffers + # We interleave them into output as [min1, max1, min2, max2, ...] + # This produces 2x the samples compared to other modes + src_start = startIndex + total_output_samples = noOfSamples * 2 + + # Use pre-allocated buffer, slice to needed size + output_data = self._callback_buffer[:total_output_samples, :] + output_data[...] = 0 # Clear only viewed elements (not entire buffer) + + for ch in self._streaming_channels: + sdk_buf_max = self._sdk_buffers[ch] + + # Get min and max slices + mins = self._sdk_buffer_min[src_start : src_start + noOfSamples] + maxs = sdk_buf_max[src_start : src_start + noOfSamples] + + # Interleave into output buffer: [min1, max1, min2, max2, ...] + output_data[0::2, ch] = mins + output_data[1::2, ch] = maxs + # Disabled channels remain as zeros + else: + # Normal mode: copy enabled channel data to fixed position + output_data = self._callback_buffer[:noOfSamples, :] + output_data[...] = 0 # Clear only viewed elements (not entire buffer) + for ch in self._streaming_channels: + sdk_buf = self._sdk_buffers[ch] + output_data[:, ch] = sdk_buf[startIndex : startIndex + noOfSamples] + # Disabled channels remain as zeros + + self._ring_buffer.write(output_data) + + if self._user_callback is not None: + try: + self._user_callback( + { + "type": "data_ready", + "samples": noOfSamples, + "channels": self._streaming_channels, + "total_samples": self._stream_total_samples, + } + ) + except Exception: + logger.exception("Error in user callback") diff --git a/picostream/dfplot.py b/picostream/dfplot.py index c9cca03..21b2473 100644 --- a/picostream/dfplot.py +++ b/picostream/dfplot.py @@ -1,660 +1,883 @@ from __future__ import annotations +import argparse import sys -import threading import time -from typing import List, Optional -import argparse +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional -import h5py import numpy as np -import pyqtgraph as pg from loguru import logger -from PyQt5.QtCore import Qt, QTimer -from PyQt5.QtGui import QCloseEvent, QFont, QKeyEvent -from PyQt5.QtWidgets import ( +from PyQt6.QtCore import Qt, QTimer +from PyQt6.QtGui import QCloseEvent, QKeyEvent +from PyQt6.QtWidgets import ( QApplication, - QHBoxLayout, - QLabel, QVBoxLayout, QWidget, ) - -from .conversion_utils import adc_to_mV, min_max_decimate_numba - - -class HDF5LivePlotter(QWidget): - """ - Real-time oscilloscope-style plotter that reads from HDF5 files. - Completely independent of acquisition system for zero-risk operation. - - Supports both standard acquisition files and live-only mode files with - automatic buffer reset detection and graceful handling of file size changes. +from vispy import scene # noqa: E402 +from vispy.app import use_app + +use_app("pyqt6") + +if TYPE_CHECKING: + pass + +from picostream.acquisition_rate import AcquisitionRate # noqa: E402 +from picostream.data_pipeline import DataPipeline, PipelineConfig # noqa: E402 +from picostream.ring_buffer import RingBuffer # noqa: E402 + +TARGET_PLOT_POINTS = 20000 + + +class LivePlotter(QWidget): + """Real-time oscilloscope-style plotter that reads from a RingBuffer. + + This plotter displays live streaming data directly from memory, completely + decoupled from the file saving mechanism. It displays both channels on a + single axis with data scaling to match voltage ranges. + + Parameters + ---------- + ring_buffer : Optional[RingBuffer] + The ring buffer to read data from. None indicates waiting for stream. + channels : List[int] + List of active channel indices (e.g., [0, 1] for channels A and B). + voltage_ranges : Dict[int, float] + Mapping of channel index to voltage range in Volts. + offsets_v : Dict[int, float] + Mapping of channel index to voltage offset in Volts. + resolution : int + ADC resolution in bits (e.g., 12, 14, 15, 16). + downsample_mode : str + Downsample mode string (e.g., "NONE", "AGGREGATE"). + update_interval_ms : int, optional + How often to update the plot in milliseconds. Default is 50ms (20 Hz). + display_window_seconds : float, optional + Time duration of data to display. Default is 5.0 seconds. + decimation_factor : int, optional + Factor for min-max decimation to reduce plotting points. Default is 150. + target_plot_points : int, optional + Target number of points to display on screen for adaptive decimation. + Default is 20000. + on_read_position : Optional[Callable[[int], None]], optional + Callback invoked with the sample index after each read from ring buffer. + Used by producer for overflow detection. + + Attributes + ---------- + ring_buffer : Optional[RingBuffer] + Current ring buffer reference. Updated via set_ring_buffer(). + acquisition_rate : Optional[AcquisitionRate] + Single source of truth for all rate calculations. """ def __init__( self, - hdf5_path: str = "/tmp/data.hdf5", + ring_buffer: Optional[RingBuffer], + channels: List[int], + voltage_ranges: Dict[int, float], + offsets_v: Dict[int, float], + resolution: int, + downsample_mode: str, update_interval_ms: int = 50, - display_window_seconds: float = 0.5, + display_window_seconds: float = 5.0, decimation_factor: int = 150, - y_min: Optional[float] = None, - y_max: Optional[float] = None, + target_plot_points: int = 20000, + on_read_position: Optional[Callable[[int], None]] = None, + downsample_ratio: int = 1, ) -> None: - """Initializes the HDF5LivePlotter window. - - Parameters - ---------- - hdf5_path : str - Path to the HDF5 file to monitor. - update_interval_ms : int - How often to check the file for updates (in ms). - display_window_seconds : float - The time duration of data to display. - decimation_factor : int - The factor by which to decimate data for plotting. - y_min : float, optional - Minimum Y-axis limit in mV. If None, auto-range is used. - y_max : float, optional - Maximum Y-axis limit in mV. If None, auto-range is used. - """ + """Initialise the LivePlotter with metadata and optional ring buffer.""" super().__init__() - # --- Configuration --- - self.hdf5_path: str = hdf5_path + self.ring_buffer: Optional[RingBuffer] = ring_buffer + self._on_read_position: Optional[Callable[[int], None]] = on_read_position + + self.channels: List[int] = channels + self.voltage_ranges: Dict[int, float] = voltage_ranges + self.offsets_v: Dict[int, float] = offsets_v + self.resolution: int = resolution + self.downsample_mode: str = downsample_mode + self.downsample_ratio: int = downsample_ratio + self.update_interval_ms: int = update_interval_ms self.display_window_seconds: float = display_window_seconds self.decimation_factor: int = decimation_factor - self.y_min: Optional[float] = y_min - self.y_max: Optional[float] = y_max + self.target_plot_points: Optional[int] = target_plot_points + + # PicoScope 5000a series uses a 16-bit ADC internally and always returns + # max_adc = 32767 (2^15 - 1) regardless of resolution setting. + # Higher resolutions use more effective bits but the hardware ADC range stays the same. + # This is documented in the hardware API and confirmed in devices/picoscope.py. + max_adc_val = 32767 + self.max_adc_val: int = max_adc_val - # --- UI State --- - self.heartbeat_chars: List[str] = ["|", "/", "-", "\\"] - self.heartbeat_index: int = 0 self.is_saturated: bool = False - # --- Data Buffers --- + self._display_cache: Dict[int, np.ndarray] = {} + self._last_read_position: int = 0 + self._cache_ready: bool = False + self._current_plot_decimation: int = 1 + self._cache_dirty: bool = False + self._cached_time_axis: Optional[np.ndarray] = None + self._target_time_axis: Optional[np.ndarray] = None + self.display_data: np.ndarray = np.array([]) self.time_data: np.ndarray = np.array([]) self.data_start_sample: int = 0 - # --- Buffer Reset Detection --- - self.last_file_size: int = 0 - self.buffer_reset_count: int = 0 - self.total_samples_processed: int = 0 - - # --- HDF5 Metadata --- - self.sample_interval_ns: float = 16.0 # Default, will be read from file - self.hardware_downsample_ratio: int = 1 - self.ch_range: Optional[int] = None - self.voltage_range_v: Optional[float] = None - self.max_adc_val: Optional[int] = None - self.downsample_mode: Optional[str] = None - self.analog_offset_v: float = 0.0 - self.metadata_read: bool = False - - # --- Debug Counters --- + self.v_div_volts: Dict[int, float] = {0: 0.1, 1: 0.1} + self.y_pos_divs: Dict[int, float] = {0: 0.0, 1: 0.0} + self.update_count: int = 0 - self.file_read_count: int = 0 self.display_update_count: int = 0 + self.conversion_error_count: int = 0 - # --- Performance Monitoring --- self.display_latency_ms: float = 0.0 self.last_data_timestamp: Optional[float] = None - # --- Rate Checking --- self.rate_check_start_time: Optional[float] = None self.rate_check_start_samples: int = 0 - # --- Data Freshness Tracking --- - self.last_displayed_size: int = 0 - self.data_change_count: int = 0 - self.stale_update_count: int = 0 - self.last_freshness_check: float = time.time() + self.pipeline: Optional[DataPipeline] = None + self.acquisition_rate: Optional[AcquisitionRate] = None - # --- Error Tracking --- - self.conversion_error_count: int = 0 - self.file_error_count: int = 0 + self._channel_colors = {0: (0.0, 0.75, 1.0, 1.0), 1: (1.0, 0.27, 0.27, 1.0)} - # Setup UI - self.setup_ui() + self._pos_buffers: Dict[int, np.ndarray] = {} + self._time_axis_float32: Optional[np.ndarray] = None - # Setup update timer, which is controlled externally - self.timer = QTimer() - self.timer.timeout.connect(self.update_from_file) + self.setup_ui() logger.info( - f"HDF5LivePlotter initialized: path={hdf5_path}, interval={update_interval_ms}ms" + "LivePlotter initialised: {} channel(s), interval={}ms", + len(self.channels), + self.update_interval_ms, ) - # Initial file check - self.check_file_exists() + self.timer = QTimer() + self.timer.timeout.connect(self.update_from_buffer) + + # Set ring buffer to initialize pipeline if provided + if ring_buffer is not None: + self.set_ring_buffer(ring_buffer) - def reset_for_new_file(self) -> None: - """Reset all state in preparation for a new file. - - This should be called BEFORE set_hdf5_path() when starting a new acquisition - to ensure no stale metadata or state is carried over. + @property + def hardware_rate_hz(self) -> float: + """Hardware sample rate in Hz (total across all channels). + + Returns + ------- + float + The hardware rate from acquisition_rate. + + Raises + ------ + RuntimeError + If called before acquisition_rate is set. """ - logger.debug("Resetting plotter state for new file") - - # Clear display - self.curve.setData([], []) - - # Reset all data buffers - self.display_data = np.array([]) - self.time_data = np.array([]) - self.data_start_sample = 0 - - # Reset buffer tracking - self.last_file_size = 0 - self.buffer_reset_count = 0 - self.total_samples_processed = 0 - - # Clear cached metadata - this is critical! - self.voltage_range_v = None - self.max_adc_val = None - self.downsample_mode = None - self.analog_offset_v = 0.0 - self.sample_interval_ns = 16.0 - self.hardware_downsample_ratio = 1 - self.metadata_read = False - - # Reset rate checking - self.rate_check_start_time = None - self.rate_check_start_samples = 0 - self.last_displayed_size = 0 - - # Reset counters - self.update_count = 0 - self.file_read_count = 0 - self.display_update_count = 0 - self.data_change_count = 0 - self.stale_update_count = 0 - - # Reset status - self.is_saturated = False - self.last_data_timestamp = None - self.display_latency_ms = 0.0 - - def set_hdf5_path(self, hdf5_path: str) -> None: - """Sets the HDF5 file path to monitor. - - Note: Call reset_for_new_file() BEFORE this method when starting a new - acquisition to ensure clean state. + if self.acquisition_rate is not None: + return self.acquisition_rate.hardware_rate_hz + msg = "hardware_rate_hz called before acquisition_rate was set" + raise RuntimeError(msg) + + def _allocate_display_buffers(self, max_points: int) -> None: + """Allocate or reallocate GPU position buffers for all channels. + + Parameters + ---------- + max_points : int + Maximum number of points to display. """ - self.hdf5_path = hdf5_path - logger.info(f"Plotter path updated to: {hdf5_path}") - - # Don't check file or read metadata here - wait for update_from_file() - # to be called, which will only read metadata once the file exists with valid data + for ch in [0, 1]: + new_size = max_points * 2 + if ch not in self._pos_buffers or len(self._pos_buffers[ch]) < new_size: + self._pos_buffers[ch] = np.empty((max_points, 2), dtype=np.float32) + + def set_acquisition_rate(self, rate: AcquisitionRate) -> None: + """Set the acquisition rate and reinitialise the pipeline. - def set_display_window(self, window_seconds: float) -> None: - """Sets the temporal display window width. - Parameters ---------- - window_seconds : float - The time duration of data to display in seconds. + rate : AcquisitionRate + The acquisition rate with all rate information. """ - self.display_window_seconds = window_seconds - logger.info(f"Display window set to {window_seconds}s") + self.acquisition_rate = rate + self._initialise_pipeline(rate) + logger.info("Plotter acquisition rate set: {}", rate) + + def _initialise_pipeline(self, acquisition_rate: AcquisitionRate) -> None: + """Initialise or reinitialise the data pipeline with current config.""" + config = PipelineConfig( + resolution=self.resolution, + voltage_ranges=self.voltage_ranges, + offsets_v=self.offsets_v, + max_adc_value=self.max_adc_val, + target_plot_points=self.target_plot_points, + ) + self.pipeline = DataPipeline(config, acquisition_rate) + + def _update_pipeline_sample_rate(self) -> None: + """Update the plotter from ring buffer's acquisition rate.""" + if self.ring_buffer is not None: + rate = self.ring_buffer.get_acquisition_rate() + if rate is not None: + self.set_acquisition_rate(rate) + + def set_ring_buffer(self, ring_buffer: Optional[RingBuffer]) -> None: + """Update the ring buffer reference. + + Called when the stream stops/starts to reconnect the plotter to + a new or restarted buffer. Resets position tracking and clears + the display cache. - def set_y_limits(self, y_min: Optional[float], y_max: Optional[float]) -> None: - """Sets fixed Y-axis limits or enables auto-ranging. - Parameters ---------- - y_min : float, optional - Minimum Y-axis limit in mV. If None, auto-range is used. - y_max : float, optional - Maximum Y-axis limit in mV. If None, auto-range is used. + ring_buffer : Optional[RingBuffer] + The new ring buffer, or None to indicate waiting for stream. """ - self.y_min = y_min - self.y_max = y_max - - has_fixed_limits = y_min is not None and y_max is not None - - if has_fixed_limits: - self.plot_widget.setYRange(y_min, y_max, padding=0) - self.plot_widget.disableAutoRange(axis='y') - logger.info(f"Y-axis limits set to [{y_min}, {y_max}] mV") + self.ring_buffer = ring_buffer + self.last_position = 0 + + self._display_cache.clear() + self._last_read_position = 0 + self._cache_ready = False + self._current_plot_decimation = 1 + self._cached_time_axis = None + self._time_axis_float32 = None + + if ring_buffer is not None: + self._update_pipeline_sample_rate() + + if ring_buffer is None: + logger.debug("LivePlotter: set to waiting state (no buffer)") else: - self.plot_widget.enableAutoRange(axis='y') - logger.info("Y-axis auto-ranging enabled") + logger.debug("LivePlotter: connected to new ring buffer") - def start_updates(self) -> None: - """Starts the plot update timer.""" - self.timer.start(self.update_interval_ms) + if ring_buffer is not None: + self._update_pipeline_sample_rate() - def stop_updates(self) -> None: - """Stops the plot update timer.""" - self.timer.stop() + if ring_buffer is None: + logger.debug("LivePlotter: set to waiting state (no buffer)") + else: + logger.debug("LivePlotter: connected to new ring buffer") + + def clear_curves(self) -> None: + """Clear all plot curves. + + Called when starting new acquisition to clear any previous + data from the display. + """ + for line in self.lines.values(): + line.set_data(pos=np.zeros((2, 2), dtype=np.float32)) + logger.debug("LivePlotter: curves cleared") + + def _update_graticule(self) -> None: + """Update the oscilloscope-style graticule to match current view bounds. + + Draws a fixed 10x10 division graticule constrained to the data region + with the centre crosshair emphasised. + """ + x_min = 0.0 + x_max = self.display_window_seconds + y_min = -5.0 + y_max = 5.0 + + # 11 vertical lines (0 to 10 divisions) + # 11 horizontal lines (0 to 10 divisions) + # Each line is 2 points, stored as segments + num_v_lines = 11 + num_h_lines = 11 + total_segments = num_v_lines + num_h_lines + + graticule_pos = np.zeros((total_segments * 2, 2), dtype=np.float32) + + # Vertical lines (bottom to top) + for i in range(num_v_lines): + x = x_min + (x_max - x_min) * i / 10.0 + base_idx = i * 2 + graticule_pos[base_idx] = [x, y_min] + graticule_pos[base_idx + 1] = [x, y_max] + + # Horizontal lines + for i in range(num_h_lines): + y = y_min + (y_max - y_min) * i / 10.0 + base_idx = (num_v_lines + i) * 2 + graticule_pos[base_idx] = [x_min, y] + graticule_pos[base_idx + 1] = [x_max, y] + + self.graticule.set_data(pos=graticule_pos) + + # Centre crosshair (horizontal and vertical through origin) + # x=0 vertical, y=0 horizontal + centre_pos = np.array( + [ + [x_min, 0.0], + [x_max, 0.0], + [0.0, y_min], + [0.0, y_max], + ], + dtype=np.float32, + ) + self.graticule_center.set_data(pos=centre_pos) + + # Border rectangle + border_pos = np.array( + [ + [x_min, y_min], + [x_max, y_min], + [x_max, y_max], + [x_min, y_max], + [x_min, y_min], + ], + dtype=np.float32, + ) + self.border.set_data(pos=border_pos) + + def set_display_window(self, window_seconds: float) -> None: + """Set the temporal display window width. - def hide_status_bar(self, hide: bool = True) -> None: - """Hide or show the status bar. - Parameters ---------- - hide : bool - If True, hide the status bar. If False, show it. + window_seconds : float + Duration in seconds to display. """ - if hasattr(self, 'status_container'): - self.status_container.setVisible(not hide) + self.display_window_seconds = window_seconds + + # Invalidate cache when display window changes - time axis must be recalculated + self._cache_ready = False + self._cached_time_axis = None + self._time_axis_float32 = None + self._target_time_axis = None + + self.view.camera.set_range( + x=(0, self.display_window_seconds), + y=(-5, 5), + ) + self._update_graticule() + + logger.info("Display window set to {}s", window_seconds) + + def start_updates(self) -> None: + """Start the plot update timer.""" + self.timer.start(self.update_interval_ms) + + def stop_updates(self) -> None: + """Stop the plot update timer.""" + self.timer.stop() def setup_ui(self) -> None: - """Sets up the main window, widgets, and plot layout.""" + """Set up the main window, widgets, and plot layout.""" layout = QVBoxLayout(self) layout.setContentsMargins(0, 0, 0, 0) - # Status bar - self.status_layout = QHBoxLayout() - self.heartbeat_label = QLabel("UI: -") - self.samples_label = QLabel("Samples: 0") - self.rate_label = QLabel("Rate: -") - self.plotter_latency_label = QLabel("Plotter Latency: 0 ms") - self.error_label = QLabel("Errors: 0") - self.saturation_label = QLabel("Saturation: -") - self.acq_status_label = QLabel( - 'Waiting for file...' + self.canvas = scene.SceneCanvas(keys="interactive", show=False) + self.canvas.create_native() + self.canvas.native.setMinimumHeight(200) + + self.view = self.canvas.central_widget.add_view() + self.view.camera = "panzoom" + self.view.camera.set_range( + x=(0, self.display_window_seconds), + y=(-5, 5), ) - font = QFont() - font.setFamily("Monospace") - font.setFixedPitch(True) - for label in [ - self.heartbeat_label, - self.samples_label, - self.rate_label, - self.plotter_latency_label, - self.error_label, - self.saturation_label, - self.acq_status_label, - ]: - label.setFont(font) - - # Add separators between status items - self.status_layout.addWidget(self.heartbeat_label) - self.status_layout.addWidget(QLabel(" | ")) - self.status_layout.addWidget(self.error_label) - self.status_layout.addWidget(QLabel(" | ")) - self.status_layout.addWidget(self.saturation_label) - self.status_layout.addWidget(QLabel(" | ")) - self.status_layout.addWidget(self.samples_label) - self.status_layout.addWidget(QLabel(" | ")) - self.status_layout.addWidget(self.plotter_latency_label) - self.status_layout.addWidget(QLabel(" | ")) - self.status_layout.addWidget(self.rate_label) - self.status_layout.addWidget(QLabel(" | ")) - self.status_layout.addWidget(self.acq_status_label) - self.status_layout.addStretch() - self.status_container = QWidget() - self.status_container.setLayout(self.status_layout) - layout.addWidget(self.status_container) - - # Plot widget - self.plot_widget = pg.PlotWidget() - self.plot_widget.setLabel("left", "Voltage", "mV") - self.plot_widget.setLabel("bottom", "Time", "s") - self.plot_widget.showGrid(x=True, y=True) - self.plot_widget.setXRange(0, self.display_window_seconds, padding=0) - - # Apply Y-axis limits if provided - has_fixed_y_limits = self.y_min is not None and self.y_max is not None - if has_fixed_y_limits: - self.plot_widget.setYRange(self.y_min, self.y_max, padding=0) - self.plot_widget.disableAutoRange(axis='y') - - # Plot curve - self.curve = self.plot_widget.plot(pen="y", width=1) - - layout.addWidget(self.plot_widget) - - # Performance optimization - self.plot_widget.setDownsampling(mode="peak") - self.plot_widget.setClipToView(True) - - def check_file_exists(self) -> None: - """Checks if the HDF5 file exists and attempts to read metadata. - - Only reads metadata if it hasn't been read yet for this file. - """ - if self.metadata_read: - return - - try: - with h5py.File(self.hdf5_path, "r") as f: - if "adc_counts" in f: - self.acq_status_label.setText( - 'Reading metadata...' - ) - self.read_metadata(f) - else: - self.acq_status_label.setText( - 'Waiting for data...' - ) - except (FileNotFoundError, OSError): - self.acq_status_label.setText( - 'Waiting for file...' + + # Graticule rendered first (order=0) so it appears behind traces + self.graticule = scene.Line( + pos=np.zeros((0, 2), dtype=np.float32), + color=(0.25, 0.25, 0.25, 1.0), + parent=self.view.scene, + connect="segments", + width=1, + ) + self.graticule.order = 0 + self.graticule_center = scene.Line( + pos=np.zeros((0, 2), dtype=np.float32), + color=(0.4, 0.4, 0.4, 1.0), + parent=self.view.scene, + connect="segments", + width=1.5, + ) + self.graticule_center.order = 0 + self.border = scene.Line( + pos=np.zeros((5, 2), dtype=np.float32), + color=(0.7, 0.7, 0.7, 1.0), + parent=self.view.scene, + width=1.5, + ) + self.border.order = 0 + self._update_graticule() + + # Signal traces rendered after (order=1) so they appear on top + self.lines: Dict[int, Any] = {} + for ch in [0, 1]: + self.lines[ch] = scene.Line( + pos=np.zeros((2, 2), dtype=np.float32), + color=self._channel_colors[ch], + parent=self.view.scene, + width=1, ) + self.lines[ch].order = 1 + self.lines[ch].visible = ch in self.channels + + self.canvas.native.setSizePolicy( + self.canvas.native.sizePolicy().Policy.Expanding, + self.canvas.native.sizePolicy().Policy.Expanding, + ) + self.canvas.native.setFocusPolicy(Qt.FocusPolicy.NoFocus) + + layout.addWidget(self.canvas.native, stretch=1) - def read_metadata(self, hdf5_file: h5py.File) -> None: - """Reads metadata attributes from the root of an open HDF5 file. + def set_v_div_settings( + self, v_div_volts: Dict[int, float], y_pos_divs: Dict[int, float] + ) -> None: + """Set V/div and Y-position settings for oscilloscope-style scaling. Parameters ---------- - hdf5_file : h5py.File - An open h5py.File object. + v_div_volts : Dict[int, float] + Mapping of channel index to volts per division. + y_pos_divs : Dict[int, float] + Mapping of channel index to Y-position in divisions from centre. """ - if self.metadata_read: - logger.debug("Metadata already read, skipping") - return - - try: - # Metadata is stored as root-level attributes - base_sample_interval_ns = hdf5_file.attrs["sample_interval_ns"] - self.hardware_downsample_ratio = hdf5_file.attrs.get( - "hardware_downsample_ratio", 1 - ) - self.sample_interval_ns = ( - base_sample_interval_ns * self.hardware_downsample_ratio + old_v_div = self.v_div_volts + old_y_pos = self.y_pos_divs + self.v_div_volts = v_div_volts.copy() + self.y_pos_divs = y_pos_divs.copy() + + if old_v_div != self.v_div_volts or old_y_pos != self.y_pos_divs: + self._cache_dirty = True + + logger.debug("V/div settings: {}, Y-pos: {}", v_div_volts, y_pos_divs) + + def is_channel_clipped(self, channel: int) -> int: + """Check if a channel's data exceeds the current display range. + + This checks for display clipping, where signal data goes beyond the + visible Y-axis range set by V/div and Y-pos settings. It does NOT + indicate hardware ADC saturation. + + Returns + ------- + int + 1 if clipping above display range, -1 if clipping below range, 0 if OK. + """ + if channel not in self._display_cache: + return 0 + + divisions = self._display_cache[channel] + if len(divisions) == 0: + return 0 + + v_div = self.v_div_volts.get(channel, 1.0) + y_pos = self.y_pos_divs.get(channel, 0.0) + + display_range_mv = v_div * 10 * 1000 + centre_mv = y_pos * v_div * 1000 + min_mv = centre_mv - (display_range_mv / 2) + max_mv = centre_mv + (display_range_mv / 2) + + sample_size = min(10000, len(divisions)) + sample_divisions = divisions[:sample_size] + voltage_volts = (sample_divisions - y_pos) * v_div + voltage_mv = voltage_volts * 1000 + + if np.any(voltage_mv > max_mv): + return 1 + if np.any(voltage_mv < min_mv): + return -1 + return 0 + + def set_v_div_data(self, channel: int, v_div_volts: float) -> None: + """Store V/div for a channel to be used in data conversion. + + Parameters + ---------- + channel : int + Channel index. + v_div_volts : float + Volts per division for this channel. + """ + self.v_div_volts[channel] = v_div_volts + + def _read_and_decimate_new_data(self) -> bool: + """Read new data since last frame and update display cache. + + This method implements incremental reading from the ring buffer, + using the DataPipeline for processing. + + Returns + ------- + bool + True if new data was processed and cache updated. + False if no new data available. + + Raises + ------ + RuntimeError + If ring buffer position has been overwritten (overrun). + """ + if self.ring_buffer is None or self.pipeline is None: + return False + + snapshot = self.ring_buffer.get_snapshot() + + self._notify_device_of_read_position(snapshot) + + window_samples = self.pipeline.time_to_samples(self.display_window_seconds) + self._current_plot_decimation = self.pipeline.calculate_decimation( + window_samples + ) + + if self.update_count % 20 == 0: + logger.debug( + f"Buffer: snapshot={snapshot}, write_idx={self.ring_buffer.write_idx}, " + f"capacity={self.ring_buffer.capacity}, " + f"last_pos={self._last_read_position}, " + f"cache_ready={self._cache_ready}, " + f"adaptive_decimation={self._current_plot_decimation}" ) - self.voltage_range_v = hdf5_file.attrs["voltage_range_v"] - self.max_adc_val = hdf5_file.attrs["max_adc"] - self.downsample_mode = hdf5_file.attrs.get("downsample_mode", "average") - self.analog_offset_v = hdf5_file.attrs.get("analog_offset_v", 0.0) - - self.metadata_read = True - - logger.info( - f"Plotter read metadata: voltage_range_v={self.voltage_range_v}V, " - f"max_adc={self.max_adc_val}, downsample_mode={self.downsample_mode}, " - f"analog_offset_v={self.analog_offset_v}V" + if self.pipeline.should_invalidate_cache(self._current_plot_decimation): + self.pipeline.invalidate_cache() + self._cache_ready = False + self._cached_time_axis = None + self._time_axis_float32 = None + self._display_cache.clear() + logger.debug( + f"Decimation factor changed, cache invalidated. " + f"New decimation: {self._current_plot_decimation}" ) - # Update rate label with configured sample rate - configured_rate_sps = 1e9 / self.sample_interval_ns - self.rate_label.setText( - f"Rate: ... : {self._format_rate_sps(configured_rate_sps)}" + if not self._cache_ready: + if window_samples <= 0: + return False + + data = self.ring_buffer.read_last(window_samples) + if len(data) == 0: + return False + + self._initialise_cache_from_data(data, self._current_plot_decimation) + self._last_read_position = snapshot + self._cache_ready = True + return True + + if not self.ring_buffer.is_valid_range(self._last_read_position): + raise RuntimeError( + f"Plotter overrun: last read position {self._last_read_position} " + f"overwritten (buffer capacity {self.ring_buffer.capacity} samples). " + "The plotter cannot keep up with the data rate." ) - except KeyError as e: - logger.debug(f"Metadata not fully available yet: {e}. Will retry.") - - def _update_heartbeat(self) -> None: - """Update UI heartbeat to show the UI thread is alive.""" - self.heartbeat_index = (self.heartbeat_index + 1) % len(self.heartbeat_chars) - self.heartbeat_label.setText( - f"UI: {self.heartbeat_chars[self.heartbeat_index]}" - ) - def _handle_new_data( - self, dataset: h5py.Dataset, start_index: int, current_size: int - ) -> None: - """Process a new window of data.""" - self.data_change_count += 1 - self.last_displayed_size = current_size - self.last_data_timestamp = time.time() - self.acq_status_label.setText('Active') - - # Read only the most recent data window - data_window = dataset[start_index:current_size] - self.file_read_count += 1 - - # Check for ADC saturation - # The Picoscope driver scales data to 16-bit, so saturation occurs at the 16-bit limit. + new_data = self.ring_buffer.read_since(self._last_read_position) + if len(new_data) == 0: + return False + FIXED_MAX_ADC = 32767 - self.is_saturated = np.any(data_window >= FIXED_MAX_ADC) or np.any( - data_window <= -FIXED_MAX_ADC - ) + if np.any(np.abs(new_data) >= FIXED_MAX_ADC): + self.is_saturated = True - logger.debug( - f"Update {self.update_count}: Reading window of {len(data_window):,} samples from index {start_index:,}" - ) + self.pipeline._last_decimation = self._current_plot_decimation - # Update the display with this complete window (only when data changes) - self.update_display(data_window) - - def _handle_buffer_reset(self, new_size: int) -> None: - """Handle detection of a buffer reset (file size decreased).""" - self.buffer_reset_count += 1 - self.total_samples_processed += self.last_file_size - logger.debug(f"Buffer reset detected: size {self.last_file_size:,} → {new_size:,} (reset #{self.buffer_reset_count})") - - # Reset rate calculation to avoid confusion - self.rate_check_start_time = None - self.rate_check_start_samples = 0 - self.last_displayed_size = 0 - - def _handle_stale_data(self) -> None: - """Handle a file check where no new data is found.""" - self.stale_update_count += 1 - self.acq_status_label.setText( - 'Acquiring... ' + geometry = self.pipeline.get_window_display_geometry( + self.display_window_seconds, + self._current_plot_decimation, ) - # Log if we're frequently updating with no new data - if self.stale_update_count % 10 == 0: - logger.debug( - f"File check #{self.update_count} with no new data (stale checks: {self.stale_update_count})" - ) + max_cache_points = geometry.max_display_values - def _update_status_labels(self, current_size: int) -> None: - """Update the various status labels in the UI.""" - # Update samples label (include total if buffer resets occurred) - samples_text = self.format_sample_count(current_size) - if self.buffer_reset_count > 0: - total_processed = self.total_samples_processed + current_size - total_text = self.format_sample_count(total_processed) - self.samples_label.setText(f"Samples: {samples_text} (total: {total_text})") - else: - self.samples_label.setText(f"Samples: {samples_text}") - - # Color-code latency - latency_color = ( - "green" - if self.display_latency_ms < 100 - else "orange" - if self.display_latency_ms < 500 - else "red" - ) - self.plotter_latency_label.setText( - f'Plotter Latency: {self.display_latency_ms:.0f}ms' - ) + for ch in self.channels: + if ch >= new_data.shape[1]: + continue - # Error counter - total_errors = self.conversion_error_count + self.file_error_count - error_color = ( - "green" if total_errors == 0 else "orange" if total_errors < 10 else "red" - ) - self.error_label.setText( - f'Errors: {total_errors}' - ) + ch_samples = new_data[:, ch] - # Saturation status - if self.voltage_range_v is None: - self.saturation_label.setText("Saturation: -") - elif self.is_saturated: - self.saturation_label.setText( - 'Saturation: CLIPPING' - ) - else: - self.saturation_label.setText( - 'Saturation: OK' + voltage_data, has_data = self.pipeline.process_channel_data( + ch_samples, ch, self._current_plot_decimation ) - def _update_rate_label(self, current_size: int) -> None: - """Check and update acquisition rate status.""" - if not self.rate_check_start_time: - return + if not has_data: + continue - elapsed_time = time.perf_counter() - self.rate_check_start_time - if elapsed_time > 1.0: # Check only after 1s for stability - points_per_timestep = 2 if self.downsample_mode == "aggregate" else 1 - samples_acquired = current_size - self.rate_check_start_samples - timesteps_acquired = samples_acquired / points_per_timestep - actual_rate_sps = timesteps_acquired / elapsed_time + v_div = self.v_div_volts.get(ch, 0.1) + y_pos = self.y_pos_divs.get(ch, 0.0) - configured_rate_sps = 1e9 / self.sample_interval_ns - rate_ratio = actual_rate_sps / configured_rate_sps + voltage_volts = voltage_data / 1000.0 + divisions = y_pos + voltage_volts / v_div - configured_rate_str = self._format_rate_sps(configured_rate_sps) - actual_rate_str = self._format_rate_sps(actual_rate_sps) + if ch in self._display_cache: + combined = np.concatenate([self._display_cache[ch], divisions]) + else: + combined = divisions - rate_text = f"Rate: {actual_rate_str} : {configured_rate_str}" - if rate_ratio < 0.95: - self.rate_label.setText(f'{rate_text}') + if len(combined) > max_cache_points: + self._display_cache[ch] = combined[-max_cache_points:] else: - self.rate_label.setText(rate_text) + self._display_cache[ch] = combined - def _process_data_from_file(self, f: h5py.File) -> None: - """Read and process data from an open HDF5 file.""" - if "adc_counts" not in f: - return + self._last_read_position = snapshot - dataset = f["adc_counts"] - current_size = dataset.shape[0] + return True - if current_size == 0: - return + def _notify_device_of_read_position(self, read_idx: int) -> None: + """Notify the device of our last read position for overflow detection. - # Read metadata if not already done - if not self.metadata_read: - self.read_metadata(f) - # If metadata still not available, wait for next update - if not self.metadata_read: - return + Uses the callback registered at construction, avoiding implicit + coupling through ring buffer attributes. - # Start the rate check timer on the first data point - if self.rate_check_start_time is None: - self.rate_check_start_time = time.perf_counter() - self.rate_check_start_samples = current_size + Parameters + ---------- + read_idx : int + The sample index we last read up to. + """ + if self._on_read_position is not None: + try: + self._on_read_position(read_idx) + except Exception: + logger.debug("Failed to notify device of read position") - # Dynamically calculate the number of timesteps for the display window - display_window_timesteps = int( - self.display_window_seconds / (self.sample_interval_ns * 1e-9) - ) + def _initialise_cache_from_data( + self, data: np.ndarray, decimation_factor: int + ) -> None: + """Initialise display cache from full window of data.""" + self._display_cache.clear() + self.is_saturated = False - # In aggregate mode, each timestep has two points (min/max) - display_window_points = display_window_timesteps - if self.downsample_mode == "aggregate": - display_window_points *= 2 - - # Calculate where to start reading to get the last window of points - start_index = max(0, current_size - display_window_points) - self.data_start_sample = start_index - - # Detect buffer resets (file size decreased) - if current_size < self.last_file_size: - self._handle_buffer_reset(current_size) - - # Track data freshness and changes - if current_size > self.last_displayed_size: - self._handle_new_data(dataset, start_index, current_size) - else: - self._handle_stale_data() - - self.last_file_size = current_size + if self.pipeline is not None: + self.pipeline._remainder.clear() + self.pipeline._last_decimation = decimation_factor + + FIXED_MAX_ADC = 32767 + if np.any(np.abs(data) >= FIXED_MAX_ADC): + self.is_saturated = True + + for ch in self.channels: + if ch >= data.shape[1]: + continue + + ch_data = data[:, ch] + + if self.pipeline is not None: + voltage_data, has_data = self.pipeline.process_channel_data( + ch_data, ch, decimation_factor + ) + else: + has_data = False + voltage_data = np.array([], dtype=float) + + if not has_data: + self._display_cache[ch] = np.array([], dtype=float) + continue - self._update_status_labels(current_size) - self._update_rate_label(current_size) + v_div = self.v_div_volts.get(ch, 0.1) + y_pos = self.y_pos_divs.get(ch, 0.0) - def update_from_file(self) -> None: - """Timer-driven function to read data from the HDF5 file and update the plot.""" + voltage_volts = voltage_data / 1000.0 + divisions = y_pos + voltage_volts / v_div + + self._display_cache[ch] = divisions + + self._cache_ready = True + self._cache_dirty = False + + def update_from_buffer(self) -> None: + """Timer-driven function to read from ring buffer and update the plot.""" self.update_count += 1 - self._update_heartbeat() + + if self.ring_buffer is None or self.pipeline is None: + return try: - with h5py.File(self.hdf5_path, "r") as f: - self._process_data_from_file(f) - except (FileNotFoundError, OSError): - self.acq_status_label.setText( - 'Waiting for file...' + if self.acquisition_rate is None: + return + window_samples = self.pipeline.time_to_samples(self.display_window_seconds) + self._current_plot_decimation = self.pipeline.calculate_decimation( + window_samples ) - except Exception as e: - self.file_error_count += 1 - logger.error(f"Update {self.update_count}: Error reading file - {e}") - self.acq_status_label.setText('File error!') - def update_display(self, data_window: np.ndarray) -> None: - """Processes and displays a new window of data. + if self._cache_dirty: + logger.debug("Cache dirty, rebuilding positions") + cache_ready_before = self._cache_ready + self._cache_ready = False + snapshot = self.ring_buffer.get_snapshot() + self._notify_device_of_read_position(snapshot) + + window_samples = self.pipeline.time_to_samples( + self.display_window_seconds + ) + + data = self.ring_buffer.read_last(window_samples) + if len(data) > 0: + self._initialise_cache_from_data( + data, self._current_plot_decimation + ) + self._last_read_position = snapshot + self._cache_ready = cache_ready_before + + has_new_data = self._read_and_decimate_new_data() + if not has_new_data: + return + + self.last_data_timestamp = time.time() - This involves decimation, voltage conversion, and updating the plot curve. + if self.rate_check_start_time is None: + self.rate_check_start_time = time.perf_counter() + self.rate_check_start_samples = self.ring_buffer.write_idx + + if not self._display_cache: + return + + geometry = self.pipeline.get_window_display_geometry( + self.display_window_seconds, + self._current_plot_decimation, + ) + time_axis_float32 = geometry.time_axis.astype(np.float32) + n_time_points = len(time_axis_float32) + + if n_time_points == 0: + return + + for ch in self.channels: + if ch not in self._display_cache: + continue + + div_data = self._display_cache[ch] + n_data_points = len(div_data) + + if n_data_points == 0: + continue + + n_points_to_plot = min(n_data_points, n_time_points) + if n_points_to_plot <= 0: + continue + + if n_data_points != n_time_points and self.update_count % 100 == 0: + logger.debug( + "Data/time axis mismatch: channel {} has {} data points, " + "time axis has {} points (geometry.max_display_values={}), plotting {}", + ch, + n_data_points, + n_time_points, + geometry.max_display_values, + n_points_to_plot, + ) + + pos_buffer = self._pos_buffers.get(ch) + if pos_buffer is None or len(pos_buffer) < n_time_points: + self._allocate_display_buffers(max(n_time_points * 2, 20000)) + pos_buffer = self._pos_buffers[ch] + + pos_buffer[:n_time_points, :] = np.nan + start_idx = n_time_points - n_points_to_plot + pos_buffer[start_idx:n_time_points, 0] = time_axis_float32[ + start_idx:n_time_points + ] + pos_buffer[start_idx:n_time_points, 1] = div_data[-n_points_to_plot:] + + self.lines[ch].set_data(pos=pos_buffer[:n_time_points]) + + for line_ch in self.lines: + if line_ch not in self.channels: + self.lines[line_ch].set_data(pos=np.zeros((2, 2), dtype=np.float32)) + + except RuntimeError as e: + logger.exception("Plotter overrun detected: {}", e) + raise + + except Exception as e: + logger.exception("Error updating from buffer: {}", e) + + def update_display( + self, channel_data: Dict[int, np.ndarray], adaptive_decimation: int = 1 + ) -> None: + """Process and display a new window of data for all channels. Parameters ---------- - data_window : np.ndarray - A NumPy array containing the raw ADC counts for display. + channel_data : Dict[int, np.ndarray] + Mapping of channel index to raw ADC data array. + adaptive_decimation : int + Adaptive decimation factor calculated based on window size. """ - if len(data_window) == 0: + if not channel_data or self.pipeline is None: return - self.display_data = data_window self.display_update_count += 1 - logger.debug( - f"Display update {self.display_update_count}: " - f"Displaying window of {len(self.display_data):,} samples, " - f"starting at sample {self.data_start_sample:,}" + geometry = self.pipeline.get_window_display_geometry( + self.display_window_seconds, + adaptive_decimation, ) + time_axis_float32 = geometry.time_axis.astype(np.float32) - # Apply Numba-optimized decimation for display - if self.decimation_factor > 1: - decimated_data = min_max_decimate_numba( - self.display_data, self.decimation_factor + for ch in self.channels: + raw_data = channel_data.get(ch) + if raw_data is None or len(raw_data) == 0: + continue + + voltage_data, has_data = self.pipeline.process_channel_data( + raw_data, ch, adaptive_decimation ) - else: - decimated_data = self.display_data - # Debug: Log ADC values and metadata - logger.debug(f"ADC range: {decimated_data.min()} to {decimated_data.max()}") - logger.debug(f"voltage_range_v: {self.voltage_range_v}") + if not has_data: + continue - # Convert to voltage if we have calibration data - if self.voltage_range_v is not None and self.max_adc_val is not None: - try: - voltage_data = adc_to_mV( - decimated_data, self.voltage_range_v, self.max_adc_val - ) - if self.analog_offset_v != 0.0: - voltage_data += self.analog_offset_v * 1000 - logger.debug( - f"Voltage conversion successful, range: {voltage_data.min():.1f} to {voltage_data.max():.1f} mV" - ) - except Exception as e: - self.conversion_error_count += 1 - logger.warning(f"Voltage conversion failed: {e}, using raw ADC values") - voltage_data = decimated_data.astype(float) - else: - logger.warning( - "Missing calibration data (voltage_range_v or max_adc_val), using raw ADC values" - ) - voltage_data = decimated_data.astype(float) + div_data = self._convert_voltage_to_divisions(voltage_data, ch) - # Create time axis for the current window - time_axis = self.create_time_axis(len(voltage_data)) + n_points = len(time_axis_float32) - logger.debug( - f"Display update {self.display_update_count}: " - f"Decimated to {len(voltage_data):,} points, " - f"time range: {time_axis[0]:.3f}s to {time_axis[-1]:.3f}s" - ) + pos_buffer = self._pos_buffers.get(ch) + if pos_buffer is None or len(pos_buffer) < n_points: + self._allocate_display_buffers(max(n_points * 2, 20000)) + pos_buffer = self._pos_buffers[ch] + + pos_buffer[:n_points, 0] = time_axis_float32[:n_points] + pos_buffer[:n_points, 1] = div_data[:n_points] + self.lines[ch].set_data(pos=pos_buffer[:n_points]) - # Calculate display latency if self.last_data_timestamp: self.display_latency_ms = (time.time() - self.last_data_timestamp) * 1000 - # Update plot - self.curve.setData(time_axis, voltage_data) + if len(geometry.time_axis) > 0: + self.view.camera.set_range( + x=(0, self.display_window_seconds), + y=(-5, 5), + ) - # Update the X-axis range to match the new time axis, creating a "snapshot" effect. - self.plot_widget.setXRange(time_axis[0], time_axis[-1], padding=0) + def _convert_voltage_to_divisions( + self, voltage_data: np.ndarray, channel: int + ) -> np.ndarray: + """Convert voltage data to oscilloscope divisions. - # Auto-scale the Y-axis occasionally, but only if fixed limits are not set - has_fixed_y_limits = self.y_min is not None and self.y_max is not None - if not has_fixed_y_limits and self.display_update_count % 5 == 1: - self.plot_widget.enableAutoRange(axis="y") + Parameters + ---------- + voltage_data : np.ndarray + Voltage data in millivolts. + channel : int + Channel index. + + Returns + ------- + np.ndarray + Data in oscilloscope divisions. + """ + v_div = self.v_div_volts.get(channel, 0.1) + y_pos = self.y_pos_divs.get(channel, 0.0) + + voltage_volts = voltage_data / 1000.0 + return y_pos + voltage_volts / v_div def _format_rate_sps(self, rate_sps: float) -> str: - """Formats a sample rate in Samples/sec into a human-readable string.""" + """Format a sample rate in Samples/sec into a human-readable string. + + Parameters + ---------- + rate_sps : float + Sample rate in samples per second. + + Returns + ------- + str + Formatted rate string (e.g., "62.50 MS/s"). + """ if rate_sps >= 1e9: return f"{rate_sps / 1e9:.2f} GS/s" if rate_sps >= 1e6: @@ -664,17 +887,17 @@ class HDF5LivePlotter(QWidget): return f"{rate_sps:.2f} S/s" def format_sample_count(self, count: int) -> str: - """Formats a large integer count into a human-readable string with units. + """Format a large integer count into a human-readable string with units. Parameters ---------- count : int - The integer number to format. + Sample count. Returns ------- str - A formatted string (e.g., "1.23M", "2.34G"). + Formatted count (e.g., "1.23M"). """ if count >= 1_000_000_000: return f"{count / 1_000_000_000:.2f}G" @@ -682,125 +905,58 @@ class HDF5LivePlotter(QWidget): return f"{count / 1_000_000:.2f}M" if count >= 1_000: return f"{count / 1_000:.2f}K" - else: - return str(count) - - def create_time_axis(self, n_samples: int) -> np.ndarray: - """Creates a time axis for the displayed data window. + return str(count) - The time axis is absolute, based on the data window's start position in - the overall acquisition. It accounts for the `aggregate` downsample mode, - where the data stream consists of interleaved min/max pairs. - - For min-max decimated data, it generates pairs of time coordinates to - draw vertical lines for each min-max pair. For non-decimated data, it - generates a linearly spaced time axis. + def set_channel_visible(self, channel_index: int, visible: bool) -> None: + """Toggle visibility of a channel's plot curve and update processing. Parameters ---------- - n_samples : int - The number of points for the time axis. This should be - the number of points *after* decimation. - - Returns - ------- - np.ndarray - A NumPy array representing the time axis in seconds. + channel_index : int + Channel index (0 for A, 1 for B). + visible : bool + True to show, False to hide. """ - if n_samples == 0: - return np.array([]) - - time_per_timestep = self.sample_interval_ns * 1e-9 - points_per_timestep = 2 if self.downsample_mode == "aggregate" else 1 - time_per_point = time_per_timestep / points_per_timestep - - start_time = (self.data_start_sample / points_per_timestep) * time_per_timestep - - if self.decimation_factor > 1: - # For min-max, create pairs of time points for vertical lines - num_pairs = n_samples // 2 - time_step_between_groups = self.decimation_factor * time_per_point - group_times = start_time + np.arange(num_pairs) * time_step_between_groups - return np.repeat(group_times, 2) - else: - # For non-decimated data - if self.downsample_mode == "aggregate": - # Data is already min/max pairs from hardware. Create vertical lines. - num_pairs = n_samples // 2 - time_step_between_pairs = time_per_timestep - pair_times = start_time + np.arange(num_pairs) * time_step_between_pairs - return np.repeat(pair_times, 2) - else: - # For linear (non-aggregate) data, create a simple time axis - duration = ( - (n_samples - 1) * time_per_point if n_samples > 1 else 0 - ) - end_time = start_time + duration - return np.linspace(start_time, end_time, n_samples) - - def save_screenshot(self) -> None: - """Save a screenshot of the current plot.""" - try: - import pyqtgraph.exporters - - timestamp = time.strftime("%Y%m%d_%H%M%S") - filename = f"plot_screenshot_{timestamp}.png" - - # Export the plot widget as an image - exporter = pg.exporters.ImageExporter(self.plot_widget.plotItem) - exporter.export(filename) - - logger.info(f"Screenshot saved: {filename}") + if channel_index not in self.lines: + return - # Temporarily stop updates to show the message for 2 seconds - self.timer.stop() - self.acq_status_label.setText( - f'Screenshot saved: {filename}' + self.lines[channel_index].visible = visible + + if visible and channel_index not in self.channels: + self.channels.append(channel_index) + self._cache_ready = False + logger.debug("Channel {} enabled and added to channels list", channel_index) + elif not visible and channel_index in self.channels: + self.channels.remove(channel_index) + if channel_index in self._display_cache: + del self._display_cache[channel_index] + self.lines[channel_index].set_data(pos=np.zeros((2, 2), dtype=np.float32)) + logger.debug( + "Channel {} disabled, removed from processing, and cache cleared", + channel_index, ) - def resume_updates(): - self.update_from_file() # Update once immediately - self.timer.start(self.update_interval_ms) - - QTimer.singleShot(2000, resume_updates) - - except Exception as e: - logger.error(f"Failed to save screenshot: {e}") - # If we failed, restart timer if it was already running - if not self.timer.isActive(): - self.timer.start(self.update_interval_ms) - - def closeEvent(self, event: QCloseEvent) -> None: - """Handles the window close event. - - Stops the plot's internal timer and allows the Qt event loop to exit. - The main application will handle the graceful shutdown. - """ + def closeEvent(self, event: QCloseEvent) -> None: # pyright: ignore + """Handle the window close event.""" logger.info("Close event received. Stopping timer.") self.timer.stop() event.accept() - def keyPressEvent(self, event: QKeyEvent) -> None: - """Handles key presses for application control (e.g., 'Q' to quit).""" - if event.key() in (Qt.Key_S, Qt.Key_Space, Qt.Key_F12): - logger.info("Screenshot key pressed. Saving screenshot.") - self.save_screenshot() - else: - super().keyPressEvent(event) + def keyPressEvent(self, event: QKeyEvent) -> None: # pyright: ignore + """Handle key presses for application control.""" + super().keyPressEvent(event) def main() -> None: - """Standalone HDF5 live plotter.""" - # Create the QApplication instance FIRST. + """Standalone live plotter for testing.""" app = QApplication(sys.argv) - parser = argparse.ArgumentParser(description="Standalone HDF5 live plotter.") - parser.add_argument("hdf5_path", type=str, help="Path to the HDF5 file.") + parser = argparse.ArgumentParser(description="Standalone live plotter.") parser.add_argument( "--window", type=float, - default=0.5, - help="Display window in seconds. [default: 0.5]", + default=5.0, + help="Display window in seconds. [default: 5.0]", ) parser.add_argument( "--decimation", @@ -808,31 +964,31 @@ def main() -> None: default=150, help="Decimation factor for plotting. [default: 150]", ) - parser.add_argument( - "--y-min", - type=float, - help="Minimum Y-axis limit in mV.", - ) - parser.add_argument( - "--y-max", - type=float, - help="Maximum Y-axis limit in mV.", - ) args = parser.parse_args() logger.info("Plotter process starting") + + channels = [0, 1] + voltage_ranges = {0: 20.0, 1: 20.0} + offsets_v = {0: 0.0, 1: 0.0} + resolution = 16 + downsample_mode = "NONE" + try: - plotter = HDF5LivePlotter( - hdf5_path=args.hdf5_path, + plotter = LivePlotter( + ring_buffer=None, + channels=channels, + voltage_ranges=voltage_ranges, + offsets_v=offsets_v, + resolution=resolution, + downsample_mode=downsample_mode, display_window_seconds=args.window, decimation_factor=args.decimation, - y_min=args.y_min, - y_max=args.y_max, ) plotter.show() - sys.exit(app.exec_()) + sys.exit(app.exec()) except Exception as e: - logger.error(f"Error in plotter process: {e}") + logger.exception("Error in plotter process: {}", e) if __name__ == "__main__": diff --git a/picostream/main.py b/picostream/main.py index e7eb709..c6052ba 100644 --- a/picostream/main.py +++ b/picostream/main.py @@ -1,525 +1,2374 @@ +"""GUI for PicoStream application using PicoscopeBufferedStream device.""" + import os -import re import sys -from typing import Any, Dict, Optional +import threading +import time +from datetime import datetime +from pathlib import Path +from typing import Callable, Optional +import labdaemon as ld from loguru import logger -from PyQt5.QtCore import QObject, QSettings, QThread, QTimer, pyqtSignal, pyqtSlot, Qt -from PyQt5.QtGui import QCloseEvent, QFont -from PyQt5.QtWidgets import ( +from PyQt6.QtCore import QSettings, Qt, QTimer +from PyQt6.QtGui import QCloseEvent, QFont, QFontMetrics, QIcon, QKeyEvent +from PyQt6.QtWidgets import ( QApplication, - QCheckBox, QComboBox, QDoubleSpinBox, QFileDialog, QFormLayout, + QFrame, + QGridLayout, QGroupBox, QHBoxLayout, QLabel, - QLCDNumber, QLineEdit, QMainWindow, + QMessageBox, QPushButton, - QSpinBox, QVBoxLayout, QWidget, ) -from picostream.cli import Streamer, VOLTAGE_RANGE_MAP -from picostream.dfplot import HDF5LivePlotter +from picostream.device import PicoscopeBufferedStream +from picostream.dfplot import LivePlotter +from picostream.mock_device import MockPicoscopeBufferedStream +from picostream.zarr_viewer import ZarrViewerWindow -class StreamerWorker(QObject): - """Worker to run the data acquisition in a background thread.""" +class LosslessWarningDialog(QMessageBox): + """Custom dialog for Lossless mode warning with override option.""" - finished = pyqtSignal() - error = pyqtSignal(str) - stopRequested = pyqtSignal() + def __init__( + self, + required_points: int, + max_points: int, + refresh_rate: str, + parent: Optional[QWidget] = None, + ) -> None: + super().__init__(parent) + self.setIcon(QMessageBox.Icon.Warning) + self.setWindowTitle("Lossless Mode Warning") - def __init__(self, settings: Dict[str, Any]) -> None: - """Initialise the worker.""" - super().__init__() - self.streamer: Optional[Streamer] = None - self.settings = settings + self.setText( + f"Lossless mode requires {required_points:,} points, which exceeds " + f"the recommended limit of {max_points:,} points at {refresh_rate}.\n\n" + "Displaying this many points may cause performance issues." + ) - def run(self) -> None: - """Run the data acquisition.""" - try: - self.streamer = Streamer(**self.settings) - self.streamer.run() - except Exception as e: - self.error.emit(str(e)) - self.finished.emit() + self.setInformativeText( + "You can:\n" + "• Use High Quality (~500k points) for smooth performance\n" + "• Force Lossless to proceed anyway\n" + "• Cancel to return to settings" + ) + + self._use_high_button = self.addButton( + "Use High Quality", QMessageBox.ButtonRole.AcceptRole + ) + self._force_lossless_button = self.addButton( + "Force Lossless", QMessageBox.ButtonRole.ApplyRole + ) + self._cancel_button = self.addButton( + "Cancel", QMessageBox.ButtonRole.RejectRole + ) + + self.setDefaultButton(self._use_high_button) + + def user_choice(self) -> str: + """Return the user's choice as a string. + + Returns + ------- + str + One of: "high", "lossless", "cancel" + """ + clicked = self.clickedButton() + if clicked == self._use_high_button: + return "high" + if clicked == self._force_lossless_button: + return "lossless" + return "cancel" + + +QUALITY_POINTS = { + "Low": 25_000, + "Medium": 100_000, + "High": 500_000, + "Lossless": None, # None indicates no decimation (show every sample) +} - @pyqtSlot() - def stop(self) -> None: - """Signal the acquisition to stop.""" - if self.streamer: - self.streamer.shutdown() - self.finished.emit() +# Refresh rate options: (interval_ms, max_lossless_points) +# Limits for VisPy OpenGL rendering - significantly higher than pyqtgraph +# Targets: 20Hz@2M, 10Hz@5M, 5Hz@10M, 2Hz@20M, 1Hz@50M points +REFRESH_RATE_MAP = { + "20 Hz": (50, 2_000_000), # 2M points at 20Hz + "10 Hz": (100, 5_000_000), # 5M points at 10Hz + "5 Hz": (200, 10_000_000), # 10M points at 5Hz + "2 Hz": (500, 20_000_000), # 20M points at 2Hz + "1 Hz": (1000, 50_000_000), # 50M points at 1Hz +} class PicoStreamMainWindow(QMainWindow): """The main window for the PicoStream GUI application.""" def __init__(self) -> None: - """Initialise the main window.""" super().__init__() self.setWindowTitle("PicoStream") - self.setGeometry(100, 100, 1200, 600) + self.setGeometry(100, 100, 1400, 700) - self.settings = QSettings("picostream", "PicoStream") - self.thread: Optional[QThread] = None - self.worker: Optional[StreamerWorker] = None + icon_path = Path(__file__).parent.parent.parent / "assets" / "icons" / "app.ico" + if icon_path.exists(): + self.setWindowIcon(QIcon(str(icon_path))) + + base_font = QApplication.font() + base_font.setPointSize(base_font.pointSize() - 1) + QApplication.setFont(base_font) + + self.settings = QSettings("qolcode", "PicoStream") + self._zarr_viewers: list[ZarrViewerWindow] = [] + self._is_recording: bool = False + self._devices_connected: bool = False + + self._daemon: Optional[ld.LabDaemon] = None + self._device: Optional[PicoscopeBufferedStream] = None + self._stream_check_timer: Optional[QTimer] = None + self._save_poll_timer: Optional[QTimer] = None + self._save_start_time: Optional[float] = None + self._save_timeout_seconds: float = 15.0 + self._pending_save_path: Optional[str] = None central_widget = QWidget() self.setCentralWidget(central_widget) main_layout = QHBoxLayout(central_widget) + central_widget.setFocusPolicy(Qt.FocusPolicy.StrongFocus) - # Left panel for settings settings_panel = QWidget() - settings_panel.setFixedWidth(350) + settings_panel.setFixedWidth(320) settings_layout = QVBoxLayout(settings_panel) - form_layout = QFormLayout() - - self.sample_rate_input = QDoubleSpinBox() - self.sample_rate_input.setRange(1, 125) - self.sample_rate_input.setValue(62.5) - self.sample_rate_input.setSuffix(" MS/s") - form_layout.addRow("Sample Rate:", self.sample_rate_input) - - self.resolution_input = QComboBox() - self.resolution_input.addItems(["12", "14", "15", "16", "8"]) - form_layout.addRow("Resolution (bits):", self.resolution_input) - - self.voltage_range_input = QComboBox() - self.voltage_range_input.addItems(VOLTAGE_RANGE_MAP.values()) - form_layout.addRow("Voltage Range:", self.voltage_range_input) - - self.output_file_input = QLineEdit() - self.output_file_input.setText("output.hdf5") - file_browse_button = QPushButton("Browse...") - file_browse_button.clicked.connect(self.select_output_file) - file_layout = QHBoxLayout() - file_layout.addWidget(self.output_file_input) - file_layout.addWidget(file_browse_button) - form_layout.addRow("Output File:", file_layout) - - self.hw_downsample_input = QSpinBox() - self.hw_downsample_input.setRange(1, 1000) - form_layout.addRow("HW Downsample:", self.hw_downsample_input) - - self.downsample_mode_input = QComboBox() - self.downsample_mode_input.addItems(["average", "aggregate"]) - form_layout.addRow("Downsample Mode:", self.downsample_mode_input) - - self.offset_v_input = QDoubleSpinBox() - self.offset_v_input.setRange(-1.0, 1.0) - self.offset_v_input.setSingleStep(0.01) - self.offset_v_input.setDecimals(3) - self.offset_v_input.setSuffix(" V") - form_layout.addRow("Offset:", self.offset_v_input) - self.bandwidth_limiter_input = QComboBox() - self.bandwidth_limiter_input.addItems(["Full", "20 MHz"]) - form_layout.addRow("Bandwidth:", self.bandwidth_limiter_input) + self.toggle_channel_a = QPushButton("Channel A: ON") + self.toggle_channel_a.setCheckable(True) + self.toggle_channel_a.setChecked(True) + self.toggle_channel_a.clicked.connect(self._on_toggle_channel_a) + self.toggle_channel_a.setToolTip("Toggle Channel A visibility (1)") + self.toggle_channel_a.setStyleSheet( + "QPushButton { background-color: #00BFFF; color: white; }" + "QPushButton:checked { background-color: #00BFFF; color: white; }" + "QPushButton:!checked { background-color: #888888; color: white; }" + ) + self.toggle_channel_b = QPushButton("Channel B: ON") + self.toggle_channel_b.setCheckable(True) + self.toggle_channel_b.setChecked(True) + self.toggle_channel_b.clicked.connect(self._on_toggle_channel_b) + self.toggle_channel_b.setToolTip("Toggle Channel B visibility (2)") + self.toggle_channel_b.setStyleSheet( + "QPushButton { background-color: #FF4444; color: white; }" + "QPushButton:checked { background-color: #FF4444; color: white; }" + "QPushButton:!checked { background-color: #888888; color: white; }" + ) - self.live_only_checkbox = QCheckBox("Live-only (overwrite buffer)") - self.max_buffer_input = QDoubleSpinBox() - self.max_buffer_input.setRange(0.1, 60.0) - self.max_buffer_input.setValue(1.0) - self.max_buffer_input.setSuffix(" s") - self.live_only_checkbox.stateChanged.connect( - lambda state: self.max_buffer_input.setEnabled(state > 0) + self.serial_code_input = QLineEdit() + self.serial_code_input.setPlaceholderText("Empty (auto) / serial # / MOCK") + self.serial_code_input.setStyleSheet( + "QLineEdit { color: black; } QLineEdit::placeholder { color: #888; }" + ) + self.serial_code_input.setToolTip( + "Picoscope serial code (e.g., AW123/456).\n" + "Leave empty to auto-detect connected device.\n" + "Enter MOCK for simulated device (testing)." ) - form_layout.addRow(self.live_only_checkbox, self.max_buffer_input) + self.serial_code_input.setFocusPolicy(Qt.FocusPolicy.ClickFocus) - settings_layout.addLayout(form_layout) + serial_layout = QHBoxLayout() + serial_layout.addWidget(self.serial_code_input, stretch=1) + self.connect_button = QPushButton("Connect Device") + self.connect_button.clicked.connect(self.toggle_device_connection) + self.connect_button.setToolTip( + "Connect to or disconnect from the Picoscope device" + ) + serial_layout.addWidget(self.connect_button) + settings_layout.addLayout(serial_layout) - # Start/Stop buttons self.start_button = QPushButton("Start Acquisition") + self.start_button.clicked.connect(self.start_acquisition) + self.start_button.setEnabled(False) + self.start_button.setToolTip("Begin data acquisition (Space)") + self.stop_button = QPushButton("Stop Acquisition") self.stop_button.setEnabled(False) - + self.stop_button.clicked.connect(self.stop_acquisition) + self.stop_button.setToolTip("Stop active acquisition (Space)") + button_layout = QHBoxLayout() button_layout.addWidget(self.start_button) button_layout.addWidget(self.stop_button) settings_layout.addLayout(button_layout) - # Plot settings group - plot_settings_group = QGroupBox("Plot Settings") - plot_settings_form = QFormLayout(plot_settings_group) - - self.plot_window_input = QDoubleSpinBox() - self.plot_window_input.setRange(0.01, 10.0) - self.plot_window_input.setValue(0.5) - self.plot_window_input.setSingleStep(0.1) - self.plot_window_input.setDecimals(2) - self.plot_window_input.setSuffix(" s") - plot_settings_form.addRow("Display Window:", self.plot_window_input) - - self.y_axis_auto_checkbox = QCheckBox("Auto Y-axis") - self.y_axis_auto_checkbox.setChecked(True) - self.y_axis_auto_checkbox.setLayoutDirection(Qt.RightToLeft) - self.y_min_input = QDoubleSpinBox() - self.y_min_input.setRange(-100000, 100000) - self.y_min_input.setValue(-1000) - self.y_min_input.setSuffix(" mV") - self.y_min_input.setEnabled(False) - self.y_max_input = QDoubleSpinBox() - self.y_max_input.setRange(-100000, 100000) - self.y_max_input.setValue(1000) - self.y_max_input.setSuffix(" mV") - self.y_max_input.setEnabled(False) - - self.y_axis_auto_checkbox.stateChanged.connect( - lambda state: self.y_min_input.setEnabled(state == 0) - ) - self.y_axis_auto_checkbox.stateChanged.connect( - lambda state: self.y_max_input.setEnabled(state == 0) - ) - - # plot_settings_form.addRow() - y_limits_layout = QHBoxLayout() - y_limits_layout.addWidget(self.y_axis_auto_checkbox) - y_limits_layout.addWidget(QLabel("Min:")) - y_limits_layout.addWidget(self.y_min_input) - y_limits_layout.addWidget(QLabel("Max:")) - y_limits_layout.addWidget(self.y_max_input) - plot_settings_form.addRow(y_limits_layout) - - self.apply_plot_settings_button = QPushButton("Apply Plot Settings") - plot_settings_form.addRow(self.apply_plot_settings_button) - - settings_layout.addWidget(plot_settings_group) - - # Status display section using QGroupBox and QFormLayout + record_group = QGroupBox("Record Controls") + record_form = QFormLayout(record_group) + + self.start_record_button = QPushButton("Start Recording") + self.start_record_button.clicked.connect(self._on_start_record) + self.start_record_button.setEnabled(False) + self.start_record_button.setToolTip("Start recording ring buffer to disk (R)") + + self.finish_keep_button = QPushButton("Stop & Keep") + self.finish_keep_button.setStyleSheet( + "background-color: #4CAF50; color: white;" + ) + self.finish_keep_button.clicked.connect(self._on_finish_keep) + self.finish_keep_button.hide() + self.finish_keep_button.setToolTip("Stop recording and keep the file (K)") + + self.finish_discard_button = QPushButton("Stop & Discard") + self.finish_discard_button.setStyleSheet( + "background-color: #f44336; color: white;" + ) + self.finish_discard_button.clicked.connect(self._on_finish_discard) + self.finish_discard_button.hide() + self.finish_discard_button.setToolTip( + "Stop recording and discard the file (Del)" + ) + + self.load_file_button = QPushButton("Load File") + self.load_file_button.clicked.connect(self._on_load_file) + self.load_file_button.setToolTip( + "Open directory containing recorded Zarr files (Ctrl+O)" + ) + + record_buttons_layout = QHBoxLayout() + record_buttons_layout.addWidget(self.start_record_button) + record_buttons_layout.addWidget(self.finish_keep_button) + record_buttons_layout.addWidget(self.finish_discard_button) + record_buttons_layout.addStretch() + record_buttons_layout.addWidget(self.load_file_button) + record_form.addRow(record_buttons_layout) + + # Pre-trigger placed below action buttons + self.pre_trigger_input = QComboBox() + self.pre_trigger_input.addItems(["1s", "2s", "5s", "10s", "20s"]) + self.pre_trigger_input.setCurrentText("10s") + self.pre_trigger_input.setMinimumWidth(80) + self.pre_trigger_input.setToolTip("Samples to capture before trigger point") + self.pre_trigger_input.currentTextChanged.connect(self._on_pre_trigger_changed) + + pre_trigger_layout = QHBoxLayout() + pre_trigger_layout.addStretch() + pre_trigger_layout.addWidget(self.pre_trigger_input) + record_form.addRow("Pre-trigger:", pre_trigger_layout) + + self.record_directory_label = QLineEdit() + self.record_directory_label.setReadOnly(True) + self.record_directory_label.setToolTip("Directory where recordings are saved") + self.record_directory_label.setStyleSheet( + "QLineEdit { background-color: transparent; border: none; }" + ) + self._update_directory_label() + + record_change_button = QPushButton("Change...") + record_change_button.clicked.connect(self._select_record_directory) + record_location_layout = QHBoxLayout() + record_location_layout.addWidget(self.record_directory_label, stretch=1) + record_location_layout.addWidget(record_change_button) + + record_form.addRow("Save to:", record_location_layout) + + settings_layout.addWidget(record_group) + + acq_group = QGroupBox("Acquisition Parameters") + acq_grid = QGridLayout(acq_group) + acq_grid.setHorizontalSpacing(8) + + acq_grid.addWidget(QLabel("Res:"), 0, 0) + self.resolution_input = QComboBox() + self.resolution_input.addItems(["8", "12", "14", "15", "16"]) + self.resolution_input.setFixedWidth(60) + self.resolution_input.setToolTip( + "Resolution in bits - higher bits = higher precision" + ) + self.resolution_input.currentIndexChanged.connect(self._on_resolution_changed) + acq_grid.addWidget(self.resolution_input, 0, 1) + + acq_grid.addWidget(QLabel("Rate (p. ch.):"), 0, 2) + self.sample_rate_input = QComboBox() + self.sample_rate_input.setFixedWidth(90) + self.sample_rate_input.setToolTip("Samples per second per channel") + self.sample_rate_input.currentTextChanged.connect(self._on_sample_rate_changed) + acq_grid.addWidget(self.sample_rate_input, 0, 3) + self._init_sample_rate_dropdown() + + acq_grid.addWidget(QLabel("BW:"), 1, 0) + self.bandwidth_limiter_input = QComboBox() + self.bandwidth_limiter_input.addItems(["Full", "20 MHz"]) + self.bandwidth_limiter_input.setFixedWidth(60) + self.bandwidth_limiter_input.setToolTip("Apply analog bandwidth limiting") + acq_grid.addWidget(self.bandwidth_limiter_input, 1, 1) + + acq_grid.addWidget(QLabel("Buffer:"), 1, 2) + self.ring_buffer_input = QComboBox() + self.ring_buffer_input.addItems(["5s", "10s", "30s", "60s", "120s"]) + self.ring_buffer_input.setCurrentText("30s") + self.ring_buffer_input.setFixedWidth(90) + self.ring_buffer_input.currentTextChanged.connect( + self._on_buffer_duration_changed + ) + self.ring_buffer_input.setToolTip("Duration of lookback buffer in seconds") + acq_grid.addWidget(self.ring_buffer_input, 1, 3) + + acq_grid.addWidget(QLabel("Down:"), 2, 0) + self.hw_downsample_input = QComboBox() + self.hw_downsample_input.addItems( + ["1x", "2x", "5x", "10x", "20x", "50x", "100x"] + ) + self.hw_downsample_input.setFixedWidth(60) + self.hw_downsample_input.setToolTip("Hardware downsampling factor") + acq_grid.addWidget(self.hw_downsample_input, 2, 1) + + acq_grid.addWidget(QLabel("Mode:"), 2, 2) + self.downsample_mode_input = QComboBox() + self.downsample_mode_input.addItems(["none", "average", "aggregate"]) + self.downsample_mode_input.setFixedWidth(90) + self.downsample_mode_input.setToolTip( + "Downsampling algorithm: none (disabled), average across pairs, or aggregate (min/max)" + ) + acq_grid.addWidget(self.downsample_mode_input, 2, 3) + + settings_layout.addWidget(acq_group) + + channel_settings_group = QGroupBox("Channel Settings") + channel_vbox = QVBoxLayout(channel_settings_group) + channel_vbox.setSpacing(4) + + voltage_range_values = [ + "0.01", + "0.02", + "0.05", + "0.1", + "0.2", + "0.5", + "1.0", + "2.0", + "5.0", + "10.0", + "20.0", + ] + + channel_a_row = QHBoxLayout() + channel_a_row.addWidget(self.toggle_channel_a) + channel_a_row.addWidget(QLabel("R:")) + self.voltage_range_a = QComboBox() + self.voltage_range_a.addItems(voltage_range_values) + self.voltage_range_a.setFixedWidth(55) + self.voltage_range_a.setToolTip("Input voltage range in volts") + self.voltage_range_a.currentIndexChanged.connect( + self._on_voltage_range_a_changed + ) + channel_a_row.addWidget(self.voltage_range_a) + channel_a_row.addWidget(QLabel("O:")) + self.offset_a = QDoubleSpinBox() + self.offset_a.setRange(-1.0, 1.0) + self.offset_a.setSingleStep(0.01) + self.offset_a.setDecimals(2) + self.offset_a.setSuffix(" V") + self.offset_a.setFixedWidth(80) + self.offset_a.setToolTip("DC offset in volts") + self.offset_a.valueChanged.connect(self._on_offset_a_changed) + channel_a_row.addWidget(self.offset_a) + channel_vbox.addLayout(channel_a_row) + + channel_b_row = QHBoxLayout() + channel_b_row.addWidget(self.toggle_channel_b) + channel_b_row.addWidget(QLabel("R:")) + self.voltage_range_b = QComboBox() + self.voltage_range_b.addItems(voltage_range_values) + self.voltage_range_b.setFixedWidth(55) + self.voltage_range_b.setToolTip("Input voltage range in volts") + self.voltage_range_b.currentIndexChanged.connect( + self._on_voltage_range_b_changed + ) + channel_b_row.addWidget(self.voltage_range_b) + channel_b_row.addWidget(QLabel("O:")) + self.offset_b = QDoubleSpinBox() + self.offset_b.setRange(-1.0, 1.0) + self.offset_b.setSingleStep(0.01) + self.offset_b.setDecimals(2) + self.offset_b.setSuffix(" V") + self.offset_b.setFixedWidth(80) + self.offset_b.setToolTip("DC offset in volts") + self.offset_b.valueChanged.connect(self._on_offset_b_changed) + channel_b_row.addWidget(self.offset_b) + channel_vbox.addLayout(channel_b_row) + + settings_layout.addWidget(channel_settings_group) + status_group = QGroupBox("Status") - status_form_layout = QFormLayout(status_group) - - # Create status value labels with monospace font - font = QFont() - font.setFamily("Monospace") - font.setFixedPitch(True) - font.setPointSize(9) - - self.status_heartbeat = QLabel("-") - self.status_heartbeat.setFont(font) - status_form_layout.addRow("UI:", self.status_heartbeat) - - self.status_errors = self._create_status_label(font) - status_form_layout.addRow("Errors:", self.status_errors) - - self.status_saturation = QLabel("-") - self.status_saturation.setFont(font) - status_form_layout.addRow("Saturation:", self.status_saturation) - - self.status_samples = QLCDNumber() - self.status_samples.setDigitCount(10) - self.status_samples.setSegmentStyle(QLCDNumber.Flat) - self.status_samples.display("0") - status_form_layout.addRow("Samples:", self.status_samples) - - self.status_latency = QLabel("-") - self.status_latency.setFont(font) - status_form_layout.addRow("Latency:", self.status_latency) - + status_form = QFormLayout(status_group) + status_form.setSpacing(6) + status_form.setContentsMargins(8, 12, 8, 12) + + self.status_state = QLabel("IDLE") + self.status_state.setToolTip("Current acquisition state") + self.status_state.setFrameStyle(QFrame.Shape.StyledPanel | QFrame.Shadow.Sunken) + self.status_state.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.status_state.setMinimumWidth(70) + self._update_state_indicator("idle") + status_form.addRow("State:", self.status_state) + + self.status_saturation = QLabel("OK") + self.status_saturation.setToolTip( + "Hardware ADC saturation: signal exceeds voltage range setting" + ) + self.status_saturation.setFrameStyle( + QFrame.Shape.StyledPanel | QFrame.Shadow.Sunken + ) + self.status_saturation.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.status_saturation.setMinimumWidth(70) + self._update_saturation_indicator(False) + status_form.addRow("Saturation:", self.status_saturation) + + self.status_record = QLabel("Idle") + self.status_record.setToolTip("Recording status") + self.status_record.setFont(QFont("Monospace", 9)) + self.status_record.setFrameStyle( + QFrame.Shape.StyledPanel | QFrame.Shadow.Sunken + ) + self.status_record.setStyleSheet("background-color: #E0E0E0; padding: 2px;") + self.status_record.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.status_record.setMinimumWidth(140) + status_form.addRow("Record:", self.status_record) + + self.status_samples = QLabel("0") + self.status_samples.setToolTip("Total collected samples") + self.status_samples.setFont(QFont("Monospace", 10)) + self.status_samples.setAlignment(Qt.AlignmentFlag.AlignRight) + self.status_samples.setMinimumWidth(70) + status_form.addRow("Acq. Samples:", self.status_samples) + + self.status_duration = QLabel("-") + self.status_duration.setToolTip( + "Total acquisition time (Acq. Samples / acquisition rate)" + ) + self.status_duration.setFont(QFont("Monospace", 10)) + self.status_duration.setAlignment(Qt.AlignmentFlag.AlignRight) + self.status_duration.setMinimumWidth(70) + status_form.addRow("Acq. Duration:", self.status_duration) + self.status_rate = QLabel("-") - self.status_rate.setFont(font) - status_form_layout.addRow("Rate:", self.status_rate) - - self.status_acquisition = QLabel("Waiting for file...") - self.status_acquisition.setFont(font) - self.status_acquisition.setWordWrap(True) - status_form_layout.addRow("Acquisition:", self.status_acquisition) - + self.status_rate.setToolTip("Hardware sample rate (as reported by device)") + self.status_rate.setFont(QFont("Monospace", 10)) + self.status_rate.setAlignment(Qt.AlignmentFlag.AlignRight) + self.status_rate.setMinimumWidth(100) + status_form.addRow("HW Rate:", self.status_rate) + + self.status_lag = QLabel("-") + self.status_lag.setToolTip( + "Producer-consumer lag (how far behind the plotter is)" + ) + self.status_lag.setFont(QFont("Monospace", 10)) + self.status_lag.setAlignment(Qt.AlignmentFlag.AlignRight) + self.status_lag.setMinimumWidth(100) + status_form.addRow("Lag:", self.status_lag) + + # New display status indicators + self.status_display_points = QLabel("-") + self.status_display_points.setToolTip( + "Number of points currently displayed in plot" + ) + self.status_display_points.setFont(QFont("Monospace", 10)) + self.status_display_points.setAlignment(Qt.AlignmentFlag.AlignRight) + self.status_display_points.setMinimumWidth(100) + status_form.addRow("Display Points:", self.status_display_points) + + self.status_decimation = QLabel("-") + self.status_decimation.setToolTip( + "Current software decimation factor (1× = lossless)" + ) + self.status_decimation.setFont(QFont("Monospace", 10)) + self.status_decimation.setAlignment(Qt.AlignmentFlag.AlignRight) + self.status_decimation.setMinimumWidth(100) + status_form.addRow("Decimation:", self.status_decimation) + settings_layout.addWidget(status_group) settings_layout.addStretch() - # Right panel for the plot - self.plotter = HDF5LivePlotter(hdf5_path=self.output_file_input.text()) - - # Hide the plotter's status bar since we're showing it in the sidebar - self.plotter.hide_status_bar(hide=True) + default_target_points = 20000 + + # Will be connected after device is created + self._plotter_read_position_callback: Optional[Callable[[int], None]] = None + + self.plotter = LivePlotter( + ring_buffer=None, + channels=[0, 1], + voltage_ranges={ + 0: float(self.voltage_range_a.currentText()), + 1: float(self.voltage_range_b.currentText()), + }, + offsets_v={0: self.offset_a.value(), 1: self.offset_b.value()}, + resolution=12, + downsample_mode="NONE", + downsample_ratio=1, + target_plot_points=default_target_points, + on_read_position=self._on_plotter_read_position, + ) + + right_panel = QWidget() + right_layout = QVBoxLayout(right_panel) + right_layout.setContentsMargins(0, 0, 0, 0) + right_panel.setFocusPolicy(Qt.FocusPolicy.StrongFocus) + + right_layout.addWidget(self.plotter, 1) + + plot_settings_scale_layout = QHBoxLayout() + plot_settings_scale_layout.setContentsMargins(0, 0, 0, 0) + + plot_settings_group = QGroupBox("Plot Settings") + plot_layout = QGridLayout(plot_settings_group) + + quality_label = QLabel("Quality:") + plot_layout.addWidget(quality_label, 0, 0) + + self.quality_input = QComboBox() + self.quality_input.addItems(["Low", "Medium", "High", "Lossless"]) + self.quality_input.setCurrentText("High") + self.quality_input.setToolTip( + "Plot detail level - fixed point targets: Low (~25k), Medium (~100k), High (~500k). " + "Lossless shows every sample but warns if display may be slow" + ) + self.quality_input.currentTextChanged.connect(self._on_quality_changed) + plot_layout.addWidget(self.quality_input, 0, 1) + + plot_layout.addWidget(QLabel("Refresh:"), 1, 0) + self.refresh_rate_input = QComboBox() + self.refresh_rate_input.addItems(list(REFRESH_RATE_MAP.keys())) + self.refresh_rate_input.setCurrentText("20 Hz") + self.refresh_rate_input.setToolTip( + "Plot refresh rate - lower rates allow more points in Lossless mode" + ) + self.refresh_rate_input.currentTextChanged.connect( + self._on_refresh_rate_changed + ) + plot_layout.addWidget(self.refresh_rate_input, 1, 1) + + plot_layout.addWidget(QLabel("Display Window:"), 2, 0) + self.plot_window_input = QComboBox() + self.plot_window_input.addItems(["1s", "2s", "5s", "10s", "20s", "30s"]) + self.plot_window_input.setCurrentText("5s") + self.plot_window_input.setToolTip( + "Width of display time window (used at next acquisition start)" + ) + plot_layout.addWidget(self.plot_window_input, 2, 1) + + plot_settings_scale_layout.addWidget(plot_settings_group, 1) + + channel_a_scale_group = QGroupBox("Channel A") + channel_a_form = QFormLayout(channel_a_scale_group) + + self.v_div_a_input = QComboBox() + self.v_div_a_input.addItems( + [ + "0.5mV", + "1mV", + "2mV", + "5mV", + "10mV", + "20mV", + "50mV", + "100mV", + "200mV", + "500mV", + "1V", + "2V", + "5V", + "10V", + "20V", + ] + ) + self.v_div_a_input.setCurrentText("100mV") + self.v_div_a_input.currentTextChanged.connect(self._on_v_div_a_changed) + channel_a_form.addRow("V/div:", self.v_div_a_input) + + self.y_pos_a_input = QDoubleSpinBox() + self.y_pos_a_input.setRange(-5, 5) + self.y_pos_a_input.setValue(0) + self.y_pos_a_input.setSingleStep(0.1) + self.y_pos_a_input.setDecimals(1) + self.y_pos_a_input.setSuffix(" div") + self.y_pos_a_input.setToolTip("Vertical position in divisions from centre") + self.y_pos_a_input.valueChanged.connect(self._on_y_pos_a_changed) + channel_a_form.addRow("Y-pos:", self.y_pos_a_input) + + self.clip_a_indicator = QLabel("") + self.clip_a_indicator.setVisible(False) + self.clip_a_indicator.setStyleSheet("font-size: 16px; font-weight: bold;") + self.clip_a_indicator.setToolTip( + "Data exceeds display range: adjust V/div or Y-pos to view full signal" + ) + channel_a_form.addRow("", self.clip_a_indicator) + + channel_a_scale_group.setStyleSheet( + "QGroupBox { color: #00BFFF; font-weight: bold; }" + ) + + channel_b_scale_group = QGroupBox("Channel B") + channel_b_form = QFormLayout(channel_b_scale_group) + + self.v_div_b_input = QComboBox() + self.v_div_b_input.addItems( + [ + "0.5mV", + "1mV", + "2mV", + "5mV", + "10mV", + "20mV", + "50mV", + "100mV", + "200mV", + "500mV", + "1V", + "2V", + "5V", + "10V", + "20V", + ] + ) + self.v_div_b_input.setCurrentText("100mV") + self.v_div_b_input.currentTextChanged.connect(self._on_v_div_b_changed) + channel_b_form.addRow("V/div:", self.v_div_b_input) + + self.y_pos_b_input = QDoubleSpinBox() + self.y_pos_b_input.setRange(-5, 5) + self.y_pos_b_input.setValue(0) + self.y_pos_b_input.setSingleStep(0.1) + self.y_pos_b_input.setDecimals(1) + self.y_pos_b_input.setSuffix(" div") + self.y_pos_b_input.setToolTip("Vertical position in divisions from centre") + self.y_pos_b_input.valueChanged.connect(self._on_y_pos_b_changed) + channel_b_form.addRow("Y-pos:", self.y_pos_b_input) + + self.clip_b_indicator = QLabel("") + self.clip_b_indicator.setVisible(False) + self.clip_b_indicator.setStyleSheet("font-size: 16px; font-weight: bold;") + self.clip_b_indicator.setToolTip( + "Data exceeds display range: adjust V/div or Y-pos to view full signal" + ) + channel_b_form.addRow("", self.clip_b_indicator) + + channel_b_scale_group.setStyleSheet( + "QGroupBox { color: #FF4444; font-weight: bold; }" + ) + + plot_settings_scale_layout.addWidget(channel_a_scale_group, 1) + plot_settings_scale_layout.addWidget(channel_b_scale_group, 1) + + right_layout.addLayout(plot_settings_scale_layout) main_layout.addWidget(settings_panel) - main_layout.addWidget(self.plotter, 1) + main_layout.addWidget(right_panel, 1) - # Create a timer to sync status from plotter to main window self.status_sync_timer = QTimer() - self.status_sync_timer.timeout.connect(self.sync_status_from_plotter) - # Timer will be started when acquisition begins - - # Connect signals - self.start_button.clicked.connect(self.start_acquisition) - self.stop_button.clicked.connect(self.stop_acquisition) - self.apply_plot_settings_button.clicked.connect(self.apply_plot_settings) + self.status_sync_timer.timeout.connect(self._sync_status) self.load_settings() - def _create_status_label(self, font: QFont) -> QLabel: - """Create a status label with error styling capability.""" - label = QLabel("0") - label.setFont(font) - return label + def _on_plotter_read_position(self, read_idx: int) -> None: + """Forward plotter read position to device for overflow detection. + + Parameters + ---------- + read_idx : int + The sample index the plotter last read up to. + """ + if self._device is not None: + self._device.update_plotter_position(read_idx) + + self._apply_clamp_constraints() + + def _on_v_div_a_changed(self) -> None: + """Handle V/div change for channel A.""" + self._update_y_axis_settings() + + def _on_v_div_b_changed(self) -> None: + """Handle V/div change for channel B.""" + self._update_y_axis_settings() + + def _on_y_pos_a_changed(self) -> None: + """Handle Y-position change for channel A.""" + self._update_y_axis_settings() + + def _on_y_pos_b_changed(self) -> None: + """Handle Y-position change for channel B.""" + self._update_y_axis_settings() + + def _apply_plot_settings_to_plotter(self) -> None: + """Apply all plot settings from UI to plotter at acquisition start. + + This ensures plotter settings are applied atomically when acquisition + starts, avoiding race conditions between UI changes and pipeline creation. + """ + # Apply refresh rate + refresh_text = self.refresh_rate_input.currentText() + interval_ms, _ = REFRESH_RATE_MAP.get(refresh_text, (50, 50000)) + self.plotter.update_interval_ms = interval_ms + + # Apply quality (target_plot_points) + quality = self.quality_input.currentText() + if quality == "Lossless": + self.plotter.target_plot_points = None + else: + target_points = self._get_target_points_for_quality(quality) + self.plotter.target_plot_points = target_points + + logger.info( + "Applied plot settings: quality='{}', target_points={}", + quality, + self.plotter.target_plot_points, + ) + + def _update_y_axis_settings(self) -> None: + """Update plotter with V/div and Y-pos settings.""" + v_div_a = self._parse_v_div(self.v_div_a_input.currentText()) + v_div_b = self._parse_v_div(self.v_div_b_input.currentText()) + y_pos_a = self.y_pos_a_input.value() + y_pos_b = self.y_pos_b_input.value() + + self.plotter.set_v_div_settings( + {0: v_div_a, 1: v_div_b}, + {0: y_pos_a, 1: y_pos_b}, + ) + + def _parse_v_div(self, text: str) -> float: + """Parse V/div text to value in volts. + + Parameters + ---------- + text : str + Text like "100mV", "1V", "2V". + + Returns + ------- + float + Value in volts. + """ + text = text.strip().upper() + if text.endswith("MV"): + return float(text.rstrip("MV")) / 1000.0 + if text.endswith("V"): + return float(text.rstrip("V")) + return float(text) + + def _update_state_indicator(self, state: str) -> None: + """Update the state indicator colour based on current state.""" + state_upper = state.upper() + if state_upper == "LIVE": + self.status_state.setText("LIVE") + self.status_state.setStyleSheet( + "background-color: #4CAF50; color: white; font-weight: bold;" + ) + elif state_upper == "RECORDING": + self.status_state.setText("REC") + self.status_state.setStyleSheet( + "background-color: #FF0000; color: white; font-weight: bold;" + ) + elif state_upper == "ERROR": + self.status_state.setText("ERROR") + self.status_state.setStyleSheet( + "background-color: #f44336; color: white; font-weight: bold;" + ) + elif state_upper == "CONNECTED": + self.status_state.setText("CONN") + self.status_state.setStyleSheet( + "background-color: #2196F3; color: white; font-weight: bold;" + ) + else: + self.status_state.setText("IDLE") + self.status_state.setStyleSheet( + "background-color: #9E9E9E; color: white; font-weight: bold;" + ) + + def _update_saturation_indicator(self, is_saturated: bool) -> None: + """Update the hardware saturation indicator. + + Hardware saturation indicates ADC has clipped to max/min values because + the input signal exceeded the voltage range setting. This is distinct from + display clipping, which only indicates data exceeds the current view range. + """ + if is_saturated: + self.status_saturation.setText("CLIPPING") + self.status_saturation.setStyleSheet( + "background-color: #f44336; color: white; font-weight: bold;" + ) + else: + self.status_saturation.setText("OK") + self.status_saturation.setStyleSheet( + "background-color: #4CAF50; color: white; font-weight: bold;" + ) + + def _update_clip_indicators(self) -> None: + """Update per-channel display clipping indicators. + + Display clipping indicates when signal data exceeds the current Y-axis + view range (V/div and Y-pos settings), not hardware ADC saturation. + """ + clip_a = self.plotter.is_channel_clipped(0) + clip_b = self.plotter.is_channel_clipped(1) + + if clip_a == 1: + self.clip_a_indicator.setText("DATA IS ▼") + self.clip_a_indicator.setVisible(True) + elif clip_a == -1: + self.clip_a_indicator.setText("DATA IS ▲") + self.clip_a_indicator.setVisible(True) + else: + self.clip_a_indicator.setText("") + self.clip_a_indicator.setVisible(False) + + if clip_b == 1: + self.clip_b_indicator.setText("DATA IS ▼") + self.clip_b_indicator.setVisible(True) + elif clip_b == -1: + self.clip_b_indicator.setText("DATA IS ▲") + self.clip_b_indicator.setVisible(True) + else: + self.clip_b_indicator.setText("") + self.clip_b_indicator.setVisible(False) + + def _parse_buffer_duration(self) -> float: + """Parse buffer duration from QComboBox text (e.g., '30s' -> 30.0).""" + buffer_text = self.ring_buffer_input.currentText() + return float(buffer_text.rstrip("s")) if buffer_text else 30.0 + + def _parse_downsample_ratio(self) -> int: + """Parse downsample ratio from QComboBox text (e.g., '10x' -> 10).""" + downsample_text = self.hw_downsample_input.currentText() + return int(downsample_text.rstrip("x")) if downsample_text else 1 + + def _apply_clamp_constraints(self) -> None: + """Clamp plot window and pre-trigger to buffer duration.""" + buffer_duration = self._parse_buffer_duration() + + # For combo boxes, we disable items that exceed the buffer duration + plot_window_items = ["1s", "2s", "5s", "10s", "20s", "30s"] + for i, item in enumerate(plot_window_items): + duration_val = float(item.rstrip("s")) + is_enabled = duration_val <= buffer_duration + self.plot_window_input.model().item(i).setEnabled(is_enabled) + + pre_trigger_items = ["1s", "2s", "5s", "10s", "20s"] + for i, item in enumerate(pre_trigger_items): + duration_val = float(item.rstrip("s")) + is_enabled = duration_val <= buffer_duration + self.pre_trigger_input.model().item(i).setEnabled(is_enabled) + + # If current selection is now disabled, clamp to largest valid value + current_plot_window = self.plot_window_input.currentText() + current_plot_val = float(current_plot_window.rstrip("s")) + if current_plot_val > buffer_duration: + # Find largest valid value + for item in reversed(plot_window_items): + if float(item.rstrip("s")) <= buffer_duration: + self.plot_window_input.setCurrentText(item) + break + + current_pre_trigger = self.pre_trigger_input.currentText() + current_pre_val = float(current_pre_trigger.rstrip("s")) + if current_pre_val > buffer_duration: + for item in reversed(pre_trigger_items): + if float(item.rstrip("s")) <= buffer_duration: + self.pre_trigger_input.setCurrentText(item) + break + + def _is_16bit_mode(self) -> bool: + """Check if current resolution is 16-bit (single channel only). + + Returns + ------- + bool + True if 16-bit resolution is selected, False otherwise. + """ + return int(self.resolution_input.currentText()) == 16 + + def _on_toggle_channel_a(self) -> None: + is_checked = self.toggle_channel_a.isChecked() + self.toggle_channel_a.setText( + "Channel A: ON" if is_checked else "Channel A: OFF" + ) + + if is_checked and self._is_16bit_mode() and self.toggle_channel_b.isChecked(): + self.toggle_channel_b.setChecked(False) + self._on_toggle_channel_b() + + self._init_sample_rate_dropdown() + + def _on_toggle_channel_b(self) -> None: + is_checked = self.toggle_channel_b.isChecked() + self.toggle_channel_b.setText( + "Channel B: ON" if is_checked else "Channel B: OFF" + ) + + if is_checked and self._is_16bit_mode() and self.toggle_channel_a.isChecked(): + self.toggle_channel_a.setChecked(False) + self._on_toggle_channel_a() + + self._init_sample_rate_dropdown() + + def _update_directory_label(self) -> None: + """Update the directory label to show the current save location.""" + directory = self._get_record_directory() + display_path = str(directory) + home = str(Path.home()) + if display_path.startswith(home): + display_path = "~" + display_path[len(home) :] + + metrics = QFontMetrics(self.record_directory_label.font()) + display_path = metrics.elidedText( + display_path, + Qt.TextElideMode.ElideLeft, + self.record_directory_label.width(), + ) + self.record_directory_label.setText(display_path) + self.record_directory_label.setCursorPosition(0) + self.record_directory_label.setToolTip(str(directory)) + + def _get_record_directory(self) -> Path: + """Get the configured record directory, creating it if needed. + + Returns + ------- + Path + The directory path where recordings are saved. + """ + directory = self.settings.value( + "record_directory", str(Path.home() / "picostream") + ) + path = Path(directory) + path.mkdir(parents=True, exist_ok=True) + return path + + def _select_record_directory(self) -> None: + """Open file dialog to select record directory.""" + default_dir = str(self._get_record_directory()) + + path = QFileDialog.getExistingDirectory( + self, + "Select Record Directory", + default_dir, + QFileDialog.Option.ShowDirsOnly, + ) + + if path: + self.settings.setValue("record_directory", path) + self._update_directory_label() + logger.info("Record directory changed to: {}", path) + + def _get_next_sequence_number(self, directory: str) -> int: + """Get the next sequence number for today's recordings in the directory. + + Parameters + ---------- + directory : str + The directory to scan for existing recordings. + + Returns + ------- + int + The next sequence number (1-based) for today. + """ + today_str = datetime.now().strftime("%Y%m%d") + pattern = f"record_{today_str}_" + max_seq = 0 + + if not os.path.exists(directory): + return 1 + + for entry in os.listdir(directory): + if entry.startswith(pattern) and entry.endswith(".zarr"): + parts = entry.replace(".zarr", "").split("_") + if len(parts) >= 3: + try: + seq = int(parts[2]) + max_seq = max(max_seq, seq) + except ValueError: + pass + + return max_seq + 1 + + def _get_record_path(self) -> str: + """Generate a record path in the configured directory. + + Returns + ------- + str + Full path to the new recording file. + """ + directory = self._get_record_directory() + + today_str = datetime.now().strftime("%Y%m%d") + time_str = datetime.now().strftime("%H%M%S") + seq_num = self._get_next_sequence_number(str(directory)) + + filename = f"record_{today_str}_{seq_num:03d}_{time_str}.zarr" + return str(directory / filename) + + def _on_start_record(self) -> None: + """Handle Start Recording button click.""" + if self._device is None: + QMessageBox.warning( + self, "Record Error", "Device not available. Is acquisition running?" + ) + return + + record_path = self._get_record_path() + filename = os.path.basename(record_path) + logger.info("Starting recording: {}", filename) + self.status_record.setText(f"Starting: {filename}") + + pre_trigger_text = self.pre_trigger_input.currentText() + lookback_seconds = float(pre_trigger_text.rstrip("s")) + + # Validate pre-trigger doesn't exceed buffer duration + buffer_duration = self._parse_buffer_duration() + if lookback_seconds > buffer_duration: + QMessageBox.warning( + self, + "Invalid Pre-trigger", + f"Pre-trigger ({lookback_seconds}s) exceeds buffer duration ({buffer_duration}s).\n\n" + f"Please reduce pre-trigger or increase buffer duration.", + ) + return + + success = self._device.start_save(lookback_seconds, record_path) + if success: + self._is_recording = True + logger.info("Started recording to: {}", record_path) + self._apply_ui_state("recording") + else: + QMessageBox.warning( + self, "Record Error", "Failed to start recording. Check device status." + ) + + def _on_finish_keep(self) -> None: + """Handle Stop & Keep button click.""" + if not self._device: + return + + try: + self._update_record_status("stopping") + self._apply_ui_state("stopping") + + self._pending_save_path = self._device.stop_save(keep=True) + self._save_start_time = time.time() + + self._save_poll_timer = QTimer(self) + self._save_poll_timer.timeout.connect(self._check_save_done_keep) + self._save_poll_timer.start(100) + except Exception as e: + logger.exception("Error in _on_finish_keep: {}", e) + self._update_record_status("error") + self._apply_ui_state("live") + QMessageBox.warning(self, "Error", f"Error stopping recording:\n{e}") + + def _on_finish_discard(self) -> None: + """Handle Stop & Discard button click.""" + if not self._device: + return + + try: + self._update_record_status("stopping") + self._apply_ui_state("stopping") + + self._device.stop_save(keep=False) + self._save_start_time = time.time() + + self._save_poll_timer = QTimer(self) + self._save_poll_timer.timeout.connect(self._check_save_done_discard) + self._save_poll_timer.start(100) + except Exception as e: + logger.exception("Error in _on_finish_discard: {}", e) + self._update_record_status("error") + self._apply_ui_state("live") + QMessageBox.warning(self, "Error", f"Error stopping recording:\n{e}") + + def _check_save_done_keep(self) -> None: + """Poll for save completion (keep mode).""" + try: + # Check for timeout + if self._save_start_time and ( + time.time() - self._save_start_time > self._save_timeout_seconds + ): + self._save_poll_timer.stop() + logger.error( + "Save operation timed out after {} seconds", + self._save_timeout_seconds, + ) + QMessageBox.warning( + self, + "Save Timeout", + f"Save operation timed out after {self._save_timeout_seconds} seconds.\n\n" + "The file may be incomplete.", + ) + self._is_recording = False + self._update_record_status("error") + self._save_poll_timer = None + self._apply_ui_state("live") + return + + if self._device.is_save_finished(): + self._save_poll_timer.stop() + path = self._pending_save_path + if path: + QMessageBox.information( + self, "Record Complete", f"File saved to:\n{path}" + ) + save_dir = os.path.dirname(path) + self.settings.setValue("last_save_directory", save_dir) + logger.info("Recording finished and kept: {}", path) + else: + logger.info("Recording finished (no path returned)") + + self._is_recording = False + self._update_record_status("complete") + self._save_poll_timer = None + self._apply_ui_state("live") + except Exception as e: + logger.exception("Error in save poll (keep mode): {}", e) + if self._save_poll_timer: + self._save_poll_timer.stop() + self._save_poll_timer = None + self._is_recording = False + self._update_record_status("error") + QMessageBox.warning( + self, "Error", f"Error during save completion check:\n{e}" + ) + self._apply_ui_state("error") + + def _check_save_done_discard(self) -> None: + """Poll for save completion (discard mode).""" + try: + # Check for timeout + if self._save_start_time and ( + time.time() - self._save_start_time > self._save_timeout_seconds + ): + self._save_poll_timer.stop() + logger.error( + "Save operation timed out after {} seconds", + self._save_timeout_seconds, + ) + QMessageBox.warning( + self, + "Save Timeout", + f"Save operation timed out after {self._save_timeout_seconds} seconds.\n\n" + "Operation cancelled.", + ) + self._is_recording = False + self._update_record_status("error") + self._save_poll_timer = None + self._apply_ui_state("live") + return + + if self._device.is_save_finished(): + self._save_poll_timer.stop() + self._is_recording = False + self._update_record_status("idle") + self._save_poll_timer = None + self._apply_ui_state("live") + logger.info("Recording discarded") + except Exception as e: + logger.exception("Error in save poll (discard mode): {}", e) + if self._save_poll_timer: + self._save_poll_timer.stop() + self._save_poll_timer = None + self._is_recording = False + self._update_record_status("error") + QMessageBox.warning( + self, "Error", f"Error during save completion check:\n{e}" + ) + self._apply_ui_state("error") + + def _update_record_status(self, state: str) -> None: + """Update the record status indicator. + + Parameters + ---------- + state : str + One of: "idle", "saving", "stopping", "error", "complete". + """ + if state == "stopping": + self.status_record.setText("Stopping...") + self.status_record.setStyleSheet( + "background-color: #FF9800; color: white; font-weight: bold;" + ) + elif state == "idle": + self.status_record.setText("Idle") + self.status_record.setStyleSheet("background-color: #E0E0E0; color: black;") + elif state == "complete": + self.status_record.setText("Saved") + self.status_record.setStyleSheet( + "background-color: #4CAF50; color: white; font-weight: bold;" + ) + elif state == "error": + self.status_record.setText("Error") + self.status_record.setStyleSheet( + "background-color: #f44336; color: white; font-weight: bold;" + ) + + def _apply_ui_state(self, state: str) -> None: + """Apply a complete, consistent UI state. + + Centralises all UI state changes to prevent partial/inconsistent states. + Valid states: idle, connected, live, recording, stopping, error. + + Parameters + ---------- + state : str + The target UI state. + """ + state = state.lower() + logger.debug("Applying UI state: {}", state) - def _set_status_error_style(self, label: QLabel, is_error: bool) -> None: - """Apply error styling (red background) to a status label.""" - if is_error: - label.setStyleSheet("background-color: #ffcccc; padding: 2px;") + if state == "idle": + self.connect_button.setEnabled(True) + self.connect_button.setText("Connect Device") + self.start_button.setEnabled(False) + self.stop_button.setEnabled(False) + self.start_record_button.setEnabled(False) + self._show_record_buttons(recording=False) + self._set_hardware_inputs_enabled(True) + self.quality_input.setEnabled(True) + self.refresh_rate_input.setEnabled(True) + self.plot_window_input.setEnabled(True) + self._update_state_indicator("idle") + elif state == "connected": + self.connect_button.setEnabled(True) + self.connect_button.setText("Disconnect Device") + self.start_button.setEnabled(True) + self.stop_button.setEnabled(False) + self.start_record_button.setEnabled(False) + self._show_record_buttons(recording=False) + self.serial_code_input.setEnabled(False) + self.resolution_input.setEnabled(False) + self.toggle_channel_a.setEnabled(True) + self.toggle_channel_b.setEnabled(True) + self.voltage_range_a.setEnabled(True) + self.voltage_range_b.setEnabled(True) + self.offset_a.setEnabled(True) + self.offset_b.setEnabled(True) + self._set_acq_params_enabled(True) + self.quality_input.setEnabled(True) + self.refresh_rate_input.setEnabled(True) + self.plot_window_input.setEnabled(True) + self._update_state_indicator("connected") + elif state == "live": + self.connect_button.setEnabled(False) + self.connect_button.setText("Disconnect Device") + self.start_button.setEnabled(False) + self.stop_button.setEnabled(True) + self.start_record_button.setEnabled(True) + self._show_record_buttons(recording=False) + self._set_hardware_inputs_enabled(False) + self.quality_input.setEnabled(False) + self.refresh_rate_input.setEnabled(False) + self.plot_window_input.setEnabled(False) + self._update_state_indicator("live") + elif state == "recording": + self.connect_button.setEnabled(False) + self.connect_button.setText("Disconnect Device") + self.start_button.setEnabled(False) + self.stop_button.setEnabled(False) + self.finish_keep_button.setEnabled(True) + self.finish_discard_button.setEnabled(True) + self._show_record_buttons(recording=True) + self._set_hardware_inputs_enabled(False) + self.quality_input.setEnabled(False) + self.refresh_rate_input.setEnabled(False) + self.plot_window_input.setEnabled(False) + self._update_state_indicator("recording") + elif state == "stopping": + self.connect_button.setEnabled(False) + self.connect_button.setText("Disconnect Device") + self.start_button.setEnabled(False) + self.stop_button.setEnabled(False) + self.finish_keep_button.setEnabled(False) + self.finish_discard_button.setEnabled(False) + self.quality_input.setEnabled(False) + self.refresh_rate_input.setEnabled(False) + self.plot_window_input.setEnabled(False) + self._update_state_indicator("recording") + elif state == "error": + self.connect_button.setEnabled(True) + if self._devices_connected: + self.connect_button.setText("Disconnect Device") + else: + self.connect_button.setText("Connect Device") + self.start_button.setEnabled(False) + self.stop_button.setEnabled(False) + self.start_record_button.setEnabled(False) + self._show_record_buttons(recording=False) + self.serial_code_input.setEnabled(not self._devices_connected) + self.resolution_input.setEnabled(not self._devices_connected) + self.toggle_channel_a.setEnabled(True) + self.toggle_channel_b.setEnabled(True) + self.voltage_range_a.setEnabled(True) + self.voltage_range_b.setEnabled(True) + self.offset_a.setEnabled(True) + self.offset_b.setEnabled(True) + self._set_acq_params_enabled(True) + self.quality_input.setEnabled(True) + self.refresh_rate_input.setEnabled(True) + self.plot_window_input.setEnabled(True) + self._update_state_indicator("error") else: - label.setStyleSheet("") - - def select_output_file(self) -> None: - """Open a dialog to select the output HDF5 file.""" - file_name, _ = QFileDialog.getSaveFileName( - self, "Select Output File", self.output_file_input.text(), "HDF5 Files (*.hdf5)" - ) - if file_name: - self.output_file_input.setText(file_name) - - def apply_plot_settings(self) -> None: - """Apply plot settings to the live plotter.""" - window_seconds = self.plot_window_input.value() - self.plotter.set_display_window(window_seconds) - - if self.y_axis_auto_checkbox.isChecked(): - self.plotter.set_y_limits(None, None) + logger.warning("Unknown UI state: {}", state) + + def _show_record_buttons(self, recording: bool) -> None: + """Toggle between Start Recording and Stop buttons. + + Parameters + ---------- + recording : bool + True to show keep/discard buttons, False to show start button. + """ + if recording: + self.start_record_button.hide() + self.finish_keep_button.show() + self.finish_discard_button.show() + else: + self.start_record_button.show() + self.finish_keep_button.hide() + self.finish_discard_button.hide() + + def _on_load_file(self) -> None: + """Handle Load File button click - open Zarr viewer.""" + default_dir = self._get_record_directory() + last_dir = self.settings.value("last_save_directory", str(default_dir)) + + path = QFileDialog.getExistingDirectory( + self, + "Open Zarr File", + last_dir, + QFileDialog.Option.ShowDirsOnly, + ) + + if not path: + return + + try: + viewer = ZarrViewerWindow(path) + viewer.show() + self._zarr_viewers.append(viewer) + viewer.destroyed.connect( + lambda: ( + self._zarr_viewers.remove(viewer) + if viewer in self._zarr_viewers + else None + ) + ) + logger.info("Opened Zarr viewer for: {}", path) + except Exception as e: + logger.exception("Failed to open Zarr file: {}", e) + QMessageBox.critical(self, "Error", f"Failed to open file:\n{e}") + + def _set_acq_params_enabled(self, enabled: bool) -> None: + """Enable or disable acquisition parameters (sample rate, downsample, etc). + + Parameters + ---------- + enabled : bool + True to enable acq params, False to disable them. + """ + self.sample_rate_input.setEnabled(enabled) + self.hw_downsample_input.setEnabled(enabled) + self.downsample_mode_input.setEnabled(enabled) + self.bandwidth_limiter_input.setEnabled(enabled) + self.ring_buffer_input.setEnabled(enabled) + + def _set_hardware_inputs_enabled(self, enabled: bool) -> None: + """Enable or disable all hardware setting input fields. + + Parameters + ---------- + enabled : bool + True to enable all inputs, False to disable all inputs. + """ + self.serial_code_input.setEnabled(enabled) + self.resolution_input.setEnabled(enabled) + self.toggle_channel_a.setEnabled(enabled) + self.toggle_channel_b.setEnabled(enabled) + self.voltage_range_a.setEnabled(enabled) + self.voltage_range_b.setEnabled(enabled) + self.offset_a.setEnabled(enabled) + self.offset_b.setEnabled(enabled) + self._set_acq_params_enabled(enabled) + + def toggle_device_connection(self) -> None: + """Connect or disconnect from the hardware device.""" + if self._devices_connected: + self._disconnect_device() else: - y_min = self.y_min_input.value() - y_max = self.y_max_input.value() - self.plotter.set_y_limits(y_min, y_max) + self._connect_device() + + def _connect_device(self) -> None: + """Connect to the Picoscope device.""" + logger.debug( + f"[THREAD] _connect_device called from thread: {threading.current_thread().name}" + ) + + serial_code = self.serial_code_input.text().strip() + is_mock = serial_code.upper() == "MOCK" + + if is_mock: + logger.warning("Using MOCK device (test mode)") + + try: + if is_mock: + self._daemon = None + self._device = MockPicoscopeBufferedStream( + device_id="picoscope_stream", + serial_code="MOCK", + resolution=int(self.resolution_input.currentText()), + ) + self._device.connect() + else: + logger.debug("[THREAD] Creating LabDaemon...") + self._daemon = ld.LabDaemon() + self._daemon.register_plugins( + devices={"PicoscopeBufferedStream": PicoscopeBufferedStream}, + tasks={}, + ) + + self._device = self._daemon.add_device( + device_id="picoscope_stream", + device_type="PicoscopeBufferedStream", + resolution=int(self.resolution_input.currentText()), + ) + + if serial_code: + self._device.set_connection_id(serial_code) + + self._daemon.connect_device(self._device, timeout=10.0) + + self._devices_connected = True + logger.info("Device connected successfully") + self._apply_ui_state("connected") + + except ld.DeviceTimeoutError as e: + logger.exception("Device connection timed out") + QMessageBox.critical( + self, + "Connection Timeout", + f"Device connection timed out:\n\n{e}\n\nCheck hardware connections and power.", + ) + self._cleanup_device() + except Exception as e: + logger.exception("Failed to connect to device") + QMessageBox.critical( + self, + "Connection Error", + f"Failed to connect to device:\n\n{e}", + ) + self._cleanup_device() + + def _disconnect_device(self) -> None: + """Disconnect from the Picoscope device.""" + logger.debug( + f"[THREAD] _disconnect_device called from thread: {threading.current_thread().name}" + ) + + try: + if self._stream_check_timer: + self._stream_check_timer.stop() + self._stream_check_timer = None + + if hasattr(self, "status_sync_timer"): + self.status_sync_timer.stop() + + if self._device and self._device.is_saving: + try: + self._device.stop_save(keep=True) + except Exception: + logger.exception("Error stopping save during disconnect") + finally: + self._is_recording = False + + if self._device and self._device.is_streaming(): + try: + self._device.stop_streaming() + except Exception: + logger.exception("Error stopping streaming during disconnect") + + self.plotter.stop_updates() + self.plotter.set_ring_buffer(None) + + self._cleanup_device() + + logger.info("Device disconnected") + + finally: + self._is_recording = False + self._devices_connected = False + self._apply_ui_state("idle") def start_acquisition(self) -> None: - """Start the background data acquisition.""" - self.start_button.setEnabled(False) - self.stop_button.setEnabled(True) + """Start data acquisition (requires device to be connected).""" + logger.debug( + f"[THREAD] start_acquisition called from thread: {threading.current_thread().name}" + ) - output_file = self.output_file_input.text() + if not self._devices_connected or self._device is None: + QMessageBox.warning( + self, + "Not Connected", + "Please connect to device before starting acquisition.", + ) + return + + # Validate Lossless mode if selected + if not self._validate_lossless_at_start(): + return - # Stop plotter updates and clear all state BEFORE changing file path self.plotter.stop_updates() - - # Clear the plotter's display and reset all internal state - self.plotter.reset_for_new_file() - - # Ensure status bar remains hidden after reset - self.plotter.hide_status_bar(hide=True) - - # Remove old file if it exists - if os.path.exists(output_file): - try: - os.remove(output_file) - logger.info(f"Removed existing file: {output_file}") - except OSError as e: - self.on_acquisition_error(f"Failed to remove old file: {e}") - self.on_acquisition_finished() - return + self.plotter.clear_curves() + self.plotter.set_ring_buffer(None) + + self.plotter.rate_check_start_time = None + self.plotter.rate_check_start_samples = 0 + + enabled_channels = [] + voltage_ranges = [] + offsets_v = [] + + if self.toggle_channel_a.isChecked(): + enabled_channels.append(0) + voltage_ranges.append(float(self.voltage_range_a.currentText())) + offsets_v.append(self.offset_a.value()) + + if self.toggle_channel_b.isChecked(): + enabled_channels.append(1) + voltage_ranges.append(float(self.voltage_range_b.currentText())) + offsets_v.append(self.offset_b.value()) + + if not enabled_channels: + QMessageBox.warning( + self, "Configuration Error", "At least one channel must be enabled." + ) + return + + sample_rate_msps = float(self.sample_rate_input.currentText().rstrip(" MS/s")) + sample_rate_hz = sample_rate_msps * 1e6 + + downsample_mode = self.downsample_mode_input.currentText() + downsample_ratio = self._parse_downsample_ratio() + + self.plotter.set_channel_visible(0, 0 in enabled_channels) + self.plotter.set_channel_visible(1, 1 in enabled_channels) + + try: + buffer_duration_s = self._parse_buffer_duration() + + self._device.configure_streaming( + sample_rate=sample_rate_hz, + channels=enabled_channels, + voltage_ranges=voltage_ranges, + buffer_duration_s=buffer_duration_s, + downsample_ratio=downsample_ratio, + downsample_mode=downsample_mode, + offsets_v=offsets_v, + bandwidth_limiter=self.bandwidth_limiter_input.currentText(), + ) + + self._device.start_streaming() + + # Apply plot settings from UI (quality, refresh rate, window) + self._apply_plot_settings_to_plotter() + + # Update plotter with actual acquisition settings before connecting + self.plotter.voltage_ranges = self._device._streaming_voltage_ranges.copy() + self.plotter.offsets_v = self._device._streaming_offsets_v.copy() + + self.plotter.set_ring_buffer(self._device.ring_buffer) + + window_text = self.plot_window_input.currentText() + window_seconds = float(window_text.rstrip("s")) + self.plotter.set_display_window(window_seconds) + + except Exception as e: + logger.exception("Failed to start acquisition") + self._show_acquisition_error(str(e)) + return - # Now set the new file path (plotter is stopped and cleared) - self.plotter.set_hdf5_path(output_file) - - # Apply plot settings before starting - self.apply_plot_settings() - - settings = { - "sample_rate_msps": self.sample_rate_input.value(), - "resolution_bits": int(self.resolution_input.currentText()), - "channel_range_str": self.voltage_range_input.currentText(), - "output_file": output_file, - "hardware_downsample": self.hw_downsample_input.value(), - "downsample_mode": self.downsample_mode_input.currentText(), - "offset_v": self.offset_v_input.value(), - "max_buffer_seconds": self.max_buffer_input.value() - if self.live_only_checkbox.isChecked() - else None, - "enable_live_plot": False, - "is_gui_mode": True, - "bandwidth_limiter": self.bandwidth_limiter_input.currentText().lower(), - } - - self.thread = QThread() - self.worker = StreamerWorker(settings) - self.worker.moveToThread(self.thread) - self.worker.stopRequested.connect(self.worker.stop, Qt.QueuedConnection) - - self.thread.started.connect(self.worker.run) - self.worker.finished.connect(self.thread.quit) - self.worker.finished.connect(self.worker.deleteLater) - self.thread.finished.connect(self.thread.deleteLater) - self.worker.finished.connect(self.on_acquisition_finished) - self.worker.error.connect(self.on_acquisition_error) - - self.thread.start() - - # Start plotter updates - it will wait for the file to exist with valid data self.plotter.start_updates() - - # Start status sync timer - self.status_sync_timer.start(100) # Update every 100ms - def stop_acquisition(self) -> None: - """Stop the background data acquisition.""" + self.status_sync_timer.start(100) + + self._stream_check_timer = QTimer() + self._stream_check_timer.timeout.connect(self._check_stream_status) + self._stream_check_timer.start(500) + + self._apply_ui_state("live") + logger.info("Acquisition started successfully") + + def _check_stream_status(self) -> None: + """Check if streaming is still active, update UI if stopped.""" + if self._device is None: + return + + if not self._device.is_streaming(): + # Check if streaming stopped due to an error + error = self._device.get_streaming_error() + if error: + logger.error("Streaming stopped with error: {}", error) + self._show_streaming_error(error) + else: + self._on_streaming_stopped() + + def _show_streaming_error(self, error: str) -> None: + """Handle streaming stopped due to an error. + + Parameters + ---------- + error : str + The error description from the device. + """ + logger.debug( + f"[THREAD] _show_streaming_error called from thread: {threading.current_thread().name}" + ) + + if self._stream_check_timer: + self._stream_check_timer.stop() + self._stream_check_timer = None + self.plotter.stop_updates() - # Stop the status sync timer to save CPU - if hasattr(self, 'status_sync_timer'): + self.plotter.set_ring_buffer(None) + + if hasattr(self, "status_sync_timer"): self.status_sync_timer.stop() - if self.worker: - self.worker.stopRequested.emit() - self.stop_button.setEnabled(False) - def on_acquisition_finished(self) -> None: - """Handle acquisition completion (both success and failure).""" - logger.info("Acquisition finished.") + self._is_recording = False + + error_display = error.replace("PICO_", "").replace("_", " ").title() + self.status_record.setText(f"Error: {error_display[:25]}") + + logger.info("Streaming stopped due to error: {}", error) + self._apply_ui_state("error") + + def _on_streaming_stopped(self) -> None: + """Handle streaming stopped (either normally or via error).""" + logger.debug( + f"[THREAD] _on_streaming_stopped called from thread: {threading.current_thread().name}" + ) + + if self._stream_check_timer: + self._stream_check_timer.stop() + self._stream_check_timer = None + self.plotter.stop_updates() - self.start_button.setEnabled(True) - self.stop_button.setEnabled(False) - self.thread = None - self.worker = None + self.plotter.set_ring_buffer(None) + + if hasattr(self, "status_sync_timer"): + self.status_sync_timer.stop() + + self._is_recording = False + + logger.info("Acquisition stopped") + self._apply_ui_state("connected") + + def stop_acquisition(self) -> None: + """Stop data acquisition (device remains connected).""" + logger.debug( + f"[THREAD] stop_acquisition called from thread: {threading.current_thread().name}" + ) + + try: + if self._stream_check_timer: + self._stream_check_timer.stop() + self._stream_check_timer = None + + self.plotter.stop_updates() + self.plotter.set_ring_buffer(None) + + if hasattr(self, "status_sync_timer"): + self.status_sync_timer.stop() + + if self._device and self._device.is_saving: + try: + self._device.stop_save(keep=True) + except Exception: + logger.exception("Error stopping save during acquisition stop") + finally: + self._is_recording = False + + if self._device: + try: + self._device.stop_streaming() + except Exception: + logger.exception("Error stopping streaming") + + logger.info("Acquisition stopped by user") + + except Exception as e: + logger.exception("Unexpected error in stop_acquisition: {}", e) + finally: + self._is_recording = False + self._apply_ui_state("connected") + + def _cleanup_device(self) -> None: + """Clean up the LabDaemon device and daemon.""" + if self._device and self._device.is_saving: + logger.debug("[THREAD] Stopping save...") + self._device.stop_save(keep=True) + time.sleep(0.1) + + if self._daemon: + all_threads_before = [t.name for t in threading.enumerate()] + logger.debug( + f"[THREAD] Threads before daemon.shutdown(): {all_threads_before}" + ) + time.sleep(0.1) + try: + logger.debug("[THREAD] Calling daemon.shutdown()...") + self._daemon.shutdown(timeout=2.0) + logger.debug("[THREAD] daemon.shutdown() completed") + except Exception: + logger.exception("Error during daemon shutdown") + finally: + self._daemon = None + self._device = None + all_threads_after = [t.name for t in threading.enumerate()] + logger.debug( + f"[THREAD] Threads after daemon.shutdown(): {all_threads_after}" + ) + elif self._device: + try: + logger.debug("[THREAD] Disconnecting mock device...") + self._device.disconnect() + except Exception: + logger.exception("Error during mock device disconnect") + finally: + self._device = None - def on_acquisition_error(self, err_msg: str) -> None: - """Handle acquisition error.""" - logger.error(f"Acquisition error: {err_msg}") + def _show_acquisition_error(self, err_msg: str) -> None: + """Show error dialog for acquisition failure.""" + logger.error("Acquisition error: {}", err_msg) + self._update_state_indicator("error") + self.status_record.setText(f"Error: {err_msg[:30]}") - def sync_status_from_plotter(self) -> None: - """Sync status from the plotter to the main window status labels.""" + if "PICO_NOT_FOUND" in err_msg: + QMessageBox.critical( + self, + "Device Not Found", + "Picoscope device not found.\n\n" + "Please check that:\n" + "• The device is connected via USB\n" + "• No other software is using the device\n" + "• The device is powered on", + ) + else: + QMessageBox.critical( + self, "Acquisition Error", f"Failed to start acquisition:\n{err_msg}" + ) + + def _sync_status(self) -> None: + """Sync status from plotter and device.""" try: - # Check if plotter exists and has been initialised - if not hasattr(self, 'plotter') or not self.plotter: + if not hasattr(self, "plotter") or not self.plotter: return - - # Helper function to safely strip HTML tags - def strip_html(text: str) -> str: - if not text: - return text - # Remove all HTML tags - return re.sub(r'<[^>]+>', '', text) - - # Safely get text from label or return empty string - def safe_get_text(label_name: str) -> str: - if hasattr(self.plotter, label_name): - label = getattr(self.plotter, label_name) - if hasattr(label, 'text'): - return strip_html(label.text()) - return "" - - # Update all status labels - self.status_heartbeat.setText(safe_get_text('heartbeat_label')) - - # Update samples with QLCDNumber - samples_text = safe_get_text('samples_label') - try: - samples_int = int(samples_text) - self.status_samples.display(samples_int) - except (ValueError, TypeError): - self.status_samples.display(0) - - self.status_rate.setText(safe_get_text('rate_label')) - self.status_latency.setText(safe_get_text('plotter_latency_label')) - - # Update errors with colour coding - errors_text = safe_get_text('error_label') - self.status_errors.setText(errors_text) - has_errors = errors_text != "0" and errors_text != "-" - self._set_status_error_style(self.status_errors, has_errors) - - self.status_saturation.setText(safe_get_text('saturation_label')) - self.status_acquisition.setText(safe_get_text('acq_status_label')) - + + if self.plotter.ring_buffer is not None: + sample_count = self.plotter.ring_buffer.write_idx + self.status_samples.setText( + self.plotter.format_sample_count(sample_count) + ) + + # Calculate total acquisition duration from samples and rate + if self.plotter.acquisition_rate is None: + msg = ( + "acquisition_rate required for duration calculation but not set" + ) + raise RuntimeError(msg) + + acq_duration_s = self.plotter.acquisition_rate.samples_to_seconds( + sample_count + ) + self.status_duration.setText(f"{acq_duration_s:.1f}s") + + # Only show rate after we have enough samples and time for stable reading + if self.plotter.rate_check_start_time is not None: + elapsed = time.perf_counter() - self.plotter.rate_check_start_time + if elapsed >= 0.5: # Wait 500ms before showing rate + rate_text = self._format_rate_for_display() + self.status_rate.setText(rate_text) + else: + self.status_rate.setText("-") + + # Show lag if pipeline is available + if ( + self.plotter.pipeline is not None + and self.plotter._last_read_position > 0 + ): + lag_text = self.plotter.pipeline.format_lag( + sample_count, self.plotter._last_read_position + ) + self.status_lag.setText(lag_text) + else: + self.status_lag.setText("-") + + self._update_saturation_indicator(self.plotter.is_saturated) + self._update_clip_indicators() + + # Update display status indicators + if self.plotter._display_cache: + total_display_points = sum( + len(data) for data in self.plotter._display_cache.values() + ) + # Average per channel for display + n_channels = len(self.plotter._display_cache) + avg_points = total_display_points // n_channels if n_channels > 0 else 0 + + is_lossless = self.plotter.target_plot_points is None + if is_lossless: + self.status_display_points.setText(f"{avg_points:,} (Lossless)") + else: + self.status_display_points.setText(f"{avg_points:,}") + + decimation = self.plotter._current_plot_decimation + self.status_decimation.setText(f"{decimation}×") + else: + self.status_display_points.setText("-") + self.status_decimation.setText("-") + + if self._device: + save_status = self._device.get_save_status() + state = save_status.get("state", "idle") + + if state == "saving": + total_s = save_status.get("total_seconds", 0) + pre_s = save_status.get("pre_trigger_seconds", 0) + self.status_record.setText( + f"Rec: {total_s:.1f}s ({pre_s:.1f}s pre)" + ) + self._update_state_indicator("recording") + elif state == "error": + error_msg = save_status.get("error", "Unknown error") + self.status_record.setText(f"Error: {error_msg[:20]}") + self._update_state_indicator("error") + else: + self.status_record.setText("Idle") + except Exception as e: - # Don't let errors in status sync crash the app - logger.debug(f"Error syncing status from plotter: {e}") - pass + logger.debug("Error syncing status: {}", e) + + def _format_rate_for_display(self) -> str: + """Format the current acquisition rate for display. + + Returns + ------- + str + Formatted rate string (e.g., "62.50 MS/s"). + + Raises + ------ + RuntimeError + If acquisition_rate is not available. + """ + if self.plotter.rate_check_start_time is None: + return "-" + + elapsed = time.perf_counter() - self.plotter.rate_check_start_time + if elapsed < 0.5: + return "-" + + if self.plotter.acquisition_rate is None: + msg = "acquisition_rate required for rate display but not set" + raise RuntimeError(msg) + + per_channel_rate_hz = self.plotter.acquisition_rate.per_channel_rate_hz + rate_msps = per_channel_rate_hz / 1e6 + return f"{rate_msps:.1f} MS/s" + + def keyPressEvent(self, event: QKeyEvent) -> None: # pyright: ignore + """Handle keyboard shortcuts for common operations.""" + key = event.key() + modifiers = event.modifiers() + + is_ctrl_pressed = modifiers & Qt.KeyboardModifier.ControlModifier + + if key == Qt.Key.Key_Space: + self._handle_space_shortcut() + elif key == Qt.Key.Key_R: + self._handle_r_shortcut() + elif key == Qt.Key.Key_K: + self._handle_k_shortcut() + elif key == Qt.Key.Key_Delete: + self._handle_delete_shortcut() + elif key == Qt.Key.Key_1: + self.toggle_channel_a.click() + elif key == Qt.Key.Key_2: + self.toggle_channel_b.click() + elif key == Qt.Key.Key_O and is_ctrl_pressed: + self._on_load_file() + else: + super().keyPressEvent(event) + + def _handle_space_shortcut(self) -> None: + """Toggle start/stop acquisition based on current state.""" + if self.start_button.isEnabled() and self._devices_connected: + self.start_acquisition() + elif self.stop_button.isEnabled(): + self.stop_acquisition() + + def _handle_r_shortcut(self) -> None: + """Start recording if acquisition is active and not already recording.""" + if ( + self.start_record_button.isEnabled() + and self.start_record_button.isVisible() + ): + self._on_start_record() + + def _handle_k_shortcut(self) -> None: + """Stop recording and keep the file.""" + if self.finish_keep_button.isVisible(): + self._on_finish_keep() - def closeEvent(self, event: QCloseEvent) -> None: + def _handle_delete_shortcut(self) -> None: + """Stop recording and discard the file.""" + if self.finish_discard_button.isVisible(): + self._on_finish_discard() + + def closeEvent(self, event: QCloseEvent) -> None: # pyright: ignore """Handle window close event.""" + logger.debug( + f"[THREAD] closeEvent called from thread: {threading.current_thread().name}" + ) + all_threads = [t.name for t in threading.enumerate()] + logger.debug("[THREAD] Active threads in closeEvent: {}", all_threads) self.save_settings() - - # Stop the status sync timer first - if hasattr(self, 'status_sync_timer'): + + if hasattr(self, "status_sync_timer"): self.status_sync_timer.stop() - self.status_sync_timer = None - - # Stop plotter updates - if hasattr(self, 'plotter') and self.plotter: + + if hasattr(self, "plotter") and self.plotter: self.plotter.stop_updates() - - if self.thread and self.thread.isRunning(): - self.stop_acquisition() - self.thread.wait() + + if self._stream_check_timer: + self._stream_check_timer.stop() + + # Stop save poll timer if active + if self._save_poll_timer and self._save_poll_timer.isActive(): + self._save_poll_timer.stop() + logger.debug("Stopped save poll timer during close") + + self._cleanup_device() + + for viewer in self._zarr_viewers[:]: + try: + viewer.close() + except Exception: + pass + self._zarr_viewers.clear() + + all_threads_final = [t.name for t in threading.enumerate()] + logger.debug( + f"[THREAD] closeEvent completed. Active threads: {all_threads_final}" + ) event.accept() def save_settings(self) -> None: - """Save current settings.""" - self.settings.setValue("sample_rate", self.sample_rate_input.value()) + """Save current settings to QSettings.""" + self.settings.setValue("channel_a_enabled", self.toggle_channel_a.isChecked()) + self.settings.setValue("channel_b_enabled", self.toggle_channel_b.isChecked()) + self.settings.setValue("sample_rate", self.sample_rate_input.currentText()) self.settings.setValue("resolution", self.resolution_input.currentText()) - self.settings.setValue("voltage_range", self.voltage_range_input.currentText()) - self.settings.setValue("output_file", self.output_file_input.text()) - self.settings.setValue("hw_downsample", self.hw_downsample_input.value()) - self.settings.setValue("downsample_mode", self.downsample_mode_input.currentText()) - self.settings.setValue("offset_v", self.offset_v_input.value()) - self.settings.setValue("live_only_mode", self.live_only_checkbox.isChecked()) - self.settings.setValue("max_buffer_seconds", self.max_buffer_input.value()) - self.settings.setValue("plot_window", self.plot_window_input.value()) - self.settings.setValue("y_axis_auto", self.y_axis_auto_checkbox.isChecked()) - - # Only save Y-axis limits if they're actually being used (not in auto mode) - if not self.y_axis_auto_checkbox.isChecked(): - self.settings.setValue("y_min", self.y_min_input.value()) - self.settings.setValue("y_max", self.y_max_input.value()) - - self.settings.setValue("bandwidth_limiter", self.bandwidth_limiter_input.currentText()) + self.settings.setValue("voltage_range_a", self.voltage_range_a.currentText()) + self.settings.setValue("voltage_range_b", self.voltage_range_b.currentText()) + self.settings.setValue("offset_a", self.offset_a.value()) + self.settings.setValue("offset_b", self.offset_b.value()) + self.settings.setValue("hw_downsample", self.hw_downsample_input.currentText()) + self.settings.setValue( + "downsample_mode", self.downsample_mode_input.currentText() + ) + self.settings.setValue("plot_window", self.plot_window_input.currentText()) + self.settings.setValue( + "ring_buffer_duration", self.ring_buffer_input.currentText() + ) + self.settings.setValue( + "pre_trigger_seconds", self.pre_trigger_input.currentText() + ) + + self.settings.setValue( + "bandwidth_limiter", self.bandwidth_limiter_input.currentText() + ) + + self.settings.setValue("quality", self.quality_input.currentText()) + self.settings.setValue("refresh_rate", self.refresh_rate_input.currentText()) + + self.settings.setValue("serial_code", self.serial_code_input.text()) + + self.settings.setValue("v_div_a", self.v_div_a_input.currentText()) + self.settings.setValue("v_div_b", self.v_div_b_input.currentText()) + self.settings.setValue("y_pos_a", self.y_pos_a_input.value()) + self.settings.setValue("y_pos_b", self.y_pos_b_input.value()) def load_settings(self) -> None: - """Load settings.""" - self.sample_rate_input.setValue(self.settings.value("sample_rate", 62.5, type=float)) + """Load settings from QSettings.""" + channel_a_enabled = self.settings.value("channel_a_enabled", True, type=bool) + channel_b_enabled = self.settings.value("channel_b_enabled", True, type=bool) + + # Enforce 16-bit single-channel constraint on loaded settings + resolution = int(self.settings.value("resolution", "12")) + if resolution == 16 and channel_a_enabled and channel_b_enabled: + channel_b_enabled = False + + self.toggle_channel_a.setChecked(channel_a_enabled) + self.toggle_channel_b.setChecked(channel_b_enabled) + self._on_toggle_channel_a() + self._on_toggle_channel_b() + + # Initialize sample rate dropdown based on resolution/channels first + self._init_sample_rate_dropdown() + + sample_rate_val = self.settings.value("sample_rate", "62.5 MS/s") + if isinstance(sample_rate_val, float): + sample_rate_val = f"{sample_rate_val:.1f} MS/s" + elif not isinstance(sample_rate_val, str): + sample_rate_val = "62.5 MS/s" + # Try to set; if invalid, default to 62.5 MS/s + idx = self.sample_rate_input.findText(sample_rate_val) + if idx >= 0: + self.sample_rate_input.setCurrentIndex(idx) + else: + self.sample_rate_input.setCurrentText("62.5 MS/s") + self.resolution_input.setCurrentText(self.settings.value("resolution", "12")) - self.voltage_range_input.setCurrentText( - self.settings.value("voltage_range", "PS5000A_20V") + self.voltage_range_a.setCurrentText( + str(self.settings.value("voltage_range_a", "20.0")) ) - self.output_file_input.setText(self.settings.value("output_file", "output.hdf5")) - self.hw_downsample_input.setValue(self.settings.value("hw_downsample", 1, type=int)) + self.voltage_range_b.setCurrentText( + str(self.settings.value("voltage_range_b", "20.0")) + ) + self.offset_a.setValue(self.settings.value("offset_a", 0.0, type=float)) + self.offset_b.setValue(self.settings.value("offset_b", 0.0, type=float)) + + downsample_val = self.settings.value("hw_downsample", "1x") + if isinstance(downsample_val, int): + downsample_val = f"{downsample_val}x" + elif not isinstance(downsample_val, str): + downsample_val = "1x" + self.hw_downsample_input.setCurrentText(downsample_val) + self.downsample_mode_input.setCurrentText( - self.settings.value("downsample_mode", "average") - ) - self.offset_v_input.setValue(self.settings.value("offset_v", 0.0, type=float)) - live_only = self.settings.value("live_only_mode", False, type=bool) - self.live_only_checkbox.setChecked(live_only) - self.max_buffer_input.setValue( - self.settings.value("max_buffer_seconds", 1.0, type=float) - ) - self.max_buffer_input.setEnabled(live_only) - - self.plot_window_input.setValue(self.settings.value("plot_window", 0.5, type=float)) - y_axis_auto = self.settings.value("y_axis_auto", True, type=bool) - self.y_axis_auto_checkbox.setChecked(y_axis_auto) - - # Only load Y-axis limits if they exist in settings (were saved in manual mode) - # Otherwise use sensible defaults - if self.settings.contains("y_min") and self.settings.contains("y_max"): - self.y_min_input.setValue(self.settings.value("y_min", type=float)) - self.y_max_input.setValue(self.settings.value("y_max", type=float)) - else: - # Use sensible defaults if no saved values - self.y_min_input.setValue(-1000.0) - self.y_max_input.setValue(1000.0) - - self.y_min_input.setEnabled(not y_axis_auto) - self.y_max_input.setEnabled(not y_axis_auto) - + self.settings.value("downsample_mode", "none") + ) + + plot_window_val = self.settings.value("plot_window", "5s") + if isinstance(plot_window_val, float): + plot_window_val = f"{plot_window_val:.0f}s" + elif not isinstance(plot_window_val, str): + plot_window_val = "5s" + self.plot_window_input.setCurrentText(plot_window_val) + + buffer_val = self.settings.value("ring_buffer_duration", "30s") + if isinstance(buffer_val, (int, float)): + buffer_val = f"{int(buffer_val)}s" + elif not isinstance(buffer_val, str): + buffer_val = "30s" + self.ring_buffer_input.setCurrentText(buffer_val) + + pre_trigger_val = self.settings.value("pre_trigger_seconds", "10s") + if isinstance(pre_trigger_val, float): + pre_trigger_val = f"{pre_trigger_val:.0f}s" + elif not isinstance(pre_trigger_val, str): + pre_trigger_val = "10s" + self.pre_trigger_input.setCurrentText(pre_trigger_val) + self.bandwidth_limiter_input.setCurrentText( self.settings.value("bandwidth_limiter", "Full") ) + self.quality_input.setCurrentText(self.settings.value("quality", "High")) + self.refresh_rate_input.setCurrentText( + self.settings.value("refresh_rate", "20 Hz") + ) + + self.serial_code_input.setText(self.settings.value("serial_code", "")) + + self.v_div_a_input.setCurrentText(self.settings.value("v_div_a", "100mV")) + self.v_div_b_input.setCurrentText(self.settings.value("v_div_b", "100mV")) + self.y_pos_a_input.setValue(self.settings.value("y_pos_a", 0.0, type=float)) + self.y_pos_b_input.setValue(self.settings.value("y_pos_b", 0.0, type=float)) + + def _init_sample_rate_dropdown(self) -> None: + """Initialize the sample rate dropdown based on current resolution and channels.""" + resolution = int(self.resolution_input.currentText()) + n_channels = (1 if self.toggle_channel_a.isChecked() else 0) + ( + 1 if self.toggle_channel_b.isChecked() else 0 + ) + if n_channels == 0: + n_channels = 1 # Default to 1 channel if none selected yet + + rates = self._get_available_rates(resolution, n_channels) + current_text = self.sample_rate_input.currentText() + + self.sample_rate_input.clear() + for rate_text, _rate_val in rates: + self.sample_rate_input.addItem(rate_text) + + # Try to restore previous selection, or default to first available + idx = self.sample_rate_input.findText(current_text) + if idx >= 0: + self.sample_rate_input.setCurrentIndex(idx) + else: + # Default to 62.5 MS/s or first available + idx_625 = self.sample_rate_input.findText("62.5 MS/s") + if idx_625 >= 0: + self.sample_rate_input.setCurrentIndex(idx_625) + elif self.sample_rate_input.count() > 0: + self.sample_rate_input.setCurrentIndex(0) + + def _get_available_rates( + self, resolution: int, n_channels: int + ) -> list[tuple[str, float]]: + """Get available sample rates for given resolution and channel count. + + Parameters + ---------- + resolution : int + Bit resolution (8, 12, 14, 15, 16). + n_channels : int + Number of enabled channels. + + Returns + ------- + list[tuple[str, float]] + List of (display_text, rate_msps) tuples. + """ + max_total_rate = { + 8: 125.0, + 12: 62.5, + 14: 31.25, + 15: 15.625, + 16: 15.625, + }.get(resolution, 62.5) + + # Per-channel max rate + max_per_channel = ( + max_total_rate / n_channels if n_channels > 0 else max_total_rate + ) + + # All candidate rates (MS/s) + all_rates = [ + 125.0, + 100.0, + 62.5, + 50.0, + 31.25, + 20.0, + 15.625, + 10.0, + 5.0, + 2.0, + 1.0, + 0.5, + 0.2, + 0.1, + ] + + rates = [] + for rate in all_rates: + if rate <= max_per_channel: + rates.append( + (f"{rate:.1f} MS/s" if rate == int(rate) else f"{rate} MS/s", rate) + ) + + return rates + + def _on_resolution_changed(self) -> None: + """Update sample rate dropdown when resolution changes. + + In 16-bit mode, the PicoScope 5000 series only supports one analog + channel. If both channels were enabled, Channel B is turned off. + """ + if self._is_16bit_mode(): + if self.toggle_channel_a.isChecked() and self.toggle_channel_b.isChecked(): + self.toggle_channel_b.setChecked(False) + self._on_toggle_channel_b() + + tooltip_suffix = " (16-bit: single channel only)" + else: + tooltip_suffix = "" + + self.toggle_channel_a.setToolTip( + f"Toggle Channel A visibility (1){tooltip_suffix}" + ) + self.toggle_channel_b.setToolTip( + f"Toggle Channel B visibility (2){tooltip_suffix}" + ) + + self._init_sample_rate_dropdown() + + def _on_sample_rate_changed(self) -> None: + """Hook for future use when sample rate changes.""" + pass + + def _on_buffer_duration_changed(self) -> None: + self._apply_clamp_constraints() + + def _on_pre_trigger_changed(self) -> None: + """Hook for future use when pre-trigger changes.""" + pass + + def _on_voltage_range_a_changed(self) -> None: + """Hook for future use when channel A voltage range changes.""" + pass + + def _on_offset_a_changed(self) -> None: + """Hook for future use when channel A offset changes.""" + pass + + def _on_voltage_range_b_changed(self) -> None: + """Hook for future use when channel B voltage range changes.""" + pass + + def _on_offset_b_changed(self) -> None: + """Hook for future use when channel B offset changes.""" + pass + + def _get_target_points_for_quality(self, quality: str) -> Optional[int]: + """Get target plot points for a quality setting. + + Parameters + ---------- + quality : str + Quality setting name. + + Returns + ------- + Optional[int] + Target points for the quality setting, or None for Lossless. + """ + return QUALITY_POINTS.get(quality) + + def _on_quality_changed(self) -> None: + """Handle quality dropdown change.""" + quality = self.quality_input.currentText() + logger.info("Quality set to {} (will apply on next acquisition)", quality) + + def _on_refresh_rate_changed(self) -> None: + """Handle refresh rate dropdown change.""" + refresh_text = self.refresh_rate_input.currentText() + interval_ms, _ = REFRESH_RATE_MAP.get(refresh_text, (50, 50000)) + + # Update plotter timer interval only if actively plotting + if self.plotter.timer.isActive(): + self.plotter.update_interval_ms = interval_ms + self.plotter.timer.stop() + self.plotter.timer.start(interval_ms) + + logger.info( + "Refresh rate set to {} ({}ms, will apply on next acquisition)", + refresh_text, + interval_ms, + ) + + def _validate_lossless_at_start(self) -> bool: + """Validate Lossless mode when acquisition starts. + + Shows a warning dialog if Lossless would exceed recommended limits. + Allows user to switch to High Quality, force Lossless anyway, or cancel. + + Returns + ------- + bool + True if acquisition should proceed, False to abort. + """ + quality = self.quality_input.currentText() + if quality != "Lossless": + return True + + window_text = self.plot_window_input.currentText() + window_seconds = float(window_text.rstrip("s")) + + refresh_text = self.refresh_rate_input.currentText() + _, max_points = REFRESH_RATE_MAP.get(refresh_text, (50, 50_000_000)) + + # Calculate required points using current acquisition parameters + sample_rate_text = self.sample_rate_input.currentText() + sample_rate_msps = float(sample_rate_text.rstrip(" MS/s")) + hw_downsample = self._parse_downsample_ratio() + + from picostream.acquisition_rate import AcquisitionRate, DownsampleMode + + mode_text = self.downsample_mode_input.currentText().upper() + try: + downsample_mode = DownsampleMode[mode_text] + except KeyError: + downsample_mode = DownsampleMode.NONE + + num_channels = (1 if self.toggle_channel_a.isChecked() else 0) + ( + 1 if self.toggle_channel_b.isChecked() else 0 + ) + if num_channels == 0: + num_channels = 1 + + temp_rate = AcquisitionRate( + hardware_rate_hz=sample_rate_msps * 1e6 * num_channels, + num_channels=num_channels, + downsample_ratio=hw_downsample, + downsample_mode=downsample_mode, + ) + required_points = temp_rate.seconds_to_samples(window_seconds) + + if required_points <= max_points: + return True + + dialog = LosslessWarningDialog(required_points, max_points, refresh_text, self) + QTimer.singleShot(30000, dialog.reject) + dialog.exec() + choice = dialog.user_choice() + + if choice == "cancel": + return False + + if choice == "high": + self.quality_input.blockSignals(True) + self.quality_input.setCurrentText("High") + self.quality_input.blockSignals(False) + logger.info("User switched to High Quality from Lossless warning") + + # choice == "lossless" - proceed anyway + return True + + +def _is_running_in_terminal() -> bool: + """Check if the application is running with an attached terminal.""" + return hasattr(sys.stdout, "isatty") and sys.stdout.isatty() + def main() -> None: """GUI entry point.""" logger.remove() - logger.add(sys.stderr, level="INFO") + + log_dir = Path.home() / ".picostream" / "logs" + log_dir.mkdir(parents=True, exist_ok=True) + log_file = log_dir / "picostream.log" + + logger.add( + str(log_file), + level="INFO", + rotation="10 MB", + retention=5, + enqueue=True, + ) + + if _is_running_in_terminal(): + logger.add(sys.stdout, level="INFO", enqueue=True) + app = QApplication(sys.argv) main_win = PicoStreamMainWindow() main_win.show() - sys.exit(app.exec_()) + sys.exit(app.exec()) if __name__ == "__main__": diff --git a/picostream/mock_device.py b/picostream/mock_device.py new file mode 100644 index 0000000..7fd1759 --- /dev/null +++ b/picostream/mock_device.py @@ -0,0 +1,1477 @@ +""" +Mock Picoscope DAQ Device + +Simulates the Picoscope DAQ for testing without hardware. +Implements the same interface as the real device. +""" + +import os +import random +import threading +import time +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional + +import numpy as np +from loguru import logger + +from picostream.acquisition_rate import AcquisitionRate, DownsampleMode + +if TYPE_CHECKING: + pass + + +class MockPicoscope: + """ + Mock implementation of Picoscope DAQ. + + Simulates data acquisition and streaming functionality. + """ + + # Valid voltage ranges matching the real Picoscope 5000a series + VOLTAGE_RANGES = { + 0.01: "PS5000A_10MV", + 0.02: "PS5000A_20MV", + 0.05: "PS5000A_50MV", + 0.1: "PS5000A_100MV", + 0.2: "PS5000A_200MV", + 0.5: "PS5000A_500MV", + 1.0: "PS5000A_1V", + 2.0: "PS5000A_2V", + 5.0: "PS5000A_5V", + 10.0: "PS5000A_10V", + 20.0: "PS5000A_20V", + } + + # Max ADC value matches the real driver (16-bit signed max) + MAX_ADC = 32767 + + def __init__(self, device_id: str, serial_code: Optional[str] = None): + """ + Initialize mock DAQ. + + Parameters + ---------- + device_id : str + Device identifier. + serial_code : str, optional + Serial code (ignored in mock). + """ + self.device_id = device_id + + # Internal state + self._connected = False + self._streaming = False + self._streaming_callback = None + self._streaming_thread = None + self._stop_streaming = threading.Event() + + # Streaming configuration + self._sample_rate = 100000 + self._voltage_range = 1.0 + self._num_samples = 4096 + self._downsample_ratio = 1 + self._downsample_mode = "NONE" + + # Armed acquisition state + self._armed = False + self._armed_config = {} + + # Mock device info + self._serial = serial_code or "MOCK123456" + self._model = "Mock Picoscope 4000" + + logger.info("Mock DAQ {} initialized (serial: {})", device_id, self._serial) + + def get_connection_id(self) -> Optional[str]: + """ + Get the serial code. + + Returns + ------- + Optional[str] + The serial code, or None if auto-detect is used. + """ + return self._serial + + def set_connection_id(self, connection_id: Optional[str]) -> None: + """ + Set the serial code for connection. + + Pass None or empty string to use auto-detect. + + Parameters + ---------- + connection_id : Optional[str] + Serial code, or None/empty for auto-detect. + + Raises + ------ + ValueError + If connection_id is not a string or None. + """ + if connection_id is None: + self._serial = None + elif isinstance(connection_id, str): + self._serial = connection_id if connection_id else None + else: + raise ValueError("connection_id must be a string or None") + + def is_connected(self) -> bool: + """Check if device is connected.""" + return self._connected + + def connect(self) -> None: + """Connect to the mock DAQ.""" + logger.info("Mock DAQ {}: Connecting...", self.device_id) + time.sleep(0.2) # Simulate connection time + self._connected = True + logger.info("Mock DAQ {}: Connected (serial: {})", self.device_id, self._serial) + + def disconnect(self) -> None: + """Disconnect from the mock DAQ.""" + logger.info("Mock DAQ {}: Disconnecting...", self.device_id) + if self._streaming: + self.stop_streaming() + self._connected = False + self._armed = False + logger.info("Mock DAQ {}: Disconnected", self.device_id) + + def configure_streaming( + self, + sample_rate: float, + voltage_range: float, + num_samples_per_block: int, + downsample_ratio: int = 1, + downsample_mode: str = "NONE", + ) -> None: + """ + Configure streaming parameters. + + Parameters + ---------- + sample_rate : float + Sample rate in Hz. + voltage_range : float + Voltage range in Volts. + num_samples_per_block : int + Number of samples per block. + downsample_ratio : int + Downsampling ratio. + downsample_mode : str + Downsampling mode ("NONE", "AVERAGE", "DECIMATE"). + """ + if not self._connected: + raise RuntimeError("DAQ not connected") + + # Validate parameters + if not 1000 <= sample_rate <= 125000000: + raise ValueError("Sample rate must be between 1kHz and 125MHz") + if voltage_range not in self.VOLTAGE_RANGES: + valid_ranges = sorted(self.VOLTAGE_RANGES.keys()) + raise ValueError( + f"Invalid voltage range: {voltage_range}V. " + f"Valid ranges: {valid_ranges}V" + ) + if not 256 <= num_samples_per_block <= 10000000: + raise ValueError("Samples per block must be between 256 and 10M") + if downsample_ratio < 1: + raise ValueError("Downsample ratio must be >= 1") + if downsample_mode not in ["NONE", "AVERAGE", "DECIMATE"]: + raise ValueError("Invalid downsample mode") + + self._sample_rate = sample_rate + self._voltage_range = voltage_range + self._num_samples = num_samples_per_block + self._downsample_ratio = downsample_ratio + self._downsample_mode = downsample_mode + + logger.info( + "Mock DAQ {}: Configured streaming: {:.0f}kHz, " + "{}V, {} samples/block, " + "downsample={} ({})", + self.device_id, + sample_rate / 1000, + voltage_range, + num_samples_per_block, + downsample_ratio, + downsample_mode, + ) + + def start_streaming(self, callback) -> None: + """ + Start streaming data. + + Parameters + ---------- + callback : callable + Callback function for streaming data. + """ + if not self._connected: + logger.error( + "Mock DAQ {}: Cannot start streaming - DAQ not connected", + self.device_id, + ) + raise RuntimeError("DAQ not connected") + + if self._streaming: + logger.warning( + "Mock DAQ {}: Already streaming, ignoring start request", self.device_id + ) + return + + self._streaming = True + self._streaming_callback = callback + self._stop_streaming.clear() + + self._streaming_thread = threading.Thread( + target=self._streaming_loop, daemon=True, name=f"MockDAQ-{self.device_id}" + ) + self._streaming_thread.start() + logger.info("Mock DAQ {}: Started streaming successfully", self.device_id) + + def stop_streaming(self, timeout: Optional[float] = None) -> None: + """ + Stop streaming data. + + Parameters + ---------- + timeout : float, optional + Timeout for stopping. + """ + if not self._streaming: + logger.info("Mock DAQ {}: Not streaming, nothing to stop", self.device_id) + return + + logger.info("Mock DAQ {}: Stopping streaming...", self.device_id) + self._streaming = False # Set flag first to prevent new data generation + self._stop_streaming.set() + + if self._streaming_thread: + self._streaming_thread.join(timeout=timeout or 5.0) + if self._streaming_thread.is_alive(): + logger.warning( + "Mock DAQ {}: Streaming thread did not stop in time", self.device_id + ) + else: + logger.info( + "Mock DAQ {}: Streaming thread stopped successfully", self.device_id + ) + + self._streaming_callback = None + self._streaming_thread = None + logger.info("Mock DAQ {}: Stopped streaming", self.device_id) + + def is_streaming(self) -> bool: + """Check if streaming is active.""" + return self._streaming + + def _streaming_loop(self) -> None: + """Generate mock streaming data.""" + block_time = self._num_samples / self._sample_rate + + iteration = 0 + # Use smaller wait intervals to allow faster response to stop requests + wait_interval = min( + 0.05, block_time / 20 + ) # Check stop event more frequently (50ms max) + elapsed_time = 0 + + while not self._stop_streaming.wait(wait_interval): + elapsed_time += wait_interval + + # Generate data every 100ms regardless of sample rate for more responsive streaming + if elapsed_time >= 0.1: + try: + # Generate mock voltage data + voltage_data = self._generate_mock_data(self._num_samples) + + # Apply downsampling if configured + if self._downsample_ratio > 1: + voltage_data = self._apply_downsampling(voltage_data) + + # Convert to list for JSON serialization + data_list = voltage_data.tolist() + + # Send data via callback + if self._streaming_callback: + try: + self._streaming_callback(data_list) + except Exception: + logger.exception( + "Mock DAQ {}: Callback failed", + self.device_id, + ) + else: + logger.warning( + "Mock DAQ {}: NO CALLBACK REGISTERED at iteration {}", + self.device_id, + iteration, + ) + + iteration += 1 + elapsed_time = 0 # Reset elapsed time after generating data + + except Exception: + logger.exception( + "Mock DAQ {}: Error in streaming loop", self.device_id + ) + break + + logger.info( + "Mock DAQ {}: _streaming_loop() exiting after {} iterations", + self.device_id, + iteration, + ) + + def _generate_mock_data(self, num_samples: int) -> np.ndarray: + """Generate realistic mock voltage data with optimized performance.""" + # Time vector + block_time = num_samples / self._sample_rate + t = np.linspace(0, block_time, num_samples) + + # Simplified signal generation for better performance + # Main frequency component (1kHz) + frequency = 2 # updated to 2 + # Use fixed amplitude in volts (not a percentage of range) + # This simulates a real signal with known amplitude + amplitude = 0.5 # 500mV signal + signal = amplitude * np.sin(2 * np.pi * frequency * t) + + # Reduced harmonics for performance + signal += amplitude * 0.1 * np.sin(2 * np.pi * frequency * 3 * t) + + # Simplified noise + noise_level = self._voltage_range * 0.02 + noise = np.random.normal(0, noise_level, num_samples) + voltage_data = signal + noise + + # Reduce spike frequency to improve performance + if random.random() < 0.005: # 0.5% chance of spike (reduced from 1%) + spike_pos = random.randint(0, num_samples - 1) + voltage_data[spike_pos] += random.uniform(-0.5, 0.5) * self._voltage_range + + return voltage_data + + def _apply_downsampling(self, data: np.ndarray) -> np.ndarray: + """Apply downsampling to data.""" + if self._downsample_ratio <= 1: + return data + + new_length = len(data) // self._downsample_ratio + + if self._downsample_mode == "AVERAGE": + # Reshape and average + truncated = data[: new_length * self._downsample_ratio] + return truncated.reshape(new_length, self._downsample_ratio).mean(axis=1) + elif self._downsample_mode == "DECIMATE": + # Simple decimation + return data[:: self._downsample_ratio] + else: # NONE + return data + + def arm_acquisition( + self, + sample_rate: float, + num_samples: int, + channels: List[int], + voltage_ranges: List[float], + trigger_source: Optional[str] = None, + trigger_threshold_v: float = 0.5, + trigger_direction: str = "RISING", + trigger_delay_s: float = 0.0, + downsample_ratio: int = 1, + downsample_mode: str = "NONE", + **kwargs, + ) -> None: + """ + Arm the DAQ for hardware-triggered acquisition. + + Parameters + ---------- + sample_rate : float + Sample rate in Hz. + num_samples : int + Number of samples to acquire. + channels : list of int + Channels to acquire. + voltage_ranges : list of float + Voltage range for each channel. + trigger_source : str, optional + Trigger source (e.g., "EXT" for external). + trigger_threshold_v : float, optional + Trigger threshold in volts. + trigger_direction : str, optional + Trigger direction ("RISING" or "FALLING"). + trigger_delay_s : float, optional + Trigger delay in seconds. + downsample_ratio : int, optional + Downsampling ratio. + downsample_mode : str, optional + Downsampling mode ("NONE", "AVERAGE", "DECIMATE"). + """ + if not self._connected: + raise RuntimeError("DAQ not connected") + + # Validate parameters + if not 1000 <= sample_rate <= 125000000: + raise ValueError("Sample rate must be between 1kHz and 125MHz") + if not 256 <= num_samples <= 10000000: + raise ValueError("Number of samples must be between 256 and 10M") + if not channels: + raise ValueError("At least one channel must be specified") + if len(channels) != len(voltage_ranges): + raise ValueError("Number of channels must match number of voltage ranges") + for v_range in voltage_ranges: + if v_range not in self.VOLTAGE_RANGES: + valid_ranges = sorted(self.VOLTAGE_RANGES.keys()) + raise ValueError( + f"Invalid voltage range: {v_range}V. Valid ranges: {valid_ranges}V" + ) + + # Store configuration for later use in wait_for_acquisition + self._armed_config = { + "sample_rate": sample_rate, + "num_samples": num_samples, + "channels": channels, + "voltage_ranges": voltage_ranges, + "trigger_source": trigger_source, + "trigger_threshold_v": trigger_threshold_v, + "trigger_direction": trigger_direction, + "trigger_delay_s": trigger_delay_s, + "downsample_ratio": downsample_ratio, + "downsample_mode": downsample_mode, + **kwargs, # Store any additional kwargs passed from the task + } + + self._armed = True + + logger.info( + "Mock DAQ {}: Armed acquisition: {:.0f}kHz, " + "{} samples, channels {}, " + "trigger={} ({})", + self.device_id, + sample_rate / 1000, + num_samples, + channels, + trigger_source, + trigger_direction, + ) + + def wait_for_acquisition( + self, timeout_s: float = 10.0, cancel_event: Optional[threading.Event] = None + ) -> Dict[str, Any]: + """ + Wait for a hardware-triggered acquisition to complete. + + Parameters + ---------- + timeout_s : float, optional + Timeout in seconds. + cancel_event : threading.Event, optional + Event to signal cancellation. + + Returns + ------- + dict + Acquisition results with keys: data_v, data_adc, metadata. + + Raises + ------ + RuntimeError + If DAQ is not armed or not connected. + TimeoutError + If acquisition does not complete within timeout. + """ + if not self._connected: + raise RuntimeError("DAQ not connected") + + if not self._armed: + raise RuntimeError("DAQ not armed. Call arm_acquisition() first.") + + logger.info( + "Mock DAQ {}: Waiting for hardware-triggered acquisition (timeout={}s)", + self.device_id, + timeout_s, + ) + + # Simulate waiting for trigger and acquisition + # In real hardware, this would block until the trigger fires + # For mock, we simulate the acquisition time + config = self._armed_config + num_samples = config["num_samples"] + sample_rate = config["sample_rate"] + channels = config["channels"] + voltage_ranges = config["voltage_ranges"] + + # Estimate acquisition time + acquisition_time = num_samples / sample_rate + + # Simulate trigger delay + acquisition with cancellation support + start_time = time.time() + while time.time() - start_time < acquisition_time: + if cancel_event and cancel_event.is_set(): + logger.info("Mock DAQ {}: Acquisition cancelled", self.device_id) + self._armed = False + raise RuntimeError("Acquisition was cancelled") + + # Check timeout + + # Sleep in small increments to allow cancellation checks + time.sleep(min(0.01, acquisition_time / 10)) + + # Check timeout + if time.time() - start_time > timeout_s: + self._armed = False + raise TimeoutError(f"Acquisition did not complete within {timeout_s}s") + + # --- Generate time-domain signal that will appear as a Gaussian when mapped to wavelength --- + # The HardwareSweepTask will map time_data_np to wavelengths. + # We need to generate a Gaussian in the time domain such that its peak corresponds + # to the desired wavelength (1554nm) when that mapping occurs. + + # Assume a typical sweep range for the mock to generate a feature within + # This is for the mock's internal calculation, not for the task's actual sweep. + mock_sweep_start_wl = 1550.0 + mock_sweep_stop_wl = 1560.0 + mock_sweep_speed_nm_s = 20.0 + + time_data_np = ( + np.arange(num_samples) / sample_rate + ) # Time axis for the acquired data + + # Calculate the wavelength at each time point if the sweep started at mock_sweep_start_wl + # and swept at mock_sweep_speed_nm_s + mock_wavelengths_at_time = ( + mock_sweep_start_wl + + np.sign(mock_sweep_stop_wl - mock_sweep_start_wl) + * mock_sweep_speed_nm_s + * time_data_np + ) + + # Define Gaussian parameters in terms of wavelength + peak_wavelength = 1554.0 # nm + peak_amplitude = -1.0 # V (negative peak) + fwhm = 2.0 # nm (Full Width at Half Maximum) + sigma = fwhm / ( + 2 * np.sqrt(2 * np.log(2)) + ) # Convert FWHM to standard deviation + + # Generate a baseline (off-resonance) voltage + baseline_voltage = 0.0 # V + + # Calculate Gaussian profile based on the mock_wavelengths_at_time + gaussian_profile = peak_amplitude * np.exp( + -((mock_wavelengths_at_time - peak_wavelength) ** 2) / (2 * sigma**2) + ) + + # Combine baseline and Gaussian + mock_spectrum_v = baseline_voltage + gaussian_profile + + # Add some noise to the spectrum + noise_level = ( + 0.01 * voltage_ranges[0] + ) # 1% of the first channel's voltage range + mock_spectrum_v += np.random.normal(0, noise_level, num_samples) + + # Ensure data stays within voltage range (clipping) + mock_spectrum_v = np.clip( + mock_spectrum_v, -voltage_ranges[0], voltage_ranges[0] + ) + + # Generate mock data for each channel + data_v = {} + data_adc = {} + + for ch_idx, ch in enumerate(channels): + voltage_range = voltage_ranges[ch_idx] + + # For the sweep, we'll use the generated mock_spectrum_v for all channels + voltage_data = mock_spectrum_v + + # Convert to ADC counts (matching real device: 16-bit signed, max = 32767) + adc_data = (voltage_data / voltage_range * self.MAX_ADC).astype(np.int16) + + data_v[ch] = voltage_data.tolist() + data_adc[ch] = adc_data.tolist() + + # Create time axis (as a list for JSON serialization) + time_data_list = time_data_np.tolist() + + # Mark as no longer armed + self._armed = False + + logger.info( + "Mock DAQ {}: Acquisition complete: {} samples at {:.0f}kHz", + self.device_id, + num_samples, + sample_rate / 1000, + ) + + return { + "data_v": data_v, + "data_adc": data_adc, + "metadata": { + "sample_rate": sample_rate, + "num_samples": num_samples, + "channels": channels, + "voltage_ranges": voltage_ranges, + "time_data": time_data_list, + "timestamp": time.time(), + }, + } + + def cancel_acquisition(self) -> None: + """ + Cancel an in-progress or armed acquisition. + """ + if not self._connected: + raise RuntimeError("DAQ not connected") + + if not self._armed: + logger.debug("Mock DAQ {}: No acquisition to cancel", self.device_id) + return + + self._armed = False + self._armed_config = {} + logger.info("Mock DAQ {}: Acquisition cancelled", self.device_id) + + def reset_device_state(self) -> None: + """ + Reset the device to a clean, idle state. + """ + if not self._connected: + raise RuntimeError("DAQ not connected") + + self._armed = False + self._armed_config = {} + logger.info("Mock DAQ {}: Device state reset", self.device_id) + + def acquire_block( + self, + sample_rate: float, + num_samples: int, + voltage_range: float, + channels: Optional[List[int]] = None, + ) -> Dict[str, Any]: + """ + Acquire a single block of data. + + Parameters + ---------- + sample_rate : float + Sample rate in Hz. + num_samples : int + Number of samples to acquire. + voltage_range : float + Voltage range in Volts. + channels : list of int, optional + Channels to acquire (default: [0]). + + Returns + ------- + dict + Acquisition results. + """ + if not self._connected: + raise RuntimeError("DAQ not connected") + + channels = channels or [0] + + # Validate parameters + if not 1000 <= sample_rate <= 125000000: + raise ValueError("Sample rate must be between 1kHz and 125MHz") + if not 256 <= num_samples <= 10000000: + raise ValueError("Number of samples must be between 256 and 10M") + if voltage_range not in self.VOLTAGE_RANGES: + valid_ranges = sorted(self.VOLTAGE_RANGES.keys()) + raise ValueError( + f"Invalid voltage range: {voltage_range}V. " + f"Valid ranges: {valid_ranges}V" + ) + + logger.info( + "Mock DAQ {}: Acquiring {} samples at {:.0f}kHz on channels {}", + self.device_id, + num_samples, + sample_rate / 1000, + channels, + ) + + # Simulate acquisition time + acquisition_time = num_samples / sample_rate + time.sleep(min(acquisition_time, 0.5)) # Cap at 0.5s for testing + + # Generate mock data for each channel + results = {} + t = np.linspace(0, acquisition_time, num_samples) + + for ch in channels: + # Different frequency per channel for variety + # frequency = 1000 + ch * 500 + frequency = 2 + ch * 1 + # Fixed 400mV signal amplitude (not percentage of range) + signal_amplitude = 0.4 + noise_level = voltage_range * 0.01 + + signal = signal_amplitude * np.sin(2 * np.pi * frequency * t) + noise = np.random.normal(0, noise_level, num_samples) + voltage_data = signal + noise + + results[ch] = voltage_data.tolist() + + # Return in format expected by HardwareSweepTask + return { + "data_v": results, + "sample_rate_hz": sample_rate, + "num_samples": num_samples, + "voltage_range_v": voltage_range, + "channels": channels, + "timestamps": t.tolist(), + } + + def get_device_info(self) -> Dict[str, Any]: + """Get device information.""" + return { + "serial": self._serial, + "model": self._model, + "connected": self._connected, + "streaming": self._streaming, + } + + def get_state(self) -> dict: + """Get device state.""" + return { + "connected": self._connected, + "streaming": self._streaming, + "armed": self._armed, + "sample_rate": self._sample_rate, + "voltage_range": self._voltage_range, + "num_samples": self._num_samples, + "downsample_ratio": self._downsample_ratio, + "downsample_mode": self._downsample_mode, + } + + +class MockPicoscopeBufferedStream: + """ + Mock implementation of PicoscopeBufferedStream for testing PicoStream GUI. + + Simulates multi-channel buffered streaming with a ring buffer, save functionality, + and realistic synthetic data (two sine waves + gaussian noise). + """ + + # Same voltage ranges as real device + VOLTAGE_RANGES = { + 0.01: "PS5000A_10MV", + 0.02: "PS5000A_20MV", + 0.05: "PS5000A_50MV", + 0.1: "PS5000A_100MV", + 0.2: "PS5000A_200MV", + 0.5: "PS5000A_500MV", + 1.0: "PS5000A_1V", + 2.0: "PS5000A_2V", + 5.0: "PS5000A_5V", + 10.0: "PS5000A_10V", + 20.0: "PS5000A_20V", + } + + def __init__( + self, + device_id: str, + serial_code: Optional[str] = None, + resolution: int = 12, + **kwargs: Any, + ) -> None: + """Initialize mock buffered stream device.""" + self.device_id = device_id + self.serial_code = serial_code or "MOCK_BUFF_STREAM" + self.resolution = resolution + + # Connection state + self._is_connected = False + self._chandle = None # Mock handle for compatibility + self._max_adc_value = None + + # Streaming state + self._streaming_configured = False + self._streaming_channels: List[int] = [] + self._enabled_channels_set: set[int] = set() + self._streaming_voltage_ranges: Dict[int, float] = {} + self._streaming_offsets_v: Dict[int, float] = {} + self._streaming_sample_rate: Optional[float] = None + self._streaming_sample_interval_ns: int = 16 + self._streaming_downsample_ratio: int = 1 + self._streaming_downsample_mode: str = "NONE" + self._streaming_bandwidth_limiter: str = "FULL" + self._streaming_buffer_duration_s: float = 30.0 + + # Ring buffer + self._ring_buffer: Optional[Any] = None + + # Producer thread + self._producer_thread: Optional[threading.Thread] = None + self._stop_streaming_event = threading.Event() + + # Save state + self._is_saving = False + self._save_thread: Optional[threading.Thread] = None + self._save_stop_event: Optional[threading.Event] = None + self._save_output_path: Optional[str] = None + self._save_samples_written = 0 + self._save_pre_trigger_samples = 0 + self._save_error: Optional[Exception] = None + self._keep_file_on_stop: bool = True + self._save_output_path: Optional[str] = None + self._save_samples_written = 0 + self._save_pre_trigger_samples = 0 + self._save_error: Optional[Exception] = None + + # Signal generation state (for continuous phases) + self._phase_ch0 = 0.0 + self._phase_ch1 = 0.0 + self._sample_count = 0 + + logger.info("MockPicoscopeBufferedStream {} initialized", device_id) + + def get_connection_id(self) -> Optional[str]: + """Get the serial code.""" + return self.serial_code + + def set_connection_id(self, connection_id: Optional[str]) -> None: + """Set the serial code for connection.""" + if connection_id is None: + self.serial_code = None + elif isinstance(connection_id, str): + self.serial_code = connection_id if connection_id else None + else: + raise ValueError("connection_id must be a string or None") + + def connect(self) -> None: + """Connect to the mock device.""" + logger.info("MockPicoscopeBufferedStream {}: Connecting...", self.device_id) + time.sleep(0.1) # Simulate connection time + self._is_connected = True + self._max_adc_value = 32767 if self.resolution <= 12 else 32767 + logger.info( + "MockPicoscopeBufferedStream {}: Connected (serial: {}, resolution: {}-bit)", + self.device_id, + self.serial_code, + self.resolution, + ) + + def disconnect(self) -> None: + """Disconnect from the mock device.""" + logger.info("MockPicoscopeBufferedStream {}: Disconnecting...", self.device_id) + if self.is_streaming(): + self.stop_streaming() + self._is_connected = False + self._streaming_configured = False + self._ring_buffer = None + logger.info("MockPicoscopeBufferedStream {}: Disconnected", self.device_id) + + def configure_streaming( + self, + sample_rate: float, + channels: List[int], + voltage_ranges: List[float], + buffer_duration_s: float = 30.0, + downsample_ratio: int = 1, + downsample_mode: str = "NONE", + offsets_v: Optional[List[float]] = None, + bandwidth_limiter: str = "FULL", + ) -> None: + """Configure buffered streaming with ring buffer.""" + if not self._is_connected: + raise RuntimeError("Device not connected") + + from picostream.ring_buffer import RingBuffer + + # Validate parameters + if len(channels) != len(voltage_ranges): + raise ValueError("Length of channels and voltage_ranges must match") + + # Validate voltage ranges against supported values + for v_range in voltage_ranges: + if v_range not in self.VOLTAGE_RANGES: + valid_ranges = sorted(self.VOLTAGE_RANGES.keys()) + raise ValueError( + f"Invalid voltage range: {v_range}V. Valid ranges: {valid_ranges}V" + ) + + # Clamp sample rate to sustainable level for mock device + max_mock_rate = 1e6 # 1 MS/s is plenty for testing + if sample_rate > max_mock_rate: + logger.warning( + "Mock device cannot sustain {:.1f} MS/s, clamping to {:.1f} MS/s for testing", + sample_rate / 1e6, + max_mock_rate / 1e6, + ) + sample_rate = max_mock_rate + + # Store configuration + self._streaming_channels = list(channels) + + # For mock device, always generate data for both channels + # (real oscilloscopes always have data, you just choose which to display) + self._enabled_channels_set = {0, 1} + + # Ensure both channels have voltage ranges defined + self._streaming_voltage_ranges = {} + for ch, v_range in zip(channels, voltage_ranges, strict=True): + self._streaming_voltage_ranges[ch] = v_range + + # Ensure channel 0 and 1 always have voltage ranges defined (use defaults if not provided) + if 0 not in self._streaming_voltage_ranges: + self._streaming_voltage_ranges[0] = ( + voltage_ranges[0] if voltage_ranges else 20.0 + ) + if 1 not in self._streaming_voltage_ranges: + self._streaming_voltage_ranges[1] = ( + voltage_ranges[-1] + if len(voltage_ranges) > 1 + else (voltage_ranges[0] if voltage_ranges else 20.0) + ) + self._streaming_sample_rate = sample_rate + self._streaming_buffer_duration_s = buffer_duration_s + + if offsets_v is None: + offsets_v = [0.0] * len(channels) + self._streaming_offsets_v = { + ch: offset for ch, offset in zip(channels, offsets_v, strict=True) + } + + if downsample_mode.upper() == "NONE": + downsample_ratio = 1 + self._streaming_downsample_ratio = downsample_ratio + self._streaming_downsample_mode = downsample_mode.upper() + self._streaming_bandwidth_limiter = bandwidth_limiter.upper() + + # Create ring buffer (always 2 channels) + # Calculate total hardware rate (per-channel rate × num_channels) + total_hardware_rate_hz = sample_rate * len(self._streaming_channels) + acquisition_rate = AcquisitionRate( + hardware_rate_hz=total_hardware_rate_hz, + num_channels=len(self._streaming_channels), + downsample_ratio=downsample_ratio, + downsample_mode=DownsampleMode(downsample_mode.upper()), + ) + + self._ring_buffer = RingBuffer( + duration_s=buffer_duration_s, + sample_rate=acquisition_rate.storage_rate_hz, + num_channels=2, + ) + + # Store acquisition rate for later use + self._acquisition_rate = acquisition_rate + + self._streaming_configured = True + self._phase_ch0 = 0.0 + self._phase_ch1 = 0.0 + self._sample_count = 0 + + enabled_names = ["A" if ch == 0 else "B" for ch in sorted(channels)] + logger.info( + "MockPicoscopeBufferedStream {}: Configured streaming channels=[{}], rate={:.2f} Hz, buffer={}s, downsample={}x ({})", + self.device_id, + ", ".join(enabled_names), + sample_rate, + buffer_duration_s, + downsample_ratio, + downsample_mode, + ) + + def start_streaming(self, callback: Optional[Callable] = None) -> None: + """Start buffered streaming to ring buffer.""" + if not self._is_connected: + raise RuntimeError("Device not connected") + if not self._streaming_configured: + raise RuntimeError( + "Streaming not configured. Call configure_streaming() first" + ) + if self.is_streaming(): + logger.warning("Streaming already active on {}", self.device_id) + return + + if self._streaming_sample_rate is None: + raise RuntimeError("Sample rate not configured") + + # Use existing acquisition_rate from configure_streaming + if not hasattr(self, "_acquisition_rate"): + raise RuntimeError( + "configure_streaming() must be called before start_streaming()" + ) + + # Set acquisition rate on ring buffer + if self._ring_buffer is not None: + self._ring_buffer.set_acquisition_rate(self._acquisition_rate) + logger.info( + "Mock ring buffer acquisition rate set: {}", + self._acquisition_rate, + ) + + self._stop_streaming_event.clear() + self._producer_thread = threading.Thread( + target=self._producer_loop, + daemon=True, + name=f"mock-producer-{self.device_id}", + ) + self._producer_thread.start() + logger.info("MockPicoscopeBufferedStream {}: Streaming started", self.device_id) + + def stop_streaming(self, timeout: float = 5.0) -> None: + """Stop buffered streaming.""" + if not self.is_streaming(): + logger.info("Streaming not active on {}", self.device_id) + return + + logger.info( + "MockPicoscopeBufferedStream {}: Stopping streaming...", self.device_id + ) + self._stop_streaming_event.set() + + if self._is_saving: + self.stop_save(keep=True) + + if self._producer_thread and self._producer_thread.is_alive(): + self._producer_thread.join(timeout) + if self._producer_thread.is_alive(): + logger.warning("Producer thread did not stop in time") + + self._producer_thread = None + logger.info("MockPicoscopeBufferedStream {}: Streaming stopped", self.device_id) + + def is_streaming(self) -> bool: + """Return True if streaming thread is active.""" + return self._producer_thread is not None and self._producer_thread.is_alive() + + @property + def ring_buffer(self) -> Optional[Any]: + """Return the ring buffer for plotter access.""" + return self._ring_buffer + + def update_plotter_position(self, read_idx: int) -> None: + """Update the plotter's last read position for overflow detection. + + This mock implementation stores the value but doesn't use it + for actual overflow detection. + + Parameters + ---------- + read_idx : int + The sample index that the plotter last read up to. + """ + logger.debug("Mock update_plotter_position called with read_idx={}", read_idx) + + def get_streaming_error(self) -> Optional[str]: + """Get the last streaming error, if any. + + Returns + ------- + Optional[str] + Error description if streaming stopped due to an error, + None otherwise. + """ + return None + + @property + def is_saving(self) -> bool: + """Return True if currently saving to file.""" + return self._is_saving + + def start_save(self, lookback_seconds: float, output_path: str) -> bool: + """Start saving data to a Zarr file.""" + if self._is_saving: + logger.warning("Save already in progress") + return False + if self._ring_buffer is None: + logger.error("Cannot start save: ring buffer not configured") + return False + if not hasattr(self, "_acquisition_rate"): + logger.error("Cannot start save: acquisition rate not configured") + return False + + self._save_output_path = output_path + + # Use acquisition_rate for consistent rate calculations + if not hasattr(self, "_acquisition_rate"): + raise RuntimeError( + "configure_streaming() must be called before start_save()" + ) + + lookback_samples = self._acquisition_rate.seconds_to_samples(lookback_seconds) + available_samples = min(lookback_samples, self._ring_buffer.write_idx) + + self._save_stop_event = threading.Event() + self._save_samples_written = 0 + self._save_pre_trigger_samples = 0 + self._save_error = None + + # Use AcquisitionRate from ring buffer if available + acquisition_rate: Optional[AcquisitionRate] = None + if self._ring_buffer is not None: + acquisition_rate = self._ring_buffer.get_acquisition_rate() + + # Build metadata with AcquisitionRate fields for consistent rate semantics + # Use the acquisition_rate that was already calculated + total_hardware_rate_hz = self._acquisition_rate.hardware_rate_hz + metadata: Dict[str, Any] = { + # AcquisitionRate fields - stored explicitly for reconstruction + "hardware_rate_hz": total_hardware_rate_hz, + "num_channels": 2, + "downsample_ratio": self._streaming_downsample_ratio, + "downsample_mode": self._streaming_downsample_mode, + # Legacy field for backward compatibility + "sample_rate_hz": acquisition_rate.storage_rate_hz, + "channels": [0, 1], + "enabled_channels": self._streaming_channels, + "voltage_ranges": [ + self._streaming_voltage_ranges.get(ch, 20.0) for ch in [0, 1] + ], + "voltage_ranges_units": "V", + "offsets_v": [self._streaming_offsets_v.get(ch, 0.0) for ch in [0, 1]], + "resolution": self.resolution, + "resolution_units": "bits", + "max_adc": self._max_adc_value if self._max_adc_value else 32767, + "pre_trigger_seconds": lookback_seconds, + "buffer_duration_s": self._streaming_buffer_duration_s, + "device_id": self.device_id, + "serial_code": self.serial_code, + "coupling": "DC", + "bandwidth_limiter": self._streaming_bandwidth_limiter, + } + + self._save_thread = threading.Thread( + target=self._save_worker, + args=(output_path, available_samples, metadata, acquisition_rate), + daemon=True, + name=f"mock-save-{self.device_id}", + ) + self._save_thread.start() + self._is_saving = True + + logger.info( + "Mock save started: path={}, lookback={:.1f}s ({} samples)", + output_path, + lookback_seconds, + available_samples, + ) + return True + + def stop_save(self, keep: bool = True) -> Optional[str]: + """Stop the current save operation.""" + if not self._is_saving: + logger.info("No save in progress") + return None + + self._keep_file_on_stop = keep + logger.info("Stopping mock save (keep={})", keep) + self._save_stop_event.set() + + if self._save_thread and self._save_thread.is_alive(): + self._save_thread.join(timeout=5.0) + + path = self._save_output_path + + if not keep and path and os.path.exists(path): + try: + import shutil + + shutil.rmtree(path) + logger.info("Discarded mock save file: {}", path) + except Exception as e: + logger.exception("Error discarding save file: {}", e) + + self._is_saving = False + self._save_thread = None + self._save_stop_event = None + + return path if keep else None + + def is_save_finished(self) -> bool: + """Check if the save operation has finished. + + After this returns True, you can safely use the saved file. + + Returns + ------- + bool + True if save is complete and the file is ready, False otherwise. + """ + if not self._is_saving: + return True + + if self._save_thread and self._save_thread.is_alive(): + return False + + try: + if not self._keep_file_on_stop and self._save_output_path: + if os.path.exists(self._save_output_path): + import shutil + + shutil.rmtree(self._save_output_path) + self._is_saving = False + self._save_thread = None + self._save_stop_event = None + except Exception: + logger.exception("Error during save cleanup") + + return True + + def get_save_status(self) -> Dict[str, Any]: + """Get the current save status.""" + if not self._is_saving and self._save_error is None: + return {"state": "idle"} + if self._save_error is not None: + return {"state": "error", "error": str(self._save_error)} + + # Use acquisition_rate for consistent rate semantics + if not hasattr(self, "_acquisition_rate"): + return {"state": "saving", "total_seconds": 0} + + total_seconds = self._acquisition_rate.samples_to_seconds( + self._save_samples_written + ) + pre_trigger_seconds = self._acquisition_rate.samples_to_seconds( + self._save_pre_trigger_samples + ) + post_trigger_seconds = total_seconds - pre_trigger_seconds + + return { + "state": "saving", + "total_seconds": total_seconds, + "pre_trigger_seconds": pre_trigger_seconds, + "post_trigger_seconds": post_trigger_seconds, + } + + def _producer_loop(self) -> None: + """Producer thread: generate synthetic data and write to ring buffer.""" + logger.debug("Mock producer loop started for {}", self.device_id) + + if self._streaming_sample_rate is None: + logger.error("Sample rate not set") + return + + # CRITICAL: Always generate at target per-channel rate + # Downsampling happens AFTER by grouping samples in _apply_hardware_downsample() + generation_rate = self._streaming_sample_rate + + # Fixed chunk duration for stable timing + chunk_duration_s = 0.01 + samples_per_chunk = max(1, int(generation_rate * chunk_duration_s)) + + # Signal parameters (different frequencies for each channel) + freq_ch0 = 1.0 + freq_ch1 = 2.5 + + while not self._stop_streaming_event.is_set(): + start_time = time.time() + + # Time array for this chunk (continuous across chunks) + t_chunk = np.linspace( + self._sample_count / generation_rate, + (self._sample_count + samples_per_chunk) / generation_rate, + samples_per_chunk, + ) + max_adc = self._max_adc_value if self._max_adc_value else 32767 + + # Channel A - fixed 500mV signal amplitude (not percentage of range) + if 0 in self._enabled_channels_set: + signal_amplitude = 0.5 # 500mV fixed amplitude + signal_ch0 = signal_amplitude * np.sin(2 * np.pi * freq_ch0 * t_chunk) + signal_ch0 += np.random.normal( + 0, signal_amplitude * 0.05, samples_per_chunk + ) + voltage_range_ch0 = self._streaming_voltage_ranges.get(0, 1.0) + adc_ch0 = (signal_ch0 / voltage_range_ch0 * max_adc).astype(np.int16) + else: + adc_ch0 = np.zeros(samples_per_chunk, dtype=np.int16) + + # Channel B - fixed 500mV signal amplitude (not percentage of range) + if 1 in self._enabled_channels_set: + signal_amplitude = 0.5 # 500mV fixed amplitude + signal_ch1 = signal_amplitude * np.sin(2 * np.pi * freq_ch1 * t_chunk) + signal_ch1 += np.random.normal( + 0, signal_amplitude * 0.05, samples_per_chunk + ) + voltage_range_ch1 = self._streaming_voltage_ranges.get(1, 1.0) + adc_ch1 = (signal_ch1 / voltage_range_ch1 * max_adc).astype(np.int16) + else: + adc_ch1 = np.zeros(samples_per_chunk, dtype=np.int16) + + self._sample_count += samples_per_chunk + + # Stack into output array (always 2 columns) + raw_data = np.column_stack([adc_ch0, adc_ch1]) + + # Apply hardware downsampling to the data before writing to ring buffer + processed_data = self._apply_hardware_downsample(raw_data) + + # Write to ring buffer + if self._ring_buffer is not None: + self._ring_buffer.write(processed_data) + + # Sleep to maintain real-time rate + elapsed = time.time() - start_time + sleep_time = chunk_duration_s - elapsed + if sleep_time > 0: + time.sleep(sleep_time) + elif sleep_time < -0.01: + logger.warning("Mock producer running behind by {:.3f}s", -sleep_time) + + logger.debug("Mock producer loop finished for {}", self.device_id) + + def _apply_hardware_downsample(self, data: np.ndarray) -> np.ndarray: + """Apply hardware downsampling to raw data before writing to ring buffer. + + This simulates the Picoscope hardware downsampling behaviour: + - NONE: Data passes through unchanged + - AVERAGE: Average blocks of samples + - AGGREGATE: Output min/max pairs for each block + - DECIMATE: Sample every Nth value + + Parameters + ---------- + data : np.ndarray + Raw ADC data array, shape (n_samples, num_channels). + + Returns + ------- + np.ndarray + Downsampled data array, always int16 dtype. + For AGGREGATE mode, returns interleaved min/max pairs producing 2 samples per block. + + Raises + ------ + ValueError + If downsample ratio is invalid. + """ + downsample_ratio = self._streaming_downsample_ratio + downsample_mode = self._streaming_downsample_mode + + if downsample_ratio <= 1 or downsample_mode == "NONE": + return data + + num_samples = data.shape[0] + num_full_groups = num_samples // downsample_ratio + + if num_full_groups == 0: + return data + + if downsample_mode == "AVERAGE": + # Reshape to (groups, samples_per_group, channels) and average + truncated = data[: num_full_groups * downsample_ratio] + reshaped = truncated.reshape( + num_full_groups, downsample_ratio, data.shape[1] + ) + return reshaped.mean(axis=1).astype(data.dtype) + + if downsample_mode == "DECIMATE": + # Simple decimation: take every Nth sample + return data[::downsample_ratio] + + if downsample_mode == "AGGREGATE": + # For each block, find min and max across all samples + truncated = data[: num_full_groups * downsample_ratio] + reshaped = truncated.reshape( + num_full_groups, downsample_ratio, data.shape[1] + ) + + # Vectorized min/max calculation + min_vals = reshaped.min(axis=1).astype(data.dtype) + max_vals = reshaped.max(axis=1).astype(data.dtype) + + # Interleave min/max: [min0, max0, min1, max1, ...] + num_channels = data.shape[1] + output = np.empty((num_full_groups * 2, num_channels), dtype=data.dtype) + output[0::2] = min_vals + output[1::2] = max_vals + + return output + + return data + + def _save_worker( + self, + path: str, + lookback_samples: int, + metadata: Dict[str, Any], + acquisition_rate: "AcquisitionRate", + ) -> None: + """Worker thread for saving data to Zarr. + + This runs in a separate thread and continuously drains data + from the ring buffer to a Zarr file until stop_save is called. + Always saves 2 columns (A and B), with zeros for disabled channels. + + Parameters + ---------- + path : str + Output path for the Zarr file. + lookback_samples : int + Number of samples to include before the trigger point. + metadata : Dict[str, Any] + Metadata dictionary to store in the Zarr file. + acquisition_rate : AcquisitionRate + Acquisition rate object with all rate information. + + Raises + ------ + Exception + Any exception during saving is stored in _save_error and re-raised + to ensure the thread terminates with an error state. + """ + try: + from picostream.zarr_writer import ZarrStreamWriter + + writer = ZarrStreamWriter( + path=path, + acquisition_rate=acquisition_rate, + num_channels=2, + compression=None, + ) + + trigger_pos = self._ring_buffer.get_snapshot() + pre_trigger_start = trigger_pos - lookback_samples + if pre_trigger_start < 0: + pre_trigger_start = 0 + + pre_trigger_data = self._ring_buffer.read_range( + pre_trigger_start, trigger_pos + ) + + if pre_trigger_data.shape[0] > 0: + writer.append(pre_trigger_data) + self._save_pre_trigger_samples = pre_trigger_data.shape[0] + + last_pos = trigger_pos + + while not self._save_stop_event.is_set(): + data = self._ring_buffer.read_since(last_pos) + + if data.shape[0] > 0: + if not self._ring_buffer.is_valid_range(last_pos): + valid_start = ( + self._ring_buffer.write_idx - self._ring_buffer.capacity + ) + if valid_start < 0: + valid_start = 0 + data = self._ring_buffer.read_since(valid_start) + last_pos = valid_start + + writer.append(data) + last_pos = self._ring_buffer.get_snapshot() + self._save_samples_written = writer.written + else: + time.sleep(0.001) + + writer.close(metadata) + logger.info( + "Mock save completed: {} samples written to {}", + writer.written, + path, + ) + + except Exception as e: + logger.exception("Mock save worker error: {}", e) + self._save_error = e diff --git a/picostream/pico.py b/picostream/pico.py deleted file mode 100644 index 6b7eada..0000000 --- a/picostream/pico.py +++ /dev/null @@ -1,535 +0,0 @@ -from __future__ import annotations - -import ctypes -import queue -import threading -import time -from typing import Any, Dict, List, Optional - -import numpy as np -from loguru import logger - -from picosdk.functions import PICO_STATUS -from picosdk.errors import CannotFindPicoSDKError - -try: - from picosdk.ps5000a import ps5000a as ps - from picosdk.PicoDeviceEnums import picoEnum as enums - - # Add PS5000A bandwidth limiter enum if not present - if not hasattr(ps, 'PS5000A_BANDWIDTH_LIMITER'): - ps.PS5000A_BANDWIDTH_LIMITER = { - 'PS5000A_BW_FULL': 0, - 'PS5000A_BW_20MHZ': 20000000, - } -except CannotFindPicoSDKError: - logger.critical("PICOSDK IMPORT FAILED: CANNOT FIND SDK LIB - download from pico website") - ps = None - enums = None - - -def check_status(status: int, function_name: str) -> None: - """Check the status returned by a Picoscope SDK call and raise on error. - - Args: - status: The status code returned by the SDK function. - function_name: The name of the function that was called. - - Raises: - Exception: If the status is not PICO_OK. - """ - if status != PICO_STATUS["PICO_OK"]: - # Find the string name of the error code - error_name = next( - (k for k, v in PICO_STATUS.items() if v == status), "PICO_UNKNOWN_ERROR" - ) - raise RuntimeError( - f"{function_name} failed with status {status} ({error_name})" - ) - - -class PicoDevice: - """A class to manage a Picoscope 5000a series device for data streaming. - - This class handles device configuration, buffer management, and the data - capture loop. It acts as the "producer" in a producer-consumer pattern. - - Data Formats: - - average mode: Single stream of averaged ADC values - - aggregate mode: Interleaved stream [min1, max1, min2, max2, ...] - where each pair represents one downsampled timestep - - Buffer Management: - - SDK buffers: Direct hardware interface (bufferA, bufferB for aggregate) - - Application buffers: Larger buffers for efficient file writing - - Interleaved buffer: Temporary buffer for aggregate mode processing - """ - - def __init__( - self, - handle: int, - resolution: str, - pico_buffer_size: int, - pico_num_buffers: int, - comp_buffer_size: int, - data_queue: queue.Queue[int], - empty_queue: queue.Queue[int], - data_buffers: List[np.ndarray], - shutdown_event: threading.Event, - downsample_mode: str = "average", - ) -> None: - """Initializes the PicoDevice and opens a connection to the hardware. - - Args: - handle: The device handle provided by the SDK. - resolution: The desired resolution, e.g., "PS5000A_DR_16BIT". - pico_buffer_size: The size of each buffer allocated within the SDK. - pico_num_buffers: The number of buffers for the SDK to use. - comp_buffer_size: The size of the application-side buffers for writing to disk. - data_queue: Queue to send indices of full buffers to the consumer. - empty_queue: Queue to receive indices of empty buffers from the consumer. - data_buffers: A list of pre-allocated numpy arrays for data transfer. - shutdown_event: A threading.Event to signal shutdown. - """ - # --- Device and Resolution --- - self.handle: ctypes.c_int16 = ctypes.c_int16(handle) - self.resolution: str = resolution - res_enum = ps.PS5000A_DEVICE_RESOLUTION[resolution] - - # --- Picoscope SDK Buffer Configuration --- - self.pico_buffer_size: int = pico_buffer_size - self.pico_num_buffers: int = pico_num_buffers - self.total_samples: int = self.pico_buffer_size * self.pico_num_buffers - - # --- Application Buffer (for file writing) --- - self.comp_buffer_size: int = comp_buffer_size - - # --- Internal Data Buffer (receives data from SDK) --- - self.downsample_mode = downsample_mode - self.bufferA: np.ndarray = np.zeros(shape=self.pico_buffer_size, dtype=np.int16) - self.bufferB: Optional[np.ndarray] = None - self.interleaved_buffer: Optional[np.ndarray] = None - if self.downsample_mode == "aggregate": - self.bufferB = np.zeros(shape=self.pico_buffer_size, dtype=np.int16) - self.interleaved_buffer = np.zeros( - shape=self.pico_buffer_size * 2, dtype=np.int16 - ) - - # --- Ctypes and Callback --- - self.callbackFuncPtr = ps.StreamingReadyType(self.streaming_callback) - self.max_adc: ctypes.c_int16 = ctypes.c_int16() - - # --- Channel Configuration --- - self.channel_range: Optional[int] = None - self.voltage_range_v: Optional[float] = None - self.channel_a_coupling: Optional[str] = None - self.analog_offset_v: float = 0.0 - self.channel_a_range_str: Optional[str] = None - self.bandwidth_limiter: Optional[str] = None - - # --- Threading and Queues --- - self.shutdown_event: threading.Event = shutdown_event - self.data_queue: queue.Queue[int] = data_queue - self.empty_queue: queue.Queue[int] = empty_queue - self.data_buffers: List[np.ndarray] = data_buffers - - # --- Buffer Management State --- - self.buf_idx: int = self.empty_queue.get() - self.buf_used: int = 0 - self.buf_free: int = self.comp_buffer_size - - # --- Streaming Configuration --- - self.streaming_configured: bool = False - self.sample_int: Optional[ctypes.c_int32] = None - self.sample_unit: Optional[int] = None - self.ratio: Optional[int] = None - self.pre_trig_samples: Optional[int] = None - self.down_sample_ratio: Optional[int] = None - self.auto_stop: Optional[int] = None - self.auto_stop_stream: Optional[int] = None - - # --- Status and Performance Metrics --- - self.captured_samples: int = 0 - self.empty_pro_queue_count: int = 0 - self.overvoltage_count: int = 0 - self.hardware_buffer_overflow_count: int = 0 - self.callback_durations: List[float] = [] - - # --- Aggregate Mode Performance Tracking --- - self.interleave_durations: List[float] = [] - - # --- Open device connection --- - status = ps.ps5000aOpenUnit(ctypes.byref(self.handle), None, res_enum) - check_status(status, "ps5000aOpenUnit") - - status = ps.ps5000aMaximumValue(self.handle, ctypes.byref(self.max_adc)) - check_status(status, "ps5000aMaximumValue") - - logger.info(f"Device opened - Resolution: {self.resolution}, max_adc queried: {self.max_adc.value}") - - def set_channel( - self, chan: str, en: int, coup: str, voltage_range_str: str, offset: float - ) -> None: - """Configure a channel on the Picoscope. - - Args: - chan: The channel identifier string, e.g., "PS5000A_CHANNEL_A". - en: Whether the channel is enabled (1) or disabled (0). - coup: The coupling type string, e.g., "PS5000A_DC". - voltage_range_str: The voltage range string, e.g., "PS5000A_20V". - offset: The analog voltage offset in Volts. - """ - channel_range_enum = ps.PS5000A_RANGE[voltage_range_str] - self.channel_range = channel_range_enum - - # Store the actual voltage range for metadata and conversion - range_to_voltage = { - "PS5000A_10MV": 0.01, - "PS5000A_20MV": 0.02, - "PS5000A_50MV": 0.05, - "PS5000A_100MV": 0.1, - "PS5000A_200MV": 0.2, - "PS5000A_500MV": 0.5, - "PS5000A_1V": 1.0, - "PS5000A_2V": 2.0, - "PS5000A_5V": 5.0, - "PS5000A_10V": 10.0, - "PS5000A_20V": 20.0, - } - self.voltage_range_v = range_to_voltage.get(voltage_range_str) - - if chan == "PS5000A_CHANNEL_A" and en: - self.channel_a_coupling = coup - self.channel_a_range_str = voltage_range_str - self.analog_offset_v = offset - - channel_enum = ps.PS5000A_CHANNEL[chan] - coupling_enum = ps.PS5000A_COUPLING[coup] - status = ps.ps5000aSetChannel( - self.handle, channel_enum, en, coupling_enum, channel_range_enum, offset - ) - check_status(status, f"ps5000aSetChannel ({chan})") - logger.info( - f"Channel configured - Range: '{voltage_range_str}' (enum: {channel_range_enum}), " - f"voltage_range_v: {self.voltage_range_v}V, max_adc: {self.max_adc.value}" - ) - - def set_bandwidth_filter(self, channel: str, bandwidth: str) -> None: - """Set the bandwidth filter for a channel to reduce noise. - - Args: - channel: The channel identifier string, e.g., "PS5000A_CHANNEL_A". - bandwidth: The bandwidth limiter string, e.g., "PS5000A_BW_FULL" or "PS5000A_BW_20MHZ". - """ - channel_enum = ps.PS5000A_CHANNEL[channel] - bandwidth_enum = ps.PS5000A_BANDWIDTH_LIMITER[bandwidth] - - if channel == "PS5000A_CHANNEL_A": - self.bandwidth_limiter = bandwidth - - status = ps.ps5000aSetBandwidthFilter( - self.handle, - channel_enum, - bandwidth_enum - ) - check_status(status, f"ps5000aSetBandwidthFilter ({channel}, {bandwidth})") - logger.info(f"Bandwidth filter set - Channel: {channel}, Bandwidth: {bandwidth}") - - def set_data_buffer(self, chan: str, segment: int, rat: str) -> None: - """Set up the data buffer for a specific channel for streaming. - - Args: - chan: The channel identifier string, e.g., "PS5000A_CHANNEL_A". - segment: The memory segment to use (0 for streaming). - rat: The ratio mode string, e.g., "PS5000A_RATIO_MODE_NONE". - """ - channel_enum = ps.PS5000A_CHANNEL[chan] - ratio_enum = ps.PS5000A_RATIO_MODE[rat] - - buffer_min_ptr = None - if self.bufferB is not None: - buffer_min_ptr = self.bufferB.ctypes.data_as(ctypes.POINTER(ctypes.c_int16)) - - status = ps.ps5000aSetDataBuffers( - self.handle, - channel_enum, - self.bufferA.ctypes.data_as(ctypes.POINTER(ctypes.c_int16)), - buffer_min_ptr, - self.pico_buffer_size, - segment, - ratio_enum, - ) - check_status(status, f"ps5000aSetDataBuffers ({chan})") - - def configure_streaming_var( - self, - samp_int: int, - samp_unit: str, - pre_trig_samp: int, - down_samp_rat: int, - rat: str, - auto_stop: int, - auto_stop_stream: int, - ) -> None: - """Store streaming parameters before starting the capture. - - Args: - samp_int: The desired sample interval in `samp_unit` units. - samp_unit: The time unit string, e.g., "PS5000A_NS". - pre_trig_samp: The number of pre-trigger samples. - down_samp_rat: The downsampling ratio. - rat: The ratio mode string, e.g., "PS5000A_RATIO_MODE_NONE". - auto_stop: Whether to stop the capture automatically (1) or not (0). - auto_stop_stream: Deprecated, not used. - """ - self.sample_int = ctypes.c_int32(samp_int) - self.sample_unit = ps.PS5000A_TIME_UNITS[samp_unit] - self.ratio = ps.PS5000A_RATIO_MODE[rat] - self.pre_trig_samples = pre_trig_samp - self.down_sample_ratio = down_samp_rat - self.auto_stop = auto_stop - self.auto_stop_stream = auto_stop_stream - - def get_metadata(self, acquisition_start_time_utc: str, picostream_version: str, acquisition_command: str, was_live_mode: bool) -> Dict[str, Any]: - """Return comprehensive acquisition metadata.""" - metadata = { - "resolution": self.resolution, - "sample_interval_ns": self.sample_int.value if self.sample_int else None, - "voltage_range_v": self.voltage_range_v, - "channel_a_coupling": self.channel_a_coupling, - "channel_a_range": self.channel_a_range_str, - "downsample_mode": self.downsample_mode, - "analog_offset_v": self.analog_offset_v, - "hardware_downsample_ratio": self.down_sample_ratio, - "max_adc": self.max_adc.value, - "data_format_version": "2.0", - "acquisition_start_time_utc": acquisition_start_time_utc, - "picostream_version": picostream_version, - "acquisition_command": acquisition_command, - "was_live_mode": was_live_mode, - "bandwidth_limiter": self.bandwidth_limiter, - } - - if self.downsample_mode == "aggregate": - metadata.update( - { - "aggregate_format": "interleaved_min_max", - "aggregate_description": "Data format: [min1, max1, min2, max2, ...]", - } - ) - - return metadata - - def run_streaming(self) -> None: - """Starts the Picoscope streaming capture.""" - status = ps.ps5000aRunStreaming( - self.handle, - ctypes.byref(self.sample_int), - self.sample_unit, - self.pre_trig_samples, - self.total_samples, - self.auto_stop, - self.down_sample_ratio, - self.ratio, - self.pico_buffer_size, - ) - check_status(status, "ps5000aRunStreaming") - self.streaming_configured = True - logger.info( - f"Streaming configured. Actual sample interval: {self.sample_int.value} ns" - ) - - def streaming_callback( - self, - _handle: int, - noOfSamples: int, - startIndex: int, - overvoltage_flags: int, - _triggerAt: int, - _triggered: int, - _autoStop: int, - _param: int, - ) -> None: - """Callback function executed by the SDK when new streaming data is available. - - This function is the heart of the producer. It copies data from the - Picoscope's internal buffer into the application's shared buffer pool. - When an application buffer is full, its index is placed on the data_queue - for the consumer. - - Note: This function is called from a thread created by the Picoscope SDK. - It must be fast and thread-safe. - """ - if overvoltage_flags: - self.overvoltage_count += 1 - logger.warning( - "Picoscope ADC over-range detected (saturation/clipping)." - ) - - # Stop processing if a shutdown is requested. - if self.shutdown_event.is_set(): - return - - callback_start_time = time.perf_counter() - if noOfSamples > 0: - self.captured_samples += noOfSamples - - source_buffer = self.bufferA - samples_to_process = noOfSamples - - # In aggregate mode, interleave the min/max buffers into one - if ( - self.downsample_mode == "aggregate" - and self.bufferB is not None - and self.interleaved_buffer is not None - ): - interleave_start = time.perf_counter() - - # The SDK provides min/max data in separate buffers (B/A). - # We interleave them into a single [min, max, min, max, ...] - # stream for the consumer. - total_interleaved_samples = noOfSamples * 2 - - # Use numpy's more efficient interleaving - np.stack( - [ - self.bufferB[startIndex : startIndex + noOfSamples], - self.bufferA[startIndex : startIndex + noOfSamples], - ], - axis=1, - out=self.interleaved_buffer[:total_interleaved_samples].reshape( - -1, 2 - ), - ) - - source_buffer = self.interleaved_buffer - samples_to_process = total_interleaved_samples - # After interleaving, the source index is always 0 - source_index = 0 - - # Track interleaving performance - self.interleave_durations.append( - (time.perf_counter() - interleave_start) * 1000 - ) - else: - source_index = startIndex - - # This loop copies data from the source_buffer into - # our larger, shared application buffers (self.data_buffers). - while samples_to_process > 0: - # Determine how much space is left in the current application buffer. - self.buf_free = self.comp_buffer_size - self.buf_used - copy_size = min(samples_to_process, self.buf_free) - - # Copy the data slice. - self.data_buffers[self.buf_idx][ - self.buf_used : self.buf_used + copy_size - ] = source_buffer[source_index : source_index + copy_size] - - # Update pointers and remaining sample counts. - self.buf_used += copy_size - samples_to_process -= copy_size - source_index += copy_size - - # If the current application buffer is full... - if self.buf_used == self.comp_buffer_size: - # ...send its index to the consumer. - self.data_queue.put(self.buf_idx) - try: - # ...and get a new empty buffer from the consumer. - self.buf_idx = self.empty_queue.get_nowait() - self.buf_used = 0 - except queue.Empty: - # This is a critical failure. The consumer is not keeping up. - self.empty_pro_queue_count += 1 - logger.critical( - "Producer queue is empty. Consumer cannot keep up. " - "Shutting down to prevent data loss." - ) - self.shutdown_event.set() - return # Exit immediately. - - duration_ms = (time.perf_counter() - callback_start_time) * 1000 - self.callback_durations.append(duration_ms) - - def run_capture(self) -> None: - """The main capture loop for the producer thread.""" - if not self.streaming_configured: - self.run_streaming() - - # This loop polls the SDK for new data, which triggers the callback. - while not self.shutdown_event.is_set(): - status = ps.ps5000aGetStreamingLatestValues( - self.handle, self.callbackFuncPtr, None - ) - - if status == PICO_STATUS["PICO_BUFFER_STALL"]: - self.hardware_buffer_overflow_count += 1 - logger.critical( - "Picoscope hardware buffer overflow occurred. Data was lost." - ) - elif status not in [ - PICO_STATUS["PICO_OK"], - PICO_STATUS["PICO_NO_SAMPLES_AVAILABLE"], - PICO_STATUS["PICO_BUSY"], - PICO_STATUS["PICO_DATA_NOT_AVAILABLE"], - PICO_STATUS["PICO_DRIVER_FUNCTION"], - ]: - check_status(status, "ps5000aGetStreamingLatestValues") - - # Yield the GIL to other threads. - time.sleep(0.001) - - # --- Shutdown and reporting --- - logger.info( - f"Producer couldn't obtain an empty queue {self.empty_pro_queue_count} times." - ) - if self.hardware_buffer_overflow_count > 0: - logger.critical( - f"Picoscope hardware buffer overflowed {self.hardware_buffer_overflow_count} times. " - "This indicates the application could not process data fast enough from the driver." - ) - if self.overvoltage_count > 0: - logger.warning( - f"Picoscope ADC over-ranged (clipped) {self.overvoltage_count} times." - ) - if self.callback_durations: - logger.info("--- Callback Performance ---") - logger.info(f"Total callbacks: {len(self.callback_durations)}") - logger.info(f"Min duration: {min(self.callback_durations):.2f} ms") - logger.info(f"Max duration: {max(self.callback_durations):.2f} ms") - logger.info(f"Avg duration: {np.mean(self.callback_durations):.2f} ms") - logger.info("--------------------------") - - # Report aggregate mode performance if applicable - if self.downsample_mode == "aggregate" and self.interleave_durations: - logger.info("--- Aggregate Mode Performance ---") - logger.info( - f"Total interleave operations: {len(self.interleave_durations)}" - ) - logger.info( - f"Min interleave duration: {min(self.interleave_durations):.3f} ms" - ) - logger.info( - f"Max interleave duration: {max(self.interleave_durations):.3f} ms" - ) - logger.info( - f"Avg interleave duration: {np.mean(self.interleave_durations):.3f} ms" - ) - logger.info("----------------------------------") - - def close_device(self) -> None: - """Stops the Picoscope and closes the connection.""" - # Check if handle is valid before trying to close. - if self.handle.value > 0: - status_stop = ps.ps5000aStop(self.handle) - if status_stop != PICO_STATUS["PICO_OK"]: - logger.warning(f"ps5000aStop failed with status {status_stop}") - - status_close = ps.ps5000aCloseUnit(self.handle) - if status_close != PICO_STATUS["PICO_OK"]: - logger.warning(f"ps5000aCloseUnit failed with status {status_close}") - - # Invalidate handle to prevent reuse. - self.handle.value = 0 - logger.info("Picoscope connection closed.") diff --git a/picostream/reader.py b/picostream/reader.py deleted file mode 100644 index aa3e425..0000000 --- a/picostream/reader.py +++ /dev/null @@ -1,259 +0,0 @@ -from __future__ import annotations - -from typing import Generator, Tuple - -import h5py -import numpy as np - -from .conversion_utils import adc_to_mV, min_max_decimate_numba - - -class PicoStreamReader: - """A helper class to read and interpret data from picostream HDF5 files. - - This class provides a simple context-manager interface to open HDF5 files, - read metadata, and retrieve blocks of data as time and voltage arrays. - - It correctly handles the `analog_offset_v` if it is present in the file's - metadata. - - Example: - >>> with PicoStreamReader("output.hdf5") as reader: - ... print(f"Sample rate: {1e9 / reader.sample_interval_ns:.2f} S/s") - ... # Iterate through the whole file - ... for times, voltages in reader.get_block_iter(chunk_size=1_000_000): - ... # process data - ... pass - ... - ... # Or iterate with 20x decimation (mean) - ... for times, voltages in reader.get_block_iter( - ... chunk_size=1_000_000, decimation_factor=20, decimation_mode='mean' - ... ): - ... # process decimated data - ... pass - ... - ... # Or get blocks sequentially - ... reader.reset() # Reset internal counter - ... while True: - ... block = reader.get_next_block(chunk_size=500_000) - ... if block is None: - ... break - ... times, voltages = block - ... # process data - """ - - def __init__(self, hdf5_path: str): - """Initializes the PicoStreamReader. - - Args: - hdf5_path: Path to the HDF5 file. - """ - self.hdf5_path = hdf5_path - self._file: h5py.File | None = None - self.dset: h5py.Dataset | None = None - self._current_pos: int = 0 - - # Metadata attributes, populated in __enter__ - self.sample_interval_ns: float = 0.0 - self.voltage_range_v: float = 0.0 - self.max_adc_val: int = 0 - self.downsample_mode: str = "average" - self.hardware_downsample_ratio: int = 1 - self.analog_offset_v: float = 0.0 - - def __enter__(self) -> PicoStreamReader: - """Opens the HDF5 file and reads metadata.""" - self._file = h5py.File(self.hdf5_path, "r") - self.dset = self._file["adc_counts"] - - # Read metadata from file attributes - attrs = self._file.attrs - base_sample_interval_ns = attrs["sample_interval_ns"] - self.hardware_downsample_ratio = attrs.get("hardware_downsample_ratio", 1) - self.sample_interval_ns = ( - base_sample_interval_ns * self.hardware_downsample_ratio - ) - self.voltage_range_v = attrs["voltage_range_v"] - if "max_adc" in attrs: - self.max_adc_val = attrs["max_adc"] - elif "resolution" in attrs: - # Fallback for older files: calculate max_adc from resolution string - res_str = attrs["resolution"] # e.g., "PS5000A_DR_16BIT" - try: - # Extract bit depth (e.g., 16) from the string - res_int = int(res_str.split("_")[-1].replace("BIT", "")) - self.max_adc_val = (2 ** (res_int - 1)) - 1 - except (ValueError, IndexError): - raise KeyError( - f"Could not parse 'resolution' attribute to determine max_adc: {res_str}" - ) - else: - raise KeyError( - "HDF5 file is missing required 'max_adc' or 'resolution' attribute." - ) - self.downsample_mode = attrs.get("downsample_mode", "average") - self.analog_offset_v = attrs.get("analog_offset_v", 0.0) - - self.reset() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """Closes the HDF5 file.""" - if self._file: - self._file.close() - - @property - def num_samples(self) -> int: - """Total number of samples in the dataset.""" - if self.dset: - return self.dset.shape[0] - return 0 - - def reset(self) -> None: - """Resets the internal position counter for get_next_block.""" - self._current_pos = 0 - - def get_next_block( - self, - chunk_size: int, - decimation_factor: int = 1, - decimation_mode: str = "mean", - ) -> Tuple[np.ndarray, np.ndarray] | None: - """Retrieves the next block of data. - - Args: - chunk_size: The maximum number of raw samples to retrieve. - decimation_factor: The factor by which to decimate the data. - decimation_mode: The decimation method ('mean' or 'min_max'). - - Returns: - A (times, voltages) tuple, or None if no more data is available. - """ - if self._current_pos >= self.num_samples: - return None - - size = min(chunk_size, self.num_samples - self._current_pos) - block = self.get_block( - size=size, - start=self._current_pos, - decimation_factor=decimation_factor, - decimation_mode=decimation_mode, - ) - self._current_pos += size - return block - - def get_block_iter( - self, - chunk_size: int = 1_000_000, - decimation_factor: int = 1, - decimation_mode: str = "mean", - ) -> Generator[Tuple[np.ndarray, np.ndarray], None, None]: - """Yields data blocks as (times, voltages) tuples for the entire dataset. - - Args: - chunk_size: The size of each raw data chunk to yield. - decimation_factor: The factor by which to decimate the data. - decimation_mode: The decimation method ('mean' or 'min_max'). - """ - num_samples = self.num_samples - for start_idx in range(0, num_samples, chunk_size): - size = min(chunk_size, num_samples - start_idx) - yield self.get_block( - size=size, - start=start_idx, - decimation_factor=decimation_factor, - decimation_mode=decimation_mode, - ) - - def get_block( - self, - size: int, - start: int = 0, - decimation_factor: int = 1, - decimation_mode: str = "mean", - ) -> Tuple[np.ndarray, np.ndarray]: - """Retrieves a specific block of data and converts it to time and voltage. - - Optionally, it can decimate the data using either averaging or a - min-max technique. - - Args: - size: The number of raw samples to retrieve from the file. - start: The starting sample index. - decimation_factor: The factor by which to decimate the data. - decimation_mode: The decimation method ('mean' or 'min_max'). - - Returns: - A tuple containing: - - times (np.ndarray): The time axis in seconds. - - voltages (np.ndarray): The voltage data in millivolts. - """ - if not self.dset: - raise RuntimeError("File not open. Use this class as a context manager.") - - adc_data = self.dset[start : start + size] - - if adc_data.size == 0: - return np.array([]), np.array([]) - - # --- Decimation --- - decimated_adc_data = adc_data - final_decimation_factor = 1 - - if decimation_factor > 1 and adc_data.size >= decimation_factor: - final_decimation_factor = decimation_factor - - # Truncate data to be a multiple of the decimation factor - num_windows = adc_data.size // decimation_factor - truncated_len = num_windows * decimation_factor - adc_data_to_process = adc_data[:truncated_len] - - if decimation_mode == "min_max": - decimated_adc_data = min_max_decimate_numba( - adc_data_to_process, decimation_factor - ) - elif decimation_mode == "mean": - means = adc_data_to_process.reshape( - -1, decimation_factor - ).mean(axis=1) - decimated_adc_data = np.repeat(means, 2) - else: - raise ValueError( - f"Unknown decimation_mode: '{decimation_mode}'. Use 'mean' or 'min_max'." - ) - - # --- Voltage Conversion --- - voltages_mv = adc_to_mV(decimated_adc_data, self.voltage_range_v, self.max_adc_val) - if self.analog_offset_v != 0.0: - voltages_mv += self.analog_offset_v * 1000 - - # --- Time Axis Creation --- - points_per_timestep = 2 if self.downsample_mode == "aggregate" else 1 - time_per_timestep = self.sample_interval_ns * 1e-9 - time_per_raw_point = time_per_timestep / points_per_timestep - - start_time = (start / points_per_timestep) * time_per_timestep - - num_output_samples = voltages_mv.size - - if final_decimation_factor > 1: - time_step_between_windows = time_per_raw_point * final_decimation_factor - - # Both 'mean' and 'min_max' now produce pairs of points. - # The time axis is the start time of each decimation window, repeated. - num_pairs = num_output_samples // 2 - window_start_times = ( - start_time + np.arange(num_pairs) * time_step_between_windows - ) - times = np.repeat(window_start_times, 2) - else: - # Original time axis logic for non-decimated data - if self.downsample_mode == "aggregate": - num_pairs = num_output_samples // 2 - pair_times = start_time + np.arange(num_pairs) * time_per_timestep - times = np.repeat(pair_times, 2) - else: - end_time = start_time + (num_output_samples - 1) * time_per_raw_point - times = np.linspace(start_time, end_time, num_output_samples) - - return times, voltages_mv diff --git a/picostream/ring_buffer.py b/picostream/ring_buffer.py new file mode 100644 index 0000000..678f6da --- /dev/null +++ b/picostream/ring_buffer.py @@ -0,0 +1,331 @@ +"""Ring buffer for high-throughput streaming data with lock-free reads/writes. + +This module provides a fixed-size circular buffer implemented as a NumPy array. +It is designed for single-producer, multiple-reader scenarios where the producer +writes continuously and readers access historical or recent data. + +The implementation relies on CPython's GIL atomicity for the write index update +and NumPy's GIL-releasing memcpy for data transfer, providing efficient +lock-free operation suitable for high-speed data acquisition. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +import numpy as np +from loguru import logger + +if TYPE_CHECKING: + from picostream.acquisition_rate import AcquisitionRate + + +class RingBuffer: + """A fixed-size ring buffer for multi-channel streaming data. + + This buffer stores int16 samples for multiple channels in a circular fashion. + It supports a single writer and multiple concurrent readers. The write index + is updated atomically (under CPython's GIL), and NumPy slice operations + release the GIL during the actual memory copy. + + Parameters + ---------- + duration_s : float + Buffer duration in seconds. Determines capacity. + sample_rate : float + Expected rate at which samples will arrive at the buffer in Hz. + This should be the effective rate after any hardware downsampling. + num_channels : int, optional + Number of channels. Default is 2. + + Attributes + ---------- + capacity : int + Total samples per channel that fit in the buffer. + buffer : np.ndarray + The underlying storage array, shape (capacity, num_channels), dtype int16. + write_idx : int + Monotonically increasing write position. Use modulo capacity for indexing. + num_channels : int + Number of channels. + storage_rate_hz : float + Rate at which samples arrive at the buffer (post-downsampling if applicable). + + Examples + -------- + >>> buffer = RingBuffer(duration_s=30.0, sample_rate=62.5e6, num_channels=2) + >>> data = np.random.randint(-1000, 1000, size=(1000, 2), dtype=np.int16) + >>> buffer.write(data) + >>> recent = buffer.read_last(500) + """ + + def __init__( + self, + duration_s: float, + sample_rate: float, + num_channels: int = 2, + ) -> None: + """Initialise the ring buffer with the specified dimensions.""" + self.duration_s = duration_s + self.num_channels = num_channels + self._acquisition_rate: Optional[AcquisitionRate] = None + + # Initial capacity based on configured rate for validation + self._initial_sample_rate = sample_rate + + self.capacity = int(duration_s * self._initial_sample_rate) + if self.capacity <= 0: + raise ValueError( + f"Buffer capacity must be positive, got {self.capacity} " + f"(duration_s={duration_s}, sample_rate={sample_rate})" + ) + + self.buffer = np.zeros((self.capacity, num_channels), dtype=np.int16) + self.write_idx = 0 + + logger.debug( + "RingBuffer initialised: capacity={} samples, " + "{:.2f} seconds, {} channels, {:.2f} MB", + self.capacity, + duration_s, + num_channels, + self.buffer.nbytes / (1024 * 1024), + ) + + def write(self, data: np.ndarray) -> None: + """Write data to the ring buffer. + + Data is written at the current write position, wrapping around + to overwrite old data when the buffer is full. The write index + is updated atomically under CPython's GIL. + + Parameters + ---------- + data : np.ndarray + Array of shape (n_samples, num_channels) with dtype int16. + + Raises + ------ + ValueError + If data shape does not match num_channels or dtype is not int16. + """ + if data.dtype != np.int16: + raise ValueError(f"Data must be int16, got {data.dtype}") + + if data.ndim != 2 or data.shape[1] != self.num_channels: + raise ValueError( + f"Data shape must be (n_samples, {self.num_channels}), got {data.shape}" + ) + + n_samples = data.shape[0] + if n_samples == 0: + return + + # If data exceeds capacity, only keep the last 'capacity' samples + if n_samples > self.capacity: + data = data[-self.capacity :] + n_samples = self.capacity + + start_idx = self.write_idx % self.capacity + end_idx = (self.write_idx + n_samples) % self.capacity + + if start_idx < end_idx: + # No wrap: single write + self.buffer[start_idx:end_idx] = data + else: + # Wrap-around: split into tail + head + tail_size = self.capacity - start_idx + self.buffer[start_idx:] = data[:tail_size] + self.buffer[:end_idx] = data[tail_size:] + + self.write_idx += n_samples + + def get_snapshot(self) -> int: + """Return the current write index position. + + Returns + ------- + int + The current write position (monotonically increasing). + """ + return self.write_idx + + def read_last(self, n_samples: int) -> np.ndarray: + """Read the most recent n_samples from the buffer. + + Parameters + ---------- + n_samples : int + Number of samples to read. Will be clamped to available data. + + Returns + ------- + np.ndarray + Array of shape (m, num_channels) where m <= n_samples, + containing the most recent data. Returns a copy, not a view. + """ + available = min(n_samples, self.write_idx, self.capacity) + if available <= 0: + return np.empty((0, self.num_channels), dtype=np.int16) + + return self.read_range(self.write_idx - available, self.write_idx) + + def read_range(self, start_idx: int, end_idx: int) -> np.ndarray: + """Read a specific range of samples from the buffer. + + Handles wrap-around automatically. Always returns a copy. + + Parameters + ---------- + start_idx : int + Starting sample index (inclusive). + end_idx : int + Ending sample index (exclusive). + + Returns + ------- + np.ndarray + Array of shape (n, num_channels) containing the requested range. + Returns empty array if end_idx <= start_idx. + + Raises + ------ + ValueError + If requesting more than capacity samples. + """ + n_requested = end_idx - start_idx + + if n_requested <= 0: + return np.empty((0, self.num_channels), dtype=np.int16) + + if n_requested > self.capacity: + logger.warning( + "Requested {} samples exceeds buffer capacity {}, clamping", + n_requested, + self.capacity, + ) + n_requested = self.capacity + start_idx = end_idx - n_requested + + start_mod = start_idx % self.capacity + end_mod = end_idx % self.capacity + + if start_mod < end_mod: + # No wrap + return self.buffer[start_mod:end_mod].copy() + else: + # Wrap-around + tail = self.buffer[start_mod:] + head = self.buffer[:end_mod] + return np.concatenate([tail, head], axis=0) + + def read_since(self, position: int) -> np.ndarray: + """Read all data written since the given position. + + Convenience method equivalent to read_range(position, get_snapshot()). + + Parameters + ---------- + position : int + Previous write position from get_snapshot(). + + Returns + ------- + np.ndarray + Array of new data written since position. May be empty. + """ + return self.read_range(position, self.write_idx) + + def is_valid_range(self, start_idx: int) -> bool: + """Check if a start index is still within the valid buffer range. + + Used by save threads to verify data hasn't been overwritten. + + Note + ---- + This check is best-effort and does not prevent a race condition: + the producer may overwrite the data after this check returns True + but before the caller reads the data. Callers should handle + RuntimeError from read_range() as a fallback. + + Parameters + ---------- + start_idx : int + The sample index to check. + + Returns + ------- + bool + True if start_idx is still in the buffer (producer hasn't + overwritten it), False otherwise. + """ + return (self.write_idx - start_idx) <= self.capacity + + def get_utilisation(self) -> float: + """Return the fraction of buffer currently filled with data. + + Returns + ------- + float + Fraction between 0.0 and 1.0. + """ + filled = min(self.write_idx, self.capacity) + return filled / self.capacity + + @property + def storage_rate_hz(self) -> float: + """Return the rate at which samples arrive at the buffer. + + This is the effective sample rate after any hardware downsampling. + For example, with 10x hardware downsampling, this will be 1/10th + of the hardware ADC rate. + + Returns + ------- + float + The storage rate from acquisition_rate. + + Raises + ------ + RuntimeError + If acquisition_rate has not been set. + """ + if self._acquisition_rate is None: + msg = "storage_rate_hz called before acquisition_rate was set" + raise RuntimeError(msg) + return self._acquisition_rate.storage_rate_hz + + def set_acquisition_rate(self, rate: AcquisitionRate) -> None: + """Set the acquisition rate object for this buffer. + + This is the single source of truth for all rate calculations. + Must be called before using the buffer for time-based calculations. + + Parameters + ---------- + rate : AcquisitionRate + The acquisition rate object with all rate information. + """ + self._acquisition_rate = rate + logger.info("RingBuffer acquisition rate set: {}", rate) + + def get_acquisition_rate(self) -> Optional[AcquisitionRate]: + """Get the acquisition rate object for this buffer. + + Returns + ------- + Optional[AcquisitionRate] + The acquisition rate, or None if not yet set. + """ + return self._acquisition_rate + + def get_duration_available(self) -> float: + """Return how many seconds of data are currently in the buffer. + + Returns + ------- + float + Duration in seconds of available data. + """ + samples = min(self.write_idx, self.capacity) + return samples / self.storage_rate_hz diff --git a/picostream/test_buffered_stream.py b/picostream/test_buffered_stream.py new file mode 100644 index 0000000..48821b2 --- /dev/null +++ b/picostream/test_buffered_stream.py @@ -0,0 +1,344 @@ +"""Unit tests for PicoscopeBufferedStream with ring buffer and Zarr save.""" + +import os +import shutil +import tempfile +import time +from typing import Generator +from unittest.mock import MagicMock, patch + +import pytest + + +class MockPicoSDK: + """Mock PicoSDK for testing.""" + + PICO_OK = 0 + PICO_BUSY = 1 + PICO_NO_SAMPLES_AVAILABLE = 2 + PICO_BUFFER_STALL = 3 + + PS5000A_CHANNEL = { + "PS5000A_CHANNEL_A": 0, + "PS5000A_CHANNEL_B": 1, + "PS5000A_CHANNEL_C": 2, + "PS5000A_CHANNEL_D": 3, + } + PS5000A_RANGE = { + "PS5000A_10MV": 0, + "PS5000A_100MV": 1, + "PS5000A_1V": 2, + "PS5000A_20V": 3, + } + PS5000A_COUPLING = {"PS5000A_DC": 1} + PS5000A_RATIO_MODE = {"PS5000A_RATIO_MODE_NONE": 0} + PS5000A_TIME_UNITS = {"PS5000A_NS": 0} + PS5000A_DEVICE_RESOLUTION = {"PS5000A_DR_12BIT": 0} + + def __init__(self): + self.StreamingReadyType = MagicMock() + self._streaming_callback = None + + def ps5000aOpenUnit(self, handle, serial, resolution): + handle.value = 1 + return self.PICO_OK + + def ps5000aCloseUnit(self, handle): + return self.PICO_OK + + def ps5000aMaximumValue(self, handle, max_adc): + max_adc.value = 32767 + return self.PICO_OK + + def ps5000aSetChannel(self, handle, channel, enabled, coupling, range_val, offset): + return self.PICO_OK + + def ps5000aSetSimpleTrigger( + self, handle, enabled, source, threshold, direction, delay, auto + ): + return self.PICO_OK + + def ps5000aGetTimebase2( + self, handle, timebase, num_samples, interval, max_samples, segment + ): + interval.value = 16.0 + return self.PICO_OK + + def ps5000aSetDataBuffers( + self, handle, channel, buffer_max, buffer_min, size, segment, mode + ): + return self.PICO_OK + + def ps5000aRunStreaming( + self, + handle, + interval, + unit, + pre, + max_post, + auto_stop, + down_ratio, + mode, + overview, + ): + return self.PICO_OK + + def ps5000aStop(self, handle): + return self.PICO_OK + + def ps5000aGetStreamingLatestValues(self, handle, callback, param): + return self.PICO_NO_SAMPLES_AVAILABLE + + def ps5000aMemorySegments(self, handle, n_segments, max_samples): + # ctypes.c_int32/c_int64 passed in, set .value + if hasattr(max_samples, "value"): + max_samples.value = 100000 + return self.PICO_OK + + def ps5000aSetNoOfCaptures(self, handle, n_captures): + return self.PICO_OK + + +@pytest.fixture +def mock_picosdk() -> Generator[MockPicoSDK, None, None]: + """Provide a mock PicoSDK.""" + mock = MockPicoSDK() + with patch.dict( + "sys.modules", + { + "picosdk": MagicMock(), + "picosdk.ps5000a": MagicMock(), + "picosdk.functions": MagicMock(), + }, + ): + with patch("picostream.device.ps", mock): + with patch("picostream.device.adc2mV", MagicMock()): + with patch("picostream.device.assert_pico_ok", lambda x: None): + with patch( + "picostream.device.PICO_STATUS", + { + "PICO_OK": 0, + "PICO_BUSY": 1, + "PICO_NO_SAMPLES_AVAILABLE": 2, + "PICO_BUFFER_STALL": 3, + }, + ): + yield mock + + +@pytest.fixture +def temp_zarr_path() -> Generator[str, None, None]: + """Provide a temporary path for Zarr files.""" + temp_dir = tempfile.mkdtemp() + zarr_path = os.path.join(temp_dir, "test.zarr") + yield zarr_path + shutil.rmtree(temp_dir, ignore_errors=True) + + +class TestPicoscopeBufferedStream: + """Test cases for PicoscopeBufferedStream.""" + + def test_configure_streaming_creates_ring_buffer( + self, mock_picosdk: MockPicoSDK + ) -> None: + """Test that configure_streaming creates a ring buffer.""" + from picostream.device import PicoscopeBufferedStream + + device = PicoscopeBufferedStream("test_device", resolution=12) + device._chandle = MagicMock() + device._is_connected = True + device._max_adc_value = MagicMock() + device._max_adc_value.value = 32767 + + device.configure_streaming( + sample_rate=1_000_000.0, + channels=[0, 1], + voltage_ranges=[1.0, 1.0], + buffer_duration_s=1.0, + ) + + assert device._ring_buffer is not None + assert device._ring_buffer.num_channels == 2 + # Note: acquisition_rate is set during start_streaming, not configure_streaming + # storage_rate_hz will fail fast unless acquisition_rate is set first + + def test_configure_streaming_validates_channels( + self, mock_picosdk: MockPicoSDK + ) -> None: + """Test that configure_streaming validates channel parameters. + + Empty channels list is invalid - at least one channel must be enabled + for acquisition. Hardware rate becomes 0 with 0 channels. + """ + from picostream.device import PicoscopeBufferedStream + + device = PicoscopeBufferedStream("test_device", resolution=12) + device._chandle = MagicMock() + device._is_connected = True + device._max_adc_value = MagicMock() + device._max_adc_value.value = 32767 + + # Empty channels list should raise DeviceConfigurationError with helpful message + with pytest.raises(Exception, match="At least one channel must be enabled"): + device.configure_streaming( + sample_rate=1_000_000.0, + channels=[], + voltage_ranges=[], + buffer_duration_s=1.0, + ) + + def test_start_save_lifecycle( + self, mock_picosdk: MockPicoSDK, temp_zarr_path: str + ) -> None: + """Test the start_save / stop_save lifecycle.""" + from picostream.device import PicoscopeBufferedStream + + device = PicoscopeBufferedStream("test_device", resolution=12) + device._chandle = MagicMock() + device._is_connected = True + device._max_adc_value = MagicMock() + device._max_adc_value.value = 32767 + + device.configure_streaming( + sample_rate=100_000.0, + channels=[0, 1], + voltage_ranges=[1.0, 1.0], + buffer_duration_s=2.0, + ) + + assert device.start_save(lookback_seconds=0.5, output_path=temp_zarr_path) + assert device._is_saving + + time.sleep(0.1) + + path = device.stop_save(keep=True) + assert path == temp_zarr_path + + # Wait for save to complete (non-blocking API) + deadline = time.time() + 5.0 + while not device.is_save_finished() and time.time() < deadline: + time.sleep(0.01) + + assert not device._is_saving + + def test_stop_save_discard_deletes_file( + self, mock_picosdk: MockPicoSDK, temp_zarr_path: str + ) -> None: + """Test that stop_save(keep=False) deletes the Zarr directory.""" + from picostream.device import PicoscopeBufferedStream + + device = PicoscopeBufferedStream("test_device", resolution=12) + device._chandle = MagicMock() + device._is_connected = True + device._max_adc_value = MagicMock() + device._max_adc_value.value = 32767 + + device.configure_streaming( + sample_rate=100_000.0, + channels=[0, 1], + voltage_ranges=[1.0, 1.0], + buffer_duration_s=2.0, + ) + + device.start_save(lookback_seconds=0.5, output_path=temp_zarr_path) + time.sleep(0.1) + + device.stop_save(keep=False) + + # Wait for save to complete (non-blocking API) + deadline = time.time() + 5.0 + while not device.is_save_finished() and time.time() < deadline: + time.sleep(0.01) + + assert not os.path.exists(temp_zarr_path) + + def test_get_save_status_idle(self, mock_picosdk: MockPicoSDK) -> None: + """Test get_save_status returns idle when not saving.""" + from picostream.device import PicoscopeBufferedStream + + device = PicoscopeBufferedStream("test_device", resolution=12) + device._chandle = MagicMock() + device._is_connected = True + device._max_adc_value = MagicMock() + device._max_adc_value.value = 32767 + + device.configure_streaming( + sample_rate=100_000.0, + channels=[0, 1], + voltage_ranges=[1.0, 1.0], + buffer_duration_s=1.0, + ) + + status = device.get_save_status() + assert status["state"] == "idle" + + def test_ring_buffer_property(self, mock_picosdk: MockPicoSDK) -> None: + """Test that ring_buffer property returns the buffer.""" + from picostream.device import PicoscopeBufferedStream + + device = PicoscopeBufferedStream("test_device", resolution=12) + device._chandle = MagicMock() + device._is_connected = True + device._max_adc_value = MagicMock() + device._max_adc_value.value = 32767 + + assert device.ring_buffer is None + + device.configure_streaming( + sample_rate=1_000_000.0, + channels=[0, 1], + voltage_ranges=[1.0, 1.0], + buffer_duration_s=1.0, + ) + + assert device.ring_buffer is not None + + def test_save_worker_captures_trigger_position_in_main_thread( + self, mock_picosdk: MockPicoSDK, temp_zarr_path: str + ) -> None: + """Test that trigger position is captured in start_save, not in worker. + + This verifies the race condition fix where pre-trigger data could be + overwritten if the worker thread was delayed in starting. + """ + from picostream.device import PicoscopeBufferedStream + + device = PicoscopeBufferedStream("test_device", resolution=12) + device._chandle = MagicMock() + device._is_connected = True + device._max_adc_value = MagicMock() + device._max_adc_value.value = 32767 + + device.configure_streaming( + sample_rate=100_000.0, + channels=[0, 1], + voltage_ranges=[1.0, 1.0], + buffer_duration_s=2.0, + ) + + # Write some data to the ring buffer first + import numpy as np + + test_data = np.ones((1000, 2), dtype=np.int16) * 100 + device._ring_buffer.write(test_data) + + # Capture the expected trigger position before starting save + expected_trigger_pos = device._ring_buffer.get_snapshot() + + # Start save - this should capture the trigger position in main thread + device.start_save(lookback_seconds=0.5, output_path=temp_zarr_path) + + # Verify the save worker received the pre-captured position + # The worker args should include the trigger_pos as second argument + assert device._save_thread is not None + + # Stop the save + device.stop_save(keep=True) + deadline = time.time() + 5.0 + while not device.is_save_finished() and time.time() < deadline: + time.sleep(0.01) + + # The key assertion: trigger_pos was captured BEFORE thread spawn + # If the old bug existed, trigger_pos would be captured in worker + # which could be after more data was written + assert expected_trigger_pos > 0 # Data was written before save diff --git a/picostream/test_data_pipeline.py b/picostream/test_data_pipeline.py new file mode 100644 index 0000000..0550999 --- /dev/null +++ b/picostream/test_data_pipeline.py @@ -0,0 +1,496 @@ +"""Unit tests for DataPipeline class.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from picostream.acquisition_rate import AcquisitionRate, DownsampleMode +from picostream.data_pipeline import DataPipeline, PipelineConfig + + +@pytest.fixture +def base_config(): + """Provide a base pipeline configuration for testing.""" + return PipelineConfig( + resolution=16, + voltage_ranges={0: 20.0, 1: 20.0}, + offsets_v={0: 0.0, 1: 0.0}, + max_adc_value=32767, + target_plot_points=20000, + ) + + +@pytest.fixture +def base_acquisition_rate(): + """Provide a base acquisition rate for testing.""" + return AcquisitionRate( + hardware_rate_hz=125e6, + num_channels=2, + downsample_ratio=1, + downsample_mode=DownsampleMode.NONE, + ) + + +@pytest.fixture +def aggregate_acquisition_rate(): + """Provide an acquisition rate with AGGREGATE mode and downsampling.""" + return AcquisitionRate( + hardware_rate_hz=125e6, + num_channels=2, + downsample_ratio=10, + downsample_mode=DownsampleMode.AGGREGATE, + ) + + +@pytest.fixture +def base_pipeline(base_config, base_acquisition_rate): + """Provide a base pipeline for testing.""" + return DataPipeline(base_config, base_acquisition_rate) + + +@pytest.fixture +def aggregate_pipeline(base_config, aggregate_acquisition_rate): + """Provide a pipeline with AGGREGATE mode.""" + return DataPipeline(base_config, aggregate_acquisition_rate) + + +class TestPipelineConfig: + """Tests for PipelineConfig dataclass.""" + + def test_config_immutable(self, base_config): + """Test that config is frozen and immutable.""" + with pytest.raises(AttributeError): + base_config.resolution = 12 + + +class TestDataPipelineInit: + """Tests for DataPipeline initialisation.""" + + def test_init_with_base_config(self, base_pipeline): + """Test pipeline initialisation.""" + assert base_pipeline.acquisition_rate.downsample_mode == DownsampleMode.NONE + assert base_pipeline._last_decimation == 1 + assert len(base_pipeline._remainder) == 0 + + def test_actual_rate_defaults_to_requested( + self, base_pipeline, base_acquisition_rate + ): + """Test that actual rate comes from acquisition_rate.""" + assert base_pipeline.acquisition_rate.per_channel_rate_hz == 62.5e6 + + def test_actual_rate_uses_hardware_rate_when_set(self, base_acquisition_rate): + """Test that actual rate uses hardware rate when available.""" + acquisition_rate = AcquisitionRate( + hardware_rate_hz=100e6, + num_channels=2, + downsample_ratio=1, + downsample_mode=DownsampleMode.NONE, + ) + assert acquisition_rate.per_channel_rate_hz == 50.0e6 + + +class TestRateCalculations: + """Tests for rate calculation properties.""" + + def test_per_channel_rate_no_downsample(self, base_acquisition_rate): + """Test per channel rate without hardware downsampling.""" + # 125 MS/s total / 2 channels = 62.5 MS/s per channel + assert base_acquisition_rate.per_channel_rate_hz == 62.5e6 + + def test_per_channel_rate_with_downsample(self): + """Test per channel rate with hardware downsampling.""" + acquisition_rate = AcquisitionRate( + hardware_rate_hz=125e6, + num_channels=2, + downsample_ratio=10, + downsample_mode=DownsampleMode.NONE, + ) + # 125 MS/s total / 2 channels / 10 = 6.25 MS/s per channel + assert acquisition_rate.per_channel_rate_hz == 6.25e6 + + def test_storage_rate_normal_mode(self, base_acquisition_rate): + """Test storage rate in normal mode.""" + # Same as per_channel_rate in NONE mode * 2 channels + assert base_acquisition_rate.storage_rate_hz == 125e6 + + def test_storage_rate_aggregate_mode(self, aggregate_acquisition_rate): + """Test storage rate in AGGREGATE mode (2x samples).""" + # AGGREGATE mode doubles the storage rate (125e6 / 10) * 2 = 25e6 + assert aggregate_acquisition_rate.storage_rate_hz == 25e6 + + +class TestTimeAxis: + """Tests for time axis generation.""" + + def test_create_time_axis_empty(self, base_pipeline): + """Test time axis creation with empty window duration.""" + time_axis = base_pipeline.create_time_axis_for_window(0, 1) + assert len(time_axis) == 0 + + def test_create_time_axis_normal_mode_lossless(self, base_pipeline): + """Test time axis in normal mode without decimation.""" + # 5-second window, decimation=1 (lossless mode) + # 5s * 62.5 MS/s per channel / 2 = 156250000 time points + time_axis = base_pipeline.create_time_axis_for_window(5.0, 1) + + # No decimation means no min/max pairs + n_time_points = int(base_pipeline.time_to_samples(5.0) // 1) + assert len(time_axis) == n_time_points + assert time_axis[0] == 0.0 + assert abs(time_axis[-1] - 5.0) < 1e-12 + + def test_create_time_axis_normal_mode_with_decimation(self, base_pipeline): + """Test time axis in normal mode with decimation.""" + # 5-second window, decimation=100 + # Software decimation produces min/max pairs + time_axis = base_pipeline.create_time_axis_for_window(5.0, 100) + + # With decimation > 1 and target_plot_points set, we get min/max pairs + n_time_points = int(base_pipeline.time_to_samples(5.0) // 100) + assert len(time_axis) == n_time_points * 2 # Doubled for min/max + assert abs(time_axis[-1] - 5.0) < 1e-12 + # First 2 values should be same (first time point repeated for min/max) + assert time_axis[0] == time_axis[1] + + def test_create_time_axis_normal_mode_lossless_no_double( + self, base_acquisition_rate + ): + """Test lossless mode doesn't double time axis even with target_plot_points.""" + # Create config with target_plot_points=None for true lossless + from picostream.data_pipeline import PipelineConfig + + config_lossless = PipelineConfig( + resolution=16, + voltage_ranges={0: 20.0, 1: 20.0}, + offsets_v={0: 0.0, 1: 0.0}, + max_adc_value=32767, + target_plot_points=None, # Lossless + ) + pipeline_lossless = DataPipeline(config_lossless, base_acquisition_rate) + + time_axis = pipeline_lossless.create_time_axis_for_window(5.0, 1) + n_time_points = int(pipeline_lossless.time_to_samples(5.0) // 1) + + # Lossless mode should NOT double the time axis + assert len(time_axis) == n_time_points + assert time_axis[0] != time_axis[1] if len(time_axis) > 1 else True + + def test_create_time_axis_aggregate_mode(self, aggregate_pipeline): + """Test time axis in AGGREGATE mode with fixed window duration.""" + # 5-second window in AGGREGATE mode + # AGGREGATE mode: time_to_samples() returns samples already accounting for min/max pairs. + # With 10x downsample: display_rate_per_channel = 6.25e6 samples/s + # 5s window = 31.25e6 time points * 2 for min/max = 62.5M storage samples + # Each time point produces 2 time axis values (repeated), so 62.5M total values. + time_axis = aggregate_pipeline.create_time_axis_for_window(5.0, 1) + + # time_to_samples returns storage samples (already doubled for AGGREGATE) + window_samples = aggregate_pipeline.time_to_samples(5.0) + # Each time point is a min/max pair (2 storage samples), doubled for display + expected_length = int(window_samples // 2 * 2) + assert len(time_axis) == expected_length + # Should span exactly 5 seconds + assert abs(time_axis[-1] - 5.0) < 1e-12 + # First 2 values should be same (first time point repeated for min/max) + assert time_axis[0] == time_axis[1] + + def test_create_time_axis_with_data_duration(self, base_pipeline): + """Test right-aligned time axis when data_duration < window_duration.""" + # 5-second window, but only 3 seconds of data + time_axis = base_pipeline.create_time_axis_for_window(5.0, 1, data_duration=3.0) + + # Should start at 5.0 - 3.0 = 2.0 + assert time_axis[0] == 2.0 + assert abs(time_axis[-1] - 5.0) < 1e-12 + + +class TestDecimation: + """Tests for decimation calculations.""" + + def test_calculate_decimation_no_need(self, base_pipeline): + """Test decimation factor when not needed.""" + # Window with fewer samples than target + assert base_pipeline.calculate_decimation(10000) == 1 + + def test_calculate_decimation_needed(self, base_pipeline): + """Test decimation factor calculation.""" + # 400k samples / 20k target = 20x decimation + assert base_pipeline.calculate_decimation(400000) == 20 + + def test_should_invalidate_cache_same_decimation(self, base_pipeline): + """Test cache invalidation detection - same decimation.""" + base_pipeline._last_decimation = 10 + assert not base_pipeline.should_invalidate_cache(10) + + def test_should_invalidate_cache_different(self, base_pipeline): + """Test cache invalidation detection - different decimation.""" + base_pipeline._last_decimation = 10 + assert base_pipeline.should_invalidate_cache(20) + + +class TestChannelProcessing: + """Tests for channel data processing.""" + + def test_process_empty_data(self, base_pipeline): + """Test processing empty data.""" + data, has_data = base_pipeline.process_channel_data( + np.array([], dtype=np.int16), 0, 1 + ) + assert len(data) == 0 + assert not has_data + + def test_process_small_data_no_decimation(self, base_pipeline): + """Test processing small amount of data.""" + raw_data = np.array([1000, 2000, 3000], dtype=np.int16) + data, has_data = base_pipeline.process_channel_data(raw_data, 0, 1) + + assert has_data + assert len(data) == 3 + # Check conversion happened (mV range is +/-20V = 40000mV total) + # 1000 counts / 32767 * 20000mV = expected mV + expected_mv = 1000 * 20000.0 / 32767 + assert abs(data[0] - expected_mv) < 0.1 + + def test_process_with_decimation(self, base_pipeline): + """Test processing with decimation.""" + # 100 samples, decimation factor 10 + raw_data = np.arange(100, dtype=np.int16) + data, has_data = base_pipeline.process_channel_data(raw_data, 0, 10) + + assert has_data + # 100 samples / 10 = 10 groups, each produces 2 values (min/max) + assert len(data) == 20 + + def test_process_with_remainder(self, base_pipeline): + """Test that remainder is handled correctly.""" + # First chunk: 15 samples with decimation 10 + raw_data1 = np.arange(15, dtype=np.int16) + data1, has_data1 = base_pipeline.process_channel_data(raw_data1, 0, 10) + + # Should decimate 10 samples (1 group), keep 5 as remainder + # 1 group produces 2 values (min/max pair) + assert has_data1 + assert len(data1) == 2 # 10 samples / 10 = 1 group = 2 values + + # Second chunk: 15 more samples + raw_data2 = np.arange(15, 30, dtype=np.int16) + data2, has_data2 = base_pipeline.process_channel_data(raw_data2, 0, 10) + + # 5 remainder + 15 new = 20, decimate 10 (2 groups), keep 0 remainder + # 2 groups produce 4 values + assert has_data2 + assert len(data2) == 4 + + def test_process_with_offset(self, base_config, base_acquisition_rate): + """Test that offset is applied correctly.""" + config = PipelineConfig( + resolution=16, + voltage_ranges={0: 20.0, 1: 20.0}, + offsets_v={0: 0.5, 1: 0.0}, # 0.5V offset on channel 0 + max_adc_value=32767, + target_plot_points=20000, + ) + pipeline = DataPipeline(config, base_acquisition_rate) + + raw_data = np.array([0], dtype=np.int16) + data, has_data = pipeline.process_channel_data(raw_data, 0, 1) + + assert has_data + # 0 ADC counts -> 0 mV + 0.5V offset = 500mV + assert abs(data[0] - 500.0) < 0.1 + + def test_channel_independence(self, base_pipeline): + """Test that channels maintain independent remainders.""" + # Process different amounts for each channel + raw_data = np.arange(15, dtype=np.int16) + data0, _ = base_pipeline.process_channel_data(raw_data, 0, 10) + data1, _ = base_pipeline.process_channel_data(raw_data, 1, 10) + + # Both should have remainders + assert len(base_pipeline._remainder[0]) > 0 + assert len(base_pipeline._remainder[1]) > 0 + + +class TestLagMetrics: + """Tests for lag metrics calculation.""" + + def test_lag_metrics_zero(self, base_pipeline): + """Test lag calculation when consumer caught up.""" + metrics = base_pipeline.get_lag_metrics(1000, 1000) + assert metrics.samples == 0 + assert metrics.milliseconds == 0.0 + + def test_lag_metrics_positive(self, base_pipeline): + """Test lag calculation with positive lag.""" + # 125 MS/s storage rate + # 62500 samples lag at 125 MS/s = 0.5ms + metrics = base_pipeline.get_lag_metrics(100625, 100000) + assert metrics.samples == 625 + assert abs(metrics.milliseconds - 0.005) < 0.001 + + def test_lag_formatting(self, base_pipeline): + """Test lag formatting.""" + lag_str = base_pipeline.format_lag(100625, 100000) + assert "ms" in lag_str + # Now uses unit prefixes (S, kS, MS) instead of "samples" + assert " S" in lag_str or " MS" in lag_str or " kS" in lag_str + + def test_lag_formatting_unit_prefixes(self, base_pipeline): + """Test lag formatting with different unit prefixes.""" + # Small lag - samples shown as-is + small_lag = base_pipeline.format_lag(1000100, 1000000) + assert "100 S" in small_lag + + # Medium lag - kilosamples + medium_lag = base_pipeline.format_lag(1100000, 1000000) + assert "kS" in medium_lag + + # Large lag - megasamples (typical case) + large_lag = base_pipeline.format_lag(5000000, 1000000) + assert "MS" in large_lag + assert "4 MS" in large_lag or "4.0 MS" in large_lag + + +class TestConversions: + """Tests for sample/time conversions.""" + + def test_samples_to_time_normal(self, base_acquisition_rate): + """Test samples to time conversion in normal mode.""" + # 62500000 samples at 125 MS/s total / 2 = 62.5 MS/s per channel = 1 second + duration = base_acquisition_rate.samples_to_seconds(62500000) + assert abs(duration - 1.0) < 0.001 + + def test_samples_to_time_aggregate(self, aggregate_acquisition_rate): + """Test samples to time conversion in AGGREGATE mode.""" + # 125000000 samples (62500000 min/max pairs) at 6.25 MS/s = 10 seconds + duration = aggregate_acquisition_rate.samples_to_seconds(125000000) + assert abs(duration - 10.0) < 0.001 + + def test_time_to_samples(self, base_acquisition_rate): + """Test time to samples conversion.""" + # 1 second at 62.5 MS/s = 62500000 samples + samples = base_acquisition_rate.seconds_to_samples(1.0) + assert samples == 62500000 + + def test_time_to_samples_aggregate(self, aggregate_acquisition_rate): + """Test time to samples returns buffer samples in AGGREGATE mode.""" + # AGGREGATE mode stores min/max pairs, so returns 2 * time points + samples = aggregate_acquisition_rate.seconds_to_samples(1.0) + assert samples == 12500000 # 6.25e6 time points per second * 2 for min/max + + def test_get_display_duration_no_decimation(self, base_pipeline): + """Test display duration with no decimation.""" + # 1000 display samples (500 pairs) at decimation 1 = 500 original samples + # 500 / 62.5e6 = 8 microseconds + duration = base_pipeline.get_display_duration(1000, 1) + expected = 500 / 62.5e6 + assert abs(duration - expected) < 1e-12 + + def test_get_display_duration_with_decimation(self, base_pipeline): + """Test display duration with decimation.""" + # 2000 display samples (1000 pairs) at decimation 10 = 10000 original samples + # 10000 / 62.5e6 = 160 microseconds + duration = base_pipeline.get_display_duration(2000, 10) + expected = 10000 / 62.5e6 + assert abs(duration - expected) < 1e-12 + + def test_get_display_duration_aggregate_mode(self, aggregate_pipeline): + """Test display duration in AGGREGATE mode.""" + # Should use per_channel_rate (6.25 MS/s with downsample=10), not storage_rate + # 1000 display samples (500 pairs) at decimation 1 = 500 original samples + # 500 / 6.25e6 = 80 microseconds + duration = aggregate_pipeline.get_display_duration(1000, 1) + expected = 500 / 6.25e6 + assert abs(duration - expected) < 1e-12 + + def test_get_display_duration_aggregate_mode_uses_display_rate( + self, aggregate_pipeline + ): + """Test that AGGREGATE mode uses display_rate_per_channel.""" + duration = aggregate_pipeline.get_display_duration(1000, 1) + # 1000 samples = 500 time points * 1 decimation = 500 original samples + # 500 / 6.25e6 = 80 microseconds + expected = 500 / 6.25e6 + assert abs(duration - expected) < 1e-12 + + def test_get_display_duration_zero_samples(self, base_pipeline): + """Test display duration with zero samples.""" + assert base_pipeline.get_display_duration(0, 1) == 0.0 + + def test_get_display_duration_zero_decimation(self, base_pipeline): + """Test display duration with zero decimation.""" + assert base_pipeline.get_display_duration(100, 0) == 0.0 + + +class TestCacheManagement: + """Tests for cache invalidation.""" + + def test_invalidate_cache_clears_remainders(self, base_pipeline): + """Test that cache invalidation clears remainders.""" + # Create some remainders + raw_data = np.arange(15, dtype=np.int16) + base_pipeline.process_channel_data(raw_data, 0, 10) + assert len(base_pipeline._remainder) > 0 + + # Invalidate cache + base_pipeline.invalidate_cache() + assert len(base_pipeline._remainder) == 0 + assert base_pipeline._last_decimation == 1 + + def test_max_cache_points(self, base_pipeline): + """Test max cache points calculation.""" + # 100000 samples, decimation 10 + max_points = base_pipeline.get_max_cache_points(100000, 10) + # (100000 // 10) * 2 = 20000 (display values, i.e., min/max pairs) + assert max_points == 20000 + + def test_max_time_points(self, base_pipeline): + """Test max time points calculation.""" + # 100000 samples, decimation 10 + max_points = base_pipeline.get_max_time_points(100000, 10) + # 100000 // 10 = 10000 (time points, not min/max pairs) + assert max_points == 10000 + + +class TestModeChecks: + """Tests for mode checking methods.""" + + def test_is_aggregate_mode_true(self, aggregate_pipeline): + """Test AGGREGATE mode detection.""" + assert aggregate_pipeline.is_aggregate_mode() + + def test_is_aggregate_mode_false(self, base_pipeline): + """Test non-AGGREGATE mode detection.""" + assert not base_pipeline.is_aggregate_mode() + + +class TestEdgeCases: + """Tests for edge cases and error handling.""" + + def test_negative_lag_clamped(self, base_pipeline): + """Test that negative lag is clamped to zero.""" + # Consumer ahead of producer (shouldn't happen, but handle gracefully) + metrics = base_pipeline.get_lag_metrics(100, 200) + assert metrics.samples == 0 + + def test_process_with_missing_voltage_range( + self, base_config, base_acquisition_rate + ): + """Test processing when voltage range not configured.""" + config = PipelineConfig( + resolution=16, + voltage_ranges={}, # Empty ranges + offsets_v={0: 0.0, 1: 0.0}, + max_adc_value=32767, + target_plot_points=20000, + ) + pipeline = DataPipeline(config, base_acquisition_rate) + + raw_data = np.array([1000, 2000], dtype=np.int16) + data, has_data = pipeline.process_channel_data(raw_data, 0, 1) + + # Should still produce data, just as floats + assert has_data + assert len(data) == 2 + assert data.dtype == np.float64 diff --git a/picostream/test_live_plotter.py b/picostream/test_live_plotter.py new file mode 100644 index 0000000..a2f83d3 --- /dev/null +++ b/picostream/test_live_plotter.py @@ -0,0 +1,437 @@ +"""Tests for the LivePlotter class.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from picostream.acquisition_rate import AcquisitionRate, DownsampleMode +from picostream.ring_buffer import RingBuffer + +pytest.importorskip("PyQt6.QtWidgets", reason="PyQt6 not available") + +from picostream.dfplot import LivePlotter + + +@pytest.fixture(scope="session") +def qapp(): + """Provide a Qt application instance for the test session.""" + from PyQt6.QtWidgets import QApplication + + app = QApplication.instance() + if app is None: + app = QApplication([]) + return app + + +@pytest.fixture +def sample_metadata(): + """Provide sample metadata for plotter tests.""" + return { + "sample_rate": 62.5e6, + "channels": [0, 1], + "voltage_ranges": {0: 20.0, 1: 20.0}, + "offsets_v": {0: 0.0, 1: 0.0}, + "resolution": 16, + "downsample_mode": "NONE", + } + + +@pytest.fixture +def populated_buffer(): + """Create a ring buffer with test data.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=1000.0, num_channels=2) + + data = np.zeros((500, 2), dtype=np.int16) + data[:, 0] = np.sin(np.linspace(0, 4 * np.pi, 500)) * 10000 + data[:, 1] = np.cos(np.linspace(0, 4 * np.pi, 500)) * 8000 + + buffer.write(data) + return buffer + + +@pytest.fixture +def acquisition_rate_none(): + """Create an AcquisitionRate for NONE mode.""" + return AcquisitionRate( + hardware_rate_hz=125e6, + num_channels=2, + downsample_ratio=1, + downsample_mode=DownsampleMode.NONE, + ) + + +@pytest.fixture +def acquisition_rate_aggregate(): + """Create an AcquisitionRate for AGGREGATE mode.""" + return AcquisitionRate( + hardware_rate_hz=125e6, + num_channels=2, + downsample_ratio=1, + downsample_mode=DownsampleMode.AGGREGATE, + ) + + +class TestLivePlotterBasic: + """Basic functionality tests.""" + + def test_init_with_buffer(self, qapp, sample_metadata, populated_buffer): + """Test plotter initialisation with a ring buffer.""" + plotter = LivePlotter( + ring_buffer=populated_buffer, + channels=sample_metadata["channels"], + voltage_ranges=sample_metadata["voltage_ranges"], + offsets_v=sample_metadata["offsets_v"], + resolution=sample_metadata["resolution"], + downsample_mode=sample_metadata["downsample_mode"], + display_window_seconds=0.5, + ) + + assert plotter.ring_buffer is populated_buffer + # sample_rate is only available after acquisition_rate is set + assert plotter.channels == [0, 1] + assert plotter.resolution == 16 + + plotter.close() + + def test_init_without_buffer(self, qapp, sample_metadata): + """Test plotter initialisation with None buffer (waiting state).""" + plotter = LivePlotter( + ring_buffer=None, + channels=sample_metadata["channels"], + voltage_ranges=sample_metadata["voltage_ranges"], + offsets_v=sample_metadata["offsets_v"], + resolution=sample_metadata["resolution"], + downsample_mode=sample_metadata["downsample_mode"], + ) + + assert plotter.ring_buffer is None + + plotter.close() + + def test_set_ring_buffer(self, qapp, sample_metadata, populated_buffer): + """Test updating the ring buffer reference.""" + plotter = LivePlotter( + ring_buffer=None, + channels=sample_metadata["channels"], + voltage_ranges=sample_metadata["voltage_ranges"], + offsets_v=sample_metadata["offsets_v"], + resolution=sample_metadata["resolution"], + downsample_mode=sample_metadata["downsample_mode"], + ) + + assert plotter.ring_buffer is None + + plotter.set_ring_buffer(populated_buffer) + assert plotter.ring_buffer is populated_buffer + assert plotter.last_position == 0 + + plotter.close() + + def test_set_ring_buffer_to_none(self, qapp, sample_metadata, populated_buffer): + """Test clearing the ring buffer reference.""" + plotter = LivePlotter( + ring_buffer=populated_buffer, + channels=sample_metadata["channels"], + voltage_ranges=sample_metadata["voltage_ranges"], + offsets_v=sample_metadata["offsets_v"], + resolution=sample_metadata["resolution"], + downsample_mode=sample_metadata["downsample_mode"], + ) + + plotter.set_ring_buffer(None) + assert plotter.ring_buffer is None + + plotter.close() + + +class TestLivePlotterDisplay: + """Display functionality tests.""" + + def test_pipeline_time_axis( + self, qapp, sample_metadata, populated_buffer, acquisition_rate_none + ): + """Test time axis creation via DataPipeline for requested sample count.""" + # Set acquisition rate on buffer so plotter can initialize pipeline + populated_buffer.set_acquisition_rate(acquisition_rate_none) + + plotter = LivePlotter( + ring_buffer=populated_buffer, + channels=sample_metadata["channels"], + voltage_ranges=sample_metadata["voltage_ranges"], + offsets_v=sample_metadata["offsets_v"], + resolution=sample_metadata["resolution"], + downsample_mode="NONE", + display_window_seconds=1.0, + decimation_factor=1, + ) + + # Notify plotter that buffer has acquisition rate + plotter._update_pipeline_sample_rate() + + # Request time axis for 1.0s duration with decimation=1 + assert plotter.pipeline is not None + time_axis = plotter.pipeline.create_time_axis_for_window(1.0, 100) + + assert len(time_axis) > 0 + assert time_axis[-1] > time_axis[0] # Monotonic + # Time axis is relative from 0 to window_duration + assert abs(time_axis[-1] - 1.0) < 0.01 # Window is 1.0s + + plotter.close() + + def test_pipeline_time_axis_normal_mode( + self, qapp, sample_metadata, populated_buffer, acquisition_rate_none + ): + """Test time axis via DataPipeline in normal (non-aggregate) mode.""" + populated_buffer.set_acquisition_rate(acquisition_rate_none) + + plotter = LivePlotter( + ring_buffer=populated_buffer, + channels=sample_metadata["channels"], + voltage_ranges=sample_metadata["voltage_ranges"], + offsets_v=sample_metadata["offsets_v"], + resolution=sample_metadata["resolution"], + downsample_mode="NONE", + display_window_seconds=5.0, + decimation_factor=100, + ) + + plotter._update_pipeline_sample_rate() + + assert plotter.pipeline is not None + time_axis = plotter.pipeline.create_time_axis_for_window(5.0, 100) + + assert len(time_axis) > 0 + # Normal mode: linear from 0 to window_duration + assert abs(time_axis[-1] - 5.0) < 0.01 + + plotter.close() + + def test_pipeline_time_axis_aggregate_mode( + self, qapp, sample_metadata, populated_buffer, acquisition_rate_aggregate + ): + """Test time axis via DataPipeline in AGGREGATE mode. + + In AGGREGATE mode, n_display_points is doubled (min/max pairs). + """ + populated_buffer.set_acquisition_rate(acquisition_rate_aggregate) + + plotter = LivePlotter( + ring_buffer=populated_buffer, + channels=sample_metadata["channels"], + voltage_ranges=sample_metadata["voltage_ranges"], + offsets_v=sample_metadata["offsets_v"], + resolution=sample_metadata["resolution"], + downsample_mode="AGGREGATE", # Use AGGREGATE mode + display_window_seconds=5.0, + decimation_factor=50, + ) + + plotter._update_pipeline_sample_rate() + + # In AGGREGATE mode: + # - Each time point is repeated twice for min/max pairs + assert plotter.pipeline is not None + time_axis = plotter.pipeline.create_time_axis_for_window(5.0, 50) + + # AGGREGATE mode: should have time axis with min/max pairs + assert len(time_axis) > 0 + # Should be multiples of 2 due to min/max pairs + assert len(time_axis) % 2 == 0 + + plotter.close() + + def test_format_sample_count(self, qapp, sample_metadata): + """Test sample count formatting.""" + plotter = LivePlotter( + ring_buffer=None, + channels=sample_metadata["channels"], + voltage_ranges=sample_metadata["voltage_ranges"], + offsets_v=sample_metadata["offsets_v"], + resolution=sample_metadata["resolution"], + downsample_mode=sample_metadata["downsample_mode"], + ) + + assert plotter.format_sample_count(500) == "500" + assert plotter.format_sample_count(1500) == "1.50K" + assert plotter.format_sample_count(2500000) == "2.50M" + assert plotter.format_sample_count(1500000000) == "1.50G" + + plotter.close() + + def test_format_rate_sps(self, qapp, sample_metadata): + """Test sample rate formatting.""" + plotter = LivePlotter( + ring_buffer=None, + channels=sample_metadata["channels"], + voltage_ranges=sample_metadata["voltage_ranges"], + offsets_v=sample_metadata["offsets_v"], + resolution=sample_metadata["resolution"], + downsample_mode=sample_metadata["downsample_mode"], + ) + + assert plotter._format_rate_sps(500) == "500.00 S/s" + assert plotter._format_rate_sps(1500) == "1.50 kS/s" + assert plotter._format_rate_sps(2500000) == "2.50 MS/s" + assert plotter._format_rate_sps(1500000000) == "1.50 GS/s" + + plotter.close() + + +class TestLivePlotterChannelVisibility: + """Channel visibility tests.""" + + def test_set_channel_visible(self, qapp, sample_metadata, populated_buffer): + """Test toggling channel visibility.""" + plotter = LivePlotter( + ring_buffer=populated_buffer, + channels=sample_metadata["channels"], + voltage_ranges=sample_metadata["voltage_ranges"], + offsets_v=sample_metadata["offsets_v"], + resolution=sample_metadata["resolution"], + downsample_mode=sample_metadata["downsample_mode"], + ) + + assert plotter.lines[0].visible + + plotter.set_channel_visible(0, False) + assert not plotter.lines[0].visible + + plotter.set_channel_visible(0, True) + assert plotter.lines[0].visible + + plotter.close() + + +class TestLivePlotterYLimits: + """Y-axis limit tests for division-based scaling.""" + + def test_y_axis_division_range(self, qapp, sample_metadata, populated_buffer): + """Test that Y-axis is set to division range (-5 to +5).""" + plotter = LivePlotter( + ring_buffer=populated_buffer, + channels=sample_metadata["channels"], + voltage_ranges=sample_metadata["voltage_ranges"], + offsets_v=sample_metadata["offsets_v"], + resolution=sample_metadata["resolution"], + downsample_mode=sample_metadata["downsample_mode"], + ) + + # Verify Y-axis range is set to divisions (-5 to +5) via camera + rect = plotter.view.camera.get_state()["rect"] + y_bottom, y_top = rect.bottom, rect.top + # Camera may add small padding, check approximate range + assert abs(y_bottom - (-5)) < 1 + assert abs(y_top - 5) < 1 + + plotter.close() + + def test_v_div_settings(self, qapp, sample_metadata): + """Test setting V/div and Y-position settings.""" + plotter = LivePlotter( + ring_buffer=None, + channels=sample_metadata["channels"], + voltage_ranges=sample_metadata["voltage_ranges"], + offsets_v=sample_metadata["offsets_v"], + resolution=sample_metadata["resolution"], + downsample_mode=sample_metadata["downsample_mode"], + ) + + # Set V/div settings for both channels + v_div_settings = {0: 0.1, 1: 0.1} # 100mV/div + y_pos_settings = {0: 0.0, 1: 0.0} + + plotter.set_v_div_settings(v_div_settings, y_pos_settings) + + assert plotter.v_div_volts[0] == 0.1 + assert plotter.v_div_volts[1] == 0.1 + assert plotter.y_pos_divs[0] == 0.0 + assert plotter.y_pos_divs[1] == 0.0 + + plotter.close() + + def test_channel_clip_detection(self, qapp, sample_metadata): + """Test channel clipping detection.""" + plotter = LivePlotter( + ring_buffer=None, + channels=[0, 1], + voltage_ranges={0: 5.0, 1: 10.0}, + offsets_v={0: 0.0, 1: 0.0}, + resolution=12, + downsample_mode="NONE", + ) + + # Set V/div = 1V, Y-pos = 0 (range is -5 to +5 divisions) + plotter.v_div_volts[0] = 1.0 + plotter.y_pos_divs[0] = 0.0 + + # No clipping within range (±2 divisions = ±2V) + plotter._display_cache[0] = np.array([-2.0, 2.0]) # ±2 divisions + assert plotter.is_channel_clipped(0) == 0 + + # Clipping above range (6 divisions exceeds +5) + plotter._display_cache[0] = np.array([0.0, 6.0]) + assert plotter.is_channel_clipped(0) == 1 + + # Clipping below range (-6 divisions exceeds -5) + plotter._display_cache[0] = np.array([-6.0, 0.0]) + assert plotter.is_channel_clipped(0) == -1 + + plotter.close() + + +class TestLivePlotterPipeline: + """DataPipeline integration tests.""" + + def test_pipeline_initialised( + self, qapp, sample_metadata, populated_buffer, acquisition_rate_none + ): + """Test that DataPipeline is initialised when acquisition rate is set.""" + populated_buffer.set_acquisition_rate(acquisition_rate_none) + + plotter = LivePlotter( + ring_buffer=populated_buffer, + channels=sample_metadata["channels"], + voltage_ranges=sample_metadata["voltage_ranges"], + offsets_v=sample_metadata["offsets_v"], + resolution=sample_metadata["resolution"], + downsample_mode=sample_metadata["downsample_mode"], + ) + + plotter._update_pipeline_sample_rate() + + assert plotter.pipeline is not None + assert plotter.acquisition_rate.num_channels == 2 + assert plotter.acquisition_rate.downsample_mode == DownsampleMode.NONE + + plotter.close() + + def test_set_acquisition_rate(self, qapp, sample_metadata): + """Test setting acquisition rate on plotter.""" + plotter = LivePlotter( + ring_buffer=None, + channels=sample_metadata["channels"], + voltage_ranges=sample_metadata["voltage_ranges"], + offsets_v=sample_metadata["offsets_v"], + resolution=sample_metadata["resolution"], + downsample_mode=sample_metadata["downsample_mode"], + ) + + rate = AcquisitionRate( + hardware_rate_hz=100e6, + num_channels=1, + downsample_ratio=2, + downsample_mode=DownsampleMode.AGGREGATE, + ) + + plotter.set_acquisition_rate(rate) + + assert plotter.pipeline is not None + assert plotter.acquisition_rate.num_channels == 1 + assert plotter.resolution == 16 # Remains from initialisation + assert plotter.acquisition_rate.downsample_mode == DownsampleMode.AGGREGATE + assert plotter.acquisition_rate.downsample_ratio == 2 + assert plotter.voltage_ranges[0] == 20.0 + + plotter.close() diff --git a/picostream/test_max_adc_fix.py b/picostream/test_max_adc_fix.py new file mode 100644 index 0000000..731d2c1 --- /dev/null +++ b/picostream/test_max_adc_fix.py @@ -0,0 +1,73 @@ +"""Test that max_adc is calculated correctly from resolution.""" + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + + +def test_max_adc_calculation(): + """Test max_adc calculation for various resolutions.""" + print("Testing max_adc calculation from resolution:") + print("-" * 50) + + for res in [8, 12, 14, 15, 16]: + max_adc = (1 << (res - 1)) - 1 + print(f" {res}-bit: max_adc = {max_adc}") + + # Verify against what we expect + assert (1 << (12 - 1)) - 1 == 2047, "12-bit max_adc should be 2047" + assert (1 << (14 - 1)) - 1 == 8191, "14-bit max_adc should be 8191" + assert (1 << (16 - 1)) - 1 == 32767, "16-bit max_adc should be 32767" + + print("\n✓ All assertions passed") + + +def test_voltage_conversion(): + """Test that voltage conversion works correctly with calculated max_adc.""" + import numpy as np + + from picostream.conversion_utils import adc_to_mV + + print("\nTesting voltage conversion with corrected max_adc:") + print("-" * 50) + + # Simulate 1V signal on 5V range, 14-bit + resolution = 14 + max_adc = (1 << (resolution - 1)) - 1 # 8191 + voltage_range_v = 5.0 + + # ADC counts for 1V (centered at 0, so ±0.5V) + true_voltage_v = 0.5 # Peak of sine wave + adc_counts = int((true_voltage_v / voltage_range_v) * max_adc) + + print(f" Resolution: {resolution}-bit") + print(f" Using max_adc: {max_adc}") + print(f" Voltage range: ±{voltage_range_v}V") + print(f" True voltage: {true_voltage_v}V") + print(f" ADC counts: {adc_counts}") + + # Convert back + voltage_mv = adc_to_mV( + np.array([adc_counts], dtype=np.int16), voltage_range_v, max_adc + ) + + print(f" Converted voltage: {voltage_mv[0]:.2f} mV = {voltage_mv[0] / 1000:.4f} V") + print(f" Error: {abs(voltage_mv[0] / 1000 - true_voltage_v) * 100:.2f}%") + + # Compare with WRONG max_adc (32767) + wrong_max_adc = 32767 + voltage_mv_wrong = adc_to_mV( + np.array([adc_counts], dtype=np.int16), voltage_range_v, wrong_max_adc + ) + + print(f"\n With WRONG max_adc={wrong_max_adc}:") + print( + f" Converted voltage: {voltage_mv_wrong[0]:.2f} mV = {voltage_mv_wrong[0] / 1000:.4f} V" + ) + print(f" Error factor: {voltage_mv[0] / voltage_mv_wrong[0]:.2f}x") + + +if __name__ == "__main__": + test_max_adc_calculation() + test_voltage_conversion() diff --git a/picostream/test_rate_contract.py b/picostream/test_rate_contract.py new file mode 100644 index 0000000..2e65f39 --- /dev/null +++ b/picostream/test_rate_contract.py @@ -0,0 +1,263 @@ +"""Integration tests for rate handling across components. + +These tests verify that rate information flows correctly from the device +through the ring buffer to the data pipeline, ensuring time calculations +are accurate. +""" + +import numpy as np + +from picostream.acquisition_rate import AcquisitionRate, DownsampleMode +from picostream.data_pipeline import DataPipeline, PipelineConfig +from picostream.ring_buffer import RingBuffer + + +class TestRateContractDeviceToPipeline: + """Tests for rate contract between device, ring buffer, and pipeline. + + These tests catch issues like double-counting channels or misapplying + downsampling ratios across component boundaries. + """ + + def _simulate_device_setting_storage_rate( + self, + ring_buffer: RingBuffer, + hardware_rate_hz: float, + downsample_ratio: int, + downsample_mode: str, + n_channels: int, + ) -> AcquisitionRate: + """Simulate how the device sets AcquisitionRate in start_streaming(). + + This replicates the logic in PicoscopeBufferedStream.start_streaming() + to ensure tests reflect actual device behavior. + + Returns + ------- + AcquisitionRate + The acquisition rate object that gets set on the ring buffer. + """ + rate = AcquisitionRate( + hardware_rate_hz=hardware_rate_hz, + num_channels=n_channels, + downsample_ratio=downsample_ratio, + downsample_mode=DownsampleMode(downsample_mode), + ) + ring_buffer.set_acquisition_rate(rate) + return rate + + def test_single_channel_no_downsample_time_calculation(self): + """1 channel, no downsampling: 1 second of data = 1 second displayed.""" + # Setup ring buffer as device would + rb = RingBuffer(duration_s=10.0, sample_rate=62.5e6, num_channels=1) + + # Device configures the actual storage rate + rate = self._simulate_device_setting_storage_rate( + rb, + hardware_rate_hz=62.5e6, + downsample_ratio=1, + downsample_mode="NONE", + n_channels=1, + ) + + # Pipeline config and acquisition rate + config = PipelineConfig( + resolution=16, + voltage_ranges={0: 20.0}, + offsets_v={0: 0.0}, + max_adc_value=32767, + target_plot_points=20000, + ) + pipeline = DataPipeline(config, rate) + + # Write 1 second of data to ring buffer + one_second_samples = int(rate.storage_rate_hz) # 62.5M samples + data = np.random.randint( + -1000, 1000, size=(one_second_samples, 1), dtype=np.int16 + ) + rb.write(data) + + # Read it back and verify time calculation + read_data = rb.read_last(one_second_samples) + duration = pipeline.samples_to_time(len(read_data)) + + # Should be exactly 1 second (within floating point tolerance) + assert abs(duration - 1.0) < 0.001, f"Expected ~1.0s, got {duration}s" + + def test_two_channel_no_downsample_time_calculation(self): + """2 channels, no downsampling: 1 second of data = 1 second displayed.""" + rb = RingBuffer(duration_s=10.0, sample_rate=125e6, num_channels=2) + + # Device: 2 channels at 62.5 MS/s each = 125 MS/s total + rate = self._simulate_device_setting_storage_rate( + rb, + hardware_rate_hz=125e6, # 62.5e6 * 2 + downsample_ratio=1, + downsample_mode="NONE", + n_channels=2, + ) + + config = PipelineConfig( + resolution=16, + voltage_ranges={0: 20.0, 1: 20.0}, + offsets_v={0: 0.0, 1: 0.0}, + max_adc_value=32767, + target_plot_points=20000, + ) + pipeline = DataPipeline(config, rate) + + # Write 1 second of data + one_second_samples = int(rate.storage_rate_hz) # 125M samples total + data = np.random.randint( + -1000, 1000, size=(one_second_samples, 2), dtype=np.int16 + ) + rb.write(data) + + # Verify: per_channel_rate should be 62.5 MS/s + assert pipeline.acquisition_rate.per_channel_rate_hz == 62.5e6 + + # Time for 62.5M samples on one channel = 1 second + duration = pipeline.samples_to_time(62500000) + assert abs(duration - 1.0) < 0.001, f"Expected ~1.0s, got {duration}s" + + def test_ten_x_downsample_time_calculation(self): + """10x downsampling: 1 second of real time = 1 second displayed.""" + rb = RingBuffer(duration_s=10.0, sample_rate=12.5e6, num_channels=2) + + # Device: 125 MS/s hardware / 10x downsample = 12.5 MS/s storage + rate = self._simulate_device_setting_storage_rate( + rb, + hardware_rate_hz=125e6, + downsample_ratio=10, + downsample_mode="DECIMATE", + n_channels=2, + ) + + config = PipelineConfig( + resolution=16, + voltage_ranges={0: 20.0, 1: 20.0}, + offsets_v={0: 0.0, 1: 0.0}, + max_adc_value=32767, + target_plot_points=20000, + ) + pipeline = DataPipeline(config, rate) + + # Write 1 second of downsampled data + one_second_samples = int(rate.storage_rate_hz) # 12.5M samples + data = np.random.randint( + -1000, 1000, size=(one_second_samples, 2), dtype=np.int16 + ) + rb.write(data) + + # Per-channel rate should be 6.25 MS/s + assert pipeline.acquisition_rate.per_channel_rate_hz == 6.25e6 + + # Time for 6.25M samples (1 channel) = 1 second + duration = pipeline.samples_to_time(6250000) + assert abs(duration - 1.0) < 0.001, f"Expected ~1.0s, got {duration}s" + + def test_aggregate_mode_time_calculation(self): + """AGGREGATE mode: 1 second of real time = 1 second displayed.""" + rb = RingBuffer(duration_s=10.0, sample_rate=25e6, num_channels=2) + + # Device: 125 MS/s hardware / 10x * 2 (min/max) = 25 MS/s storage + rate = self._simulate_device_setting_storage_rate( + rb, + hardware_rate_hz=125e6, + downsample_ratio=10, + downsample_mode="AGGREGATE", + n_channels=2, + ) + + config = PipelineConfig( + resolution=16, + voltage_ranges={0: 20.0, 1: 20.0}, + offsets_v={0: 0.0, 1: 0.0}, + max_adc_value=32767, + target_plot_points=20000, + ) + pipeline = DataPipeline(config, rate) + + # Write 1 second of AGGREGATE data + one_second_samples = int( + rate.storage_rate_hz + ) # 25M samples (12.5M min/max pairs) + data = np.random.randint( + -1000, 1000, size=(one_second_samples, 2), dtype=np.int16 + ) + rb.write(data) + + # Storage rate is 25 MS/s (12.5M pairs/sec * 2 values/pair) + # Per-channel rate is 6.25 MS/s (125M / 2 channels / 10 downsample) + assert rate.storage_rate_hz == 25e6 + assert pipeline.acquisition_rate.per_channel_rate_hz == 6.25e6 + + # In AGGREGATE mode: + # - 25e6 samples = 12.5e6 min/max pairs + # - samples_to_time: n_time_points = 25e6 // 2 = 12.5e6 pairs + # - duration = 12.5e6 pairs / 6.25e6 per_channel_rate = 2 seconds + duration = pipeline.samples_to_time(25000000) + assert abs(duration - 2.0) < 0.001, f"Expected ~2.0s, got {duration}s" + + +class TestRateContractInvariants: + """Invariants that must hold regardless of configuration.""" + + def test_storage_rate_is_total_not_per_channel(self): + """storage_rate_hz must always represent total rate, not per-channel.""" + # Single channel + rate1 = AcquisitionRate( + hardware_rate_hz=10e6, + num_channels=1, + downsample_ratio=1, + downsample_mode=DownsampleMode.NONE, + ) + assert rate1.storage_rate_hz == 10e6 + + # Two channels - rate doubles (more samples per unit time) + rate2 = AcquisitionRate( + hardware_rate_hz=20e6, + num_channels=2, + downsample_ratio=1, + downsample_mode=DownsampleMode.NONE, + ) + assert rate2.storage_rate_hz == 20e6 # 10 MS/s per channel * 2 channels + + def test_pipeline_divides_by_channel_count(self): + """Pipeline must divide total rate by channel count for per-channel rate.""" + # Test 2-channel setup + rate2 = AcquisitionRate( + hardware_rate_hz=100e6, + num_channels=2, + downsample_ratio=1, + downsample_mode=DownsampleMode.NONE, + ) + config = PipelineConfig( + resolution=16, + voltage_ranges={0: 20.0, 1: 20.0}, + offsets_v={0: 0.0, 1: 0.0}, + max_adc_value=32767, + target_plot_points=20000, + ) + pipeline = DataPipeline(config, rate2) + + # Per-channel rate should be 50 MS/s (100 / 2 channels) + assert pipeline.acquisition_rate.per_channel_rate_hz == 50e6 + + # Test 1-channel setup + rate1 = AcquisitionRate( + hardware_rate_hz=50e6, + num_channels=1, + downsample_ratio=1, + downsample_mode=DownsampleMode.NONE, + ) + config1 = PipelineConfig( + resolution=16, + voltage_ranges={0: 20.0}, + offsets_v={0: 0.0}, + max_adc_value=32767, + target_plot_points=20000, + ) + pipeline1 = DataPipeline(config1, rate1) + + assert pipeline1.acquisition_rate.per_channel_rate_hz == 50e6 diff --git a/picostream/test_rate_invariants.py b/picostream/test_rate_invariants.py new file mode 100644 index 0000000..f9bbbd3 --- /dev/null +++ b/picostream/test_rate_invariants.py @@ -0,0 +1,329 @@ +"""Rate invariant tests - verify rate calculations remain consistent across system. + +These tests catch bugs where rates get multiplied/divided incorrectly at +system boundaries (device → ring buffer → pipeline → display). +""" + +from __future__ import annotations + +from picostream.acquisition_rate import AcquisitionRate, DownsampleMode +from picostream.data_pipeline import DataPipeline, PipelineConfig + + +class TestDeviceToRingBufferRateInvariant: + """Verify storage_rate_hz is calculated correctly from hardware parameters.""" + + def test_no_downsampling_no_aggregate(self) -> None: + """Storage rate = hardware rate when no downsampling/AGGREGATE.""" + hardware_rate_hz = 125e6 # 125 MS/s total (62.5 per channel, 2 channels) + + rate = AcquisitionRate( + hardware_rate_hz=hardware_rate_hz, + num_channels=2, + downsample_ratio=1, + downsample_mode=DownsampleMode.NONE, + ) + + expected_per_channel = hardware_rate_hz / 2 # 2 channels + assert rate.storage_rate_hz == 125e6 + assert rate.per_channel_rate_hz == expected_per_channel == 62.5e6 + + def test_with_downsampling_no_aggregate(self) -> None: + """Per-channel rate = hardware rate / (channels * downsample_ratio).""" + hardware_rate_hz = 125e6 + downsample_ratio = 10 + + rate = AcquisitionRate( + hardware_rate_hz=hardware_rate_hz, + num_channels=2, + downsample_ratio=downsample_ratio, + downsample_mode=DownsampleMode.NONE, + ) + + assert rate.storage_rate_hz == 12.5e6 + assert rate.per_channel_rate_hz == 6.25e6 + + def test_with_aggregate_doubles_storage_rate(self) -> None: + """AGGREGATE mode at 1x downsampling does not double storage rate. + + When downsample_ratio=1, AGGREGATE mode has no effect - data passes + through unchanged. Doubling only occurs when actually downsampling. + """ + hardware_rate_hz = 125e6 + + rate = AcquisitionRate( + hardware_rate_hz=hardware_rate_hz, + num_channels=2, + downsample_ratio=1, + downsample_mode=DownsampleMode.AGGREGATE, + ) + + # At 1x downsampling, AGGREGATE does not apply - rate is unchanged + assert rate.storage_rate_hz == 125e6 + + def test_with_downsample_and_aggregate(self) -> None: + """Downsampling and AGGREGATE combined.""" + hardware_rate_hz = 125e6 + downsample_ratio = 10 + + rate = AcquisitionRate( + hardware_rate_hz=hardware_rate_hz, + num_channels=2, + downsample_ratio=downsample_ratio, + downsample_mode=DownsampleMode.AGGREGATE, + ) + + assert rate.storage_rate_hz == 25e6 + assert rate.per_channel_rate_hz == 6.25e6 + + +class TestRingBufferToPlotterRatePropagation: + """Verify plotter passes storage_rate_hz to pipeline WITHOUT multiplying by channels. + + This is the specific bug that was fixed: plotter was doing + pipeline.set_actual_hardware_rate(storage_rate_hz * len(channels)) + when it should just pass storage_rate_hz directly. + """ + + def test_pipeline_receives_acquire_rate(self) -> None: + """Pipeline should receive total rate, not rate * channel_count.""" + rate = AcquisitionRate( + hardware_rate_hz=125e6, + num_channels=2, + downsample_ratio=10, + downsample_mode=DownsampleMode.NONE, + ) + + config = PipelineConfig( + resolution=16, + voltage_ranges={0: 20.0, 1: 20.0}, + offsets_v={0: 0.0, 1: 0.0}, + max_adc_value=32767, + target_plot_points=20000, + ) + pipeline = DataPipeline(config, rate) + + # Pipeline internally divides by channel count to get per-channel rate + per_channel_rate = pipeline.acquisition_rate.per_channel_rate_hz + + # Should be 6.25 MS/s per channel + assert per_channel_rate == 6.25e6 + + def test_storage_rate_from_ring_buffer_used_directly(self) -> None: + """Plotter should use ring_buffer.storage_rate_hz as-is for pipeline.""" + # Create rate representing device setting after streaming starts + rate = AcquisitionRate( + hardware_rate_hz=125e6, + num_channels=2, + downsample_ratio=10, + downsample_mode=DownsampleMode.NONE, + ) + + # Storage rate should be 12.5e6 (125e6 / 10) + assert rate.storage_rate_hz == 12.5e6 + + config = PipelineConfig( + resolution=16, + voltage_ranges={0: 20.0, 1: 20.0}, + offsets_v={0: 0.0, 1: 0.0}, + max_adc_value=32767, + target_plot_points=20000, + ) + pipeline = DataPipeline(config, rate) + + # Pipeline should report 6.25 MS/s per channel + assert pipeline.acquisition_rate.per_channel_rate_hz == 6.25e6 + + +class TestEndToEndTimeAxisAccuracy: + """Verify time axis calculations produce correct durations after downsampling. + + This tests the symptom: "time intervals display 10x too large". + """ + + def test_time_axis_accuracy_no_downsampling(self) -> None: + """1 second of data should display as 1.0s without downsampling.""" + rate = AcquisitionRate( + hardware_rate_hz=125e6, # 62.5 per channel, 2 channels + num_channels=2, + downsample_ratio=1, + downsample_mode=DownsampleMode.NONE, + ) + + config = PipelineConfig( + resolution=16, + voltage_ranges={0: 20.0, 1: 20.0}, + offsets_v={0: 0.0, 1: 0.0}, + max_adc_value=32767, + target_plot_points=20000, + ) + pipeline = DataPipeline(config, rate) + + # Create time axis for 1 second display window with decimation=1 + # Signature: (window_duration: float, decimation: int) + time_axis = pipeline.create_time_axis_for_window(1.0, 1) + + # Should span from 0 to 1.0 seconds + assert time_axis[0] == 0.0 + assert abs(time_axis[-1] - 1.0) < 1e-9 + + def test_time_axis_accuracy_with_10x_downsampling(self) -> None: + """1 second of data should display as 1.0s with 10x downsampling. + + This is the critical test for the bug: with 10x downsampling, + if rate was multiplied by channels, time would show as 10x too large. + """ + hardware_rate_hz = 125e6 # 125 MS/s total (62.5 per channel) + downsample_ratio = 10 + + rate = AcquisitionRate( + hardware_rate_hz=hardware_rate_hz, + num_channels=2, + downsample_ratio=downsample_ratio, + downsample_mode=DownsampleMode.NONE, + ) + + config = PipelineConfig( + resolution=16, + voltage_ranges={0: 20.0, 1: 20.0}, + offsets_v={0: 0.0, 1: 0.0}, + max_adc_value=32767, + target_plot_points=20000, + ) + pipeline = DataPipeline(config, rate) + + # Per-channel effective rate after downsampling + effective_rate_per_channel = pipeline.acquisition_rate.per_channel_rate_hz + assert effective_rate_per_channel == 6.25e6 + + # Create time axis for 1 second window with decimation=1 + # Signature: (window_duration: float, decimation: int) + time_axis = pipeline.create_time_axis_for_window(1.0, 1) + + # Should span from 0 to 1.0 seconds + assert time_axis[0] == 0.0 + assert abs(time_axis[-1] - 1.0) < 1e-9 + + def test_time_axis_accuracy_aggregate_mode(self) -> None: + """Time axis should be accurate in AGGREGATE mode with downsampling.""" + rate = AcquisitionRate( + hardware_rate_hz=125e6, + num_channels=2, + downsample_ratio=10, + downsample_mode=DownsampleMode.AGGREGATE, + ) + + config = PipelineConfig( + resolution=16, + voltage_ranges={0: 20.0, 1: 20.0}, + offsets_v={0: 0.0, 1: 0.0}, + max_adc_value=32767, + target_plot_points=20000, + ) + pipeline = DataPipeline(config, rate) + + # AGGREGATE mode: time_to_samples returns storage samples (already doubled) + # 1s window = 6.25e6 time points per channel * 2 for min/max = 12.5M storage samples + # Each time point produces 2 time axis values (repeated), so 12.5M total values + time_axis = pipeline.create_time_axis_for_window(1.0, 1) + + # Should still span 0 to 1.0 seconds + assert time_axis[0] == 0.0 + assert abs(time_axis[-1] - 1.0) < 1e-9 + # window_samples from time_to_samples // 2 (pairs) * 2 (doubled) = window_samples + window_samples = pipeline.time_to_samples(1.0) + expected_length = int(window_samples // 2 * 2) + assert len(time_axis) == expected_length + + +class TestDisplayDurationCalculation: + """Verify display duration calculations account for decimation correctly.""" + + def test_display_duration_no_decimation_no_downsample(self) -> None: + """Display duration should match actual data duration. + + Note: get_display_duration accounts for software decimation + (min/max pairs) and hardware AGGREGATE mode. + """ + rate = AcquisitionRate( + hardware_rate_hz=125e6, + num_channels=2, + downsample_ratio=1, + downsample_mode=DownsampleMode.NONE, + ) + + config = PipelineConfig( + resolution=16, + voltage_ranges={0: 20.0, 1: 20.0}, + offsets_v={0: 0.0, 1: 0.0}, + max_adc_value=32767, + target_plot_points=20000, + ) + pipeline = DataPipeline(config, rate) + + # 100 display samples = 50 min/max pairs (software decimation), at decimation 1 + # = 50 original samples at 62.5 MS/s per_channel_rate = 0.8 microseconds + display_duration = pipeline.get_display_duration(100, decimation=1) + # n_display_samples // 2 = 50 pairs * decimation 1 = 50 original samples + # 50 / 62.5e6 = 8e-7 seconds + expected = 50 / 62.5e6 + assert abs(display_duration - expected) < 1e-9 + + def test_display_duration_with_decimation_and_downsampling(self) -> None: + """Display duration should account for both software and hardware decimation.""" + rate = AcquisitionRate( + hardware_rate_hz=125e6, + num_channels=2, + downsample_ratio=10, + downsample_mode=DownsampleMode.NONE, + ) + + config = PipelineConfig( + resolution=16, + voltage_ranges={0: 20.0, 1: 20.0}, + offsets_v={0: 0.0, 1: 0.0}, + max_adc_value=32767, + target_plot_points=20000, + ) + pipeline = DataPipeline(config, rate) + + # 200 display samples = 100 min/max pairs with software decimation 1 + # = 100 original samples at 6.25 MS/s per_channel_rate = 16 microseconds + display_duration = pipeline.get_display_duration(200, decimation=1) + expected = 100 / 6.25e6 + assert abs(display_duration - expected) < 1e-12 + + def test_display_duration_10x_downsampling_sanity_check(self) -> None: + """Sanity check: verify rate calculations with 10x downsampling. + + With 10x downsampling, 125 MS/s hardware becomes 12.5 MS/s storage. + Per channel effective rate is 6.25 MS/s. + """ + rate = AcquisitionRate( + hardware_rate_hz=125e6, + num_channels=2, + downsample_ratio=10, + downsample_mode=DownsampleMode.NONE, + ) + + config = PipelineConfig( + resolution=16, + voltage_ranges={0: 20.0, 1: 20.0}, + offsets_v={0: 0.0, 1: 0.0}, + max_adc_value=32767, + target_plot_points=20000, + ) + pipeline = DataPipeline(config, rate) + + # Verify effective rate is correct + assert pipeline.acquisition_rate.per_channel_rate_hz == 6.25e6 + + # 1000 display samples = 500 min/max pairs (software decimation) + # at decimation 10 = 5000 original samples + # at 6.25 MS/s = 5000 / 6.25e6 = 0.8 milliseconds + display_duration = pipeline.get_display_duration(1000, decimation=10) + expected_duration = (500 * 10) / 6.25e6 # 0.0008 seconds + assert abs(display_duration - expected_duration) < 1e-9 + + # Verify the value is reasonable (not orders of magnitude off) + assert 0.0001 < display_duration < 0.01 # Between 0.1ms and 10ms diff --git a/picostream/test_ring_buffer.py b/picostream/test_ring_buffer.py new file mode 100644 index 0000000..c622e8f --- /dev/null +++ b/picostream/test_ring_buffer.py @@ -0,0 +1,508 @@ +"""Tests for the RingBuffer class.""" + +from __future__ import annotations + +import threading +import time + +import numpy as np +import pytest + +from picostream.acquisition_rate import AcquisitionRate, DownsampleMode +from picostream.ring_buffer import RingBuffer + + +class TestRingBufferBasic: + """Basic functionality tests.""" + + def test_init(self) -> None: + """Test buffer initialisation with valid parameters.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=1000.0, num_channels=2) + + acquisition_rate = AcquisitionRate( + hardware_rate_hz=1000.0, + num_channels=2, + downsample_ratio=1, + downsample_mode=DownsampleMode.NONE, + ) + buffer.set_acquisition_rate(acquisition_rate) + + assert buffer.capacity == 1000 + assert buffer.num_channels == 2 + assert buffer.storage_rate_hz == 1000.0 + assert buffer.write_idx == 0 + assert buffer.buffer.shape == (1000, 2) + assert buffer.buffer.dtype == np.int16 + + def test_init_invalid_capacity(self) -> None: + """Test initialisation with invalid parameters raises ValueError.""" + with pytest.raises(ValueError, match="Buffer capacity must be positive"): + RingBuffer(duration_s=0.0, sample_rate=1000.0) + + with pytest.raises(ValueError, match="Buffer capacity must be positive"): + RingBuffer(duration_s=1.0, sample_rate=0.0) + + def test_get_snapshot(self) -> None: + """Test snapshot returns current write position.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=1000.0, num_channels=2) + assert buffer.get_snapshot() == 0 + + data = np.ones((100, 2), dtype=np.int16) + buffer.write(data) + assert buffer.get_snapshot() == 100 + + def test_get_utilisation(self) -> None: + """Test utilisation calculation.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2) + assert buffer.get_utilisation() == 0.0 + + data = np.ones((50, 2), dtype=np.int16) + buffer.write(data) + assert buffer.get_utilisation() == 0.5 + + buffer.write(np.ones((100, 2), dtype=np.int16)) + assert buffer.get_utilisation() == 1.0 + + def test_get_duration_available(self) -> None: + """Test duration available calculation.""" + buffer = RingBuffer(duration_s=10.0, sample_rate=1000.0, num_channels=2) + + acquisition_rate = AcquisitionRate( + hardware_rate_hz=1000.0, + num_channels=2, + downsample_ratio=1, + downsample_mode=DownsampleMode.NONE, + ) + buffer.set_acquisition_rate(acquisition_rate) + + assert buffer.get_duration_available() == 0.0 + + data = np.ones((5000, 2), dtype=np.int16) + buffer.write(data) + assert buffer.get_duration_available() == 5.0 + + +class TestRingBufferWrite: + """Tests for the write method.""" + + def test_write_single(self) -> None: + """Test writing a single batch of data.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=1000.0, num_channels=2) + data = np.arange(200, dtype=np.int16).reshape(100, 2) + + buffer.write(data) + assert buffer.write_idx == 100 + + result = buffer.read_last(100) + np.testing.assert_array_equal(result, data) + + def test_write_wrong_dtype(self) -> None: + """Test writing wrong dtype raises ValueError.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=1000.0, num_channels=2) + data = np.ones((100, 2), dtype=np.float32) + + with pytest.raises(ValueError, match="Data must be int16"): + buffer.write(data) + + def test_write_wrong_shape(self) -> None: + """Test writing wrong shape raises ValueError.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=1000.0, num_channels=2) + + # Wrong number of channels + data = np.ones((100, 3), dtype=np.int16) + with pytest.raises(ValueError, match="Data shape must be"): + buffer.write(data) + + # 1D array + data = np.ones(100, dtype=np.int16) + with pytest.raises(ValueError, match="Data shape must be"): + buffer.write(data) + + def test_write_empty(self) -> None: + """Test writing empty array is a no-op.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=1000.0, num_channels=2) + data = np.zeros((0, 2), dtype=np.int16) + + buffer.write(data) + assert buffer.write_idx == 0 + + +class TestRingBufferReadLast: + """Tests for the read_last method.""" + + def test_read_last_basic(self) -> None: + """Test reading most recent data.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2) + + data1 = np.ones((50, 2), dtype=np.int16) * 1 + data2 = np.ones((50, 2), dtype=np.int16) * 2 + buffer.write(data1) + buffer.write(data2) + + result = buffer.read_last(50) + expected = np.ones((50, 2), dtype=np.int16) * 2 + np.testing.assert_array_equal(result, expected) + + def test_read_last_zero(self) -> None: + """Test reading zero samples returns empty array.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2) + buffer.write(np.ones((50, 2), dtype=np.int16)) + + result = buffer.read_last(0) + assert result.shape == (0, 2) + + def test_read_last_more_than_available(self) -> None: + """Test reading more than available clamps to available.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2) + data = np.arange(100, dtype=np.int16).reshape(50, 2) + buffer.write(data) + + result = buffer.read_last(1000) + assert result.shape == (50, 2) + np.testing.assert_array_equal(result, data) + + def test_read_last_before_write(self) -> None: + """Test reading from empty buffer returns empty.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2) + result = buffer.read_last(50) + assert result.shape == (0, 2) + + +class TestRingBufferReadRange: + """Tests for the read_range method.""" + + def test_read_range_no_wrap(self) -> None: + """Test reading a range without wrap-around.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2) + data = np.arange(200, dtype=np.int16).reshape(100, 2) + buffer.write(data) + + result = buffer.read_range(0, 50) + expected = data[:50] + np.testing.assert_array_equal(result, expected) + + def test_read_range_empty(self) -> None: + """Test reading empty range returns empty array.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2) + buffer.write(np.ones((50, 2), dtype=np.int16)) + + result = buffer.read_range(50, 50) + assert result.shape == (0, 2) + + result = buffer.read_range(60, 50) + assert result.shape == (0, 2) + + def test_read_range_clamped(self) -> None: + """Test reading more than capacity clamps to capacity.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2) + data = np.arange(200, dtype=np.int16).reshape(100, 2) + buffer.write(data) + + result = buffer.read_range(0, 200) + assert result.shape == (100, 2) + + +class TestRingBufferWrapAround: + """Tests for wrap-around behavior.""" + + def test_write_wraparound(self) -> None: + """Test writing past capacity overwrites old data.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2) + + data1 = np.ones((100, 2), dtype=np.int16) * 1 + buffer.write(data1) + assert buffer.write_idx == 100 + + data2 = np.ones((50, 2), dtype=np.int16) * 2 + buffer.write(data2) + assert buffer.write_idx == 150 + + # Read last 100 samples - should have 50 from data1 (positions 50-99) + # and 50 from data2 (positions 100-149) + result = buffer.read_last(100) + assert result[0, 0] == 1 # From data1 (earlier in the read) + assert result[50, 0] == 2 # From data2 (later in the read) + + def test_write_exact_capacity(self) -> None: + """Test writing exactly capacity samples.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2) + data = np.arange(200, dtype=np.int16).reshape(100, 2) + buffer.write(data) + + result = buffer.read_last(100) + np.testing.assert_array_equal(result, data) + + def test_write_past_capacity(self) -> None: + """Test writing past capacity overwrites from beginning.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2) + + data1 = np.arange(200, dtype=np.int16).reshape(100, 2) + buffer.write(data1) + + data2 = np.arange(200, 400, dtype=np.int16).reshape(100, 2) + buffer.write(data2) + + # Buffer should contain only data2 + result = buffer.read_last(100) + np.testing.assert_array_equal(result, data2) + + def test_read_range_wraparound(self) -> None: + """Test reading a range that spans the wrap point.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2) + + # Fill buffer completely + data1 = np.arange(200, dtype=np.int16).reshape(100, 2) + buffer.write(data1) + assert buffer.write_idx == 100 + + # Write more to cause wrap (only last 100 samples kept) + data2 = np.arange(200, 400, dtype=np.int16).reshape(100, 2) + buffer.write(data2) + assert buffer.write_idx == 200 + + # Buffer now contains data2 at positions 100-199 (write_idx range) + # but stored at buffer indices 0-99 + + # Read from position 150 to 200 (last 50 samples) + result = buffer.read_range(150, 200) + expected = data2[50:100] # Last 50 samples of data2 + np.testing.assert_array_equal(result, expected) + + def test_read_last_wraparound(self) -> None: + """Test read_last with wraparound.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2) + + data1 = np.ones((80, 2), dtype=np.int16) * 1 + data2 = np.ones((50, 2), dtype=np.int16) * 2 + buffer.write(data1) + buffer.write(data2) + + # Read last 60 samples - should span wrap point + result = buffer.read_last(60) + assert result.shape == (60, 2) + assert result[-1, 0] == 2 + assert result[0, 0] == 1 + + +class TestRingBufferReadSince: + """Tests for the read_since method.""" + + def test_read_since_basic(self) -> None: + """Test reading data since a position.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2) + + pos1 = buffer.get_snapshot() + data1 = np.ones((50, 2), dtype=np.int16) * 1 + buffer.write(data1) + + result = buffer.read_since(pos1) + assert result.shape == (50, 2) + np.testing.assert_array_equal(result, data1) + + def test_read_since_empty(self) -> None: + """Test reading since current position returns empty.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2) + buffer.write(np.ones((50, 2), dtype=np.int16)) + + pos = buffer.get_snapshot() + result = buffer.read_since(pos) + assert result.shape == (0, 2) + + +class TestRingBufferIsValidRange: + """Tests for the is_valid_range method.""" + + def test_is_valid_range_true(self) -> None: + """Test valid range returns True.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2) + buffer.write(np.ones((50, 2), dtype=np.int16)) + + assert buffer.is_valid_range(0) is True + assert buffer.is_valid_range(25) is True + assert buffer.is_valid_range(49) is True + + def test_is_valid_range_false(self) -> None: + """Test overwritten range returns False.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2) + + data1 = np.ones((100, 2), dtype=np.int16) + buffer.write(data1) + + data2 = np.ones((60, 2), dtype=np.int16) * 2 + buffer.write(data2) + + # Position 0 has been overwritten + assert buffer.is_valid_range(0) is False + assert buffer.is_valid_range(50) is False + + # Position 60 (start of data2) should be valid + assert buffer.is_valid_range(60) is True + + +class TestRingBufferMultiChannel: + """Tests for multi-channel data integrity.""" + + def test_multi_channel_aligned(self) -> None: + """Test channel data stays aligned through wrap-around.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2) + + # Write data where each sample has a specific pattern + data = np.zeros((150, 2), dtype=np.int16) + data[:, 0] = np.arange(150, dtype=np.int16) # Channel 0: 0, 1, 2, ... + data[:, 1] = np.arange( + 300, 450, dtype=np.int16 + ) # Channel 1: 300, 301, 302, ... + + buffer.write(data) + + # Read last 50 samples - should be the tail of original data + result = buffer.read_last(50) + + # Channel 0 should have consecutive values + expected_ch0 = np.arange(100, 150, dtype=np.int16) + np.testing.assert_array_equal(result[:, 0], expected_ch0) + + # Channel 1 should have the corresponding values + expected_ch1 = np.arange(400, 450, dtype=np.int16) + np.testing.assert_array_equal(result[:, 1], expected_ch1) + + def test_single_channel(self) -> None: + """Test with single channel.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=1) + data = np.arange(100, dtype=np.int16).reshape(100, 1) + + buffer.write(data) + result = buffer.read_last(50) + + expected = np.arange(50, 100, dtype=np.int16).reshape(50, 1) + np.testing.assert_array_equal(result, expected) + + def test_four_channels(self) -> None: + """Test with four channels.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=4) + data = np.arange(400, dtype=np.int16).reshape(100, 4) + + buffer.write(data) + result = buffer.read_last(50) + + expected = data[50:] + np.testing.assert_array_equal(result, expected) + + +class TestRingBufferReturnsCopy: + """Tests that read methods return copies, not views.""" + + def test_read_last_returns_copy(self) -> None: + """Test read_last returns a copy.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2) + data = np.ones((50, 2), dtype=np.int16) + buffer.write(data) + + result = buffer.read_last(50) + result[0, 0] = 999 + + # Original buffer should be unchanged + result2 = buffer.read_last(50) + assert result2[0, 0] == 1 + + def test_read_range_returns_copy(self) -> None: + """Test read_range returns a copy.""" + buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2) + data = np.ones((50, 2), dtype=np.int16) + buffer.write(data) + + result = buffer.read_range(0, 50) + result[0, 0] = 999 + + result2 = buffer.read_range(0, 50) + assert result2[0, 0] == 1 + + +class TestRingBufferThreadSafety: + """Tests for thread safety under load.""" + + def test_concurrent_write_read(self) -> None: + """Test concurrent writing and reading. + + Runs writer and reader threads concurrently to verify no crashes + or obvious data corruption. + """ + buffer = RingBuffer(duration_s=0.5, sample_rate=10000.0, num_channels=2) + errors: list[str] = [] + stop_event = threading.Event() + + def writer() -> None: + """Writer thread.""" + counter = 0 + while not stop_event.is_set(): + data = np.full((100, 2), counter, dtype=np.int16) + try: + buffer.write(data) + except Exception as e: + errors.append(f"Writer error: {e}") + break + counter = (counter + 1) % 1000 + + def reader() -> None: + """Reader thread.""" + while not stop_event.is_set(): + try: + _ = buffer.read_last(1000) + except Exception as e: + errors.append(f"Reader error: {e}") + break + time.sleep(0.001) + + writer_thread = threading.Thread(target=writer) + reader_thread = threading.Thread(target=reader) + + writer_thread.start() + reader_thread.start() + + time.sleep(0.5) + stop_event.set() + + writer_thread.join(timeout=2.0) + reader_thread.join(timeout=2.0) + + assert not errors, f"Thread errors occurred: {errors}" + + def test_data_integrity_under_load(self) -> None: + """Test data integrity with sequential pattern under thread load. + + Writer writes sequential values, readers verify they can read + consistent data. + """ + buffer = RingBuffer(duration_s=0.1, sample_rate=10000.0, num_channels=2) + stop_event = threading.Event() + + def writer() -> None: + """Writer with sequential pattern.""" + counter = 0 + while not stop_event.is_set(): + data = np.zeros((50, 2), dtype=np.int16) + data[:, 0] = counter + data[:, 1] = counter + 1 + buffer.write(data) + counter += 1 + if counter > 10000: + counter = 0 + + writer_thread = threading.Thread(target=writer) + writer_thread.start() + + time.sleep(0.05) + + # Multiple reads to verify consistency + for _ in range(100): + result = buffer.read_last(100) + if len(result) > 0: + # All values in a column should be the same + # (within the same write batch) + assert ( + np.all(result[:, 0] == result[0, 0]) or len(set(result[:, 0])) <= 2 + ) + time.sleep(0.001) + + stop_event.set() + writer_thread.join(timeout=2.0) diff --git a/picostream/test_zarr_reader.py b/picostream/test_zarr_reader.py new file mode 100644 index 0000000..aabf360 --- /dev/null +++ b/picostream/test_zarr_reader.py @@ -0,0 +1,331 @@ +"""Tests for the PicoZarrReader class.""" + +from __future__ import annotations + +import os +import shutil +import tempfile + +import numpy as np +import pytest + +from picostream.acquisition_rate import AcquisitionRate, DownsampleMode +from picostream.zarr_reader import PicoZarrReader +from picostream.zarr_writer import ZarrStreamWriter + + +@pytest.fixture +def temp_zarr_path(): + """Provide a temporary path for Zarr files.""" + path = tempfile.mkdtemp(suffix=".zarr") + yield path + if os.path.exists(path): + shutil.rmtree(path) + + +@pytest.fixture +def sample_metadata(): + """Provide sample metadata for Zarr tests.""" + return { + "hardware_rate_hz": 2000.0, + "channels": [0, 1], + "voltage_ranges": [20.0, 20.0], + "offsets_v": [0.0, 0.0], + "resolution": 16, + "downsample_ratio": 1, + "downsample_mode": "NONE", + "pre_trigger_seconds": 5.0, + } + + +@pytest.fixture +def populated_zarr_file(temp_zarr_path, sample_metadata): + """Create a Zarr file with test data.""" + acquisition_rate = AcquisitionRate( + hardware_rate_hz=sample_metadata["hardware_rate_hz"], + num_channels=len(sample_metadata["channels"]), + downsample_ratio=sample_metadata["downsample_ratio"], + downsample_mode=DownsampleMode(sample_metadata["downsample_mode"]), + ) + writer = ZarrStreamWriter( + path=temp_zarr_path, + acquisition_rate=acquisition_rate, + num_channels=len(sample_metadata["channels"]), + ) + + data = np.zeros((5000, 2), dtype=np.int16) + data[:, 0] = np.sin(np.linspace(0, 4 * np.pi, 5000)) * 10000 + data[:, 1] = np.cos(np.linspace(0, 4 * np.pi, 5000)) * 8000 + + writer.append(data) + writer.close(sample_metadata) + + return temp_zarr_path + + +class TestPicoZarrReaderBasic: + """Basic functionality tests.""" + + def test_init(self, populated_zarr_file, sample_metadata): + """Test reader initialisation.""" + reader = PicoZarrReader(populated_zarr_file) + + assert reader.path.name == os.path.basename(populated_zarr_file) + assert reader.sample_rate == 2000.0 + assert reader.num_channels == 2 + assert reader.total_samples == 5000 + assert reader.duration_s == 2.5 + assert reader.channels == [0, 1] + assert reader.resolution == 16 + + def test_init_nonexistent_path(self): + """Test initialisation with non-existent path raises FileNotFoundError.""" + with pytest.raises(FileNotFoundError): + PicoZarrReader("/nonexistent/path/file.zarr") + + def test_voltage_ranges(self, populated_zarr_file): + """Test voltage ranges are loaded correctly.""" + reader = PicoZarrReader(populated_zarr_file) + + assert reader.voltage_ranges == {0: 20.0, 1: 20.0} + assert reader.offsets_v == {0: 0.0, 1: 0.0} + + def test_max_adc_calculation(self, populated_zarr_file): + """Test max ADC value is calculated from resolution.""" + reader = PicoZarrReader(populated_zarr_file) + + expected_max = (1 << (16 - 1)) - 1 + assert reader.max_adc == expected_max + + +class TestPicoZarrReaderRaw: + """Raw data reading tests.""" + + def test_get_raw_basic(self, populated_zarr_file): + """Test reading raw data.""" + reader = PicoZarrReader(populated_zarr_file) + + data = reader.get_raw(0.0, 1.0) + + assert data.shape == (1000, 2) + assert data.dtype == np.int16 + + def test_get_raw_empty_range(self, populated_zarr_file): + """Test reading empty range returns empty array.""" + reader = PicoZarrReader(populated_zarr_file) + + data = reader.get_raw(0.0, 0.0) + assert data.shape == (0, 2) + + data = reader.get_raw(1.0, 0.5) + assert data.shape == (0, 2) + + def test_get_raw_out_of_bounds(self, populated_zarr_file): + """Test reading beyond file bounds clamps correctly.""" + reader = PicoZarrReader(populated_zarr_file) + + data = reader.get_raw(4.0, 10.0) + + assert data.shape[0] == 1000 + + def test_get_raw_returns_copy(self, populated_zarr_file): + """Test that get_raw returns a copy, not a view.""" + reader = PicoZarrReader(populated_zarr_file) + + data1 = reader.get_raw(0.0, 1.0) + data1[0, 0] = 999 + + data2 = reader.get_raw(0.0, 1.0) + assert data2[0, 0] != 999 + + +class TestPicoZarrReaderVoltage: + """Voltage conversion tests.""" + + def test_get_voltage_basic(self, populated_zarr_file): + """Test reading voltage data.""" + reader = PicoZarrReader(populated_zarr_file) + + voltage = reader.get_voltage(0.0, 1.0) + + assert voltage.shape == (1000, 2) + assert voltage.dtype == np.float64 + + def test_get_voltage_empty(self, populated_zarr_file): + """Test reading empty range returns empty float array.""" + reader = PicoZarrReader(populated_zarr_file) + + voltage = reader.get_voltage(0.0, 0.0) + assert voltage.shape == (0, 2) + assert voltage.dtype == np.float64 + + def test_get_voltage_conversion(self, populated_zarr_file): + """Test voltage conversion is correct.""" + reader = PicoZarrReader(populated_zarr_file) + + raw = reader.get_raw(0.0, 1.0) + voltage = reader.get_voltage(0.0, 1.0) + + max_raw = np.max(np.abs(raw)) + max_voltage = np.max(np.abs(voltage)) + + expected_ratio = 20.0 * 1000 / reader.max_adc + actual_ratio = max_voltage / max_raw if max_raw > 0 else 0 + + assert abs(actual_ratio - expected_ratio) < 0.1 + + +class TestPicoZarrReaderTimeAxis: + """Time axis tests.""" + + def test_get_time_axis_basic(self, populated_zarr_file): + """Test time axis creation.""" + reader = PicoZarrReader(populated_zarr_file) + + time_axis = reader.get_time_axis(0.0, 1.0) + + assert len(time_axis) == 1000 + assert time_axis[0] == 0.0 + assert abs(time_axis[-1] - 1.0) < 0.01 + + def test_get_time_axis_empty(self, populated_zarr_file): + """Test empty time axis.""" + reader = PicoZarrReader(populated_zarr_file) + + time_axis = reader.get_time_axis(0.0, 0.0) + assert len(time_axis) == 0 + + def test_get_time_axis_matches_data(self, populated_zarr_file): + """Test time axis length matches data length.""" + reader = PicoZarrReader(populated_zarr_file) + + raw = reader.get_raw(0.5, 1.5) + time_axis = reader.get_time_axis(0.5, 1.5) + + assert len(time_axis) == raw.shape[0] + + +class TestPicoZarrReaderMetadata: + """Metadata tests.""" + + def test_get_metadata(self, populated_zarr_file, sample_metadata): + """Test metadata retrieval.""" + reader = PicoZarrReader(populated_zarr_file) + + metadata = reader.get_metadata() + + assert metadata["sample_rate_hz"] == 2000.0 + assert metadata["channels"] == [0, 1] + assert metadata["total_samples"] == 5000 + + def test_pre_trigger_time(self, populated_zarr_file): + """Test pre-trigger time retrieval.""" + reader = PicoZarrReader(populated_zarr_file) + + assert reader.get_pre_trigger_time() == 5.0 + + def test_get_channel_name(self, populated_zarr_file): + """Test channel name generation.""" + reader = PicoZarrReader(populated_zarr_file) + + assert reader.get_channel_name(0) == "A" + assert reader.get_channel_name(1) == "B" + + +class TestPicoZarrReaderTimeConversion: + """Time/sample conversion tests.""" + + def test_time_to_sample(self, populated_zarr_file): + """Test time to sample conversion.""" + reader = PicoZarrReader(populated_zarr_file) + + assert reader._time_to_sample(0.0) == 0 + assert reader._time_to_sample(1.0) == 1000 + assert reader._time_to_sample(5.0) == 5000 + + def test_time_to_sample_clamping(self, populated_zarr_file): + """Test time to sample clamping.""" + reader = PicoZarrReader(populated_zarr_file) + + assert reader._time_to_sample(-1.0) == 0 + assert reader._time_to_sample(10.0) == 5000 + + def test_sample_to_time(self, populated_zarr_file): + """Test sample to time conversion.""" + reader = PicoZarrReader(populated_zarr_file) + + assert reader._sample_to_time(0) == 0.0 + assert reader._sample_to_time(1000) == 1.0 + assert reader._sample_to_time(5000) == 5.0 + + +class TestPicoZarrReaderClose: + """Close/cleanup tests.""" + + def test_close(self, populated_zarr_file): + """Test reader close.""" + reader = PicoZarrReader(populated_zarr_file) + + reader.close() + + with pytest.raises(ValueError, match="Reader has been closed"): + reader.get_raw(0.0, 1.0) + + def test_context_manager(self, populated_zarr_file): + """Test context manager usage.""" + with PicoZarrReader(populated_zarr_file) as reader: + assert reader.total_samples == 5000 + + with pytest.raises(ValueError, match="Reader has been closed"): + reader.get_raw(0.0, 1.0) + + +class TestPicoZarrReaderPartialReads: + """Partial read tests.""" + + def test_partial_read_start(self, populated_zarr_file): + """Test reading from start of file.""" + reader = PicoZarrReader(populated_zarr_file) + + data = reader.get_raw(0.0, 0.5) + + assert data.shape[0] == 500 + + def test_partial_read_end(self, populated_zarr_file): + """Test reading to end of file.""" + reader = PicoZarrReader(populated_zarr_file) + + data = reader.get_raw(4.5, 5.0) + + assert data.shape[0] == 500 + + def test_partial_read_middle(self, populated_zarr_file): + """Test reading from middle of file.""" + reader = PicoZarrReader(populated_zarr_file) + + data = reader.get_raw(2.0, 3.0) + + assert data.shape[0] == 1000 + + def test_arbitrary_time_range(self, populated_zarr_file): + """Test arbitrary time range.""" + reader = PicoZarrReader(populated_zarr_file) + + data = reader.get_raw(1.234, 3.567) + + expected_samples = int((3.567 - 1.234) * 1000) + assert abs(data.shape[0] - expected_samples) <= 1 + + +class TestPicoZarrReaderRepresentation: + """String representation tests.""" + + def test_repr(self, populated_zarr_file): + """Test string representation.""" + reader = PicoZarrReader(populated_zarr_file) + + repr_str = repr(reader) + + assert "PicoZarrReader" in repr_str + assert "2.50s" in repr_str + assert "2 ch" in repr_str diff --git a/picostream/test_zarr_viewer.py b/picostream/test_zarr_viewer.py new file mode 100644 index 0000000..10b5683 --- /dev/null +++ b/picostream/test_zarr_viewer.py @@ -0,0 +1,223 @@ +"""Tests for the ZarrViewerWindow class.""" + +from __future__ import annotations + +import os +import shutil +import tempfile + +import numpy as np +import pytest + +from picostream.acquisition_rate import AcquisitionRate, DownsampleMode +from picostream.zarr_writer import ZarrStreamWriter + +pytest.importorskip("PyQt6.QtWidgets", reason="PyQt6 not available") + +from PyQt6.QtWidgets import QApplication + +from picostream.zarr_viewer import ZarrViewerWindow + + +@pytest.fixture(scope="session") +def qapp(): + """Provide a Qt application instance for the test session.""" + app = QApplication.instance() + if app is None: + app = QApplication([]) + return app + + +@pytest.fixture +def temp_zarr_path(): + """Provide a temporary path for Zarr files.""" + path = tempfile.mkdtemp(suffix=".zarr") + yield path + if os.path.exists(path): + shutil.rmtree(path) + + +@pytest.fixture +def sample_metadata(): + """Provide sample metadata for Zarr tests.""" + return { + "hardware_rate_hz": 1000.0, + "channels": [0, 1], + "voltage_ranges": [20.0, 20.0], + "offsets_v": [0.0, 0.0], + "resolution": 16, + "downsample_ratio": 1, + "downsample_mode": "NONE", + "pre_trigger_seconds": 5.0, + } + + +@pytest.fixture +def populated_zarr_file(temp_zarr_path, sample_metadata): + """Create a Zarr file with test data.""" + acquisition_rate = AcquisitionRate( + hardware_rate_hz=sample_metadata["hardware_rate_hz"], + num_channels=2, + downsample_ratio=sample_metadata["downsample_ratio"], + downsample_mode=DownsampleMode(sample_metadata["downsample_mode"]), + ) + writer = ZarrStreamWriter( + path=temp_zarr_path, + acquisition_rate=acquisition_rate, + num_channels=len(sample_metadata["channels"]), + ) + + data = np.zeros((10000, 2), dtype=np.int16) + data[:, 0] = np.sin(np.linspace(0, 8 * np.pi, 10000)) * 10000 + data[:, 1] = np.cos(np.linspace(0, 8 * np.pi, 10000)) * 8000 + + writer.append(data) + writer.close(sample_metadata) + + return temp_zarr_path + + +class TestZarrViewerWindowBasic: + """Basic viewer window tests.""" + + def test_window_creation(self, qapp, populated_zarr_file): + """Test viewer window can be created.""" + viewer = ZarrViewerWindow(populated_zarr_file) + + assert viewer.reader is not None + assert viewer.reader.duration_s == 10.0 + # Window duration is clamped to valid range + assert ( + viewer.MIN_WINDOW_DURATION + <= viewer.window_duration_s + <= min(viewer.MAX_WINDOW_DURATION, viewer.reader.duration_s) + ) + + viewer.close() + + def test_window_title(self, qapp, populated_zarr_file): + """Test window title is set correctly.""" + viewer = ZarrViewerWindow(populated_zarr_file) + + title = viewer.windowTitle() + assert "PicoStream Viewer" in title + assert "10.0s" in title + + viewer.close() + + +class TestZarrViewerWindowControls: + """Control interaction tests.""" + + def test_window_duration_change(self, qapp, populated_zarr_file): + """Test window duration can be changed.""" + viewer = ZarrViewerWindow(populated_zarr_file) + + # Get initial duration + initial_duration = viewer.window_duration_s + assert ( + viewer.MIN_WINDOW_DURATION + <= initial_duration + <= min(viewer.MAX_WINDOW_DURATION, viewer.reader.duration_s) + ) + + viewer.window_duration_s = 3.0 + assert viewer.window_duration_s == 3.0 + + viewer.close() + + def test_scroll_position(self, qapp, populated_zarr_file): + """Test scroll position.""" + viewer = ZarrViewerWindow(populated_zarr_file) + + assert viewer.scroll_position_s == 0.0 + + viewer.scroll_position_s = 3.0 + assert viewer.scroll_position_s == 3.0 + + viewer.close() + + def test_clamp_scroll_position(self, qapp, populated_zarr_file): + """Test scroll position clamping.""" + viewer = ZarrViewerWindow(populated_zarr_file) + + viewer.window_duration_s = 5.0 + viewer.scroll_position_s = 20.0 + viewer._clamp_scroll_position() + + max_pos = 10.0 - 5.0 + assert viewer.scroll_position_s <= max_pos + + viewer.close() + + +class TestZarrViewerWindowChannelVisibility: + """Channel visibility tests.""" + + def test_channel_toggling(self, qapp, populated_zarr_file): + """Test channel visibility toggling.""" + viewer = ZarrViewerWindow(populated_zarr_file) + + assert viewer.curves[0].isVisible() + + viewer._on_channel_toggled(0, False) + assert not viewer.curves[0].isVisible() + + viewer._on_channel_toggled(0, True) + assert viewer.curves[0].isVisible() + + viewer.close() + + +class TestZarrViewerWindowTimeConversion: + """Time/sample conversion tests.""" + + def test_time_to_sample(self, qapp, populated_zarr_file): + """Test time to sample conversion.""" + viewer = ZarrViewerWindow(populated_zarr_file) + + sample_idx = viewer.reader._time_to_sample(1.0) + assert sample_idx == 500 + + sample_idx = viewer.reader._time_to_sample(5.0) + assert sample_idx == 2500 + + viewer.close() + + +class TestZarrViewerWindowMetadata: + """Metadata display tests.""" + + def test_metadata_loaded(self, qapp, populated_zarr_file): + """Test metadata is loaded correctly.""" + viewer = ZarrViewerWindow(populated_zarr_file) + + assert viewer.reader.sample_rate == 1000.0 + assert viewer.reader.num_channels == 2 + assert viewer.reader.channels == [0, 1] + assert viewer.reader.pre_trigger_seconds == 5.0 + + viewer.close() + + +class TestZarrViewerWindowConstants: + """Constants tests.""" + + def test_window_limits(self, qapp, populated_zarr_file): + """Test window duration limits.""" + viewer = ZarrViewerWindow(populated_zarr_file) + + assert viewer.MIN_WINDOW_DURATION == 0.5 + assert viewer.MAX_WINDOW_DURATION == 30.0 + assert viewer.MAX_ADC_VALUE == 32767 + + viewer.close() + + def test_lossless_mode(self, qapp, populated_zarr_file): + """Test viewer operates in lossless mode (no decimation).""" + viewer = ZarrViewerWindow(populated_zarr_file) + + assert viewer.pipeline is not None + assert viewer.pipeline.config.target_plot_points is None + + viewer.close() diff --git a/picostream/test_zarr_writer.py b/picostream/test_zarr_writer.py new file mode 100644 index 0000000..74e578a --- /dev/null +++ b/picostream/test_zarr_writer.py @@ -0,0 +1,310 @@ +"""Unit tests for ZarrStreamWriter.""" + +import os +import shutil +import tempfile +from typing import Generator + +import numpy as np +import pytest +import zarr + +from picostream.acquisition_rate import AcquisitionRate, DownsampleMode +from picostream.zarr_writer import ZarrStreamWriter + + +@pytest.fixture +def temp_zarr_path() -> Generator[str, None, None]: + """Provide a temporary path for Zarr files.""" + temp_dir = tempfile.mkdtemp() + zarr_path = os.path.join(temp_dir, "test.zarr") + yield zarr_path + shutil.rmtree(temp_dir, ignore_errors=True) + + +class TestZarrStreamWriter: + """Test cases for ZarrStreamWriter.""" + + def test_basic_write_read_back(self, temp_zarr_path: str) -> None: + """Test basic writing and reading back data.""" + acquisition_rate = AcquisitionRate( + hardware_rate_hz=1000.0, + num_channels=2, + downsample_ratio=1, + downsample_mode=DownsampleMode.NONE, + ) + num_channels = 2 + writer = ZarrStreamWriter(temp_zarr_path, acquisition_rate, num_channels) + + expected_data = np.array([[100, 200], [300, 400], [500, 600]], dtype=np.int16) + writer.append(expected_data) + + metadata = { + "channels": [0, 1], + "voltage_ranges": [20.0, 20.0], + } + writer.close(metadata) + + root = zarr.open(temp_zarr_path, mode="r") + stored_data = root["data"][:] + + np.testing.assert_array_equal(stored_data, expected_data) + + def test_batching_many_small_appends(self, temp_zarr_path: str) -> None: + """Test that many small appends are batched correctly.""" + acquisition_rate = AcquisitionRate( + hardware_rate_hz=100.0, + num_channels=2, + downsample_ratio=1, + downsample_mode=DownsampleMode.NONE, + ) + num_channels = 2 + writer = ZarrStreamWriter(temp_zarr_path, acquisition_rate, num_channels) + + expected_data = [] + for i in range(50): + small_chunk = np.array([[i, i + 1], [i + 2, i + 3]], dtype=np.int16) + writer.append(small_chunk) + expected_data.append(small_chunk) + + writer.close({}) + + expected_concatenated = np.concatenate(expected_data, axis=0) + root = zarr.open(temp_zarr_path, mode="r") + stored_data = root["data"][:] + + np.testing.assert_array_equal(stored_data, expected_concatenated) + + def test_chunk_alignment(self, temp_zarr_path: str) -> None: + """Test that chunks align with 1-second boundaries.""" + acquisition_rate = AcquisitionRate( + hardware_rate_hz=1000.0, + num_channels=2, + downsample_ratio=1, + downsample_mode=DownsampleMode.NONE, + ) + num_channels = 2 + writer = ZarrStreamWriter(temp_zarr_path, acquisition_rate, num_channels) + + assert writer._chunk_size == 1000 + + data = np.ones((2500, 2), dtype=np.int16) + writer.append(data) + writer.close({}) + + root = zarr.open(temp_zarr_path, mode="r") + data_array = root["data"] + + assert data_array.chunks == (1000, 2) + + def test_metadata_round_trip(self, temp_zarr_path: str) -> None: + """Test that metadata is correctly written and readable.""" + acquisition_rate = AcquisitionRate( + hardware_rate_hz=50000.0, + num_channels=2, + downsample_ratio=1, + downsample_mode=DownsampleMode.NONE, + ) + num_channels = 2 + writer = ZarrStreamWriter(temp_zarr_path, acquisition_rate, num_channels) + + test_data = np.ones((100, 2), dtype=np.int16) + writer.append(test_data) + + expected_metadata = { + "channels": [0, 1], + "voltage_ranges": [20.0, 20.0], + "offsets_v": [0.0, 0.0], + "resolution": 12, + "max_adc": 32767, + "pre_trigger_seconds": 10.0, + "compression": None, + } + writer.close(expected_metadata) + + root = zarr.open(temp_zarr_path, mode="r") + + assert list(root.attrs["channels"]) == [0, 1] + assert list(root.attrs["voltage_ranges"]) == [20.0, 20.0] + assert root.attrs["resolution"] == 12 + assert root.attrs["max_adc"] == 32767 + assert root.attrs["format_version"] == "2.0" + assert "total_samples" in root.attrs + assert "duration_s" in root.attrs + assert "start_time_iso" in root.attrs + + def test_empty_file(self, temp_zarr_path: str) -> None: + """Test closing without writing any data.""" + acquisition_rate = AcquisitionRate( + hardware_rate_hz=1000.0, + num_channels=2, + downsample_ratio=1, + downsample_mode=DownsampleMode.NONE, + ) + num_channels = 2 + writer = ZarrStreamWriter(temp_zarr_path, acquisition_rate, num_channels) + writer.close({}) + + root = zarr.open(temp_zarr_path, mode="r") + stored_data = root["data"][:] + + assert stored_data.shape == (0, 2) + assert root.attrs["total_samples"] == 0 + + def test_large_write(self, temp_zarr_path: str) -> None: + """Test writing a larger amount of data.""" + acquisition_rate = AcquisitionRate( + hardware_rate_hz=10000.0, + num_channels=2, + downsample_ratio=1, + downsample_mode=DownsampleMode.NONE, + ) + num_channels = 2 + writer = ZarrStreamWriter(temp_zarr_path, acquisition_rate, num_channels) + + large_data = np.random.randint(-10000, 10000, size=(60000, 2), dtype=np.int16) + writer.append(large_data) + writer.close({}) + + root = zarr.open(temp_zarr_path, mode="r") + stored_data = root["data"][:] + + np.testing.assert_array_equal(stored_data, large_data) + assert root.attrs["total_samples"] == 60000 + + def test_invalid_data_dtype(self, temp_zarr_path: str) -> None: + """Test that non-int16 data raises an error.""" + acquisition_rate = AcquisitionRate( + hardware_rate_hz=1000.0, + num_channels=2, + downsample_ratio=1, + downsample_mode=DownsampleMode.NONE, + ) + writer = ZarrStreamWriter(temp_zarr_path, acquisition_rate, 2) + + invalid_data = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + + with pytest.raises(ValueError, match="Data must be int16"): + writer.append(invalid_data) + + def test_invalid_data_shape(self, temp_zarr_path: str) -> None: + """Test that data with wrong shape raises an error.""" + acquisition_rate = AcquisitionRate( + hardware_rate_hz=1000.0, + num_channels=2, + downsample_ratio=1, + downsample_mode=DownsampleMode.NONE, + ) + writer = ZarrStreamWriter(temp_zarr_path, acquisition_rate, 2) + + wrong_channels = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16) + + with pytest.raises(ValueError, match="Data shape must be"): + writer.append(wrong_channels) + + def test_context_manager(self, temp_zarr_path: str) -> None: + """Test using the writer as a context manager.""" + data = np.array([[100, 200], [300, 400]], dtype=np.int16) + + acquisition_rate = AcquisitionRate( + hardware_rate_hz=1000.0, + num_channels=2, + downsample_ratio=1, + downsample_mode=DownsampleMode.NONE, + ) + with ZarrStreamWriter(temp_zarr_path, acquisition_rate, 2) as writer: + writer.append(data) + + root = zarr.open(temp_zarr_path, mode="r") + stored_data = root["data"][:] + + np.testing.assert_array_equal(stored_data, data) + + def test_discard(self, temp_zarr_path: str) -> None: + """Test discarding a file.""" + acquisition_rate = AcquisitionRate( + hardware_rate_hz=1000.0, + num_channels=2, + downsample_ratio=1, + downsample_mode=DownsampleMode.NONE, + ) + writer = ZarrStreamWriter(temp_zarr_path, acquisition_rate, 2) + data = np.array([[1, 2], [3, 4]], dtype=np.int16) + writer.append(data) + writer.discard() + + assert not os.path.exists(temp_zarr_path) + + def test_flush_partial_chunk(self, temp_zarr_path: str) -> None: + """Test that flush writes partial chunks.""" + acquisition_rate = AcquisitionRate( + hardware_rate_hz=1000.0, + num_channels=2, + downsample_ratio=1, + downsample_mode=DownsampleMode.NONE, + ) + num_channels = 2 + writer = ZarrStreamWriter(temp_zarr_path, acquisition_rate, num_channels) + + data = np.ones((100, 2), dtype=np.int16) + writer.append(data) + + writer.flush() + writer.close({}) + + root = zarr.open(temp_zarr_path, mode="r") + stored_data = root["data"][:] + + np.testing.assert_array_equal(stored_data, data) + + def test_multiple_flush_calls(self, temp_zarr_path: str) -> None: + """Test multiple flush calls work correctly.""" + acquisition_rate = AcquisitionRate( + hardware_rate_hz=1000.0, + num_channels=2, + downsample_ratio=1, + downsample_mode=DownsampleMode.NONE, + ) + num_channels = 2 + writer = ZarrStreamWriter(temp_zarr_path, acquisition_rate, num_channels) + + data1 = np.ones((100, 2), dtype=np.int16) + writer.append(data1) + writer.flush() + + data2 = np.ones((50, 2), dtype=np.int16) * 2 + writer.append(data2) + writer.flush() + + expected = np.concatenate([data1, data2], axis=0) + + writer.close({}) + root = zarr.open(temp_zarr_path, mode="r") + stored_data = root["data"][:] + + np.testing.assert_array_equal(stored_data, expected) + + def test_written_counter(self, temp_zarr_path: str) -> None: + """Test that the written counter is accurate.""" + acquisition_rate = AcquisitionRate( + hardware_rate_hz=1000.0, + num_channels=2, + downsample_ratio=1, + downsample_mode=DownsampleMode.NONE, + ) + writer = ZarrStreamWriter(temp_zarr_path, acquisition_rate, 2) + + assert writer.written == 0 + + data1 = np.ones((100, 2), dtype=np.int16) + writer.append(data1) + writer.flush() + assert writer.written == 100 + + data2 = np.ones((200, 2), dtype=np.int16) + writer.append(data2) + writer.flush() + assert writer.written == 300 + + writer.close({}) + assert writer.written == 300 diff --git a/picostream/zarr_reader.py b/picostream/zarr_reader.py new file mode 100644 index 0000000..7c8b1af --- /dev/null +++ b/picostream/zarr_reader.py @@ -0,0 +1,405 @@ +"""Zarr file reader for post-hoc analysis of saved PicoStream data. + +This module provides a reader for Zarr files created by ZarrStreamWriter, +enabling post-acquisition analysis with on-demand disk access. Data is not +loaded entirely into memory; instead, Zarr's chunk-level access is used to +read only the requested time ranges. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, Dict, List, Optional + +import numpy as np +import zarr +from loguru import logger + +from picostream.conversion_utils import adc_to_mV + +if TYPE_CHECKING: + from picostream.acquisition_rate import AcquisitionRate + + +class PicoZarrReader: + """Reader for saved PicoStream Zarr files. + + Provides access to ADC data and metadata from Zarr files with on-demand + reading. Voltage conversion and time axis construction are handled + automatically. + + Parameters + ---------- + path : str + Path to the Zarr directory store. + + Attributes + ---------- + path : Path + Path to the Zarr file. + root : zarr.Group + Root Zarr group. + data : zarr.Array + Data array containing int16 ADC samples. + sample_rate : float + Sample rate in Hz. + num_channels : int + Number of channels in the data. + total_samples : int + Total number of samples per channel. + duration_s : float + Total duration in seconds. + channels : List[int] + List of channel indices. + voltage_ranges : Dict[int, float] + Voltage range in Volts for each channel. + offsets_v : Dict[int, float] + Voltage offset in Volts for each channel. + resolution : int + ADC resolution in bits. + max_adc : int + Maximum ADC value. + downsample_ratio : int + Downsample ratio if applied. + downsample_mode : str + Downsample mode (e.g., "NONE", "AGGREGATE"). + pre_trigger_seconds : float + Pre-trigger capture duration in seconds. + format_version : str + File format version string. + compression : Optional[str] + Compression algorithm used, if any. + start_time_iso : str + ISO 8601 timestamp when capture started. + buffer_duration_s : float + Ring buffer duration in seconds. + device_id : str + Device identifier. + serial_code : Optional[str] + Device serial code, if available. + coupling : str + Input coupling mode (e.g., "DC", "AC"). + bandwidth_limiter : str + Bandwidth limiter setting (e.g., "FULL", "20MHZ"). + + Examples + -------- + >>> reader = PicoZarrReader("save_20240115_143022.zarr") + >>> print(f"Duration: {reader.duration_s}s, Channels: {reader.channels}") + >>> raw = reader.get_raw(0, 1.0) # Get first second of raw data + >>> voltage = reader.get_voltage(0, 1.0) # Get first second in mV + >>> reader.close() + """ + + def __init__(self, path: str) -> None: + """Initialise the reader and load metadata from the Zarr file.""" + self.path = Path(path) + + if not self.path.exists(): + raise FileNotFoundError(f"Zarr file not found: {path}") + + self.root = zarr.open(str(self.path), mode="r") + self.data: zarr.Array = self.root["data"] # type: ignore + + attrs = self.root.attrs + + self.num_channels: int = int(self.data.shape[1]) + self.total_samples: int = int(self.data.shape[0]) + + # Reconstruct AcquisitionRate from metadata + hardware_rate = attrs.get("hardware_rate_hz") # type: ignore + ds_ratio = attrs.get("downsample_ratio", 1) # type: ignore + ds_mode = attrs.get("downsample_mode", "NONE") # type: ignore + + if hardware_rate is None: + raise ValueError( + "Missing required 'hardware_rate_hz' metadata field. " + "This file may be from an old format version." + ) + + from picostream.acquisition_rate import AcquisitionRate, DownsampleMode + + num_channels_meta = attrs.get("num_channels", self.num_channels) + assert isinstance(num_channels_meta, (int, float)) + assert isinstance(ds_ratio, (int, float, str)) + downsample_ratio_value: int = int(float(ds_ratio)) + + self.acquisition_rate: Optional[AcquisitionRate] = AcquisitionRate( + hardware_rate_hz=float(hardware_rate), # type: ignore + num_channels=int(num_channels_meta), + downsample_ratio=downsample_ratio_value, + downsample_mode=DownsampleMode(str(ds_mode)), # type: ignore + ) + # Use acquisition_rate for all rate calculations + self.sample_rate: float = self.acquisition_rate.storage_rate_hz + + if self.sample_rate > 0: + self.duration_s: float = self.total_samples / self.sample_rate + else: + self.duration_s = 0.0 + + channels_attr = attrs.get("channels", []) # type: ignore + self.channels: List[int] = [int(ch) for ch in channels_attr] # type: ignore + + voltage_ranges_attr = attrs.get("voltage_ranges", []) # type: ignore + self.voltage_ranges: Dict[int, float] = { + ch: float(vr) + for ch, vr in zip(self.channels, voltage_ranges_attr, strict=True) # type: ignore + } + + offsets_attr = attrs.get("offsets_v", []) # type: ignore + self.offsets_v: Dict[int, float] = { + ch: float(off) + for ch, off in zip(self.channels, offsets_attr, strict=True) # type: ignore + } + + self.resolution: int = int(attrs.get("resolution", 16)) # type: ignore + # Default to 32767 (PicoScope 5000a max) if not specified + self.max_adc: int = int(attrs.get("max_adc", 32767)) # type: ignore + self.downsample_ratio: int = int( + float(ds_ratio) if isinstance(ds_ratio, str) else ds_ratio + ) # type: ignore + self.downsample_mode: str = str(ds_mode) # type: ignore + self.pre_trigger_seconds: float = float(attrs.get("pre_trigger_seconds", 0.0)) # type: ignore + self.format_version: str = str(attrs.get("format_version", "unknown")) # type: ignore + self.compression: Optional[str] = attrs.get("compression") # type: ignore + self.start_time_iso: str = str(attrs.get("start_time_iso", "")) # type: ignore + self.buffer_duration_s: float = float(attrs.get("buffer_duration_s", 0.0)) # type: ignore + self.device_id: str = str(attrs.get("device_id", "")) # type: ignore + self.serial_code: Optional[str] = attrs.get("serial_code") # type: ignore + self.coupling: str = str(attrs.get("coupling", "DC")) # type: ignore + self.bandwidth_limiter: str = str(attrs.get("bandwidth_limiter", "FULL")) # type: ignore + + logger.info( + "PicoZarrReader opened: {}, {} samples, {:.2f}s, {} channels", + self.path.name, + self.total_samples, + self.duration_s, + self.num_channels, + ) + + def _time_to_sample(self, time_s: float) -> int: + """Convert time in seconds to sample index. + + Parameters + ---------- + time_s : float + Time in seconds from start of capture. + + Returns + ------- + int + Sample index, clamped to valid range. + """ + sample_idx = self.acquisition_rate.seconds_to_samples(time_s) + return max(0, min(sample_idx, self.total_samples)) + + def _sample_to_time(self, sample_idx: int) -> float: + """Convert sample index to time in seconds. + + Parameters + ---------- + sample_idx : int + Sample index. + + Returns + ------- + float + Time in seconds. + """ + return self.acquisition_rate.samples_to_seconds(sample_idx) + + def get_raw(self, start_s: float, end_s: float) -> np.ndarray: + """Read raw ADC data from the specified time range. + + Reads int16 ADC counts from disk using Zarr's chunk-level access. + Returns a copy, not a view into the Zarr array. + + Parameters + ---------- + start_s : float + Start time in seconds from capture start. + end_s : float + End time in seconds from capture start (exclusive). + + Returns + ------- + np.ndarray + Array of shape (n_samples, num_channels) with dtype int16. + Returns empty array if start_s >= end_s or out of bounds. + + Raises + ------ + ValueError + If requesting data from a closed reader. + """ + if not hasattr(self, "data"): + raise ValueError("Reader has been closed") + + start_sample = self._time_to_sample(start_s) + end_sample = self._time_to_sample(end_s) + + if start_sample >= end_sample: + return np.empty((0, self.num_channels), dtype=np.int16) + + if start_sample >= self.total_samples: + return np.empty((0, self.num_channels), dtype=np.int16) + + end_sample = min(end_sample, self.total_samples) + + data = self.data[start_sample:end_sample] + + return np.array(data, dtype=np.int16, copy=True) + + def get_voltage(self, start_s: float, end_s: float) -> np.ndarray: + """Read voltage data (in mV) from the specified time range. + + Applies ADC-to-mV conversion using stored voltage ranges and offsets. + + Parameters + ---------- + start_s : float + Start time in seconds from capture start. + end_s : float + End time in seconds from capture start (exclusive). + + Returns + ------- + np.ndarray + Array of shape (n_samples, num_channels) with dtype float64, + containing voltage values in millivolts. + """ + raw_data = self.get_raw(start_s, end_s) + + if raw_data.size == 0: + return np.empty((0, self.num_channels), dtype=np.float64) + + voltage_data = np.empty_like(raw_data, dtype=np.float64) + + for i, ch in enumerate(self.channels): + if i < raw_data.shape[1]: + voltage_range = self.voltage_ranges.get(ch, 1.0) + offset = self.offsets_v.get(ch, 0.0) + + converted = adc_to_mV(raw_data[:, i], voltage_range, self.max_adc) + voltage_data[:, i] = converted + (offset * 1000.0) + + return voltage_data + + def get_time_axis(self, start_s: float, end_s: float) -> np.ndarray: + """Create a time axis for the specified time range. + + Parameters + ---------- + start_s : float + Start time in seconds. + end_s : float + End time in seconds. + + Returns + ------- + np.ndarray + Array of time values in seconds, aligned with get_raw/get_voltage. + """ + start_sample = self._time_to_sample(start_s) + end_sample = self._time_to_sample(end_s) + + n_samples = max(0, end_sample - start_sample) + if n_samples == 0: + return np.array([], dtype=np.float64) + + start_time = self._sample_to_time(start_sample) + end_time = self._sample_to_time(end_sample) + + return np.linspace(start_time, end_time, n_samples, endpoint=False) + + def get_metadata(self) -> Dict: + """Return all metadata as a dictionary. + + Returns + ------- + Dict + Dictionary containing all file metadata. + """ + return { + "path": str(self.path), + "sample_rate_hz": self.sample_rate, + "num_channels": self.num_channels, + "total_samples": self.total_samples, + "duration_s": self.duration_s, + "channels": self.channels, + "voltage_ranges": self.voltage_ranges, + "offsets_v": self.offsets_v, + "resolution": self.resolution, + "max_adc": self.max_adc, + "downsample_ratio": self.downsample_ratio, + "downsample_mode": self.downsample_mode, + "pre_trigger_seconds": self.pre_trigger_seconds, + "format_version": self.format_version, + "compression": self.compression, + "start_time_iso": self.start_time_iso, + "buffer_duration_s": self.buffer_duration_s, + "device_id": self.device_id, + "serial_code": self.serial_code, + "coupling": self.coupling, + "bandwidth_limiter": self.bandwidth_limiter, + } + + def get_pre_trigger_time(self) -> float: + """Get the time at which the save trigger occurred. + + The trigger point is the start of post-trigger data, which + is at pre_trigger_seconds from the beginning of the file. + + Returns + ------- + float + Time in seconds of the trigger point. + """ + return self.pre_trigger_seconds + + def get_channel_name(self, channel_index: int) -> str: + """Get the display name for a channel index. + + Parameters + ---------- + channel_index : int + Channel index (0, 1, etc.). + + Returns + ------- + str + Channel name (e.g., "A", "B", "Ch 0"). + """ + if channel_index < len(self.channels): + ch = self.channels[channel_index] + return chr(ord("A") + ch) + return f"Ch {channel_index}" + + def close(self) -> None: + """Close the reader and release resources. + + After closing, the reader can no longer be used to access data. + """ + if hasattr(self, "root"): + del self.root + del self.data + logger.info("PicoZarrReader closed: {}", self.path.name) + + def __enter__(self) -> PicoZarrReader: + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Context manager exit.""" + self.close() + + def __repr__(self) -> str: + """String representation of the reader.""" + return ( + f"PicoZarrReader(" + f"path='{self.path.name}', " + f"{self.duration_s:.2f}s, " + f"{self.sample_rate / 1e6:.2f} MS/s, " + f"{len(self.channels)} ch)" + ) diff --git a/picostream/zarr_viewer.py b/picostream/zarr_viewer.py new file mode 100644 index 0000000..0dc7ecf --- /dev/null +++ b/picostream/zarr_viewer.py @@ -0,0 +1,553 @@ +"""Standalone viewer window for saved PicoStream Zarr files. + +This module provides a Qt-based viewer for analysing saved Zarr captures +independently from the live acquisition system. Multiple viewer windows +can be open simultaneously. +""" + +from __future__ import annotations + +import sys +from typing import Dict, Optional + +import numpy as np +import pyqtgraph as pg +from loguru import logger +from PyQt6.QtCore import QSettings, Qt +from PyQt6.QtGui import QCloseEvent, QFont +from PyQt6.QtWidgets import ( + QApplication, + QDoubleSpinBox, + QFileDialog, + QHBoxLayout, + QLabel, + QMainWindow, + QPushButton, + QScrollBar, + QVBoxLayout, + QWidget, +) + +from picostream.data_pipeline import DataPipeline, PipelineConfig +from picostream.zarr_reader import PicoZarrReader + + +class ZarrViewerWindow(QMainWindow): + """Independent viewer window for saved Zarr files. + + Displays saved capture data with a scrollable timeline and configurable + window duration. Uses on-demand disk access via Zarr for memory efficiency. + + Parameters + ---------- + path : str + Path to the Zarr file to display. + parent : Optional[QWidget] + Parent widget, if any. + + Attributes + ---------- + reader : PicoZarrReader + The Zarr reader instance. + window_duration_s : float + Current visible window duration in seconds. + scroll_position_s : float + Current scroll position (start of visible window) in seconds. + """ + + MIN_WINDOW_DURATION: float = 0.5 + MAX_WINDOW_DURATION: float = 30.0 + + # Fixed max ADC value for PicoScope 5000a series (consistent with dfplot) + MAX_ADC_VALUE: int = 32767 + + def __init__(self, path: str, parent: Optional[QWidget] = None) -> None: + """Initialise the viewer window and load the Zarr file.""" + super().__init__(parent) + + self.reader = PicoZarrReader(path) + + # Load window duration from QSettings (default 2.0s) + settings = QSettings("lab", "picostream") + default_window = settings.value("zarr_viewer/window_duration", 2.0, float) + self.window_duration_s: float = max( + self.MIN_WINDOW_DURATION, + min(default_window, self.reader.duration_s), + ) + self.scroll_position_s: float = 0.0 + + # Scale factors: 1.0 = no scaling, auto-calculated from voltage ranges + self.scale_factors: Dict[int, float] = {} + self._calculate_scale_factors() + + # Initialise DataPipeline for consistent processing with dfplot + self._initialise_pipeline() + + self._setup_ui() + self._update_title() + self._update_scrollbar_range() + self._refresh_plot() + + logger.info( + "ZarrViewerWindow opened: {}, duration={:.2f}s", + path, + self.reader.duration_s, + ) + + def _setup_ui(self) -> None: + """Set up the window UI components.""" + self.setMinimumSize(800, 600) + + central_widget = QWidget() + self.setCentralWidget(central_widget) + + main_layout = QVBoxLayout(central_widget) + main_layout.setContentsMargins(10, 10, 10, 10) + + # Info panel + info_layout = QHBoxLayout() + + self.title_label = QLabel() + self.title_label.setFont(QFont("Monospace")) + info_layout.addWidget(self.title_label) + + info_layout.addStretch() + + # Metadata display + meta_text = ( + f"SR: {self.reader.sample_rate / 1e6:.2f} MS/s | " + f"Ch: {', '.join(self.reader.get_channel_name(ch) for ch in self.reader.channels)} | " + f"Pre-trigger: {self.reader.pre_trigger_seconds:.1f}s | " + f"Resolution: {self.reader.resolution}-bit" + ) + self.meta_label = QLabel(meta_text) + self.meta_label.setFont(QFont("Monospace", 9)) + info_layout.addWidget(self.meta_label) + + main_layout.addLayout(info_layout) + + # Plot area + self.plot_widget = pg.PlotWidget() + self.plot_widget.setLabel("left", "Voltage", "mV") + self.plot_widget.setLabel("bottom", "Time", "s") + self.plot_widget.showGrid(x=True, y=True) + self.plot_widget.setDownsampling(mode="peak") + self.plot_widget.setClipToView(True) + self.plot_widget.showAxis("right", False) + + self.viewBox = self.plot_widget.getViewBox() + + # Channel colours matching dfplot (RGBA as hex) + self.channel_pens = { + 0: {"color": "#00BFFF"}, # Cyan (0.0, 0.75, 1.0) → #00BFFF + 1: {"color": "#FF4444"}, # Red (1.0, 0.27, 0.27) → #FF4444 + } + + self.curves: Dict[int, pg.PlotDataItem] = {} + self.curves[0] = pg.PlotDataItem( + pen=self.channel_pens[0]["color"], width=1, name="Ch A" + ) + self.curves[1] = pg.PlotDataItem( + pen=self.channel_pens[1]["color"], width=1, name="Ch B" + ) + + self.viewBox.addItem(self.curves[0]) + self.viewBox.addItem(self.curves[1]) + + # Show/hide curves based on available channels + for ch_idx in self.curves: + self.curves[ch_idx].setVisible(ch_idx in self.reader.channels) + + # Add legend + self.legend = pg.LegendItem(offset=(50, 30)) + self.legend.setParentItem(self.plot_widget.graphicsItem()) + self._update_legend() + + main_layout.addWidget(self.plot_widget, stretch=1) + + # Controls panel + controls_layout = QHBoxLayout() + + # Window duration control + controls_layout.addWidget(QLabel("Window:")) + self.window_spin = QDoubleSpinBox() + self.window_spin.setRange( + self.MIN_WINDOW_DURATION, + min(self.MAX_WINDOW_DURATION, self.reader.duration_s), + ) + self.window_spin.setSingleStep(0.1) + self.window_spin.setDecimals(2) + self.window_spin.setValue(self.window_duration_s) + self.window_spin.setSuffix(" s") + self.window_spin.valueChanged.connect(self._on_window_changed) + controls_layout.addWidget(self.window_spin) + + controls_layout.addSpacing(20) + + # Channel visibility toggles + controls_layout.addWidget(QLabel("Channels:")) + self.channel_buttons: Dict[int, QPushButton] = {} + for ch in self.reader.channels: + btn = QPushButton(f"Ch {chr(ord('A') + ch)}") + btn.setCheckable(True) + btn.setChecked(True) + btn.clicked.connect( + lambda checked, ch=ch: self._on_channel_toggled(ch, checked) + ) + self.channel_buttons[ch] = btn + controls_layout.addWidget(btn) + + controls_layout.addStretch() + + # Info labels + self.position_label = QLabel("0.0s") + self.position_label.setFont(QFont("Monospace")) + controls_layout.addWidget(self.position_label) + + controls_layout.addWidget(QLabel("/")) + + self.duration_label = QLabel(f"{self.reader.duration_s:.1f}s") + self.duration_label.setFont(QFont("Monospace")) + controls_layout.addWidget(self.duration_label) + + main_layout.addLayout(controls_layout) + + # Horizontal scrollbar + self.scrollbar = QScrollBar(Qt.Orientation.Horizontal) + self.scrollbar.valueChanged.connect(self._on_scroll_changed) + main_layout.addWidget(self.scrollbar) + + def _update_title(self) -> None: + """Update the window title with file info.""" + filename = self.reader.path.name + self.setWindowTitle( + f"PicoStream Viewer — {filename} ({self.reader.duration_s:.1f}s)" + ) + self.title_label.setText(f"{filename}") + + def _update_axis_visibility(self) -> None: + pass + + def _calculate_scale_factors(self) -> None: + """Calculate automatic scale factors based on voltage ranges.""" + range_0 = self.reader.voltage_ranges[0] + self.scale_factors[0] = 1.0 + if 1 in self.reader.voltage_ranges: + range_1 = self.reader.voltage_ranges[1] + if range_1 > 0: + self.scale_factors[1] = range_0 / range_1 + else: + self.scale_factors[1] = 1.0 + + def _initialise_pipeline(self) -> None: + """Initialise the DataPipeline for consistent processing with dfplot.""" + if self.reader.acquisition_rate is None: + logger.warning("No acquisition_rate available, pipeline not initialised") + self.pipeline = None + return + + config = PipelineConfig( + resolution=self.reader.resolution, + voltage_ranges=self.reader.voltage_ranges, + offsets_v=self.reader.offsets_v, + max_adc_value=self.MAX_ADC_VALUE, + target_plot_points=None, # Lossless: show all samples + ) + self.pipeline = DataPipeline(config, self.reader.acquisition_rate) + logger.debug("DataPipeline initialised for ZarrViewer") + + def set_scale_factors(self, scale_factors: Dict[int, float]) -> None: + """Set explicit scale factors for channels.""" + self.scale_factors = scale_factors.copy() + self._update_legend() + logger.info("Scale factors set: {}", scale_factors) + + def _update_legend(self) -> None: + """Update the legend with current channel ranges and scale factors.""" + self.legend.clear() + for ch in sorted(self.curves.keys()): + if ch in self.reader.channels: + range_v = self.reader.voltage_ranges.get(ch, 0.0) + scale = self._get_scale_factor(ch) + is_auto = ch not in self.scale_factors + ch_name = "A" if ch == 0 else "B" + if is_auto: + label = f"Ch {ch_name}: ±{range_v}V (auto×{scale:.2f})" + else: + label = f"Ch {ch_name}: ±{range_v}V (×{scale:.2f})" + self.legend.addItem(self.curves[ch], label) + + def _get_scale_factor(self, ch: int) -> float: + """Get scale factor for a channel.""" + return self.scale_factors.get(ch, 1.0) + + def _update_scrollbar_range(self) -> None: + """Update scrollbar range based on file duration and window size.""" + max_scroll = max(0, self.reader.duration_s - self.window_duration_s) + + self.scrollbar.setRange(0, int(max_scroll * 1000)) + self.scrollbar.setPageStep(int(self.window_duration_s * 1000)) + self.scrollbar.setSingleStep(int(self.window_duration_s * 100)) + + def _on_window_changed(self, value: float) -> None: + """Handle window duration spinner change. + + Parameters + ---------- + value : float + New window duration in seconds. + """ + self.window_duration_s = value + + # Save to QSettings for persistence + settings = QSettings("lab", "picostream") + settings.setValue("zarr_viewer/window_duration", value) + + self._update_scrollbar_range() + self._clamp_scroll_position() + self._refresh_plot() + + def _on_scroll_changed(self, value: int) -> None: + """Handle scrollbar position change. + + Parameters + ---------- + value : int + Scroll position in milliseconds. + """ + self.scroll_position_s = value / 1000.0 + self._refresh_plot() + + def _clamp_scroll_position(self) -> None: + """Ensure scroll position is within valid bounds.""" + max_pos = max(0, self.reader.duration_s - self.window_duration_s) + self.scroll_position_s = max(0, min(self.scroll_position_s, max_pos)) + self.scrollbar.setValue(int(self.scroll_position_s * 1000)) + + def _on_channel_toggled(self, channel: int, visible: bool) -> None: + """Toggle channel visibility. + + Parameters + ---------- + channel : int + Channel index. + visible : bool + True to show, False to hide. + """ + if channel in self.curves: + self.curves[channel].setVisible(visible) + logger.debug("Channel {} visibility set to {}", channel, visible) + + def _refresh_plot(self) -> None: + """Refresh the plot with current scroll position and window.""" + start_s = self.scroll_position_s + end_s = min(start_s + self.window_duration_s, self.reader.duration_s) + + if start_s >= end_s: + return + + if self.pipeline is None: + logger.warning("Cannot refresh plot: pipeline not initialised") + return + + try: + raw_data = self.reader.get_raw(start_s, end_s) + + if raw_data.size == 0: + return + + self.position_label.setText(f"{start_s:.2f}s") + + # Lossless mode: decimation is always 1, no software min-max + decimation = 1 + + # Create time axis starting at scroll position (absolute time, not right-aligned) + # Unlike live plotter, Zarr viewer shows data at its actual time position. + # Note: We use reader.get_time_axis() instead of DataPipeline because: + # 1. The pipeline's get_window_display_geometry() is designed for live streaming + # with right-aligned relative windows, while the Zarr viewer needs absolute time. + # 2. The reader already has access to AcquisitionRate for consistent conversions. + time_axis = self.reader.get_time_axis(start_s, end_s) + + # Handle AGGREGATE mode hardware min-max pairs only + is_aggregate = ( + self.reader.acquisition_rate.downsample_mode.name == "AGGREGATE" + and self.reader.acquisition_rate.downsample_ratio > 1 + ) + + if is_aggregate and len(time_axis) > 1: + # AGGREGATE mode: data already has min-max pairs interleaved + # Each pair represents one time point + n_pairs = len(time_axis) // 2 + if n_pairs > 0: + pair_times = time_axis[::2][:n_pairs] + time_axis = np.repeat(pair_times, 2) + + for i, ch in enumerate(self.reader.channels): + if i >= raw_data.shape[1]: + continue + + ch_data = raw_data[:, i] + + # Use DataPipeline for processing (consistent with dfplot) + voltage_data, has_data = self.pipeline.process_channel_data( + ch_data, ch, decimation + ) + + if not has_data: + continue + + # Apply scale factor + voltage_data = voltage_data * self._get_scale_factor(ch) + + # Create display time axis matching data length + n_points = len(voltage_data) + if n_points > 0 and len(time_axis) > 0: + display_time = time_axis[:n_points] + self.curves[ch].setData(display_time, voltage_data) + + if len(time_axis) > 0: + self.viewBox.setXRange(time_axis[0], time_axis[-1], padding=0) + + self.viewBox.enableAutoRange(axis="y") + + except Exception as e: + logger.exception("Error refreshing plot: {}", e) + + def keyPressEvent(self, event) -> None: # pyright: ignore + """Handle keyboard navigation. + + Arrow keys scroll the view. Page Up/Down scroll by window size. + """ + step = self.window_duration_s * 0.1 + + if event.key() == Qt.Key.Key_Left: + self.scroll_position_s = max(0, self.scroll_position_s - step) + self._clamp_scroll_position() + self._refresh_plot() + elif event.key() == Qt.Key.Key_Right: + max_pos = self.reader.duration_s - self.window_duration_s + self.scroll_position_s = min(max_pos, self.scroll_position_s + step) + self._clamp_scroll_position() + self._refresh_plot() + elif event.key() == Qt.Key.Key_Home: + self.scroll_position_s = 0 + self._clamp_scroll_position() + self._refresh_plot() + elif event.key() == Qt.Key.Key_End: + self.scroll_position_s = self.reader.duration_s - self.window_duration_s + self._clamp_scroll_position() + self._refresh_plot() + elif event.key() == Qt.Key.Key_PageUp: + self.scroll_position_s = max( + 0, self.scroll_position_s - self.window_duration_s + ) + self._clamp_scroll_position() + self._refresh_plot() + elif event.key() == Qt.Key.Key_PageDown: + max_pos = self.reader.duration_s - self.window_duration_s + self.scroll_position_s = min( + max_pos, self.scroll_position_s + self.window_duration_s + ) + self._clamp_scroll_position() + self._refresh_plot() + else: + super().keyPressEvent(event) + + def closeEvent(self, event: QCloseEvent) -> None: # pyright: ignore + """Handle window close event.""" + if hasattr(self, "reader"): + logger.info("ZarrViewerWindow closing: {}", self.reader.path.name) + else: + logger.info("ZarrViewerWindow closing") + self._cleanup_resources() + event.accept() + + def _cleanup_resources(self) -> None: + """Clean up resources including Zarr file handles. + + Forces immediate release of file handles by explicitly clearing + references and invoking garbage collection. + """ + try: + # Clear pipeline first + if hasattr(self, "pipeline") and self.pipeline is not None: + self.pipeline.invalidate_cache() + del self.pipeline + + # Close the reader + if hasattr(self, "reader"): + self.reader.close() + del self.reader + + # Force garbage collection to release file handles promptly + import gc + + gc.collect() + + logger.debug("ZarrViewerWindow resources cleaned up") + except Exception as e: + logger.exception("Error during resource cleanup: {}", e) + + +def open_zarr_viewer_dialog( + parent: Optional[QWidget] = None, +) -> Optional[ZarrViewerWindow]: + """Open a file dialog and create a ZarrViewerWindow. + + Parameters + ---------- + parent : Optional[QWidget] + Parent widget for the file dialog. + + Returns + ------- + Optional[ZarrViewerWindow] + The opened viewer window, or None if cancelled. + """ + path = QFileDialog.getExistingDirectory( + parent, + "Open Zarr File", + "", + QFileDialog.Option.ShowDirsOnly, + ) + + if not path: + return None + + try: + return ZarrViewerWindow(path) + except Exception as e: + logger.exception("Failed to open Zarr file: {}", e) + return None + + +def main() -> None: + """Standalone Zarr viewer entry point.""" + app = QApplication(sys.argv) + + import argparse + + parser = argparse.ArgumentParser(description="Standalone Zarr file viewer.") + parser.add_argument("path", nargs="?", help="Path to Zarr file to open") + args = parser.parse_args() + + if args.path: + try: + viewer = ZarrViewerWindow(args.path) + viewer.show() + sys.exit(app.exec()) + except Exception as e: + logger.exception("Failed to open Zarr file: {}", e) + sys.exit(1) + else: + viewer = open_zarr_viewer_dialog() + if viewer: + viewer.show() + sys.exit(app.exec()) + else: + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/picostream/zarr_writer.py b/picostream/zarr_writer.py new file mode 100644 index 0000000..bd7d092 --- /dev/null +++ b/picostream/zarr_writer.py @@ -0,0 +1,250 @@ +"""Zarr-based streaming data writer for PicoStream. + +This module provides a high-performance writer for saving streaming ADC data +to Zarr format with configurable chunking and optional compression. +""" + +from __future__ import annotations + +import os +import shutil +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +import numpy as np +import zarr +from loguru import logger + +if TYPE_CHECKING: + from picostream.acquisition_rate import AcquisitionRate + + +class ZarrStreamWriter: + """Streaming data writer using Zarr format. + + This writer appends int16 ADC data to a Zarr directory store with + configurable chunk sizes. It batches small appends internally and + flushes on chunk boundaries to optimise disk I/O. + + Parameters + ---------- + path : str + Output path for the Zarr directory store. + acquisition_rate : AcquisitionRate + Acquisition rate object with all rate information. + num_channels : int + Number of channels in the data. + compression : Optional[str], default None + Compression algorithm to use. Options: None, "lz4". + + Attributes + ---------- + root : zarr.Group + The root Zarr group. + data : zarr.Array + The main data array containing int16 ADC samples. + written : int + Number of samples written so far. + + Examples + -------- + >>> writer = ZarrStreamWriter("output.zarr", acquisition_rate, 2) + >>> data = np.random.randint(-1000, 1000, size=(1000, 2), dtype=np.int16) + >>> writer.append(data) + >>> writer.close({"sample_rate_hz": 62500000.0}) + """ + + def __init__( + self, + path: str, + acquisition_rate: AcquisitionRate, + num_channels: int, + compression: Optional[str] = None, + ) -> None: + """Initialise the Zarr stream writer.""" + self.final_path = path + # Use .incomplete suffix during writing to identify incomplete files + self.path = path + ".incomplete" + self.acquisition_rate = acquisition_rate + self.num_channels = num_channels + self.compression = compression + + chunk_size = max(1000, int(acquisition_rate.storage_rate_hz * 0.2)) + self._chunk_size = chunk_size + + self.root = zarr.open(self.path, mode="w") + + self.data = self.root.create_array( + "data", + shape=(0, num_channels), + chunks=(chunk_size, num_channels), + dtype=np.int16, + ) + + self.written = 0 + self._pending: List[np.ndarray] = [] + self._pending_samples: int = 0 + + logger.debug( + "ZarrStreamWriter initialised: path={}, chunk_size={}, " + "num_channels={}, compression={}", + path, + chunk_size, + num_channels, + compression, + ) + + def append(self, data: np.ndarray) -> None: + """Append data to the Zarr store. + + Data is buffered internally and flushed to disk when the buffer + reaches the chunk size boundary. + + Parameters + ---------- + data : np.ndarray + Array of shape (n_samples, num_channels) with dtype int16. + + Raises + ------ + ValueError + If data shape or dtype is incorrect. + """ + if data.dtype != np.int16: + raise ValueError(f"Data must be int16, got {data.dtype}") + + if data.ndim != 2 or data.shape[1] != self.num_channels: + raise ValueError( + f"Data shape must be (n_samples, {self.num_channels}), got {data.shape}" + ) + + n_samples = data.shape[0] + if n_samples == 0: + return + + self._pending.append(data) + self._pending_samples += n_samples + + if self._pending_samples >= self._chunk_size: + self._flush() + + def _flush(self) -> None: + """Flush pending data to the Zarr array. + + This method concatenates all pending data and writes it to the + Zarr array, resizing as needed. It then clears the pending buffer. + """ + if not self._pending: + return + + concatenated = np.concatenate(self._pending, axis=0) + n_samples = concatenated.shape[0] + + new_size = self.written + n_samples + self.data.resize((new_size, self.num_channels)) + self.data[self.written : new_size] = concatenated + + self.written += n_samples + self._pending = [] + self._pending_samples = 0 + + logger.debug( + "ZarrStreamWriter flushed {} samples, total written: {}", + n_samples, + self.written, + ) + + @property + def total_samples(self) -> int: + """Total samples including pending (not yet flushed) data. + + Returns + ------- + int + Total samples written to disk plus samples buffered in memory. + """ + return self.written + self._pending_samples + + def flush(self) -> None: + """Force-flush any pending data to disk. + + This should be called before closing to ensure all data is written. + """ + self._flush() + logger.debug("ZarrStreamWriter flush complete") + + def close(self, metadata: Dict[str, Any]) -> None: + """Close the writer and finalise the Zarr file. + + This flushes any pending data and writes metadata to the root + attributes. On successful close, the .incomplete suffix is removed + to mark the file as complete. + + Parameters + ---------- + metadata : Dict[str, Any] + Dictionary of metadata to store in the Zarr attributes. + Common keys: hardware_rate_hz, num_channels, downsample_ratio, + downsample_mode, pre_trigger_seconds, compression. + + Notes + ----- + Files with the .incomplete suffix were not successfully closed and + may contain partial or corrupted data. These can be safely deleted. + """ + try: + self._flush() + + metadata["total_samples"] = self.written + storage_rate = self.acquisition_rate.storage_rate_hz + if storage_rate > 0: + metadata["duration_s"] = self.written / storage_rate + + metadata["format_version"] = "2.0" + metadata["compression"] = self.compression + + if "start_time_iso" not in metadata: + metadata["start_time_iso"] = datetime.now(timezone.utc).isoformat() + + for key, value in metadata.items(): + self.root.attrs[key] = value + + # Rename to final path to mark as complete + os.rename(self.path, self.final_path) + + logger.info( + "ZarrStreamWriter closed: {} with {} samples", + self.final_path, + self.written, + ) + except Exception as e: + logger.exception("Error closing ZarrStreamWriter: {}", e) + raise + + def discard(self) -> None: + """Discard the Zarr file and close the writer. + + This deletes the entire Zarr directory and should be called + when the user chooses to discard a capture. + """ + try: + self._pending = [] + self._pending_samples = 0 + + if os.path.exists(self.path): + shutil.rmtree(self.path) + logger.info("ZarrStreamWriter discarded: {}", self.path) + except Exception: + logger.exception("Error discarding Zarr file") + raise + + def __enter__(self) -> ZarrStreamWriter: + """Context manager entry.""" + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Context manager exit.""" + if exc_type is None: + self.close({}) + else: + self.discard() diff --git a/pyproject.toml b/pyproject.toml index a835307..486d2f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,23 +1,73 @@ -[build-system] -requires = ["setuptools"] -build-backend = "setuptools.build_meta" - -[project] -name = "picostream" -version = "0.2.0" -dependencies = [ - "numpy", - "loguru", - "picosdk", - "h5py", - "pyqtgraph", - "numba", - "click", -] - - -[tool.setuptools] -packages = ["picostream"] # List the package names directly - -[project.scripts] -picostream = "picostream.main:main" +[project] +name = "picostream" +version = "1.0.0" +description = "High-speed dual-channel data acquisition for PicoScope 5000a series" +readme = "README.md" +requires-python = ">=3.10" +dependencies = [ + "labdaemon>=1.0.0", + "numpy", + "numba", + "PyQt6>=6.0.0", + "pyqtgraph>=0.13.0", + "loguru>=0.6.0", + "zarr>=2.16.0", + "vispy>=0.14.0", + "picosdk", +] + +[dependency-groups] +dev = [ + "pyinstaller", + "pytest>=7.0.0", + "pytest-cov", + "ruff>=0.1.0", + "pyright>=1.1.0", + "snakeviz>=2.2.0", +] + +[tool.pyright] +# Type checking rules to disable for FFI libraries and external tools +# reportAttributeAccessIssue: ctypes/FFI library attributes (picosdk SDK) +# reportOptionalSubscript: ctypes functions that can return None +# reportOptionalMemberAccess: FFI methods with optional returns +# reportCallIssue: calling possibly-None FFI function results +# reportOptionalOperand: arithmetic on possibly-None FFI results +reportAttributeAccessIssue = false +reportOptionalSubscript = false +reportOptionalMemberAccess = false +reportCallIssue = false +reportOptionalOperand = false + +exclude = [ + "picostream/test_buffered_stream.py", + "picostream/test_live_plotter.py", + "picostream/test_zarr_writer.py", + "picostream/tests/", + "picostream/v0_backup/", +] + +[tool.coverage.run] +source = ["picostream"] +omit = [ + "picostream/tests/", + "picostream/v0_backup/**", +] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "if TYPE_CHECKING:", + "raise NotImplementedError", + "if __name__ == .__main__.:", +] + +[tool.setuptools] +packages = ["picostream"] + +[project.scripts] +picostream = "picostream.main:main" + +[project.urls] +Repository = "https://github.com/yourusername/picostream" -- cgit v1.2.3