summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSam Scholten2026-03-30 11:42:22 +1000
committerSam Scholten2026-03-30 11:42:22 +1000
commit637ddc52f4dc23ba3aa7cccef014aa85cab36b49 (patch)
treed9116fb184f32741bf1c8571ab6160be0b08acb3
parent5a7c47d626ff3fc1352b2036001e853ae211d1af (diff)
downloadpicostream-1.0.tar.gz
picostream-1.0.zip
Release v1.0.0v1.0
-rw-r--r--.gitignore73
-rw-r--r--Makefile18
-rw-r--r--PicoStream.spec102
-rw-r--r--README.md240
-rw-r--r--assets/icons/app.icobin0 -> 36439 bytes
-rw-r--r--assets/images/screenshot.pngbin0 -> 106092 bytes
-rw-r--r--justfile92
-rw-r--r--new_plan.md65
-rw-r--r--picostream/__init__.py8
-rw-r--r--picostream/acquisition_rate.py287
-rw-r--r--picostream/cli.py699
-rw-r--r--picostream/consumer.py160
-rw-r--r--picostream/data_pipeline.py622
-rw-r--r--picostream/device.py2794
-rw-r--r--picostream/dfplot.py1438
-rw-r--r--picostream/main.py2671
-rw-r--r--picostream/mock_device.py1477
-rw-r--r--picostream/pico.py535
-rw-r--r--picostream/reader.py259
-rw-r--r--picostream/ring_buffer.py331
-rw-r--r--picostream/test_buffered_stream.py344
-rw-r--r--picostream/test_data_pipeline.py496
-rw-r--r--picostream/test_live_plotter.py437
-rw-r--r--picostream/test_max_adc_fix.py73
-rw-r--r--picostream/test_rate_contract.py263
-rw-r--r--picostream/test_rate_invariants.py329
-rw-r--r--picostream/test_ring_buffer.py508
-rw-r--r--picostream/test_zarr_reader.py331
-rw-r--r--picostream/test_zarr_viewer.py223
-rw-r--r--picostream/test_zarr_writer.py310
-rw-r--r--picostream/zarr_reader.py405
-rw-r--r--picostream/zarr_viewer.py553
-rw-r--r--picostream/zarr_writer.py250
-rw-r--r--pyproject.toml96
34 files changed, 13389 insertions, 3100 deletions
diff --git a/.gitignore b/.gitignore
index d66a8e1..5c2e902 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,8 +1,65 @@
-conda_env*
-output.hdf5
-*__pycache__*
-complexipy.json
-build/*
-*egg-info*
-uv.lock
-dist/
+# Python
+__pycache__/
+*.py[cod]
+*$py.class
+*.so
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+
+# Virtual environments
+venv/
+env/
+ENV/
+.venv
+
+# IDE
+.vscode/
+.idea/
+*.swp
+*.swo
+*~
+
+# Testing
+.pytest_cache/
+.coverage
+htmlcov/
+
+# Type checking / linting
+.ruff_cache/
+complexipy.json
+
+# OS
+.DS_Store
+Thumbs.db
+
+# Data files
+*.hdf5
+*.h5
+*.zarr
+output.*
+
+# Logs
+*.log
+
+# uv lock file (we use pyproject.toml for deps)
+uv.lock
+
+# Old backups (keep for reference but not committed)
+v0_backup/
+
+# Conda envs (legacy)
+conda_env*
diff --git a/Makefile b/Makefile
deleted file mode 100644
index c9a7a95..0000000
--- a/Makefile
+++ /dev/null
@@ -1,18 +0,0 @@
-
-
-
-format:
- ruff format
- ruff check --fix --select F,B,I .
- ruff format
-
-lint:
- prospector --with-tool mypy
-
-complexity:
- complexipy --sort asc .
-
-complexity-json:
- complexipy --sort desc --output-json .
-
-
diff --git a/PicoStream.spec b/PicoStream.spec
deleted file mode 100644
index fdadedf..0000000
--- a/PicoStream.spec
+++ /dev/null
@@ -1,102 +0,0 @@
-# -*- mode: python ; coding: utf-8 -*-
-
-import glob
-import os
-import sys
-
-
-def find_libffi_dll():
- """
- Find the libffi-*.dll file required for _ctypes on Windows.
- Searches in common locations for standard, venv, and Conda Python.
- """
- if sys.platform != "win32":
- return []
-
- print("--- PyInstaller Build Environment ---")
- print(f" - Python Executable: {sys.executable}")
- print(f" - sys.prefix: {sys.prefix}")
- if hasattr(sys, "base_prefix"):
- print(f" - sys.base_prefix: {sys.base_prefix}")
- else:
- print(" - sys.base_prefix: Not available")
- conda_prefix = os.environ.get("CONDA_PREFIX")
- print(f" - CONDA_PREFIX env var: {conda_prefix}")
-
- search_paths = []
- # Active environment's DLLs directory
- search_paths.append(os.path.join(sys.prefix, "DLLs"))
-
- # Base Python installation's directories (if in a venv)
- if hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix:
- search_paths.append(os.path.join(sys.base_prefix, "DLLs"))
- search_paths.append(os.path.join(sys.base_prefix, "Library", "bin"))
-
- # Conda environment's directory (if CONDA_PREFIX is set)
- if conda_prefix:
- search_paths.append(os.path.join(conda_prefix, "Library", "bin"))
-
- print("\n--- Potential Search Paths for libffi ---")
- for p in search_paths:
- print(f" - Path: {p}, Exists? {os.path.isdir(p)}")
-
- print("\n--- Searching for libffi DLL on Windows ---")
- unique_paths = sorted(list(set(p for p in search_paths if os.path.isdir(p))))
-
- for path in unique_paths:
- print(f" - Checking: {path}")
- dll_pattern = os.path.join(path, "libffi-*.dll")
- found_dlls = glob.glob(dll_pattern)
- if found_dlls:
- dll_path = found_dlls[0]
- print(f" - Found: {dll_path}")
- return [(dll_path, ".")] # (source, destination_in_bundle)
-
- print("\nERROR: Could not find libffi-*.dll on Windows.")
- sys.exit(1)
-
-
-# --- PyInstaller Spec ---
-app_name = 'PicoStream'
-binaries = find_libffi_dll() if sys.platform == "win32" else []
-
-a = Analysis(
- ['picostream/main.py'],
- pathex=[],
- binaries=binaries,
- datas=[],
- hiddenimports=[],
- hookspath=[],
- hooksconfig={},
- runtime_hooks=[],
- excludes=[],
- win_no_prefer_redirects=False,
- win_private_assemblies=False,
- cipher=None,
- noarchive=False,
-)
-pyz = PYZ(a.pure, a.zipped_data, cipher=None)
-
-exe = EXE(
- pyz,
- a.scripts,
- a.binaries,
- a.zipfiles,
- a.datas,
- [],
- name=app_name,
- debug=False,
- bootloader_ignore_signals=False,
- strip=False,
- upx=True,
- upx_exclude=[],
- runtime_tmpdir=None,
- console=True, # useful for debugging for now
- disable_windowed_traceback=False,
- argv_emulation=False,
- target_arch=None,
- codesign_identity=None,
- entitlements_file=None,
- icon=None,
-)
-# NOTE: The absence of a COLLECT block is what defines a one-file build.
diff --git a/README.md b/README.md
index d332ae0..e130832 100644
--- a/README.md
+++ b/README.md
@@ -1,216 +1,118 @@
# PicoStream
-High-performance data acquisition from PicoScope 5000a series to HDF5, with decoupled live visualisation. Designed for robust, high-speed logging where data integrity is critical.
+A fast, simple GUI for streaming from PicoScope 5000a series scopes. Built on [labdaemon](https://github.com/qnslab/labdaemon).
-## Quick Start
+## What's new in v1.0
-Install with `uv`:
+Complete rewrite with:
-```bash
-uv pip install -e .
-```
+- **Dual channels** — acquire A and B simultaneously (16-bit = single channel only)
+- **Ring buffer** — configurable lookback (5–120s) with on-demand recording
+- **Pre-trigger capture** — include N seconds of data before you hit record
+- **Keep or discard** — stop recording and either save or trash the file
+- **VisPy plotting** — hardware-accelerated OpenGL, much faster than before
+- **Zarr format** — chunked storage, faster than HDF5 for large files
+- **PyQt6** — bumped from PyQt5
-Install the [PicoSDK](https://www.picotech.com/downloads) (Linux users: see [Arch AUR wiki](https://aur.archlinux.org/packages/picoscope) for configuration).
+![PicoStream GUI](assets/images/screenshot.png)
-Launch the GUI:
+## Install
-```bash
-just gui
-# or: uv run python -m picostream.main
-```
+Requires Python ≥3.10, PicoSDK installed.
-Or use the CLI:
+Install PicoSDK from [picotech.com/downloads](https://www.picotech.com/downloads).
-```bash
-uv run picostream -s 62.5 -o data.hdf5 --plot
+Linux users: create `/etc/udev/rules.d/99-picoscope.rules`:
```
-
-## Key Features
-
-- **Producer-consumer architecture** with large shared memory buffer pool prevents data loss
-- **Decoupled live plotting** reads from HDF5 file, cannot interfere with acquisition
-- **Live-only mode** for continuous monitoring without filling disk
-- **Numba-accelerated decimation** for efficient visualisation of large datasets
-- **GUI and CLI interfaces** for interactive and scripted workflows
-
-## Usage
-
-### GUI Application
-
-```bash
-just gui
+SUBSYSTEM=="usb", ATTR{idVendor}=="0ce9", MODE="0666"
```
+Then run `sudo udevadm control --reload-rules && sudo udevadm trigger`.
-Configure acquisition parameters, start/stop capture, and view live data. Settings persist between sessions.
-
-### Building Standalone Executable
+From PyPI:
```bash
-just build-gui
+pip install picostream
```
-Executable appears in `dist/PicoStream` (or `dist/PicoStream.exe` on Windows).
-
-### Command-Line Interface
-
-Standard acquisition (saves all data):
+For development:
```bash
-uv run picostream -s 62.5 -o my_data.hdf5 --plot
+just sync
```
-Live-only mode (limits file size):
+## Run
+
+After pip install:
```bash
-uv run picostream -s 62.5 --plot --max-buff-sec 60
+picostream
```
-View existing file:
+During development:
```bash
-uv run python -m picostream.dfplot /path/to/data.hdf5
+just gui
```
-Run `uv run picostream --help` for all options.
-
-## Documentation
-
-### Architecture
-
-PicoStream uses a producer-consumer pattern to ensure data integrity:
-
-- **Producer (`PicoDevice`)**: Interfaces with PicoScope hardware, streams ADC data into shared memory buffers in a dedicated thread
-- **Consumer (`Consumer`)**: Retrieves filled buffers from queue, writes to HDF5, returns empty buffers to pool in separate thread
-- **Buffer Pool**: Large shared memory buffers (100+ MB) prevent data loss if disk I/O slows
-- **Live Plotter (`HDF5LivePlotter`)**: Reads from HDF5 file on disk, completely decoupled from acquisition
+## Quickstart
-This ensures the critical acquisition path is never blocked by disk I/O or GUI rendering.
+1. Enter scope serial (or `MOCK` to simulate), click **Connect**
+2. Set channels, voltage ranges, sample rate, resolution
+3. **Start Acquisition** (Space bar)
+4. **Start Recording** (R) when you want to save
+5. **Stop & Keep** (K) or **Stop & Discard** (Del)
-### Data Analysis with `PicoStreamReader`
+Recordings save to `~/picostream/` by default.
-The `PicoStreamReader` class provides efficient access to HDF5 files with on-the-fly decimation.
+## Crash Cleanup
-```python
-import numpy as np
-from picostream.reader import PicoStreamReader
+If the app crashes during recording, incomplete Zarr files remain in `~/picostream/`.
+These are useless. Delete them:
-with PicoStreamReader('my_data.hdf5') as reader:
- sample_rate_sps = 1e9 / reader.sample_interval_ns
- print(f"File contains {reader.num_samples:,} samples at {sample_rate_sps / 1e6:.2f} MS/s")
-
- # Iterate through file with 10x decimation
- for times, voltages_mv in reader.get_block_iter(
- chunk_size=10_000_000, decimation_factor=10, decimation_mode='min_max'
- ):
- print(f"Processed {voltages_mv.size} decimated points")
+```bash
+rm -rf ~/picostream/*.zarr
```
-### API Reference: `PicoStreamReader`
-
-#### Initialisation
-
-**`__init__(self, hdf5_path: str)`**
-
-Initialises the reader. File is opened when used as context manager.
-
-#### Metadata Attributes
-
-Available after opening:
+Change save location in the GUI (Record Controls → Change...).
-- `num_samples: int` - Total raw samples in dataset
-- `sample_interval_ns: float` - Time interval between samples (nanoseconds)
-- `voltage_range_v: float` - Configured voltage range (e.g., `20.0` for ±20V)
-- `max_adc_val: int` - Maximum ADC count value (e.g., 32767)
-- `analog_offset_v: float` - Configured analogue offset (Volts)
-- `downsample_mode: str` - Hardware downsampling mode (`'average'` or `'aggregate'`)
-- `hardware_downsample_ratio: int` - Hardware downsampling ratio
+## Shortcuts
-#### Methods
+| Key | Action |
+|-----|--------|
+| Space | Start/stop acquisition |
+| 1, 2 | Toggle channels |
+| R | Start recording |
+| K | Stop + keep |
+| Del | Stop + discard |
+| Ctrl+O | Open existing Zarr file |
-**`get_block_iter(self, chunk_size: int = 1_000_000, decimation_factor: int = 1, decimation_mode: str = "mean") -> Generator`**
+## Development
-Generator yielding `(times, voltages)` tuples for entire dataset. Recommended for large files.
-
-- `chunk_size`: Number of raw samples per chunk
-- `decimation_factor`: Decimation factor
-- `decimation_mode`: `'mean'` or `'min_max'`
-
-**`get_next_block(self, chunk_size: int, decimation_factor: int = 1, decimation_mode: str = "mean") -> Tuple | None`**
-
-Retrieves next sequential block. Returns `None` at end of file. Use `reset()` to restart.
-
-**`get_block(self, size: int, start: int = 0, decimation_factor: int = 1, decimation_mode: str = "mean") -> Tuple`**
-
-Retrieves specific block from file.
-
-- `size`: Number of raw samples
-- `start`: Starting sample index
-
-**`reset(self) -> None`**
-
-Resets internal counter for `get_next_block()`.
-
-### API Reference: `HDF5LivePlotter`
-
-PyQt5 widget for real-time visualisation of HDF5 files. Can be used standalone or embedded in custom applications.
-
-#### Initialisation
-
-**`__init__(self, hdf5_path: str = "/tmp/data.hdf5", update_interval_ms: int = 50, display_window_seconds: float = 0.5, decimation_factor: int = 150, max_display_points: int = 4000)`**
-
-- `hdf5_path`: Path to HDF5 file
-- `update_interval_ms`: Update frequency (milliseconds)
-- `display_window_seconds`: Duration of data to display (seconds)
-- `decimation_factor`: Decimation factor for display
-- `max_display_points`: Maximum points to display (prevents GUI slowdown)
-
-#### Methods
-
-**`set_hdf5_path(self, hdf5_path: str) -> None`**
-
-Updates monitored HDF5 file path.
-
-**`start_updates(self) -> None`**
-
-Begins periodic plot updates.
-
-**`stop_updates(self) -> None`**
-
-Stops periodic updates.
-
-**`save_screenshot(self) -> None`**
-
-Saves PNG screenshot. Called automatically when 'S' key is pressed.
+```bash
+just gui # run from source
+just test # run tests
+just test-cov # run tests with coverage
+just format # format all code
+just lint # lint code
+just typecheck # type check
+just build # build PyInstaller executable
+just clean # remove build artifacts
+just sync # install dependencies
+just update # update dependencies
+just help # see all commands
+```
-#### Usage Example
+## Acknowledgements
-```python
-from PyQt5.QtWidgets import QApplication
-from picostream.dfplot import HDF5LivePlotter
-
-app = QApplication([])
-plotter = HDF5LivePlotter(
- hdf5_path="my_data.hdf5",
- display_window_seconds=1.0,
- decimation_factor=100
-)
-plotter.show()
-plotter.start_updates()
-app.exec_()
-```
+This began as a fork of [JoshHarris2108/pico_streaming](https://github.com/JoshHarris2108/pico_streaming) (unlicensed). The original producer-consumer architecture and PicoSDK integration came from Josh's work.
## Changelog
-### Version 0.2.0
-- Added live-only mode with `--max-buff-sec` option
-- Added GUI application for interactive control
-- Improved plotter to handle buffer resets gracefully
-- Added total sample count tracking across buffer resets
-- Skip verification step in live-only mode for better performance
-
-### Version 0.1.0
-- Initial release with core streaming and plotting functionality
+### v1.0.0
+Complete rewrite — new architecture, new file format, dual channels, GUI-only.
-## Acknowledgements
+### v0.2.0
+CLI + HDF5 + single channel + PyQt5/PyQtGraph.
-This package began as a fork of [JoshHarris2108/pico_streaming](https://github.com/JoshHarris2108/pico_streaming) (unlicensed). Thanks to Josh for the original architecture.
+### v0.1.0
+Initial release.
diff --git a/assets/icons/app.ico b/assets/icons/app.ico
new file mode 100644
index 0000000..e0f4c0f
--- /dev/null
+++ b/assets/icons/app.ico
Binary files differ
diff --git a/assets/images/screenshot.png b/assets/images/screenshot.png
new file mode 100644
index 0000000..3fb0d26
--- /dev/null
+++ b/assets/images/screenshot.png
Binary files differ
diff --git a/justfile b/justfile
index 237121f..c1e1b1f 100644
--- a/justfile
+++ b/justfile
@@ -1,12 +1,96 @@
-set windows-shell := ["C:\\Program Files\\Git\\bin\\sh.exe","-c"]
+# PicoStream - Justfile
+set windows-shell := ["C:\\Program Files\\Git\\bin\\sh.exe", "-c"]
+
+# Default: show available commands
+default:
+ @just --list
# Run the picostream GUI application
gui:
PYTHONPATH=. uv run picostream/main.py
-# Build the picostream GUI executable
-build-gui:
+# Run with MOCK device (no hardware needed)
+mock:
+ PYTHONPATH=. uv run python -c "import os; os.environ['MOCK_PICO'] = '1'; from picostream.main import main; main()"
+
+# Format all code
+fmt:
+ uv run ruff format picostream/
+
+# Check code formatting without fixing
+fmt-check:
+ uv run ruff format --check picostream/
+
+# Lint code
+lint:
+ uv run ruff check picostream/
+
+# Fix linting issues
+lint-fix:
+ uv run ruff check --fix picostream/.
+
+# Type check code
+type-check:
+ uv run pyright picostream/
+
+# Run all checks (format check, lint, type-check)
+check: fmt-check lint type-check
+ @echo "✓ All checks passed"
+
+# Run all checks and fix what can be fixed
+fix: lint-fix fmt
+ @echo "✓ Code fixed"
+
+# Run all dev tools (tests run even if type-check fails)
+devtools: fix fmt-check lint test type-check
+
+# Run all tests
+test:
+ PYTHONPATH=. uv run pytest picostream -v
+
+# Run tests with coverage
+test-cov:
+ PYTHONPATH=. uv run pytest picostream --cov=picostream --cov-report=term-missing -v
+
+# Run tests and generate HTML coverage report
+test-cov-html:
+ PYTHONPATH=. uv run pytest picostream --cov=picostream --cov-report=html -v
+ @echo "Coverage report: file://$(pwd)/htmlcov/index.html"
+
+# Profile the PicoStream GUI (outputs to profile.stats)
+profile:
+ PYTHONPATH=. uv run python -m cProfile -o profile.stats picostream/main.py
+
+# View profile results in browser (snakeviz)
+profile-view:
+ uv run snakeviz profile.stats
+
+# Print top 20 functions by cumulative time (quick CLI view)
+profile-text:
+ uv run python -c "import pstats; p = pstats.Stats('profile.stats'); p.sort_stats('cumulative').print_stats(20)"
+
+# Build the picostream GUI executable with PyInstaller
+build:
uv pip install pyinstaller pyinstaller-hooks-contrib
- uv run pyinstaller --clean PicoStream.spec --noconfirm
+ uv run pyinstaller \
+ --name PicoStream \
+ --onefile \
+ --windowed \
+ --collect-all vispy \
+ --icon assets/icons/app.ico \
+ picostream/main.py
+
+# Sync dependencies (install/update)
+sync:
+ uv sync
+
+# Update all dependencies
+update:
+ uv lock --upgrade
+# Clean build artifacts and caches
+clean:
+ rm -rf build/ dist/ picostream/*.spec __pycache__ .pytest_cache .ruff_cache
+ find picostream -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true
+ find picostream -type f -name "*.pyc" -delete 2>/dev/null || true
diff --git a/new_plan.md b/new_plan.md
deleted file mode 100644
index 6c7186e..0000000
--- a/new_plan.md
+++ /dev/null
@@ -1,65 +0,0 @@
- Design
-
-The application will be a single-window GUI built with PyQt5. It will be composed of three main components: a main
-window, a refactored plotter widget, and a background worker for data acquisition.
-
- 1 PicoStreamMainWindow (QMainWindow): This will be the application's central component, serving as the main entry point.
- • Layout: It will feature a two-panel layout. The left panel will contain all user-configurable settings for the
- acquisition (e.g., sample rate, voltage range, output file). The right panel will contain the embedded live plot.
- • Control: It will have "Start" and "Stop" buttons to manage the acquisition lifecycle. It will manage the
- application's state (e.g., idle, acquiring, error).
- • Persistence: It will use QSettings to automatically save user-entered settings on exit and load them on startup.
- • Lifecycle: It will be responsible for creating and managing the background worker thread and ensuring a graceful
- shutdown.
- 2 HDF5LivePlotter (QWidget): The existing plotter will be refactored from a QMainWindow into a QWidget.
- • Responsibility: Its sole responsibility will be to monitor the HDF5 file and display the live data. It will no
- longer be a top-level window or control the application's lifecycle.
- • Integration: An instance of this widget will be created and embedded directly into the right-hand panel of the
- PicoStreamMainWindow.
- 3 StreamerWorker (QObject): This class will manage the acquisition task in a background thread to keep the GUI
- responsive.
- • Execution: It will be moved to a QThread. Its primary method will instantiate the Streamer class with parameters
- from the GUI and call the blocking Streamer.run() method.
- • Communication: It will use Qt signals to report its status (e.g., finished, error) back to the PicoStreamMainWindow
- in a thread-safe manner. The main window will connect to these signals to update the UI, for example, by
- re-enabling the "Start" button upon completion.
-
- Phased Implementation Plan
-
-This plan breaks the work into five distinct, sequential phases.
-
-Phase 1: Project Restructuring and GUI Shell The goal is to set up the new file structure and a basic, non-functional GUI
-window.
-
- 1 Rename picostream/main.py to picostream/cli.py.
- 2 Create a new, empty picostream/main.py to serve as the GUI entry point.
- 3 In the new main.py, create a PicoStreamMainWindow class with a simple layout containing placeholders for the settings
- panel and the plot.
- 4 Update the justfile with a new target to run the GUI application.
-
-Phase 2: Background Worker Implementation The goal is to run the data acquisition in a background thread, controlled by
-the GUI.
-
- 1 In picostream/main.py, create the StreamerWorker class inheriting from QObject.
- 2 Implement the QThread worker pattern in PicoStreamMainWindow to start the acquisition when a "Start" button is clicked
- and to signal a stop using the existing shutdown_event.
- 3 Connect the worker's finished and error signals to GUI methods that update the UI state (e.g., re-enable buttons).
-
-Phase 3: GUI Controls and Settings Persistence The goal is to make the acquisition configurable through the GUI and to
-remember settings.
-
- 1 Populate the settings panel in PicoStreamMainWindow with input widgets for all acquisition parameters.
- 2 Pass the values from these widgets to the StreamerWorker when starting an acquisition.
- 3 Implement load_settings and save_settings methods using QSettings.
-
-Phase 4: Plotter Integration The goal is to embed the live plot directly into the main window.
-
- 1 In picostream/dfplot.py, refactor the HDF5LivePlotter class to inherit from QWidget instead of QMainWindow. Remove its
- window-management logic.
- 2 In PicoStreamMainWindow, replace the plot placeholder with an instance of the refactored HDF5LivePlotter widget.
-
-Phase 5: Packaging The goal is to create a standalone, distributable executable.
-
- 1 Add a new build target to the justfile that uses PyInstaller to bundle the application.
- 2 Configure the build to handle dependencies, particularly creating a hook for Numba if necessary.
- 3 Test the final executable.
diff --git a/picostream/__init__.py b/picostream/__init__.py
index 1fb1a8e..ea077ae 100644
--- a/picostream/__init__.py
+++ b/picostream/__init__.py
@@ -1,7 +1,3 @@
-"""PicoStream: High-performance PicoScope data acquisition and visualization."""
+"""PicoStream - fast GUI streaming for PicoScope 5000a series."""
-from .reader import PicoStreamReader
-from .dfplot import HDF5LivePlotter
-
-__version__ = "0.2.0"
-__all__ = ["PicoStreamReader", "HDF5LivePlotter"]
+__version__ = "1.0.0"
diff --git a/picostream/acquisition_rate.py b/picostream/acquisition_rate.py
new file mode 100644
index 0000000..f9c3e18
--- /dev/null
+++ b/picostream/acquisition_rate.py
@@ -0,0 +1,287 @@
+"""Acquisition rate abstraction for consistent rate semantics.
+
+This module provides a single source of truth for all rate calculations
+throughout the picostream pipeline. Rates are represented as immutable
+objects that know the relationships between hardware rate, downsampling,
+channels, and effective rates at different stages.
+
+All time/sample conversions should go through the AcquisitionRate object
+to ensure consistency and avoid double-counting of downsampling ratios.
+"""
+
+from dataclasses import dataclass
+from enum import Enum
+from typing import TypeAlias
+
+
+class DownsampleMode(Enum):
+ """Downsampling modes supported by Picoscope hardware."""
+
+ NONE = "NONE"
+ AVERAGE = "AVERAGE"
+ AGGREGATE = "AGGREGATE"
+ DECIMATE = "DECIMATE"
+
+
+StorageSampleCount: TypeAlias = int
+"""Number of samples in ring buffer (post-hardware-downsampling).
+
+In AGGREGATE mode, these are interleaved min/max pairs.
+"""
+
+DisplayPointCount: TypeAlias = int
+"""Number of points to display on screen (post-software-decimation).
+
+Always includes min/max pairs for envelope preservation.
+"""
+
+
+@dataclass(frozen=True)
+class AcquisitionRate:
+ """Encapsulates all rate information for an acquisition.
+
+ This is the single source of truth for rate semantics throughout
+ the pipeline. All rate calculations derive from these four core facts.
+ The object is frozen to prevent accidental misuse or reinterpretation
+ of rates at different stages.
+
+ Attributes
+ ----------
+ hardware_rate_hz : float
+ Raw ADC rate from hardware. For multi-channel systems, this is the
+ total rate across all channels.
+ Example: 125 MHz for 2 channels at 62.5 MS/s per channel.
+
+ num_channels : int
+ Number of active acquisition channels.
+
+ downsample_ratio : int
+ Hardware downsampling factor. 1 means no downsampling.
+ Example: 10 means every 10th sample is kept.
+
+ downsample_mode : DownsampleMode
+ Type of downsampling applied: NONE, AVERAGE, AGGREGATE, or DECIMATE.
+ AGGREGATE produces min/max pairs (2 samples per time point).
+ """
+
+ hardware_rate_hz: float
+ num_channels: int
+ downsample_ratio: int
+ downsample_mode: DownsampleMode
+
+ def __post_init__(self) -> None:
+ """Validate rate parameters."""
+ if self.hardware_rate_hz <= 0:
+ raise ValueError(
+ f"hardware_rate_hz must be positive, got {self.hardware_rate_hz}"
+ )
+ if self.num_channels <= 0:
+ raise ValueError(f"num_channels must be positive, got {self.num_channels}")
+ if self.downsample_ratio < 1:
+ raise ValueError(
+ f"downsample_ratio must be >= 1, got {self.downsample_ratio}"
+ )
+ if not isinstance(self.downsample_mode, DownsampleMode):
+ raise ValueError(
+ f"downsample_mode must be DownsampleMode, got {type(self.downsample_mode)}"
+ )
+
+ @property
+ def per_channel_rate_hz(self) -> float:
+ """Per-channel rate after hardware downsampling.
+
+ This is the rate at which individual channels produce time points
+ after downsampling is applied.
+
+ Returns
+ -------
+ float
+ Samples per second per channel.
+
+ Example
+ -------
+ hardware_rate_hz=125e6, num_channels=2, downsample_ratio=10
+ → per_channel_rate_hz = 125e6 / (2 * 10) = 6.25e6 MS/s
+ """
+ return self.hardware_rate_hz / (self.num_channels * self.downsample_ratio)
+
+ @property
+ def storage_rate_hz(self) -> float:
+ """Total rate at which samples arrive at ring buffer.
+
+ This is the rate to use for ring buffer capacity calculations and
+ writing to storage. It accounts for AGGREGATE mode producing 2 samples
+ per downsampled point (min/max pairs interleaved).
+
+ For non-AGGREGATE modes: storage_rate = hardware_rate / downsample_ratio
+ For AGGREGATE mode: storage_rate = (hardware_rate / downsample_ratio) * 2
+
+ Returns
+ -------
+ float
+ Total samples per second arriving at the buffer.
+
+ Example
+ -------
+ hardware_rate_hz=125e6, downsample_ratio=10, mode=DECIMATE
+ → storage_rate_hz = 125e6 / 10 = 12.5 MS/s
+
+ hardware_rate_hz=125e6, downsample_ratio=10, mode=AGGREGATE
+ → storage_rate_hz = (125e6 / 10) * 2 = 25 MS/s
+ """
+ base_rate = self.hardware_rate_hz / self.downsample_ratio
+ if (
+ self.downsample_mode == DownsampleMode.AGGREGATE
+ and self.downsample_ratio > 1
+ ):
+ return base_rate * 2
+ return base_rate
+
+ @property
+ def display_rate_per_channel_hz(self) -> float:
+ """Per-channel rate for time axis generation in display.
+
+ Used by the plotter for converting sample counts to time durations.
+ This is the "semantic" time rate for the user - how fast time flows
+ in the plot regardless of how samples are packed in the buffer.
+
+ Returns
+ -------
+ float
+ Samples per second per channel as perceived in display time.
+ """
+ return self.per_channel_rate_hz
+
+ def samples_to_seconds(self, n_samples: StorageSampleCount) -> float:
+ """Convert sample count to time duration.
+
+ Operates on per-channel sample counts (time-points), properly handling
+ AGGREGATE mode where storage includes interleaved min/max pairs.
+
+ The conversion uses per_channel_rate_hz, not storage_rate_hz, because
+ this method is designed for time-point counts (one value per channel
+ per time step), not total storage sample counts.
+
+ Parameters
+ ----------
+ n_samples : StorageSampleCount
+ Number of time-points (samples per channel). For AGGREGATE mode,
+ this counts each min/max pair as 2 samples but represents 1 time-point.
+
+ Returns
+ -------
+ float
+ Duration in seconds.
+
+ Example
+ -------
+ If per_channel_rate = 6.25 MS/s and n_samples = 6.25M:
+ - NONE mode: 6.25e6 / 6.25e6 = 1.0 second
+ - AGGREGATE mode: (6.25e6 // 2) / 6.25e6 = 0.5 second
+ (because 6.25M samples = 3.125M pairs = 0.5s of data)
+ """
+ if (
+ self.downsample_mode == DownsampleMode.AGGREGATE
+ and self.downsample_ratio > 1
+ ):
+ # Each min/max pair represents one time point
+ n_time_points = n_samples // 2
+ return n_time_points / self.display_rate_per_channel_hz
+ return n_samples / self.display_rate_per_channel_hz
+
+ def seconds_to_samples(self, duration_s: float) -> StorageSampleCount:
+ """Convert time duration to sample count.
+
+ Returns the number of storage samples (including min/max pairs for
+ AGGREGATE mode) needed to represent the given duration.
+
+ Note: This operates on per-channel rates. The returned count is for
+ storage capacity calculations. For time-point counts (one per channel),
+ divide the result by num_channels.
+
+ Parameters
+ ----------
+ duration_s : float
+ Duration in seconds.
+
+ Returns
+ -------
+ StorageSampleCount
+ Number of storage samples needed. For AGGREGATE mode with
+ downsampling, this is 2× the number of time-points.
+
+ Example
+ -------
+ If per_channel_rate = 6.25 MS/s and duration = 1.0s:
+ - NONE mode: 1.0 * 6.25e6 = 6.25M samples
+ - AGGREGATE mode: (1.0 * 6.25e6) * 2 = 12.5M samples
+ (because we store both min and max for each time point)
+ """
+ n_time_points = int(duration_s * self.display_rate_per_channel_hz)
+ if (
+ self.downsample_mode == DownsampleMode.AGGREGATE
+ and self.downsample_ratio > 1
+ ):
+ return n_time_points * 2
+ return n_time_points
+
+ def get_display_duration(
+ self, n_display_samples: DisplayPointCount, decimation: int
+ ) -> float:
+ """Calculate time duration represented by display samples.
+
+ DEPRECATED: This method mixes concerns by trying to reverse-engineer
+ duration from display points. Use samples_to_seconds() with storage
+ samples instead.
+
+ Accounts for both software min-max decimation (which produces pairs)
+ and the hardware downsampling that's already baked into the rate.
+
+ Parameters
+ ----------
+ n_display_samples : DisplayPointCount
+ Number of display points (after software decimation).
+ decimation : int
+ Software decimation factor used for display (minimum 1).
+
+ Returns
+ -------
+ float
+ Duration in seconds that the display data represents.
+
+ Notes
+ -----
+ This is used by the plotter to know how much time is represented
+ by the currently displayed points, independent of how many points
+ are actually being drawn on screen.
+ """
+ if decimation <= 0 or n_display_samples <= 0:
+ return 0.0
+
+ # Software min-max decimation produces pairs
+ n_time_points = n_display_samples // 2
+ original_samples = n_time_points * decimation
+
+ return original_samples / self.display_rate_per_channel_hz
+
+ def __str__(self) -> str:
+ """Human-readable description of the acquisition rate."""
+ return (
+ f"AcquisitionRate("
+ f"hw={self.hardware_rate_hz / 1e6:.1f}MS/s, "
+ f"ch={self.num_channels}, "
+ f"ds={self.downsample_ratio}x {self.downsample_mode.value}, "
+ f"→ {self.per_channel_rate_hz / 1e6:.2f}MS/s/ch, "
+ f"{self.storage_rate_hz / 1e6:.2f}MS/s storage"
+ f")"
+ )
+
+ def __repr__(self) -> str:
+ """Development-friendly representation."""
+ return (
+ f"AcquisitionRate("
+ f"hardware_rate_hz={self.hardware_rate_hz}, "
+ f"num_channels={self.num_channels}, "
+ f"downsample_ratio={self.downsample_ratio}, "
+ f"downsample_mode={self.downsample_mode})"
+ )
diff --git a/picostream/cli.py b/picostream/cli.py
deleted file mode 100644
index d8e7ffb..0000000
--- a/picostream/cli.py
+++ /dev/null
@@ -1,699 +0,0 @@
-from __future__ import annotations
-
-import queue
-import signal
-import sys
-import threading
-import time
-from datetime import datetime
-from typing import TYPE_CHECKING, List, Optional
-
-if TYPE_CHECKING:
- from PyQt5.QtWidgets import QApplication
-
-import click
-import h5py
-import numpy as np
-from loguru import logger
-
-from . import __version__
-from .consumer import Consumer
-from .pico import PicoDevice
-
-
-class Streamer:
- """Orchestrates the Picoscope data acquisition process.
-
- This class initializes the Picoscope device (producer), the HDF5 writer
- (consumer), and the live plotter. It manages the threads, queues, and
- graceful shutdown of the entire application.
-
- Supports both standard acquisition mode (saves all data) and live-only mode
- (limits buffer size using max_buffer_seconds parameter).
- """
-
- def __init__(
- self,
- sample_rate_msps: float = 62.5,
- resolution_bits: int = 12,
- channel_range_str: str = "PS5000A_20V",
- enable_live_plot: bool = False,
- output_file: str = "./output.hdf5",
- debug: bool = False,
- plot_window_s: float = 0.5,
- plot_points: int = 4000,
- hardware_downsample: int = 1,
- downsample_mode: str = "average",
- offset_v: float = 0.0,
- max_buffer_seconds: Optional[float] = None,
- is_gui_mode: bool = False,
- y_min: Optional[float] = None,
- y_max: Optional[float] = None,
- bandwidth_limiter: str = "full",
- ) -> None:
- # --- Configuration ---
- self.output_file = output_file
- self.debug = debug
- self.enable_live_plot = enable_live_plot
- self.is_gui_mode = is_gui_mode
- self.plot_window_s = plot_window_s
- self.max_buffer_seconds = max_buffer_seconds
- self.y_min = y_min
- self.y_max = y_max
-
- (
- sample_rate_msps,
- pico_downsample_ratio,
- pico_ratio_mode,
- offset_v,
- ) = self._validate_config(
- resolution_bits,
- sample_rate_msps,
- channel_range_str,
- hardware_downsample,
- downsample_mode,
- offset_v,
- )
- # Dynamically size buffers to hold a specific duration of data. This makes
- # memory usage proportional to the data rate, providing a consistent
- # time-based buffer to handle processing latencies.
- effective_rate_sps = (sample_rate_msps * 1e6) / pico_downsample_ratio
-
- # Consumer buffers (for writing to HDF5) are sized to hold 1 second of data.
- # This is a good balance, as larger buffers lead to more efficient disk writes
- # but use more RAM.
- consumer_buffer_duration_s = 1.0
- self.consumer_buffer_size = int(
- effective_rate_sps * consumer_buffer_duration_s
- )
- if downsample_mode == "aggregate":
- self.consumer_buffer_size *= 2
- self.consumer_num_buffers = 5 # A pool of 5 buffers
-
- # The Picoscope driver buffer is sized to hold 0.5 seconds of data. This
- # buffer receives data directly from the hardware. A smaller size ensures
- # that the application receives data in timely chunks, reducing latency.
- driver_buffer_duration_s = 0.5
- self.pico_driver_buffer_size = int(
- effective_rate_sps * driver_buffer_duration_s
- )
- self.pico_driver_num_buffers = (
- 1 # A single large buffer is efficient for the driver
- )
-
- logger.info(
- f"Consumer buffer sized to {self.consumer_buffer_size:,} samples "
- f"({consumer_buffer_duration_s}s at effective rate)"
- )
- logger.info(
- f"Pico driver buffer sized to {self.pico_driver_buffer_size:,} samples "
- f"({driver_buffer_duration_s}s at effective rate)"
- )
-
- # --- Plotting Decimation ---
- # Calculate the decimation factor needed to achieve the target number of plot points.
- points_per_timestep = 2 if downsample_mode == "aggregate" else 1
- samples_in_window = effective_rate_sps * plot_window_s * points_per_timestep
- self.decimation_factor = max(1, int(samples_in_window / plot_points))
- logger.info(
- f"Plotting with target of {plot_points} points. "
- f"Calculated decimation factor: {self.decimation_factor}"
- )
-
- # Picoscope hardware settings
- self.pico_resolution = f"PS5000A_DR_{resolution_bits}BIT"
- self.pico_channel_range = channel_range_str
- self.pico_sample_interval_ns = int(1000 / sample_rate_msps)
- self.pico_sample_unit = "PS5000A_NS"
-
- # Streaming settings
- self.pico_auto_stop = 0 # Don't auto stop
- self.pico_auto_stop_stream = False
- # --- End Configuration ---
-
- # --- System Components ---
- self.shutdown_event: threading.Event = threading.Event()
- data_queue: queue.Queue[int] = queue.Queue()
- empty_queue: queue.Queue[int] = queue.Queue()
- data_buffers: List[np.ndarray] = []
-
- # Pre-allocate a pool of numpy arrays for data transfer and populate the
- # empty_queue with their indices.
- for idx in range(self.consumer_num_buffers):
- data_buffers.append(np.empty((self.consumer_buffer_size,), dtype="int16"))
- empty_queue.put(idx)
-
- # --- Producer ---
- self.pico_device: PicoDevice = PicoDevice(
- 0, # handle
- self.pico_resolution,
- self.pico_driver_buffer_size,
- self.pico_driver_num_buffers,
- self.consumer_buffer_size,
- data_queue,
- empty_queue,
- data_buffers,
- self.shutdown_event,
- downsample_mode=downsample_mode,
- )
-
- # Store bandwidth limiter setting
- self.bandwidth_limiter = bandwidth_limiter
-
- self.pico_device.set_channel(
- "PS5000A_CHANNEL_A", 1, "PS5000A_DC", self.pico_channel_range, offset_v
- )
- self.pico_device.set_channel(
- "PS5000A_CHANNEL_B", 0, "PS5000A_DC", self.pico_channel_range, 0.0
- )
-
- # Set bandwidth filter for Channel A if specified
- if self.bandwidth_limiter == "20MHz":
- self.pico_device.set_bandwidth_filter("PS5000A_CHANNEL_A", "PS5000A_BW_20MHZ")
- else:
- self.pico_device.set_bandwidth_filter("PS5000A_CHANNEL_A", "PS5000A_BW_FULL")
- self.pico_device.set_data_buffer("PS5000A_CHANNEL_A", 0, pico_ratio_mode)
- self.pico_device.configure_streaming_var(
- self.pico_sample_interval_ns,
- self.pico_sample_unit,
- 0, # pre-trigger samples
- pico_downsample_ratio,
- pico_ratio_mode,
- self.pico_auto_stop,
- self.pico_auto_stop_stream,
- )
-
- # Run streaming once to get the actual sample interval from the driver
- self.pico_device.run_streaming()
-
- # --- Consumer ---
- # Prepare metadata for the consumer
- acquisition_start_time_utc = datetime.utcnow().isoformat() + "Z"
- was_live_mode = self.max_buffer_seconds is not None
-
- # Get metadata from configured device and pass to consumer
- metadata = self.pico_device.get_metadata(
- acquisition_start_time_utc=acquisition_start_time_utc,
- picostream_version=__version__,
- acquisition_command="", # Will be set later in main()
- was_live_mode=was_live_mode
- )
-
- # Calculate max samples for live-only mode
- max_samples = None
- if self.max_buffer_seconds:
- max_samples = int(effective_rate_sps * self.max_buffer_seconds)
- if downsample_mode == "aggregate":
- max_samples *= 2
- logger.info(f"Live-only mode: limiting buffer to {self.max_buffer_seconds}s ({max_samples:,} samples)")
-
- self.consumer: Consumer = Consumer(
- self.consumer_buffer_size,
- data_queue,
- empty_queue,
- data_buffers,
- output_file,
- self.shutdown_event,
- metadata=metadata,
- max_samples=max_samples,
- )
-
- # --- Threads ---
- self.consumer_thread: threading.Thread = threading.Thread(
- target=self.consumer.consume
- )
- self.pico_thread: threading.Thread = threading.Thread(
- target=self.pico_device.run_capture
- )
-
- # --- Signal Handling ---
- # Only set the signal handler if not in GUI mode.
- # In GUI mode, the main window handles shutdown signals.
- # In CLI plot mode, the main() function handles SIGINT to quit the Qt app.
- if not self.is_gui_mode and not self.enable_live_plot:
- signal.signal(signal.SIGINT, self.signal_handler)
-
- # --- Live Plotting (optional) ---
- self.start_time: Optional[float] = None
-
- def update_acquisition_command(self, command: str) -> None:
- """Update the acquisition command in the consumer's metadata."""
- self.consumer.metadata["acquisition_command"] = command
-
- def _validate_config(
- self,
- resolution_bits: int,
- sample_rate_msps: float,
- channel_range_str: str,
- hardware_downsample: int,
- downsample_mode: str,
- offset_v: float,
- ) -> tuple[float, int, str, float]:
- """Validates user-provided settings and returns derived configuration."""
- if resolution_bits == 8:
- max_rate_msps = 125.0
- elif resolution_bits in [12, 14, 15, 16]:
- max_rate_msps = 62.5
- else:
- raise ValueError(
- f"Unsupported resolution: {resolution_bits} bits. Must be one of 8, 12, 14, 15, 16."
- )
-
- if sample_rate_msps <= 0:
- sample_rate_msps = max_rate_msps
- logger.info(f"Max sample rate requested. Setting to {max_rate_msps} MS/s.")
-
- if sample_rate_msps > max_rate_msps:
- raise ValueError(
- f"Sample rate {sample_rate_msps} MS/s exceeds maximum of {max_rate_msps} MS/s for {resolution_bits}-bit resolution."
- )
-
- # Check if sample rate is excessive for the analog bandwidth.
- # Bandwidth is dependent on both resolution and voltage range.
- # (Based on PicoScope 5000A/B Series datasheet)
- inv_voltage_map = {v: k for k, v in VOLTAGE_RANGE_MAP.items()}
- voltage_v = inv_voltage_map.get(channel_range_str, 0)
-
- if resolution_bits == 16:
- bandwidth_mhz = 20 # 20 MHz for all ranges
- elif resolution_bits == 15:
- # Bandwidth is 70MHz for < ±5V, 60MHz for >= ±5V
- bandwidth_mhz = 70 if voltage_v < 5.0 else 60
- else: # 8-14 bits
- # Bandwidth is 100MHz for < ±5V, 60MHz for >= ±5V
- bandwidth_mhz = 100 if voltage_v < 5.0 else 60
-
- # Nyquist rate is 2x bandwidth. A common rule of thumb is 3-5x.
- # Warn if sampling faster than 5x the analog bandwidth.
- if sample_rate_msps > 5 * bandwidth_mhz:
- logger.warning(
- f"Sample rate ({sample_rate_msps} MS/s) may be unnecessarily high "
- f"for the selected voltage range ({channel_range_str}), which has an "
- f"analog bandwidth of {bandwidth_mhz} MHz."
- )
-
- if downsample_mode == "aggregate" and hardware_downsample <= 1:
- raise ValueError(
- "Hardware downsample ratio must be > 1 for 'aggregate' mode."
- )
-
- if hardware_downsample > 1:
- pico_downsample_ratio = hardware_downsample
- pico_ratio_mode = f"PS5000A_RATIO_MODE_{downsample_mode.upper()}"
- logger.info(
- f"Hardware down-sampling ({downsample_mode}) enabled "
- + f"with ratio {pico_downsample_ratio}."
- )
- else:
- pico_downsample_ratio = 1
- pico_ratio_mode = "PS5000A_RATIO_MODE_NONE"
-
- # Validate analog offset
- if offset_v != 0.0:
- if voltage_v >= 5.0:
- raise ValueError(
- f"Analog offset is not supported for voltage ranges >= 5V (selected: {channel_range_str})."
- )
- if abs(offset_v) > voltage_v:
- raise ValueError(
- f"Analog offset ({offset_v}V) exceeds the selected voltage range (±{voltage_v}V)."
- )
- logger.info(f"Analog offset set to {offset_v:.3f}V.")
-
- return sample_rate_msps, pico_downsample_ratio, pico_ratio_mode, offset_v
-
- def signal_handler(self, _sig: int, frame: Optional[object]) -> None:
- """Handles Ctrl+C interrupts to initiate a graceful shutdown."""
- logger.warning("Ctrl+C detected. Shutting down.")
- self.shutdown()
-
- def shutdown(self) -> None:
- """Performs a graceful shutdown of all components.
-
- This method calculates final statistics, stops all threads, closes the
- plotter, and ensures the Picoscope device is properly closed.
- """
- if self.shutdown_event.is_set():
- return
-
- self._log_acquisition_summary()
- self.shutdown_event.set()
-
- logger.info("Stopping data acquisition and saving...")
-
- self.pico_device.close_device()
- self._join_threads()
-
- logger.success("Shutdown complete.")
-
- def _log_acquisition_summary(self) -> None:
- """Calculates and logs final acquisition statistics."""
- if not self.start_time:
- return
-
- end_time = time.time()
- duration = end_time - self.start_time
- total_samples = self.consumer.values_written
- effective_rate_msps = (total_samples / duration) / 1e6 if duration > 0 else 0
- configured_rate_msps = 1e3 / self.pico_device.sample_int.value
-
- logger.info("--- Acquisition Summary ---")
- logger.info(f"Total acquisition time: {duration:.2f} s")
- logger.info(
- "Total samples written: "
- + f"{self.consumer.format_sample_count(total_samples)}"
- )
- logger.info(f"Configured sample rate: {configured_rate_msps:.2f} MS/s")
- logger.info(f"Effective average rate: {effective_rate_msps:.2f} MS/s")
-
- rate_ratio = (
- effective_rate_msps / configured_rate_msps
- if configured_rate_msps > 0
- else 0
- )
- if rate_ratio < 0.95:
- logger.warning(
- f"Effective rate was only {rate_ratio:.1%} " + "of the configured rate."
- )
- else:
- logger.success("Effective rate matches configured rate.")
- logger.info("--------------------------")
-
- def _join_threads(self) -> None:
- """Waits for the producer and consumer threads to terminate."""
- for thread_name in ["pico_thread", "consumer_thread"]:
- thread = getattr(self, thread_name, None)
- if thread and thread.is_alive():
- logger.info(f"Waiting for {thread_name} to terminate...")
- thread.join(timeout=2.0)
- if thread.is_alive():
- logger.critical(f"{thread_name} failed to terminate.")
-
- def run(self, app: Optional[QApplication] = None) -> None:
- """Starts the acquisition threads and optionally the Qt event loop."""
- # Start acquisition threads
- self.start_time = time.time()
- self.consumer_thread.start()
- self.pico_thread.start()
-
- # Handle Qt event loop if plotting is enabled
- if self.enable_live_plot and app:
- from .dfplot import HDF5LivePlotter
-
- plotter = HDF5LivePlotter(
- hdf5_path=self.output_file,
- display_window_seconds=self.plot_window_s,
- decimation_factor=self.decimation_factor,
- y_min=self.y_min,
- y_max=self.y_max,
- )
- plotter.show()
-
- # Run the Qt event loop. This will block until the plot window is closed.
- app.exec_()
-
- # Once the window is closed, the shutdown event should have been set.
- # We call shutdown() to ensure threads are joined and cleanup happens.
- self.shutdown()
- else:
- # In GUI mode, run() returns immediately, allowing the worker's event
- # loop to process signals. In CLI mode, we block until completion.
- if not self.is_gui_mode:
- self.consumer_thread.join()
- self.pico_thread.join()
- logger.success("Acquisition complete!")
-
-
-# --- Argument Parsing ---
-VOLTAGE_RANGE_MAP = {
- 0.01: "PS5000A_10MV",
- 0.02: "PS5000A_20MV",
- 0.05: "PS5000A_50MV",
- 0.1: "PS5000A_100MV",
- 0.2: "PS5000A_200MV",
- 0.5: "PS5000A_500MV",
- 1: "PS5000A_1V",
- 2: "PS5000A_2V",
- 5: "PS5000A_5V",
- 10: "PS5000A_10V",
- 20: "PS5000A_20V",
-}
-
-
-def generate_unique_filename(base_path: str) -> str:
- """Generate a unique filename by appending timestamp if file exists."""
- import os
- from pathlib import Path
-
- if not os.path.exists(base_path):
- return base_path
-
- # Split the path into parts
- path = Path(base_path)
- stem = path.stem
- suffix = path.suffix
- parent = path.parent
-
- # Generate timestamp
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
-
- # Create new filename with timestamp
- new_filename = f"{stem}_{timestamp}{suffix}"
- new_path = parent / new_filename
-
- return str(new_path)
-
-
-@click.command()
-@click.option(
- "--sample-rate",
- "-s",
- type=float,
- default=20,
- help="Sample rate in MS/s (e.g., 62.5). Use 0 for max rate. [default: 20]",
-)
-@click.option(
- "--resolution",
- "-b",
- type=click.Choice(["8", "12", "16"]),
- default="12",
- help="Resolution in bits. [default: 12]",
-)
-@click.option(
- "--rangev",
- "-r",
- type=click.Choice([str(k) for k in sorted(VOLTAGE_RANGE_MAP.keys())]),
- default="20",
- help=f"Voltage range in Volts. [default: 20]",
-)
-@click.option(
- "--plot/--no-plot",
- "-p",
- is_flag=True,
- default=True,
- help="Enable/disable live plotting. [default: --plot]",
-)
-@click.option(
- "--output",
- "-o",
- type=click.Path(dir_okay=False, writable=True),
- help="Output HDF5 file (default: auto-timestamped).",
-)
-@click.option(
- "--plot-window",
- "-w",
- type=float,
- default=0.5,
- help="Live plot display window duration in seconds. [default: 0.5]",
-)
-@click.option(
- "--verbose", "-v", is_flag=True, default=False, help="Enable debug logging."
-)
-@click.option(
- "--plot-npts",
- "-n",
- type=int,
- default=4000,
- help="Target number of points for the plot window. [default: 4000]",
-)
-@click.option(
- "--hardware-downsample",
- "-h",
- type=int,
- default=1,
- help="Hardware down-sampling ratio (power of 2 for 'average' mode). [default: 1]",
-)
-@click.option(
- "--downsample-mode",
- "-m",
- type=click.Choice(["average", "aggregate"]),
- default="average",
- help="Hardware down-sampling mode. [default: average]",
-)
-@click.option(
- "--offset",
- type=float,
- default=0.0,
- help="Analog offset in Volts (only for ranges < 5V). [default: 0.0]",
-)
-@click.option(
- "--bandwidth",
- type=click.Choice(["full", "20MHz"]),
- default="full",
- help="Bandwidth limiter to reduce noise. [default: full]",
-)
-@click.option(
- "--max-buff-sec",
- type=float,
- help="Maximum buffer duration in seconds for live-only mode (limits file size).",
-)
-@click.option(
- "--force",
- "-f",
- is_flag=True,
- default=False,
- help="Overwrite existing output file.",
-)
-@click.option(
- "--y-min",
- type=float,
- help="Minimum Y-axis limit in mV for live plot.",
-)
-@click.option(
- "--y-max",
- type=float,
- help="Maximum Y-axis limit in mV for live plot.",
-)
-def main(
- sample_rate: float,
- resolution: str,
- rangev: str,
- plot: bool,
- output: Optional[str],
- plot_window: float,
- verbose: bool,
- plot_npts: int,
- hardware_downsample: int,
- downsample_mode: str,
- offset: float,
- bandwidth: str,
- max_buff_sec: Optional[float],
- force: bool,
- y_min: Optional[float],
- y_max: Optional[float],
-) -> None:
- """High-speed data acquisition tool for Picoscope 5000a series."""
- # --- Argument Validation and Processing ---
-
- # Validate Y-axis limits
- if (y_min is None) != (y_max is None):
- logger.error("Both --y-min and --y-max must be provided together, or neither.")
- sys.exit(1)
-
- if y_min is not None and y_max is not None and y_min >= y_max:
- logger.error(f"Invalid Y-axis range: y_min ({y_min}) must be less than y_max ({y_max}).")
- sys.exit(1)
-
- channel_range_str = VOLTAGE_RANGE_MAP[float(rangev)]
- resolution_bits = int(resolution)
-
- app: Optional[QApplication] = None
- if plot:
- from PyQt5.QtWidgets import QApplication
-
- app = QApplication(sys.argv)
-
- # When plotting, SIGINT should gracefully close the Qt application.
- # The main loop will then handle the shutdown.
- def sigint_handler(_sig: int, _frame: Optional[object]) -> None:
- logger.warning("Ctrl+C detected. Closing application.")
- QApplication.quit()
-
- signal.signal(signal.SIGINT, sigint_handler)
-
- # Configure logging
- logger.remove()
- log_level = "DEBUG" if verbose else "INFO"
- logger.add(sys.stderr, level=log_level)
- logger.info(f"Logging configured at level: {log_level}")
-
- # Auto-generate filename if not specified
- if not output:
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- output = f"./output_{timestamp}.hdf5"
- else:
- # Check if file exists and handle accordingly
- if not force:
- original_output = output
- output = generate_unique_filename(output)
- if output != original_output:
- logger.info(f"File '{original_output}' exists. Using '{output}' instead.")
- logger.info("Use --force/-f to overwrite existing files.")
-
- logger.info(f"Output file: {output}")
- logger.info(f"Selected voltage range: {rangev}V -> {channel_range_str}")
-
- try:
- # Create and run the streamer
- streamer = Streamer(
- sample_rate_msps=sample_rate,
- resolution_bits=resolution_bits,
- channel_range_str=channel_range_str,
- enable_live_plot=plot,
- output_file=output,
- debug=verbose,
- plot_window_s=plot_window,
- plot_points=plot_npts,
- hardware_downsample=hardware_downsample,
- downsample_mode=downsample_mode,
- offset_v=offset,
- max_buffer_seconds=max_buff_sec,
- y_min=y_min,
- y_max=y_max,
- bandwidth_limiter=bandwidth,
- )
-
- # Update the acquisition command in metadata
- acquisition_command = " ".join(sys.argv)
- streamer.update_acquisition_command(acquisition_command)
-
- streamer.run(app)
- except RuntimeError as e:
- if "PICO_NOT_FOUND" in str(e):
- logger.critical(
- "Picoscope device not found. Please check connection and ensure no other software is using it."
- )
- else:
- logger.critical(f"Failed to initialize Picoscope: {e}")
- sys.exit(1)
-
- # --- Verification Step ---
- # Skip verification in live-only mode since file size is limited
- if not streamer.max_buffer_seconds:
- logger.info(f"Verifying output file: {output}")
- try:
- expected_samples = streamer.consumer.values_written
- if expected_samples == 0:
- logger.warning("Consumer processed no samples. Nothing to verify.")
- else:
- with h5py.File(output, "r") as f:
- if "adc_counts" not in f:
- raise ValueError("Dataset 'adc_counts' not found in HDF5 file.")
-
- actual_samples = len(f["adc_counts"])
- if actual_samples == expected_samples:
- logger.success(
- f"Verification PASSED: File contains {actual_samples} samples, as expected."
- )
- else:
- logger.error(
- f"Verification FAILED: Expected {expected_samples} samples, but file has {actual_samples}."
- )
- except Exception as e:
- logger.error(f"HDF5 file verification failed: {e}")
- else:
- logger.info("Skipping verification in live-only mode.")
-
-
-if __name__ == "__main__":
- main()
diff --git a/picostream/consumer.py b/picostream/consumer.py
deleted file mode 100644
index f48ff30..0000000
--- a/picostream/consumer.py
+++ /dev/null
@@ -1,160 +0,0 @@
-from __future__ import annotations
-
-import os
-import queue
-import threading
-from typing import Any, Dict, List, Optional
-
-import h5py
-import numpy as np
-from loguru import logger
-
-
-class Consumer:
- """A data consumer that runs in a separate thread.
-
- This class retrieves data buffers from a queue, writes them to an HDF5 file,
- and then returns the buffer index to an "empty" queue for reuse by the
- producer. It handles file creation, data writing, and metadata storage.
-
- Supports live-only mode when max_samples is specified, which limits the
- HDF5 file size by resetting and overwriting data when the limit is reached.
- """
-
- def __init__(
- self,
- buffer_size: int,
- data_queue: queue.Queue[int],
- empty_queue: queue.Queue[int],
- data_buffers: List[np.ndarray],
- file_name: str,
- shutdown_event: threading.Event,
- metadata: Dict[str, Any],
- max_samples: Optional[int] = None,
- ):
- """Initializes the Consumer.
-
- Args:
- buffer_size: The size of each individual data buffer.
- data_queue: A queue for receiving indices of data-filled buffers.
- empty_queue: A queue for returning indices of processed (empty) buffers.
- data_buffers: A list of pre-allocated NumPy arrays for data.
- file_name: The path to the output HDF5 file.
- shutdown_event: A threading.Event to signal termination.
- metadata: A dictionary of metadata to be saved as HDF5 attributes.
- max_samples: Maximum number of samples to keep in file (for live-only mode).
- """
- self.buffer_size = buffer_size
- self.data_queue = data_queue
- self.empty_queue = empty_queue
- self.data_buffers = data_buffers
- self.file_name = file_name
- self.shutdown_event = shutdown_event
- self.metadata = metadata
- self.max_samples = max_samples
-
- self.values_written: int = 0
- self.empty_con_queue_count: int = 0
- self.buffer_resets: int = 0
-
- def format_sample_count(self, count: int) -> str:
- """Format a large integer count into a human-readable string.
-
- Uses metric prefixes (K, M, G) for thousands, millions, and billions.
-
- Args:
- count: The integer number to format.
-
- Returns:
- A formatted string representation of the count.
- """
- if count >= 1_000_000_000:
- return f"{count / 1_000_000_000:.2f}G"
- if count >= 1_000_000:
- return f"{count / 1_000_000:.2f}M"
- if count >= 1_000:
- return f"{count / 1_000:.2f}K"
- else:
- return str(count)
-
- def _processing_loop(self, dset: h5py.Dataset) -> None:
- """
- Continuously processes data from the queue and writes to the HDF5 dataset.
-
- Args:
- dset: The HDF5 dataset to write to.
- """
- while not self.shutdown_event.is_set():
- try:
- # Wait for a buffer index from the producer.
- # A timeout allows the loop to periodically check the shutdown event.
- idx = self.data_queue.get(timeout=0.1)
-
- # Append the new data to the HDF5 dataset.
- buffer_len = len(self.data_buffers[idx])
-
- # Check if we need to reset the buffer (live-only mode)
- if self.max_samples and self.values_written + buffer_len > self.max_samples:
- # Reset the dataset to start overwriting from the beginning
- dset.resize((buffer_len,))
- dset[:] = self.data_buffers[idx]
- self.values_written = buffer_len
- self.buffer_resets += 1
- logger.debug(f"Buffer reset #{self.buffer_resets} - file size limited to {self.max_samples:,} samples")
- else:
- # Normal append
- dset.resize((self.values_written + buffer_len,))
- dset[self.values_written :] = self.data_buffers[idx]
- self.values_written += buffer_len
-
- # Return the buffer index to the empty queue for reuse.
- self.empty_queue.put(idx)
-
- except queue.Empty:
- # This occurs if the producer hasn't provided data within the timeout.
- self.empty_con_queue_count += 1
- # This is expected when acquisition stops, so no need to log as a warning
- if not self.shutdown_event.is_set():
- logger.debug("Consumer queue was empty.")
-
- def consume(self) -> None:
- """The main loop for the consumer thread.
-
- This method continuously checks for data from the producer, writes it to
- the HDF5 file, and returns the buffer for reuse. It handles file setup,
- the main processing loop, and graceful shutdown.
- """
- try:
- # Ensure a clean slate by removing any pre-existing file.
- if os.path.exists(self.file_name):
- os.remove(self.file_name)
- logger.info(f"Removed existing file: {self.file_name}")
-
- with h5py.File(self.file_name, "w") as f:
- # Write the collected metadata to the HDF5 file's attributes.
- for key, value in self.metadata.items():
- if value is not None:
- f.attrs[key] = value
-
- # Create a resizable dataset for the ADC data.
- # Chunking is aligned with the buffer size for efficient writes.
- dset = f.create_dataset(
- "adc_counts",
- (0,),
- maxshape=(None,),
- dtype="int16",
- chunks=(self.buffer_size,),
- )
-
- self._processing_loop(dset)
- except (IOError, OSError) as e:
- # A critical file error means we cannot continue.
- logger.critical(f"Failed to create or write to HDF5 file: {e}")
- self.shutdown_event.set() # Signal other threads to shut down.
- return
-
- logger.info(
- f"Consumer couldn't obtain data from queue {self.empty_con_queue_count} times."
- )
- if self.buffer_resets > 0:
- logger.info(f"Buffer was reset {self.buffer_resets} times (live-only mode).")
diff --git a/picostream/data_pipeline.py b/picostream/data_pipeline.py
new file mode 100644
index 0000000..c5587dd
--- /dev/null
+++ b/picostream/data_pipeline.py
@@ -0,0 +1,622 @@
+"""Data pipeline for processing PicoScope streaming data.
+
+This module provides a unified data processing pipeline that handles:
+- Sample rate calculations for different downsample modes
+- Data decimation for display optimization
+- ADC to millivolt conversion
+- Time axis generation with proper accounting for decimation and modes
+- Lag metrics between producer and consumer
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Dict, NamedTuple, Optional, Tuple
+
+import numpy as np
+from loguru import logger
+
+from picostream.acquisition_rate import (
+ AcquisitionRate,
+ DisplayPointCount,
+ DownsampleMode,
+ StorageSampleCount,
+)
+from picostream.conversion_utils import adc_to_mV, min_max_decimate_numba
+
+if TYPE_CHECKING:
+ pass # Already imported above
+
+
+class LagMetrics(NamedTuple):
+ """Metrics for producer-consumer lag.
+
+ Attributes
+ ----------
+ samples : int
+ Number of samples of lag.
+ milliseconds : float
+ Lag duration in milliseconds.
+ """
+
+ samples: int
+ milliseconds: float
+
+
+@dataclass(frozen=True)
+class DisplayGeometry:
+ """Complete display geometry for a window, computed in one place.
+
+ This is the single source of truth for how many points to cache,
+ how long the time axis is, and whether min/max pairs are used.
+ Both the cache sizing and the time axis generation must use values
+ from the same DisplayGeometry instance to stay in sync.
+
+ Attributes
+ ----------
+ time_axis : np.ndarray
+ Time values in seconds, right-aligned within the window.
+ Length equals ``max_display_values``.
+ max_display_values : int
+ Total number of display values to keep in the cache per channel.
+ Equals ``n_time_points * values_per_point``.
+ n_time_points : int
+ Number of unique time points (before any min/max doubling).
+ values_per_point : int
+ 1 for raw samples, 2 for min/max pairs.
+ decimation : int
+ The software decimation factor used to compute this geometry.
+ window_samples : int
+ Number of storage samples that fit in the window duration.
+ """
+
+ time_axis: np.ndarray
+ max_display_values: int
+ n_time_points: int
+ values_per_point: int
+ decimation: int
+ window_samples: int
+
+
+@dataclass(frozen=True)
+class PipelineConfig:
+ """Immutable configuration set at acquisition start.
+
+ This configuration is frozen at acquisition start to avoid race
+ conditions and ensure consistent processing throughout acquisition.
+
+ Attributes
+ ----------
+ resolution : int
+ ADC resolution in bits.
+ voltage_ranges : Dict[int, float]
+ Mapping of channel index to voltage range in Volts.
+ offsets_v : Dict[int, float]
+ Mapping of channel index to DC offset in Volts.
+ max_adc_value : int
+ Maximum ADC count value for the configured resolution.
+ target_plot_points : Optional[int]
+ Target number of points to display on screen, or None for lossless mode.
+ """
+
+ resolution: int
+ voltage_ranges: Dict[int, float]
+ offsets_v: Dict[int, float]
+ max_adc_value: int
+ target_plot_points: Optional[int]
+
+
+@dataclass
+class ProcessedChunk:
+ """Result of processing a chunk of data.
+
+ Attributes
+ ----------
+ time_axis : np.ndarray
+ Time axis in seconds (relative to window end).
+ voltage_data : Dict[int, np.ndarray]
+ Processed voltage data per channel in mV.
+ decimation : int
+ Decimation factor used for this chunk.
+ """
+
+ time_axis: np.ndarray
+ voltage_data: Dict[int, np.ndarray]
+ decimation: int
+
+
+class DataPipeline:
+ """Consumer-side data processing: decimation, conversion, time axis.
+
+ This class owns the complete consumer-side data processing pipeline,
+ consolidating previously scattered logic for decimation, ADC conversion,
+ and time axis generation into a single source of truth.
+
+ All rate calculations are delegated to the AcquisitionRate object,
+ ensuring consistent semantics throughout the system.
+
+ Parameters
+ ----------
+ config : PipelineConfig
+ Immutable configuration for conversion and decimation.
+ acquisition_rate : AcquisitionRate
+ Rate information with all downsampling and channel details.
+
+ Attributes
+ ----------
+ config : PipelineConfig
+ The pipeline configuration.
+ acquisition_rate : AcquisitionRate
+ The acquisition rate for consistent rate calculations.
+ _remainder : Dict[int, np.ndarray]
+ Leftover samples from previous decimation (per channel).
+ _last_decimation : int
+ Last decimation factor used (for detecting changes).
+ """
+
+ def __init__(self, config: PipelineConfig, acquisition_rate: AcquisitionRate):
+ self.config = config
+ self.acquisition_rate = acquisition_rate
+
+ # Processing state (owned here, not scattered)
+ self._remainder: Dict[int, np.ndarray] = {}
+ self._last_decimation: int = 1
+
+ logger.debug(
+ "DataPipeline initialised: resolution={} bits, "
+ "channels={}, mode={}, ratio={} → {:.2f} MS/s per channel",
+ config.resolution,
+ acquisition_rate.num_channels,
+ acquisition_rate.downsample_mode.value,
+ acquisition_rate.downsample_ratio,
+ acquisition_rate.per_channel_rate_hz / 1e6,
+ )
+
+ def calculate_decimation(self, window_samples: int) -> int:
+ """Calculate decimation factor for display optimization.
+
+ Parameters
+ ----------
+ window_samples : int
+ Number of samples in the display window.
+
+ Returns
+ -------
+ int
+ Decimation factor (minimum 1).
+ """
+ # Lossless mode: no decimation
+ if self.config.target_plot_points is None:
+ return 1
+ return max(1, window_samples // self.config.target_plot_points)
+
+ def should_invalidate_cache(self, decimation: int) -> bool:
+ """Check if cache should be invalidated due to decimation change.
+
+ Parameters
+ ----------
+ decimation : int
+ New decimation factor.
+
+ Returns
+ -------
+ bool
+ True if decimation has changed and cache should be reset.
+ """
+ return decimation != self._last_decimation
+
+ def invalidate_cache(self) -> None:
+ """Invalidate processing caches.
+
+ Called when decimation factor changes to reset remainder buffers
+ and signal that full reprocessing is needed.
+ """
+ self._remainder.clear()
+ self._last_decimation = 1
+ logger.debug("Pipeline cache invalidated")
+
+ def process_channel_data(
+ self,
+ raw_data: np.ndarray,
+ channel: int,
+ decimation: int,
+ ) -> Tuple[np.ndarray, bool]:
+ """Process raw ADC data for a single channel.
+
+ This method decimates the data, converts ADC to mV, and applies offsets.
+ It handles remainder samples from previous calls to ensure no data loss.
+
+ Parameters
+ ----------
+ raw_data : np.ndarray
+ Raw int16 ADC data for this channel.
+ channel : int
+ Channel index.
+ decimation : int
+ Decimation factor for display optimization.
+
+ Returns
+ -------
+ Tuple[np.ndarray, bool]
+ (voltage_data, has_data) where voltage_data is the processed
+ voltage in mV and has_data indicates if any data was produced.
+ """
+ if len(raw_data) == 0:
+ return np.array([], dtype=np.float64), False
+
+ # Combine with remainder from previous chunk
+ remainder = self._remainder.get(channel)
+ if remainder is not None and len(remainder) > 0:
+ combined = np.concatenate([remainder, raw_data])
+ else:
+ combined = raw_data
+
+ n_combined = len(combined)
+ n_decimated = (n_combined // decimation) * decimation
+
+ # Perform decimation if we have enough samples
+ if n_decimated > 0:
+ to_decimate = combined[:n_decimated]
+ decimated = min_max_decimate_numba(to_decimate, decimation)
+ else:
+ decimated = np.array([], dtype=np.int16)
+
+ # Store remainder for next call
+ remainder_len = n_combined - n_decimated
+ if remainder_len > 0:
+ self._remainder[channel] = combined[-remainder_len:].copy()
+ else:
+ self._remainder[channel] = np.array([], dtype=np.int16)
+
+ if len(decimated) == 0:
+ return np.array([], dtype=np.float64), False
+
+ # Convert ADC to mV
+ voltage_range = self.config.voltage_ranges.get(channel)
+ if voltage_range is not None and self.config.max_adc_value > 0:
+ try:
+ voltage_data = adc_to_mV(
+ decimated, voltage_range, self.config.max_adc_value
+ )
+ # Apply offset
+ offset = self.config.offsets_v.get(channel, 0.0)
+ voltage_data = voltage_data + (offset * 1000.0)
+ except Exception:
+ logger.exception("Voltage conversion failed for channel {}", channel)
+ voltage_data = decimated.astype(np.float64)
+ else:
+ voltage_data = decimated.astype(np.float64)
+
+ return voltage_data, True
+
+ def create_time_axis_for_window(
+ self,
+ window_duration: float,
+ decimation: int,
+ data_duration: Optional[float] = None,
+ ) -> np.ndarray:
+ """Create a time axis for a display window.
+
+ Delegates to ``get_window_display_geometry()`` to ensure the time
+ axis length is always consistent with cache sizing.
+
+ Parameters
+ ----------
+ window_duration : float
+ Full duration of the display window in seconds.
+ decimation : int
+ Current software decimation factor (1 for lossless).
+ data_duration : Optional[float]
+ Actual duration that the data represents. If provided and less
+ than window_duration, the time axis is right-aligned within
+ the window. Default is None (span full window).
+
+ Returns
+ -------
+ np.ndarray
+ Time axis array in seconds, right-aligned within the window.
+ """
+ geometry = self.get_window_display_geometry(
+ window_duration, decimation, data_duration
+ )
+ return geometry.time_axis
+
+ def get_window_display_geometry(
+ self,
+ window_duration: float,
+ decimation: int,
+ data_duration: Optional[float] = None,
+ ) -> DisplayGeometry:
+ """Compute the complete display geometry for a window.
+
+ This is the single authoritative method for determining cache size,
+ time axis, and values-per-point. Both cache sizing and time axis
+ generation derive from the same call, so they cannot diverge.
+
+ Parameters
+ ----------
+ window_duration : float
+ Full duration of the display window in seconds.
+ decimation : int
+ Current software decimation factor (1 for lossless).
+ data_duration : Optional[float]
+ Actual data duration if less than window. When provided,
+ the time axis is right-aligned within the window.
+
+ Returns
+ -------
+ DisplayGeometry
+ Complete geometry with time axis, cache sizing, and metadata.
+ """
+ if window_duration <= 0:
+ empty = np.array([], dtype=np.float64)
+ return DisplayGeometry(
+ time_axis=empty,
+ max_display_values=0,
+ n_time_points=0,
+ values_per_point=1,
+ decimation=decimation,
+ window_samples=0,
+ )
+
+ window_samples = self.time_to_samples(window_duration)
+
+ is_aggregate = (
+ self.acquisition_rate.downsample_mode == DownsampleMode.AGGREGATE
+ and self.acquisition_rate.downsample_ratio > 1
+ )
+ # Software decimation ALWAYS produces min/max pairs when decimation > 1,
+ # regardless of target_plot_points (lossless vs quality mode).
+ # This is because min_max_decimate_numba always outputs pairs.
+ has_software_minmax = not is_aggregate and decimation > 1
+
+ if is_aggregate:
+ # AGGREGATE mode: window_samples already includes min/max pairs (storage_rate doubles).
+ # Each "time point" is a min/max pair (2 storage samples).
+ # So n_time_points = window_samples // 2 (pairs) // decimation.
+ values_per_point = 2
+ n_time_points = int(window_samples // 2 // decimation)
+ elif has_software_minmax:
+ # Software decimation produces min/max pairs.
+ values_per_point = 2
+ n_time_points = int(window_samples // decimation)
+ else:
+ # Normal mode: one value per sample.
+ values_per_point = 1
+ n_time_points = int(window_samples // decimation)
+
+ max_display_values = n_time_points * values_per_point
+
+ if data_duration is not None and data_duration < window_duration:
+ start_time = window_duration - data_duration
+ else:
+ start_time = 0.0
+
+ if n_time_points > 0:
+ time_values = np.linspace(start_time, window_duration, n_time_points)
+ else:
+ time_values = np.array([], dtype=np.float64)
+
+ if values_per_point == 2 and len(time_values) > 0:
+ time_axis = np.repeat(time_values, 2)
+ else:
+ time_axis = time_values
+
+ return DisplayGeometry(
+ time_axis=time_axis,
+ max_display_values=max_display_values,
+ n_time_points=n_time_points,
+ values_per_point=values_per_point,
+ decimation=decimation,
+ window_samples=window_samples,
+ )
+
+ def get_lag_metrics(self, producer_idx: int, consumer_idx: int) -> LagMetrics:
+ """Calculate lag between producer and consumer.
+
+ Parameters
+ ----------
+ producer_idx : int
+ Current write index from producer.
+ consumer_idx : int
+ Last read index from consumer.
+
+ Returns
+ -------
+ LagMetrics
+ Lag in samples and milliseconds.
+ """
+ lag_samples = producer_idx - consumer_idx
+ if lag_samples < 0:
+ lag_samples = 0
+
+ storage_rate = self.acquisition_rate.storage_rate_hz
+ if storage_rate > 0:
+ lag_seconds = lag_samples / storage_rate
+ lag_ms = lag_seconds * 1000.0
+ else:
+ lag_ms = 0.0
+
+ return LagMetrics(lag_samples, lag_ms)
+
+ def format_lag(self, producer_idx: int, consumer_idx: int) -> str:
+ """Format lag for UI display.
+
+ Parameters
+ ----------
+ producer_idx : int
+ Current write index from producer.
+ consumer_idx : int
+ Last read index from consumer.
+
+ Returns
+ -------
+ str
+ Formatted lag string with two lines (e.g., "12ms<br>2.45MS").
+ """
+ metrics = self.get_lag_metrics(producer_idx, consumer_idx)
+ lag_samples_formatted = self._format_sample_count_with_unit(metrics.samples)
+ return f"{metrics.milliseconds:.0f}ms<br>{lag_samples_formatted}"
+
+ def _format_sample_count_with_unit(self, count: int) -> str:
+ """Format sample count with appropriate SI unit prefix.
+
+ Always uses MSamples (mega) as the primary unit since typical
+ lag values are in the megasample range.
+
+ Parameters
+ ----------
+ count : int
+ Sample count.
+
+ Returns
+ -------
+ str
+ Formatted count with unit (e.g., "2.45 MS", "850 kS").
+ """
+ if count >= 1_000_000:
+ val = count / 1_000_000
+ return f"{val:.2f} MS".replace(".00 ", " ").replace(".0 ", " ")
+ if count >= 1_000:
+ return f"{count / 1_000:.1f} kS"
+ return f"{count} S"
+
+ def samples_to_time(self, n_samples: StorageSampleCount) -> float:
+ """Convert sample count to duration.
+
+ Parameters
+ ----------
+ n_samples : StorageSampleCount
+ Number of storage samples.
+
+ Returns
+ -------
+ float
+ Duration in seconds.
+ """
+ return self.acquisition_rate.samples_to_seconds(n_samples)
+
+ def time_to_samples(self, duration_s: float) -> StorageSampleCount:
+ """Convert duration to sample count.
+
+ Parameters
+ ----------
+ duration_s : float
+ Duration in seconds.
+
+ Returns
+ -------
+ StorageSampleCount
+ Number of samples.
+ """
+ return self.acquisition_rate.seconds_to_samples(duration_s)
+
+ def get_display_duration(
+ self, n_display_samples: DisplayPointCount, decimation: int
+ ) -> float:
+ """Calculate the time duration represented by display samples.
+
+ DEPRECATED: This method mixes concerns by trying to reverse-engineer
+ duration from display points. Use samples_to_seconds() with storage
+ samples instead.
+
+ This accounts for both software min-max decimation (which produces
+ pairs of min/max values) and hardware AGGREGATE mode (which stores
+ min/max pairs in the ring buffer).
+
+ Parameters
+ ----------
+ n_display_samples : DisplayPointCount
+ Number of display points (after software decimation).
+ decimation : int
+ Software decimation factor used.
+
+ Returns
+ -------
+ float
+ Duration in seconds that the data represents.
+ """
+ return self.acquisition_rate.get_display_duration(n_display_samples, decimation)
+
+ def get_max_cache_points(self, window_samples: int, decimation: int) -> int:
+ """Calculate maximum display values for cache.
+
+ Delegates to ``get_window_display_geometry()`` so cache sizing
+ always matches time axis length exactly.
+
+ Parameters
+ ----------
+ window_samples : int
+ Number of samples in the display window.
+ decimation : int
+ Decimation factor.
+
+ Returns
+ -------
+ int
+ Maximum number of display values to keep in cache per channel.
+ """
+ window_duration = self.samples_to_time(window_samples)
+ geometry = self.get_window_display_geometry(window_duration, decimation)
+ return geometry.max_display_values
+
+ def get_max_time_points(self, window_samples: int, decimation: int) -> int:
+ """Calculate maximum time points for time axis creation.
+
+ Parameters
+ ----------
+ window_samples : int
+ Number of samples in the display window.
+ decimation : int
+ Decimation factor.
+
+ Returns
+ -------
+ int
+ Maximum number of time points (not display values).
+ """
+ return window_samples // decimation
+
+ def is_aggregate_mode(self) -> bool:
+ """Check if current mode is AGGREGATE.
+
+ Returns
+ -------
+ bool
+ True if mode is AGGREGATE.
+ """
+ return self.acquisition_rate.downsample_mode == DownsampleMode.AGGREGATE
+
+ def calculate_display_points(
+ self, n_input_samples: int, channel: int, decimation: int
+ ) -> int:
+ """Calculate the number of display points that will be produced.
+
+ Accounts for current remainder state and min-max pair expansion
+ when software decimation is active.
+
+ Parameters
+ ----------
+ n_input_samples : int
+ Number of input samples to process.
+ channel : int
+ Channel index (remainder is per-channel).
+ decimation : int
+ Decimation factor.
+
+ Returns
+ -------
+ int
+ Number of display points that will be produced.
+ """
+ if n_input_samples <= 0 or decimation <= 0:
+ return 0
+
+ remainder = self._remainder.get(channel)
+ remainder_len = len(remainder) if remainder is not None else 0
+ total_samples = remainder_len + n_input_samples
+
+ n_groups = total_samples // decimation
+ n_points = n_groups * 2 # min-max pairs
+
+ return n_points
diff --git a/picostream/device.py b/picostream/device.py
new file mode 100644
index 0000000..3681f79
--- /dev/null
+++ b/picostream/device.py
@@ -0,0 +1,2794 @@
+import ctypes
+import os
+import threading
+import time
+from datetime import datetime, timezone
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+import numpy as np
+from loguru import logger
+
+from picostream.acquisition_rate import AcquisitionRate, DownsampleMode
+
+# Lazy import of PicoSDK
+try:
+ from picosdk.functions import PICO_STATUS, adc2mV, assert_pico_ok
+ from picosdk.ps5000a import ps5000a as ps # type: ignore
+except ImportError:
+ ps = None
+ adc2mV = None
+ assert_pico_ok = None
+ PICO_STATUS = None
+ logger.warning("PicoSDK not available - Picoscope functionality will be limited")
+
+try:
+ import h5py # type: ignore
+except ImportError:
+ h5py = None
+ logger.warning("h5py not available - PicoscopeBufferedStream will be limited")
+
+from labdaemon.exceptions import (
+ DeviceConfigurationError,
+ DeviceConnectionError,
+ DeviceNotConnectedError,
+ DeviceOperationError,
+ TaskCancelledError,
+)
+
+
+class Picoscope:
+ """
+ Device plugin for Picoscope 5000a series digitizers.
+
+ This is a synchronous implementation for the LabDaemon framework, providing
+ blocking data acquisition functionality.
+ """
+
+ # Valid voltage ranges for 5000a series
+ VOLTAGE_RANGES = {
+ 0.01: "PS5000A_10MV",
+ 0.02: "PS5000A_20MV",
+ 0.05: "PS5000A_50MV",
+ 0.1: "PS5000A_100MV",
+ 0.2: "PS5000A_200MV",
+ 0.5: "PS5000A_500MV",
+ 1.0: "PS5000A_1V",
+ 2.0: "PS5000A_2V",
+ 5.0: "PS5000A_5V",
+ 10.0: "PS5000A_10V",
+ 20.0: "PS5000A_20V",
+ }
+
+ # Hardware limits for 5000a series
+ MAX_SAMPLE_RATE_HZ = 1_000_000_000 # 1 GS/s for single channel
+ MAX_DOWNSAMPLE_RATIO = 2**32 - 1 # uint32 limit in PicoSDK
+ SUPPORTED_DOWNSAMPLE_MODES = {"NONE", "AVERAGE", "DECIMATE", "AGGREGATE"}
+
+ def __init__(
+ self,
+ device_id: str,
+ serial_code: Optional[str] = None,
+ resolution: int = 12,
+ **kwargs: Any,
+ ) -> None:
+ """Initialise the Picoscope device."""
+ self.device_id = device_id
+ self.serial_code = serial_code
+ self.resolution = resolution
+
+ if ps is None:
+ raise ImportError(
+ "PicoSDK not found. Please install it to use the Picoscope plugin."
+ )
+ self._ps = ps
+
+ if resolution not in {8, 12, 14, 15, 16}:
+ raise DeviceConfigurationError(
+ f"Invalid resolution: {resolution}. Must be one of: 8, 12, 14, 15, 16"
+ )
+
+ self._chandle: Optional[ctypes.c_int16] = None
+ self._is_connected: bool = False
+ self._enabled_channels: List[int] = []
+ self._channel_ranges: Dict[int, str] = {}
+ self._channel_voltage_ranges: Dict[int, float] = {}
+ self._buffers: Dict[int, np.ndarray] = {}
+ self._sample_interval_s: Optional[float] = None
+ self._max_adc_value: Optional[ctypes.c_int16] = None
+ self._current_acquisition_params: Dict[str, Any] = {}
+ self._downsample_ratio: int = 1
+ self._downsample_mode: str = "NONE"
+
+ logger.info("Picoscope plugin initialised for {}", self.device_id)
+
+ def get_connection_id(self) -> Optional[str]:
+ """
+ Get the Picoscope serial code.
+
+ Returns
+ -------
+ Optional[str]
+ The serial code, or None if auto-detect is used.
+ """
+ return self.serial_code
+
+ def set_connection_id(self, connection_id: Optional[str]) -> None:
+ """
+ Set the Picoscope serial code for connection.
+
+ Pass None or empty string to use auto-detect.
+
+ Parameters
+ ----------
+ connection_id : Optional[str]
+ Serial code (e.g., "AW123/456"), or None/empty for auto-detect.
+
+ Raises
+ ------
+ ValueError
+ If connection_id is not a string or None.
+ """
+ if connection_id is None:
+ self.serial_code = None
+ elif isinstance(connection_id, str):
+ self.serial_code = connection_id if connection_id else None
+ else:
+ raise ValueError("connection_id must be a string or None")
+ logger.info(
+ "Picoscope {} connection_id set to: {}",
+ self.device_id,
+ self.serial_code or "auto-detect",
+ )
+
+ def connect(self) -> None:
+ """Establishes a connection to the Picoscope device."""
+ if self._is_connected:
+ logger.info("Device {} is already connected.", self.device_id)
+ return
+
+ self._chandle = ctypes.c_int16()
+ res_enum = self._ps.PS5000A_DEVICE_RESOLUTION[ # pyright: ignore[reportAttributeAccessIssue]
+ f"PS5000A_DR_{self.resolution}BIT"
+ ]
+ serial_number = self.serial_code.encode("utf-8") if self.serial_code else None
+
+ try:
+ status = self._ps.ps5000aOpenUnit(
+ ctypes.byref(self._chandle), serial_number, res_enum
+ )
+ self._assert_pico_ok(status)
+ self._is_connected = True
+
+ # Get max ADC value from driver
+ # Note: PicoScope 5000a returns the SAME max_adc for all resolutions
+ # because it uses a 16-bit ADC internally. Higher resolutions just use
+ # more effective bits. We store both resolution and max_adc for clarity.
+ self._max_adc_value = ctypes.c_int16()
+ status = self._ps.ps5000aMaximumValue(
+ self._chandle, ctypes.byref(self._max_adc_value)
+ )
+ self._assert_pico_ok(status)
+
+ # Validate: max_adc should match what the hardware actually uses
+ # For PicoScope 5000a, this is typically 32767 (15-bit signed) regardless of resolution
+ expected_max_for_res = (1 << (self.resolution - 1)) - 1
+ logger.info(
+ "Connected to Picoscope {}: {}-bit resolution, "
+ "driver max_adc={}, theoretical max for resolution={}",
+ self.device_id,
+ self.resolution,
+ self._max_adc_value.value,
+ expected_max_for_res,
+ )
+
+ if self._max_adc_value.value != expected_max_for_res:
+ logger.warning(
+ "max_adc mismatch: driver returns {} but {}-bit resolution "
+ "theoretical max is {}. Using driver value for conversion.",
+ self._max_adc_value.value,
+ self.resolution,
+ expected_max_for_res,
+ )
+ except Exception as e:
+ self._is_connected = False
+ raise DeviceConnectionError(
+ f"Failed to connect to Picoscope device {self.device_id}: {e}"
+ ) from e
+
+ def disconnect(self) -> None:
+ """Closes the connection to the Picoscope device.
+
+ Cleanup order is important: this method should only be called after
+ any streaming or acquisition has been stopped. The typical sequence is:
+ 1. Stop streaming/acquisition (streaming thread)
+ 2. Stop hardware (ps5000aStop)
+ 3. Close unit (ps5000aCloseUnit) - this method
+
+ Reversing this order can cause deadlocks or handle leaks.
+ """
+ if not self._is_connected or self._chandle is None:
+ logger.info("Device {} is already disconnected.", self.device_id)
+ return
+
+ try:
+ status = self._ps.ps5000aCloseUnit(self._chandle)
+ self._assert_pico_ok(status)
+ except Exception as e:
+ logger.exception(
+ "Error disconnecting from Picoscope device {}: {}", self.device_id, e
+ )
+ finally:
+ self._chandle = None
+ self._is_connected = False
+ logger.info("Disconnected from Picoscope device {}", self.device_id)
+
+ def acquire_block(
+ self,
+ sample_rate: float,
+ num_samples: int,
+ channels: List[int],
+ voltage_ranges: List[float],
+ trigger_source: Optional[str] = None,
+ trigger_threshold_v: float = 0.5,
+ trigger_direction: str = "RISING",
+ trigger_delay_s: float = 0.0,
+ timeout_s: float = 10.0,
+ downsample_ratio: int = 1,
+ downsample_mode: str = "NONE",
+ **kwargs: Any,
+ ) -> Dict[str, Any]:
+ """
+ Acquires a single block of data from the Picoscope.
+
+ This is a convenience method that arms the acquisition and waits for it
+ to complete. For more control, use `arm_acquisition` and
+ `wait_for_acquisition` separately.
+
+ Can be configured to wait for a hardware trigger before starting.
+ """
+ if not self._is_connected or self._chandle is None:
+ raise DeviceNotConnectedError(f"Device {self.device_id} is not connected.")
+
+ self.arm_acquisition(
+ sample_rate=sample_rate,
+ num_samples=num_samples,
+ channels=channels,
+ voltage_ranges=voltage_ranges,
+ trigger_source=trigger_source,
+ trigger_threshold_v=trigger_threshold_v,
+ trigger_direction=trigger_direction,
+ trigger_delay_s=trigger_delay_s,
+ downsample_ratio=downsample_ratio,
+ downsample_mode=downsample_mode,
+ **kwargs,
+ )
+ return self.wait_for_acquisition(timeout_s=timeout_s)
+
+ def arm_acquisition(
+ self,
+ sample_rate: float,
+ num_samples: int,
+ channels: List[int],
+ voltage_ranges: List[float],
+ trigger_source: Optional[str] = None,
+ trigger_threshold_v: float = 0.5,
+ trigger_direction: str = "RISING",
+ trigger_delay_s: float = 0.0,
+ downsample_ratio: int = 1,
+ downsample_mode: str = "NONE",
+ **kwargs: Any,
+ ) -> None:
+ """
+ Arms the Picoscope for a block mode acquisition.
+
+ This configures the channels, timebase, and trigger, then sets up the
+ data buffers and starts the acquisition. The device will then wait for
+ a trigger (if configured) or start capturing immediately. This method
+ returns immediately without waiting for the capture to finish.
+
+ Note
+ ----
+ The num_samples parameter is interpreted as the desired acquisition
+ duration, relative to the requested sample rate. If the hardware
+ cannot achieve the requested sample rate, num_samples will be adjusted
+ to keep the same acquisition duration at the actual achievable rate.
+ """
+ if not self._is_connected or self._chandle is None:
+ raise DeviceNotConnectedError(f"Device {self.device_id} is not connected.")
+ if not channels:
+ raise DeviceConfigurationError("No channels specified for acquisition.")
+ if len(channels) != len(voltage_ranges):
+ raise DeviceConfigurationError(
+ "Length of 'channels' and 'voltage_ranges' must match."
+ )
+
+ logger.debug(
+ "Arming block acquisition on device {} at requested {} Hz for {} samples on channels {}.{}",
+ self.device_id,
+ sample_rate,
+ num_samples,
+ channels,
+ f" Waiting for {trigger_source} trigger." if trigger_source else "",
+ )
+
+ # Reset device to a clean state before configuring block mode
+ self._reset_device_state()
+
+ try:
+ # 1. Configure channels and store downsampling parameters
+ for ch, v_range in zip(channels, voltage_ranges, strict=True):
+ self._configure_channel(ch, True, "DC", v_range)
+ self._store_downsampling_params(downsample_ratio, downsample_mode)
+
+ # 2. Probe actual achievable sample rate
+ desired_interval = 1.0 / sample_rate
+ timebase, actual_interval = self._get_timebase(
+ desired_interval, num_samples
+ )
+ actual_sample_rate = 1.0 / actual_interval
+ self._sample_interval_s = actual_interval
+
+ # 3. Adjust num_samples to keep the same acquisition duration at the actual rate
+ # Calculate implied duration from requested parameters
+ requested_duration = num_samples * desired_interval
+ adjusted_num_samples = int(round(requested_duration / actual_interval))
+
+ rate_mismatch_pct = (
+ abs(actual_sample_rate - sample_rate) / sample_rate * 100
+ )
+ if rate_mismatch_pct > 1.0:
+ logger.warning(
+ "Sample rate mismatch on device {}: requested {:.0f} Hz, achieved {:.0f} Hz ({:.1f}% difference). Adjusted num_samples: {} -> {} to maintain {:.2f}s duration.",
+ self.device_id,
+ sample_rate,
+ actual_sample_rate,
+ rate_mismatch_pct,
+ num_samples,
+ adjusted_num_samples,
+ requested_duration,
+ )
+
+ num_samples = adjusted_num_samples
+ logger.debug(
+ "[PICO] Timebase result: timebase_idx={}, actual_interval={:.3e} s, actual_rate={:.0f} Hz",
+ timebase,
+ actual_interval,
+ actual_sample_rate,
+ )
+
+ # 4. Configure trigger
+ if trigger_source:
+ self._configure_trigger(
+ source=trigger_source,
+ threshold_v=trigger_threshold_v,
+ direction=trigger_direction,
+ delay_s=trigger_delay_s,
+ )
+ else:
+ # Explicitly disable trigger if not requested
+ self._configure_trigger(source=None)
+
+ # 5. Set up buffers
+ output_samples = num_samples // downsample_ratio
+ if output_samples < 1:
+ raise DeviceConfigurationError(
+ f"Calculated output samples is {output_samples}. This must be at least 1. "
+ "This can happen if `num_samples` is smaller than `downsample_ratio`."
+ )
+ self._setup_block_buffers(channels, output_samples)
+
+ # 6. Store acquisition parameters for later use
+ self._current_acquisition_params = {
+ "num_samples": output_samples,
+ "channels": channels,
+ "requested_sample_rate": sample_rate,
+ "actual_sample_rate": actual_sample_rate,
+ "downsample_ratio": downsample_ratio,
+ "downsample_mode": downsample_mode,
+ }
+
+ # 7. Run acquisition
+ self._run_block(num_samples, timebase)
+ time.sleep(0.5) # ensure ready for trigger?
+
+ except Exception as e:
+ logger.exception("Error arming acquisition on device {}", self.device_id)
+ raise DeviceOperationError(
+ f"Error arming acquisition on device {self.device_id}: {e}"
+ ) from e
+
+ def wait_for_acquisition(
+ self, timeout_s: float = 10.0, cancel_event: Optional[threading.Event] = None
+ ) -> Dict[str, Any]:
+ """
+ Waits for an armed acquisition to complete and retrieves the data.
+
+ Parameters
+ ----------
+ timeout_s : float, default 10.0
+ Maximum time to wait for acquisition completion.
+ cancel_event : Optional[threading.Event], default None
+ If provided, the acquisition will be cancelled if this event is set.
+
+ Returns
+ -------
+ Dict[str, Any]
+ Dictionary containing acquired data and metadata.
+
+ Raises
+ ------
+ TimeoutError
+ If the acquisition does not complete within the timeout.
+ TaskCancelledError
+ If the cancel_event is set during acquisition.
+
+ Notes
+ -----
+ When cancel_event is set, this method will call cancel_acquisition()
+ and raise TaskCancelledError. This allows tasks to pass their
+ context.cancel_event to enable cooperative cancellation during
+ long acquisitions.
+ """
+ if not self._is_connected or self._chandle is None:
+ raise DeviceNotConnectedError(f"Device {self.device_id} is not connected.")
+ if not self._current_acquisition_params:
+ raise DeviceOperationError("No acquisition is currently armed.")
+
+ try:
+ # 1. Wait for completion
+ self._wait_for_capture(timeout_s=timeout_s, cancel_event=cancel_event)
+
+ # 2. Retrieve and convert data
+ time_data, voltage_data, raw_adc_data = self._get_block_data()
+
+ # 2b. Ensure block acquisition is fully stopped and buffers are de-registered
+ try:
+ status = self._ps.ps5000aStop(self._chandle)
+ if status != PICO_STATUS["PICO_OK"]:
+ logger.debug(
+ "ps5000aStop returned status {} after block retrieval on device {}.",
+ status,
+ self.device_id,
+ )
+ except Exception:
+ logger.exception("Error stopping block acquisition after retrieval")
+
+ # De-register block-mode buffers using the same pointers and size=0
+ try:
+ for ch in list(self._enabled_channels):
+ buf = self._buffers.get(ch)
+ if buf is None:
+ continue
+ status = self._ps.ps5000aSetDataBuffer(
+ self._chandle,
+ self._get_channel_enum(ch),
+ buf.ctypes.data_as(ctypes.POINTER(ctypes.c_int16)),
+ 0, # size=0 de-registers this buffer for block mode
+ 0, # segmentIndex
+ self._ps.PS5000A_RATIO_MODE["PS5000A_RATIO_MODE_NONE"],
+ )
+ if status != PICO_STATUS["PICO_OK"]:
+ logger.debug(
+ "ps5000aSetDataBuffer (deregister) status {} for channel {} on device {}.",
+ status,
+ ch,
+ self.device_id,
+ )
+ except Exception:
+ logger.exception("Error de-registering block buffers after retrieval")
+
+ # 3. Format and return results
+ params = self._current_acquisition_params
+ if self._max_adc_value is None:
+ raise DeviceOperationError("max_adc_value not available for metadata.")
+
+ downsample_ratio = params.get("downsample_ratio", 1)
+ result = {
+ "data_v": voltage_data,
+ "data_adc": raw_adc_data,
+ "metadata": {
+ "sample_rate": params["actual_sample_rate"] / downsample_ratio,
+ "num_samples": len(time_data),
+ "channels": params["channels"],
+ "channel_voltage_ranges": [
+ self._channel_voltage_ranges[ch] for ch in params["channels"]
+ ],
+ "max_adc": self._max_adc_value.value,
+ "time_data": time_data,
+ "requested_sample_rate": params["requested_sample_rate"],
+ "device_id": self.device_id,
+ "downsample_ratio": downsample_ratio,
+ "downsample_mode": params.get("downsample_mode", "NONE"),
+ },
+ }
+ # Clear params after successful retrieval
+ self._current_acquisition_params = {}
+ return result
+ except TimeoutError:
+ # This is an expected outcome on timeout. Log and re-raise.
+ logger.warning("Acquisition timed out on device {}", self.device_id)
+ self._current_acquisition_params = {}
+ raise # Re-raise TimeoutError for the task to handle
+ except TaskCancelledError:
+ # This is an expected outcome, not an error. Log and re-raise.
+ logger.warning("Acquisition cancelled on device {}", self.device_id)
+ self._current_acquisition_params = {}
+ raise
+ except Exception as e:
+ logger.exception(
+ f"Error waiting for acquisition on device {self.device_id}"
+ )
+ # Clear params on failure
+ self._current_acquisition_params = {}
+ raise DeviceOperationError(
+ f"Error waiting for acquisition on device {self.device_id}: {e}"
+ ) from e
+
+ def _cancel_acquisition_unlocked(self) -> None:
+ """Stops a running acquisition. Assumes no lock is held."""
+ if not self._is_connected or self._chandle is None:
+ logger.warning(
+ "Cannot cancel acquisition, device {} is not connected.", self.device_id
+ )
+ return
+
+ try:
+ status = self._ps.ps5000aStop(self._chandle)
+ # Don't assert OK on stop, as it can fail if no acquisition is running.
+ if status != PICO_STATUS["PICO_OK"]:
+ logger.debug(
+ f"ps5000aStop returned status {status} on device {self.device_id}. This is often normal."
+ )
+ logger.info("Acquisition stop command sent to device {}", self.device_id)
+ except Exception as e:
+ logger.exception(
+ f"Error sending stop command on device {self.device_id}: {e}"
+ )
+
+ def cancel_acquisition(self) -> None:
+ """
+ Stops a running acquisition on the Picoscope. This method is thread-safe.
+ """
+ # This public method is wrapped by the device lock.
+ self._cancel_acquisition_unlocked()
+
+ def disable_trigger(self) -> None:
+ """
+ Disables the trigger on the Picoscope.
+
+ This is useful to ensure the device is in a known state after a
+ triggered acquisition.
+ """
+ if not self._is_connected or self._chandle is None:
+ logger.warning(
+ "Cannot disable trigger, device {} is not connected.", self.device_id
+ )
+ return
+ try:
+ self._configure_trigger(source=None)
+ except Exception as e:
+ logger.exception(
+ f"Failed to disable trigger on device {self.device_id}: {e}"
+ )
+ # Do not re-raise, as this is often for cleanup.
+
+ def reset_device_state(self) -> None:
+ """Public method to reset the device to a clean, idle state."""
+ self._reset_device_state()
+
+ def _reset_device_state(self) -> None:
+ """
+ Resets the device to a clean, idle state.
+
+ This method is intended to be called before any new acquisition
+ (block or streaming) to prevent state pollution from previous
+ operations. It clears all buffers, resets channel configurations,
+ and disables the hardware trigger.
+ """
+ logger.debug("Resetting device state for {}", self.device_id)
+ # 0. Stop any running acquisition. This is critical to return the hardware
+ # to a known idle state before re-configuration.
+ if self._is_connected and self._chandle is not None:
+ try:
+ status = self._ps.ps5000aStop(self._chandle)
+ if status != PICO_STATUS["PICO_OK"]:
+ # This is not an error. It's normal for ps5000aStop to return
+ # non-OK status if the device is already idle.
+ logger.debug(
+ f"ps5000aStop returned status {status} during reset. This is normal."
+ )
+ except Exception:
+ # Log but do not re-raise. We want to continue the reset process.
+ logger.exception(
+ f"Error calling ps5000aStop during reset on device {self.device_id}"
+ )
+
+ # 1. Disable trigger to ensure device is not waiting for an event
+ self.disable_trigger()
+
+ # 2. Explicitly disable all channels in the SDK to clear any previous configuration
+ if self._is_connected and self._chandle is not None:
+ for ch in range(4): # Channels A, B, C, D (0-3)
+ try:
+ channel_enum = self._get_channel_enum(ch)
+ status = self._ps.ps5000aSetChannel(
+ self._chandle,
+ channel_enum,
+ 0,
+ 0,
+ 0,
+ 0.0, # disabled
+ )
+ if status != PICO_STATUS["PICO_OK"]:
+ logger.debug(
+ f"Failed to disable channel {ch} during reset: status {status}"
+ )
+ except Exception:
+ logger.exception(
+ f"Error disabling channel {ch} during reset on device {self.device_id}"
+ )
+
+ # 3. De-register block-mode buffers using the original pointers (size=0)
+ if self._is_connected and self._chandle is not None:
+ for ch in range(4): # Channels A, B, C, D (0-3)
+ try:
+ buf = self._buffers.get(ch)
+ if buf is None:
+ continue # Only de-register buffers we actually set
+ channel_enum = self._get_channel_enum(ch)
+ status = self._ps.ps5000aSetDataBuffer(
+ self._chandle,
+ channel_enum,
+ buf.ctypes.data_as(ctypes.POINTER(ctypes.c_int16)),
+ 0, # size=0 de-registers this buffer for block mode
+ 0, # segmentIndex
+ self._ps.PS5000A_RATIO_MODE["PS5000A_RATIO_MODE_NONE"],
+ )
+ if status != PICO_STATUS["PICO_OK"]:
+ logger.debug(
+ f"Failed to de-register block buffer for channel {ch} during reset: status {status}"
+ )
+ except Exception:
+ logger.exception(
+ f"Error de-registering block buffer for channel {ch} during reset on device {self.device_id}"
+ )
+
+ # 4. Normalise segmentation/rapid-capture state to avoid lingering block configuration
+ if self._is_connected and self._chandle is not None:
+ try:
+ max_samples = ctypes.c_int32()
+ status = self._ps.ps5000aMemorySegments(
+ self._chandle, 1, ctypes.byref(max_samples)
+ )
+ self._assert_pico_ok(status)
+ status = self._ps.ps5000aSetNoOfCaptures(self._chandle, 1)
+ self._assert_pico_ok(status)
+ except Exception:
+ logger.exception(
+ f"Error normalising segmentation state during reset on device {self.device_id}"
+ )
+
+ # 5. Clear software buffers and channel configurations
+ self._buffers.clear()
+ self._enabled_channels.clear()
+ self._channel_ranges.clear()
+ self._channel_voltage_ranges.clear()
+
+ # 6. Reset acquisition parameters
+ self._sample_interval_s = None
+ self._downsample_ratio = 1
+ self._downsample_mode = "NONE"
+ self._current_acquisition_params.clear()
+
+ def _store_downsampling_params(
+ self,
+ downsample_ratio: int,
+ downsample_mode: str,
+ ) -> None:
+ """Validates and stores downsampling parameters for later use."""
+ # Validate downsample ratio
+ if downsample_ratio < 1:
+ raise DeviceConfigurationError(
+ f"Downsample ratio must be >= 1, got {downsample_ratio}"
+ )
+ if downsample_ratio > self.MAX_DOWNSAMPLE_RATIO:
+ raise DeviceConfigurationError(
+ f"Downsample ratio {downsample_ratio} exceeds hardware limit of {self.MAX_DOWNSAMPLE_RATIO}"
+ )
+
+ # Validate downsample mode
+ mode_upper = downsample_mode.upper()
+ if mode_upper not in self.SUPPORTED_DOWNSAMPLE_MODES:
+ raise DeviceConfigurationError(
+ f"Unsupported downsample mode: {downsample_mode}. "
+ f"Supported modes: {', '.join(self.SUPPORTED_DOWNSAMPLE_MODES)}"
+ )
+
+ # Store parameters for use in SDK calls
+ self._downsample_ratio = downsample_ratio
+ self._downsample_mode = mode_upper
+
+ logger.debug(
+ "Stored downsampling parameters for device {}: ratio={}, mode={}",
+ self.device_id,
+ downsample_ratio,
+ downsample_mode,
+ )
+
+ def _configure_trigger(
+ self,
+ source: Optional[str],
+ threshold_v: float = 0.5,
+ direction: str = "RISING",
+ delay_s: float = 0.0,
+ ) -> None:
+ """Configures the trigger for the acquisition."""
+ if source is None:
+ # Disable trigger
+ status = self._ps.ps5000aSetSimpleTrigger(self._chandle, 0, 0, 0, 0, 0, 0)
+ self._assert_pico_ok(status)
+ logger.debug("Trigger disabled for device {}", self.device_id)
+ return
+
+ source_upper = source.upper()
+ if source_upper != "EXT":
+ raise DeviceConfigurationError(
+ f"Unsupported trigger source: {source}. Only 'EXT' is supported."
+ )
+
+ source_enum = self._ps.PS5000A_CHANNEL["PS5000A_EXTERNAL"]
+ threshold_adc = self._v_to_adc_trigger(threshold_v)
+
+ valid_directions = {"RISING", "FALLING"}
+ direction_upper = direction.upper()
+ if direction_upper not in valid_directions:
+ raise DeviceConfigurationError(
+ f"Invalid trigger direction: {direction}. Supported: RISING, FALLING"
+ )
+ direction_enum = self._ps.PS5000A_THRESHOLD_DIRECTION[
+ f"PS5000A_{direction_upper}"
+ ]
+
+ if self._sample_interval_s is None:
+ raise DeviceOperationError(
+ "Sample interval must be set before configuring trigger delay."
+ )
+ delay_samples = int(delay_s / self._sample_interval_s)
+
+ status = self._ps.ps5000aSetSimpleTrigger(
+ self._chandle,
+ 1, # enabled
+ source_enum,
+ threshold_adc,
+ direction_enum,
+ delay_samples,
+ 0, # autoTrigger_ms=0 -> wait indefinitely
+ )
+ self._assert_pico_ok(status)
+ logger.debug(
+ "Configured trigger on {} at {}V ({}) for device {}",
+ source,
+ threshold_v,
+ direction,
+ self.device_id,
+ )
+
+ def _v_to_adc_trigger(self, voltage_v: float) -> int:
+ """Converts a voltage to ADC counts for the external trigger input."""
+ if not -5.0 <= voltage_v <= 5.0:
+ raise DeviceConfigurationError(
+ "External trigger threshold must be between -5V and +5V."
+ )
+ if self._max_adc_value is None:
+ raise DeviceOperationError(
+ "Cannot convert voltage to ADC counts, max_adc_value not set."
+ )
+
+ # External trigger range is fixed at +/- 5V
+ trigger_range_mv = 5000.0
+
+ # Convert voltage to millivolts
+ voltage_mv = voltage_v * 1000.0
+
+ # Calculate ADC value
+ adc_value = int((voltage_mv / trigger_range_mv) * self._max_adc_value.value)
+ return adc_value
+
+ def _configure_channel(
+ self,
+ channel_id: int,
+ enabled: bool,
+ coupling: str,
+ voltage_range: float,
+ offset: float = 0.0,
+ ) -> None:
+ """Configures a single input channel."""
+ if not 0 <= channel_id <= 3:
+ raise DeviceConfigurationError(
+ f"Invalid channel ID: {channel_id}. Must be between 0 and 3."
+ )
+ if voltage_range not in self.VOLTAGE_RANGES:
+ raise DeviceConfigurationError(
+ f"Invalid voltage range: {voltage_range}V. "
+ f"Valid ranges: {sorted(self.VOLTAGE_RANGES.keys())}V"
+ )
+ range_str = self.VOLTAGE_RANGES[voltage_range]
+
+ coupling_enum = self._ps.PS5000A_COUPLING[f"PS5000A_{coupling.upper()}"]
+ channel_enum = self._get_channel_enum(channel_id)
+ range_enum = self._ps.PS5000A_RANGE[range_str]
+
+ status = self._ps.ps5000aSetChannel(
+ self._chandle, channel_enum, int(enabled), coupling_enum, range_enum, offset
+ )
+ self._assert_pico_ok(status)
+
+ if enabled:
+ if channel_id not in self._enabled_channels:
+ self._enabled_channels.append(channel_id)
+ self._channel_ranges[channel_id] = range_str
+ self._channel_voltage_ranges[channel_id] = voltage_range
+ else:
+ if channel_id in self._enabled_channels:
+ self._enabled_channels.remove(channel_id)
+ if channel_id in self._channel_ranges:
+ del self._channel_ranges[channel_id]
+ if channel_id in self._channel_voltage_ranges:
+ del self._channel_voltage_ranges[channel_id]
+
+ def _get_timebase(
+ self, desired_interval: float, num_samples: int
+ ) -> Tuple[int, float]:
+ """
+ Calculate the correct timebase for the desired sample interval.
+
+ Uses the appropriate formula based on the device's resolution to calculate
+ the timebase index, then validates it with the PicoSDK.
+ """
+ import math
+
+ # Determine max sample rate and timebase formula based on resolution
+ if self.resolution == 8:
+ max_sample_rate = 1_000_000_000.0
+ # For 8-bit: interval = (timebase + 1) * 1ns for timebase 0-4, then (timebase - 4) * 2ns
+ if desired_interval <= 5e-9:
+ timebase_idx = math.ceil(desired_interval / 1e-9) - 1
+ else:
+ timebase_idx = math.ceil(desired_interval / 2e-9) + 4
+ elif self.resolution == 12:
+ max_sample_rate = 500_000_000.0
+ # For 12-bit: interval = (timebase + 1) * 2ns for timebase 0-2, then (timebase - 2) * 4ns
+ if desired_interval <= 6e-9:
+ timebase_idx = math.ceil(desired_interval / 2e-9) - 1
+ else:
+ timebase_idx = math.ceil(desired_interval / 4e-9) + 2
+ elif self.resolution == 14:
+ max_sample_rate = 125_000_000.0
+ # For 14-bit: interval = (timebase + 1) * 8ns
+ timebase_idx = math.ceil(desired_interval / 8e-9) - 1
+ else: # 15 and 16 bit
+ max_sample_rate = 62_500_000.0
+ # For 15/16-bit: interval = (timebase + 1) * 16ns
+ timebase_idx = math.ceil(desired_interval / 16e-9) - 1
+
+ # Validate that the desired sample rate is within hardware limits for the resolution
+ desired_sample_rate = 1.0 / desired_interval
+ if desired_sample_rate > max_sample_rate:
+ raise DeviceConfigurationError(
+ f"Requested sample rate {desired_sample_rate / 1e6:.1f} MS/s exceeds "
+ f"hardware limit of {max_sample_rate / 1e6:.1f} MS/s for {self.resolution}-bit resolution."
+ )
+
+ # Ensure timebase is not negative
+ if timebase_idx < 0:
+ timebase_idx = 0
+
+ # Verify the calculated timebase with the SDK and get the actual interval.
+ # The SDK will error out if this timebase is invalid for the current mode.
+ timebase = ctypes.c_uint32(int(timebase_idx))
+ interval = ctypes.c_float()
+ max_samples = ctypes.c_int32()
+
+ status = self._ps.ps5000aGetTimebase2(
+ self._chandle,
+ timebase,
+ ctypes.c_int32(num_samples),
+ ctypes.byref(interval),
+ ctypes.byref(max_samples),
+ ctypes.c_uint32(0),
+ )
+ self._assert_pico_ok(status)
+
+ actual_interval = interval.value * 1e-9
+ logger.debug(
+ "Calculated timebase {} for {}-bit resolution: desired interval {:.3e}s -> actual {:.3e}s",
+ timebase_idx,
+ self.resolution,
+ desired_interval,
+ actual_interval,
+ )
+ return int(timebase_idx), actual_interval
+
+ def _setup_block_buffers(self, channels: List[int], num_samples: int) -> None:
+ """Set up data buffers for block mode capture."""
+ for ch in channels:
+ self._buffers[ch] = np.zeros(shape=num_samples, dtype=np.int16)
+ status = self._ps.ps5000aSetDataBuffer(
+ self._chandle,
+ self._get_channel_enum(ch),
+ self._buffers[ch].ctypes.data_as(ctypes.POINTER(ctypes.c_int16)),
+ num_samples,
+ 0,
+ self._ps.PS5000A_RATIO_MODE["PS5000A_RATIO_MODE_NONE"],
+ )
+ self._assert_pico_ok(status)
+
+ def _run_block(self, num_samples: int, timebase: int) -> None:
+ """Initiates a single block acquisition."""
+ time_indisposed_ms = ctypes.c_int32()
+ status = self._ps.ps5000aRunBlock(
+ self._chandle,
+ 0, # preTriggerSamples
+ num_samples,
+ ctypes.c_uint32(timebase),
+ ctypes.byref(time_indisposed_ms),
+ 0,
+ None,
+ None,
+ )
+ self._assert_pico_ok(status)
+
+ def _wait_for_capture(
+ self, timeout_s: float = 10.0, cancel_event: Optional[threading.Event] = None
+ ) -> None:
+ """Waits for the block capture to complete."""
+ ready = ctypes.c_int16(0)
+ start_time = time.time()
+ while ready.value == 0:
+ if cancel_event and cancel_event.is_set():
+ self.cancel_acquisition()
+ raise TaskCancelledError("Acquisition cancelled by user.")
+ if time.time() - start_time > timeout_s:
+ self.cancel_acquisition()
+ raise TimeoutError(
+ f"Block capture timed out after {timeout_s} seconds."
+ )
+ status = self._ps.ps5000aIsReady(self._chandle, ctypes.byref(ready))
+ self._assert_pico_ok(status)
+ time.sleep(0.01)
+
+ def _get_block_data(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """Retrieves the captured block data and converts it to voltage."""
+ overflow = ctypes.c_int16()
+ buffer_len = len(self._buffers[self._enabled_channels[0]])
+ samples_retrieved = ctypes.c_uint32(buffer_len)
+
+ # Apply downsampling in the GetValues call
+ downsample_ratio_val = ctypes.c_uint32(self._downsample_ratio)
+ downsample_mode = self._ps.PS5000A_RATIO_MODE[
+ f"PS5000A_RATIO_MODE_{self._downsample_mode}"
+ ]
+
+ status = self._ps.ps5000aGetValues(
+ self._chandle,
+ 0, # startIndex
+ ctypes.byref(samples_retrieved), # noOfSamples (in/out)
+ downsample_ratio_val,
+ downsample_mode,
+ 0, # segmentIndex
+ ctypes.byref(overflow),
+ )
+ self._assert_pico_ok(status)
+
+ actual_num_samples = samples_retrieved.value
+ logger.debug(
+ "Retrieved {} samples from a buffer of size {}.",
+ actual_num_samples,
+ buffer_len,
+ )
+
+ raw_adc_data = np.array(
+ [self._buffers[ch][:actual_num_samples] for ch in self._enabled_channels],
+ dtype=np.int16,
+ )
+
+ # Convert to voltage and check for clipping
+ voltages_list = []
+ if self._max_adc_value is None:
+ raise DeviceOperationError(
+ "Cannot convert ADC, max_adc_value not available."
+ )
+ max_adc = self._max_adc_value.value
+
+ for i, ch in enumerate(self._enabled_channels):
+ adc_buffer = raw_adc_data[i]
+ voltage_range = self._channel_voltage_ranges[ch]
+
+ # Check for clipping (analogue saturation)
+ clipped_samples = np.sum((adc_buffer >= max_adc) | (adc_buffer <= -max_adc))
+ if clipped_samples > 0:
+ total_samples = len(adc_buffer)
+ clip_percent = (clipped_samples / total_samples) * 100
+ logger.warning(
+ "Picoscope channel {} may be clipping. {} of {} samples ({:.2f}%) are at the ADC limit. Consider using a larger voltage range than {}V.",
+ ch,
+ clipped_samples,
+ total_samples,
+ clip_percent,
+ voltage_range,
+ )
+
+ # Perform vectorized scaling to volts
+ voltages = (adc_buffer.astype(np.float32) / max_adc) * voltage_range
+ voltages_list.append(voltages)
+
+ logger.debug(
+ "Converted ADC to voltage for channel {}: range={}V, ADC min/max={}/{}, max_adc={}, V min/max={:.6f}/{:.6f}",
+ ch,
+ voltage_range,
+ adc_buffer.min(),
+ adc_buffer.max(),
+ max_adc,
+ voltages.min(),
+ voltages.max(),
+ )
+
+ voltage_data = np.array(voltages_list)
+
+ if self._sample_interval_s is None:
+ raise DeviceOperationError(
+ "Sample interval not set for block data retrieval."
+ )
+
+ downsample_ratio = self._current_acquisition_params.get("downsample_ratio", 1)
+ effective_interval = self._sample_interval_s * downsample_ratio
+
+ time_data = np.linspace(
+ 0, (actual_num_samples - 1) * effective_interval, actual_num_samples
+ )
+ return time_data, voltage_data, raw_adc_data
+
+ def _get_channel_enum(self, ch: int) -> int:
+ """Converts a channel index (0-3) to the corresponding PicoSDK channel enum."""
+ return self._ps.PS5000A_CHANNEL[f"PS5000A_CHANNEL_{chr(65 + ch)}"]
+
+ def _assert_pico_ok(self, status: int) -> None:
+ """Check if a PicoScope status code indicates success."""
+ try:
+ if assert_pico_ok is not None:
+ assert_pico_ok(status) # type: ignore
+ except Exception as e:
+ # Provide more context for common errors
+ if status == 2: # PICO_NOT_FOUND
+ raise DeviceConnectionError(
+ f"Picoscope device not found. Status: {status}"
+ ) from e
+ elif status in {286, 282}: # PICO_POWER_SUPPLY_...
+ logger.warning(
+ f"Picoscope device {self.device_id} has a power issue (status: {status}). "
+ "Consider connecting external power or using a USB3.0 port."
+ )
+ return
+ else:
+ raise DeviceOperationError(
+ f"PicoSDK operation failed with status {status}: {e}"
+ ) from e
+
+
+class PicoscopeSimpleStream(Picoscope):
+ """
+ Picoscope device with simple callback-based streaming.
+
+ This class inherits all block acquisition functionality from Picoscope
+ and provides direct callback streaming suitable for live monitoring.
+ """
+
+ def __init__(
+ self,
+ device_id: str,
+ serial_code: Optional[str] = None,
+ resolution: int = 12,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(device_id, serial_code, resolution, **kwargs)
+
+ self._streaming_thread: Optional[threading.Thread] = None
+ self._stop_streaming_event = threading.Event()
+ self._streaming_callback_ref: Optional[Callable] = None
+ self._streaming_voltage_range: Optional[float] = None
+ self._streaming_sample_rate: Optional[float] = None
+ self._streaming_channel: int = 0
+ self._streaming_num_samples_per_block: int = 16384
+ self._sdk_buffer: Optional[np.ndarray] = None
+ self._callbackFuncPtr: Optional[Any] = (
+ None # Will be recreated per streaming session
+ )
+
+ # --- Streaming Performance Tracking ---
+ self._stream_stats_lock = threading.Lock()
+ self._stream_total_samples = 0
+ self._stream_callback_count = 0
+ self._stream_overflow_count = 0
+ self._stream_last_summary_time = 0.0
+ self._stream_start_count = 0 # Track how many times streaming has been started
+
+ # --- Diagnostic Timing for Chunking Issues ---
+ self._last_callback_time: Optional[float] = None
+ self._callback_intervals: List[float] = []
+
+ def configure_streaming(
+ self,
+ sample_rate: float,
+ voltage_range: float,
+ channel: int = 0,
+ num_samples_per_block: int = 4096,
+ downsample_ratio: int = 1,
+ downsample_mode: str = "NONE",
+ ) -> None:
+ """
+ Configures the parameters for hardware-timed streaming.
+
+ This must be called before `start_streaming`.
+
+ Args:
+ sample_rate: The desired sample rate in Hz.
+ voltage_range: The voltage range for the acquisition channel.
+ channel: The channel to stream from (default 0).
+ num_samples_per_block: The number of samples in each acquired block.
+ downsample_ratio: The hardware downsampling factor.
+ downsample_mode: The downsampling mode (e.g., 'AVERAGE').
+ """
+ if sample_rate <= 0:
+ raise DeviceConfigurationError("Sample rate must be positive.")
+ if num_samples_per_block <= 0:
+ raise DeviceConfigurationError(
+ "Number of samples per block must be positive."
+ )
+
+ # Reset device to a clean state before configuring streaming
+ self._reset_device_state()
+
+ # When mode is NONE, force ratio to 1 for SDK compatibility
+ if downsample_mode.upper() == "NONE":
+ downsample_ratio = 1
+
+ # Validate and store downsampling parameters
+ self._store_downsampling_params(downsample_ratio, downsample_mode)
+
+ self._streaming_sample_rate = sample_rate
+ self._streaming_voltage_range = voltage_range
+ self._streaming_channel = channel
+ self._streaming_num_samples_per_block = num_samples_per_block
+ logger.debug(
+ "Streaming configured for device {}: rate={}Hz, range={}V, downsample={}x ({})",
+ self.device_id,
+ sample_rate,
+ voltage_range,
+ downsample_ratio,
+ downsample_mode,
+ )
+
+ def start_streaming(self, callback: Callable) -> None:
+ """
+ Starts a true hardware-timed stream using the PicoSDK's streaming mode.
+
+ `configure_streaming` must be called before this method.
+
+ Parameters
+ ----------
+ callback : Callable
+ A function to call with new data. It will receive two arguments:
+ a (currently None) time array and a NumPy array of voltage data.
+ """
+ if not self._is_connected or self._chandle is None:
+ raise DeviceNotConnectedError(f"Device {self.device_id} is not connected.")
+
+ if self._streaming_thread and self._streaming_thread.is_alive():
+ logger.warning(
+ "Streaming is already active on device {}.",
+ self.device_id,
+ )
+ return
+
+ if self._streaming_sample_rate is None or self._streaming_voltage_range is None:
+ raise DeviceConfigurationError(
+ "Streaming parameters not configured. Call `configure_streaming` before `start_streaming`."
+ )
+
+ # Store callback and reset stop event
+ self._streaming_callback_ref = callback
+ self._stop_streaming_event.clear()
+
+ # Reset performance counters and increment start count
+ with self._stream_stats_lock:
+ self._stream_total_samples = 0
+ self._stream_callback_count = 0
+ self._stream_overflow_count = 0
+ self._stream_last_summary_time = time.time()
+ self._stream_start_count += 1
+
+ # Reset diagnostic timing
+ self._last_callback_time = None
+ self._callback_intervals.clear()
+
+ # 1) Configure channel for streaming
+ channel = self._streaming_channel or 0
+ self._configure_channel(channel, True, "DC", self._streaming_voltage_range)
+
+ # 2) Recreate callback function pointer and allocate SDK buffer
+ self._callbackFuncPtr = self._ps.StreamingReadyType(
+ self._streaming_sdk_callback
+ )
+ buffer_size = int(self._streaming_num_samples_per_block or 50000)
+ self._sdk_buffer = np.zeros(shape=buffer_size, dtype=np.int16)
+ downsample_mode_enum = self._ps.PS5000A_RATIO_MODE[
+ f"PS5000A_RATIO_MODE_{self._downsample_mode.upper()}"
+ ]
+ status = self._ps.ps5000aSetDataBuffers(
+ self._chandle,
+ self._get_channel_enum(channel),
+ self._sdk_buffer.ctypes.data_as(ctypes.POINTER(ctypes.c_int16)),
+ None, # bufferMin
+ buffer_size,
+ 0, # segmentIndex
+ downsample_mode_enum,
+ )
+ self._assert_pico_ok(status)
+
+ # 3) Determine streaming interval and unit
+ interval, unit_enum, actual_rate = self._get_streaming_interval(
+ self._streaming_sample_rate
+ )
+ sample_interval = ctypes.c_int32(interval)
+
+ # 4) Start hardware streaming
+ downsample_mode_enum = self._ps.PS5000A_RATIO_MODE[
+ f"PS5000A_RATIO_MODE_{self._downsample_mode}"
+ ]
+ status = self._ps.ps5000aRunStreaming(
+ self._chandle,
+ ctypes.byref(sample_interval),
+ unit_enum,
+ 0, # preTriggerSamples
+ 10_000_000, # maxPostTriggerSamples (large number for continuous)
+ 0, # autoStop = False
+ self._downsample_ratio, # downSampleRatio
+ downsample_mode_enum,
+ buffer_size, # overviewBufferSize
+ )
+ self._assert_pico_ok(status)
+
+ logger.info(
+ "Streaming started on device {} (session #{}). "
+ "Requested rate: {} Hz, Actual rate: {} Hz. "
+ "Buffer size: {} samples.",
+ self.device_id,
+ self._stream_start_count,
+ self._streaming_sample_rate,
+ actual_rate,
+ buffer_size,
+ )
+
+ # 5) Start a lightweight polling thread to trigger SDK callbacks
+ self._streaming_thread = threading.Thread(
+ target=self._streaming_polling_loop,
+ daemon=True,
+ name=f"picoscope-stream-{self.device_id}",
+ )
+ self._streaming_thread.start()
+
+ def stop_streaming(self, timeout: float = 2.0) -> None:
+ """Stops the background streaming thread and hardware stream."""
+ if not self._streaming_thread or not self._streaming_thread.is_alive():
+ logger.info("Streaming is not active on device {}.", self.device_id)
+ return
+
+ logger.info("Stopping stream on device {}...", self.device_id)
+ self._stop_streaming_event.set()
+
+ # Avoid self-join deadlock if called from the polling thread itself
+ if threading.current_thread() is self._streaming_thread:
+ logger.warning(
+ "stop_streaming called from within the streaming polling thread on device {}. "
+ "The thread cannot be joined, but the stop event has been set.",
+ self.device_id,
+ )
+ return
+
+ # Join the polling thread with a bound
+ self._streaming_thread.join(timeout)
+ if self._streaming_thread.is_alive():
+ logger.warning(
+ "Streaming thread on device {} did not stop within {}s.",
+ self.device_id,
+ timeout,
+ )
+
+ # Stop the hardware stream regardless, swallowing errors
+ try:
+ if self._is_connected and self._chandle is not None:
+ status = self._ps.ps5000aStop(self._chandle)
+ if status != PICO_STATUS["PICO_OK"]:
+ logger.debug(
+ "ps5000aStop returned status {} on device {}.",
+ status,
+ self.device_id,
+ )
+ except Exception:
+ logger.exception("Error stopping hardware stream during stop_streaming")
+
+ # De-register streaming buffers using size 0 before clearing references
+ try:
+ if (
+ self._is_connected
+ and self._chandle is not None
+ and self._sdk_buffer is not None
+ ):
+ channel_enum = self._get_channel_enum(self._streaming_channel or 0)
+ status = self._ps.ps5000aSetDataBuffers(
+ self._chandle,
+ channel_enum,
+ self._sdk_buffer.ctypes.data_as(ctypes.POINTER(ctypes.c_int16)),
+ None, # bufferMin unused in non-aggregate mode
+ 0, # size=0 de-registers streaming buffer
+ 0, # segmentIndex
+ self._ps.PS5000A_RATIO_MODE["PS5000A_RATIO_MODE_NONE"],
+ )
+ if status != PICO_STATUS["PICO_OK"]:
+ logger.debug(
+ "ps5000aSetDataBuffers (deregister) returned status {} on device {}.",
+ status,
+ self.device_id,
+ )
+ except Exception:
+ logger.exception(
+ "Error de-registering streaming buffers during stop_streaming"
+ )
+
+ # Clear state
+ self._streaming_thread = None
+ self._streaming_callback_ref = None
+ self._sdk_buffer = None
+ self._callbackFuncPtr = None
+ logger.info("Streaming stopped and cleaned up for device {}.", self.device_id)
+
+ def is_streaming(self) -> bool:
+ """Returns True if the streaming thread is currently active."""
+ return self._streaming_thread is not None and self._streaming_thread.is_alive()
+
+ def disconnect(self) -> None:
+ """
+ Closes the connection to the Picoscope device.
+
+ Stops any active streaming before disconnecting.
+ """
+ if self.is_streaming():
+ self.stop_streaming()
+ super().disconnect()
+
+ def _get_streaming_interval(self, sample_rate_hz: float) -> Tuple[int, int, float]:
+ """
+ Calculate the integer interval and unit for streaming mode.
+
+ Parameters
+ ----------
+ sample_rate_hz : float
+ Desired sample rate in Hz.
+
+ Returns
+ -------
+ Tuple[int, int, float]
+ (interval, unit_enum, actual_rate_hz)
+ """
+ if sample_rate_hz <= 0:
+ raise DeviceConfigurationError("Sample rate must be positive.")
+
+ # Choose unit to maximise resolution of integer interval
+ if sample_rate_hz >= 1_000_000:
+ unit_str = "PS5000A_NS"
+ scale = 1_000_000_000
+ elif sample_rate_hz >= 1_000:
+ unit_str = "PS5000A_US"
+ scale = 1_000_000
+ else:
+ unit_str = "PS5000A_MS"
+ scale = 1_000
+
+ interval = max(1, int(round(scale / sample_rate_hz)))
+ unit_enum = self._ps.PS5000A_TIME_UNITS[unit_str]
+ actual_rate = scale / interval
+ return interval, unit_enum, actual_rate
+
+ def _streaming_polling_loop(self) -> None:
+ """Poll the SDK to trigger data callbacks until stopped."""
+ logger.debug(
+ "Streaming polling loop started for device {}.",
+ self.device_id,
+ )
+ try:
+ while not self._stop_streaming_event.is_set():
+ # --- Periodic Summary Logging ---
+ now = time.time()
+ with self._stream_stats_lock:
+ last_summary = self._stream_last_summary_time
+
+ if now - last_summary > 5.0:
+ with self._stream_stats_lock:
+ elapsed = now - self._stream_last_summary_time
+ if elapsed > 0:
+ sample_rate = self._stream_total_samples / elapsed
+ callback_rate = self._stream_callback_count / elapsed
+
+ # Add diagnostic info for chunking analysis
+ avg_interval = (
+ np.mean(self._callback_intervals)
+ if self._callback_intervals
+ else 0
+ )
+ max_interval = (
+ np.max(self._callback_intervals)
+ if self._callback_intervals
+ else 0
+ )
+
+ logger.info(
+ "Streaming summary for {} (session #{}): "
+ "{} Samples/s, {} Callbacks/s, Overflows: {}, "
+ "Avg callback interval: {}ms, Max callback interval: {}ms",
+ self.device_id,
+ self._stream_start_count,
+ sample_rate,
+ callback_rate,
+ self._stream_overflow_count,
+ avg_interval * 1000,
+ max_interval * 1000,
+ )
+ # Reset for next interval
+ self._stream_total_samples = 0
+ self._stream_callback_count = 0
+ self._stream_last_summary_time = now
+ self._callback_intervals.clear()
+
+ # --- Poll for Data ---
+ status = self._ps.ps5000aGetStreamingLatestValues(
+ self._chandle, self._callbackFuncPtr, None
+ )
+ # Expected statuses during normal operation
+ if status not in [
+ PICO_STATUS["PICO_OK"],
+ PICO_STATUS["PICO_BUSY"],
+ PICO_STATUS["PICO_NO_SAMPLES_AVAILABLE"],
+ ]:
+ logger.warning(
+ "ps5000aGetStreamingLatestValues returned status {} on device {}.",
+ status,
+ self.device_id,
+ )
+
+ time.sleep(0.005) # Yield to other threads
+ except Exception:
+ logger.exception("Error in streaming polling loop")
+ finally:
+ # Best-effort stop of hardware when loop exits
+ try:
+ if self._is_connected and self._chandle is not None:
+ status = self._ps.ps5000aStop(self._chandle)
+ if status != PICO_STATUS["PICO_OK"]:
+ logger.debug(
+ f"ps5000aStop returned status {status} on device {self.device_id}."
+ )
+ except Exception:
+ logger.exception("Error stopping hardware stream on polling loop exit")
+ logger.debug(
+ "Streaming polling loop finished for device {}.",
+ self.device_id,
+ )
+
+ def _streaming_sdk_callback(
+ self,
+ _handle: int,
+ noOfSamples: int,
+ startIndex: int,
+ overflow: int,
+ _triggerAt: int,
+ _triggered: int,
+ _autoStop: int,
+ _param: int,
+ ) -> None:
+ """
+ SDK callback executed when new streaming data is available.
+
+ Notes
+ -----
+ - Called from a PicoSDK-managed thread.
+ - Must be fast and must not acquire the device lock.
+ - Fail-fast on user-callback exceptions: stop event is set.
+ """
+ try:
+ # Track timing for diagnostic purposes
+ current_time = time.time()
+ interval = None
+ if self._last_callback_time is not None:
+ interval = current_time - self._last_callback_time
+ self._callback_intervals.append(interval)
+ # Keep only recent intervals to avoid memory growth
+ if len(self._callback_intervals) > 100:
+ self._callback_intervals = self._callback_intervals[-50:]
+ self._last_callback_time = current_time
+
+ with self._stream_stats_lock:
+ if overflow:
+ self._stream_overflow_count += 1
+ logger.warning(
+ "Picoscope hardware buffer overflow on device {}. Data loss may have occurred.",
+ self.device_id,
+ )
+ if noOfSamples > 0:
+ self._stream_total_samples += noOfSamples
+ self._stream_callback_count += 1
+
+ # Log large chunks for diagnostic purposes (only for sessions 2+)
+ if self._stream_start_count >= 2 and noOfSamples > 40:
+ interval_ms = interval * 1000 if interval is not None else 0
+ logger.debug(
+ "Large chunk detected on {} session #{}: {} samples (interval: {}ms)",
+ self.device_id,
+ self._stream_start_count,
+ noOfSamples,
+ interval_ms,
+ )
+
+ if noOfSamples <= 0 or self._sdk_buffer is None:
+ return
+
+ # Slice ADC data from the SDK-managed buffer
+ adc_view = self._sdk_buffer[startIndex : startIndex + noOfSamples]
+
+ if self._max_adc_value is None or self._streaming_voltage_range is None:
+ logger.warning(
+ "Cannot process streaming data: max_adc or voltage_range not set."
+ )
+ return
+
+ # Convert to volts
+ max_adc = self._max_adc_value.value
+ volts = (
+ adc_view.astype(np.float32) / max_adc
+ ) * self._streaming_voltage_range
+
+ # Invoke user callback; if it raises, fail fast
+ if self._streaming_callback_ref is not None:
+ try:
+ self._streaming_callback_ref(None, volts)
+ except Exception:
+ logger.exception(
+ "Error in user streaming callback on device {}", self.device_id
+ )
+ self._stop_streaming_event.set()
+ except Exception:
+ logger.exception("Unexpected error in SDK streaming callback")
+
+
+class PicoscopeBufferedStream(Picoscope):
+ """
+ Picoscope device with buffered streaming using a ring buffer and Zarr save.
+
+ This class inherits all block acquisition functionality from Picoscope
+ and implements a producer-consumer pattern with an in-memory ring buffer
+ for live display and on-demand Zarr saving.
+
+ Parameters
+ ----------
+ device_id : str
+ Unique identifier for this device instance.
+ serial_code : Optional[str], default None
+ Serial code of the Picoscope hardware. If None, auto-detect is used.
+ resolution : int, default 12
+ ADC resolution in bits (8, 12, 14, 15, or 16).
+ """
+
+ def __init__(
+ self,
+ device_id: str,
+ serial_code: Optional[str] = None,
+ resolution: int = 12,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(device_id, serial_code, resolution, **kwargs)
+
+ # --- Streaming Configuration ---
+ self._streaming_configured: bool = False
+ self._streaming_channels: List[int] = []
+ self._enabled_channels_set: set[int] = set()
+ self._streaming_voltage_ranges: Dict[int, float] = {}
+ self._streaming_offsets_v: Dict[int, float] = {}
+ self._streaming_sample_rate: Optional[float] = None
+ self._streaming_sample_interval_ns: Optional[int] = None
+ self._streaming_sample_unit: Optional[int] = None
+ self._streaming_downsample_ratio: int = 1
+ self._streaming_downsample_mode: str = "NONE"
+ self._streaming_bandwidth_limiter: str = "FULL"
+ self._streaming_buffer_duration_s: float = 30.0
+
+ # --- Ring Buffer ---
+ self._ring_buffer: Optional[Any] = None
+
+ # --- SDK Buffers ---
+ self._sdk_buffers: Dict[int, np.ndarray] = {}
+ self._sdk_buffer_min: Optional[np.ndarray] = None
+ self._interleaved_buffer: Optional[np.ndarray] = None
+
+ # --- Threads and Control ---
+ self._producer_thread: Optional[threading.Thread] = None
+ self._stop_streaming_event = threading.Event()
+ self._user_callback: Optional[Callable] = None
+ self._callbackFuncPtr: Optional[Any] = None
+
+ # --- Save Thread State ---
+ self._save_thread: Optional[threading.Thread] = None
+ self._save_stop_event: Optional[threading.Event] = None
+ self._save_output_path: Optional[str] = None
+ self._save_samples_written: int = 0
+ self._save_pre_trigger_samples: int = 0
+ self._save_error: Optional[Exception] = None
+ self._is_saving: bool = False
+ self._save_start_time_iso: Optional[str] = None
+ self._save_stop_time_iso: Optional[str] = None
+
+ # --- Performance Tracking ---
+ self._stream_stats_lock = threading.Lock()
+ self._stream_total_samples: int = 0
+ self._stream_callback_count: int = 0
+ self._stream_overflow_count: int = 0
+ self._hardware_overflow_count: int = 0
+ self._overvoltage_count: int = 0
+
+ # --- Streaming Error State ---
+ self._streaming_error: Optional[str] = None
+
+ # --- Pre-allocated callback buffer for zero-copy operation ---
+ self._callback_buffer: Optional[np.ndarray] = None
+
+ # --- Plotter position tracking for overflow detection ---
+ self._last_plotter_read_idx: int = 0
+ self._plotter_position_lock = threading.Lock()
+ self._streaming_start_time: float = 0.0
+
+ def configure_streaming(
+ self,
+ sample_rate: float,
+ channels: List[int],
+ voltage_ranges: List[float],
+ buffer_duration_s: float = 30.0,
+ downsample_ratio: int = 1,
+ downsample_mode: str = "NONE",
+ offsets_v: Optional[List[float]] = None,
+ bandwidth_limiter: str = "FULL",
+ ) -> None:
+ """Configure buffered streaming acquisition with a ring buffer.
+
+ This sets up the ring buffer for high-speed multi-channel acquisition
+ and prepares channel configuration. Data is stored in memory and can
+ be saved on-demand using start_save().
+
+ The ring buffer always has 2 columns (Channel A and Channel B).
+ Disabled channels will have zeros written to their column.
+
+ Parameters
+ ----------
+ sample_rate : float
+ Desired sample rate in Hz per enabled channel.
+ channels : List[int]
+ List of channel indices to enable for acquisition (e.g., [0, 1] for A and B).
+ Empty list is allowed (no acquisition, just configuration).
+ voltage_ranges : List[float]
+ Voltage range for each channel in volts. Must match length of channels.
+ buffer_duration_s : float, default 30.0
+ Duration of the ring buffer in seconds. Determines memory usage.
+ At 62.5 MS/s per channel with 2 channels, 30s requires ~7.5 GB.
+ downsample_ratio : int, default 1
+ Hardware downsampling ratio.
+ downsample_mode : str, default "NONE"
+ Downsampling mode ("NONE", "AVERAGE", "DECIMATE", "AGGREGATE").
+ offsets_v : Optional[List[float]], default None
+ Analogue offset in volts for each channel. If None, defaults to 0.0 for all.
+
+ Raises
+ ------
+ DeviceConfigurationError
+ If parameters are invalid or inconsistent.
+ ImportError
+ If ring buffer module is not available.
+ """
+ if sample_rate <= 0:
+ raise DeviceConfigurationError("Sample rate must be positive.")
+ if not channels:
+ raise DeviceConfigurationError(
+ "At least one channel must be enabled for acquisition. "
+ "Specify channels=[0] for Channel A or channels=[1] for Channel B."
+ )
+ if len(channels) != len(voltage_ranges):
+ raise DeviceConfigurationError(
+ "Length of channels and voltage_ranges must match."
+ )
+ if len(set(channels)) != len(channels):
+ raise DeviceConfigurationError("Duplicate channel indices not allowed.")
+ for ch in channels:
+ if not 0 <= ch <= 3:
+ raise DeviceConfigurationError(
+ f"Invalid channel index: {ch}. Must be 0-3."
+ )
+ for v_range in voltage_ranges:
+ if v_range not in self.VOLTAGE_RANGES:
+ raise DeviceConfigurationError(
+ f"Invalid voltage range: {v_range}V. "
+ f"Valid ranges: {sorted(self.VOLTAGE_RANGES.keys())}V"
+ )
+
+ if buffer_duration_s <= 0:
+ raise DeviceConfigurationError("Buffer duration must be positive.")
+
+ try:
+ from picostream.ring_buffer import RingBuffer
+ except ImportError as e:
+ raise ImportError(
+ "Ring buffer module not found. "
+ "Please ensure the picostream app is available."
+ ) from e
+
+ if downsample_mode.upper() == "NONE":
+ downsample_ratio = 1
+
+ self._store_downsampling_params(downsample_ratio, downsample_mode)
+ self._reset_device_state()
+
+ # Always use 2 columns in ring buffer (A and B)
+ num_buffer_channels = 2
+
+ # Store enabled channels and their configurations
+ self._streaming_channels = list(channels)
+ self._enabled_channels_set = set(channels)
+ self._streaming_voltage_ranges = {
+ ch: v_range for ch, v_range in zip(channels, voltage_ranges, strict=True)
+ }
+
+ # Hardware samples at the requested rate regardless of downsampling mode
+ # AGGREGATE mode stores 2 values per downsampled point (min, max)
+ # but the hardware timing is still based on the original sample rate
+ self._streaming_sample_rate = sample_rate
+
+ self._streaming_buffer_duration_s = buffer_duration_s
+
+ if offsets_v is None:
+ offsets_v = [0.0] * len(channels)
+ elif len(offsets_v) != len(channels):
+ raise DeviceConfigurationError(
+ "Length of offsets_v must match length of channels."
+ )
+
+ self._streaming_offsets_v = {
+ ch: offset for ch, offset in zip(channels, offsets_v, strict=True)
+ }
+ self._streaming_bandwidth_limiter = bandwidth_limiter.upper()
+
+ # Configure all physical channels - enabled ones for acquisition,
+ # disabled ones are turned off in hardware
+ all_channels = [0, 1] # Only support A and B for now
+ for ch in all_channels:
+ if ch in self._enabled_channels_set:
+ idx = channels.index(ch)
+ v_range = voltage_ranges[idx]
+ offset = offsets_v[idx]
+ self._configure_channel(ch, True, "DC", v_range, offset)
+ else:
+ # Disable channel in hardware - use dummy values
+ self._configure_channel(ch, False, "DC", 20.0, 0.0)
+
+ # Create AcquisitionRate to calculate effective storage rate
+ # hardware_rate_hz is total across all channels (per-channel × num_channels)
+ total_hardware_rate_hz = sample_rate * len(channels)
+ self._acquisition_rate = AcquisitionRate(
+ hardware_rate_hz=total_hardware_rate_hz,
+ num_channels=len(channels),
+ downsample_ratio=self._downsample_ratio,
+ downsample_mode=DownsampleMode(self._downsample_mode),
+ )
+
+ # Use storage_rate_hz which accounts for hardware downsampling
+ # This is the actual rate at which samples arrive at the buffer
+ self._ring_buffer = RingBuffer(
+ duration_s=buffer_duration_s,
+ sample_rate=self._acquisition_rate.storage_rate_hz,
+ num_channels=num_buffer_channels,
+ )
+ self._ring_buffer.set_acquisition_rate(self._acquisition_rate)
+
+ buffer_memory_mb = self._ring_buffer.buffer.nbytes / (1024 * 1024)
+
+ try:
+ import psutil # type: ignore
+
+ available_memory_mb = psutil.virtual_memory().available / (1024 * 1024)
+ if buffer_memory_mb > available_memory_mb * 0.5:
+ logger.warning(
+ "Ring buffer (%.1f MB) exceeds 50%% of available RAM (%.1f MB). "
+ "Consider reducing buffer_duration_s.",
+ buffer_memory_mb,
+ available_memory_mb,
+ )
+ except ImportError:
+ logger.debug("psutil not available, skipping memory validation")
+
+ # SDK buffers only for enabled channels
+ # Use a single large buffer (500ms) matching reference implementation pattern
+ samples_per_buffer = int(sample_rate * 0.5)
+ if samples_per_buffer < 1:
+ samples_per_buffer = 5000
+
+ self._sdk_buffers = {}
+ for ch in channels:
+ self._sdk_buffers[ch] = np.zeros(shape=samples_per_buffer, dtype=np.int16)
+
+ # Pre-allocate callback output buffer (Fix #1: avoid allocation in hot path)
+ # Max size is samples_per_buffer * 2 for AGGREGATE mode (min/max pairs)
+ max_callback_samples = (
+ samples_per_buffer * 2
+ if downsample_mode.upper() == "AGGREGATE"
+ else samples_per_buffer
+ )
+ self._callback_buffer = np.zeros((max_callback_samples, 2), dtype=np.int16)
+
+ if self._downsample_mode == "AGGREGATE":
+ self._sdk_buffer_min = np.zeros(shape=samples_per_buffer, dtype=np.int16)
+ self._interleaved_buffer = np.zeros(
+ shape=samples_per_buffer * 2, dtype=np.int16
+ )
+
+ interval, unit_enum, actual_rate = self._get_streaming_interval_ns(sample_rate)
+ self._streaming_sample_interval_ns = interval
+ self._streaming_sample_unit = unit_enum
+
+ self._streaming_configured = True
+
+ enabled_names = ["A" if ch == 0 else "B" for ch in sorted(channels)]
+ logger.info(
+ "Buffered streaming configured for {}: channels=[{}], "
+ "rate={} Hz per channel, {:.1f}s buffer ({:.1f} MB), "
+ "downsample={}x ({})",
+ self.device_id,
+ ", ".join(enabled_names),
+ actual_rate,
+ buffer_duration_s,
+ buffer_memory_mb,
+ downsample_ratio,
+ downsample_mode,
+ )
+
+ def start_streaming(self, callback: Optional[Callable] = None) -> None:
+ """Start buffered streaming acquisition to the ring buffer.
+
+ Parameters
+ ----------
+ callback : Optional[Callable], default None
+ Optional callback function called with metadata when data arrives.
+ Receives a dict: {'type': 'data_ready', 'samples': int,
+ 'channels': List[int], 'total_samples': int}.
+
+ Raises
+ ------
+ DeviceNotConnectedError
+ If device is not connected.
+ DeviceConfigurationError
+ If streaming is not configured.
+ """
+ if not self._is_connected or self._chandle is None:
+ raise DeviceNotConnectedError(f"Device {self.device_id} is not connected.")
+
+ if not self._streaming_configured:
+ raise DeviceConfigurationError(
+ "Streaming not configured. Call configure_streaming() first."
+ )
+
+ if self.is_streaming():
+ logger.warning("Streaming already active on device {}.", self.device_id)
+ return
+
+ self._user_callback = callback
+ self._stop_streaming_event.clear()
+
+ with self._stream_stats_lock:
+ self._stream_total_samples = 0
+ self._stream_callback_count = 0
+ self._stream_overflow_count = 0
+ self._hardware_overflow_count = 0
+ self._overvoltage_count = 0
+ self._streaming_error = None # Clear any previous error
+
+ # Reset plotter tracking for overflow detection
+ with self._plotter_position_lock:
+ self._last_plotter_read_idx = 0
+ self._streaming_start_time = time.time()
+
+ self._setup_streaming_buffers()
+
+ self._callbackFuncPtr = self._ps.StreamingReadyType(self._streaming_callback)
+
+ if self._streaming_sample_interval_ns is None:
+ raise DeviceConfigurationError(
+ "Streaming interval not set. Call configure_streaming first."
+ )
+ sample_interval = ctypes.c_int32(self._streaming_sample_interval_ns)
+ downsample_mode_enum = self._ps.PS5000A_RATIO_MODE[
+ f"PS5000A_RATIO_MODE_{self._downsample_mode}"
+ ]
+
+ # Use single buffer size for maxPostTriggerSamples (matches reference pattern)
+ sdk_buffer_size = self._sdk_buffers[self._streaming_channels[0]].shape[0]
+ max_samples = ctypes.c_int32(sdk_buffer_size)
+
+ status = self._ps.ps5000aRunStreaming(
+ self._chandle,
+ ctypes.byref(sample_interval),
+ self._streaming_sample_unit,
+ 0, # preTriggerSamples
+ max_samples.value,
+ 0, # autoStop = False (continuous streaming)
+ self._downsample_ratio,
+ downsample_mode_enum,
+ sdk_buffer_size, # overviewBufferSize
+ )
+ self._assert_pico_ok(status)
+
+ # Get actual sample interval from SDK (Fix #2: use actual rate, not requested)
+ # IMPORTANT: The SDK's actual_rate_hz is the PRE-DOWNSAMPLED rate.
+ # It's what we requested (or the closest achievable).
+ # The effective per-channel rate after downsampling is actual_rate_hz / downsample_ratio.
+ self._streaming_sample_interval_ns = sample_interval.value
+ actual_rate_ns = self._streaming_sample_interval_ns
+ if self._streaming_sample_unit == self._ps.PS5000A_TIME_UNITS["PS5000A_US"]:
+ actual_rate_ns *= 1000
+ elif self._streaming_sample_unit == self._ps.PS5000A_TIME_UNITS["PS5000A_MS"]:
+ actual_rate_ns *= 1_000_000
+ actual_rate_hz = 1e9 / actual_rate_ns if actual_rate_ns > 0 else 0
+
+ logger.info(
+ "Hardware streaming started on {}. "
+ "Actual interval: {} (unit={}), rate={} Hz",
+ self.device_id,
+ self._streaming_sample_interval_ns,
+ self._streaming_sample_unit,
+ actual_rate_hz,
+ )
+
+ self._producer_thread = threading.Thread(
+ target=self._producer_loop,
+ daemon=True,
+ name=f"picoscope-producer-{self.device_id}",
+ )
+
+ self._producer_thread.start()
+
+ logger.info("Buffered streaming started on {}.", self.device_id)
+
+ def stop_streaming(self, timeout: float = 5.0) -> None:
+ """Stop buffered streaming acquisition.
+
+ Parameters
+ ----------
+ timeout : float, default 5.0
+ Maximum time to wait for threads to finish.
+ """
+ if not self.is_streaming():
+ logger.info("Streaming not active on device {}.", self.device_id)
+ return
+
+ logger.info("Stopping buffered stream on {}...", self.device_id)
+ self._stop_streaming_event.set()
+
+ if self._is_saving:
+ self.stop_save(keep=True)
+
+ try:
+ if self._is_connected and self._chandle is not None:
+ status = self._ps.ps5000aStop(self._chandle)
+ if status != PICO_STATUS["PICO_OK"]:
+ logger.debug("ps5000aStop returned status {}", status)
+ except Exception:
+ logger.exception("Error stopping hardware stream")
+
+ if self._producer_thread and self._producer_thread.is_alive():
+ self._producer_thread.join(timeout)
+ if self._producer_thread.is_alive():
+ logger.warning("Producer thread did not stop in time")
+
+ self._cleanup_streaming_buffers()
+
+ self._producer_thread = None
+ self._user_callback = None
+ self._callbackFuncPtr = None
+
+ with self._stream_stats_lock:
+ logger.info(
+ "Streaming stopped on {}. Total samples: {}, callbacks: {}, "
+ "hardware overflows: {}",
+ self.device_id,
+ self._stream_total_samples,
+ self._stream_callback_count,
+ self._hardware_overflow_count,
+ )
+
+ def is_streaming(self) -> bool:
+ """Return True if streaming thread is active."""
+ return self._producer_thread is not None and self._producer_thread.is_alive()
+
+ def get_streaming_error(self) -> Optional[str]:
+ """Get the last streaming error, if any.
+
+ Returns
+ -------
+ Optional[str]
+ Error description (e.g., "PICO_NOT_RESPONDING") if streaming
+ stopped due to an error, None otherwise.
+ Error is cleared when streaming is restarted.
+ """
+ with self._stream_stats_lock:
+ return self._streaming_error
+
+ @property
+ def is_saving(self) -> bool:
+ """Return True if currently saving to file."""
+ return self._is_saving
+
+ def disconnect(self) -> None:
+ """Disconnect from device, stopping any active streaming first.
+
+ Cleanup must happen in this order to avoid deadlocks and handle leaks:
+ 1. stop_streaming() - stops the producer thread first
+ 2. ps5000aStop() - stops hardware acquisition (called by stop_streaming)
+ 3. ps5000aCloseUnit() - closes device handle last (called by parent)
+
+ Reversing this order can cause deadlocks or resource leaks.
+ """
+ try:
+ if self.is_streaming():
+ self.stop_streaming(timeout=1.0)
+ except Exception:
+ logger.exception("Error stopping streaming during disconnect")
+ finally:
+ super().disconnect()
+
+ @property
+ def ring_buffer(self) -> Optional[Any]:
+ """Return the ring buffer for plotter access.
+
+ Returns
+ -------
+ Optional[RingBuffer]
+ The ring buffer if configured, None otherwise.
+ """
+ return self._ring_buffer
+
+ def update_plotter_position(self, read_idx: int) -> None:
+ """Update the plotter's last read position for overflow detection.
+
+ This should be called by the plotter whenever it reads data from
+ the ring buffer. It allows the producer to detect if the plotter
+ is falling behind.
+
+ Parameters
+ ----------
+ read_idx : int
+ The sample index that the plotter last read up to.
+ """
+ with self._plotter_position_lock:
+ self._last_plotter_read_idx = read_idx
+ self._last_plotter_update_time = time.time()
+
+ # Log overflow check outside lock to avoid holding lock during computation
+ if self._ring_buffer is not None:
+ write_idx = self._ring_buffer.write_idx
+ buffer_capacity = self._ring_buffer.capacity
+ lag_samples = write_idx - read_idx
+ if lag_samples > 0.9 * buffer_capacity:
+ logger.warning(
+ "Ring buffer nearly full: {} samples lag ({:.1%} utilised). "
+ "write_idx={}, read_idx={}, capacity={}. "
+ "Plotter may not be keeping up - data loss imminent!",
+ lag_samples,
+ lag_samples / buffer_capacity if buffer_capacity > 0 else 0,
+ write_idx,
+ read_idx,
+ buffer_capacity,
+ )
+
+ def start_save(self, lookback_seconds: float, output_path: str) -> bool:
+ """Start saving data to a Zarr file.
+
+ Launches a save thread that writes pre-trigger data from the ring
+ buffer, then continuously drains new data to a Zarr file.
+
+ Parameters
+ ----------
+ lookback_seconds : float
+ How many seconds of pre-trigger data to include.
+ output_path : str
+ Path for the output Zarr directory.
+
+ Returns
+ -------
+ bool
+ True if save started successfully, False otherwise.
+ """
+ if self._is_saving:
+ logger.warning("Save already in progress")
+ return False
+
+ if self._ring_buffer is None:
+ logger.error("Cannot start save: ring buffer not configured")
+ return False
+
+ self._save_output_path = output_path
+
+ if self._acquisition_rate is None:
+ logger.error("Cannot start save: acquisition_rate not configured")
+ return False
+
+ lookback_samples = self._acquisition_rate.seconds_to_samples(lookback_seconds)
+ available_samples = min(lookback_samples, self._ring_buffer.write_idx)
+
+ # Capture start/stop timestamps for metadata
+ self._save_start_time_iso = datetime.now(timezone.utc).isoformat()
+ self._save_stop_time_iso = None
+
+ self._save_stop_event = threading.Event()
+ self._save_samples_written = 0
+ self._save_pre_trigger_samples = 0
+ self._save_error = None
+
+ # Use existing acquisition_rate, should already be set from start_streaming
+ acquisition_rate = self._acquisition_rate
+
+ # Use the driver's max_adc value for voltage conversion
+ # The PicoScope 5000a driver always returns 32767 (16-bit signed max)
+ # regardless of resolution setting. Higher resolutions use more effective
+ # bits, but the ADC value range stays the same.
+ driver_max_adc = self._max_adc_value.value if self._max_adc_value else 32767
+
+ # Build metadata with AcquisitionRate fields for consistent rate semantics
+ # Use the actual acquisition_rate stored on the ring buffer
+ total_hardware_rate_hz = self._acquisition_rate.hardware_rate_hz
+ metadata: Dict[str, Any] = {
+ # AcquisitionRate fields - stored explicitly for reconstruction
+ "hardware_rate_hz": total_hardware_rate_hz,
+ "num_channels": len(self._streaming_channels),
+ "downsample_ratio": self._downsample_ratio,
+ "downsample_mode": self._downsample_mode,
+ # Legacy field for backward compatibility (deprecated)
+ "sample_rate_hz": acquisition_rate.storage_rate_hz,
+ "channels": [0, 1], # Always save both columns (A and B)
+ "enabled_channels": self._streaming_channels, # Which channels were actually acquiring
+ "voltage_ranges": [
+ self._streaming_voltage_ranges.get(ch, 20.0) for ch in [0, 1]
+ ],
+ "voltage_ranges_units": "V",
+ "offsets_v": [self._streaming_offsets_v.get(ch, 0.0) for ch in [0, 1]],
+ "resolution": self.resolution,
+ "resolution_units": "bits",
+ # Use driver's max_adc (always 32767 for PicoScope 5000a)
+ # for correct voltage conversion in external tools
+ "max_adc": driver_max_adc,
+ "pre_trigger_seconds": lookback_seconds,
+ "buffer_duration_s": self._streaming_buffer_duration_s,
+ "device_id": self.device_id,
+ "serial_code": self.serial_code,
+ "coupling": "DC", # Currently hardcoded to DC in configure_streaming
+ "bandwidth_limiter": self._streaming_bandwidth_limiter,
+ }
+
+ # Capture trigger position in main thread BEFORE spawning worker thread.
+ # This eliminates a race condition where the producer could overwrite
+ # pre-trigger data while the save worker is starting up.
+ trigger_pos = self._ring_buffer.get_snapshot()
+
+ self._save_thread = threading.Thread(
+ target=self._save_worker,
+ args=(
+ output_path,
+ trigger_pos,
+ available_samples,
+ metadata,
+ acquisition_rate,
+ ),
+ daemon=True,
+ name=f"picoscope-save-{self.device_id}",
+ )
+ self._save_thread.start()
+ self._is_saving = True
+
+ logger.info(
+ "Started save on {}: path={}, lookback={:.1f}s ({} samples)",
+ self.device_id,
+ output_path,
+ lookback_seconds,
+ available_samples,
+ )
+ return True
+
+ def _save_worker(
+ self,
+ path: str,
+ trigger_pos: int,
+ lookback_samples: int,
+ metadata: Dict[str, Any],
+ acquisition_rate: AcquisitionRate,
+ ) -> None:
+ """Worker thread for saving data to Zarr.
+
+ This runs in a separate thread and continuously drains data
+ from the ring buffer to a Zarr file until stop_save is called.
+ Always saves 2 columns (A and B), with zeros for disabled channels.
+
+ Parameters
+ ----------
+ path : str
+ Output path for the Zarr file.
+ trigger_pos : int
+ The write index position captured when start_save() was called.
+ Used as the trigger point for pre-trigger data capture.
+ lookback_samples : int
+ Number of samples to include before the trigger point.
+ metadata : Dict[str, Any]
+ Metadata dictionary to store in the Zarr file.
+ acquisition_rate : AcquisitionRate
+ Acquisition rate object with all rate information.
+
+ Raises
+ ------
+ Exception
+ Any exception during saving is stored in _save_error and re-raised
+ to ensure the thread terminates with an error state.
+ """
+ try:
+ from picostream.zarr_writer import ZarrStreamWriter
+
+ writer = ZarrStreamWriter(
+ path=path,
+ acquisition_rate=acquisition_rate,
+ num_channels=2, # Always 2 columns
+ compression=None,
+ )
+
+ pre_trigger_start = trigger_pos - lookback_samples
+ if pre_trigger_start < 0:
+ pre_trigger_start = 0
+
+ # Validate that trigger_pos is still valid before reading.
+ # This catches edge cases where save was significantly delayed
+ # after start_save() was called.
+ if not self._ring_buffer.is_valid_range(pre_trigger_start):
+ msg = (
+ f"Pre-trigger data expired before save worker started: "
+ f"trigger_pos={trigger_pos}, lookback_samples={lookback_samples}, "
+ f"buffer write_idx={self._ring_buffer.write_idx}, "
+ f"capacity={self._ring_buffer.capacity}"
+ )
+ logger.error(msg)
+ self._save_error = RuntimeError(msg)
+ raise self._save_error
+
+ pre_trigger_data = self._ring_buffer.read_range(
+ pre_trigger_start, trigger_pos
+ )
+
+ if pre_trigger_data.shape[0] > 0:
+ writer.append(pre_trigger_data)
+ self._save_pre_trigger_samples = pre_trigger_data.shape[0]
+
+ last_pos = trigger_pos
+
+ while not self._save_stop_event.is_set():
+ data = self._ring_buffer.read_since(last_pos)
+
+ if data.shape[0] > 0:
+ if not self._ring_buffer.is_valid_range(last_pos):
+ valid_start = (
+ self._ring_buffer.write_idx - self._ring_buffer.capacity
+ )
+ if valid_start < 0:
+ valid_start = 0
+ samples_lost = valid_start - last_pos
+ if samples_lost > 0:
+ logger.warning(
+ "Save worker lost {} samples: producer outran consumer. "
+ "Consider reducing sample rate or increasing buffer size.",
+ samples_lost,
+ )
+ data = self._ring_buffer.read_since(valid_start)
+ last_pos = valid_start
+
+ writer.append(data)
+ last_pos = self._ring_buffer.get_snapshot()
+ self._save_samples_written = writer.total_samples
+ else:
+ time.sleep(0.001)
+
+ if self._save_start_time_iso is not None:
+ metadata["record_start_time_iso"] = self._save_start_time_iso
+ if self._save_stop_time_iso is not None:
+ metadata["record_stop_time_iso"] = self._save_stop_time_iso
+
+ writer.close(metadata)
+ logger.info(
+ "Save worker completed: {} samples written to {}",
+ writer.total_samples,
+ path,
+ )
+
+ except Exception as e:
+ logger.exception("Save worker error: {}", e)
+ self._save_error = e
+ # Re-raise to ensure thread terminates with error state
+ raise
+
+ def stop_save(self, keep: bool = True) -> Optional[str]:
+ """Stop the current save operation (non-blocking).
+
+ Signals the save thread to stop and returns immediately.
+ Poll is_save_finished() to check when the operation completes.
+
+ Parameters
+ ----------
+ keep : bool, default True
+ If True, keep the saved file. If False, delete it.
+
+ Returns
+ -------
+ Optional[str]
+ Path where the file is being saved (if keep=True), or None.
+ The file may still be incomplete - check is_save_finished() before using.
+
+ Raises
+ ------
+ DeviceOperationError
+ If the save thread already encountered an error during recording.
+ """
+ if not self._is_saving:
+ logger.info("No save in progress")
+ return None
+
+ # Capture stop timestamp and store keep flag for cleanup
+ self._save_stop_time_iso = datetime.now(timezone.utc).isoformat()
+ self._keep_file_on_stop = keep
+
+ logger.info("Stopping save on {} (keep={})", self.device_id, keep)
+ self._save_stop_event.set()
+
+ # Check for errors from the save worker (already occurred)
+ if self._save_error is not None:
+ save_error = self._save_error
+ logger.error(
+ "Save operation failed on {}: {}",
+ self.device_id,
+ save_error,
+ )
+ # Clear the error so we don't re-raise it again later
+ self._save_error = None
+ raise DeviceOperationError(
+ f"Save operation failed: {save_error}"
+ ) from save_error
+
+ return self._save_output_path if keep else None
+
+ def is_save_finished(self) -> bool:
+ """Check if the save operation has finished.
+
+ After this returns True, you can safely use the saved file
+ (if keep=True was passed to stop_save()).
+
+ Returns
+ -------
+ bool
+ True if save is complete and the file is ready, False otherwise.
+
+ Notes
+ -----
+ This method is thread-safe and will not raise exceptions. Any errors
+ during cleanup are logged but not propagated.
+
+ There is a benign race condition: the save thread may terminate after
+ the is_alive() check but before cleanup completes. This is handled
+ gracefully by catching exceptions during cleanup.
+ """
+ if not self._is_saving:
+ return True
+
+ # Check if thread is still running
+ if self._save_thread and self._save_thread.is_alive():
+ return False
+
+ # Thread has finished, do cleanup
+ try:
+ self._cleanup_save()
+ except Exception:
+ logger.exception("Error during save cleanup in is_save_finished()")
+ # Still mark as finished even if cleanup fails, to avoid infinite polling
+ self._is_saving = False
+
+ return True
+
+ def _cleanup_save(self) -> None:
+ """Clean up after save thread completes (idempotent)."""
+ if not self._is_saving:
+ return
+
+ path = self._save_output_path
+ keep = getattr(self, "_keep_file_on_stop", True)
+
+ if not keep and path and os.path.exists(path):
+ try:
+ import shutil
+
+ shutil.rmtree(path)
+ logger.info("Discarded save file: {}", path)
+ except Exception as e:
+ logger.exception("Error discarding save file: {}", e)
+
+ self._is_saving = False
+ self._save_thread = None
+ self._save_stop_event = None
+ self._save_output_path = None
+ if hasattr(self, "_keep_file_on_stop"):
+ delattr(self, "_keep_file_on_stop")
+
+ def get_save_status(self) -> Dict[str, Any]:
+ """Get the current save status.
+
+ Returns
+ -------
+ Dict[str, Any]
+ Dictionary with keys: 'state', 'total_seconds', 'pre_trigger_seconds',
+ 'post_trigger_seconds', 'error' (if applicable).
+ """
+ if not self._is_saving and self._save_error is None:
+ return {"state": "idle"}
+
+ if self._save_error is not None:
+ return {"state": "error", "error": str(self._save_error)}
+
+ if self._acquisition_rate is None:
+ return {"state": "saving", "total_seconds": 0}
+
+ total_seconds = self._acquisition_rate.samples_to_seconds(
+ self._save_samples_written
+ )
+ pre_trigger_seconds = self._acquisition_rate.samples_to_seconds(
+ self._save_pre_trigger_samples
+ )
+ post_trigger_seconds = total_seconds - pre_trigger_seconds
+
+ return {
+ "state": "saving",
+ "total_seconds": total_seconds,
+ "pre_trigger_seconds": pre_trigger_seconds,
+ "post_trigger_seconds": post_trigger_seconds,
+ }
+
+ def _setup_streaming_buffers(self) -> None:
+ """Register SDK streaming buffers for enabled channels only."""
+ downsample_mode_enum = self._ps.PS5000A_RATIO_MODE[
+ f"PS5000A_RATIO_MODE_{self._downsample_mode}"
+ ]
+
+ for ch in self._streaming_channels:
+ buf = self._sdk_buffers[ch]
+ buffer_min_ptr = None
+
+ if (
+ self._downsample_mode == "AGGREGATE"
+ and self._sdk_buffer_min is not None
+ ):
+ buffer_min_ptr = self._sdk_buffer_min.ctypes.data_as(
+ ctypes.POINTER(ctypes.c_int16)
+ )
+
+ status = self._ps.ps5000aSetDataBuffers(
+ self._chandle,
+ self._get_channel_enum(ch),
+ buf.ctypes.data_as(ctypes.POINTER(ctypes.c_int16)),
+ buffer_min_ptr,
+ buf.shape[0],
+ 0,
+ downsample_mode_enum,
+ )
+ self._assert_pico_ok(status)
+
+ logger.debug(
+ "Set up streaming buffers for enabled channels %s", self._streaming_channels
+ )
+
+ def _cleanup_streaming_buffers(self) -> None:
+ """De-register streaming buffers for enabled channels."""
+ if not self._is_connected or self._chandle is None:
+ return
+
+ for ch in self._streaming_channels:
+ try:
+ buf = self._sdk_buffers.get(ch)
+ if buf is None:
+ continue
+
+ buffer_min_ptr = None
+ if (
+ self._downsample_mode == "AGGREGATE"
+ and self._sdk_buffer_min is not None
+ ):
+ buffer_min_ptr = self._sdk_buffer_min.ctypes.data_as(
+ ctypes.POINTER(ctypes.c_int16)
+ )
+
+ status = self._ps.ps5000aSetDataBuffers(
+ self._chandle,
+ self._get_channel_enum(ch),
+ buf.ctypes.data_as(ctypes.POINTER(ctypes.c_int16)),
+ buffer_min_ptr,
+ 0, # size=0 de-registers
+ 0,
+ self._ps.PS5000A_RATIO_MODE["PS5000A_RATIO_MODE_NONE"],
+ )
+ if status != PICO_STATUS["PICO_OK"]:
+ logger.debug(
+ "Buffer de-register status {} for channel {}.", status, ch
+ )
+ except Exception:
+ logger.exception("Error de-registering buffer for channel {}", ch)
+
+ def _get_streaming_interval_ns(
+ self, sample_rate_hz: float
+ ) -> Tuple[int, int, float]:
+ """
+ Calculate streaming interval in nanoseconds.
+
+ Returns
+ -------
+ Tuple[int, int, float]
+ (interval, unit_enum, actual_rate_hz)
+ """
+ if sample_rate_hz <= 0:
+ raise DeviceConfigurationError("Sample rate must be positive.")
+
+ # Use nanoseconds for highest resolution
+ unit_str = "PS5000A_NS"
+ interval = max(1, int(round(1e9 / sample_rate_hz)))
+ unit_enum = self._ps.PS5000A_TIME_UNITS[unit_str]
+ actual_rate = 1e9 / interval
+
+ return interval, unit_enum, actual_rate
+
+ def _producer_loop(self) -> None:
+ """Producer thread: poll SDK for data and write to ring buffer."""
+ logger.debug("Producer loop started for {}.", self.device_id)
+ not_responding_count = 0
+
+ # Main polling loop - start immediately like the reference implementation
+ # The device may return PICO_NOT_RESPONDING initially; we handle this
+ # gracefully by counting consecutive failures and only stopping after 3
+ main_loop_start = time.perf_counter()
+ last_status_log_time = main_loop_start
+ last_callback_seen = 0
+
+ try:
+ while not self._stop_streaming_event.is_set():
+ status = self._ps.ps5000aGetStreamingLatestValues(
+ self._chandle, self._callbackFuncPtr, None
+ )
+
+ now = time.perf_counter()
+ elapsed_since_start = now - main_loop_start
+
+ # Handle statuses according to PicoSDK documentation
+ # These are all normal during streaming - only PICO_NOT_RESPONDING
+ # and unexpected errors are treated as abnormal
+ if status == PICO_STATUS["PICO_BUFFER_STALL"]:
+ with self._stream_stats_lock:
+ self._hardware_overflow_count += 1
+ logger.warning("Hardware buffer overflow - data lost")
+ elif status == PICO_STATUS["PICO_NOT_RESPONDING"]:
+ not_responding_count += 1
+ if not_responding_count == 1:
+ # First occurrence - log diagnostic info
+ with self._stream_stats_lock:
+ cb_count = self._stream_callback_count
+ sample_count = self._stream_total_samples
+ logger.error(
+ "PICO_NOT_RESPONDING (status 7) first detected after {:.3f}s. "
+ "Callbacks so far: {}, Samples: {}. "
+ "This may indicate hardware initialization timing issue.",
+ elapsed_since_start,
+ cb_count,
+ sample_count,
+ )
+
+ if not_responding_count >= 3:
+ # Device has stopped responding persistently - treat as fatal
+ with self._stream_stats_lock:
+ self._streaming_error = "PICO_NOT_RESPONDING"
+ final_cb_count = self._stream_callback_count
+ final_sample_count = self._stream_total_samples
+ logger.error(
+ "Picoscope device {} stopped responding (status 7 - PICO_NOT_RESPONDING) "
+ "after {} consecutive failures. Total time: {:.3f}s, "
+ "Total callbacks: {}, Total samples: {}. Streaming will stop.",
+ self.device_id,
+ not_responding_count,
+ elapsed_since_start,
+ final_cb_count,
+ final_sample_count,
+ )
+ break # Exit the producer loop
+ else:
+ logger.debug(
+ "Picoscope device {} returned PICO_NOT_RESPONDING (attempt {}, {:.3f}s elapsed)",
+ self.device_id,
+ not_responding_count,
+ elapsed_since_start,
+ )
+ elif status in [
+ PICO_STATUS["PICO_OK"],
+ PICO_STATUS["PICO_NO_SAMPLES_AVAILABLE"],
+ PICO_STATUS["PICO_BUSY"],
+ PICO_STATUS["PICO_DATA_NOT_AVAILABLE"],
+ PICO_STATUS["PICO_DRIVER_FUNCTION"],
+ ]:
+ # All normal statuses during streaming - reset not_responding_count
+ if not_responding_count > 0:
+ logger.info(
+ "Device {} responding again after {} NOT_RESPONDING errors "
+ "(at {:.3f}s elapsed)",
+ self.device_id,
+ not_responding_count,
+ elapsed_since_start,
+ )
+ not_responding_count = 0
+ else:
+ # Unexpected status - log warning (throttled)
+ if (
+ not hasattr(self, "_last_status_warning_time")
+ or now - self._last_status_warning_time > 5.0
+ ):
+ logger.warning(
+ "GetStreamingLatestValues returned unexpected status {} after {:.3f}s",
+ status,
+ elapsed_since_start,
+ )
+ self._last_status_warning_time = now
+
+ # Log callback stats periodically
+ with self._stream_stats_lock:
+ current_callbacks = self._stream_callback_count
+ current_samples = self._stream_total_samples
+
+ if current_callbacks > last_callback_seen:
+ last_callback_seen = current_callbacks
+ # Log every 100 callbacks or every 5 seconds
+ if current_callbacks % 100 == 1 or now - last_status_log_time > 5.0:
+ logger.debug(
+ "Producer stats @ {:.3f}s: callbacks={}, samples={}, "
+ "rate={:.1f} samples/s",
+ elapsed_since_start,
+ current_callbacks,
+ current_samples,
+ current_samples / elapsed_since_start
+ if elapsed_since_start > 0
+ else 0,
+ )
+ last_status_log_time = now
+
+ time.sleep(0.001)
+
+ except Exception:
+ logger.exception(
+ "Error in producer loop after {:.3f}s",
+ time.perf_counter() - main_loop_start,
+ )
+ finally:
+ total_elapsed = time.perf_counter() - main_loop_start
+ logger.debug(
+ "Producer loop finished for %s. Total time: %.3fs",
+ self.device_id,
+ total_elapsed,
+ )
+
+ def _streaming_callback(
+ self,
+ _handle: int,
+ noOfSamples: int,
+ startIndex: int,
+ overvoltage_flags: int,
+ _triggerAt: int,
+ _triggered: int,
+ _autoStop: int,
+ _param: int,
+ ) -> None:
+ """SDK callback: copy data from SDK buffers to the ring buffer.
+
+ This is called from the SDK's internal thread when new data is available.
+ Data is written to fixed column positions (0=Channel A, 1=Channel B).
+ Disabled channels have zeros written to their column.
+ """
+ # Debug: log first few callbacks
+ with self._stream_stats_lock:
+ self._stream_callback_count += 1
+ if self._stream_callback_count <= 5:
+ logger.debug(
+ "SDK callback #%s: %s samples",
+ self._stream_callback_count,
+ noOfSamples,
+ )
+
+ if self._stop_streaming_event.is_set():
+ return
+
+ if overvoltage_flags:
+ with self._stream_stats_lock:
+ self._overvoltage_count += 1
+ logger.warning("ADC overvoltage detected (signal clipping)")
+
+ if noOfSamples <= 0:
+ return
+
+ with self._stream_stats_lock:
+ self._stream_total_samples += noOfSamples
+ self._stream_callback_count += 1
+
+ if self._ring_buffer is None:
+ logger.error("Ring buffer not initialised")
+ return
+
+ if (
+ self._downsample_mode == "AGGREGATE"
+ and self._interleaved_buffer is not None
+ ):
+ # AGGREGATE mode: SDK provides separate min/max buffers
+ # We interleave them into output as [min1, max1, min2, max2, ...]
+ # This produces 2x the samples compared to other modes
+ src_start = startIndex
+ total_output_samples = noOfSamples * 2
+
+ # Use pre-allocated buffer, slice to needed size
+ output_data = self._callback_buffer[:total_output_samples, :]
+ output_data[...] = 0 # Clear only viewed elements (not entire buffer)
+
+ for ch in self._streaming_channels:
+ sdk_buf_max = self._sdk_buffers[ch]
+
+ # Get min and max slices
+ mins = self._sdk_buffer_min[src_start : src_start + noOfSamples]
+ maxs = sdk_buf_max[src_start : src_start + noOfSamples]
+
+ # Interleave into output buffer: [min1, max1, min2, max2, ...]
+ output_data[0::2, ch] = mins
+ output_data[1::2, ch] = maxs
+ # Disabled channels remain as zeros
+ else:
+ # Normal mode: copy enabled channel data to fixed position
+ output_data = self._callback_buffer[:noOfSamples, :]
+ output_data[...] = 0 # Clear only viewed elements (not entire buffer)
+ for ch in self._streaming_channels:
+ sdk_buf = self._sdk_buffers[ch]
+ output_data[:, ch] = sdk_buf[startIndex : startIndex + noOfSamples]
+ # Disabled channels remain as zeros
+
+ self._ring_buffer.write(output_data)
+
+ if self._user_callback is not None:
+ try:
+ self._user_callback(
+ {
+ "type": "data_ready",
+ "samples": noOfSamples,
+ "channels": self._streaming_channels,
+ "total_samples": self._stream_total_samples,
+ }
+ )
+ except Exception:
+ logger.exception("Error in user callback")
diff --git a/picostream/dfplot.py b/picostream/dfplot.py
index c9cca03..21b2473 100644
--- a/picostream/dfplot.py
+++ b/picostream/dfplot.py
@@ -1,660 +1,883 @@
from __future__ import annotations
+import argparse
import sys
-import threading
import time
-from typing import List, Optional
-import argparse
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
-import h5py
import numpy as np
-import pyqtgraph as pg
from loguru import logger
-from PyQt5.QtCore import Qt, QTimer
-from PyQt5.QtGui import QCloseEvent, QFont, QKeyEvent
-from PyQt5.QtWidgets import (
+from PyQt6.QtCore import Qt, QTimer
+from PyQt6.QtGui import QCloseEvent, QKeyEvent
+from PyQt6.QtWidgets import (
QApplication,
- QHBoxLayout,
- QLabel,
QVBoxLayout,
QWidget,
)
-
-from .conversion_utils import adc_to_mV, min_max_decimate_numba
-
-
-class HDF5LivePlotter(QWidget):
- """
- Real-time oscilloscope-style plotter that reads from HDF5 files.
- Completely independent of acquisition system for zero-risk operation.
-
- Supports both standard acquisition files and live-only mode files with
- automatic buffer reset detection and graceful handling of file size changes.
+from vispy import scene # noqa: E402
+from vispy.app import use_app
+
+use_app("pyqt6")
+
+if TYPE_CHECKING:
+ pass
+
+from picostream.acquisition_rate import AcquisitionRate # noqa: E402
+from picostream.data_pipeline import DataPipeline, PipelineConfig # noqa: E402
+from picostream.ring_buffer import RingBuffer # noqa: E402
+
+TARGET_PLOT_POINTS = 20000
+
+
+class LivePlotter(QWidget):
+ """Real-time oscilloscope-style plotter that reads from a RingBuffer.
+
+ This plotter displays live streaming data directly from memory, completely
+ decoupled from the file saving mechanism. It displays both channels on a
+ single axis with data scaling to match voltage ranges.
+
+ Parameters
+ ----------
+ ring_buffer : Optional[RingBuffer]
+ The ring buffer to read data from. None indicates waiting for stream.
+ channels : List[int]
+ List of active channel indices (e.g., [0, 1] for channels A and B).
+ voltage_ranges : Dict[int, float]
+ Mapping of channel index to voltage range in Volts.
+ offsets_v : Dict[int, float]
+ Mapping of channel index to voltage offset in Volts.
+ resolution : int
+ ADC resolution in bits (e.g., 12, 14, 15, 16).
+ downsample_mode : str
+ Downsample mode string (e.g., "NONE", "AGGREGATE").
+ update_interval_ms : int, optional
+ How often to update the plot in milliseconds. Default is 50ms (20 Hz).
+ display_window_seconds : float, optional
+ Time duration of data to display. Default is 5.0 seconds.
+ decimation_factor : int, optional
+ Factor for min-max decimation to reduce plotting points. Default is 150.
+ target_plot_points : int, optional
+ Target number of points to display on screen for adaptive decimation.
+ Default is 20000.
+ on_read_position : Optional[Callable[[int], None]], optional
+ Callback invoked with the sample index after each read from ring buffer.
+ Used by producer for overflow detection.
+
+ Attributes
+ ----------
+ ring_buffer : Optional[RingBuffer]
+ Current ring buffer reference. Updated via set_ring_buffer().
+ acquisition_rate : Optional[AcquisitionRate]
+ Single source of truth for all rate calculations.
"""
def __init__(
self,
- hdf5_path: str = "/tmp/data.hdf5",
+ ring_buffer: Optional[RingBuffer],
+ channels: List[int],
+ voltage_ranges: Dict[int, float],
+ offsets_v: Dict[int, float],
+ resolution: int,
+ downsample_mode: str,
update_interval_ms: int = 50,
- display_window_seconds: float = 0.5,
+ display_window_seconds: float = 5.0,
decimation_factor: int = 150,
- y_min: Optional[float] = None,
- y_max: Optional[float] = None,
+ target_plot_points: int = 20000,
+ on_read_position: Optional[Callable[[int], None]] = None,
+ downsample_ratio: int = 1,
) -> None:
- """Initializes the HDF5LivePlotter window.
-
- Parameters
- ----------
- hdf5_path : str
- Path to the HDF5 file to monitor.
- update_interval_ms : int
- How often to check the file for updates (in ms).
- display_window_seconds : float
- The time duration of data to display.
- decimation_factor : int
- The factor by which to decimate data for plotting.
- y_min : float, optional
- Minimum Y-axis limit in mV. If None, auto-range is used.
- y_max : float, optional
- Maximum Y-axis limit in mV. If None, auto-range is used.
- """
+ """Initialise the LivePlotter with metadata and optional ring buffer."""
super().__init__()
- # --- Configuration ---
- self.hdf5_path: str = hdf5_path
+ self.ring_buffer: Optional[RingBuffer] = ring_buffer
+ self._on_read_position: Optional[Callable[[int], None]] = on_read_position
+
+ self.channels: List[int] = channels
+ self.voltage_ranges: Dict[int, float] = voltage_ranges
+ self.offsets_v: Dict[int, float] = offsets_v
+ self.resolution: int = resolution
+ self.downsample_mode: str = downsample_mode
+ self.downsample_ratio: int = downsample_ratio
+
self.update_interval_ms: int = update_interval_ms
self.display_window_seconds: float = display_window_seconds
self.decimation_factor: int = decimation_factor
- self.y_min: Optional[float] = y_min
- self.y_max: Optional[float] = y_max
+ self.target_plot_points: Optional[int] = target_plot_points
+
+ # PicoScope 5000a series uses a 16-bit ADC internally and always returns
+ # max_adc = 32767 (2^15 - 1) regardless of resolution setting.
+ # Higher resolutions use more effective bits but the hardware ADC range stays the same.
+ # This is documented in the hardware API and confirmed in devices/picoscope.py.
+ max_adc_val = 32767
+ self.max_adc_val: int = max_adc_val
- # --- UI State ---
- self.heartbeat_chars: List[str] = ["|", "/", "-", "\\"]
- self.heartbeat_index: int = 0
self.is_saturated: bool = False
- # --- Data Buffers ---
+ self._display_cache: Dict[int, np.ndarray] = {}
+ self._last_read_position: int = 0
+ self._cache_ready: bool = False
+ self._current_plot_decimation: int = 1
+ self._cache_dirty: bool = False
+ self._cached_time_axis: Optional[np.ndarray] = None
+ self._target_time_axis: Optional[np.ndarray] = None
+
self.display_data: np.ndarray = np.array([])
self.time_data: np.ndarray = np.array([])
self.data_start_sample: int = 0
- # --- Buffer Reset Detection ---
- self.last_file_size: int = 0
- self.buffer_reset_count: int = 0
- self.total_samples_processed: int = 0
-
- # --- HDF5 Metadata ---
- self.sample_interval_ns: float = 16.0 # Default, will be read from file
- self.hardware_downsample_ratio: int = 1
- self.ch_range: Optional[int] = None
- self.voltage_range_v: Optional[float] = None
- self.max_adc_val: Optional[int] = None
- self.downsample_mode: Optional[str] = None
- self.analog_offset_v: float = 0.0
- self.metadata_read: bool = False
-
- # --- Debug Counters ---
+ self.v_div_volts: Dict[int, float] = {0: 0.1, 1: 0.1}
+ self.y_pos_divs: Dict[int, float] = {0: 0.0, 1: 0.0}
+
self.update_count: int = 0
- self.file_read_count: int = 0
self.display_update_count: int = 0
+ self.conversion_error_count: int = 0
- # --- Performance Monitoring ---
self.display_latency_ms: float = 0.0
self.last_data_timestamp: Optional[float] = None
- # --- Rate Checking ---
self.rate_check_start_time: Optional[float] = None
self.rate_check_start_samples: int = 0
- # --- Data Freshness Tracking ---
- self.last_displayed_size: int = 0
- self.data_change_count: int = 0
- self.stale_update_count: int = 0
- self.last_freshness_check: float = time.time()
+ self.pipeline: Optional[DataPipeline] = None
+ self.acquisition_rate: Optional[AcquisitionRate] = None
- # --- Error Tracking ---
- self.conversion_error_count: int = 0
- self.file_error_count: int = 0
+ self._channel_colors = {0: (0.0, 0.75, 1.0, 1.0), 1: (1.0, 0.27, 0.27, 1.0)}
- # Setup UI
- self.setup_ui()
+ self._pos_buffers: Dict[int, np.ndarray] = {}
+ self._time_axis_float32: Optional[np.ndarray] = None
- # Setup update timer, which is controlled externally
- self.timer = QTimer()
- self.timer.timeout.connect(self.update_from_file)
+ self.setup_ui()
logger.info(
- f"HDF5LivePlotter initialized: path={hdf5_path}, interval={update_interval_ms}ms"
+ "LivePlotter initialised: {} channel(s), interval={}ms",
+ len(self.channels),
+ self.update_interval_ms,
)
- # Initial file check
- self.check_file_exists()
+ self.timer = QTimer()
+ self.timer.timeout.connect(self.update_from_buffer)
+
+ # Set ring buffer to initialize pipeline if provided
+ if ring_buffer is not None:
+ self.set_ring_buffer(ring_buffer)
- def reset_for_new_file(self) -> None:
- """Reset all state in preparation for a new file.
-
- This should be called BEFORE set_hdf5_path() when starting a new acquisition
- to ensure no stale metadata or state is carried over.
+ @property
+ def hardware_rate_hz(self) -> float:
+ """Hardware sample rate in Hz (total across all channels).
+
+ Returns
+ -------
+ float
+ The hardware rate from acquisition_rate.
+
+ Raises
+ ------
+ RuntimeError
+ If called before acquisition_rate is set.
"""
- logger.debug("Resetting plotter state for new file")
-
- # Clear display
- self.curve.setData([], [])
-
- # Reset all data buffers
- self.display_data = np.array([])
- self.time_data = np.array([])
- self.data_start_sample = 0
-
- # Reset buffer tracking
- self.last_file_size = 0
- self.buffer_reset_count = 0
- self.total_samples_processed = 0
-
- # Clear cached metadata - this is critical!
- self.voltage_range_v = None
- self.max_adc_val = None
- self.downsample_mode = None
- self.analog_offset_v = 0.0
- self.sample_interval_ns = 16.0
- self.hardware_downsample_ratio = 1
- self.metadata_read = False
-
- # Reset rate checking
- self.rate_check_start_time = None
- self.rate_check_start_samples = 0
- self.last_displayed_size = 0
-
- # Reset counters
- self.update_count = 0
- self.file_read_count = 0
- self.display_update_count = 0
- self.data_change_count = 0
- self.stale_update_count = 0
-
- # Reset status
- self.is_saturated = False
- self.last_data_timestamp = None
- self.display_latency_ms = 0.0
-
- def set_hdf5_path(self, hdf5_path: str) -> None:
- """Sets the HDF5 file path to monitor.
-
- Note: Call reset_for_new_file() BEFORE this method when starting a new
- acquisition to ensure clean state.
+ if self.acquisition_rate is not None:
+ return self.acquisition_rate.hardware_rate_hz
+ msg = "hardware_rate_hz called before acquisition_rate was set"
+ raise RuntimeError(msg)
+
+ def _allocate_display_buffers(self, max_points: int) -> None:
+ """Allocate or reallocate GPU position buffers for all channels.
+
+ Parameters
+ ----------
+ max_points : int
+ Maximum number of points to display.
"""
- self.hdf5_path = hdf5_path
- logger.info(f"Plotter path updated to: {hdf5_path}")
-
- # Don't check file or read metadata here - wait for update_from_file()
- # to be called, which will only read metadata once the file exists with valid data
+ for ch in [0, 1]:
+ new_size = max_points * 2
+ if ch not in self._pos_buffers or len(self._pos_buffers[ch]) < new_size:
+ self._pos_buffers[ch] = np.empty((max_points, 2), dtype=np.float32)
+
+ def set_acquisition_rate(self, rate: AcquisitionRate) -> None:
+ """Set the acquisition rate and reinitialise the pipeline.
- def set_display_window(self, window_seconds: float) -> None:
- """Sets the temporal display window width.
-
Parameters
----------
- window_seconds : float
- The time duration of data to display in seconds.
+ rate : AcquisitionRate
+ The acquisition rate with all rate information.
"""
- self.display_window_seconds = window_seconds
- logger.info(f"Display window set to {window_seconds}s")
+ self.acquisition_rate = rate
+ self._initialise_pipeline(rate)
+ logger.info("Plotter acquisition rate set: {}", rate)
+
+ def _initialise_pipeline(self, acquisition_rate: AcquisitionRate) -> None:
+ """Initialise or reinitialise the data pipeline with current config."""
+ config = PipelineConfig(
+ resolution=self.resolution,
+ voltage_ranges=self.voltage_ranges,
+ offsets_v=self.offsets_v,
+ max_adc_value=self.max_adc_val,
+ target_plot_points=self.target_plot_points,
+ )
+ self.pipeline = DataPipeline(config, acquisition_rate)
+
+ def _update_pipeline_sample_rate(self) -> None:
+ """Update the plotter from ring buffer's acquisition rate."""
+ if self.ring_buffer is not None:
+ rate = self.ring_buffer.get_acquisition_rate()
+ if rate is not None:
+ self.set_acquisition_rate(rate)
+
+ def set_ring_buffer(self, ring_buffer: Optional[RingBuffer]) -> None:
+ """Update the ring buffer reference.
+
+ Called when the stream stops/starts to reconnect the plotter to
+ a new or restarted buffer. Resets position tracking and clears
+ the display cache.
- def set_y_limits(self, y_min: Optional[float], y_max: Optional[float]) -> None:
- """Sets fixed Y-axis limits or enables auto-ranging.
-
Parameters
----------
- y_min : float, optional
- Minimum Y-axis limit in mV. If None, auto-range is used.
- y_max : float, optional
- Maximum Y-axis limit in mV. If None, auto-range is used.
+ ring_buffer : Optional[RingBuffer]
+ The new ring buffer, or None to indicate waiting for stream.
"""
- self.y_min = y_min
- self.y_max = y_max
-
- has_fixed_limits = y_min is not None and y_max is not None
-
- if has_fixed_limits:
- self.plot_widget.setYRange(y_min, y_max, padding=0)
- self.plot_widget.disableAutoRange(axis='y')
- logger.info(f"Y-axis limits set to [{y_min}, {y_max}] mV")
+ self.ring_buffer = ring_buffer
+ self.last_position = 0
+
+ self._display_cache.clear()
+ self._last_read_position = 0
+ self._cache_ready = False
+ self._current_plot_decimation = 1
+ self._cached_time_axis = None
+ self._time_axis_float32 = None
+
+ if ring_buffer is not None:
+ self._update_pipeline_sample_rate()
+
+ if ring_buffer is None:
+ logger.debug("LivePlotter: set to waiting state (no buffer)")
else:
- self.plot_widget.enableAutoRange(axis='y')
- logger.info("Y-axis auto-ranging enabled")
+ logger.debug("LivePlotter: connected to new ring buffer")
- def start_updates(self) -> None:
- """Starts the plot update timer."""
- self.timer.start(self.update_interval_ms)
+ if ring_buffer is not None:
+ self._update_pipeline_sample_rate()
- def stop_updates(self) -> None:
- """Stops the plot update timer."""
- self.timer.stop()
+ if ring_buffer is None:
+ logger.debug("LivePlotter: set to waiting state (no buffer)")
+ else:
+ logger.debug("LivePlotter: connected to new ring buffer")
+
+ def clear_curves(self) -> None:
+ """Clear all plot curves.
+
+ Called when starting new acquisition to clear any previous
+ data from the display.
+ """
+ for line in self.lines.values():
+ line.set_data(pos=np.zeros((2, 2), dtype=np.float32))
+ logger.debug("LivePlotter: curves cleared")
+
+ def _update_graticule(self) -> None:
+ """Update the oscilloscope-style graticule to match current view bounds.
+
+ Draws a fixed 10x10 division graticule constrained to the data region
+ with the centre crosshair emphasised.
+ """
+ x_min = 0.0
+ x_max = self.display_window_seconds
+ y_min = -5.0
+ y_max = 5.0
+
+ # 11 vertical lines (0 to 10 divisions)
+ # 11 horizontal lines (0 to 10 divisions)
+ # Each line is 2 points, stored as segments
+ num_v_lines = 11
+ num_h_lines = 11
+ total_segments = num_v_lines + num_h_lines
+
+ graticule_pos = np.zeros((total_segments * 2, 2), dtype=np.float32)
+
+ # Vertical lines (bottom to top)
+ for i in range(num_v_lines):
+ x = x_min + (x_max - x_min) * i / 10.0
+ base_idx = i * 2
+ graticule_pos[base_idx] = [x, y_min]
+ graticule_pos[base_idx + 1] = [x, y_max]
+
+ # Horizontal lines
+ for i in range(num_h_lines):
+ y = y_min + (y_max - y_min) * i / 10.0
+ base_idx = (num_v_lines + i) * 2
+ graticule_pos[base_idx] = [x_min, y]
+ graticule_pos[base_idx + 1] = [x_max, y]
+
+ self.graticule.set_data(pos=graticule_pos)
+
+ # Centre crosshair (horizontal and vertical through origin)
+ # x=0 vertical, y=0 horizontal
+ centre_pos = np.array(
+ [
+ [x_min, 0.0],
+ [x_max, 0.0],
+ [0.0, y_min],
+ [0.0, y_max],
+ ],
+ dtype=np.float32,
+ )
+ self.graticule_center.set_data(pos=centre_pos)
+
+ # Border rectangle
+ border_pos = np.array(
+ [
+ [x_min, y_min],
+ [x_max, y_min],
+ [x_max, y_max],
+ [x_min, y_max],
+ [x_min, y_min],
+ ],
+ dtype=np.float32,
+ )
+ self.border.set_data(pos=border_pos)
+
+ def set_display_window(self, window_seconds: float) -> None:
+ """Set the temporal display window width.
- def hide_status_bar(self, hide: bool = True) -> None:
- """Hide or show the status bar.
-
Parameters
----------
- hide : bool
- If True, hide the status bar. If False, show it.
+ window_seconds : float
+ Duration in seconds to display.
"""
- if hasattr(self, 'status_container'):
- self.status_container.setVisible(not hide)
+ self.display_window_seconds = window_seconds
+
+ # Invalidate cache when display window changes - time axis must be recalculated
+ self._cache_ready = False
+ self._cached_time_axis = None
+ self._time_axis_float32 = None
+ self._target_time_axis = None
+
+ self.view.camera.set_range(
+ x=(0, self.display_window_seconds),
+ y=(-5, 5),
+ )
+ self._update_graticule()
+
+ logger.info("Display window set to {}s", window_seconds)
+
+ def start_updates(self) -> None:
+ """Start the plot update timer."""
+ self.timer.start(self.update_interval_ms)
+
+ def stop_updates(self) -> None:
+ """Stop the plot update timer."""
+ self.timer.stop()
def setup_ui(self) -> None:
- """Sets up the main window, widgets, and plot layout."""
+ """Set up the main window, widgets, and plot layout."""
layout = QVBoxLayout(self)
layout.setContentsMargins(0, 0, 0, 0)
- # Status bar
- self.status_layout = QHBoxLayout()
- self.heartbeat_label = QLabel("UI: -")
- self.samples_label = QLabel("Samples: 0")
- self.rate_label = QLabel("Rate: -")
- self.plotter_latency_label = QLabel("Plotter Latency: 0 ms")
- self.error_label = QLabel("Errors: 0")
- self.saturation_label = QLabel("Saturation: -")
- self.acq_status_label = QLabel(
- '<span style="color: orange">Waiting for file...</span>'
+ self.canvas = scene.SceneCanvas(keys="interactive", show=False)
+ self.canvas.create_native()
+ self.canvas.native.setMinimumHeight(200)
+
+ self.view = self.canvas.central_widget.add_view()
+ self.view.camera = "panzoom"
+ self.view.camera.set_range(
+ x=(0, self.display_window_seconds),
+ y=(-5, 5),
)
- font = QFont()
- font.setFamily("Monospace")
- font.setFixedPitch(True)
- for label in [
- self.heartbeat_label,
- self.samples_label,
- self.rate_label,
- self.plotter_latency_label,
- self.error_label,
- self.saturation_label,
- self.acq_status_label,
- ]:
- label.setFont(font)
-
- # Add separators between status items
- self.status_layout.addWidget(self.heartbeat_label)
- self.status_layout.addWidget(QLabel(" | "))
- self.status_layout.addWidget(self.error_label)
- self.status_layout.addWidget(QLabel(" | "))
- self.status_layout.addWidget(self.saturation_label)
- self.status_layout.addWidget(QLabel(" | "))
- self.status_layout.addWidget(self.samples_label)
- self.status_layout.addWidget(QLabel(" | "))
- self.status_layout.addWidget(self.plotter_latency_label)
- self.status_layout.addWidget(QLabel(" | "))
- self.status_layout.addWidget(self.rate_label)
- self.status_layout.addWidget(QLabel(" | "))
- self.status_layout.addWidget(self.acq_status_label)
- self.status_layout.addStretch()
- self.status_container = QWidget()
- self.status_container.setLayout(self.status_layout)
- layout.addWidget(self.status_container)
-
- # Plot widget
- self.plot_widget = pg.PlotWidget()
- self.plot_widget.setLabel("left", "Voltage", "mV")
- self.plot_widget.setLabel("bottom", "Time", "s")
- self.plot_widget.showGrid(x=True, y=True)
- self.plot_widget.setXRange(0, self.display_window_seconds, padding=0)
-
- # Apply Y-axis limits if provided
- has_fixed_y_limits = self.y_min is not None and self.y_max is not None
- if has_fixed_y_limits:
- self.plot_widget.setYRange(self.y_min, self.y_max, padding=0)
- self.plot_widget.disableAutoRange(axis='y')
-
- # Plot curve
- self.curve = self.plot_widget.plot(pen="y", width=1)
-
- layout.addWidget(self.plot_widget)
-
- # Performance optimization
- self.plot_widget.setDownsampling(mode="peak")
- self.plot_widget.setClipToView(True)
-
- def check_file_exists(self) -> None:
- """Checks if the HDF5 file exists and attempts to read metadata.
-
- Only reads metadata if it hasn't been read yet for this file.
- """
- if self.metadata_read:
- return
-
- try:
- with h5py.File(self.hdf5_path, "r") as f:
- if "adc_counts" in f:
- self.acq_status_label.setText(
- '<span style="color: orange">Reading metadata...</span>'
- )
- self.read_metadata(f)
- else:
- self.acq_status_label.setText(
- '<span style="color: orange">Waiting for data...</span>'
- )
- except (FileNotFoundError, OSError):
- self.acq_status_label.setText(
- '<span style="color: orange">Waiting for file...</span>'
+
+ # Graticule rendered first (order=0) so it appears behind traces
+ self.graticule = scene.Line(
+ pos=np.zeros((0, 2), dtype=np.float32),
+ color=(0.25, 0.25, 0.25, 1.0),
+ parent=self.view.scene,
+ connect="segments",
+ width=1,
+ )
+ self.graticule.order = 0
+ self.graticule_center = scene.Line(
+ pos=np.zeros((0, 2), dtype=np.float32),
+ color=(0.4, 0.4, 0.4, 1.0),
+ parent=self.view.scene,
+ connect="segments",
+ width=1.5,
+ )
+ self.graticule_center.order = 0
+ self.border = scene.Line(
+ pos=np.zeros((5, 2), dtype=np.float32),
+ color=(0.7, 0.7, 0.7, 1.0),
+ parent=self.view.scene,
+ width=1.5,
+ )
+ self.border.order = 0
+ self._update_graticule()
+
+ # Signal traces rendered after (order=1) so they appear on top
+ self.lines: Dict[int, Any] = {}
+ for ch in [0, 1]:
+ self.lines[ch] = scene.Line(
+ pos=np.zeros((2, 2), dtype=np.float32),
+ color=self._channel_colors[ch],
+ parent=self.view.scene,
+ width=1,
)
+ self.lines[ch].order = 1
+ self.lines[ch].visible = ch in self.channels
+
+ self.canvas.native.setSizePolicy(
+ self.canvas.native.sizePolicy().Policy.Expanding,
+ self.canvas.native.sizePolicy().Policy.Expanding,
+ )
+ self.canvas.native.setFocusPolicy(Qt.FocusPolicy.NoFocus)
+
+ layout.addWidget(self.canvas.native, stretch=1)
- def read_metadata(self, hdf5_file: h5py.File) -> None:
- """Reads metadata attributes from the root of an open HDF5 file.
+ def set_v_div_settings(
+ self, v_div_volts: Dict[int, float], y_pos_divs: Dict[int, float]
+ ) -> None:
+ """Set V/div and Y-position settings for oscilloscope-style scaling.
Parameters
----------
- hdf5_file : h5py.File
- An open h5py.File object.
+ v_div_volts : Dict[int, float]
+ Mapping of channel index to volts per division.
+ y_pos_divs : Dict[int, float]
+ Mapping of channel index to Y-position in divisions from centre.
"""
- if self.metadata_read:
- logger.debug("Metadata already read, skipping")
- return
-
- try:
- # Metadata is stored as root-level attributes
- base_sample_interval_ns = hdf5_file.attrs["sample_interval_ns"]
- self.hardware_downsample_ratio = hdf5_file.attrs.get(
- "hardware_downsample_ratio", 1
- )
- self.sample_interval_ns = (
- base_sample_interval_ns * self.hardware_downsample_ratio
+ old_v_div = self.v_div_volts
+ old_y_pos = self.y_pos_divs
+ self.v_div_volts = v_div_volts.copy()
+ self.y_pos_divs = y_pos_divs.copy()
+
+ if old_v_div != self.v_div_volts or old_y_pos != self.y_pos_divs:
+ self._cache_dirty = True
+
+ logger.debug("V/div settings: {}, Y-pos: {}", v_div_volts, y_pos_divs)
+
+ def is_channel_clipped(self, channel: int) -> int:
+ """Check if a channel's data exceeds the current display range.
+
+ This checks for display clipping, where signal data goes beyond the
+ visible Y-axis range set by V/div and Y-pos settings. It does NOT
+ indicate hardware ADC saturation.
+
+ Returns
+ -------
+ int
+ 1 if clipping above display range, -1 if clipping below range, 0 if OK.
+ """
+ if channel not in self._display_cache:
+ return 0
+
+ divisions = self._display_cache[channel]
+ if len(divisions) == 0:
+ return 0
+
+ v_div = self.v_div_volts.get(channel, 1.0)
+ y_pos = self.y_pos_divs.get(channel, 0.0)
+
+ display_range_mv = v_div * 10 * 1000
+ centre_mv = y_pos * v_div * 1000
+ min_mv = centre_mv - (display_range_mv / 2)
+ max_mv = centre_mv + (display_range_mv / 2)
+
+ sample_size = min(10000, len(divisions))
+ sample_divisions = divisions[:sample_size]
+ voltage_volts = (sample_divisions - y_pos) * v_div
+ voltage_mv = voltage_volts * 1000
+
+ if np.any(voltage_mv > max_mv):
+ return 1
+ if np.any(voltage_mv < min_mv):
+ return -1
+ return 0
+
+ def set_v_div_data(self, channel: int, v_div_volts: float) -> None:
+ """Store V/div for a channel to be used in data conversion.
+
+ Parameters
+ ----------
+ channel : int
+ Channel index.
+ v_div_volts : float
+ Volts per division for this channel.
+ """
+ self.v_div_volts[channel] = v_div_volts
+
+ def _read_and_decimate_new_data(self) -> bool:
+ """Read new data since last frame and update display cache.
+
+ This method implements incremental reading from the ring buffer,
+ using the DataPipeline for processing.
+
+ Returns
+ -------
+ bool
+ True if new data was processed and cache updated.
+ False if no new data available.
+
+ Raises
+ ------
+ RuntimeError
+ If ring buffer position has been overwritten (overrun).
+ """
+ if self.ring_buffer is None or self.pipeline is None:
+ return False
+
+ snapshot = self.ring_buffer.get_snapshot()
+
+ self._notify_device_of_read_position(snapshot)
+
+ window_samples = self.pipeline.time_to_samples(self.display_window_seconds)
+ self._current_plot_decimation = self.pipeline.calculate_decimation(
+ window_samples
+ )
+
+ if self.update_count % 20 == 0:
+ logger.debug(
+ f"Buffer: snapshot={snapshot}, write_idx={self.ring_buffer.write_idx}, "
+ f"capacity={self.ring_buffer.capacity}, "
+ f"last_pos={self._last_read_position}, "
+ f"cache_ready={self._cache_ready}, "
+ f"adaptive_decimation={self._current_plot_decimation}"
)
- self.voltage_range_v = hdf5_file.attrs["voltage_range_v"]
- self.max_adc_val = hdf5_file.attrs["max_adc"]
- self.downsample_mode = hdf5_file.attrs.get("downsample_mode", "average")
- self.analog_offset_v = hdf5_file.attrs.get("analog_offset_v", 0.0)
-
- self.metadata_read = True
-
- logger.info(
- f"Plotter read metadata: voltage_range_v={self.voltage_range_v}V, "
- f"max_adc={self.max_adc_val}, downsample_mode={self.downsample_mode}, "
- f"analog_offset_v={self.analog_offset_v}V"
+ if self.pipeline.should_invalidate_cache(self._current_plot_decimation):
+ self.pipeline.invalidate_cache()
+ self._cache_ready = False
+ self._cached_time_axis = None
+ self._time_axis_float32 = None
+ self._display_cache.clear()
+ logger.debug(
+ f"Decimation factor changed, cache invalidated. "
+ f"New decimation: {self._current_plot_decimation}"
)
- # Update rate label with configured sample rate
- configured_rate_sps = 1e9 / self.sample_interval_ns
- self.rate_label.setText(
- f"Rate: ... : {self._format_rate_sps(configured_rate_sps)}"
+ if not self._cache_ready:
+ if window_samples <= 0:
+ return False
+
+ data = self.ring_buffer.read_last(window_samples)
+ if len(data) == 0:
+ return False
+
+ self._initialise_cache_from_data(data, self._current_plot_decimation)
+ self._last_read_position = snapshot
+ self._cache_ready = True
+ return True
+
+ if not self.ring_buffer.is_valid_range(self._last_read_position):
+ raise RuntimeError(
+ f"Plotter overrun: last read position {self._last_read_position} "
+ f"overwritten (buffer capacity {self.ring_buffer.capacity} samples). "
+ "The plotter cannot keep up with the data rate."
)
- except KeyError as e:
- logger.debug(f"Metadata not fully available yet: {e}. Will retry.")
-
- def _update_heartbeat(self) -> None:
- """Update UI heartbeat to show the UI thread is alive."""
- self.heartbeat_index = (self.heartbeat_index + 1) % len(self.heartbeat_chars)
- self.heartbeat_label.setText(
- f"UI: {self.heartbeat_chars[self.heartbeat_index]}"
- )
- def _handle_new_data(
- self, dataset: h5py.Dataset, start_index: int, current_size: int
- ) -> None:
- """Process a new window of data."""
- self.data_change_count += 1
- self.last_displayed_size = current_size
- self.last_data_timestamp = time.time()
- self.acq_status_label.setText('<span style="color: green">Active</span>')
-
- # Read only the most recent data window
- data_window = dataset[start_index:current_size]
- self.file_read_count += 1
-
- # Check for ADC saturation
- # The Picoscope driver scales data to 16-bit, so saturation occurs at the 16-bit limit.
+ new_data = self.ring_buffer.read_since(self._last_read_position)
+ if len(new_data) == 0:
+ return False
+
FIXED_MAX_ADC = 32767
- self.is_saturated = np.any(data_window >= FIXED_MAX_ADC) or np.any(
- data_window <= -FIXED_MAX_ADC
- )
+ if np.any(np.abs(new_data) >= FIXED_MAX_ADC):
+ self.is_saturated = True
- logger.debug(
- f"Update {self.update_count}: Reading window of {len(data_window):,} samples from index {start_index:,}"
- )
+ self.pipeline._last_decimation = self._current_plot_decimation
- # Update the display with this complete window (only when data changes)
- self.update_display(data_window)
-
- def _handle_buffer_reset(self, new_size: int) -> None:
- """Handle detection of a buffer reset (file size decreased)."""
- self.buffer_reset_count += 1
- self.total_samples_processed += self.last_file_size
- logger.debug(f"Buffer reset detected: size {self.last_file_size:,} → {new_size:,} (reset #{self.buffer_reset_count})")
-
- # Reset rate calculation to avoid confusion
- self.rate_check_start_time = None
- self.rate_check_start_samples = 0
- self.last_displayed_size = 0
-
- def _handle_stale_data(self) -> None:
- """Handle a file check where no new data is found."""
- self.stale_update_count += 1
- self.acq_status_label.setText(
- '<span style="color: orange">Acquiring... </span>'
+ geometry = self.pipeline.get_window_display_geometry(
+ self.display_window_seconds,
+ self._current_plot_decimation,
)
- # Log if we're frequently updating with no new data
- if self.stale_update_count % 10 == 0:
- logger.debug(
- f"File check #{self.update_count} with no new data (stale checks: {self.stale_update_count})"
- )
+ max_cache_points = geometry.max_display_values
- def _update_status_labels(self, current_size: int) -> None:
- """Update the various status labels in the UI."""
- # Update samples label (include total if buffer resets occurred)
- samples_text = self.format_sample_count(current_size)
- if self.buffer_reset_count > 0:
- total_processed = self.total_samples_processed + current_size
- total_text = self.format_sample_count(total_processed)
- self.samples_label.setText(f"Samples: {samples_text} (total: {total_text})")
- else:
- self.samples_label.setText(f"Samples: {samples_text}")
-
- # Color-code latency
- latency_color = (
- "green"
- if self.display_latency_ms < 100
- else "orange"
- if self.display_latency_ms < 500
- else "red"
- )
- self.plotter_latency_label.setText(
- f'<span style="color: {latency_color}">Plotter Latency: {self.display_latency_ms:.0f}ms</span>'
- )
+ for ch in self.channels:
+ if ch >= new_data.shape[1]:
+ continue
- # Error counter
- total_errors = self.conversion_error_count + self.file_error_count
- error_color = (
- "green" if total_errors == 0 else "orange" if total_errors < 10 else "red"
- )
- self.error_label.setText(
- f'<span style="color: {error_color}">Errors: {total_errors}</span>'
- )
+ ch_samples = new_data[:, ch]
- # Saturation status
- if self.voltage_range_v is None:
- self.saturation_label.setText("Saturation: -")
- elif self.is_saturated:
- self.saturation_label.setText(
- '<span style="color: red">Saturation: CLIPPING</span>'
- )
- else:
- self.saturation_label.setText(
- '<span style="color: green">Saturation: OK</span>'
+ voltage_data, has_data = self.pipeline.process_channel_data(
+ ch_samples, ch, self._current_plot_decimation
)
- def _update_rate_label(self, current_size: int) -> None:
- """Check and update acquisition rate status."""
- if not self.rate_check_start_time:
- return
+ if not has_data:
+ continue
- elapsed_time = time.perf_counter() - self.rate_check_start_time
- if elapsed_time > 1.0: # Check only after 1s for stability
- points_per_timestep = 2 if self.downsample_mode == "aggregate" else 1
- samples_acquired = current_size - self.rate_check_start_samples
- timesteps_acquired = samples_acquired / points_per_timestep
- actual_rate_sps = timesteps_acquired / elapsed_time
+ v_div = self.v_div_volts.get(ch, 0.1)
+ y_pos = self.y_pos_divs.get(ch, 0.0)
- configured_rate_sps = 1e9 / self.sample_interval_ns
- rate_ratio = actual_rate_sps / configured_rate_sps
+ voltage_volts = voltage_data / 1000.0
+ divisions = y_pos + voltage_volts / v_div
- configured_rate_str = self._format_rate_sps(configured_rate_sps)
- actual_rate_str = self._format_rate_sps(actual_rate_sps)
+ if ch in self._display_cache:
+ combined = np.concatenate([self._display_cache[ch], divisions])
+ else:
+ combined = divisions
- rate_text = f"Rate: {actual_rate_str} : {configured_rate_str}"
- if rate_ratio < 0.95:
- self.rate_label.setText(f'<span style="color: red">{rate_text}</span>')
+ if len(combined) > max_cache_points:
+ self._display_cache[ch] = combined[-max_cache_points:]
else:
- self.rate_label.setText(rate_text)
+ self._display_cache[ch] = combined
- def _process_data_from_file(self, f: h5py.File) -> None:
- """Read and process data from an open HDF5 file."""
- if "adc_counts" not in f:
- return
+ self._last_read_position = snapshot
- dataset = f["adc_counts"]
- current_size = dataset.shape[0]
+ return True
- if current_size == 0:
- return
+ def _notify_device_of_read_position(self, read_idx: int) -> None:
+ """Notify the device of our last read position for overflow detection.
- # Read metadata if not already done
- if not self.metadata_read:
- self.read_metadata(f)
- # If metadata still not available, wait for next update
- if not self.metadata_read:
- return
+ Uses the callback registered at construction, avoiding implicit
+ coupling through ring buffer attributes.
- # Start the rate check timer on the first data point
- if self.rate_check_start_time is None:
- self.rate_check_start_time = time.perf_counter()
- self.rate_check_start_samples = current_size
+ Parameters
+ ----------
+ read_idx : int
+ The sample index we last read up to.
+ """
+ if self._on_read_position is not None:
+ try:
+ self._on_read_position(read_idx)
+ except Exception:
+ logger.debug("Failed to notify device of read position")
- # Dynamically calculate the number of timesteps for the display window
- display_window_timesteps = int(
- self.display_window_seconds / (self.sample_interval_ns * 1e-9)
- )
+ def _initialise_cache_from_data(
+ self, data: np.ndarray, decimation_factor: int
+ ) -> None:
+ """Initialise display cache from full window of data."""
+ self._display_cache.clear()
+ self.is_saturated = False
- # In aggregate mode, each timestep has two points (min/max)
- display_window_points = display_window_timesteps
- if self.downsample_mode == "aggregate":
- display_window_points *= 2
-
- # Calculate where to start reading to get the last window of points
- start_index = max(0, current_size - display_window_points)
- self.data_start_sample = start_index
-
- # Detect buffer resets (file size decreased)
- if current_size < self.last_file_size:
- self._handle_buffer_reset(current_size)
-
- # Track data freshness and changes
- if current_size > self.last_displayed_size:
- self._handle_new_data(dataset, start_index, current_size)
- else:
- self._handle_stale_data()
-
- self.last_file_size = current_size
+ if self.pipeline is not None:
+ self.pipeline._remainder.clear()
+ self.pipeline._last_decimation = decimation_factor
+
+ FIXED_MAX_ADC = 32767
+ if np.any(np.abs(data) >= FIXED_MAX_ADC):
+ self.is_saturated = True
+
+ for ch in self.channels:
+ if ch >= data.shape[1]:
+ continue
+
+ ch_data = data[:, ch]
+
+ if self.pipeline is not None:
+ voltage_data, has_data = self.pipeline.process_channel_data(
+ ch_data, ch, decimation_factor
+ )
+ else:
+ has_data = False
+ voltage_data = np.array([], dtype=float)
+
+ if not has_data:
+ self._display_cache[ch] = np.array([], dtype=float)
+ continue
- self._update_status_labels(current_size)
- self._update_rate_label(current_size)
+ v_div = self.v_div_volts.get(ch, 0.1)
+ y_pos = self.y_pos_divs.get(ch, 0.0)
- def update_from_file(self) -> None:
- """Timer-driven function to read data from the HDF5 file and update the plot."""
+ voltage_volts = voltage_data / 1000.0
+ divisions = y_pos + voltage_volts / v_div
+
+ self._display_cache[ch] = divisions
+
+ self._cache_ready = True
+ self._cache_dirty = False
+
+ def update_from_buffer(self) -> None:
+ """Timer-driven function to read from ring buffer and update the plot."""
self.update_count += 1
- self._update_heartbeat()
+
+ if self.ring_buffer is None or self.pipeline is None:
+ return
try:
- with h5py.File(self.hdf5_path, "r") as f:
- self._process_data_from_file(f)
- except (FileNotFoundError, OSError):
- self.acq_status_label.setText(
- '<span style="color: orange">Waiting for file...</span>'
+ if self.acquisition_rate is None:
+ return
+ window_samples = self.pipeline.time_to_samples(self.display_window_seconds)
+ self._current_plot_decimation = self.pipeline.calculate_decimation(
+ window_samples
)
- except Exception as e:
- self.file_error_count += 1
- logger.error(f"Update {self.update_count}: Error reading file - {e}")
- self.acq_status_label.setText('<span style="color: red">File error!</span>')
- def update_display(self, data_window: np.ndarray) -> None:
- """Processes and displays a new window of data.
+ if self._cache_dirty:
+ logger.debug("Cache dirty, rebuilding positions")
+ cache_ready_before = self._cache_ready
+ self._cache_ready = False
+ snapshot = self.ring_buffer.get_snapshot()
+ self._notify_device_of_read_position(snapshot)
+
+ window_samples = self.pipeline.time_to_samples(
+ self.display_window_seconds
+ )
+
+ data = self.ring_buffer.read_last(window_samples)
+ if len(data) > 0:
+ self._initialise_cache_from_data(
+ data, self._current_plot_decimation
+ )
+ self._last_read_position = snapshot
+ self._cache_ready = cache_ready_before
+
+ has_new_data = self._read_and_decimate_new_data()
+ if not has_new_data:
+ return
+
+ self.last_data_timestamp = time.time()
- This involves decimation, voltage conversion, and updating the plot curve.
+ if self.rate_check_start_time is None:
+ self.rate_check_start_time = time.perf_counter()
+ self.rate_check_start_samples = self.ring_buffer.write_idx
+
+ if not self._display_cache:
+ return
+
+ geometry = self.pipeline.get_window_display_geometry(
+ self.display_window_seconds,
+ self._current_plot_decimation,
+ )
+ time_axis_float32 = geometry.time_axis.astype(np.float32)
+ n_time_points = len(time_axis_float32)
+
+ if n_time_points == 0:
+ return
+
+ for ch in self.channels:
+ if ch not in self._display_cache:
+ continue
+
+ div_data = self._display_cache[ch]
+ n_data_points = len(div_data)
+
+ if n_data_points == 0:
+ continue
+
+ n_points_to_plot = min(n_data_points, n_time_points)
+ if n_points_to_plot <= 0:
+ continue
+
+ if n_data_points != n_time_points and self.update_count % 100 == 0:
+ logger.debug(
+ "Data/time axis mismatch: channel {} has {} data points, "
+ "time axis has {} points (geometry.max_display_values={}), plotting {}",
+ ch,
+ n_data_points,
+ n_time_points,
+ geometry.max_display_values,
+ n_points_to_plot,
+ )
+
+ pos_buffer = self._pos_buffers.get(ch)
+ if pos_buffer is None or len(pos_buffer) < n_time_points:
+ self._allocate_display_buffers(max(n_time_points * 2, 20000))
+ pos_buffer = self._pos_buffers[ch]
+
+ pos_buffer[:n_time_points, :] = np.nan
+ start_idx = n_time_points - n_points_to_plot
+ pos_buffer[start_idx:n_time_points, 0] = time_axis_float32[
+ start_idx:n_time_points
+ ]
+ pos_buffer[start_idx:n_time_points, 1] = div_data[-n_points_to_plot:]
+
+ self.lines[ch].set_data(pos=pos_buffer[:n_time_points])
+
+ for line_ch in self.lines:
+ if line_ch not in self.channels:
+ self.lines[line_ch].set_data(pos=np.zeros((2, 2), dtype=np.float32))
+
+ except RuntimeError as e:
+ logger.exception("Plotter overrun detected: {}", e)
+ raise
+
+ except Exception as e:
+ logger.exception("Error updating from buffer: {}", e)
+
+ def update_display(
+ self, channel_data: Dict[int, np.ndarray], adaptive_decimation: int = 1
+ ) -> None:
+ """Process and display a new window of data for all channels.
Parameters
----------
- data_window : np.ndarray
- A NumPy array containing the raw ADC counts for display.
+ channel_data : Dict[int, np.ndarray]
+ Mapping of channel index to raw ADC data array.
+ adaptive_decimation : int
+ Adaptive decimation factor calculated based on window size.
"""
- if len(data_window) == 0:
+ if not channel_data or self.pipeline is None:
return
- self.display_data = data_window
self.display_update_count += 1
- logger.debug(
- f"Display update {self.display_update_count}: "
- f"Displaying window of {len(self.display_data):,} samples, "
- f"starting at sample {self.data_start_sample:,}"
+ geometry = self.pipeline.get_window_display_geometry(
+ self.display_window_seconds,
+ adaptive_decimation,
)
+ time_axis_float32 = geometry.time_axis.astype(np.float32)
- # Apply Numba-optimized decimation for display
- if self.decimation_factor > 1:
- decimated_data = min_max_decimate_numba(
- self.display_data, self.decimation_factor
+ for ch in self.channels:
+ raw_data = channel_data.get(ch)
+ if raw_data is None or len(raw_data) == 0:
+ continue
+
+ voltage_data, has_data = self.pipeline.process_channel_data(
+ raw_data, ch, adaptive_decimation
)
- else:
- decimated_data = self.display_data
- # Debug: Log ADC values and metadata
- logger.debug(f"ADC range: {decimated_data.min()} to {decimated_data.max()}")
- logger.debug(f"voltage_range_v: {self.voltage_range_v}")
+ if not has_data:
+ continue
- # Convert to voltage if we have calibration data
- if self.voltage_range_v is not None and self.max_adc_val is not None:
- try:
- voltage_data = adc_to_mV(
- decimated_data, self.voltage_range_v, self.max_adc_val
- )
- if self.analog_offset_v != 0.0:
- voltage_data += self.analog_offset_v * 1000
- logger.debug(
- f"Voltage conversion successful, range: {voltage_data.min():.1f} to {voltage_data.max():.1f} mV"
- )
- except Exception as e:
- self.conversion_error_count += 1
- logger.warning(f"Voltage conversion failed: {e}, using raw ADC values")
- voltage_data = decimated_data.astype(float)
- else:
- logger.warning(
- "Missing calibration data (voltage_range_v or max_adc_val), using raw ADC values"
- )
- voltage_data = decimated_data.astype(float)
+ div_data = self._convert_voltage_to_divisions(voltage_data, ch)
- # Create time axis for the current window
- time_axis = self.create_time_axis(len(voltage_data))
+ n_points = len(time_axis_float32)
- logger.debug(
- f"Display update {self.display_update_count}: "
- f"Decimated to {len(voltage_data):,} points, "
- f"time range: {time_axis[0]:.3f}s to {time_axis[-1]:.3f}s"
- )
+ pos_buffer = self._pos_buffers.get(ch)
+ if pos_buffer is None or len(pos_buffer) < n_points:
+ self._allocate_display_buffers(max(n_points * 2, 20000))
+ pos_buffer = self._pos_buffers[ch]
+
+ pos_buffer[:n_points, 0] = time_axis_float32[:n_points]
+ pos_buffer[:n_points, 1] = div_data[:n_points]
+ self.lines[ch].set_data(pos=pos_buffer[:n_points])
- # Calculate display latency
if self.last_data_timestamp:
self.display_latency_ms = (time.time() - self.last_data_timestamp) * 1000
- # Update plot
- self.curve.setData(time_axis, voltage_data)
+ if len(geometry.time_axis) > 0:
+ self.view.camera.set_range(
+ x=(0, self.display_window_seconds),
+ y=(-5, 5),
+ )
- # Update the X-axis range to match the new time axis, creating a "snapshot" effect.
- self.plot_widget.setXRange(time_axis[0], time_axis[-1], padding=0)
+ def _convert_voltage_to_divisions(
+ self, voltage_data: np.ndarray, channel: int
+ ) -> np.ndarray:
+ """Convert voltage data to oscilloscope divisions.
- # Auto-scale the Y-axis occasionally, but only if fixed limits are not set
- has_fixed_y_limits = self.y_min is not None and self.y_max is not None
- if not has_fixed_y_limits and self.display_update_count % 5 == 1:
- self.plot_widget.enableAutoRange(axis="y")
+ Parameters
+ ----------
+ voltage_data : np.ndarray
+ Voltage data in millivolts.
+ channel : int
+ Channel index.
+
+ Returns
+ -------
+ np.ndarray
+ Data in oscilloscope divisions.
+ """
+ v_div = self.v_div_volts.get(channel, 0.1)
+ y_pos = self.y_pos_divs.get(channel, 0.0)
+
+ voltage_volts = voltage_data / 1000.0
+ return y_pos + voltage_volts / v_div
def _format_rate_sps(self, rate_sps: float) -> str:
- """Formats a sample rate in Samples/sec into a human-readable string."""
+ """Format a sample rate in Samples/sec into a human-readable string.
+
+ Parameters
+ ----------
+ rate_sps : float
+ Sample rate in samples per second.
+
+ Returns
+ -------
+ str
+ Formatted rate string (e.g., "62.50 MS/s").
+ """
if rate_sps >= 1e9:
return f"{rate_sps / 1e9:.2f} GS/s"
if rate_sps >= 1e6:
@@ -664,17 +887,17 @@ class HDF5LivePlotter(QWidget):
return f"{rate_sps:.2f} S/s"
def format_sample_count(self, count: int) -> str:
- """Formats a large integer count into a human-readable string with units.
+ """Format a large integer count into a human-readable string with units.
Parameters
----------
count : int
- The integer number to format.
+ Sample count.
Returns
-------
str
- A formatted string (e.g., "1.23M", "2.34G").
+ Formatted count (e.g., "1.23M").
"""
if count >= 1_000_000_000:
return f"{count / 1_000_000_000:.2f}G"
@@ -682,125 +905,58 @@ class HDF5LivePlotter(QWidget):
return f"{count / 1_000_000:.2f}M"
if count >= 1_000:
return f"{count / 1_000:.2f}K"
- else:
- return str(count)
-
- def create_time_axis(self, n_samples: int) -> np.ndarray:
- """Creates a time axis for the displayed data window.
+ return str(count)
- The time axis is absolute, based on the data window's start position in
- the overall acquisition. It accounts for the `aggregate` downsample mode,
- where the data stream consists of interleaved min/max pairs.
-
- For min-max decimated data, it generates pairs of time coordinates to
- draw vertical lines for each min-max pair. For non-decimated data, it
- generates a linearly spaced time axis.
+ def set_channel_visible(self, channel_index: int, visible: bool) -> None:
+ """Toggle visibility of a channel's plot curve and update processing.
Parameters
----------
- n_samples : int
- The number of points for the time axis. This should be
- the number of points *after* decimation.
-
- Returns
- -------
- np.ndarray
- A NumPy array representing the time axis in seconds.
+ channel_index : int
+ Channel index (0 for A, 1 for B).
+ visible : bool
+ True to show, False to hide.
"""
- if n_samples == 0:
- return np.array([])
-
- time_per_timestep = self.sample_interval_ns * 1e-9
- points_per_timestep = 2 if self.downsample_mode == "aggregate" else 1
- time_per_point = time_per_timestep / points_per_timestep
-
- start_time = (self.data_start_sample / points_per_timestep) * time_per_timestep
-
- if self.decimation_factor > 1:
- # For min-max, create pairs of time points for vertical lines
- num_pairs = n_samples // 2
- time_step_between_groups = self.decimation_factor * time_per_point
- group_times = start_time + np.arange(num_pairs) * time_step_between_groups
- return np.repeat(group_times, 2)
- else:
- # For non-decimated data
- if self.downsample_mode == "aggregate":
- # Data is already min/max pairs from hardware. Create vertical lines.
- num_pairs = n_samples // 2
- time_step_between_pairs = time_per_timestep
- pair_times = start_time + np.arange(num_pairs) * time_step_between_pairs
- return np.repeat(pair_times, 2)
- else:
- # For linear (non-aggregate) data, create a simple time axis
- duration = (
- (n_samples - 1) * time_per_point if n_samples > 1 else 0
- )
- end_time = start_time + duration
- return np.linspace(start_time, end_time, n_samples)
-
- def save_screenshot(self) -> None:
- """Save a screenshot of the current plot."""
- try:
- import pyqtgraph.exporters
-
- timestamp = time.strftime("%Y%m%d_%H%M%S")
- filename = f"plot_screenshot_{timestamp}.png"
-
- # Export the plot widget as an image
- exporter = pg.exporters.ImageExporter(self.plot_widget.plotItem)
- exporter.export(filename)
-
- logger.info(f"Screenshot saved: {filename}")
+ if channel_index not in self.lines:
+ return
- # Temporarily stop updates to show the message for 2 seconds
- self.timer.stop()
- self.acq_status_label.setText(
- f'<span style="color: blue">Screenshot saved: {filename}</span>'
+ self.lines[channel_index].visible = visible
+
+ if visible and channel_index not in self.channels:
+ self.channels.append(channel_index)
+ self._cache_ready = False
+ logger.debug("Channel {} enabled and added to channels list", channel_index)
+ elif not visible and channel_index in self.channels:
+ self.channels.remove(channel_index)
+ if channel_index in self._display_cache:
+ del self._display_cache[channel_index]
+ self.lines[channel_index].set_data(pos=np.zeros((2, 2), dtype=np.float32))
+ logger.debug(
+ "Channel {} disabled, removed from processing, and cache cleared",
+ channel_index,
)
- def resume_updates():
- self.update_from_file() # Update once immediately
- self.timer.start(self.update_interval_ms)
-
- QTimer.singleShot(2000, resume_updates)
-
- except Exception as e:
- logger.error(f"Failed to save screenshot: {e}")
- # If we failed, restart timer if it was already running
- if not self.timer.isActive():
- self.timer.start(self.update_interval_ms)
-
- def closeEvent(self, event: QCloseEvent) -> None:
- """Handles the window close event.
-
- Stops the plot's internal timer and allows the Qt event loop to exit.
- The main application will handle the graceful shutdown.
- """
+ def closeEvent(self, event: QCloseEvent) -> None: # pyright: ignore
+ """Handle the window close event."""
logger.info("Close event received. Stopping timer.")
self.timer.stop()
event.accept()
- def keyPressEvent(self, event: QKeyEvent) -> None:
- """Handles key presses for application control (e.g., 'Q' to quit)."""
- if event.key() in (Qt.Key_S, Qt.Key_Space, Qt.Key_F12):
- logger.info("Screenshot key pressed. Saving screenshot.")
- self.save_screenshot()
- else:
- super().keyPressEvent(event)
+ def keyPressEvent(self, event: QKeyEvent) -> None: # pyright: ignore
+ """Handle key presses for application control."""
+ super().keyPressEvent(event)
def main() -> None:
- """Standalone HDF5 live plotter."""
- # Create the QApplication instance FIRST.
+ """Standalone live plotter for testing."""
app = QApplication(sys.argv)
- parser = argparse.ArgumentParser(description="Standalone HDF5 live plotter.")
- parser.add_argument("hdf5_path", type=str, help="Path to the HDF5 file.")
+ parser = argparse.ArgumentParser(description="Standalone live plotter.")
parser.add_argument(
"--window",
type=float,
- default=0.5,
- help="Display window in seconds. [default: 0.5]",
+ default=5.0,
+ help="Display window in seconds. [default: 5.0]",
)
parser.add_argument(
"--decimation",
@@ -808,31 +964,31 @@ def main() -> None:
default=150,
help="Decimation factor for plotting. [default: 150]",
)
- parser.add_argument(
- "--y-min",
- type=float,
- help="Minimum Y-axis limit in mV.",
- )
- parser.add_argument(
- "--y-max",
- type=float,
- help="Maximum Y-axis limit in mV.",
- )
args = parser.parse_args()
logger.info("Plotter process starting")
+
+ channels = [0, 1]
+ voltage_ranges = {0: 20.0, 1: 20.0}
+ offsets_v = {0: 0.0, 1: 0.0}
+ resolution = 16
+ downsample_mode = "NONE"
+
try:
- plotter = HDF5LivePlotter(
- hdf5_path=args.hdf5_path,
+ plotter = LivePlotter(
+ ring_buffer=None,
+ channels=channels,
+ voltage_ranges=voltage_ranges,
+ offsets_v=offsets_v,
+ resolution=resolution,
+ downsample_mode=downsample_mode,
display_window_seconds=args.window,
decimation_factor=args.decimation,
- y_min=args.y_min,
- y_max=args.y_max,
)
plotter.show()
- sys.exit(app.exec_())
+ sys.exit(app.exec())
except Exception as e:
- logger.error(f"Error in plotter process: {e}")
+ logger.exception("Error in plotter process: {}", e)
if __name__ == "__main__":
diff --git a/picostream/main.py b/picostream/main.py
index e7eb709..c6052ba 100644
--- a/picostream/main.py
+++ b/picostream/main.py
@@ -1,525 +1,2374 @@
+"""GUI for PicoStream application using PicoscopeBufferedStream device."""
+
import os
-import re
import sys
-from typing import Any, Dict, Optional
+import threading
+import time
+from datetime import datetime
+from pathlib import Path
+from typing import Callable, Optional
+import labdaemon as ld
from loguru import logger
-from PyQt5.QtCore import QObject, QSettings, QThread, QTimer, pyqtSignal, pyqtSlot, Qt
-from PyQt5.QtGui import QCloseEvent, QFont
-from PyQt5.QtWidgets import (
+from PyQt6.QtCore import QSettings, Qt, QTimer
+from PyQt6.QtGui import QCloseEvent, QFont, QFontMetrics, QIcon, QKeyEvent
+from PyQt6.QtWidgets import (
QApplication,
- QCheckBox,
QComboBox,
QDoubleSpinBox,
QFileDialog,
QFormLayout,
+ QFrame,
+ QGridLayout,
QGroupBox,
QHBoxLayout,
QLabel,
- QLCDNumber,
QLineEdit,
QMainWindow,
+ QMessageBox,
QPushButton,
- QSpinBox,
QVBoxLayout,
QWidget,
)
-from picostream.cli import Streamer, VOLTAGE_RANGE_MAP
-from picostream.dfplot import HDF5LivePlotter
+from picostream.device import PicoscopeBufferedStream
+from picostream.dfplot import LivePlotter
+from picostream.mock_device import MockPicoscopeBufferedStream
+from picostream.zarr_viewer import ZarrViewerWindow
-class StreamerWorker(QObject):
- """Worker to run the data acquisition in a background thread."""
+class LosslessWarningDialog(QMessageBox):
+ """Custom dialog for Lossless mode warning with override option."""
- finished = pyqtSignal()
- error = pyqtSignal(str)
- stopRequested = pyqtSignal()
+ def __init__(
+ self,
+ required_points: int,
+ max_points: int,
+ refresh_rate: str,
+ parent: Optional[QWidget] = None,
+ ) -> None:
+ super().__init__(parent)
+ self.setIcon(QMessageBox.Icon.Warning)
+ self.setWindowTitle("Lossless Mode Warning")
- def __init__(self, settings: Dict[str, Any]) -> None:
- """Initialise the worker."""
- super().__init__()
- self.streamer: Optional[Streamer] = None
- self.settings = settings
+ self.setText(
+ f"Lossless mode requires {required_points:,} points, which exceeds "
+ f"the recommended limit of {max_points:,} points at {refresh_rate}.\n\n"
+ "Displaying this many points may cause performance issues."
+ )
- def run(self) -> None:
- """Run the data acquisition."""
- try:
- self.streamer = Streamer(**self.settings)
- self.streamer.run()
- except Exception as e:
- self.error.emit(str(e))
- self.finished.emit()
+ self.setInformativeText(
+ "You can:\n"
+ "• Use High Quality (~500k points) for smooth performance\n"
+ "• Force Lossless to proceed anyway\n"
+ "• Cancel to return to settings"
+ )
+
+ self._use_high_button = self.addButton(
+ "Use High Quality", QMessageBox.ButtonRole.AcceptRole
+ )
+ self._force_lossless_button = self.addButton(
+ "Force Lossless", QMessageBox.ButtonRole.ApplyRole
+ )
+ self._cancel_button = self.addButton(
+ "Cancel", QMessageBox.ButtonRole.RejectRole
+ )
+
+ self.setDefaultButton(self._use_high_button)
+
+ def user_choice(self) -> str:
+ """Return the user's choice as a string.
+
+ Returns
+ -------
+ str
+ One of: "high", "lossless", "cancel"
+ """
+ clicked = self.clickedButton()
+ if clicked == self._use_high_button:
+ return "high"
+ if clicked == self._force_lossless_button:
+ return "lossless"
+ return "cancel"
+
+
+QUALITY_POINTS = {
+ "Low": 25_000,
+ "Medium": 100_000,
+ "High": 500_000,
+ "Lossless": None, # None indicates no decimation (show every sample)
+}
- @pyqtSlot()
- def stop(self) -> None:
- """Signal the acquisition to stop."""
- if self.streamer:
- self.streamer.shutdown()
- self.finished.emit()
+# Refresh rate options: (interval_ms, max_lossless_points)
+# Limits for VisPy OpenGL rendering - significantly higher than pyqtgraph
+# Targets: 20Hz@2M, 10Hz@5M, 5Hz@10M, 2Hz@20M, 1Hz@50M points
+REFRESH_RATE_MAP = {
+ "20 Hz": (50, 2_000_000), # 2M points at 20Hz
+ "10 Hz": (100, 5_000_000), # 5M points at 10Hz
+ "5 Hz": (200, 10_000_000), # 10M points at 5Hz
+ "2 Hz": (500, 20_000_000), # 20M points at 2Hz
+ "1 Hz": (1000, 50_000_000), # 50M points at 1Hz
+}
class PicoStreamMainWindow(QMainWindow):
"""The main window for the PicoStream GUI application."""
def __init__(self) -> None:
- """Initialise the main window."""
super().__init__()
self.setWindowTitle("PicoStream")
- self.setGeometry(100, 100, 1200, 600)
+ self.setGeometry(100, 100, 1400, 700)
- self.settings = QSettings("picostream", "PicoStream")
- self.thread: Optional[QThread] = None
- self.worker: Optional[StreamerWorker] = None
+ icon_path = Path(__file__).parent.parent.parent / "assets" / "icons" / "app.ico"
+ if icon_path.exists():
+ self.setWindowIcon(QIcon(str(icon_path)))
+
+ base_font = QApplication.font()
+ base_font.setPointSize(base_font.pointSize() - 1)
+ QApplication.setFont(base_font)
+
+ self.settings = QSettings("qolcode", "PicoStream")
+ self._zarr_viewers: list[ZarrViewerWindow] = []
+ self._is_recording: bool = False
+ self._devices_connected: bool = False
+
+ self._daemon: Optional[ld.LabDaemon] = None
+ self._device: Optional[PicoscopeBufferedStream] = None
+ self._stream_check_timer: Optional[QTimer] = None
+ self._save_poll_timer: Optional[QTimer] = None
+ self._save_start_time: Optional[float] = None
+ self._save_timeout_seconds: float = 15.0
+ self._pending_save_path: Optional[str] = None
central_widget = QWidget()
self.setCentralWidget(central_widget)
main_layout = QHBoxLayout(central_widget)
+ central_widget.setFocusPolicy(Qt.FocusPolicy.StrongFocus)
- # Left panel for settings
settings_panel = QWidget()
- settings_panel.setFixedWidth(350)
+ settings_panel.setFixedWidth(320)
settings_layout = QVBoxLayout(settings_panel)
- form_layout = QFormLayout()
-
- self.sample_rate_input = QDoubleSpinBox()
- self.sample_rate_input.setRange(1, 125)
- self.sample_rate_input.setValue(62.5)
- self.sample_rate_input.setSuffix(" MS/s")
- form_layout.addRow("Sample Rate:", self.sample_rate_input)
-
- self.resolution_input = QComboBox()
- self.resolution_input.addItems(["12", "14", "15", "16", "8"])
- form_layout.addRow("Resolution (bits):", self.resolution_input)
-
- self.voltage_range_input = QComboBox()
- self.voltage_range_input.addItems(VOLTAGE_RANGE_MAP.values())
- form_layout.addRow("Voltage Range:", self.voltage_range_input)
-
- self.output_file_input = QLineEdit()
- self.output_file_input.setText("output.hdf5")
- file_browse_button = QPushButton("Browse...")
- file_browse_button.clicked.connect(self.select_output_file)
- file_layout = QHBoxLayout()
- file_layout.addWidget(self.output_file_input)
- file_layout.addWidget(file_browse_button)
- form_layout.addRow("Output File:", file_layout)
-
- self.hw_downsample_input = QSpinBox()
- self.hw_downsample_input.setRange(1, 1000)
- form_layout.addRow("HW Downsample:", self.hw_downsample_input)
-
- self.downsample_mode_input = QComboBox()
- self.downsample_mode_input.addItems(["average", "aggregate"])
- form_layout.addRow("Downsample Mode:", self.downsample_mode_input)
-
- self.offset_v_input = QDoubleSpinBox()
- self.offset_v_input.setRange(-1.0, 1.0)
- self.offset_v_input.setSingleStep(0.01)
- self.offset_v_input.setDecimals(3)
- self.offset_v_input.setSuffix(" V")
- form_layout.addRow("Offset:", self.offset_v_input)
- self.bandwidth_limiter_input = QComboBox()
- self.bandwidth_limiter_input.addItems(["Full", "20 MHz"])
- form_layout.addRow("Bandwidth:", self.bandwidth_limiter_input)
+ self.toggle_channel_a = QPushButton("Channel A: ON")
+ self.toggle_channel_a.setCheckable(True)
+ self.toggle_channel_a.setChecked(True)
+ self.toggle_channel_a.clicked.connect(self._on_toggle_channel_a)
+ self.toggle_channel_a.setToolTip("Toggle Channel A visibility (1)")
+ self.toggle_channel_a.setStyleSheet(
+ "QPushButton { background-color: #00BFFF; color: white; }"
+ "QPushButton:checked { background-color: #00BFFF; color: white; }"
+ "QPushButton:!checked { background-color: #888888; color: white; }"
+ )
+ self.toggle_channel_b = QPushButton("Channel B: ON")
+ self.toggle_channel_b.setCheckable(True)
+ self.toggle_channel_b.setChecked(True)
+ self.toggle_channel_b.clicked.connect(self._on_toggle_channel_b)
+ self.toggle_channel_b.setToolTip("Toggle Channel B visibility (2)")
+ self.toggle_channel_b.setStyleSheet(
+ "QPushButton { background-color: #FF4444; color: white; }"
+ "QPushButton:checked { background-color: #FF4444; color: white; }"
+ "QPushButton:!checked { background-color: #888888; color: white; }"
+ )
- self.live_only_checkbox = QCheckBox("Live-only (overwrite buffer)")
- self.max_buffer_input = QDoubleSpinBox()
- self.max_buffer_input.setRange(0.1, 60.0)
- self.max_buffer_input.setValue(1.0)
- self.max_buffer_input.setSuffix(" s")
- self.live_only_checkbox.stateChanged.connect(
- lambda state: self.max_buffer_input.setEnabled(state > 0)
+ self.serial_code_input = QLineEdit()
+ self.serial_code_input.setPlaceholderText("Empty (auto) / serial # / MOCK")
+ self.serial_code_input.setStyleSheet(
+ "QLineEdit { color: black; } QLineEdit::placeholder { color: #888; }"
+ )
+ self.serial_code_input.setToolTip(
+ "Picoscope serial code (e.g., AW123/456).\n"
+ "Leave empty to auto-detect connected device.\n"
+ "Enter MOCK for simulated device (testing)."
)
- form_layout.addRow(self.live_only_checkbox, self.max_buffer_input)
+ self.serial_code_input.setFocusPolicy(Qt.FocusPolicy.ClickFocus)
- settings_layout.addLayout(form_layout)
+ serial_layout = QHBoxLayout()
+ serial_layout.addWidget(self.serial_code_input, stretch=1)
+ self.connect_button = QPushButton("Connect Device")
+ self.connect_button.clicked.connect(self.toggle_device_connection)
+ self.connect_button.setToolTip(
+ "Connect to or disconnect from the Picoscope device"
+ )
+ serial_layout.addWidget(self.connect_button)
+ settings_layout.addLayout(serial_layout)
- # Start/Stop buttons
self.start_button = QPushButton("Start Acquisition")
+ self.start_button.clicked.connect(self.start_acquisition)
+ self.start_button.setEnabled(False)
+ self.start_button.setToolTip("Begin data acquisition (Space)")
+
self.stop_button = QPushButton("Stop Acquisition")
self.stop_button.setEnabled(False)
-
+ self.stop_button.clicked.connect(self.stop_acquisition)
+ self.stop_button.setToolTip("Stop active acquisition (Space)")
+
button_layout = QHBoxLayout()
button_layout.addWidget(self.start_button)
button_layout.addWidget(self.stop_button)
settings_layout.addLayout(button_layout)
- # Plot settings group
- plot_settings_group = QGroupBox("Plot Settings")
- plot_settings_form = QFormLayout(plot_settings_group)
-
- self.plot_window_input = QDoubleSpinBox()
- self.plot_window_input.setRange(0.01, 10.0)
- self.plot_window_input.setValue(0.5)
- self.plot_window_input.setSingleStep(0.1)
- self.plot_window_input.setDecimals(2)
- self.plot_window_input.setSuffix(" s")
- plot_settings_form.addRow("Display Window:", self.plot_window_input)
-
- self.y_axis_auto_checkbox = QCheckBox("Auto Y-axis")
- self.y_axis_auto_checkbox.setChecked(True)
- self.y_axis_auto_checkbox.setLayoutDirection(Qt.RightToLeft)
- self.y_min_input = QDoubleSpinBox()
- self.y_min_input.setRange(-100000, 100000)
- self.y_min_input.setValue(-1000)
- self.y_min_input.setSuffix(" mV")
- self.y_min_input.setEnabled(False)
- self.y_max_input = QDoubleSpinBox()
- self.y_max_input.setRange(-100000, 100000)
- self.y_max_input.setValue(1000)
- self.y_max_input.setSuffix(" mV")
- self.y_max_input.setEnabled(False)
-
- self.y_axis_auto_checkbox.stateChanged.connect(
- lambda state: self.y_min_input.setEnabled(state == 0)
- )
- self.y_axis_auto_checkbox.stateChanged.connect(
- lambda state: self.y_max_input.setEnabled(state == 0)
- )
-
- # plot_settings_form.addRow()
- y_limits_layout = QHBoxLayout()
- y_limits_layout.addWidget(self.y_axis_auto_checkbox)
- y_limits_layout.addWidget(QLabel("Min:"))
- y_limits_layout.addWidget(self.y_min_input)
- y_limits_layout.addWidget(QLabel("Max:"))
- y_limits_layout.addWidget(self.y_max_input)
- plot_settings_form.addRow(y_limits_layout)
-
- self.apply_plot_settings_button = QPushButton("Apply Plot Settings")
- plot_settings_form.addRow(self.apply_plot_settings_button)
-
- settings_layout.addWidget(plot_settings_group)
-
- # Status display section using QGroupBox and QFormLayout
+ record_group = QGroupBox("Record Controls")
+ record_form = QFormLayout(record_group)
+
+ self.start_record_button = QPushButton("Start Recording")
+ self.start_record_button.clicked.connect(self._on_start_record)
+ self.start_record_button.setEnabled(False)
+ self.start_record_button.setToolTip("Start recording ring buffer to disk (R)")
+
+ self.finish_keep_button = QPushButton("Stop & Keep")
+ self.finish_keep_button.setStyleSheet(
+ "background-color: #4CAF50; color: white;"
+ )
+ self.finish_keep_button.clicked.connect(self._on_finish_keep)
+ self.finish_keep_button.hide()
+ self.finish_keep_button.setToolTip("Stop recording and keep the file (K)")
+
+ self.finish_discard_button = QPushButton("Stop & Discard")
+ self.finish_discard_button.setStyleSheet(
+ "background-color: #f44336; color: white;"
+ )
+ self.finish_discard_button.clicked.connect(self._on_finish_discard)
+ self.finish_discard_button.hide()
+ self.finish_discard_button.setToolTip(
+ "Stop recording and discard the file (Del)"
+ )
+
+ self.load_file_button = QPushButton("Load File")
+ self.load_file_button.clicked.connect(self._on_load_file)
+ self.load_file_button.setToolTip(
+ "Open directory containing recorded Zarr files (Ctrl+O)"
+ )
+
+ record_buttons_layout = QHBoxLayout()
+ record_buttons_layout.addWidget(self.start_record_button)
+ record_buttons_layout.addWidget(self.finish_keep_button)
+ record_buttons_layout.addWidget(self.finish_discard_button)
+ record_buttons_layout.addStretch()
+ record_buttons_layout.addWidget(self.load_file_button)
+ record_form.addRow(record_buttons_layout)
+
+ # Pre-trigger placed below action buttons
+ self.pre_trigger_input = QComboBox()
+ self.pre_trigger_input.addItems(["1s", "2s", "5s", "10s", "20s"])
+ self.pre_trigger_input.setCurrentText("10s")
+ self.pre_trigger_input.setMinimumWidth(80)
+ self.pre_trigger_input.setToolTip("Samples to capture before trigger point")
+ self.pre_trigger_input.currentTextChanged.connect(self._on_pre_trigger_changed)
+
+ pre_trigger_layout = QHBoxLayout()
+ pre_trigger_layout.addStretch()
+ pre_trigger_layout.addWidget(self.pre_trigger_input)
+ record_form.addRow("Pre-trigger:", pre_trigger_layout)
+
+ self.record_directory_label = QLineEdit()
+ self.record_directory_label.setReadOnly(True)
+ self.record_directory_label.setToolTip("Directory where recordings are saved")
+ self.record_directory_label.setStyleSheet(
+ "QLineEdit { background-color: transparent; border: none; }"
+ )
+ self._update_directory_label()
+
+ record_change_button = QPushButton("Change...")
+ record_change_button.clicked.connect(self._select_record_directory)
+ record_location_layout = QHBoxLayout()
+ record_location_layout.addWidget(self.record_directory_label, stretch=1)
+ record_location_layout.addWidget(record_change_button)
+
+ record_form.addRow("Save to:", record_location_layout)
+
+ settings_layout.addWidget(record_group)
+
+ acq_group = QGroupBox("Acquisition Parameters")
+ acq_grid = QGridLayout(acq_group)
+ acq_grid.setHorizontalSpacing(8)
+
+ acq_grid.addWidget(QLabel("Res:"), 0, 0)
+ self.resolution_input = QComboBox()
+ self.resolution_input.addItems(["8", "12", "14", "15", "16"])
+ self.resolution_input.setFixedWidth(60)
+ self.resolution_input.setToolTip(
+ "Resolution in bits - higher bits = higher precision"
+ )
+ self.resolution_input.currentIndexChanged.connect(self._on_resolution_changed)
+ acq_grid.addWidget(self.resolution_input, 0, 1)
+
+ acq_grid.addWidget(QLabel("Rate (p. ch.):"), 0, 2)
+ self.sample_rate_input = QComboBox()
+ self.sample_rate_input.setFixedWidth(90)
+ self.sample_rate_input.setToolTip("Samples per second per channel")
+ self.sample_rate_input.currentTextChanged.connect(self._on_sample_rate_changed)
+ acq_grid.addWidget(self.sample_rate_input, 0, 3)
+ self._init_sample_rate_dropdown()
+
+ acq_grid.addWidget(QLabel("BW:"), 1, 0)
+ self.bandwidth_limiter_input = QComboBox()
+ self.bandwidth_limiter_input.addItems(["Full", "20 MHz"])
+ self.bandwidth_limiter_input.setFixedWidth(60)
+ self.bandwidth_limiter_input.setToolTip("Apply analog bandwidth limiting")
+ acq_grid.addWidget(self.bandwidth_limiter_input, 1, 1)
+
+ acq_grid.addWidget(QLabel("Buffer:"), 1, 2)
+ self.ring_buffer_input = QComboBox()
+ self.ring_buffer_input.addItems(["5s", "10s", "30s", "60s", "120s"])
+ self.ring_buffer_input.setCurrentText("30s")
+ self.ring_buffer_input.setFixedWidth(90)
+ self.ring_buffer_input.currentTextChanged.connect(
+ self._on_buffer_duration_changed
+ )
+ self.ring_buffer_input.setToolTip("Duration of lookback buffer in seconds")
+ acq_grid.addWidget(self.ring_buffer_input, 1, 3)
+
+ acq_grid.addWidget(QLabel("Down:"), 2, 0)
+ self.hw_downsample_input = QComboBox()
+ self.hw_downsample_input.addItems(
+ ["1x", "2x", "5x", "10x", "20x", "50x", "100x"]
+ )
+ self.hw_downsample_input.setFixedWidth(60)
+ self.hw_downsample_input.setToolTip("Hardware downsampling factor")
+ acq_grid.addWidget(self.hw_downsample_input, 2, 1)
+
+ acq_grid.addWidget(QLabel("Mode:"), 2, 2)
+ self.downsample_mode_input = QComboBox()
+ self.downsample_mode_input.addItems(["none", "average", "aggregate"])
+ self.downsample_mode_input.setFixedWidth(90)
+ self.downsample_mode_input.setToolTip(
+ "Downsampling algorithm: none (disabled), average across pairs, or aggregate (min/max)"
+ )
+ acq_grid.addWidget(self.downsample_mode_input, 2, 3)
+
+ settings_layout.addWidget(acq_group)
+
+ channel_settings_group = QGroupBox("Channel Settings")
+ channel_vbox = QVBoxLayout(channel_settings_group)
+ channel_vbox.setSpacing(4)
+
+ voltage_range_values = [
+ "0.01",
+ "0.02",
+ "0.05",
+ "0.1",
+ "0.2",
+ "0.5",
+ "1.0",
+ "2.0",
+ "5.0",
+ "10.0",
+ "20.0",
+ ]
+
+ channel_a_row = QHBoxLayout()
+ channel_a_row.addWidget(self.toggle_channel_a)
+ channel_a_row.addWidget(QLabel("R:"))
+ self.voltage_range_a = QComboBox()
+ self.voltage_range_a.addItems(voltage_range_values)
+ self.voltage_range_a.setFixedWidth(55)
+ self.voltage_range_a.setToolTip("Input voltage range in volts")
+ self.voltage_range_a.currentIndexChanged.connect(
+ self._on_voltage_range_a_changed
+ )
+ channel_a_row.addWidget(self.voltage_range_a)
+ channel_a_row.addWidget(QLabel("O:"))
+ self.offset_a = QDoubleSpinBox()
+ self.offset_a.setRange(-1.0, 1.0)
+ self.offset_a.setSingleStep(0.01)
+ self.offset_a.setDecimals(2)
+ self.offset_a.setSuffix(" V")
+ self.offset_a.setFixedWidth(80)
+ self.offset_a.setToolTip("DC offset in volts")
+ self.offset_a.valueChanged.connect(self._on_offset_a_changed)
+ channel_a_row.addWidget(self.offset_a)
+ channel_vbox.addLayout(channel_a_row)
+
+ channel_b_row = QHBoxLayout()
+ channel_b_row.addWidget(self.toggle_channel_b)
+ channel_b_row.addWidget(QLabel("R:"))
+ self.voltage_range_b = QComboBox()
+ self.voltage_range_b.addItems(voltage_range_values)
+ self.voltage_range_b.setFixedWidth(55)
+ self.voltage_range_b.setToolTip("Input voltage range in volts")
+ self.voltage_range_b.currentIndexChanged.connect(
+ self._on_voltage_range_b_changed
+ )
+ channel_b_row.addWidget(self.voltage_range_b)
+ channel_b_row.addWidget(QLabel("O:"))
+ self.offset_b = QDoubleSpinBox()
+ self.offset_b.setRange(-1.0, 1.0)
+ self.offset_b.setSingleStep(0.01)
+ self.offset_b.setDecimals(2)
+ self.offset_b.setSuffix(" V")
+ self.offset_b.setFixedWidth(80)
+ self.offset_b.setToolTip("DC offset in volts")
+ self.offset_b.valueChanged.connect(self._on_offset_b_changed)
+ channel_b_row.addWidget(self.offset_b)
+ channel_vbox.addLayout(channel_b_row)
+
+ settings_layout.addWidget(channel_settings_group)
+
status_group = QGroupBox("Status")
- status_form_layout = QFormLayout(status_group)
-
- # Create status value labels with monospace font
- font = QFont()
- font.setFamily("Monospace")
- font.setFixedPitch(True)
- font.setPointSize(9)
-
- self.status_heartbeat = QLabel("-")
- self.status_heartbeat.setFont(font)
- status_form_layout.addRow("UI:", self.status_heartbeat)
-
- self.status_errors = self._create_status_label(font)
- status_form_layout.addRow("Errors:", self.status_errors)
-
- self.status_saturation = QLabel("-")
- self.status_saturation.setFont(font)
- status_form_layout.addRow("Saturation:", self.status_saturation)
-
- self.status_samples = QLCDNumber()
- self.status_samples.setDigitCount(10)
- self.status_samples.setSegmentStyle(QLCDNumber.Flat)
- self.status_samples.display("0")
- status_form_layout.addRow("Samples:", self.status_samples)
-
- self.status_latency = QLabel("-")
- self.status_latency.setFont(font)
- status_form_layout.addRow("Latency:", self.status_latency)
-
+ status_form = QFormLayout(status_group)
+ status_form.setSpacing(6)
+ status_form.setContentsMargins(8, 12, 8, 12)
+
+ self.status_state = QLabel("IDLE")
+ self.status_state.setToolTip("Current acquisition state")
+ self.status_state.setFrameStyle(QFrame.Shape.StyledPanel | QFrame.Shadow.Sunken)
+ self.status_state.setAlignment(Qt.AlignmentFlag.AlignCenter)
+ self.status_state.setMinimumWidth(70)
+ self._update_state_indicator("idle")
+ status_form.addRow("State:", self.status_state)
+
+ self.status_saturation = QLabel("OK")
+ self.status_saturation.setToolTip(
+ "Hardware ADC saturation: signal exceeds voltage range setting"
+ )
+ self.status_saturation.setFrameStyle(
+ QFrame.Shape.StyledPanel | QFrame.Shadow.Sunken
+ )
+ self.status_saturation.setAlignment(Qt.AlignmentFlag.AlignCenter)
+ self.status_saturation.setMinimumWidth(70)
+ self._update_saturation_indicator(False)
+ status_form.addRow("Saturation:", self.status_saturation)
+
+ self.status_record = QLabel("Idle")
+ self.status_record.setToolTip("Recording status")
+ self.status_record.setFont(QFont("Monospace", 9))
+ self.status_record.setFrameStyle(
+ QFrame.Shape.StyledPanel | QFrame.Shadow.Sunken
+ )
+ self.status_record.setStyleSheet("background-color: #E0E0E0; padding: 2px;")
+ self.status_record.setAlignment(Qt.AlignmentFlag.AlignCenter)
+ self.status_record.setMinimumWidth(140)
+ status_form.addRow("Record:", self.status_record)
+
+ self.status_samples = QLabel("0")
+ self.status_samples.setToolTip("Total collected samples")
+ self.status_samples.setFont(QFont("Monospace", 10))
+ self.status_samples.setAlignment(Qt.AlignmentFlag.AlignRight)
+ self.status_samples.setMinimumWidth(70)
+ status_form.addRow("Acq. Samples:", self.status_samples)
+
+ self.status_duration = QLabel("-")
+ self.status_duration.setToolTip(
+ "Total acquisition time (Acq. Samples / acquisition rate)"
+ )
+ self.status_duration.setFont(QFont("Monospace", 10))
+ self.status_duration.setAlignment(Qt.AlignmentFlag.AlignRight)
+ self.status_duration.setMinimumWidth(70)
+ status_form.addRow("Acq. Duration:", self.status_duration)
+
self.status_rate = QLabel("-")
- self.status_rate.setFont(font)
- status_form_layout.addRow("Rate:", self.status_rate)
-
- self.status_acquisition = QLabel("Waiting for file...")
- self.status_acquisition.setFont(font)
- self.status_acquisition.setWordWrap(True)
- status_form_layout.addRow("Acquisition:", self.status_acquisition)
-
+ self.status_rate.setToolTip("Hardware sample rate (as reported by device)")
+ self.status_rate.setFont(QFont("Monospace", 10))
+ self.status_rate.setAlignment(Qt.AlignmentFlag.AlignRight)
+ self.status_rate.setMinimumWidth(100)
+ status_form.addRow("HW Rate:", self.status_rate)
+
+ self.status_lag = QLabel("-")
+ self.status_lag.setToolTip(
+ "Producer-consumer lag (how far behind the plotter is)"
+ )
+ self.status_lag.setFont(QFont("Monospace", 10))
+ self.status_lag.setAlignment(Qt.AlignmentFlag.AlignRight)
+ self.status_lag.setMinimumWidth(100)
+ status_form.addRow("Lag:", self.status_lag)
+
+ # New display status indicators
+ self.status_display_points = QLabel("-")
+ self.status_display_points.setToolTip(
+ "Number of points currently displayed in plot"
+ )
+ self.status_display_points.setFont(QFont("Monospace", 10))
+ self.status_display_points.setAlignment(Qt.AlignmentFlag.AlignRight)
+ self.status_display_points.setMinimumWidth(100)
+ status_form.addRow("Display Points:", self.status_display_points)
+
+ self.status_decimation = QLabel("-")
+ self.status_decimation.setToolTip(
+ "Current software decimation factor (1× = lossless)"
+ )
+ self.status_decimation.setFont(QFont("Monospace", 10))
+ self.status_decimation.setAlignment(Qt.AlignmentFlag.AlignRight)
+ self.status_decimation.setMinimumWidth(100)
+ status_form.addRow("Decimation:", self.status_decimation)
+
settings_layout.addWidget(status_group)
settings_layout.addStretch()
- # Right panel for the plot
- self.plotter = HDF5LivePlotter(hdf5_path=self.output_file_input.text())
-
- # Hide the plotter's status bar since we're showing it in the sidebar
- self.plotter.hide_status_bar(hide=True)
+ default_target_points = 20000
+
+ # Will be connected after device is created
+ self._plotter_read_position_callback: Optional[Callable[[int], None]] = None
+
+ self.plotter = LivePlotter(
+ ring_buffer=None,
+ channels=[0, 1],
+ voltage_ranges={
+ 0: float(self.voltage_range_a.currentText()),
+ 1: float(self.voltage_range_b.currentText()),
+ },
+ offsets_v={0: self.offset_a.value(), 1: self.offset_b.value()},
+ resolution=12,
+ downsample_mode="NONE",
+ downsample_ratio=1,
+ target_plot_points=default_target_points,
+ on_read_position=self._on_plotter_read_position,
+ )
+
+ right_panel = QWidget()
+ right_layout = QVBoxLayout(right_panel)
+ right_layout.setContentsMargins(0, 0, 0, 0)
+ right_panel.setFocusPolicy(Qt.FocusPolicy.StrongFocus)
+
+ right_layout.addWidget(self.plotter, 1)
+
+ plot_settings_scale_layout = QHBoxLayout()
+ plot_settings_scale_layout.setContentsMargins(0, 0, 0, 0)
+
+ plot_settings_group = QGroupBox("Plot Settings")
+ plot_layout = QGridLayout(plot_settings_group)
+
+ quality_label = QLabel("Quality:")
+ plot_layout.addWidget(quality_label, 0, 0)
+
+ self.quality_input = QComboBox()
+ self.quality_input.addItems(["Low", "Medium", "High", "Lossless"])
+ self.quality_input.setCurrentText("High")
+ self.quality_input.setToolTip(
+ "Plot detail level - fixed point targets: Low (~25k), Medium (~100k), High (~500k). "
+ "Lossless shows every sample but warns if display may be slow"
+ )
+ self.quality_input.currentTextChanged.connect(self._on_quality_changed)
+ plot_layout.addWidget(self.quality_input, 0, 1)
+
+ plot_layout.addWidget(QLabel("Refresh:"), 1, 0)
+ self.refresh_rate_input = QComboBox()
+ self.refresh_rate_input.addItems(list(REFRESH_RATE_MAP.keys()))
+ self.refresh_rate_input.setCurrentText("20 Hz")
+ self.refresh_rate_input.setToolTip(
+ "Plot refresh rate - lower rates allow more points in Lossless mode"
+ )
+ self.refresh_rate_input.currentTextChanged.connect(
+ self._on_refresh_rate_changed
+ )
+ plot_layout.addWidget(self.refresh_rate_input, 1, 1)
+
+ plot_layout.addWidget(QLabel("Display Window:"), 2, 0)
+ self.plot_window_input = QComboBox()
+ self.plot_window_input.addItems(["1s", "2s", "5s", "10s", "20s", "30s"])
+ self.plot_window_input.setCurrentText("5s")
+ self.plot_window_input.setToolTip(
+ "Width of display time window (used at next acquisition start)"
+ )
+ plot_layout.addWidget(self.plot_window_input, 2, 1)
+
+ plot_settings_scale_layout.addWidget(plot_settings_group, 1)
+
+ channel_a_scale_group = QGroupBox("Channel A")
+ channel_a_form = QFormLayout(channel_a_scale_group)
+
+ self.v_div_a_input = QComboBox()
+ self.v_div_a_input.addItems(
+ [
+ "0.5mV",
+ "1mV",
+ "2mV",
+ "5mV",
+ "10mV",
+ "20mV",
+ "50mV",
+ "100mV",
+ "200mV",
+ "500mV",
+ "1V",
+ "2V",
+ "5V",
+ "10V",
+ "20V",
+ ]
+ )
+ self.v_div_a_input.setCurrentText("100mV")
+ self.v_div_a_input.currentTextChanged.connect(self._on_v_div_a_changed)
+ channel_a_form.addRow("V/div:", self.v_div_a_input)
+
+ self.y_pos_a_input = QDoubleSpinBox()
+ self.y_pos_a_input.setRange(-5, 5)
+ self.y_pos_a_input.setValue(0)
+ self.y_pos_a_input.setSingleStep(0.1)
+ self.y_pos_a_input.setDecimals(1)
+ self.y_pos_a_input.setSuffix(" div")
+ self.y_pos_a_input.setToolTip("Vertical position in divisions from centre")
+ self.y_pos_a_input.valueChanged.connect(self._on_y_pos_a_changed)
+ channel_a_form.addRow("Y-pos:", self.y_pos_a_input)
+
+ self.clip_a_indicator = QLabel("")
+ self.clip_a_indicator.setVisible(False)
+ self.clip_a_indicator.setStyleSheet("font-size: 16px; font-weight: bold;")
+ self.clip_a_indicator.setToolTip(
+ "Data exceeds display range: adjust V/div or Y-pos to view full signal"
+ )
+ channel_a_form.addRow("", self.clip_a_indicator)
+
+ channel_a_scale_group.setStyleSheet(
+ "QGroupBox { color: #00BFFF; font-weight: bold; }"
+ )
+
+ channel_b_scale_group = QGroupBox("Channel B")
+ channel_b_form = QFormLayout(channel_b_scale_group)
+
+ self.v_div_b_input = QComboBox()
+ self.v_div_b_input.addItems(
+ [
+ "0.5mV",
+ "1mV",
+ "2mV",
+ "5mV",
+ "10mV",
+ "20mV",
+ "50mV",
+ "100mV",
+ "200mV",
+ "500mV",
+ "1V",
+ "2V",
+ "5V",
+ "10V",
+ "20V",
+ ]
+ )
+ self.v_div_b_input.setCurrentText("100mV")
+ self.v_div_b_input.currentTextChanged.connect(self._on_v_div_b_changed)
+ channel_b_form.addRow("V/div:", self.v_div_b_input)
+
+ self.y_pos_b_input = QDoubleSpinBox()
+ self.y_pos_b_input.setRange(-5, 5)
+ self.y_pos_b_input.setValue(0)
+ self.y_pos_b_input.setSingleStep(0.1)
+ self.y_pos_b_input.setDecimals(1)
+ self.y_pos_b_input.setSuffix(" div")
+ self.y_pos_b_input.setToolTip("Vertical position in divisions from centre")
+ self.y_pos_b_input.valueChanged.connect(self._on_y_pos_b_changed)
+ channel_b_form.addRow("Y-pos:", self.y_pos_b_input)
+
+ self.clip_b_indicator = QLabel("")
+ self.clip_b_indicator.setVisible(False)
+ self.clip_b_indicator.setStyleSheet("font-size: 16px; font-weight: bold;")
+ self.clip_b_indicator.setToolTip(
+ "Data exceeds display range: adjust V/div or Y-pos to view full signal"
+ )
+ channel_b_form.addRow("", self.clip_b_indicator)
+
+ channel_b_scale_group.setStyleSheet(
+ "QGroupBox { color: #FF4444; font-weight: bold; }"
+ )
+
+ plot_settings_scale_layout.addWidget(channel_a_scale_group, 1)
+ plot_settings_scale_layout.addWidget(channel_b_scale_group, 1)
+
+ right_layout.addLayout(plot_settings_scale_layout)
main_layout.addWidget(settings_panel)
- main_layout.addWidget(self.plotter, 1)
+ main_layout.addWidget(right_panel, 1)
- # Create a timer to sync status from plotter to main window
self.status_sync_timer = QTimer()
- self.status_sync_timer.timeout.connect(self.sync_status_from_plotter)
- # Timer will be started when acquisition begins
-
- # Connect signals
- self.start_button.clicked.connect(self.start_acquisition)
- self.stop_button.clicked.connect(self.stop_acquisition)
- self.apply_plot_settings_button.clicked.connect(self.apply_plot_settings)
+ self.status_sync_timer.timeout.connect(self._sync_status)
self.load_settings()
- def _create_status_label(self, font: QFont) -> QLabel:
- """Create a status label with error styling capability."""
- label = QLabel("0")
- label.setFont(font)
- return label
+ def _on_plotter_read_position(self, read_idx: int) -> None:
+ """Forward plotter read position to device for overflow detection.
+
+ Parameters
+ ----------
+ read_idx : int
+ The sample index the plotter last read up to.
+ """
+ if self._device is not None:
+ self._device.update_plotter_position(read_idx)
+
+ self._apply_clamp_constraints()
+
+ def _on_v_div_a_changed(self) -> None:
+ """Handle V/div change for channel A."""
+ self._update_y_axis_settings()
+
+ def _on_v_div_b_changed(self) -> None:
+ """Handle V/div change for channel B."""
+ self._update_y_axis_settings()
+
+ def _on_y_pos_a_changed(self) -> None:
+ """Handle Y-position change for channel A."""
+ self._update_y_axis_settings()
+
+ def _on_y_pos_b_changed(self) -> None:
+ """Handle Y-position change for channel B."""
+ self._update_y_axis_settings()
+
+ def _apply_plot_settings_to_plotter(self) -> None:
+ """Apply all plot settings from UI to plotter at acquisition start.
+
+ This ensures plotter settings are applied atomically when acquisition
+ starts, avoiding race conditions between UI changes and pipeline creation.
+ """
+ # Apply refresh rate
+ refresh_text = self.refresh_rate_input.currentText()
+ interval_ms, _ = REFRESH_RATE_MAP.get(refresh_text, (50, 50000))
+ self.plotter.update_interval_ms = interval_ms
+
+ # Apply quality (target_plot_points)
+ quality = self.quality_input.currentText()
+ if quality == "Lossless":
+ self.plotter.target_plot_points = None
+ else:
+ target_points = self._get_target_points_for_quality(quality)
+ self.plotter.target_plot_points = target_points
+
+ logger.info(
+ "Applied plot settings: quality='{}', target_points={}",
+ quality,
+ self.plotter.target_plot_points,
+ )
+
+ def _update_y_axis_settings(self) -> None:
+ """Update plotter with V/div and Y-pos settings."""
+ v_div_a = self._parse_v_div(self.v_div_a_input.currentText())
+ v_div_b = self._parse_v_div(self.v_div_b_input.currentText())
+ y_pos_a = self.y_pos_a_input.value()
+ y_pos_b = self.y_pos_b_input.value()
+
+ self.plotter.set_v_div_settings(
+ {0: v_div_a, 1: v_div_b},
+ {0: y_pos_a, 1: y_pos_b},
+ )
+
+ def _parse_v_div(self, text: str) -> float:
+ """Parse V/div text to value in volts.
+
+ Parameters
+ ----------
+ text : str
+ Text like "100mV", "1V", "2V".
+
+ Returns
+ -------
+ float
+ Value in volts.
+ """
+ text = text.strip().upper()
+ if text.endswith("MV"):
+ return float(text.rstrip("MV")) / 1000.0
+ if text.endswith("V"):
+ return float(text.rstrip("V"))
+ return float(text)
+
+ def _update_state_indicator(self, state: str) -> None:
+ """Update the state indicator colour based on current state."""
+ state_upper = state.upper()
+ if state_upper == "LIVE":
+ self.status_state.setText("LIVE")
+ self.status_state.setStyleSheet(
+ "background-color: #4CAF50; color: white; font-weight: bold;"
+ )
+ elif state_upper == "RECORDING":
+ self.status_state.setText("REC")
+ self.status_state.setStyleSheet(
+ "background-color: #FF0000; color: white; font-weight: bold;"
+ )
+ elif state_upper == "ERROR":
+ self.status_state.setText("ERROR")
+ self.status_state.setStyleSheet(
+ "background-color: #f44336; color: white; font-weight: bold;"
+ )
+ elif state_upper == "CONNECTED":
+ self.status_state.setText("CONN")
+ self.status_state.setStyleSheet(
+ "background-color: #2196F3; color: white; font-weight: bold;"
+ )
+ else:
+ self.status_state.setText("IDLE")
+ self.status_state.setStyleSheet(
+ "background-color: #9E9E9E; color: white; font-weight: bold;"
+ )
+
+ def _update_saturation_indicator(self, is_saturated: bool) -> None:
+ """Update the hardware saturation indicator.
+
+ Hardware saturation indicates ADC has clipped to max/min values because
+ the input signal exceeded the voltage range setting. This is distinct from
+ display clipping, which only indicates data exceeds the current view range.
+ """
+ if is_saturated:
+ self.status_saturation.setText("CLIPPING")
+ self.status_saturation.setStyleSheet(
+ "background-color: #f44336; color: white; font-weight: bold;"
+ )
+ else:
+ self.status_saturation.setText("OK")
+ self.status_saturation.setStyleSheet(
+ "background-color: #4CAF50; color: white; font-weight: bold;"
+ )
+
+ def _update_clip_indicators(self) -> None:
+ """Update per-channel display clipping indicators.
+
+ Display clipping indicates when signal data exceeds the current Y-axis
+ view range (V/div and Y-pos settings), not hardware ADC saturation.
+ """
+ clip_a = self.plotter.is_channel_clipped(0)
+ clip_b = self.plotter.is_channel_clipped(1)
+
+ if clip_a == 1:
+ self.clip_a_indicator.setText("DATA IS ▼")
+ self.clip_a_indicator.setVisible(True)
+ elif clip_a == -1:
+ self.clip_a_indicator.setText("DATA IS ▲")
+ self.clip_a_indicator.setVisible(True)
+ else:
+ self.clip_a_indicator.setText("")
+ self.clip_a_indicator.setVisible(False)
+
+ if clip_b == 1:
+ self.clip_b_indicator.setText("DATA IS ▼")
+ self.clip_b_indicator.setVisible(True)
+ elif clip_b == -1:
+ self.clip_b_indicator.setText("DATA IS ▲")
+ self.clip_b_indicator.setVisible(True)
+ else:
+ self.clip_b_indicator.setText("")
+ self.clip_b_indicator.setVisible(False)
+
+ def _parse_buffer_duration(self) -> float:
+ """Parse buffer duration from QComboBox text (e.g., '30s' -> 30.0)."""
+ buffer_text = self.ring_buffer_input.currentText()
+ return float(buffer_text.rstrip("s")) if buffer_text else 30.0
+
+ def _parse_downsample_ratio(self) -> int:
+ """Parse downsample ratio from QComboBox text (e.g., '10x' -> 10)."""
+ downsample_text = self.hw_downsample_input.currentText()
+ return int(downsample_text.rstrip("x")) if downsample_text else 1
+
+ def _apply_clamp_constraints(self) -> None:
+ """Clamp plot window and pre-trigger to buffer duration."""
+ buffer_duration = self._parse_buffer_duration()
+
+ # For combo boxes, we disable items that exceed the buffer duration
+ plot_window_items = ["1s", "2s", "5s", "10s", "20s", "30s"]
+ for i, item in enumerate(plot_window_items):
+ duration_val = float(item.rstrip("s"))
+ is_enabled = duration_val <= buffer_duration
+ self.plot_window_input.model().item(i).setEnabled(is_enabled)
+
+ pre_trigger_items = ["1s", "2s", "5s", "10s", "20s"]
+ for i, item in enumerate(pre_trigger_items):
+ duration_val = float(item.rstrip("s"))
+ is_enabled = duration_val <= buffer_duration
+ self.pre_trigger_input.model().item(i).setEnabled(is_enabled)
+
+ # If current selection is now disabled, clamp to largest valid value
+ current_plot_window = self.plot_window_input.currentText()
+ current_plot_val = float(current_plot_window.rstrip("s"))
+ if current_plot_val > buffer_duration:
+ # Find largest valid value
+ for item in reversed(plot_window_items):
+ if float(item.rstrip("s")) <= buffer_duration:
+ self.plot_window_input.setCurrentText(item)
+ break
+
+ current_pre_trigger = self.pre_trigger_input.currentText()
+ current_pre_val = float(current_pre_trigger.rstrip("s"))
+ if current_pre_val > buffer_duration:
+ for item in reversed(pre_trigger_items):
+ if float(item.rstrip("s")) <= buffer_duration:
+ self.pre_trigger_input.setCurrentText(item)
+ break
+
+ def _is_16bit_mode(self) -> bool:
+ """Check if current resolution is 16-bit (single channel only).
+
+ Returns
+ -------
+ bool
+ True if 16-bit resolution is selected, False otherwise.
+ """
+ return int(self.resolution_input.currentText()) == 16
+
+ def _on_toggle_channel_a(self) -> None:
+ is_checked = self.toggle_channel_a.isChecked()
+ self.toggle_channel_a.setText(
+ "Channel A: ON" if is_checked else "Channel A: OFF"
+ )
+
+ if is_checked and self._is_16bit_mode() and self.toggle_channel_b.isChecked():
+ self.toggle_channel_b.setChecked(False)
+ self._on_toggle_channel_b()
+
+ self._init_sample_rate_dropdown()
+
+ def _on_toggle_channel_b(self) -> None:
+ is_checked = self.toggle_channel_b.isChecked()
+ self.toggle_channel_b.setText(
+ "Channel B: ON" if is_checked else "Channel B: OFF"
+ )
+
+ if is_checked and self._is_16bit_mode() and self.toggle_channel_a.isChecked():
+ self.toggle_channel_a.setChecked(False)
+ self._on_toggle_channel_a()
+
+ self._init_sample_rate_dropdown()
+
+ def _update_directory_label(self) -> None:
+ """Update the directory label to show the current save location."""
+ directory = self._get_record_directory()
+ display_path = str(directory)
+ home = str(Path.home())
+ if display_path.startswith(home):
+ display_path = "~" + display_path[len(home) :]
+
+ metrics = QFontMetrics(self.record_directory_label.font())
+ display_path = metrics.elidedText(
+ display_path,
+ Qt.TextElideMode.ElideLeft,
+ self.record_directory_label.width(),
+ )
+ self.record_directory_label.setText(display_path)
+ self.record_directory_label.setCursorPosition(0)
+ self.record_directory_label.setToolTip(str(directory))
+
+ def _get_record_directory(self) -> Path:
+ """Get the configured record directory, creating it if needed.
+
+ Returns
+ -------
+ Path
+ The directory path where recordings are saved.
+ """
+ directory = self.settings.value(
+ "record_directory", str(Path.home() / "picostream")
+ )
+ path = Path(directory)
+ path.mkdir(parents=True, exist_ok=True)
+ return path
+
+ def _select_record_directory(self) -> None:
+ """Open file dialog to select record directory."""
+ default_dir = str(self._get_record_directory())
+
+ path = QFileDialog.getExistingDirectory(
+ self,
+ "Select Record Directory",
+ default_dir,
+ QFileDialog.Option.ShowDirsOnly,
+ )
+
+ if path:
+ self.settings.setValue("record_directory", path)
+ self._update_directory_label()
+ logger.info("Record directory changed to: {}", path)
+
+ def _get_next_sequence_number(self, directory: str) -> int:
+ """Get the next sequence number for today's recordings in the directory.
+
+ Parameters
+ ----------
+ directory : str
+ The directory to scan for existing recordings.
+
+ Returns
+ -------
+ int
+ The next sequence number (1-based) for today.
+ """
+ today_str = datetime.now().strftime("%Y%m%d")
+ pattern = f"record_{today_str}_"
+ max_seq = 0
+
+ if not os.path.exists(directory):
+ return 1
+
+ for entry in os.listdir(directory):
+ if entry.startswith(pattern) and entry.endswith(".zarr"):
+ parts = entry.replace(".zarr", "").split("_")
+ if len(parts) >= 3:
+ try:
+ seq = int(parts[2])
+ max_seq = max(max_seq, seq)
+ except ValueError:
+ pass
+
+ return max_seq + 1
+
+ def _get_record_path(self) -> str:
+ """Generate a record path in the configured directory.
+
+ Returns
+ -------
+ str
+ Full path to the new recording file.
+ """
+ directory = self._get_record_directory()
+
+ today_str = datetime.now().strftime("%Y%m%d")
+ time_str = datetime.now().strftime("%H%M%S")
+ seq_num = self._get_next_sequence_number(str(directory))
+
+ filename = f"record_{today_str}_{seq_num:03d}_{time_str}.zarr"
+ return str(directory / filename)
+
+ def _on_start_record(self) -> None:
+ """Handle Start Recording button click."""
+ if self._device is None:
+ QMessageBox.warning(
+ self, "Record Error", "Device not available. Is acquisition running?"
+ )
+ return
+
+ record_path = self._get_record_path()
+ filename = os.path.basename(record_path)
+ logger.info("Starting recording: {}", filename)
+ self.status_record.setText(f"Starting: {filename}")
+
+ pre_trigger_text = self.pre_trigger_input.currentText()
+ lookback_seconds = float(pre_trigger_text.rstrip("s"))
+
+ # Validate pre-trigger doesn't exceed buffer duration
+ buffer_duration = self._parse_buffer_duration()
+ if lookback_seconds > buffer_duration:
+ QMessageBox.warning(
+ self,
+ "Invalid Pre-trigger",
+ f"Pre-trigger ({lookback_seconds}s) exceeds buffer duration ({buffer_duration}s).\n\n"
+ f"Please reduce pre-trigger or increase buffer duration.",
+ )
+ return
+
+ success = self._device.start_save(lookback_seconds, record_path)
+ if success:
+ self._is_recording = True
+ logger.info("Started recording to: {}", record_path)
+ self._apply_ui_state("recording")
+ else:
+ QMessageBox.warning(
+ self, "Record Error", "Failed to start recording. Check device status."
+ )
+
+ def _on_finish_keep(self) -> None:
+ """Handle Stop & Keep button click."""
+ if not self._device:
+ return
+
+ try:
+ self._update_record_status("stopping")
+ self._apply_ui_state("stopping")
+
+ self._pending_save_path = self._device.stop_save(keep=True)
+ self._save_start_time = time.time()
+
+ self._save_poll_timer = QTimer(self)
+ self._save_poll_timer.timeout.connect(self._check_save_done_keep)
+ self._save_poll_timer.start(100)
+ except Exception as e:
+ logger.exception("Error in _on_finish_keep: {}", e)
+ self._update_record_status("error")
+ self._apply_ui_state("live")
+ QMessageBox.warning(self, "Error", f"Error stopping recording:\n{e}")
+
+ def _on_finish_discard(self) -> None:
+ """Handle Stop & Discard button click."""
+ if not self._device:
+ return
+
+ try:
+ self._update_record_status("stopping")
+ self._apply_ui_state("stopping")
+
+ self._device.stop_save(keep=False)
+ self._save_start_time = time.time()
+
+ self._save_poll_timer = QTimer(self)
+ self._save_poll_timer.timeout.connect(self._check_save_done_discard)
+ self._save_poll_timer.start(100)
+ except Exception as e:
+ logger.exception("Error in _on_finish_discard: {}", e)
+ self._update_record_status("error")
+ self._apply_ui_state("live")
+ QMessageBox.warning(self, "Error", f"Error stopping recording:\n{e}")
+
+ def _check_save_done_keep(self) -> None:
+ """Poll for save completion (keep mode)."""
+ try:
+ # Check for timeout
+ if self._save_start_time and (
+ time.time() - self._save_start_time > self._save_timeout_seconds
+ ):
+ self._save_poll_timer.stop()
+ logger.error(
+ "Save operation timed out after {} seconds",
+ self._save_timeout_seconds,
+ )
+ QMessageBox.warning(
+ self,
+ "Save Timeout",
+ f"Save operation timed out after {self._save_timeout_seconds} seconds.\n\n"
+ "The file may be incomplete.",
+ )
+ self._is_recording = False
+ self._update_record_status("error")
+ self._save_poll_timer = None
+ self._apply_ui_state("live")
+ return
+
+ if self._device.is_save_finished():
+ self._save_poll_timer.stop()
+ path = self._pending_save_path
+ if path:
+ QMessageBox.information(
+ self, "Record Complete", f"File saved to:\n{path}"
+ )
+ save_dir = os.path.dirname(path)
+ self.settings.setValue("last_save_directory", save_dir)
+ logger.info("Recording finished and kept: {}", path)
+ else:
+ logger.info("Recording finished (no path returned)")
+
+ self._is_recording = False
+ self._update_record_status("complete")
+ self._save_poll_timer = None
+ self._apply_ui_state("live")
+ except Exception as e:
+ logger.exception("Error in save poll (keep mode): {}", e)
+ if self._save_poll_timer:
+ self._save_poll_timer.stop()
+ self._save_poll_timer = None
+ self._is_recording = False
+ self._update_record_status("error")
+ QMessageBox.warning(
+ self, "Error", f"Error during save completion check:\n{e}"
+ )
+ self._apply_ui_state("error")
+
+ def _check_save_done_discard(self) -> None:
+ """Poll for save completion (discard mode)."""
+ try:
+ # Check for timeout
+ if self._save_start_time and (
+ time.time() - self._save_start_time > self._save_timeout_seconds
+ ):
+ self._save_poll_timer.stop()
+ logger.error(
+ "Save operation timed out after {} seconds",
+ self._save_timeout_seconds,
+ )
+ QMessageBox.warning(
+ self,
+ "Save Timeout",
+ f"Save operation timed out after {self._save_timeout_seconds} seconds.\n\n"
+ "Operation cancelled.",
+ )
+ self._is_recording = False
+ self._update_record_status("error")
+ self._save_poll_timer = None
+ self._apply_ui_state("live")
+ return
+
+ if self._device.is_save_finished():
+ self._save_poll_timer.stop()
+ self._is_recording = False
+ self._update_record_status("idle")
+ self._save_poll_timer = None
+ self._apply_ui_state("live")
+ logger.info("Recording discarded")
+ except Exception as e:
+ logger.exception("Error in save poll (discard mode): {}", e)
+ if self._save_poll_timer:
+ self._save_poll_timer.stop()
+ self._save_poll_timer = None
+ self._is_recording = False
+ self._update_record_status("error")
+ QMessageBox.warning(
+ self, "Error", f"Error during save completion check:\n{e}"
+ )
+ self._apply_ui_state("error")
+
+ def _update_record_status(self, state: str) -> None:
+ """Update the record status indicator.
+
+ Parameters
+ ----------
+ state : str
+ One of: "idle", "saving", "stopping", "error", "complete".
+ """
+ if state == "stopping":
+ self.status_record.setText("Stopping...")
+ self.status_record.setStyleSheet(
+ "background-color: #FF9800; color: white; font-weight: bold;"
+ )
+ elif state == "idle":
+ self.status_record.setText("Idle")
+ self.status_record.setStyleSheet("background-color: #E0E0E0; color: black;")
+ elif state == "complete":
+ self.status_record.setText("Saved")
+ self.status_record.setStyleSheet(
+ "background-color: #4CAF50; color: white; font-weight: bold;"
+ )
+ elif state == "error":
+ self.status_record.setText("Error")
+ self.status_record.setStyleSheet(
+ "background-color: #f44336; color: white; font-weight: bold;"
+ )
+
+ def _apply_ui_state(self, state: str) -> None:
+ """Apply a complete, consistent UI state.
+
+ Centralises all UI state changes to prevent partial/inconsistent states.
+ Valid states: idle, connected, live, recording, stopping, error.
+
+ Parameters
+ ----------
+ state : str
+ The target UI state.
+ """
+ state = state.lower()
+ logger.debug("Applying UI state: {}", state)
- def _set_status_error_style(self, label: QLabel, is_error: bool) -> None:
- """Apply error styling (red background) to a status label."""
- if is_error:
- label.setStyleSheet("background-color: #ffcccc; padding: 2px;")
+ if state == "idle":
+ self.connect_button.setEnabled(True)
+ self.connect_button.setText("Connect Device")
+ self.start_button.setEnabled(False)
+ self.stop_button.setEnabled(False)
+ self.start_record_button.setEnabled(False)
+ self._show_record_buttons(recording=False)
+ self._set_hardware_inputs_enabled(True)
+ self.quality_input.setEnabled(True)
+ self.refresh_rate_input.setEnabled(True)
+ self.plot_window_input.setEnabled(True)
+ self._update_state_indicator("idle")
+ elif state == "connected":
+ self.connect_button.setEnabled(True)
+ self.connect_button.setText("Disconnect Device")
+ self.start_button.setEnabled(True)
+ self.stop_button.setEnabled(False)
+ self.start_record_button.setEnabled(False)
+ self._show_record_buttons(recording=False)
+ self.serial_code_input.setEnabled(False)
+ self.resolution_input.setEnabled(False)
+ self.toggle_channel_a.setEnabled(True)
+ self.toggle_channel_b.setEnabled(True)
+ self.voltage_range_a.setEnabled(True)
+ self.voltage_range_b.setEnabled(True)
+ self.offset_a.setEnabled(True)
+ self.offset_b.setEnabled(True)
+ self._set_acq_params_enabled(True)
+ self.quality_input.setEnabled(True)
+ self.refresh_rate_input.setEnabled(True)
+ self.plot_window_input.setEnabled(True)
+ self._update_state_indicator("connected")
+ elif state == "live":
+ self.connect_button.setEnabled(False)
+ self.connect_button.setText("Disconnect Device")
+ self.start_button.setEnabled(False)
+ self.stop_button.setEnabled(True)
+ self.start_record_button.setEnabled(True)
+ self._show_record_buttons(recording=False)
+ self._set_hardware_inputs_enabled(False)
+ self.quality_input.setEnabled(False)
+ self.refresh_rate_input.setEnabled(False)
+ self.plot_window_input.setEnabled(False)
+ self._update_state_indicator("live")
+ elif state == "recording":
+ self.connect_button.setEnabled(False)
+ self.connect_button.setText("Disconnect Device")
+ self.start_button.setEnabled(False)
+ self.stop_button.setEnabled(False)
+ self.finish_keep_button.setEnabled(True)
+ self.finish_discard_button.setEnabled(True)
+ self._show_record_buttons(recording=True)
+ self._set_hardware_inputs_enabled(False)
+ self.quality_input.setEnabled(False)
+ self.refresh_rate_input.setEnabled(False)
+ self.plot_window_input.setEnabled(False)
+ self._update_state_indicator("recording")
+ elif state == "stopping":
+ self.connect_button.setEnabled(False)
+ self.connect_button.setText("Disconnect Device")
+ self.start_button.setEnabled(False)
+ self.stop_button.setEnabled(False)
+ self.finish_keep_button.setEnabled(False)
+ self.finish_discard_button.setEnabled(False)
+ self.quality_input.setEnabled(False)
+ self.refresh_rate_input.setEnabled(False)
+ self.plot_window_input.setEnabled(False)
+ self._update_state_indicator("recording")
+ elif state == "error":
+ self.connect_button.setEnabled(True)
+ if self._devices_connected:
+ self.connect_button.setText("Disconnect Device")
+ else:
+ self.connect_button.setText("Connect Device")
+ self.start_button.setEnabled(False)
+ self.stop_button.setEnabled(False)
+ self.start_record_button.setEnabled(False)
+ self._show_record_buttons(recording=False)
+ self.serial_code_input.setEnabled(not self._devices_connected)
+ self.resolution_input.setEnabled(not self._devices_connected)
+ self.toggle_channel_a.setEnabled(True)
+ self.toggle_channel_b.setEnabled(True)
+ self.voltage_range_a.setEnabled(True)
+ self.voltage_range_b.setEnabled(True)
+ self.offset_a.setEnabled(True)
+ self.offset_b.setEnabled(True)
+ self._set_acq_params_enabled(True)
+ self.quality_input.setEnabled(True)
+ self.refresh_rate_input.setEnabled(True)
+ self.plot_window_input.setEnabled(True)
+ self._update_state_indicator("error")
else:
- label.setStyleSheet("")
-
- def select_output_file(self) -> None:
- """Open a dialog to select the output HDF5 file."""
- file_name, _ = QFileDialog.getSaveFileName(
- self, "Select Output File", self.output_file_input.text(), "HDF5 Files (*.hdf5)"
- )
- if file_name:
- self.output_file_input.setText(file_name)
-
- def apply_plot_settings(self) -> None:
- """Apply plot settings to the live plotter."""
- window_seconds = self.plot_window_input.value()
- self.plotter.set_display_window(window_seconds)
-
- if self.y_axis_auto_checkbox.isChecked():
- self.plotter.set_y_limits(None, None)
+ logger.warning("Unknown UI state: {}", state)
+
+ def _show_record_buttons(self, recording: bool) -> None:
+ """Toggle between Start Recording and Stop buttons.
+
+ Parameters
+ ----------
+ recording : bool
+ True to show keep/discard buttons, False to show start button.
+ """
+ if recording:
+ self.start_record_button.hide()
+ self.finish_keep_button.show()
+ self.finish_discard_button.show()
+ else:
+ self.start_record_button.show()
+ self.finish_keep_button.hide()
+ self.finish_discard_button.hide()
+
+ def _on_load_file(self) -> None:
+ """Handle Load File button click - open Zarr viewer."""
+ default_dir = self._get_record_directory()
+ last_dir = self.settings.value("last_save_directory", str(default_dir))
+
+ path = QFileDialog.getExistingDirectory(
+ self,
+ "Open Zarr File",
+ last_dir,
+ QFileDialog.Option.ShowDirsOnly,
+ )
+
+ if not path:
+ return
+
+ try:
+ viewer = ZarrViewerWindow(path)
+ viewer.show()
+ self._zarr_viewers.append(viewer)
+ viewer.destroyed.connect(
+ lambda: (
+ self._zarr_viewers.remove(viewer)
+ if viewer in self._zarr_viewers
+ else None
+ )
+ )
+ logger.info("Opened Zarr viewer for: {}", path)
+ except Exception as e:
+ logger.exception("Failed to open Zarr file: {}", e)
+ QMessageBox.critical(self, "Error", f"Failed to open file:\n{e}")
+
+ def _set_acq_params_enabled(self, enabled: bool) -> None:
+ """Enable or disable acquisition parameters (sample rate, downsample, etc).
+
+ Parameters
+ ----------
+ enabled : bool
+ True to enable acq params, False to disable them.
+ """
+ self.sample_rate_input.setEnabled(enabled)
+ self.hw_downsample_input.setEnabled(enabled)
+ self.downsample_mode_input.setEnabled(enabled)
+ self.bandwidth_limiter_input.setEnabled(enabled)
+ self.ring_buffer_input.setEnabled(enabled)
+
+ def _set_hardware_inputs_enabled(self, enabled: bool) -> None:
+ """Enable or disable all hardware setting input fields.
+
+ Parameters
+ ----------
+ enabled : bool
+ True to enable all inputs, False to disable all inputs.
+ """
+ self.serial_code_input.setEnabled(enabled)
+ self.resolution_input.setEnabled(enabled)
+ self.toggle_channel_a.setEnabled(enabled)
+ self.toggle_channel_b.setEnabled(enabled)
+ self.voltage_range_a.setEnabled(enabled)
+ self.voltage_range_b.setEnabled(enabled)
+ self.offset_a.setEnabled(enabled)
+ self.offset_b.setEnabled(enabled)
+ self._set_acq_params_enabled(enabled)
+
+ def toggle_device_connection(self) -> None:
+ """Connect or disconnect from the hardware device."""
+ if self._devices_connected:
+ self._disconnect_device()
else:
- y_min = self.y_min_input.value()
- y_max = self.y_max_input.value()
- self.plotter.set_y_limits(y_min, y_max)
+ self._connect_device()
+
+ def _connect_device(self) -> None:
+ """Connect to the Picoscope device."""
+ logger.debug(
+ f"[THREAD] _connect_device called from thread: {threading.current_thread().name}"
+ )
+
+ serial_code = self.serial_code_input.text().strip()
+ is_mock = serial_code.upper() == "MOCK"
+
+ if is_mock:
+ logger.warning("Using MOCK device (test mode)")
+
+ try:
+ if is_mock:
+ self._daemon = None
+ self._device = MockPicoscopeBufferedStream(
+ device_id="picoscope_stream",
+ serial_code="MOCK",
+ resolution=int(self.resolution_input.currentText()),
+ )
+ self._device.connect()
+ else:
+ logger.debug("[THREAD] Creating LabDaemon...")
+ self._daemon = ld.LabDaemon()
+ self._daemon.register_plugins(
+ devices={"PicoscopeBufferedStream": PicoscopeBufferedStream},
+ tasks={},
+ )
+
+ self._device = self._daemon.add_device(
+ device_id="picoscope_stream",
+ device_type="PicoscopeBufferedStream",
+ resolution=int(self.resolution_input.currentText()),
+ )
+
+ if serial_code:
+ self._device.set_connection_id(serial_code)
+
+ self._daemon.connect_device(self._device, timeout=10.0)
+
+ self._devices_connected = True
+ logger.info("Device connected successfully")
+ self._apply_ui_state("connected")
+
+ except ld.DeviceTimeoutError as e:
+ logger.exception("Device connection timed out")
+ QMessageBox.critical(
+ self,
+ "Connection Timeout",
+ f"Device connection timed out:\n\n{e}\n\nCheck hardware connections and power.",
+ )
+ self._cleanup_device()
+ except Exception as e:
+ logger.exception("Failed to connect to device")
+ QMessageBox.critical(
+ self,
+ "Connection Error",
+ f"Failed to connect to device:\n\n{e}",
+ )
+ self._cleanup_device()
+
+ def _disconnect_device(self) -> None:
+ """Disconnect from the Picoscope device."""
+ logger.debug(
+ f"[THREAD] _disconnect_device called from thread: {threading.current_thread().name}"
+ )
+
+ try:
+ if self._stream_check_timer:
+ self._stream_check_timer.stop()
+ self._stream_check_timer = None
+
+ if hasattr(self, "status_sync_timer"):
+ self.status_sync_timer.stop()
+
+ if self._device and self._device.is_saving:
+ try:
+ self._device.stop_save(keep=True)
+ except Exception:
+ logger.exception("Error stopping save during disconnect")
+ finally:
+ self._is_recording = False
+
+ if self._device and self._device.is_streaming():
+ try:
+ self._device.stop_streaming()
+ except Exception:
+ logger.exception("Error stopping streaming during disconnect")
+
+ self.plotter.stop_updates()
+ self.plotter.set_ring_buffer(None)
+
+ self._cleanup_device()
+
+ logger.info("Device disconnected")
+
+ finally:
+ self._is_recording = False
+ self._devices_connected = False
+ self._apply_ui_state("idle")
def start_acquisition(self) -> None:
- """Start the background data acquisition."""
- self.start_button.setEnabled(False)
- self.stop_button.setEnabled(True)
+ """Start data acquisition (requires device to be connected)."""
+ logger.debug(
+ f"[THREAD] start_acquisition called from thread: {threading.current_thread().name}"
+ )
- output_file = self.output_file_input.text()
+ if not self._devices_connected or self._device is None:
+ QMessageBox.warning(
+ self,
+ "Not Connected",
+ "Please connect to device before starting acquisition.",
+ )
+ return
+
+ # Validate Lossless mode if selected
+ if not self._validate_lossless_at_start():
+ return
- # Stop plotter updates and clear all state BEFORE changing file path
self.plotter.stop_updates()
-
- # Clear the plotter's display and reset all internal state
- self.plotter.reset_for_new_file()
-
- # Ensure status bar remains hidden after reset
- self.plotter.hide_status_bar(hide=True)
-
- # Remove old file if it exists
- if os.path.exists(output_file):
- try:
- os.remove(output_file)
- logger.info(f"Removed existing file: {output_file}")
- except OSError as e:
- self.on_acquisition_error(f"Failed to remove old file: {e}")
- self.on_acquisition_finished()
- return
+ self.plotter.clear_curves()
+ self.plotter.set_ring_buffer(None)
+
+ self.plotter.rate_check_start_time = None
+ self.plotter.rate_check_start_samples = 0
+
+ enabled_channels = []
+ voltage_ranges = []
+ offsets_v = []
+
+ if self.toggle_channel_a.isChecked():
+ enabled_channels.append(0)
+ voltage_ranges.append(float(self.voltage_range_a.currentText()))
+ offsets_v.append(self.offset_a.value())
+
+ if self.toggle_channel_b.isChecked():
+ enabled_channels.append(1)
+ voltage_ranges.append(float(self.voltage_range_b.currentText()))
+ offsets_v.append(self.offset_b.value())
+
+ if not enabled_channels:
+ QMessageBox.warning(
+ self, "Configuration Error", "At least one channel must be enabled."
+ )
+ return
+
+ sample_rate_msps = float(self.sample_rate_input.currentText().rstrip(" MS/s"))
+ sample_rate_hz = sample_rate_msps * 1e6
+
+ downsample_mode = self.downsample_mode_input.currentText()
+ downsample_ratio = self._parse_downsample_ratio()
+
+ self.plotter.set_channel_visible(0, 0 in enabled_channels)
+ self.plotter.set_channel_visible(1, 1 in enabled_channels)
+
+ try:
+ buffer_duration_s = self._parse_buffer_duration()
+
+ self._device.configure_streaming(
+ sample_rate=sample_rate_hz,
+ channels=enabled_channels,
+ voltage_ranges=voltage_ranges,
+ buffer_duration_s=buffer_duration_s,
+ downsample_ratio=downsample_ratio,
+ downsample_mode=downsample_mode,
+ offsets_v=offsets_v,
+ bandwidth_limiter=self.bandwidth_limiter_input.currentText(),
+ )
+
+ self._device.start_streaming()
+
+ # Apply plot settings from UI (quality, refresh rate, window)
+ self._apply_plot_settings_to_plotter()
+
+ # Update plotter with actual acquisition settings before connecting
+ self.plotter.voltage_ranges = self._device._streaming_voltage_ranges.copy()
+ self.plotter.offsets_v = self._device._streaming_offsets_v.copy()
+
+ self.plotter.set_ring_buffer(self._device.ring_buffer)
+
+ window_text = self.plot_window_input.currentText()
+ window_seconds = float(window_text.rstrip("s"))
+ self.plotter.set_display_window(window_seconds)
+
+ except Exception as e:
+ logger.exception("Failed to start acquisition")
+ self._show_acquisition_error(str(e))
+ return
- # Now set the new file path (plotter is stopped and cleared)
- self.plotter.set_hdf5_path(output_file)
-
- # Apply plot settings before starting
- self.apply_plot_settings()
-
- settings = {
- "sample_rate_msps": self.sample_rate_input.value(),
- "resolution_bits": int(self.resolution_input.currentText()),
- "channel_range_str": self.voltage_range_input.currentText(),
- "output_file": output_file,
- "hardware_downsample": self.hw_downsample_input.value(),
- "downsample_mode": self.downsample_mode_input.currentText(),
- "offset_v": self.offset_v_input.value(),
- "max_buffer_seconds": self.max_buffer_input.value()
- if self.live_only_checkbox.isChecked()
- else None,
- "enable_live_plot": False,
- "is_gui_mode": True,
- "bandwidth_limiter": self.bandwidth_limiter_input.currentText().lower(),
- }
-
- self.thread = QThread()
- self.worker = StreamerWorker(settings)
- self.worker.moveToThread(self.thread)
- self.worker.stopRequested.connect(self.worker.stop, Qt.QueuedConnection)
-
- self.thread.started.connect(self.worker.run)
- self.worker.finished.connect(self.thread.quit)
- self.worker.finished.connect(self.worker.deleteLater)
- self.thread.finished.connect(self.thread.deleteLater)
- self.worker.finished.connect(self.on_acquisition_finished)
- self.worker.error.connect(self.on_acquisition_error)
-
- self.thread.start()
-
- # Start plotter updates - it will wait for the file to exist with valid data
self.plotter.start_updates()
-
- # Start status sync timer
- self.status_sync_timer.start(100) # Update every 100ms
- def stop_acquisition(self) -> None:
- """Stop the background data acquisition."""
+ self.status_sync_timer.start(100)
+
+ self._stream_check_timer = QTimer()
+ self._stream_check_timer.timeout.connect(self._check_stream_status)
+ self._stream_check_timer.start(500)
+
+ self._apply_ui_state("live")
+ logger.info("Acquisition started successfully")
+
+ def _check_stream_status(self) -> None:
+ """Check if streaming is still active, update UI if stopped."""
+ if self._device is None:
+ return
+
+ if not self._device.is_streaming():
+ # Check if streaming stopped due to an error
+ error = self._device.get_streaming_error()
+ if error:
+ logger.error("Streaming stopped with error: {}", error)
+ self._show_streaming_error(error)
+ else:
+ self._on_streaming_stopped()
+
+ def _show_streaming_error(self, error: str) -> None:
+ """Handle streaming stopped due to an error.
+
+ Parameters
+ ----------
+ error : str
+ The error description from the device.
+ """
+ logger.debug(
+ f"[THREAD] _show_streaming_error called from thread: {threading.current_thread().name}"
+ )
+
+ if self._stream_check_timer:
+ self._stream_check_timer.stop()
+ self._stream_check_timer = None
+
self.plotter.stop_updates()
- # Stop the status sync timer to save CPU
- if hasattr(self, 'status_sync_timer'):
+ self.plotter.set_ring_buffer(None)
+
+ if hasattr(self, "status_sync_timer"):
self.status_sync_timer.stop()
- if self.worker:
- self.worker.stopRequested.emit()
- self.stop_button.setEnabled(False)
- def on_acquisition_finished(self) -> None:
- """Handle acquisition completion (both success and failure)."""
- logger.info("Acquisition finished.")
+ self._is_recording = False
+
+ error_display = error.replace("PICO_", "").replace("_", " ").title()
+ self.status_record.setText(f"Error: {error_display[:25]}")
+
+ logger.info("Streaming stopped due to error: {}", error)
+ self._apply_ui_state("error")
+
+ def _on_streaming_stopped(self) -> None:
+ """Handle streaming stopped (either normally or via error)."""
+ logger.debug(
+ f"[THREAD] _on_streaming_stopped called from thread: {threading.current_thread().name}"
+ )
+
+ if self._stream_check_timer:
+ self._stream_check_timer.stop()
+ self._stream_check_timer = None
+
self.plotter.stop_updates()
- self.start_button.setEnabled(True)
- self.stop_button.setEnabled(False)
- self.thread = None
- self.worker = None
+ self.plotter.set_ring_buffer(None)
+
+ if hasattr(self, "status_sync_timer"):
+ self.status_sync_timer.stop()
+
+ self._is_recording = False
+
+ logger.info("Acquisition stopped")
+ self._apply_ui_state("connected")
+
+ def stop_acquisition(self) -> None:
+ """Stop data acquisition (device remains connected)."""
+ logger.debug(
+ f"[THREAD] stop_acquisition called from thread: {threading.current_thread().name}"
+ )
+
+ try:
+ if self._stream_check_timer:
+ self._stream_check_timer.stop()
+ self._stream_check_timer = None
+
+ self.plotter.stop_updates()
+ self.plotter.set_ring_buffer(None)
+
+ if hasattr(self, "status_sync_timer"):
+ self.status_sync_timer.stop()
+
+ if self._device and self._device.is_saving:
+ try:
+ self._device.stop_save(keep=True)
+ except Exception:
+ logger.exception("Error stopping save during acquisition stop")
+ finally:
+ self._is_recording = False
+
+ if self._device:
+ try:
+ self._device.stop_streaming()
+ except Exception:
+ logger.exception("Error stopping streaming")
+
+ logger.info("Acquisition stopped by user")
+
+ except Exception as e:
+ logger.exception("Unexpected error in stop_acquisition: {}", e)
+ finally:
+ self._is_recording = False
+ self._apply_ui_state("connected")
+
+ def _cleanup_device(self) -> None:
+ """Clean up the LabDaemon device and daemon."""
+ if self._device and self._device.is_saving:
+ logger.debug("[THREAD] Stopping save...")
+ self._device.stop_save(keep=True)
+ time.sleep(0.1)
+
+ if self._daemon:
+ all_threads_before = [t.name for t in threading.enumerate()]
+ logger.debug(
+ f"[THREAD] Threads before daemon.shutdown(): {all_threads_before}"
+ )
+ time.sleep(0.1)
+ try:
+ logger.debug("[THREAD] Calling daemon.shutdown()...")
+ self._daemon.shutdown(timeout=2.0)
+ logger.debug("[THREAD] daemon.shutdown() completed")
+ except Exception:
+ logger.exception("Error during daemon shutdown")
+ finally:
+ self._daemon = None
+ self._device = None
+ all_threads_after = [t.name for t in threading.enumerate()]
+ logger.debug(
+ f"[THREAD] Threads after daemon.shutdown(): {all_threads_after}"
+ )
+ elif self._device:
+ try:
+ logger.debug("[THREAD] Disconnecting mock device...")
+ self._device.disconnect()
+ except Exception:
+ logger.exception("Error during mock device disconnect")
+ finally:
+ self._device = None
- def on_acquisition_error(self, err_msg: str) -> None:
- """Handle acquisition error."""
- logger.error(f"Acquisition error: {err_msg}")
+ def _show_acquisition_error(self, err_msg: str) -> None:
+ """Show error dialog for acquisition failure."""
+ logger.error("Acquisition error: {}", err_msg)
+ self._update_state_indicator("error")
+ self.status_record.setText(f"Error: {err_msg[:30]}")
- def sync_status_from_plotter(self) -> None:
- """Sync status from the plotter to the main window status labels."""
+ if "PICO_NOT_FOUND" in err_msg:
+ QMessageBox.critical(
+ self,
+ "Device Not Found",
+ "Picoscope device not found.\n\n"
+ "Please check that:\n"
+ "• The device is connected via USB\n"
+ "• No other software is using the device\n"
+ "• The device is powered on",
+ )
+ else:
+ QMessageBox.critical(
+ self, "Acquisition Error", f"Failed to start acquisition:\n{err_msg}"
+ )
+
+ def _sync_status(self) -> None:
+ """Sync status from plotter and device."""
try:
- # Check if plotter exists and has been initialised
- if not hasattr(self, 'plotter') or not self.plotter:
+ if not hasattr(self, "plotter") or not self.plotter:
return
-
- # Helper function to safely strip HTML tags
- def strip_html(text: str) -> str:
- if not text:
- return text
- # Remove all HTML tags
- return re.sub(r'<[^>]+>', '', text)
-
- # Safely get text from label or return empty string
- def safe_get_text(label_name: str) -> str:
- if hasattr(self.plotter, label_name):
- label = getattr(self.plotter, label_name)
- if hasattr(label, 'text'):
- return strip_html(label.text())
- return ""
-
- # Update all status labels
- self.status_heartbeat.setText(safe_get_text('heartbeat_label'))
-
- # Update samples with QLCDNumber
- samples_text = safe_get_text('samples_label')
- try:
- samples_int = int(samples_text)
- self.status_samples.display(samples_int)
- except (ValueError, TypeError):
- self.status_samples.display(0)
-
- self.status_rate.setText(safe_get_text('rate_label'))
- self.status_latency.setText(safe_get_text('plotter_latency_label'))
-
- # Update errors with colour coding
- errors_text = safe_get_text('error_label')
- self.status_errors.setText(errors_text)
- has_errors = errors_text != "0" and errors_text != "-"
- self._set_status_error_style(self.status_errors, has_errors)
-
- self.status_saturation.setText(safe_get_text('saturation_label'))
- self.status_acquisition.setText(safe_get_text('acq_status_label'))
-
+
+ if self.plotter.ring_buffer is not None:
+ sample_count = self.plotter.ring_buffer.write_idx
+ self.status_samples.setText(
+ self.plotter.format_sample_count(sample_count)
+ )
+
+ # Calculate total acquisition duration from samples and rate
+ if self.plotter.acquisition_rate is None:
+ msg = (
+ "acquisition_rate required for duration calculation but not set"
+ )
+ raise RuntimeError(msg)
+
+ acq_duration_s = self.plotter.acquisition_rate.samples_to_seconds(
+ sample_count
+ )
+ self.status_duration.setText(f"{acq_duration_s:.1f}s")
+
+ # Only show rate after we have enough samples and time for stable reading
+ if self.plotter.rate_check_start_time is not None:
+ elapsed = time.perf_counter() - self.plotter.rate_check_start_time
+ if elapsed >= 0.5: # Wait 500ms before showing rate
+ rate_text = self._format_rate_for_display()
+ self.status_rate.setText(rate_text)
+ else:
+ self.status_rate.setText("-")
+
+ # Show lag if pipeline is available
+ if (
+ self.plotter.pipeline is not None
+ and self.plotter._last_read_position > 0
+ ):
+ lag_text = self.plotter.pipeline.format_lag(
+ sample_count, self.plotter._last_read_position
+ )
+ self.status_lag.setText(lag_text)
+ else:
+ self.status_lag.setText("-")
+
+ self._update_saturation_indicator(self.plotter.is_saturated)
+ self._update_clip_indicators()
+
+ # Update display status indicators
+ if self.plotter._display_cache:
+ total_display_points = sum(
+ len(data) for data in self.plotter._display_cache.values()
+ )
+ # Average per channel for display
+ n_channels = len(self.plotter._display_cache)
+ avg_points = total_display_points // n_channels if n_channels > 0 else 0
+
+ is_lossless = self.plotter.target_plot_points is None
+ if is_lossless:
+ self.status_display_points.setText(f"{avg_points:,} (Lossless)")
+ else:
+ self.status_display_points.setText(f"{avg_points:,}")
+
+ decimation = self.plotter._current_plot_decimation
+ self.status_decimation.setText(f"{decimation}×")
+ else:
+ self.status_display_points.setText("-")
+ self.status_decimation.setText("-")
+
+ if self._device:
+ save_status = self._device.get_save_status()
+ state = save_status.get("state", "idle")
+
+ if state == "saving":
+ total_s = save_status.get("total_seconds", 0)
+ pre_s = save_status.get("pre_trigger_seconds", 0)
+ self.status_record.setText(
+ f"Rec: {total_s:.1f}s ({pre_s:.1f}s pre)"
+ )
+ self._update_state_indicator("recording")
+ elif state == "error":
+ error_msg = save_status.get("error", "Unknown error")
+ self.status_record.setText(f"Error: {error_msg[:20]}")
+ self._update_state_indicator("error")
+ else:
+ self.status_record.setText("Idle")
+
except Exception as e:
- # Don't let errors in status sync crash the app
- logger.debug(f"Error syncing status from plotter: {e}")
- pass
+ logger.debug("Error syncing status: {}", e)
+
+ def _format_rate_for_display(self) -> str:
+ """Format the current acquisition rate for display.
+
+ Returns
+ -------
+ str
+ Formatted rate string (e.g., "62.50 MS/s").
+
+ Raises
+ ------
+ RuntimeError
+ If acquisition_rate is not available.
+ """
+ if self.plotter.rate_check_start_time is None:
+ return "-"
+
+ elapsed = time.perf_counter() - self.plotter.rate_check_start_time
+ if elapsed < 0.5:
+ return "-"
+
+ if self.plotter.acquisition_rate is None:
+ msg = "acquisition_rate required for rate display but not set"
+ raise RuntimeError(msg)
+
+ per_channel_rate_hz = self.plotter.acquisition_rate.per_channel_rate_hz
+ rate_msps = per_channel_rate_hz / 1e6
+ return f"{rate_msps:.1f} MS/s"
+
+ def keyPressEvent(self, event: QKeyEvent) -> None: # pyright: ignore
+ """Handle keyboard shortcuts for common operations."""
+ key = event.key()
+ modifiers = event.modifiers()
+
+ is_ctrl_pressed = modifiers & Qt.KeyboardModifier.ControlModifier
+
+ if key == Qt.Key.Key_Space:
+ self._handle_space_shortcut()
+ elif key == Qt.Key.Key_R:
+ self._handle_r_shortcut()
+ elif key == Qt.Key.Key_K:
+ self._handle_k_shortcut()
+ elif key == Qt.Key.Key_Delete:
+ self._handle_delete_shortcut()
+ elif key == Qt.Key.Key_1:
+ self.toggle_channel_a.click()
+ elif key == Qt.Key.Key_2:
+ self.toggle_channel_b.click()
+ elif key == Qt.Key.Key_O and is_ctrl_pressed:
+ self._on_load_file()
+ else:
+ super().keyPressEvent(event)
+
+ def _handle_space_shortcut(self) -> None:
+ """Toggle start/stop acquisition based on current state."""
+ if self.start_button.isEnabled() and self._devices_connected:
+ self.start_acquisition()
+ elif self.stop_button.isEnabled():
+ self.stop_acquisition()
+
+ def _handle_r_shortcut(self) -> None:
+ """Start recording if acquisition is active and not already recording."""
+ if (
+ self.start_record_button.isEnabled()
+ and self.start_record_button.isVisible()
+ ):
+ self._on_start_record()
+
+ def _handle_k_shortcut(self) -> None:
+ """Stop recording and keep the file."""
+ if self.finish_keep_button.isVisible():
+ self._on_finish_keep()
- def closeEvent(self, event: QCloseEvent) -> None:
+ def _handle_delete_shortcut(self) -> None:
+ """Stop recording and discard the file."""
+ if self.finish_discard_button.isVisible():
+ self._on_finish_discard()
+
+ def closeEvent(self, event: QCloseEvent) -> None: # pyright: ignore
"""Handle window close event."""
+ logger.debug(
+ f"[THREAD] closeEvent called from thread: {threading.current_thread().name}"
+ )
+ all_threads = [t.name for t in threading.enumerate()]
+ logger.debug("[THREAD] Active threads in closeEvent: {}", all_threads)
self.save_settings()
-
- # Stop the status sync timer first
- if hasattr(self, 'status_sync_timer'):
+
+ if hasattr(self, "status_sync_timer"):
self.status_sync_timer.stop()
- self.status_sync_timer = None
-
- # Stop plotter updates
- if hasattr(self, 'plotter') and self.plotter:
+
+ if hasattr(self, "plotter") and self.plotter:
self.plotter.stop_updates()
-
- if self.thread and self.thread.isRunning():
- self.stop_acquisition()
- self.thread.wait()
+
+ if self._stream_check_timer:
+ self._stream_check_timer.stop()
+
+ # Stop save poll timer if active
+ if self._save_poll_timer and self._save_poll_timer.isActive():
+ self._save_poll_timer.stop()
+ logger.debug("Stopped save poll timer during close")
+
+ self._cleanup_device()
+
+ for viewer in self._zarr_viewers[:]:
+ try:
+ viewer.close()
+ except Exception:
+ pass
+ self._zarr_viewers.clear()
+
+ all_threads_final = [t.name for t in threading.enumerate()]
+ logger.debug(
+ f"[THREAD] closeEvent completed. Active threads: {all_threads_final}"
+ )
event.accept()
def save_settings(self) -> None:
- """Save current settings."""
- self.settings.setValue("sample_rate", self.sample_rate_input.value())
+ """Save current settings to QSettings."""
+ self.settings.setValue("channel_a_enabled", self.toggle_channel_a.isChecked())
+ self.settings.setValue("channel_b_enabled", self.toggle_channel_b.isChecked())
+ self.settings.setValue("sample_rate", self.sample_rate_input.currentText())
self.settings.setValue("resolution", self.resolution_input.currentText())
- self.settings.setValue("voltage_range", self.voltage_range_input.currentText())
- self.settings.setValue("output_file", self.output_file_input.text())
- self.settings.setValue("hw_downsample", self.hw_downsample_input.value())
- self.settings.setValue("downsample_mode", self.downsample_mode_input.currentText())
- self.settings.setValue("offset_v", self.offset_v_input.value())
- self.settings.setValue("live_only_mode", self.live_only_checkbox.isChecked())
- self.settings.setValue("max_buffer_seconds", self.max_buffer_input.value())
- self.settings.setValue("plot_window", self.plot_window_input.value())
- self.settings.setValue("y_axis_auto", self.y_axis_auto_checkbox.isChecked())
-
- # Only save Y-axis limits if they're actually being used (not in auto mode)
- if not self.y_axis_auto_checkbox.isChecked():
- self.settings.setValue("y_min", self.y_min_input.value())
- self.settings.setValue("y_max", self.y_max_input.value())
-
- self.settings.setValue("bandwidth_limiter", self.bandwidth_limiter_input.currentText())
+ self.settings.setValue("voltage_range_a", self.voltage_range_a.currentText())
+ self.settings.setValue("voltage_range_b", self.voltage_range_b.currentText())
+ self.settings.setValue("offset_a", self.offset_a.value())
+ self.settings.setValue("offset_b", self.offset_b.value())
+ self.settings.setValue("hw_downsample", self.hw_downsample_input.currentText())
+ self.settings.setValue(
+ "downsample_mode", self.downsample_mode_input.currentText()
+ )
+ self.settings.setValue("plot_window", self.plot_window_input.currentText())
+ self.settings.setValue(
+ "ring_buffer_duration", self.ring_buffer_input.currentText()
+ )
+ self.settings.setValue(
+ "pre_trigger_seconds", self.pre_trigger_input.currentText()
+ )
+
+ self.settings.setValue(
+ "bandwidth_limiter", self.bandwidth_limiter_input.currentText()
+ )
+
+ self.settings.setValue("quality", self.quality_input.currentText())
+ self.settings.setValue("refresh_rate", self.refresh_rate_input.currentText())
+
+ self.settings.setValue("serial_code", self.serial_code_input.text())
+
+ self.settings.setValue("v_div_a", self.v_div_a_input.currentText())
+ self.settings.setValue("v_div_b", self.v_div_b_input.currentText())
+ self.settings.setValue("y_pos_a", self.y_pos_a_input.value())
+ self.settings.setValue("y_pos_b", self.y_pos_b_input.value())
def load_settings(self) -> None:
- """Load settings."""
- self.sample_rate_input.setValue(self.settings.value("sample_rate", 62.5, type=float))
+ """Load settings from QSettings."""
+ channel_a_enabled = self.settings.value("channel_a_enabled", True, type=bool)
+ channel_b_enabled = self.settings.value("channel_b_enabled", True, type=bool)
+
+ # Enforce 16-bit single-channel constraint on loaded settings
+ resolution = int(self.settings.value("resolution", "12"))
+ if resolution == 16 and channel_a_enabled and channel_b_enabled:
+ channel_b_enabled = False
+
+ self.toggle_channel_a.setChecked(channel_a_enabled)
+ self.toggle_channel_b.setChecked(channel_b_enabled)
+ self._on_toggle_channel_a()
+ self._on_toggle_channel_b()
+
+ # Initialize sample rate dropdown based on resolution/channels first
+ self._init_sample_rate_dropdown()
+
+ sample_rate_val = self.settings.value("sample_rate", "62.5 MS/s")
+ if isinstance(sample_rate_val, float):
+ sample_rate_val = f"{sample_rate_val:.1f} MS/s"
+ elif not isinstance(sample_rate_val, str):
+ sample_rate_val = "62.5 MS/s"
+ # Try to set; if invalid, default to 62.5 MS/s
+ idx = self.sample_rate_input.findText(sample_rate_val)
+ if idx >= 0:
+ self.sample_rate_input.setCurrentIndex(idx)
+ else:
+ self.sample_rate_input.setCurrentText("62.5 MS/s")
+
self.resolution_input.setCurrentText(self.settings.value("resolution", "12"))
- self.voltage_range_input.setCurrentText(
- self.settings.value("voltage_range", "PS5000A_20V")
+ self.voltage_range_a.setCurrentText(
+ str(self.settings.value("voltage_range_a", "20.0"))
)
- self.output_file_input.setText(self.settings.value("output_file", "output.hdf5"))
- self.hw_downsample_input.setValue(self.settings.value("hw_downsample", 1, type=int))
+ self.voltage_range_b.setCurrentText(
+ str(self.settings.value("voltage_range_b", "20.0"))
+ )
+ self.offset_a.setValue(self.settings.value("offset_a", 0.0, type=float))
+ self.offset_b.setValue(self.settings.value("offset_b", 0.0, type=float))
+
+ downsample_val = self.settings.value("hw_downsample", "1x")
+ if isinstance(downsample_val, int):
+ downsample_val = f"{downsample_val}x"
+ elif not isinstance(downsample_val, str):
+ downsample_val = "1x"
+ self.hw_downsample_input.setCurrentText(downsample_val)
+
self.downsample_mode_input.setCurrentText(
- self.settings.value("downsample_mode", "average")
- )
- self.offset_v_input.setValue(self.settings.value("offset_v", 0.0, type=float))
- live_only = self.settings.value("live_only_mode", False, type=bool)
- self.live_only_checkbox.setChecked(live_only)
- self.max_buffer_input.setValue(
- self.settings.value("max_buffer_seconds", 1.0, type=float)
- )
- self.max_buffer_input.setEnabled(live_only)
-
- self.plot_window_input.setValue(self.settings.value("plot_window", 0.5, type=float))
- y_axis_auto = self.settings.value("y_axis_auto", True, type=bool)
- self.y_axis_auto_checkbox.setChecked(y_axis_auto)
-
- # Only load Y-axis limits if they exist in settings (were saved in manual mode)
- # Otherwise use sensible defaults
- if self.settings.contains("y_min") and self.settings.contains("y_max"):
- self.y_min_input.setValue(self.settings.value("y_min", type=float))
- self.y_max_input.setValue(self.settings.value("y_max", type=float))
- else:
- # Use sensible defaults if no saved values
- self.y_min_input.setValue(-1000.0)
- self.y_max_input.setValue(1000.0)
-
- self.y_min_input.setEnabled(not y_axis_auto)
- self.y_max_input.setEnabled(not y_axis_auto)
-
+ self.settings.value("downsample_mode", "none")
+ )
+
+ plot_window_val = self.settings.value("plot_window", "5s")
+ if isinstance(plot_window_val, float):
+ plot_window_val = f"{plot_window_val:.0f}s"
+ elif not isinstance(plot_window_val, str):
+ plot_window_val = "5s"
+ self.plot_window_input.setCurrentText(plot_window_val)
+
+ buffer_val = self.settings.value("ring_buffer_duration", "30s")
+ if isinstance(buffer_val, (int, float)):
+ buffer_val = f"{int(buffer_val)}s"
+ elif not isinstance(buffer_val, str):
+ buffer_val = "30s"
+ self.ring_buffer_input.setCurrentText(buffer_val)
+
+ pre_trigger_val = self.settings.value("pre_trigger_seconds", "10s")
+ if isinstance(pre_trigger_val, float):
+ pre_trigger_val = f"{pre_trigger_val:.0f}s"
+ elif not isinstance(pre_trigger_val, str):
+ pre_trigger_val = "10s"
+ self.pre_trigger_input.setCurrentText(pre_trigger_val)
+
self.bandwidth_limiter_input.setCurrentText(
self.settings.value("bandwidth_limiter", "Full")
)
+ self.quality_input.setCurrentText(self.settings.value("quality", "High"))
+ self.refresh_rate_input.setCurrentText(
+ self.settings.value("refresh_rate", "20 Hz")
+ )
+
+ self.serial_code_input.setText(self.settings.value("serial_code", ""))
+
+ self.v_div_a_input.setCurrentText(self.settings.value("v_div_a", "100mV"))
+ self.v_div_b_input.setCurrentText(self.settings.value("v_div_b", "100mV"))
+ self.y_pos_a_input.setValue(self.settings.value("y_pos_a", 0.0, type=float))
+ self.y_pos_b_input.setValue(self.settings.value("y_pos_b", 0.0, type=float))
+
+ def _init_sample_rate_dropdown(self) -> None:
+ """Initialize the sample rate dropdown based on current resolution and channels."""
+ resolution = int(self.resolution_input.currentText())
+ n_channels = (1 if self.toggle_channel_a.isChecked() else 0) + (
+ 1 if self.toggle_channel_b.isChecked() else 0
+ )
+ if n_channels == 0:
+ n_channels = 1 # Default to 1 channel if none selected yet
+
+ rates = self._get_available_rates(resolution, n_channels)
+ current_text = self.sample_rate_input.currentText()
+
+ self.sample_rate_input.clear()
+ for rate_text, _rate_val in rates:
+ self.sample_rate_input.addItem(rate_text)
+
+ # Try to restore previous selection, or default to first available
+ idx = self.sample_rate_input.findText(current_text)
+ if idx >= 0:
+ self.sample_rate_input.setCurrentIndex(idx)
+ else:
+ # Default to 62.5 MS/s or first available
+ idx_625 = self.sample_rate_input.findText("62.5 MS/s")
+ if idx_625 >= 0:
+ self.sample_rate_input.setCurrentIndex(idx_625)
+ elif self.sample_rate_input.count() > 0:
+ self.sample_rate_input.setCurrentIndex(0)
+
+ def _get_available_rates(
+ self, resolution: int, n_channels: int
+ ) -> list[tuple[str, float]]:
+ """Get available sample rates for given resolution and channel count.
+
+ Parameters
+ ----------
+ resolution : int
+ Bit resolution (8, 12, 14, 15, 16).
+ n_channels : int
+ Number of enabled channels.
+
+ Returns
+ -------
+ list[tuple[str, float]]
+ List of (display_text, rate_msps) tuples.
+ """
+ max_total_rate = {
+ 8: 125.0,
+ 12: 62.5,
+ 14: 31.25,
+ 15: 15.625,
+ 16: 15.625,
+ }.get(resolution, 62.5)
+
+ # Per-channel max rate
+ max_per_channel = (
+ max_total_rate / n_channels if n_channels > 0 else max_total_rate
+ )
+
+ # All candidate rates (MS/s)
+ all_rates = [
+ 125.0,
+ 100.0,
+ 62.5,
+ 50.0,
+ 31.25,
+ 20.0,
+ 15.625,
+ 10.0,
+ 5.0,
+ 2.0,
+ 1.0,
+ 0.5,
+ 0.2,
+ 0.1,
+ ]
+
+ rates = []
+ for rate in all_rates:
+ if rate <= max_per_channel:
+ rates.append(
+ (f"{rate:.1f} MS/s" if rate == int(rate) else f"{rate} MS/s", rate)
+ )
+
+ return rates
+
+ def _on_resolution_changed(self) -> None:
+ """Update sample rate dropdown when resolution changes.
+
+ In 16-bit mode, the PicoScope 5000 series only supports one analog
+ channel. If both channels were enabled, Channel B is turned off.
+ """
+ if self._is_16bit_mode():
+ if self.toggle_channel_a.isChecked() and self.toggle_channel_b.isChecked():
+ self.toggle_channel_b.setChecked(False)
+ self._on_toggle_channel_b()
+
+ tooltip_suffix = " (16-bit: single channel only)"
+ else:
+ tooltip_suffix = ""
+
+ self.toggle_channel_a.setToolTip(
+ f"Toggle Channel A visibility (1){tooltip_suffix}"
+ )
+ self.toggle_channel_b.setToolTip(
+ f"Toggle Channel B visibility (2){tooltip_suffix}"
+ )
+
+ self._init_sample_rate_dropdown()
+
+ def _on_sample_rate_changed(self) -> None:
+ """Hook for future use when sample rate changes."""
+ pass
+
+ def _on_buffer_duration_changed(self) -> None:
+ self._apply_clamp_constraints()
+
+ def _on_pre_trigger_changed(self) -> None:
+ """Hook for future use when pre-trigger changes."""
+ pass
+
+ def _on_voltage_range_a_changed(self) -> None:
+ """Hook for future use when channel A voltage range changes."""
+ pass
+
+ def _on_offset_a_changed(self) -> None:
+ """Hook for future use when channel A offset changes."""
+ pass
+
+ def _on_voltage_range_b_changed(self) -> None:
+ """Hook for future use when channel B voltage range changes."""
+ pass
+
+ def _on_offset_b_changed(self) -> None:
+ """Hook for future use when channel B offset changes."""
+ pass
+
+ def _get_target_points_for_quality(self, quality: str) -> Optional[int]:
+ """Get target plot points for a quality setting.
+
+ Parameters
+ ----------
+ quality : str
+ Quality setting name.
+
+ Returns
+ -------
+ Optional[int]
+ Target points for the quality setting, or None for Lossless.
+ """
+ return QUALITY_POINTS.get(quality)
+
+ def _on_quality_changed(self) -> None:
+ """Handle quality dropdown change."""
+ quality = self.quality_input.currentText()
+ logger.info("Quality set to {} (will apply on next acquisition)", quality)
+
+ def _on_refresh_rate_changed(self) -> None:
+ """Handle refresh rate dropdown change."""
+ refresh_text = self.refresh_rate_input.currentText()
+ interval_ms, _ = REFRESH_RATE_MAP.get(refresh_text, (50, 50000))
+
+ # Update plotter timer interval only if actively plotting
+ if self.plotter.timer.isActive():
+ self.plotter.update_interval_ms = interval_ms
+ self.plotter.timer.stop()
+ self.plotter.timer.start(interval_ms)
+
+ logger.info(
+ "Refresh rate set to {} ({}ms, will apply on next acquisition)",
+ refresh_text,
+ interval_ms,
+ )
+
+ def _validate_lossless_at_start(self) -> bool:
+ """Validate Lossless mode when acquisition starts.
+
+ Shows a warning dialog if Lossless would exceed recommended limits.
+ Allows user to switch to High Quality, force Lossless anyway, or cancel.
+
+ Returns
+ -------
+ bool
+ True if acquisition should proceed, False to abort.
+ """
+ quality = self.quality_input.currentText()
+ if quality != "Lossless":
+ return True
+
+ window_text = self.plot_window_input.currentText()
+ window_seconds = float(window_text.rstrip("s"))
+
+ refresh_text = self.refresh_rate_input.currentText()
+ _, max_points = REFRESH_RATE_MAP.get(refresh_text, (50, 50_000_000))
+
+ # Calculate required points using current acquisition parameters
+ sample_rate_text = self.sample_rate_input.currentText()
+ sample_rate_msps = float(sample_rate_text.rstrip(" MS/s"))
+ hw_downsample = self._parse_downsample_ratio()
+
+ from picostream.acquisition_rate import AcquisitionRate, DownsampleMode
+
+ mode_text = self.downsample_mode_input.currentText().upper()
+ try:
+ downsample_mode = DownsampleMode[mode_text]
+ except KeyError:
+ downsample_mode = DownsampleMode.NONE
+
+ num_channels = (1 if self.toggle_channel_a.isChecked() else 0) + (
+ 1 if self.toggle_channel_b.isChecked() else 0
+ )
+ if num_channels == 0:
+ num_channels = 1
+
+ temp_rate = AcquisitionRate(
+ hardware_rate_hz=sample_rate_msps * 1e6 * num_channels,
+ num_channels=num_channels,
+ downsample_ratio=hw_downsample,
+ downsample_mode=downsample_mode,
+ )
+ required_points = temp_rate.seconds_to_samples(window_seconds)
+
+ if required_points <= max_points:
+ return True
+
+ dialog = LosslessWarningDialog(required_points, max_points, refresh_text, self)
+ QTimer.singleShot(30000, dialog.reject)
+ dialog.exec()
+ choice = dialog.user_choice()
+
+ if choice == "cancel":
+ return False
+
+ if choice == "high":
+ self.quality_input.blockSignals(True)
+ self.quality_input.setCurrentText("High")
+ self.quality_input.blockSignals(False)
+ logger.info("User switched to High Quality from Lossless warning")
+
+ # choice == "lossless" - proceed anyway
+ return True
+
+
+def _is_running_in_terminal() -> bool:
+ """Check if the application is running with an attached terminal."""
+ return hasattr(sys.stdout, "isatty") and sys.stdout.isatty()
+
def main() -> None:
"""GUI entry point."""
logger.remove()
- logger.add(sys.stderr, level="INFO")
+
+ log_dir = Path.home() / ".picostream" / "logs"
+ log_dir.mkdir(parents=True, exist_ok=True)
+ log_file = log_dir / "picostream.log"
+
+ logger.add(
+ str(log_file),
+ level="INFO",
+ rotation="10 MB",
+ retention=5,
+ enqueue=True,
+ )
+
+ if _is_running_in_terminal():
+ logger.add(sys.stdout, level="INFO", enqueue=True)
+
app = QApplication(sys.argv)
main_win = PicoStreamMainWindow()
main_win.show()
- sys.exit(app.exec_())
+ sys.exit(app.exec())
if __name__ == "__main__":
diff --git a/picostream/mock_device.py b/picostream/mock_device.py
new file mode 100644
index 0000000..7fd1759
--- /dev/null
+++ b/picostream/mock_device.py
@@ -0,0 +1,1477 @@
+"""
+Mock Picoscope DAQ Device
+
+Simulates the Picoscope DAQ for testing without hardware.
+Implements the same interface as the real device.
+"""
+
+import os
+import random
+import threading
+import time
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
+
+import numpy as np
+from loguru import logger
+
+from picostream.acquisition_rate import AcquisitionRate, DownsampleMode
+
+if TYPE_CHECKING:
+ pass
+
+
+class MockPicoscope:
+ """
+ Mock implementation of Picoscope DAQ.
+
+ Simulates data acquisition and streaming functionality.
+ """
+
+ # Valid voltage ranges matching the real Picoscope 5000a series
+ VOLTAGE_RANGES = {
+ 0.01: "PS5000A_10MV",
+ 0.02: "PS5000A_20MV",
+ 0.05: "PS5000A_50MV",
+ 0.1: "PS5000A_100MV",
+ 0.2: "PS5000A_200MV",
+ 0.5: "PS5000A_500MV",
+ 1.0: "PS5000A_1V",
+ 2.0: "PS5000A_2V",
+ 5.0: "PS5000A_5V",
+ 10.0: "PS5000A_10V",
+ 20.0: "PS5000A_20V",
+ }
+
+ # Max ADC value matches the real driver (16-bit signed max)
+ MAX_ADC = 32767
+
+ def __init__(self, device_id: str, serial_code: Optional[str] = None):
+ """
+ Initialize mock DAQ.
+
+ Parameters
+ ----------
+ device_id : str
+ Device identifier.
+ serial_code : str, optional
+ Serial code (ignored in mock).
+ """
+ self.device_id = device_id
+
+ # Internal state
+ self._connected = False
+ self._streaming = False
+ self._streaming_callback = None
+ self._streaming_thread = None
+ self._stop_streaming = threading.Event()
+
+ # Streaming configuration
+ self._sample_rate = 100000
+ self._voltage_range = 1.0
+ self._num_samples = 4096
+ self._downsample_ratio = 1
+ self._downsample_mode = "NONE"
+
+ # Armed acquisition state
+ self._armed = False
+ self._armed_config = {}
+
+ # Mock device info
+ self._serial = serial_code or "MOCK123456"
+ self._model = "Mock Picoscope 4000"
+
+ logger.info("Mock DAQ {} initialized (serial: {})", device_id, self._serial)
+
+ def get_connection_id(self) -> Optional[str]:
+ """
+ Get the serial code.
+
+ Returns
+ -------
+ Optional[str]
+ The serial code, or None if auto-detect is used.
+ """
+ return self._serial
+
+ def set_connection_id(self, connection_id: Optional[str]) -> None:
+ """
+ Set the serial code for connection.
+
+ Pass None or empty string to use auto-detect.
+
+ Parameters
+ ----------
+ connection_id : Optional[str]
+ Serial code, or None/empty for auto-detect.
+
+ Raises
+ ------
+ ValueError
+ If connection_id is not a string or None.
+ """
+ if connection_id is None:
+ self._serial = None
+ elif isinstance(connection_id, str):
+ self._serial = connection_id if connection_id else None
+ else:
+ raise ValueError("connection_id must be a string or None")
+
+ def is_connected(self) -> bool:
+ """Check if device is connected."""
+ return self._connected
+
+ def connect(self) -> None:
+ """Connect to the mock DAQ."""
+ logger.info("Mock DAQ {}: Connecting...", self.device_id)
+ time.sleep(0.2) # Simulate connection time
+ self._connected = True
+ logger.info("Mock DAQ {}: Connected (serial: {})", self.device_id, self._serial)
+
+ def disconnect(self) -> None:
+ """Disconnect from the mock DAQ."""
+ logger.info("Mock DAQ {}: Disconnecting...", self.device_id)
+ if self._streaming:
+ self.stop_streaming()
+ self._connected = False
+ self._armed = False
+ logger.info("Mock DAQ {}: Disconnected", self.device_id)
+
+ def configure_streaming(
+ self,
+ sample_rate: float,
+ voltage_range: float,
+ num_samples_per_block: int,
+ downsample_ratio: int = 1,
+ downsample_mode: str = "NONE",
+ ) -> None:
+ """
+ Configure streaming parameters.
+
+ Parameters
+ ----------
+ sample_rate : float
+ Sample rate in Hz.
+ voltage_range : float
+ Voltage range in Volts.
+ num_samples_per_block : int
+ Number of samples per block.
+ downsample_ratio : int
+ Downsampling ratio.
+ downsample_mode : str
+ Downsampling mode ("NONE", "AVERAGE", "DECIMATE").
+ """
+ if not self._connected:
+ raise RuntimeError("DAQ not connected")
+
+ # Validate parameters
+ if not 1000 <= sample_rate <= 125000000:
+ raise ValueError("Sample rate must be between 1kHz and 125MHz")
+ if voltage_range not in self.VOLTAGE_RANGES:
+ valid_ranges = sorted(self.VOLTAGE_RANGES.keys())
+ raise ValueError(
+ f"Invalid voltage range: {voltage_range}V. "
+ f"Valid ranges: {valid_ranges}V"
+ )
+ if not 256 <= num_samples_per_block <= 10000000:
+ raise ValueError("Samples per block must be between 256 and 10M")
+ if downsample_ratio < 1:
+ raise ValueError("Downsample ratio must be >= 1")
+ if downsample_mode not in ["NONE", "AVERAGE", "DECIMATE"]:
+ raise ValueError("Invalid downsample mode")
+
+ self._sample_rate = sample_rate
+ self._voltage_range = voltage_range
+ self._num_samples = num_samples_per_block
+ self._downsample_ratio = downsample_ratio
+ self._downsample_mode = downsample_mode
+
+ logger.info(
+ "Mock DAQ {}: Configured streaming: {:.0f}kHz, "
+ "{}V, {} samples/block, "
+ "downsample={} ({})",
+ self.device_id,
+ sample_rate / 1000,
+ voltage_range,
+ num_samples_per_block,
+ downsample_ratio,
+ downsample_mode,
+ )
+
+ def start_streaming(self, callback) -> None:
+ """
+ Start streaming data.
+
+ Parameters
+ ----------
+ callback : callable
+ Callback function for streaming data.
+ """
+ if not self._connected:
+ logger.error(
+ "Mock DAQ {}: Cannot start streaming - DAQ not connected",
+ self.device_id,
+ )
+ raise RuntimeError("DAQ not connected")
+
+ if self._streaming:
+ logger.warning(
+ "Mock DAQ {}: Already streaming, ignoring start request", self.device_id
+ )
+ return
+
+ self._streaming = True
+ self._streaming_callback = callback
+ self._stop_streaming.clear()
+
+ self._streaming_thread = threading.Thread(
+ target=self._streaming_loop, daemon=True, name=f"MockDAQ-{self.device_id}"
+ )
+ self._streaming_thread.start()
+ logger.info("Mock DAQ {}: Started streaming successfully", self.device_id)
+
+ def stop_streaming(self, timeout: Optional[float] = None) -> None:
+ """
+ Stop streaming data.
+
+ Parameters
+ ----------
+ timeout : float, optional
+ Timeout for stopping.
+ """
+ if not self._streaming:
+ logger.info("Mock DAQ {}: Not streaming, nothing to stop", self.device_id)
+ return
+
+ logger.info("Mock DAQ {}: Stopping streaming...", self.device_id)
+ self._streaming = False # Set flag first to prevent new data generation
+ self._stop_streaming.set()
+
+ if self._streaming_thread:
+ self._streaming_thread.join(timeout=timeout or 5.0)
+ if self._streaming_thread.is_alive():
+ logger.warning(
+ "Mock DAQ {}: Streaming thread did not stop in time", self.device_id
+ )
+ else:
+ logger.info(
+ "Mock DAQ {}: Streaming thread stopped successfully", self.device_id
+ )
+
+ self._streaming_callback = None
+ self._streaming_thread = None
+ logger.info("Mock DAQ {}: Stopped streaming", self.device_id)
+
+ def is_streaming(self) -> bool:
+ """Check if streaming is active."""
+ return self._streaming
+
+ def _streaming_loop(self) -> None:
+ """Generate mock streaming data."""
+ block_time = self._num_samples / self._sample_rate
+
+ iteration = 0
+ # Use smaller wait intervals to allow faster response to stop requests
+ wait_interval = min(
+ 0.05, block_time / 20
+ ) # Check stop event more frequently (50ms max)
+ elapsed_time = 0
+
+ while not self._stop_streaming.wait(wait_interval):
+ elapsed_time += wait_interval
+
+ # Generate data every 100ms regardless of sample rate for more responsive streaming
+ if elapsed_time >= 0.1:
+ try:
+ # Generate mock voltage data
+ voltage_data = self._generate_mock_data(self._num_samples)
+
+ # Apply downsampling if configured
+ if self._downsample_ratio > 1:
+ voltage_data = self._apply_downsampling(voltage_data)
+
+ # Convert to list for JSON serialization
+ data_list = voltage_data.tolist()
+
+ # Send data via callback
+ if self._streaming_callback:
+ try:
+ self._streaming_callback(data_list)
+ except Exception:
+ logger.exception(
+ "Mock DAQ {}: Callback failed",
+ self.device_id,
+ )
+ else:
+ logger.warning(
+ "Mock DAQ {}: NO CALLBACK REGISTERED at iteration {}",
+ self.device_id,
+ iteration,
+ )
+
+ iteration += 1
+ elapsed_time = 0 # Reset elapsed time after generating data
+
+ except Exception:
+ logger.exception(
+ "Mock DAQ {}: Error in streaming loop", self.device_id
+ )
+ break
+
+ logger.info(
+ "Mock DAQ {}: _streaming_loop() exiting after {} iterations",
+ self.device_id,
+ iteration,
+ )
+
+ def _generate_mock_data(self, num_samples: int) -> np.ndarray:
+ """Generate realistic mock voltage data with optimized performance."""
+ # Time vector
+ block_time = num_samples / self._sample_rate
+ t = np.linspace(0, block_time, num_samples)
+
+ # Simplified signal generation for better performance
+ # Main frequency component (1kHz)
+ frequency = 2 # updated to 2
+ # Use fixed amplitude in volts (not a percentage of range)
+ # This simulates a real signal with known amplitude
+ amplitude = 0.5 # 500mV signal
+ signal = amplitude * np.sin(2 * np.pi * frequency * t)
+
+ # Reduced harmonics for performance
+ signal += amplitude * 0.1 * np.sin(2 * np.pi * frequency * 3 * t)
+
+ # Simplified noise
+ noise_level = self._voltage_range * 0.02
+ noise = np.random.normal(0, noise_level, num_samples)
+ voltage_data = signal + noise
+
+ # Reduce spike frequency to improve performance
+ if random.random() < 0.005: # 0.5% chance of spike (reduced from 1%)
+ spike_pos = random.randint(0, num_samples - 1)
+ voltage_data[spike_pos] += random.uniform(-0.5, 0.5) * self._voltage_range
+
+ return voltage_data
+
+ def _apply_downsampling(self, data: np.ndarray) -> np.ndarray:
+ """Apply downsampling to data."""
+ if self._downsample_ratio <= 1:
+ return data
+
+ new_length = len(data) // self._downsample_ratio
+
+ if self._downsample_mode == "AVERAGE":
+ # Reshape and average
+ truncated = data[: new_length * self._downsample_ratio]
+ return truncated.reshape(new_length, self._downsample_ratio).mean(axis=1)
+ elif self._downsample_mode == "DECIMATE":
+ # Simple decimation
+ return data[:: self._downsample_ratio]
+ else: # NONE
+ return data
+
+ def arm_acquisition(
+ self,
+ sample_rate: float,
+ num_samples: int,
+ channels: List[int],
+ voltage_ranges: List[float],
+ trigger_source: Optional[str] = None,
+ trigger_threshold_v: float = 0.5,
+ trigger_direction: str = "RISING",
+ trigger_delay_s: float = 0.0,
+ downsample_ratio: int = 1,
+ downsample_mode: str = "NONE",
+ **kwargs,
+ ) -> None:
+ """
+ Arm the DAQ for hardware-triggered acquisition.
+
+ Parameters
+ ----------
+ sample_rate : float
+ Sample rate in Hz.
+ num_samples : int
+ Number of samples to acquire.
+ channels : list of int
+ Channels to acquire.
+ voltage_ranges : list of float
+ Voltage range for each channel.
+ trigger_source : str, optional
+ Trigger source (e.g., "EXT" for external).
+ trigger_threshold_v : float, optional
+ Trigger threshold in volts.
+ trigger_direction : str, optional
+ Trigger direction ("RISING" or "FALLING").
+ trigger_delay_s : float, optional
+ Trigger delay in seconds.
+ downsample_ratio : int, optional
+ Downsampling ratio.
+ downsample_mode : str, optional
+ Downsampling mode ("NONE", "AVERAGE", "DECIMATE").
+ """
+ if not self._connected:
+ raise RuntimeError("DAQ not connected")
+
+ # Validate parameters
+ if not 1000 <= sample_rate <= 125000000:
+ raise ValueError("Sample rate must be between 1kHz and 125MHz")
+ if not 256 <= num_samples <= 10000000:
+ raise ValueError("Number of samples must be between 256 and 10M")
+ if not channels:
+ raise ValueError("At least one channel must be specified")
+ if len(channels) != len(voltage_ranges):
+ raise ValueError("Number of channels must match number of voltage ranges")
+ for v_range in voltage_ranges:
+ if v_range not in self.VOLTAGE_RANGES:
+ valid_ranges = sorted(self.VOLTAGE_RANGES.keys())
+ raise ValueError(
+ f"Invalid voltage range: {v_range}V. Valid ranges: {valid_ranges}V"
+ )
+
+ # Store configuration for later use in wait_for_acquisition
+ self._armed_config = {
+ "sample_rate": sample_rate,
+ "num_samples": num_samples,
+ "channels": channels,
+ "voltage_ranges": voltage_ranges,
+ "trigger_source": trigger_source,
+ "trigger_threshold_v": trigger_threshold_v,
+ "trigger_direction": trigger_direction,
+ "trigger_delay_s": trigger_delay_s,
+ "downsample_ratio": downsample_ratio,
+ "downsample_mode": downsample_mode,
+ **kwargs, # Store any additional kwargs passed from the task
+ }
+
+ self._armed = True
+
+ logger.info(
+ "Mock DAQ {}: Armed acquisition: {:.0f}kHz, "
+ "{} samples, channels {}, "
+ "trigger={} ({})",
+ self.device_id,
+ sample_rate / 1000,
+ num_samples,
+ channels,
+ trigger_source,
+ trigger_direction,
+ )
+
+ def wait_for_acquisition(
+ self, timeout_s: float = 10.0, cancel_event: Optional[threading.Event] = None
+ ) -> Dict[str, Any]:
+ """
+ Wait for a hardware-triggered acquisition to complete.
+
+ Parameters
+ ----------
+ timeout_s : float, optional
+ Timeout in seconds.
+ cancel_event : threading.Event, optional
+ Event to signal cancellation.
+
+ Returns
+ -------
+ dict
+ Acquisition results with keys: data_v, data_adc, metadata.
+
+ Raises
+ ------
+ RuntimeError
+ If DAQ is not armed or not connected.
+ TimeoutError
+ If acquisition does not complete within timeout.
+ """
+ if not self._connected:
+ raise RuntimeError("DAQ not connected")
+
+ if not self._armed:
+ raise RuntimeError("DAQ not armed. Call arm_acquisition() first.")
+
+ logger.info(
+ "Mock DAQ {}: Waiting for hardware-triggered acquisition (timeout={}s)",
+ self.device_id,
+ timeout_s,
+ )
+
+ # Simulate waiting for trigger and acquisition
+ # In real hardware, this would block until the trigger fires
+ # For mock, we simulate the acquisition time
+ config = self._armed_config
+ num_samples = config["num_samples"]
+ sample_rate = config["sample_rate"]
+ channels = config["channels"]
+ voltage_ranges = config["voltage_ranges"]
+
+ # Estimate acquisition time
+ acquisition_time = num_samples / sample_rate
+
+ # Simulate trigger delay + acquisition with cancellation support
+ start_time = time.time()
+ while time.time() - start_time < acquisition_time:
+ if cancel_event and cancel_event.is_set():
+ logger.info("Mock DAQ {}: Acquisition cancelled", self.device_id)
+ self._armed = False
+ raise RuntimeError("Acquisition was cancelled")
+
+ # Check timeout
+
+ # Sleep in small increments to allow cancellation checks
+ time.sleep(min(0.01, acquisition_time / 10))
+
+ # Check timeout
+ if time.time() - start_time > timeout_s:
+ self._armed = False
+ raise TimeoutError(f"Acquisition did not complete within {timeout_s}s")
+
+ # --- Generate time-domain signal that will appear as a Gaussian when mapped to wavelength ---
+ # The HardwareSweepTask will map time_data_np to wavelengths.
+ # We need to generate a Gaussian in the time domain such that its peak corresponds
+ # to the desired wavelength (1554nm) when that mapping occurs.
+
+ # Assume a typical sweep range for the mock to generate a feature within
+ # This is for the mock's internal calculation, not for the task's actual sweep.
+ mock_sweep_start_wl = 1550.0
+ mock_sweep_stop_wl = 1560.0
+ mock_sweep_speed_nm_s = 20.0
+
+ time_data_np = (
+ np.arange(num_samples) / sample_rate
+ ) # Time axis for the acquired data
+
+ # Calculate the wavelength at each time point if the sweep started at mock_sweep_start_wl
+ # and swept at mock_sweep_speed_nm_s
+ mock_wavelengths_at_time = (
+ mock_sweep_start_wl
+ + np.sign(mock_sweep_stop_wl - mock_sweep_start_wl)
+ * mock_sweep_speed_nm_s
+ * time_data_np
+ )
+
+ # Define Gaussian parameters in terms of wavelength
+ peak_wavelength = 1554.0 # nm
+ peak_amplitude = -1.0 # V (negative peak)
+ fwhm = 2.0 # nm (Full Width at Half Maximum)
+ sigma = fwhm / (
+ 2 * np.sqrt(2 * np.log(2))
+ ) # Convert FWHM to standard deviation
+
+ # Generate a baseline (off-resonance) voltage
+ baseline_voltage = 0.0 # V
+
+ # Calculate Gaussian profile based on the mock_wavelengths_at_time
+ gaussian_profile = peak_amplitude * np.exp(
+ -((mock_wavelengths_at_time - peak_wavelength) ** 2) / (2 * sigma**2)
+ )
+
+ # Combine baseline and Gaussian
+ mock_spectrum_v = baseline_voltage + gaussian_profile
+
+ # Add some noise to the spectrum
+ noise_level = (
+ 0.01 * voltage_ranges[0]
+ ) # 1% of the first channel's voltage range
+ mock_spectrum_v += np.random.normal(0, noise_level, num_samples)
+
+ # Ensure data stays within voltage range (clipping)
+ mock_spectrum_v = np.clip(
+ mock_spectrum_v, -voltage_ranges[0], voltage_ranges[0]
+ )
+
+ # Generate mock data for each channel
+ data_v = {}
+ data_adc = {}
+
+ for ch_idx, ch in enumerate(channels):
+ voltage_range = voltage_ranges[ch_idx]
+
+ # For the sweep, we'll use the generated mock_spectrum_v for all channels
+ voltage_data = mock_spectrum_v
+
+ # Convert to ADC counts (matching real device: 16-bit signed, max = 32767)
+ adc_data = (voltage_data / voltage_range * self.MAX_ADC).astype(np.int16)
+
+ data_v[ch] = voltage_data.tolist()
+ data_adc[ch] = adc_data.tolist()
+
+ # Create time axis (as a list for JSON serialization)
+ time_data_list = time_data_np.tolist()
+
+ # Mark as no longer armed
+ self._armed = False
+
+ logger.info(
+ "Mock DAQ {}: Acquisition complete: {} samples at {:.0f}kHz",
+ self.device_id,
+ num_samples,
+ sample_rate / 1000,
+ )
+
+ return {
+ "data_v": data_v,
+ "data_adc": data_adc,
+ "metadata": {
+ "sample_rate": sample_rate,
+ "num_samples": num_samples,
+ "channels": channels,
+ "voltage_ranges": voltage_ranges,
+ "time_data": time_data_list,
+ "timestamp": time.time(),
+ },
+ }
+
+ def cancel_acquisition(self) -> None:
+ """
+ Cancel an in-progress or armed acquisition.
+ """
+ if not self._connected:
+ raise RuntimeError("DAQ not connected")
+
+ if not self._armed:
+ logger.debug("Mock DAQ {}: No acquisition to cancel", self.device_id)
+ return
+
+ self._armed = False
+ self._armed_config = {}
+ logger.info("Mock DAQ {}: Acquisition cancelled", self.device_id)
+
+ def reset_device_state(self) -> None:
+ """
+ Reset the device to a clean, idle state.
+ """
+ if not self._connected:
+ raise RuntimeError("DAQ not connected")
+
+ self._armed = False
+ self._armed_config = {}
+ logger.info("Mock DAQ {}: Device state reset", self.device_id)
+
+ def acquire_block(
+ self,
+ sample_rate: float,
+ num_samples: int,
+ voltage_range: float,
+ channels: Optional[List[int]] = None,
+ ) -> Dict[str, Any]:
+ """
+ Acquire a single block of data.
+
+ Parameters
+ ----------
+ sample_rate : float
+ Sample rate in Hz.
+ num_samples : int
+ Number of samples to acquire.
+ voltage_range : float
+ Voltage range in Volts.
+ channels : list of int, optional
+ Channels to acquire (default: [0]).
+
+ Returns
+ -------
+ dict
+ Acquisition results.
+ """
+ if not self._connected:
+ raise RuntimeError("DAQ not connected")
+
+ channels = channels or [0]
+
+ # Validate parameters
+ if not 1000 <= sample_rate <= 125000000:
+ raise ValueError("Sample rate must be between 1kHz and 125MHz")
+ if not 256 <= num_samples <= 10000000:
+ raise ValueError("Number of samples must be between 256 and 10M")
+ if voltage_range not in self.VOLTAGE_RANGES:
+ valid_ranges = sorted(self.VOLTAGE_RANGES.keys())
+ raise ValueError(
+ f"Invalid voltage range: {voltage_range}V. "
+ f"Valid ranges: {valid_ranges}V"
+ )
+
+ logger.info(
+ "Mock DAQ {}: Acquiring {} samples at {:.0f}kHz on channels {}",
+ self.device_id,
+ num_samples,
+ sample_rate / 1000,
+ channels,
+ )
+
+ # Simulate acquisition time
+ acquisition_time = num_samples / sample_rate
+ time.sleep(min(acquisition_time, 0.5)) # Cap at 0.5s for testing
+
+ # Generate mock data for each channel
+ results = {}
+ t = np.linspace(0, acquisition_time, num_samples)
+
+ for ch in channels:
+ # Different frequency per channel for variety
+ # frequency = 1000 + ch * 500
+ frequency = 2 + ch * 1
+ # Fixed 400mV signal amplitude (not percentage of range)
+ signal_amplitude = 0.4
+ noise_level = voltage_range * 0.01
+
+ signal = signal_amplitude * np.sin(2 * np.pi * frequency * t)
+ noise = np.random.normal(0, noise_level, num_samples)
+ voltage_data = signal + noise
+
+ results[ch] = voltage_data.tolist()
+
+ # Return in format expected by HardwareSweepTask
+ return {
+ "data_v": results,
+ "sample_rate_hz": sample_rate,
+ "num_samples": num_samples,
+ "voltage_range_v": voltage_range,
+ "channels": channels,
+ "timestamps": t.tolist(),
+ }
+
+ def get_device_info(self) -> Dict[str, Any]:
+ """Get device information."""
+ return {
+ "serial": self._serial,
+ "model": self._model,
+ "connected": self._connected,
+ "streaming": self._streaming,
+ }
+
+ def get_state(self) -> dict:
+ """Get device state."""
+ return {
+ "connected": self._connected,
+ "streaming": self._streaming,
+ "armed": self._armed,
+ "sample_rate": self._sample_rate,
+ "voltage_range": self._voltage_range,
+ "num_samples": self._num_samples,
+ "downsample_ratio": self._downsample_ratio,
+ "downsample_mode": self._downsample_mode,
+ }
+
+
+class MockPicoscopeBufferedStream:
+ """
+ Mock implementation of PicoscopeBufferedStream for testing PicoStream GUI.
+
+ Simulates multi-channel buffered streaming with a ring buffer, save functionality,
+ and realistic synthetic data (two sine waves + gaussian noise).
+ """
+
+ # Same voltage ranges as real device
+ VOLTAGE_RANGES = {
+ 0.01: "PS5000A_10MV",
+ 0.02: "PS5000A_20MV",
+ 0.05: "PS5000A_50MV",
+ 0.1: "PS5000A_100MV",
+ 0.2: "PS5000A_200MV",
+ 0.5: "PS5000A_500MV",
+ 1.0: "PS5000A_1V",
+ 2.0: "PS5000A_2V",
+ 5.0: "PS5000A_5V",
+ 10.0: "PS5000A_10V",
+ 20.0: "PS5000A_20V",
+ }
+
+ def __init__(
+ self,
+ device_id: str,
+ serial_code: Optional[str] = None,
+ resolution: int = 12,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize mock buffered stream device."""
+ self.device_id = device_id
+ self.serial_code = serial_code or "MOCK_BUFF_STREAM"
+ self.resolution = resolution
+
+ # Connection state
+ self._is_connected = False
+ self._chandle = None # Mock handle for compatibility
+ self._max_adc_value = None
+
+ # Streaming state
+ self._streaming_configured = False
+ self._streaming_channels: List[int] = []
+ self._enabled_channels_set: set[int] = set()
+ self._streaming_voltage_ranges: Dict[int, float] = {}
+ self._streaming_offsets_v: Dict[int, float] = {}
+ self._streaming_sample_rate: Optional[float] = None
+ self._streaming_sample_interval_ns: int = 16
+ self._streaming_downsample_ratio: int = 1
+ self._streaming_downsample_mode: str = "NONE"
+ self._streaming_bandwidth_limiter: str = "FULL"
+ self._streaming_buffer_duration_s: float = 30.0
+
+ # Ring buffer
+ self._ring_buffer: Optional[Any] = None
+
+ # Producer thread
+ self._producer_thread: Optional[threading.Thread] = None
+ self._stop_streaming_event = threading.Event()
+
+ # Save state
+ self._is_saving = False
+ self._save_thread: Optional[threading.Thread] = None
+ self._save_stop_event: Optional[threading.Event] = None
+ self._save_output_path: Optional[str] = None
+ self._save_samples_written = 0
+ self._save_pre_trigger_samples = 0
+ self._save_error: Optional[Exception] = None
+ self._keep_file_on_stop: bool = True
+ self._save_output_path: Optional[str] = None
+ self._save_samples_written = 0
+ self._save_pre_trigger_samples = 0
+ self._save_error: Optional[Exception] = None
+
+ # Signal generation state (for continuous phases)
+ self._phase_ch0 = 0.0
+ self._phase_ch1 = 0.0
+ self._sample_count = 0
+
+ logger.info("MockPicoscopeBufferedStream {} initialized", device_id)
+
+ def get_connection_id(self) -> Optional[str]:
+ """Get the serial code."""
+ return self.serial_code
+
+ def set_connection_id(self, connection_id: Optional[str]) -> None:
+ """Set the serial code for connection."""
+ if connection_id is None:
+ self.serial_code = None
+ elif isinstance(connection_id, str):
+ self.serial_code = connection_id if connection_id else None
+ else:
+ raise ValueError("connection_id must be a string or None")
+
+ def connect(self) -> None:
+ """Connect to the mock device."""
+ logger.info("MockPicoscopeBufferedStream {}: Connecting...", self.device_id)
+ time.sleep(0.1) # Simulate connection time
+ self._is_connected = True
+ self._max_adc_value = 32767 if self.resolution <= 12 else 32767
+ logger.info(
+ "MockPicoscopeBufferedStream {}: Connected (serial: {}, resolution: {}-bit)",
+ self.device_id,
+ self.serial_code,
+ self.resolution,
+ )
+
+ def disconnect(self) -> None:
+ """Disconnect from the mock device."""
+ logger.info("MockPicoscopeBufferedStream {}: Disconnecting...", self.device_id)
+ if self.is_streaming():
+ self.stop_streaming()
+ self._is_connected = False
+ self._streaming_configured = False
+ self._ring_buffer = None
+ logger.info("MockPicoscopeBufferedStream {}: Disconnected", self.device_id)
+
+ def configure_streaming(
+ self,
+ sample_rate: float,
+ channels: List[int],
+ voltage_ranges: List[float],
+ buffer_duration_s: float = 30.0,
+ downsample_ratio: int = 1,
+ downsample_mode: str = "NONE",
+ offsets_v: Optional[List[float]] = None,
+ bandwidth_limiter: str = "FULL",
+ ) -> None:
+ """Configure buffered streaming with ring buffer."""
+ if not self._is_connected:
+ raise RuntimeError("Device not connected")
+
+ from picostream.ring_buffer import RingBuffer
+
+ # Validate parameters
+ if len(channels) != len(voltage_ranges):
+ raise ValueError("Length of channels and voltage_ranges must match")
+
+ # Validate voltage ranges against supported values
+ for v_range in voltage_ranges:
+ if v_range not in self.VOLTAGE_RANGES:
+ valid_ranges = sorted(self.VOLTAGE_RANGES.keys())
+ raise ValueError(
+ f"Invalid voltage range: {v_range}V. Valid ranges: {valid_ranges}V"
+ )
+
+ # Clamp sample rate to sustainable level for mock device
+ max_mock_rate = 1e6 # 1 MS/s is plenty for testing
+ if sample_rate > max_mock_rate:
+ logger.warning(
+ "Mock device cannot sustain {:.1f} MS/s, clamping to {:.1f} MS/s for testing",
+ sample_rate / 1e6,
+ max_mock_rate / 1e6,
+ )
+ sample_rate = max_mock_rate
+
+ # Store configuration
+ self._streaming_channels = list(channels)
+
+ # For mock device, always generate data for both channels
+ # (real oscilloscopes always have data, you just choose which to display)
+ self._enabled_channels_set = {0, 1}
+
+ # Ensure both channels have voltage ranges defined
+ self._streaming_voltage_ranges = {}
+ for ch, v_range in zip(channels, voltage_ranges, strict=True):
+ self._streaming_voltage_ranges[ch] = v_range
+
+ # Ensure channel 0 and 1 always have voltage ranges defined (use defaults if not provided)
+ if 0 not in self._streaming_voltage_ranges:
+ self._streaming_voltage_ranges[0] = (
+ voltage_ranges[0] if voltage_ranges else 20.0
+ )
+ if 1 not in self._streaming_voltage_ranges:
+ self._streaming_voltage_ranges[1] = (
+ voltage_ranges[-1]
+ if len(voltage_ranges) > 1
+ else (voltage_ranges[0] if voltage_ranges else 20.0)
+ )
+ self._streaming_sample_rate = sample_rate
+ self._streaming_buffer_duration_s = buffer_duration_s
+
+ if offsets_v is None:
+ offsets_v = [0.0] * len(channels)
+ self._streaming_offsets_v = {
+ ch: offset for ch, offset in zip(channels, offsets_v, strict=True)
+ }
+
+ if downsample_mode.upper() == "NONE":
+ downsample_ratio = 1
+ self._streaming_downsample_ratio = downsample_ratio
+ self._streaming_downsample_mode = downsample_mode.upper()
+ self._streaming_bandwidth_limiter = bandwidth_limiter.upper()
+
+ # Create ring buffer (always 2 channels)
+ # Calculate total hardware rate (per-channel rate × num_channels)
+ total_hardware_rate_hz = sample_rate * len(self._streaming_channels)
+ acquisition_rate = AcquisitionRate(
+ hardware_rate_hz=total_hardware_rate_hz,
+ num_channels=len(self._streaming_channels),
+ downsample_ratio=downsample_ratio,
+ downsample_mode=DownsampleMode(downsample_mode.upper()),
+ )
+
+ self._ring_buffer = RingBuffer(
+ duration_s=buffer_duration_s,
+ sample_rate=acquisition_rate.storage_rate_hz,
+ num_channels=2,
+ )
+
+ # Store acquisition rate for later use
+ self._acquisition_rate = acquisition_rate
+
+ self._streaming_configured = True
+ self._phase_ch0 = 0.0
+ self._phase_ch1 = 0.0
+ self._sample_count = 0
+
+ enabled_names = ["A" if ch == 0 else "B" for ch in sorted(channels)]
+ logger.info(
+ "MockPicoscopeBufferedStream {}: Configured streaming channels=[{}], rate={:.2f} Hz, buffer={}s, downsample={}x ({})",
+ self.device_id,
+ ", ".join(enabled_names),
+ sample_rate,
+ buffer_duration_s,
+ downsample_ratio,
+ downsample_mode,
+ )
+
+ def start_streaming(self, callback: Optional[Callable] = None) -> None:
+ """Start buffered streaming to ring buffer."""
+ if not self._is_connected:
+ raise RuntimeError("Device not connected")
+ if not self._streaming_configured:
+ raise RuntimeError(
+ "Streaming not configured. Call configure_streaming() first"
+ )
+ if self.is_streaming():
+ logger.warning("Streaming already active on {}", self.device_id)
+ return
+
+ if self._streaming_sample_rate is None:
+ raise RuntimeError("Sample rate not configured")
+
+ # Use existing acquisition_rate from configure_streaming
+ if not hasattr(self, "_acquisition_rate"):
+ raise RuntimeError(
+ "configure_streaming() must be called before start_streaming()"
+ )
+
+ # Set acquisition rate on ring buffer
+ if self._ring_buffer is not None:
+ self._ring_buffer.set_acquisition_rate(self._acquisition_rate)
+ logger.info(
+ "Mock ring buffer acquisition rate set: {}",
+ self._acquisition_rate,
+ )
+
+ self._stop_streaming_event.clear()
+ self._producer_thread = threading.Thread(
+ target=self._producer_loop,
+ daemon=True,
+ name=f"mock-producer-{self.device_id}",
+ )
+ self._producer_thread.start()
+ logger.info("MockPicoscopeBufferedStream {}: Streaming started", self.device_id)
+
+ def stop_streaming(self, timeout: float = 5.0) -> None:
+ """Stop buffered streaming."""
+ if not self.is_streaming():
+ logger.info("Streaming not active on {}", self.device_id)
+ return
+
+ logger.info(
+ "MockPicoscopeBufferedStream {}: Stopping streaming...", self.device_id
+ )
+ self._stop_streaming_event.set()
+
+ if self._is_saving:
+ self.stop_save(keep=True)
+
+ if self._producer_thread and self._producer_thread.is_alive():
+ self._producer_thread.join(timeout)
+ if self._producer_thread.is_alive():
+ logger.warning("Producer thread did not stop in time")
+
+ self._producer_thread = None
+ logger.info("MockPicoscopeBufferedStream {}: Streaming stopped", self.device_id)
+
+ def is_streaming(self) -> bool:
+ """Return True if streaming thread is active."""
+ return self._producer_thread is not None and self._producer_thread.is_alive()
+
+ @property
+ def ring_buffer(self) -> Optional[Any]:
+ """Return the ring buffer for plotter access."""
+ return self._ring_buffer
+
+ def update_plotter_position(self, read_idx: int) -> None:
+ """Update the plotter's last read position for overflow detection.
+
+ This mock implementation stores the value but doesn't use it
+ for actual overflow detection.
+
+ Parameters
+ ----------
+ read_idx : int
+ The sample index that the plotter last read up to.
+ """
+ logger.debug("Mock update_plotter_position called with read_idx={}", read_idx)
+
+ def get_streaming_error(self) -> Optional[str]:
+ """Get the last streaming error, if any.
+
+ Returns
+ -------
+ Optional[str]
+ Error description if streaming stopped due to an error,
+ None otherwise.
+ """
+ return None
+
+ @property
+ def is_saving(self) -> bool:
+ """Return True if currently saving to file."""
+ return self._is_saving
+
+ def start_save(self, lookback_seconds: float, output_path: str) -> bool:
+ """Start saving data to a Zarr file."""
+ if self._is_saving:
+ logger.warning("Save already in progress")
+ return False
+ if self._ring_buffer is None:
+ logger.error("Cannot start save: ring buffer not configured")
+ return False
+ if not hasattr(self, "_acquisition_rate"):
+ logger.error("Cannot start save: acquisition rate not configured")
+ return False
+
+ self._save_output_path = output_path
+
+ # Use acquisition_rate for consistent rate calculations
+ if not hasattr(self, "_acquisition_rate"):
+ raise RuntimeError(
+ "configure_streaming() must be called before start_save()"
+ )
+
+ lookback_samples = self._acquisition_rate.seconds_to_samples(lookback_seconds)
+ available_samples = min(lookback_samples, self._ring_buffer.write_idx)
+
+ self._save_stop_event = threading.Event()
+ self._save_samples_written = 0
+ self._save_pre_trigger_samples = 0
+ self._save_error = None
+
+ # Use AcquisitionRate from ring buffer if available
+ acquisition_rate: Optional[AcquisitionRate] = None
+ if self._ring_buffer is not None:
+ acquisition_rate = self._ring_buffer.get_acquisition_rate()
+
+ # Build metadata with AcquisitionRate fields for consistent rate semantics
+ # Use the acquisition_rate that was already calculated
+ total_hardware_rate_hz = self._acquisition_rate.hardware_rate_hz
+ metadata: Dict[str, Any] = {
+ # AcquisitionRate fields - stored explicitly for reconstruction
+ "hardware_rate_hz": total_hardware_rate_hz,
+ "num_channels": 2,
+ "downsample_ratio": self._streaming_downsample_ratio,
+ "downsample_mode": self._streaming_downsample_mode,
+ # Legacy field for backward compatibility
+ "sample_rate_hz": acquisition_rate.storage_rate_hz,
+ "channels": [0, 1],
+ "enabled_channels": self._streaming_channels,
+ "voltage_ranges": [
+ self._streaming_voltage_ranges.get(ch, 20.0) for ch in [0, 1]
+ ],
+ "voltage_ranges_units": "V",
+ "offsets_v": [self._streaming_offsets_v.get(ch, 0.0) for ch in [0, 1]],
+ "resolution": self.resolution,
+ "resolution_units": "bits",
+ "max_adc": self._max_adc_value if self._max_adc_value else 32767,
+ "pre_trigger_seconds": lookback_seconds,
+ "buffer_duration_s": self._streaming_buffer_duration_s,
+ "device_id": self.device_id,
+ "serial_code": self.serial_code,
+ "coupling": "DC",
+ "bandwidth_limiter": self._streaming_bandwidth_limiter,
+ }
+
+ self._save_thread = threading.Thread(
+ target=self._save_worker,
+ args=(output_path, available_samples, metadata, acquisition_rate),
+ daemon=True,
+ name=f"mock-save-{self.device_id}",
+ )
+ self._save_thread.start()
+ self._is_saving = True
+
+ logger.info(
+ "Mock save started: path={}, lookback={:.1f}s ({} samples)",
+ output_path,
+ lookback_seconds,
+ available_samples,
+ )
+ return True
+
+ def stop_save(self, keep: bool = True) -> Optional[str]:
+ """Stop the current save operation."""
+ if not self._is_saving:
+ logger.info("No save in progress")
+ return None
+
+ self._keep_file_on_stop = keep
+ logger.info("Stopping mock save (keep={})", keep)
+ self._save_stop_event.set()
+
+ if self._save_thread and self._save_thread.is_alive():
+ self._save_thread.join(timeout=5.0)
+
+ path = self._save_output_path
+
+ if not keep and path and os.path.exists(path):
+ try:
+ import shutil
+
+ shutil.rmtree(path)
+ logger.info("Discarded mock save file: {}", path)
+ except Exception as e:
+ logger.exception("Error discarding save file: {}", e)
+
+ self._is_saving = False
+ self._save_thread = None
+ self._save_stop_event = None
+
+ return path if keep else None
+
+ def is_save_finished(self) -> bool:
+ """Check if the save operation has finished.
+
+ After this returns True, you can safely use the saved file.
+
+ Returns
+ -------
+ bool
+ True if save is complete and the file is ready, False otherwise.
+ """
+ if not self._is_saving:
+ return True
+
+ if self._save_thread and self._save_thread.is_alive():
+ return False
+
+ try:
+ if not self._keep_file_on_stop and self._save_output_path:
+ if os.path.exists(self._save_output_path):
+ import shutil
+
+ shutil.rmtree(self._save_output_path)
+ self._is_saving = False
+ self._save_thread = None
+ self._save_stop_event = None
+ except Exception:
+ logger.exception("Error during save cleanup")
+
+ return True
+
+ def get_save_status(self) -> Dict[str, Any]:
+ """Get the current save status."""
+ if not self._is_saving and self._save_error is None:
+ return {"state": "idle"}
+ if self._save_error is not None:
+ return {"state": "error", "error": str(self._save_error)}
+
+ # Use acquisition_rate for consistent rate semantics
+ if not hasattr(self, "_acquisition_rate"):
+ return {"state": "saving", "total_seconds": 0}
+
+ total_seconds = self._acquisition_rate.samples_to_seconds(
+ self._save_samples_written
+ )
+ pre_trigger_seconds = self._acquisition_rate.samples_to_seconds(
+ self._save_pre_trigger_samples
+ )
+ post_trigger_seconds = total_seconds - pre_trigger_seconds
+
+ return {
+ "state": "saving",
+ "total_seconds": total_seconds,
+ "pre_trigger_seconds": pre_trigger_seconds,
+ "post_trigger_seconds": post_trigger_seconds,
+ }
+
+ def _producer_loop(self) -> None:
+ """Producer thread: generate synthetic data and write to ring buffer."""
+ logger.debug("Mock producer loop started for {}", self.device_id)
+
+ if self._streaming_sample_rate is None:
+ logger.error("Sample rate not set")
+ return
+
+ # CRITICAL: Always generate at target per-channel rate
+ # Downsampling happens AFTER by grouping samples in _apply_hardware_downsample()
+ generation_rate = self._streaming_sample_rate
+
+ # Fixed chunk duration for stable timing
+ chunk_duration_s = 0.01
+ samples_per_chunk = max(1, int(generation_rate * chunk_duration_s))
+
+ # Signal parameters (different frequencies for each channel)
+ freq_ch0 = 1.0
+ freq_ch1 = 2.5
+
+ while not self._stop_streaming_event.is_set():
+ start_time = time.time()
+
+ # Time array for this chunk (continuous across chunks)
+ t_chunk = np.linspace(
+ self._sample_count / generation_rate,
+ (self._sample_count + samples_per_chunk) / generation_rate,
+ samples_per_chunk,
+ )
+ max_adc = self._max_adc_value if self._max_adc_value else 32767
+
+ # Channel A - fixed 500mV signal amplitude (not percentage of range)
+ if 0 in self._enabled_channels_set:
+ signal_amplitude = 0.5 # 500mV fixed amplitude
+ signal_ch0 = signal_amplitude * np.sin(2 * np.pi * freq_ch0 * t_chunk)
+ signal_ch0 += np.random.normal(
+ 0, signal_amplitude * 0.05, samples_per_chunk
+ )
+ voltage_range_ch0 = self._streaming_voltage_ranges.get(0, 1.0)
+ adc_ch0 = (signal_ch0 / voltage_range_ch0 * max_adc).astype(np.int16)
+ else:
+ adc_ch0 = np.zeros(samples_per_chunk, dtype=np.int16)
+
+ # Channel B - fixed 500mV signal amplitude (not percentage of range)
+ if 1 in self._enabled_channels_set:
+ signal_amplitude = 0.5 # 500mV fixed amplitude
+ signal_ch1 = signal_amplitude * np.sin(2 * np.pi * freq_ch1 * t_chunk)
+ signal_ch1 += np.random.normal(
+ 0, signal_amplitude * 0.05, samples_per_chunk
+ )
+ voltage_range_ch1 = self._streaming_voltage_ranges.get(1, 1.0)
+ adc_ch1 = (signal_ch1 / voltage_range_ch1 * max_adc).astype(np.int16)
+ else:
+ adc_ch1 = np.zeros(samples_per_chunk, dtype=np.int16)
+
+ self._sample_count += samples_per_chunk
+
+ # Stack into output array (always 2 columns)
+ raw_data = np.column_stack([adc_ch0, adc_ch1])
+
+ # Apply hardware downsampling to the data before writing to ring buffer
+ processed_data = self._apply_hardware_downsample(raw_data)
+
+ # Write to ring buffer
+ if self._ring_buffer is not None:
+ self._ring_buffer.write(processed_data)
+
+ # Sleep to maintain real-time rate
+ elapsed = time.time() - start_time
+ sleep_time = chunk_duration_s - elapsed
+ if sleep_time > 0:
+ time.sleep(sleep_time)
+ elif sleep_time < -0.01:
+ logger.warning("Mock producer running behind by {:.3f}s", -sleep_time)
+
+ logger.debug("Mock producer loop finished for {}", self.device_id)
+
+ def _apply_hardware_downsample(self, data: np.ndarray) -> np.ndarray:
+ """Apply hardware downsampling to raw data before writing to ring buffer.
+
+ This simulates the Picoscope hardware downsampling behaviour:
+ - NONE: Data passes through unchanged
+ - AVERAGE: Average blocks of samples
+ - AGGREGATE: Output min/max pairs for each block
+ - DECIMATE: Sample every Nth value
+
+ Parameters
+ ----------
+ data : np.ndarray
+ Raw ADC data array, shape (n_samples, num_channels).
+
+ Returns
+ -------
+ np.ndarray
+ Downsampled data array, always int16 dtype.
+ For AGGREGATE mode, returns interleaved min/max pairs producing 2 samples per block.
+
+ Raises
+ ------
+ ValueError
+ If downsample ratio is invalid.
+ """
+ downsample_ratio = self._streaming_downsample_ratio
+ downsample_mode = self._streaming_downsample_mode
+
+ if downsample_ratio <= 1 or downsample_mode == "NONE":
+ return data
+
+ num_samples = data.shape[0]
+ num_full_groups = num_samples // downsample_ratio
+
+ if num_full_groups == 0:
+ return data
+
+ if downsample_mode == "AVERAGE":
+ # Reshape to (groups, samples_per_group, channels) and average
+ truncated = data[: num_full_groups * downsample_ratio]
+ reshaped = truncated.reshape(
+ num_full_groups, downsample_ratio, data.shape[1]
+ )
+ return reshaped.mean(axis=1).astype(data.dtype)
+
+ if downsample_mode == "DECIMATE":
+ # Simple decimation: take every Nth sample
+ return data[::downsample_ratio]
+
+ if downsample_mode == "AGGREGATE":
+ # For each block, find min and max across all samples
+ truncated = data[: num_full_groups * downsample_ratio]
+ reshaped = truncated.reshape(
+ num_full_groups, downsample_ratio, data.shape[1]
+ )
+
+ # Vectorized min/max calculation
+ min_vals = reshaped.min(axis=1).astype(data.dtype)
+ max_vals = reshaped.max(axis=1).astype(data.dtype)
+
+ # Interleave min/max: [min0, max0, min1, max1, ...]
+ num_channels = data.shape[1]
+ output = np.empty((num_full_groups * 2, num_channels), dtype=data.dtype)
+ output[0::2] = min_vals
+ output[1::2] = max_vals
+
+ return output
+
+ return data
+
+ def _save_worker(
+ self,
+ path: str,
+ lookback_samples: int,
+ metadata: Dict[str, Any],
+ acquisition_rate: "AcquisitionRate",
+ ) -> None:
+ """Worker thread for saving data to Zarr.
+
+ This runs in a separate thread and continuously drains data
+ from the ring buffer to a Zarr file until stop_save is called.
+ Always saves 2 columns (A and B), with zeros for disabled channels.
+
+ Parameters
+ ----------
+ path : str
+ Output path for the Zarr file.
+ lookback_samples : int
+ Number of samples to include before the trigger point.
+ metadata : Dict[str, Any]
+ Metadata dictionary to store in the Zarr file.
+ acquisition_rate : AcquisitionRate
+ Acquisition rate object with all rate information.
+
+ Raises
+ ------
+ Exception
+ Any exception during saving is stored in _save_error and re-raised
+ to ensure the thread terminates with an error state.
+ """
+ try:
+ from picostream.zarr_writer import ZarrStreamWriter
+
+ writer = ZarrStreamWriter(
+ path=path,
+ acquisition_rate=acquisition_rate,
+ num_channels=2,
+ compression=None,
+ )
+
+ trigger_pos = self._ring_buffer.get_snapshot()
+ pre_trigger_start = trigger_pos - lookback_samples
+ if pre_trigger_start < 0:
+ pre_trigger_start = 0
+
+ pre_trigger_data = self._ring_buffer.read_range(
+ pre_trigger_start, trigger_pos
+ )
+
+ if pre_trigger_data.shape[0] > 0:
+ writer.append(pre_trigger_data)
+ self._save_pre_trigger_samples = pre_trigger_data.shape[0]
+
+ last_pos = trigger_pos
+
+ while not self._save_stop_event.is_set():
+ data = self._ring_buffer.read_since(last_pos)
+
+ if data.shape[0] > 0:
+ if not self._ring_buffer.is_valid_range(last_pos):
+ valid_start = (
+ self._ring_buffer.write_idx - self._ring_buffer.capacity
+ )
+ if valid_start < 0:
+ valid_start = 0
+ data = self._ring_buffer.read_since(valid_start)
+ last_pos = valid_start
+
+ writer.append(data)
+ last_pos = self._ring_buffer.get_snapshot()
+ self._save_samples_written = writer.written
+ else:
+ time.sleep(0.001)
+
+ writer.close(metadata)
+ logger.info(
+ "Mock save completed: {} samples written to {}",
+ writer.written,
+ path,
+ )
+
+ except Exception as e:
+ logger.exception("Mock save worker error: {}", e)
+ self._save_error = e
diff --git a/picostream/pico.py b/picostream/pico.py
deleted file mode 100644
index 6b7eada..0000000
--- a/picostream/pico.py
+++ /dev/null
@@ -1,535 +0,0 @@
-from __future__ import annotations
-
-import ctypes
-import queue
-import threading
-import time
-from typing import Any, Dict, List, Optional
-
-import numpy as np
-from loguru import logger
-
-from picosdk.functions import PICO_STATUS
-from picosdk.errors import CannotFindPicoSDKError
-
-try:
- from picosdk.ps5000a import ps5000a as ps
- from picosdk.PicoDeviceEnums import picoEnum as enums
-
- # Add PS5000A bandwidth limiter enum if not present
- if not hasattr(ps, 'PS5000A_BANDWIDTH_LIMITER'):
- ps.PS5000A_BANDWIDTH_LIMITER = {
- 'PS5000A_BW_FULL': 0,
- 'PS5000A_BW_20MHZ': 20000000,
- }
-except CannotFindPicoSDKError:
- logger.critical("PICOSDK IMPORT FAILED: CANNOT FIND SDK LIB - download from pico website")
- ps = None
- enums = None
-
-
-def check_status(status: int, function_name: str) -> None:
- """Check the status returned by a Picoscope SDK call and raise on error.
-
- Args:
- status: The status code returned by the SDK function.
- function_name: The name of the function that was called.
-
- Raises:
- Exception: If the status is not PICO_OK.
- """
- if status != PICO_STATUS["PICO_OK"]:
- # Find the string name of the error code
- error_name = next(
- (k for k, v in PICO_STATUS.items() if v == status), "PICO_UNKNOWN_ERROR"
- )
- raise RuntimeError(
- f"{function_name} failed with status {status} ({error_name})"
- )
-
-
-class PicoDevice:
- """A class to manage a Picoscope 5000a series device for data streaming.
-
- This class handles device configuration, buffer management, and the data
- capture loop. It acts as the "producer" in a producer-consumer pattern.
-
- Data Formats:
- - average mode: Single stream of averaged ADC values
- - aggregate mode: Interleaved stream [min1, max1, min2, max2, ...]
- where each pair represents one downsampled timestep
-
- Buffer Management:
- - SDK buffers: Direct hardware interface (bufferA, bufferB for aggregate)
- - Application buffers: Larger buffers for efficient file writing
- - Interleaved buffer: Temporary buffer for aggregate mode processing
- """
-
- def __init__(
- self,
- handle: int,
- resolution: str,
- pico_buffer_size: int,
- pico_num_buffers: int,
- comp_buffer_size: int,
- data_queue: queue.Queue[int],
- empty_queue: queue.Queue[int],
- data_buffers: List[np.ndarray],
- shutdown_event: threading.Event,
- downsample_mode: str = "average",
- ) -> None:
- """Initializes the PicoDevice and opens a connection to the hardware.
-
- Args:
- handle: The device handle provided by the SDK.
- resolution: The desired resolution, e.g., "PS5000A_DR_16BIT".
- pico_buffer_size: The size of each buffer allocated within the SDK.
- pico_num_buffers: The number of buffers for the SDK to use.
- comp_buffer_size: The size of the application-side buffers for writing to disk.
- data_queue: Queue to send indices of full buffers to the consumer.
- empty_queue: Queue to receive indices of empty buffers from the consumer.
- data_buffers: A list of pre-allocated numpy arrays for data transfer.
- shutdown_event: A threading.Event to signal shutdown.
- """
- # --- Device and Resolution ---
- self.handle: ctypes.c_int16 = ctypes.c_int16(handle)
- self.resolution: str = resolution
- res_enum = ps.PS5000A_DEVICE_RESOLUTION[resolution]
-
- # --- Picoscope SDK Buffer Configuration ---
- self.pico_buffer_size: int = pico_buffer_size
- self.pico_num_buffers: int = pico_num_buffers
- self.total_samples: int = self.pico_buffer_size * self.pico_num_buffers
-
- # --- Application Buffer (for file writing) ---
- self.comp_buffer_size: int = comp_buffer_size
-
- # --- Internal Data Buffer (receives data from SDK) ---
- self.downsample_mode = downsample_mode
- self.bufferA: np.ndarray = np.zeros(shape=self.pico_buffer_size, dtype=np.int16)
- self.bufferB: Optional[np.ndarray] = None
- self.interleaved_buffer: Optional[np.ndarray] = None
- if self.downsample_mode == "aggregate":
- self.bufferB = np.zeros(shape=self.pico_buffer_size, dtype=np.int16)
- self.interleaved_buffer = np.zeros(
- shape=self.pico_buffer_size * 2, dtype=np.int16
- )
-
- # --- Ctypes and Callback ---
- self.callbackFuncPtr = ps.StreamingReadyType(self.streaming_callback)
- self.max_adc: ctypes.c_int16 = ctypes.c_int16()
-
- # --- Channel Configuration ---
- self.channel_range: Optional[int] = None
- self.voltage_range_v: Optional[float] = None
- self.channel_a_coupling: Optional[str] = None
- self.analog_offset_v: float = 0.0
- self.channel_a_range_str: Optional[str] = None
- self.bandwidth_limiter: Optional[str] = None
-
- # --- Threading and Queues ---
- self.shutdown_event: threading.Event = shutdown_event
- self.data_queue: queue.Queue[int] = data_queue
- self.empty_queue: queue.Queue[int] = empty_queue
- self.data_buffers: List[np.ndarray] = data_buffers
-
- # --- Buffer Management State ---
- self.buf_idx: int = self.empty_queue.get()
- self.buf_used: int = 0
- self.buf_free: int = self.comp_buffer_size
-
- # --- Streaming Configuration ---
- self.streaming_configured: bool = False
- self.sample_int: Optional[ctypes.c_int32] = None
- self.sample_unit: Optional[int] = None
- self.ratio: Optional[int] = None
- self.pre_trig_samples: Optional[int] = None
- self.down_sample_ratio: Optional[int] = None
- self.auto_stop: Optional[int] = None
- self.auto_stop_stream: Optional[int] = None
-
- # --- Status and Performance Metrics ---
- self.captured_samples: int = 0
- self.empty_pro_queue_count: int = 0
- self.overvoltage_count: int = 0
- self.hardware_buffer_overflow_count: int = 0
- self.callback_durations: List[float] = []
-
- # --- Aggregate Mode Performance Tracking ---
- self.interleave_durations: List[float] = []
-
- # --- Open device connection ---
- status = ps.ps5000aOpenUnit(ctypes.byref(self.handle), None, res_enum)
- check_status(status, "ps5000aOpenUnit")
-
- status = ps.ps5000aMaximumValue(self.handle, ctypes.byref(self.max_adc))
- check_status(status, "ps5000aMaximumValue")
-
- logger.info(f"Device opened - Resolution: {self.resolution}, max_adc queried: {self.max_adc.value}")
-
- def set_channel(
- self, chan: str, en: int, coup: str, voltage_range_str: str, offset: float
- ) -> None:
- """Configure a channel on the Picoscope.
-
- Args:
- chan: The channel identifier string, e.g., "PS5000A_CHANNEL_A".
- en: Whether the channel is enabled (1) or disabled (0).
- coup: The coupling type string, e.g., "PS5000A_DC".
- voltage_range_str: The voltage range string, e.g., "PS5000A_20V".
- offset: The analog voltage offset in Volts.
- """
- channel_range_enum = ps.PS5000A_RANGE[voltage_range_str]
- self.channel_range = channel_range_enum
-
- # Store the actual voltage range for metadata and conversion
- range_to_voltage = {
- "PS5000A_10MV": 0.01,
- "PS5000A_20MV": 0.02,
- "PS5000A_50MV": 0.05,
- "PS5000A_100MV": 0.1,
- "PS5000A_200MV": 0.2,
- "PS5000A_500MV": 0.5,
- "PS5000A_1V": 1.0,
- "PS5000A_2V": 2.0,
- "PS5000A_5V": 5.0,
- "PS5000A_10V": 10.0,
- "PS5000A_20V": 20.0,
- }
- self.voltage_range_v = range_to_voltage.get(voltage_range_str)
-
- if chan == "PS5000A_CHANNEL_A" and en:
- self.channel_a_coupling = coup
- self.channel_a_range_str = voltage_range_str
- self.analog_offset_v = offset
-
- channel_enum = ps.PS5000A_CHANNEL[chan]
- coupling_enum = ps.PS5000A_COUPLING[coup]
- status = ps.ps5000aSetChannel(
- self.handle, channel_enum, en, coupling_enum, channel_range_enum, offset
- )
- check_status(status, f"ps5000aSetChannel ({chan})")
- logger.info(
- f"Channel configured - Range: '{voltage_range_str}' (enum: {channel_range_enum}), "
- f"voltage_range_v: {self.voltage_range_v}V, max_adc: {self.max_adc.value}"
- )
-
- def set_bandwidth_filter(self, channel: str, bandwidth: str) -> None:
- """Set the bandwidth filter for a channel to reduce noise.
-
- Args:
- channel: The channel identifier string, e.g., "PS5000A_CHANNEL_A".
- bandwidth: The bandwidth limiter string, e.g., "PS5000A_BW_FULL" or "PS5000A_BW_20MHZ".
- """
- channel_enum = ps.PS5000A_CHANNEL[channel]
- bandwidth_enum = ps.PS5000A_BANDWIDTH_LIMITER[bandwidth]
-
- if channel == "PS5000A_CHANNEL_A":
- self.bandwidth_limiter = bandwidth
-
- status = ps.ps5000aSetBandwidthFilter(
- self.handle,
- channel_enum,
- bandwidth_enum
- )
- check_status(status, f"ps5000aSetBandwidthFilter ({channel}, {bandwidth})")
- logger.info(f"Bandwidth filter set - Channel: {channel}, Bandwidth: {bandwidth}")
-
- def set_data_buffer(self, chan: str, segment: int, rat: str) -> None:
- """Set up the data buffer for a specific channel for streaming.
-
- Args:
- chan: The channel identifier string, e.g., "PS5000A_CHANNEL_A".
- segment: The memory segment to use (0 for streaming).
- rat: The ratio mode string, e.g., "PS5000A_RATIO_MODE_NONE".
- """
- channel_enum = ps.PS5000A_CHANNEL[chan]
- ratio_enum = ps.PS5000A_RATIO_MODE[rat]
-
- buffer_min_ptr = None
- if self.bufferB is not None:
- buffer_min_ptr = self.bufferB.ctypes.data_as(ctypes.POINTER(ctypes.c_int16))
-
- status = ps.ps5000aSetDataBuffers(
- self.handle,
- channel_enum,
- self.bufferA.ctypes.data_as(ctypes.POINTER(ctypes.c_int16)),
- buffer_min_ptr,
- self.pico_buffer_size,
- segment,
- ratio_enum,
- )
- check_status(status, f"ps5000aSetDataBuffers ({chan})")
-
- def configure_streaming_var(
- self,
- samp_int: int,
- samp_unit: str,
- pre_trig_samp: int,
- down_samp_rat: int,
- rat: str,
- auto_stop: int,
- auto_stop_stream: int,
- ) -> None:
- """Store streaming parameters before starting the capture.
-
- Args:
- samp_int: The desired sample interval in `samp_unit` units.
- samp_unit: The time unit string, e.g., "PS5000A_NS".
- pre_trig_samp: The number of pre-trigger samples.
- down_samp_rat: The downsampling ratio.
- rat: The ratio mode string, e.g., "PS5000A_RATIO_MODE_NONE".
- auto_stop: Whether to stop the capture automatically (1) or not (0).
- auto_stop_stream: Deprecated, not used.
- """
- self.sample_int = ctypes.c_int32(samp_int)
- self.sample_unit = ps.PS5000A_TIME_UNITS[samp_unit]
- self.ratio = ps.PS5000A_RATIO_MODE[rat]
- self.pre_trig_samples = pre_trig_samp
- self.down_sample_ratio = down_samp_rat
- self.auto_stop = auto_stop
- self.auto_stop_stream = auto_stop_stream
-
- def get_metadata(self, acquisition_start_time_utc: str, picostream_version: str, acquisition_command: str, was_live_mode: bool) -> Dict[str, Any]:
- """Return comprehensive acquisition metadata."""
- metadata = {
- "resolution": self.resolution,
- "sample_interval_ns": self.sample_int.value if self.sample_int else None,
- "voltage_range_v": self.voltage_range_v,
- "channel_a_coupling": self.channel_a_coupling,
- "channel_a_range": self.channel_a_range_str,
- "downsample_mode": self.downsample_mode,
- "analog_offset_v": self.analog_offset_v,
- "hardware_downsample_ratio": self.down_sample_ratio,
- "max_adc": self.max_adc.value,
- "data_format_version": "2.0",
- "acquisition_start_time_utc": acquisition_start_time_utc,
- "picostream_version": picostream_version,
- "acquisition_command": acquisition_command,
- "was_live_mode": was_live_mode,
- "bandwidth_limiter": self.bandwidth_limiter,
- }
-
- if self.downsample_mode == "aggregate":
- metadata.update(
- {
- "aggregate_format": "interleaved_min_max",
- "aggregate_description": "Data format: [min1, max1, min2, max2, ...]",
- }
- )
-
- return metadata
-
- def run_streaming(self) -> None:
- """Starts the Picoscope streaming capture."""
- status = ps.ps5000aRunStreaming(
- self.handle,
- ctypes.byref(self.sample_int),
- self.sample_unit,
- self.pre_trig_samples,
- self.total_samples,
- self.auto_stop,
- self.down_sample_ratio,
- self.ratio,
- self.pico_buffer_size,
- )
- check_status(status, "ps5000aRunStreaming")
- self.streaming_configured = True
- logger.info(
- f"Streaming configured. Actual sample interval: {self.sample_int.value} ns"
- )
-
- def streaming_callback(
- self,
- _handle: int,
- noOfSamples: int,
- startIndex: int,
- overvoltage_flags: int,
- _triggerAt: int,
- _triggered: int,
- _autoStop: int,
- _param: int,
- ) -> None:
- """Callback function executed by the SDK when new streaming data is available.
-
- This function is the heart of the producer. It copies data from the
- Picoscope's internal buffer into the application's shared buffer pool.
- When an application buffer is full, its index is placed on the data_queue
- for the consumer.
-
- Note: This function is called from a thread created by the Picoscope SDK.
- It must be fast and thread-safe.
- """
- if overvoltage_flags:
- self.overvoltage_count += 1
- logger.warning(
- "Picoscope ADC over-range detected (saturation/clipping)."
- )
-
- # Stop processing if a shutdown is requested.
- if self.shutdown_event.is_set():
- return
-
- callback_start_time = time.perf_counter()
- if noOfSamples > 0:
- self.captured_samples += noOfSamples
-
- source_buffer = self.bufferA
- samples_to_process = noOfSamples
-
- # In aggregate mode, interleave the min/max buffers into one
- if (
- self.downsample_mode == "aggregate"
- and self.bufferB is not None
- and self.interleaved_buffer is not None
- ):
- interleave_start = time.perf_counter()
-
- # The SDK provides min/max data in separate buffers (B/A).
- # We interleave them into a single [min, max, min, max, ...]
- # stream for the consumer.
- total_interleaved_samples = noOfSamples * 2
-
- # Use numpy's more efficient interleaving
- np.stack(
- [
- self.bufferB[startIndex : startIndex + noOfSamples],
- self.bufferA[startIndex : startIndex + noOfSamples],
- ],
- axis=1,
- out=self.interleaved_buffer[:total_interleaved_samples].reshape(
- -1, 2
- ),
- )
-
- source_buffer = self.interleaved_buffer
- samples_to_process = total_interleaved_samples
- # After interleaving, the source index is always 0
- source_index = 0
-
- # Track interleaving performance
- self.interleave_durations.append(
- (time.perf_counter() - interleave_start) * 1000
- )
- else:
- source_index = startIndex
-
- # This loop copies data from the source_buffer into
- # our larger, shared application buffers (self.data_buffers).
- while samples_to_process > 0:
- # Determine how much space is left in the current application buffer.
- self.buf_free = self.comp_buffer_size - self.buf_used
- copy_size = min(samples_to_process, self.buf_free)
-
- # Copy the data slice.
- self.data_buffers[self.buf_idx][
- self.buf_used : self.buf_used + copy_size
- ] = source_buffer[source_index : source_index + copy_size]
-
- # Update pointers and remaining sample counts.
- self.buf_used += copy_size
- samples_to_process -= copy_size
- source_index += copy_size
-
- # If the current application buffer is full...
- if self.buf_used == self.comp_buffer_size:
- # ...send its index to the consumer.
- self.data_queue.put(self.buf_idx)
- try:
- # ...and get a new empty buffer from the consumer.
- self.buf_idx = self.empty_queue.get_nowait()
- self.buf_used = 0
- except queue.Empty:
- # This is a critical failure. The consumer is not keeping up.
- self.empty_pro_queue_count += 1
- logger.critical(
- "Producer queue is empty. Consumer cannot keep up. "
- "Shutting down to prevent data loss."
- )
- self.shutdown_event.set()
- return # Exit immediately.
-
- duration_ms = (time.perf_counter() - callback_start_time) * 1000
- self.callback_durations.append(duration_ms)
-
- def run_capture(self) -> None:
- """The main capture loop for the producer thread."""
- if not self.streaming_configured:
- self.run_streaming()
-
- # This loop polls the SDK for new data, which triggers the callback.
- while not self.shutdown_event.is_set():
- status = ps.ps5000aGetStreamingLatestValues(
- self.handle, self.callbackFuncPtr, None
- )
-
- if status == PICO_STATUS["PICO_BUFFER_STALL"]:
- self.hardware_buffer_overflow_count += 1
- logger.critical(
- "Picoscope hardware buffer overflow occurred. Data was lost."
- )
- elif status not in [
- PICO_STATUS["PICO_OK"],
- PICO_STATUS["PICO_NO_SAMPLES_AVAILABLE"],
- PICO_STATUS["PICO_BUSY"],
- PICO_STATUS["PICO_DATA_NOT_AVAILABLE"],
- PICO_STATUS["PICO_DRIVER_FUNCTION"],
- ]:
- check_status(status, "ps5000aGetStreamingLatestValues")
-
- # Yield the GIL to other threads.
- time.sleep(0.001)
-
- # --- Shutdown and reporting ---
- logger.info(
- f"Producer couldn't obtain an empty queue {self.empty_pro_queue_count} times."
- )
- if self.hardware_buffer_overflow_count > 0:
- logger.critical(
- f"Picoscope hardware buffer overflowed {self.hardware_buffer_overflow_count} times. "
- "This indicates the application could not process data fast enough from the driver."
- )
- if self.overvoltage_count > 0:
- logger.warning(
- f"Picoscope ADC over-ranged (clipped) {self.overvoltage_count} times."
- )
- if self.callback_durations:
- logger.info("--- Callback Performance ---")
- logger.info(f"Total callbacks: {len(self.callback_durations)}")
- logger.info(f"Min duration: {min(self.callback_durations):.2f} ms")
- logger.info(f"Max duration: {max(self.callback_durations):.2f} ms")
- logger.info(f"Avg duration: {np.mean(self.callback_durations):.2f} ms")
- logger.info("--------------------------")
-
- # Report aggregate mode performance if applicable
- if self.downsample_mode == "aggregate" and self.interleave_durations:
- logger.info("--- Aggregate Mode Performance ---")
- logger.info(
- f"Total interleave operations: {len(self.interleave_durations)}"
- )
- logger.info(
- f"Min interleave duration: {min(self.interleave_durations):.3f} ms"
- )
- logger.info(
- f"Max interleave duration: {max(self.interleave_durations):.3f} ms"
- )
- logger.info(
- f"Avg interleave duration: {np.mean(self.interleave_durations):.3f} ms"
- )
- logger.info("----------------------------------")
-
- def close_device(self) -> None:
- """Stops the Picoscope and closes the connection."""
- # Check if handle is valid before trying to close.
- if self.handle.value > 0:
- status_stop = ps.ps5000aStop(self.handle)
- if status_stop != PICO_STATUS["PICO_OK"]:
- logger.warning(f"ps5000aStop failed with status {status_stop}")
-
- status_close = ps.ps5000aCloseUnit(self.handle)
- if status_close != PICO_STATUS["PICO_OK"]:
- logger.warning(f"ps5000aCloseUnit failed with status {status_close}")
-
- # Invalidate handle to prevent reuse.
- self.handle.value = 0
- logger.info("Picoscope connection closed.")
diff --git a/picostream/reader.py b/picostream/reader.py
deleted file mode 100644
index aa3e425..0000000
--- a/picostream/reader.py
+++ /dev/null
@@ -1,259 +0,0 @@
-from __future__ import annotations
-
-from typing import Generator, Tuple
-
-import h5py
-import numpy as np
-
-from .conversion_utils import adc_to_mV, min_max_decimate_numba
-
-
-class PicoStreamReader:
- """A helper class to read and interpret data from picostream HDF5 files.
-
- This class provides a simple context-manager interface to open HDF5 files,
- read metadata, and retrieve blocks of data as time and voltage arrays.
-
- It correctly handles the `analog_offset_v` if it is present in the file's
- metadata.
-
- Example:
- >>> with PicoStreamReader("output.hdf5") as reader:
- ... print(f"Sample rate: {1e9 / reader.sample_interval_ns:.2f} S/s")
- ... # Iterate through the whole file
- ... for times, voltages in reader.get_block_iter(chunk_size=1_000_000):
- ... # process data
- ... pass
- ...
- ... # Or iterate with 20x decimation (mean)
- ... for times, voltages in reader.get_block_iter(
- ... chunk_size=1_000_000, decimation_factor=20, decimation_mode='mean'
- ... ):
- ... # process decimated data
- ... pass
- ...
- ... # Or get blocks sequentially
- ... reader.reset() # Reset internal counter
- ... while True:
- ... block = reader.get_next_block(chunk_size=500_000)
- ... if block is None:
- ... break
- ... times, voltages = block
- ... # process data
- """
-
- def __init__(self, hdf5_path: str):
- """Initializes the PicoStreamReader.
-
- Args:
- hdf5_path: Path to the HDF5 file.
- """
- self.hdf5_path = hdf5_path
- self._file: h5py.File | None = None
- self.dset: h5py.Dataset | None = None
- self._current_pos: int = 0
-
- # Metadata attributes, populated in __enter__
- self.sample_interval_ns: float = 0.0
- self.voltage_range_v: float = 0.0
- self.max_adc_val: int = 0
- self.downsample_mode: str = "average"
- self.hardware_downsample_ratio: int = 1
- self.analog_offset_v: float = 0.0
-
- def __enter__(self) -> PicoStreamReader:
- """Opens the HDF5 file and reads metadata."""
- self._file = h5py.File(self.hdf5_path, "r")
- self.dset = self._file["adc_counts"]
-
- # Read metadata from file attributes
- attrs = self._file.attrs
- base_sample_interval_ns = attrs["sample_interval_ns"]
- self.hardware_downsample_ratio = attrs.get("hardware_downsample_ratio", 1)
- self.sample_interval_ns = (
- base_sample_interval_ns * self.hardware_downsample_ratio
- )
- self.voltage_range_v = attrs["voltage_range_v"]
- if "max_adc" in attrs:
- self.max_adc_val = attrs["max_adc"]
- elif "resolution" in attrs:
- # Fallback for older files: calculate max_adc from resolution string
- res_str = attrs["resolution"] # e.g., "PS5000A_DR_16BIT"
- try:
- # Extract bit depth (e.g., 16) from the string
- res_int = int(res_str.split("_")[-1].replace("BIT", ""))
- self.max_adc_val = (2 ** (res_int - 1)) - 1
- except (ValueError, IndexError):
- raise KeyError(
- f"Could not parse 'resolution' attribute to determine max_adc: {res_str}"
- )
- else:
- raise KeyError(
- "HDF5 file is missing required 'max_adc' or 'resolution' attribute."
- )
- self.downsample_mode = attrs.get("downsample_mode", "average")
- self.analog_offset_v = attrs.get("analog_offset_v", 0.0)
-
- self.reset()
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- """Closes the HDF5 file."""
- if self._file:
- self._file.close()
-
- @property
- def num_samples(self) -> int:
- """Total number of samples in the dataset."""
- if self.dset:
- return self.dset.shape[0]
- return 0
-
- def reset(self) -> None:
- """Resets the internal position counter for get_next_block."""
- self._current_pos = 0
-
- def get_next_block(
- self,
- chunk_size: int,
- decimation_factor: int = 1,
- decimation_mode: str = "mean",
- ) -> Tuple[np.ndarray, np.ndarray] | None:
- """Retrieves the next block of data.
-
- Args:
- chunk_size: The maximum number of raw samples to retrieve.
- decimation_factor: The factor by which to decimate the data.
- decimation_mode: The decimation method ('mean' or 'min_max').
-
- Returns:
- A (times, voltages) tuple, or None if no more data is available.
- """
- if self._current_pos >= self.num_samples:
- return None
-
- size = min(chunk_size, self.num_samples - self._current_pos)
- block = self.get_block(
- size=size,
- start=self._current_pos,
- decimation_factor=decimation_factor,
- decimation_mode=decimation_mode,
- )
- self._current_pos += size
- return block
-
- def get_block_iter(
- self,
- chunk_size: int = 1_000_000,
- decimation_factor: int = 1,
- decimation_mode: str = "mean",
- ) -> Generator[Tuple[np.ndarray, np.ndarray], None, None]:
- """Yields data blocks as (times, voltages) tuples for the entire dataset.
-
- Args:
- chunk_size: The size of each raw data chunk to yield.
- decimation_factor: The factor by which to decimate the data.
- decimation_mode: The decimation method ('mean' or 'min_max').
- """
- num_samples = self.num_samples
- for start_idx in range(0, num_samples, chunk_size):
- size = min(chunk_size, num_samples - start_idx)
- yield self.get_block(
- size=size,
- start=start_idx,
- decimation_factor=decimation_factor,
- decimation_mode=decimation_mode,
- )
-
- def get_block(
- self,
- size: int,
- start: int = 0,
- decimation_factor: int = 1,
- decimation_mode: str = "mean",
- ) -> Tuple[np.ndarray, np.ndarray]:
- """Retrieves a specific block of data and converts it to time and voltage.
-
- Optionally, it can decimate the data using either averaging or a
- min-max technique.
-
- Args:
- size: The number of raw samples to retrieve from the file.
- start: The starting sample index.
- decimation_factor: The factor by which to decimate the data.
- decimation_mode: The decimation method ('mean' or 'min_max').
-
- Returns:
- A tuple containing:
- - times (np.ndarray): The time axis in seconds.
- - voltages (np.ndarray): The voltage data in millivolts.
- """
- if not self.dset:
- raise RuntimeError("File not open. Use this class as a context manager.")
-
- adc_data = self.dset[start : start + size]
-
- if adc_data.size == 0:
- return np.array([]), np.array([])
-
- # --- Decimation ---
- decimated_adc_data = adc_data
- final_decimation_factor = 1
-
- if decimation_factor > 1 and adc_data.size >= decimation_factor:
- final_decimation_factor = decimation_factor
-
- # Truncate data to be a multiple of the decimation factor
- num_windows = adc_data.size // decimation_factor
- truncated_len = num_windows * decimation_factor
- adc_data_to_process = adc_data[:truncated_len]
-
- if decimation_mode == "min_max":
- decimated_adc_data = min_max_decimate_numba(
- adc_data_to_process, decimation_factor
- )
- elif decimation_mode == "mean":
- means = adc_data_to_process.reshape(
- -1, decimation_factor
- ).mean(axis=1)
- decimated_adc_data = np.repeat(means, 2)
- else:
- raise ValueError(
- f"Unknown decimation_mode: '{decimation_mode}'. Use 'mean' or 'min_max'."
- )
-
- # --- Voltage Conversion ---
- voltages_mv = adc_to_mV(decimated_adc_data, self.voltage_range_v, self.max_adc_val)
- if self.analog_offset_v != 0.0:
- voltages_mv += self.analog_offset_v * 1000
-
- # --- Time Axis Creation ---
- points_per_timestep = 2 if self.downsample_mode == "aggregate" else 1
- time_per_timestep = self.sample_interval_ns * 1e-9
- time_per_raw_point = time_per_timestep / points_per_timestep
-
- start_time = (start / points_per_timestep) * time_per_timestep
-
- num_output_samples = voltages_mv.size
-
- if final_decimation_factor > 1:
- time_step_between_windows = time_per_raw_point * final_decimation_factor
-
- # Both 'mean' and 'min_max' now produce pairs of points.
- # The time axis is the start time of each decimation window, repeated.
- num_pairs = num_output_samples // 2
- window_start_times = (
- start_time + np.arange(num_pairs) * time_step_between_windows
- )
- times = np.repeat(window_start_times, 2)
- else:
- # Original time axis logic for non-decimated data
- if self.downsample_mode == "aggregate":
- num_pairs = num_output_samples // 2
- pair_times = start_time + np.arange(num_pairs) * time_per_timestep
- times = np.repeat(pair_times, 2)
- else:
- end_time = start_time + (num_output_samples - 1) * time_per_raw_point
- times = np.linspace(start_time, end_time, num_output_samples)
-
- return times, voltages_mv
diff --git a/picostream/ring_buffer.py b/picostream/ring_buffer.py
new file mode 100644
index 0000000..678f6da
--- /dev/null
+++ b/picostream/ring_buffer.py
@@ -0,0 +1,331 @@
+"""Ring buffer for high-throughput streaming data with lock-free reads/writes.
+
+This module provides a fixed-size circular buffer implemented as a NumPy array.
+It is designed for single-producer, multiple-reader scenarios where the producer
+writes continuously and readers access historical or recent data.
+
+The implementation relies on CPython's GIL atomicity for the write index update
+and NumPy's GIL-releasing memcpy for data transfer, providing efficient
+lock-free operation suitable for high-speed data acquisition.
+"""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Optional
+
+import numpy as np
+from loguru import logger
+
+if TYPE_CHECKING:
+ from picostream.acquisition_rate import AcquisitionRate
+
+
+class RingBuffer:
+ """A fixed-size ring buffer for multi-channel streaming data.
+
+ This buffer stores int16 samples for multiple channels in a circular fashion.
+ It supports a single writer and multiple concurrent readers. The write index
+ is updated atomically (under CPython's GIL), and NumPy slice operations
+ release the GIL during the actual memory copy.
+
+ Parameters
+ ----------
+ duration_s : float
+ Buffer duration in seconds. Determines capacity.
+ sample_rate : float
+ Expected rate at which samples will arrive at the buffer in Hz.
+ This should be the effective rate after any hardware downsampling.
+ num_channels : int, optional
+ Number of channels. Default is 2.
+
+ Attributes
+ ----------
+ capacity : int
+ Total samples per channel that fit in the buffer.
+ buffer : np.ndarray
+ The underlying storage array, shape (capacity, num_channels), dtype int16.
+ write_idx : int
+ Monotonically increasing write position. Use modulo capacity for indexing.
+ num_channels : int
+ Number of channels.
+ storage_rate_hz : float
+ Rate at which samples arrive at the buffer (post-downsampling if applicable).
+
+ Examples
+ --------
+ >>> buffer = RingBuffer(duration_s=30.0, sample_rate=62.5e6, num_channels=2)
+ >>> data = np.random.randint(-1000, 1000, size=(1000, 2), dtype=np.int16)
+ >>> buffer.write(data)
+ >>> recent = buffer.read_last(500)
+ """
+
+ def __init__(
+ self,
+ duration_s: float,
+ sample_rate: float,
+ num_channels: int = 2,
+ ) -> None:
+ """Initialise the ring buffer with the specified dimensions."""
+ self.duration_s = duration_s
+ self.num_channels = num_channels
+ self._acquisition_rate: Optional[AcquisitionRate] = None
+
+ # Initial capacity based on configured rate for validation
+ self._initial_sample_rate = sample_rate
+
+ self.capacity = int(duration_s * self._initial_sample_rate)
+ if self.capacity <= 0:
+ raise ValueError(
+ f"Buffer capacity must be positive, got {self.capacity} "
+ f"(duration_s={duration_s}, sample_rate={sample_rate})"
+ )
+
+ self.buffer = np.zeros((self.capacity, num_channels), dtype=np.int16)
+ self.write_idx = 0
+
+ logger.debug(
+ "RingBuffer initialised: capacity={} samples, "
+ "{:.2f} seconds, {} channels, {:.2f} MB",
+ self.capacity,
+ duration_s,
+ num_channels,
+ self.buffer.nbytes / (1024 * 1024),
+ )
+
+ def write(self, data: np.ndarray) -> None:
+ """Write data to the ring buffer.
+
+ Data is written at the current write position, wrapping around
+ to overwrite old data when the buffer is full. The write index
+ is updated atomically under CPython's GIL.
+
+ Parameters
+ ----------
+ data : np.ndarray
+ Array of shape (n_samples, num_channels) with dtype int16.
+
+ Raises
+ ------
+ ValueError
+ If data shape does not match num_channels or dtype is not int16.
+ """
+ if data.dtype != np.int16:
+ raise ValueError(f"Data must be int16, got {data.dtype}")
+
+ if data.ndim != 2 or data.shape[1] != self.num_channels:
+ raise ValueError(
+ f"Data shape must be (n_samples, {self.num_channels}), got {data.shape}"
+ )
+
+ n_samples = data.shape[0]
+ if n_samples == 0:
+ return
+
+ # If data exceeds capacity, only keep the last 'capacity' samples
+ if n_samples > self.capacity:
+ data = data[-self.capacity :]
+ n_samples = self.capacity
+
+ start_idx = self.write_idx % self.capacity
+ end_idx = (self.write_idx + n_samples) % self.capacity
+
+ if start_idx < end_idx:
+ # No wrap: single write
+ self.buffer[start_idx:end_idx] = data
+ else:
+ # Wrap-around: split into tail + head
+ tail_size = self.capacity - start_idx
+ self.buffer[start_idx:] = data[:tail_size]
+ self.buffer[:end_idx] = data[tail_size:]
+
+ self.write_idx += n_samples
+
+ def get_snapshot(self) -> int:
+ """Return the current write index position.
+
+ Returns
+ -------
+ int
+ The current write position (monotonically increasing).
+ """
+ return self.write_idx
+
+ def read_last(self, n_samples: int) -> np.ndarray:
+ """Read the most recent n_samples from the buffer.
+
+ Parameters
+ ----------
+ n_samples : int
+ Number of samples to read. Will be clamped to available data.
+
+ Returns
+ -------
+ np.ndarray
+ Array of shape (m, num_channels) where m <= n_samples,
+ containing the most recent data. Returns a copy, not a view.
+ """
+ available = min(n_samples, self.write_idx, self.capacity)
+ if available <= 0:
+ return np.empty((0, self.num_channels), dtype=np.int16)
+
+ return self.read_range(self.write_idx - available, self.write_idx)
+
+ def read_range(self, start_idx: int, end_idx: int) -> np.ndarray:
+ """Read a specific range of samples from the buffer.
+
+ Handles wrap-around automatically. Always returns a copy.
+
+ Parameters
+ ----------
+ start_idx : int
+ Starting sample index (inclusive).
+ end_idx : int
+ Ending sample index (exclusive).
+
+ Returns
+ -------
+ np.ndarray
+ Array of shape (n, num_channels) containing the requested range.
+ Returns empty array if end_idx <= start_idx.
+
+ Raises
+ ------
+ ValueError
+ If requesting more than capacity samples.
+ """
+ n_requested = end_idx - start_idx
+
+ if n_requested <= 0:
+ return np.empty((0, self.num_channels), dtype=np.int16)
+
+ if n_requested > self.capacity:
+ logger.warning(
+ "Requested {} samples exceeds buffer capacity {}, clamping",
+ n_requested,
+ self.capacity,
+ )
+ n_requested = self.capacity
+ start_idx = end_idx - n_requested
+
+ start_mod = start_idx % self.capacity
+ end_mod = end_idx % self.capacity
+
+ if start_mod < end_mod:
+ # No wrap
+ return self.buffer[start_mod:end_mod].copy()
+ else:
+ # Wrap-around
+ tail = self.buffer[start_mod:]
+ head = self.buffer[:end_mod]
+ return np.concatenate([tail, head], axis=0)
+
+ def read_since(self, position: int) -> np.ndarray:
+ """Read all data written since the given position.
+
+ Convenience method equivalent to read_range(position, get_snapshot()).
+
+ Parameters
+ ----------
+ position : int
+ Previous write position from get_snapshot().
+
+ Returns
+ -------
+ np.ndarray
+ Array of new data written since position. May be empty.
+ """
+ return self.read_range(position, self.write_idx)
+
+ def is_valid_range(self, start_idx: int) -> bool:
+ """Check if a start index is still within the valid buffer range.
+
+ Used by save threads to verify data hasn't been overwritten.
+
+ Note
+ ----
+ This check is best-effort and does not prevent a race condition:
+ the producer may overwrite the data after this check returns True
+ but before the caller reads the data. Callers should handle
+ RuntimeError from read_range() as a fallback.
+
+ Parameters
+ ----------
+ start_idx : int
+ The sample index to check.
+
+ Returns
+ -------
+ bool
+ True if start_idx is still in the buffer (producer hasn't
+ overwritten it), False otherwise.
+ """
+ return (self.write_idx - start_idx) <= self.capacity
+
+ def get_utilisation(self) -> float:
+ """Return the fraction of buffer currently filled with data.
+
+ Returns
+ -------
+ float
+ Fraction between 0.0 and 1.0.
+ """
+ filled = min(self.write_idx, self.capacity)
+ return filled / self.capacity
+
+ @property
+ def storage_rate_hz(self) -> float:
+ """Return the rate at which samples arrive at the buffer.
+
+ This is the effective sample rate after any hardware downsampling.
+ For example, with 10x hardware downsampling, this will be 1/10th
+ of the hardware ADC rate.
+
+ Returns
+ -------
+ float
+ The storage rate from acquisition_rate.
+
+ Raises
+ ------
+ RuntimeError
+ If acquisition_rate has not been set.
+ """
+ if self._acquisition_rate is None:
+ msg = "storage_rate_hz called before acquisition_rate was set"
+ raise RuntimeError(msg)
+ return self._acquisition_rate.storage_rate_hz
+
+ def set_acquisition_rate(self, rate: AcquisitionRate) -> None:
+ """Set the acquisition rate object for this buffer.
+
+ This is the single source of truth for all rate calculations.
+ Must be called before using the buffer for time-based calculations.
+
+ Parameters
+ ----------
+ rate : AcquisitionRate
+ The acquisition rate object with all rate information.
+ """
+ self._acquisition_rate = rate
+ logger.info("RingBuffer acquisition rate set: {}", rate)
+
+ def get_acquisition_rate(self) -> Optional[AcquisitionRate]:
+ """Get the acquisition rate object for this buffer.
+
+ Returns
+ -------
+ Optional[AcquisitionRate]
+ The acquisition rate, or None if not yet set.
+ """
+ return self._acquisition_rate
+
+ def get_duration_available(self) -> float:
+ """Return how many seconds of data are currently in the buffer.
+
+ Returns
+ -------
+ float
+ Duration in seconds of available data.
+ """
+ samples = min(self.write_idx, self.capacity)
+ return samples / self.storage_rate_hz
diff --git a/picostream/test_buffered_stream.py b/picostream/test_buffered_stream.py
new file mode 100644
index 0000000..48821b2
--- /dev/null
+++ b/picostream/test_buffered_stream.py
@@ -0,0 +1,344 @@
+"""Unit tests for PicoscopeBufferedStream with ring buffer and Zarr save."""
+
+import os
+import shutil
+import tempfile
+import time
+from typing import Generator
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+
+class MockPicoSDK:
+ """Mock PicoSDK for testing."""
+
+ PICO_OK = 0
+ PICO_BUSY = 1
+ PICO_NO_SAMPLES_AVAILABLE = 2
+ PICO_BUFFER_STALL = 3
+
+ PS5000A_CHANNEL = {
+ "PS5000A_CHANNEL_A": 0,
+ "PS5000A_CHANNEL_B": 1,
+ "PS5000A_CHANNEL_C": 2,
+ "PS5000A_CHANNEL_D": 3,
+ }
+ PS5000A_RANGE = {
+ "PS5000A_10MV": 0,
+ "PS5000A_100MV": 1,
+ "PS5000A_1V": 2,
+ "PS5000A_20V": 3,
+ }
+ PS5000A_COUPLING = {"PS5000A_DC": 1}
+ PS5000A_RATIO_MODE = {"PS5000A_RATIO_MODE_NONE": 0}
+ PS5000A_TIME_UNITS = {"PS5000A_NS": 0}
+ PS5000A_DEVICE_RESOLUTION = {"PS5000A_DR_12BIT": 0}
+
+ def __init__(self):
+ self.StreamingReadyType = MagicMock()
+ self._streaming_callback = None
+
+ def ps5000aOpenUnit(self, handle, serial, resolution):
+ handle.value = 1
+ return self.PICO_OK
+
+ def ps5000aCloseUnit(self, handle):
+ return self.PICO_OK
+
+ def ps5000aMaximumValue(self, handle, max_adc):
+ max_adc.value = 32767
+ return self.PICO_OK
+
+ def ps5000aSetChannel(self, handle, channel, enabled, coupling, range_val, offset):
+ return self.PICO_OK
+
+ def ps5000aSetSimpleTrigger(
+ self, handle, enabled, source, threshold, direction, delay, auto
+ ):
+ return self.PICO_OK
+
+ def ps5000aGetTimebase2(
+ self, handle, timebase, num_samples, interval, max_samples, segment
+ ):
+ interval.value = 16.0
+ return self.PICO_OK
+
+ def ps5000aSetDataBuffers(
+ self, handle, channel, buffer_max, buffer_min, size, segment, mode
+ ):
+ return self.PICO_OK
+
+ def ps5000aRunStreaming(
+ self,
+ handle,
+ interval,
+ unit,
+ pre,
+ max_post,
+ auto_stop,
+ down_ratio,
+ mode,
+ overview,
+ ):
+ return self.PICO_OK
+
+ def ps5000aStop(self, handle):
+ return self.PICO_OK
+
+ def ps5000aGetStreamingLatestValues(self, handle, callback, param):
+ return self.PICO_NO_SAMPLES_AVAILABLE
+
+ def ps5000aMemorySegments(self, handle, n_segments, max_samples):
+ # ctypes.c_int32/c_int64 passed in, set .value
+ if hasattr(max_samples, "value"):
+ max_samples.value = 100000
+ return self.PICO_OK
+
+ def ps5000aSetNoOfCaptures(self, handle, n_captures):
+ return self.PICO_OK
+
+
+@pytest.fixture
+def mock_picosdk() -> Generator[MockPicoSDK, None, None]:
+ """Provide a mock PicoSDK."""
+ mock = MockPicoSDK()
+ with patch.dict(
+ "sys.modules",
+ {
+ "picosdk": MagicMock(),
+ "picosdk.ps5000a": MagicMock(),
+ "picosdk.functions": MagicMock(),
+ },
+ ):
+ with patch("picostream.device.ps", mock):
+ with patch("picostream.device.adc2mV", MagicMock()):
+ with patch("picostream.device.assert_pico_ok", lambda x: None):
+ with patch(
+ "picostream.device.PICO_STATUS",
+ {
+ "PICO_OK": 0,
+ "PICO_BUSY": 1,
+ "PICO_NO_SAMPLES_AVAILABLE": 2,
+ "PICO_BUFFER_STALL": 3,
+ },
+ ):
+ yield mock
+
+
+@pytest.fixture
+def temp_zarr_path() -> Generator[str, None, None]:
+ """Provide a temporary path for Zarr files."""
+ temp_dir = tempfile.mkdtemp()
+ zarr_path = os.path.join(temp_dir, "test.zarr")
+ yield zarr_path
+ shutil.rmtree(temp_dir, ignore_errors=True)
+
+
+class TestPicoscopeBufferedStream:
+ """Test cases for PicoscopeBufferedStream."""
+
+ def test_configure_streaming_creates_ring_buffer(
+ self, mock_picosdk: MockPicoSDK
+ ) -> None:
+ """Test that configure_streaming creates a ring buffer."""
+ from picostream.device import PicoscopeBufferedStream
+
+ device = PicoscopeBufferedStream("test_device", resolution=12)
+ device._chandle = MagicMock()
+ device._is_connected = True
+ device._max_adc_value = MagicMock()
+ device._max_adc_value.value = 32767
+
+ device.configure_streaming(
+ sample_rate=1_000_000.0,
+ channels=[0, 1],
+ voltage_ranges=[1.0, 1.0],
+ buffer_duration_s=1.0,
+ )
+
+ assert device._ring_buffer is not None
+ assert device._ring_buffer.num_channels == 2
+ # Note: acquisition_rate is set during start_streaming, not configure_streaming
+ # storage_rate_hz will fail fast unless acquisition_rate is set first
+
+ def test_configure_streaming_validates_channels(
+ self, mock_picosdk: MockPicoSDK
+ ) -> None:
+ """Test that configure_streaming validates channel parameters.
+
+ Empty channels list is invalid - at least one channel must be enabled
+ for acquisition. Hardware rate becomes 0 with 0 channels.
+ """
+ from picostream.device import PicoscopeBufferedStream
+
+ device = PicoscopeBufferedStream("test_device", resolution=12)
+ device._chandle = MagicMock()
+ device._is_connected = True
+ device._max_adc_value = MagicMock()
+ device._max_adc_value.value = 32767
+
+ # Empty channels list should raise DeviceConfigurationError with helpful message
+ with pytest.raises(Exception, match="At least one channel must be enabled"):
+ device.configure_streaming(
+ sample_rate=1_000_000.0,
+ channels=[],
+ voltage_ranges=[],
+ buffer_duration_s=1.0,
+ )
+
+ def test_start_save_lifecycle(
+ self, mock_picosdk: MockPicoSDK, temp_zarr_path: str
+ ) -> None:
+ """Test the start_save / stop_save lifecycle."""
+ from picostream.device import PicoscopeBufferedStream
+
+ device = PicoscopeBufferedStream("test_device", resolution=12)
+ device._chandle = MagicMock()
+ device._is_connected = True
+ device._max_adc_value = MagicMock()
+ device._max_adc_value.value = 32767
+
+ device.configure_streaming(
+ sample_rate=100_000.0,
+ channels=[0, 1],
+ voltage_ranges=[1.0, 1.0],
+ buffer_duration_s=2.0,
+ )
+
+ assert device.start_save(lookback_seconds=0.5, output_path=temp_zarr_path)
+ assert device._is_saving
+
+ time.sleep(0.1)
+
+ path = device.stop_save(keep=True)
+ assert path == temp_zarr_path
+
+ # Wait for save to complete (non-blocking API)
+ deadline = time.time() + 5.0
+ while not device.is_save_finished() and time.time() < deadline:
+ time.sleep(0.01)
+
+ assert not device._is_saving
+
+ def test_stop_save_discard_deletes_file(
+ self, mock_picosdk: MockPicoSDK, temp_zarr_path: str
+ ) -> None:
+ """Test that stop_save(keep=False) deletes the Zarr directory."""
+ from picostream.device import PicoscopeBufferedStream
+
+ device = PicoscopeBufferedStream("test_device", resolution=12)
+ device._chandle = MagicMock()
+ device._is_connected = True
+ device._max_adc_value = MagicMock()
+ device._max_adc_value.value = 32767
+
+ device.configure_streaming(
+ sample_rate=100_000.0,
+ channels=[0, 1],
+ voltage_ranges=[1.0, 1.0],
+ buffer_duration_s=2.0,
+ )
+
+ device.start_save(lookback_seconds=0.5, output_path=temp_zarr_path)
+ time.sleep(0.1)
+
+ device.stop_save(keep=False)
+
+ # Wait for save to complete (non-blocking API)
+ deadline = time.time() + 5.0
+ while not device.is_save_finished() and time.time() < deadline:
+ time.sleep(0.01)
+
+ assert not os.path.exists(temp_zarr_path)
+
+ def test_get_save_status_idle(self, mock_picosdk: MockPicoSDK) -> None:
+ """Test get_save_status returns idle when not saving."""
+ from picostream.device import PicoscopeBufferedStream
+
+ device = PicoscopeBufferedStream("test_device", resolution=12)
+ device._chandle = MagicMock()
+ device._is_connected = True
+ device._max_adc_value = MagicMock()
+ device._max_adc_value.value = 32767
+
+ device.configure_streaming(
+ sample_rate=100_000.0,
+ channels=[0, 1],
+ voltage_ranges=[1.0, 1.0],
+ buffer_duration_s=1.0,
+ )
+
+ status = device.get_save_status()
+ assert status["state"] == "idle"
+
+ def test_ring_buffer_property(self, mock_picosdk: MockPicoSDK) -> None:
+ """Test that ring_buffer property returns the buffer."""
+ from picostream.device import PicoscopeBufferedStream
+
+ device = PicoscopeBufferedStream("test_device", resolution=12)
+ device._chandle = MagicMock()
+ device._is_connected = True
+ device._max_adc_value = MagicMock()
+ device._max_adc_value.value = 32767
+
+ assert device.ring_buffer is None
+
+ device.configure_streaming(
+ sample_rate=1_000_000.0,
+ channels=[0, 1],
+ voltage_ranges=[1.0, 1.0],
+ buffer_duration_s=1.0,
+ )
+
+ assert device.ring_buffer is not None
+
+ def test_save_worker_captures_trigger_position_in_main_thread(
+ self, mock_picosdk: MockPicoSDK, temp_zarr_path: str
+ ) -> None:
+ """Test that trigger position is captured in start_save, not in worker.
+
+ This verifies the race condition fix where pre-trigger data could be
+ overwritten if the worker thread was delayed in starting.
+ """
+ from picostream.device import PicoscopeBufferedStream
+
+ device = PicoscopeBufferedStream("test_device", resolution=12)
+ device._chandle = MagicMock()
+ device._is_connected = True
+ device._max_adc_value = MagicMock()
+ device._max_adc_value.value = 32767
+
+ device.configure_streaming(
+ sample_rate=100_000.0,
+ channels=[0, 1],
+ voltage_ranges=[1.0, 1.0],
+ buffer_duration_s=2.0,
+ )
+
+ # Write some data to the ring buffer first
+ import numpy as np
+
+ test_data = np.ones((1000, 2), dtype=np.int16) * 100
+ device._ring_buffer.write(test_data)
+
+ # Capture the expected trigger position before starting save
+ expected_trigger_pos = device._ring_buffer.get_snapshot()
+
+ # Start save - this should capture the trigger position in main thread
+ device.start_save(lookback_seconds=0.5, output_path=temp_zarr_path)
+
+ # Verify the save worker received the pre-captured position
+ # The worker args should include the trigger_pos as second argument
+ assert device._save_thread is not None
+
+ # Stop the save
+ device.stop_save(keep=True)
+ deadline = time.time() + 5.0
+ while not device.is_save_finished() and time.time() < deadline:
+ time.sleep(0.01)
+
+ # The key assertion: trigger_pos was captured BEFORE thread spawn
+ # If the old bug existed, trigger_pos would be captured in worker
+ # which could be after more data was written
+ assert expected_trigger_pos > 0 # Data was written before save
diff --git a/picostream/test_data_pipeline.py b/picostream/test_data_pipeline.py
new file mode 100644
index 0000000..0550999
--- /dev/null
+++ b/picostream/test_data_pipeline.py
@@ -0,0 +1,496 @@
+"""Unit tests for DataPipeline class."""
+
+from __future__ import annotations
+
+import numpy as np
+import pytest
+
+from picostream.acquisition_rate import AcquisitionRate, DownsampleMode
+from picostream.data_pipeline import DataPipeline, PipelineConfig
+
+
+@pytest.fixture
+def base_config():
+ """Provide a base pipeline configuration for testing."""
+ return PipelineConfig(
+ resolution=16,
+ voltage_ranges={0: 20.0, 1: 20.0},
+ offsets_v={0: 0.0, 1: 0.0},
+ max_adc_value=32767,
+ target_plot_points=20000,
+ )
+
+
+@pytest.fixture
+def base_acquisition_rate():
+ """Provide a base acquisition rate for testing."""
+ return AcquisitionRate(
+ hardware_rate_hz=125e6,
+ num_channels=2,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.NONE,
+ )
+
+
+@pytest.fixture
+def aggregate_acquisition_rate():
+ """Provide an acquisition rate with AGGREGATE mode and downsampling."""
+ return AcquisitionRate(
+ hardware_rate_hz=125e6,
+ num_channels=2,
+ downsample_ratio=10,
+ downsample_mode=DownsampleMode.AGGREGATE,
+ )
+
+
+@pytest.fixture
+def base_pipeline(base_config, base_acquisition_rate):
+ """Provide a base pipeline for testing."""
+ return DataPipeline(base_config, base_acquisition_rate)
+
+
+@pytest.fixture
+def aggregate_pipeline(base_config, aggregate_acquisition_rate):
+ """Provide a pipeline with AGGREGATE mode."""
+ return DataPipeline(base_config, aggregate_acquisition_rate)
+
+
+class TestPipelineConfig:
+ """Tests for PipelineConfig dataclass."""
+
+ def test_config_immutable(self, base_config):
+ """Test that config is frozen and immutable."""
+ with pytest.raises(AttributeError):
+ base_config.resolution = 12
+
+
+class TestDataPipelineInit:
+ """Tests for DataPipeline initialisation."""
+
+ def test_init_with_base_config(self, base_pipeline):
+ """Test pipeline initialisation."""
+ assert base_pipeline.acquisition_rate.downsample_mode == DownsampleMode.NONE
+ assert base_pipeline._last_decimation == 1
+ assert len(base_pipeline._remainder) == 0
+
+ def test_actual_rate_defaults_to_requested(
+ self, base_pipeline, base_acquisition_rate
+ ):
+ """Test that actual rate comes from acquisition_rate."""
+ assert base_pipeline.acquisition_rate.per_channel_rate_hz == 62.5e6
+
+ def test_actual_rate_uses_hardware_rate_when_set(self, base_acquisition_rate):
+ """Test that actual rate uses hardware rate when available."""
+ acquisition_rate = AcquisitionRate(
+ hardware_rate_hz=100e6,
+ num_channels=2,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.NONE,
+ )
+ assert acquisition_rate.per_channel_rate_hz == 50.0e6
+
+
+class TestRateCalculations:
+ """Tests for rate calculation properties."""
+
+ def test_per_channel_rate_no_downsample(self, base_acquisition_rate):
+ """Test per channel rate without hardware downsampling."""
+ # 125 MS/s total / 2 channels = 62.5 MS/s per channel
+ assert base_acquisition_rate.per_channel_rate_hz == 62.5e6
+
+ def test_per_channel_rate_with_downsample(self):
+ """Test per channel rate with hardware downsampling."""
+ acquisition_rate = AcquisitionRate(
+ hardware_rate_hz=125e6,
+ num_channels=2,
+ downsample_ratio=10,
+ downsample_mode=DownsampleMode.NONE,
+ )
+ # 125 MS/s total / 2 channels / 10 = 6.25 MS/s per channel
+ assert acquisition_rate.per_channel_rate_hz == 6.25e6
+
+ def test_storage_rate_normal_mode(self, base_acquisition_rate):
+ """Test storage rate in normal mode."""
+ # Same as per_channel_rate in NONE mode * 2 channels
+ assert base_acquisition_rate.storage_rate_hz == 125e6
+
+ def test_storage_rate_aggregate_mode(self, aggregate_acquisition_rate):
+ """Test storage rate in AGGREGATE mode (2x samples)."""
+ # AGGREGATE mode doubles the storage rate (125e6 / 10) * 2 = 25e6
+ assert aggregate_acquisition_rate.storage_rate_hz == 25e6
+
+
+class TestTimeAxis:
+ """Tests for time axis generation."""
+
+ def test_create_time_axis_empty(self, base_pipeline):
+ """Test time axis creation with empty window duration."""
+ time_axis = base_pipeline.create_time_axis_for_window(0, 1)
+ assert len(time_axis) == 0
+
+ def test_create_time_axis_normal_mode_lossless(self, base_pipeline):
+ """Test time axis in normal mode without decimation."""
+ # 5-second window, decimation=1 (lossless mode)
+ # 5s * 62.5 MS/s per channel / 2 = 156250000 time points
+ time_axis = base_pipeline.create_time_axis_for_window(5.0, 1)
+
+ # No decimation means no min/max pairs
+ n_time_points = int(base_pipeline.time_to_samples(5.0) // 1)
+ assert len(time_axis) == n_time_points
+ assert time_axis[0] == 0.0
+ assert abs(time_axis[-1] - 5.0) < 1e-12
+
+ def test_create_time_axis_normal_mode_with_decimation(self, base_pipeline):
+ """Test time axis in normal mode with decimation."""
+ # 5-second window, decimation=100
+ # Software decimation produces min/max pairs
+ time_axis = base_pipeline.create_time_axis_for_window(5.0, 100)
+
+ # With decimation > 1 and target_plot_points set, we get min/max pairs
+ n_time_points = int(base_pipeline.time_to_samples(5.0) // 100)
+ assert len(time_axis) == n_time_points * 2 # Doubled for min/max
+ assert abs(time_axis[-1] - 5.0) < 1e-12
+ # First 2 values should be same (first time point repeated for min/max)
+ assert time_axis[0] == time_axis[1]
+
+ def test_create_time_axis_normal_mode_lossless_no_double(
+ self, base_acquisition_rate
+ ):
+ """Test lossless mode doesn't double time axis even with target_plot_points."""
+ # Create config with target_plot_points=None for true lossless
+ from picostream.data_pipeline import PipelineConfig
+
+ config_lossless = PipelineConfig(
+ resolution=16,
+ voltage_ranges={0: 20.0, 1: 20.0},
+ offsets_v={0: 0.0, 1: 0.0},
+ max_adc_value=32767,
+ target_plot_points=None, # Lossless
+ )
+ pipeline_lossless = DataPipeline(config_lossless, base_acquisition_rate)
+
+ time_axis = pipeline_lossless.create_time_axis_for_window(5.0, 1)
+ n_time_points = int(pipeline_lossless.time_to_samples(5.0) // 1)
+
+ # Lossless mode should NOT double the time axis
+ assert len(time_axis) == n_time_points
+ assert time_axis[0] != time_axis[1] if len(time_axis) > 1 else True
+
+ def test_create_time_axis_aggregate_mode(self, aggregate_pipeline):
+ """Test time axis in AGGREGATE mode with fixed window duration."""
+ # 5-second window in AGGREGATE mode
+ # AGGREGATE mode: time_to_samples() returns samples already accounting for min/max pairs.
+ # With 10x downsample: display_rate_per_channel = 6.25e6 samples/s
+ # 5s window = 31.25e6 time points * 2 for min/max = 62.5M storage samples
+ # Each time point produces 2 time axis values (repeated), so 62.5M total values.
+ time_axis = aggregate_pipeline.create_time_axis_for_window(5.0, 1)
+
+ # time_to_samples returns storage samples (already doubled for AGGREGATE)
+ window_samples = aggregate_pipeline.time_to_samples(5.0)
+ # Each time point is a min/max pair (2 storage samples), doubled for display
+ expected_length = int(window_samples // 2 * 2)
+ assert len(time_axis) == expected_length
+ # Should span exactly 5 seconds
+ assert abs(time_axis[-1] - 5.0) < 1e-12
+ # First 2 values should be same (first time point repeated for min/max)
+ assert time_axis[0] == time_axis[1]
+
+ def test_create_time_axis_with_data_duration(self, base_pipeline):
+ """Test right-aligned time axis when data_duration < window_duration."""
+ # 5-second window, but only 3 seconds of data
+ time_axis = base_pipeline.create_time_axis_for_window(5.0, 1, data_duration=3.0)
+
+ # Should start at 5.0 - 3.0 = 2.0
+ assert time_axis[0] == 2.0
+ assert abs(time_axis[-1] - 5.0) < 1e-12
+
+
+class TestDecimation:
+ """Tests for decimation calculations."""
+
+ def test_calculate_decimation_no_need(self, base_pipeline):
+ """Test decimation factor when not needed."""
+ # Window with fewer samples than target
+ assert base_pipeline.calculate_decimation(10000) == 1
+
+ def test_calculate_decimation_needed(self, base_pipeline):
+ """Test decimation factor calculation."""
+ # 400k samples / 20k target = 20x decimation
+ assert base_pipeline.calculate_decimation(400000) == 20
+
+ def test_should_invalidate_cache_same_decimation(self, base_pipeline):
+ """Test cache invalidation detection - same decimation."""
+ base_pipeline._last_decimation = 10
+ assert not base_pipeline.should_invalidate_cache(10)
+
+ def test_should_invalidate_cache_different(self, base_pipeline):
+ """Test cache invalidation detection - different decimation."""
+ base_pipeline._last_decimation = 10
+ assert base_pipeline.should_invalidate_cache(20)
+
+
+class TestChannelProcessing:
+ """Tests for channel data processing."""
+
+ def test_process_empty_data(self, base_pipeline):
+ """Test processing empty data."""
+ data, has_data = base_pipeline.process_channel_data(
+ np.array([], dtype=np.int16), 0, 1
+ )
+ assert len(data) == 0
+ assert not has_data
+
+ def test_process_small_data_no_decimation(self, base_pipeline):
+ """Test processing small amount of data."""
+ raw_data = np.array([1000, 2000, 3000], dtype=np.int16)
+ data, has_data = base_pipeline.process_channel_data(raw_data, 0, 1)
+
+ assert has_data
+ assert len(data) == 3
+ # Check conversion happened (mV range is +/-20V = 40000mV total)
+ # 1000 counts / 32767 * 20000mV = expected mV
+ expected_mv = 1000 * 20000.0 / 32767
+ assert abs(data[0] - expected_mv) < 0.1
+
+ def test_process_with_decimation(self, base_pipeline):
+ """Test processing with decimation."""
+ # 100 samples, decimation factor 10
+ raw_data = np.arange(100, dtype=np.int16)
+ data, has_data = base_pipeline.process_channel_data(raw_data, 0, 10)
+
+ assert has_data
+ # 100 samples / 10 = 10 groups, each produces 2 values (min/max)
+ assert len(data) == 20
+
+ def test_process_with_remainder(self, base_pipeline):
+ """Test that remainder is handled correctly."""
+ # First chunk: 15 samples with decimation 10
+ raw_data1 = np.arange(15, dtype=np.int16)
+ data1, has_data1 = base_pipeline.process_channel_data(raw_data1, 0, 10)
+
+ # Should decimate 10 samples (1 group), keep 5 as remainder
+ # 1 group produces 2 values (min/max pair)
+ assert has_data1
+ assert len(data1) == 2 # 10 samples / 10 = 1 group = 2 values
+
+ # Second chunk: 15 more samples
+ raw_data2 = np.arange(15, 30, dtype=np.int16)
+ data2, has_data2 = base_pipeline.process_channel_data(raw_data2, 0, 10)
+
+ # 5 remainder + 15 new = 20, decimate 10 (2 groups), keep 0 remainder
+ # 2 groups produce 4 values
+ assert has_data2
+ assert len(data2) == 4
+
+ def test_process_with_offset(self, base_config, base_acquisition_rate):
+ """Test that offset is applied correctly."""
+ config = PipelineConfig(
+ resolution=16,
+ voltage_ranges={0: 20.0, 1: 20.0},
+ offsets_v={0: 0.5, 1: 0.0}, # 0.5V offset on channel 0
+ max_adc_value=32767,
+ target_plot_points=20000,
+ )
+ pipeline = DataPipeline(config, base_acquisition_rate)
+
+ raw_data = np.array([0], dtype=np.int16)
+ data, has_data = pipeline.process_channel_data(raw_data, 0, 1)
+
+ assert has_data
+ # 0 ADC counts -> 0 mV + 0.5V offset = 500mV
+ assert abs(data[0] - 500.0) < 0.1
+
+ def test_channel_independence(self, base_pipeline):
+ """Test that channels maintain independent remainders."""
+ # Process different amounts for each channel
+ raw_data = np.arange(15, dtype=np.int16)
+ data0, _ = base_pipeline.process_channel_data(raw_data, 0, 10)
+ data1, _ = base_pipeline.process_channel_data(raw_data, 1, 10)
+
+ # Both should have remainders
+ assert len(base_pipeline._remainder[0]) > 0
+ assert len(base_pipeline._remainder[1]) > 0
+
+
+class TestLagMetrics:
+ """Tests for lag metrics calculation."""
+
+ def test_lag_metrics_zero(self, base_pipeline):
+ """Test lag calculation when consumer caught up."""
+ metrics = base_pipeline.get_lag_metrics(1000, 1000)
+ assert metrics.samples == 0
+ assert metrics.milliseconds == 0.0
+
+ def test_lag_metrics_positive(self, base_pipeline):
+ """Test lag calculation with positive lag."""
+ # 125 MS/s storage rate
+ # 62500 samples lag at 125 MS/s = 0.5ms
+ metrics = base_pipeline.get_lag_metrics(100625, 100000)
+ assert metrics.samples == 625
+ assert abs(metrics.milliseconds - 0.005) < 0.001
+
+ def test_lag_formatting(self, base_pipeline):
+ """Test lag formatting."""
+ lag_str = base_pipeline.format_lag(100625, 100000)
+ assert "ms" in lag_str
+ # Now uses unit prefixes (S, kS, MS) instead of "samples"
+ assert " S" in lag_str or " MS" in lag_str or " kS" in lag_str
+
+ def test_lag_formatting_unit_prefixes(self, base_pipeline):
+ """Test lag formatting with different unit prefixes."""
+ # Small lag - samples shown as-is
+ small_lag = base_pipeline.format_lag(1000100, 1000000)
+ assert "100 S" in small_lag
+
+ # Medium lag - kilosamples
+ medium_lag = base_pipeline.format_lag(1100000, 1000000)
+ assert "kS" in medium_lag
+
+ # Large lag - megasamples (typical case)
+ large_lag = base_pipeline.format_lag(5000000, 1000000)
+ assert "MS" in large_lag
+ assert "4 MS" in large_lag or "4.0 MS" in large_lag
+
+
+class TestConversions:
+ """Tests for sample/time conversions."""
+
+ def test_samples_to_time_normal(self, base_acquisition_rate):
+ """Test samples to time conversion in normal mode."""
+ # 62500000 samples at 125 MS/s total / 2 = 62.5 MS/s per channel = 1 second
+ duration = base_acquisition_rate.samples_to_seconds(62500000)
+ assert abs(duration - 1.0) < 0.001
+
+ def test_samples_to_time_aggregate(self, aggregate_acquisition_rate):
+ """Test samples to time conversion in AGGREGATE mode."""
+ # 125000000 samples (62500000 min/max pairs) at 6.25 MS/s = 10 seconds
+ duration = aggregate_acquisition_rate.samples_to_seconds(125000000)
+ assert abs(duration - 10.0) < 0.001
+
+ def test_time_to_samples(self, base_acquisition_rate):
+ """Test time to samples conversion."""
+ # 1 second at 62.5 MS/s = 62500000 samples
+ samples = base_acquisition_rate.seconds_to_samples(1.0)
+ assert samples == 62500000
+
+ def test_time_to_samples_aggregate(self, aggregate_acquisition_rate):
+ """Test time to samples returns buffer samples in AGGREGATE mode."""
+ # AGGREGATE mode stores min/max pairs, so returns 2 * time points
+ samples = aggregate_acquisition_rate.seconds_to_samples(1.0)
+ assert samples == 12500000 # 6.25e6 time points per second * 2 for min/max
+
+ def test_get_display_duration_no_decimation(self, base_pipeline):
+ """Test display duration with no decimation."""
+ # 1000 display samples (500 pairs) at decimation 1 = 500 original samples
+ # 500 / 62.5e6 = 8 microseconds
+ duration = base_pipeline.get_display_duration(1000, 1)
+ expected = 500 / 62.5e6
+ assert abs(duration - expected) < 1e-12
+
+ def test_get_display_duration_with_decimation(self, base_pipeline):
+ """Test display duration with decimation."""
+ # 2000 display samples (1000 pairs) at decimation 10 = 10000 original samples
+ # 10000 / 62.5e6 = 160 microseconds
+ duration = base_pipeline.get_display_duration(2000, 10)
+ expected = 10000 / 62.5e6
+ assert abs(duration - expected) < 1e-12
+
+ def test_get_display_duration_aggregate_mode(self, aggregate_pipeline):
+ """Test display duration in AGGREGATE mode."""
+ # Should use per_channel_rate (6.25 MS/s with downsample=10), not storage_rate
+ # 1000 display samples (500 pairs) at decimation 1 = 500 original samples
+ # 500 / 6.25e6 = 80 microseconds
+ duration = aggregate_pipeline.get_display_duration(1000, 1)
+ expected = 500 / 6.25e6
+ assert abs(duration - expected) < 1e-12
+
+ def test_get_display_duration_aggregate_mode_uses_display_rate(
+ self, aggregate_pipeline
+ ):
+ """Test that AGGREGATE mode uses display_rate_per_channel."""
+ duration = aggregate_pipeline.get_display_duration(1000, 1)
+ # 1000 samples = 500 time points * 1 decimation = 500 original samples
+ # 500 / 6.25e6 = 80 microseconds
+ expected = 500 / 6.25e6
+ assert abs(duration - expected) < 1e-12
+
+ def test_get_display_duration_zero_samples(self, base_pipeline):
+ """Test display duration with zero samples."""
+ assert base_pipeline.get_display_duration(0, 1) == 0.0
+
+ def test_get_display_duration_zero_decimation(self, base_pipeline):
+ """Test display duration with zero decimation."""
+ assert base_pipeline.get_display_duration(100, 0) == 0.0
+
+
+class TestCacheManagement:
+ """Tests for cache invalidation."""
+
+ def test_invalidate_cache_clears_remainders(self, base_pipeline):
+ """Test that cache invalidation clears remainders."""
+ # Create some remainders
+ raw_data = np.arange(15, dtype=np.int16)
+ base_pipeline.process_channel_data(raw_data, 0, 10)
+ assert len(base_pipeline._remainder) > 0
+
+ # Invalidate cache
+ base_pipeline.invalidate_cache()
+ assert len(base_pipeline._remainder) == 0
+ assert base_pipeline._last_decimation == 1
+
+ def test_max_cache_points(self, base_pipeline):
+ """Test max cache points calculation."""
+ # 100000 samples, decimation 10
+ max_points = base_pipeline.get_max_cache_points(100000, 10)
+ # (100000 // 10) * 2 = 20000 (display values, i.e., min/max pairs)
+ assert max_points == 20000
+
+ def test_max_time_points(self, base_pipeline):
+ """Test max time points calculation."""
+ # 100000 samples, decimation 10
+ max_points = base_pipeline.get_max_time_points(100000, 10)
+ # 100000 // 10 = 10000 (time points, not min/max pairs)
+ assert max_points == 10000
+
+
+class TestModeChecks:
+ """Tests for mode checking methods."""
+
+ def test_is_aggregate_mode_true(self, aggregate_pipeline):
+ """Test AGGREGATE mode detection."""
+ assert aggregate_pipeline.is_aggregate_mode()
+
+ def test_is_aggregate_mode_false(self, base_pipeline):
+ """Test non-AGGREGATE mode detection."""
+ assert not base_pipeline.is_aggregate_mode()
+
+
+class TestEdgeCases:
+ """Tests for edge cases and error handling."""
+
+ def test_negative_lag_clamped(self, base_pipeline):
+ """Test that negative lag is clamped to zero."""
+ # Consumer ahead of producer (shouldn't happen, but handle gracefully)
+ metrics = base_pipeline.get_lag_metrics(100, 200)
+ assert metrics.samples == 0
+
+ def test_process_with_missing_voltage_range(
+ self, base_config, base_acquisition_rate
+ ):
+ """Test processing when voltage range not configured."""
+ config = PipelineConfig(
+ resolution=16,
+ voltage_ranges={}, # Empty ranges
+ offsets_v={0: 0.0, 1: 0.0},
+ max_adc_value=32767,
+ target_plot_points=20000,
+ )
+ pipeline = DataPipeline(config, base_acquisition_rate)
+
+ raw_data = np.array([1000, 2000], dtype=np.int16)
+ data, has_data = pipeline.process_channel_data(raw_data, 0, 1)
+
+ # Should still produce data, just as floats
+ assert has_data
+ assert len(data) == 2
+ assert data.dtype == np.float64
diff --git a/picostream/test_live_plotter.py b/picostream/test_live_plotter.py
new file mode 100644
index 0000000..a2f83d3
--- /dev/null
+++ b/picostream/test_live_plotter.py
@@ -0,0 +1,437 @@
+"""Tests for the LivePlotter class."""
+
+from __future__ import annotations
+
+import numpy as np
+import pytest
+
+from picostream.acquisition_rate import AcquisitionRate, DownsampleMode
+from picostream.ring_buffer import RingBuffer
+
+pytest.importorskip("PyQt6.QtWidgets", reason="PyQt6 not available")
+
+from picostream.dfplot import LivePlotter
+
+
+@pytest.fixture(scope="session")
+def qapp():
+ """Provide a Qt application instance for the test session."""
+ from PyQt6.QtWidgets import QApplication
+
+ app = QApplication.instance()
+ if app is None:
+ app = QApplication([])
+ return app
+
+
+@pytest.fixture
+def sample_metadata():
+ """Provide sample metadata for plotter tests."""
+ return {
+ "sample_rate": 62.5e6,
+ "channels": [0, 1],
+ "voltage_ranges": {0: 20.0, 1: 20.0},
+ "offsets_v": {0: 0.0, 1: 0.0},
+ "resolution": 16,
+ "downsample_mode": "NONE",
+ }
+
+
+@pytest.fixture
+def populated_buffer():
+ """Create a ring buffer with test data."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=1000.0, num_channels=2)
+
+ data = np.zeros((500, 2), dtype=np.int16)
+ data[:, 0] = np.sin(np.linspace(0, 4 * np.pi, 500)) * 10000
+ data[:, 1] = np.cos(np.linspace(0, 4 * np.pi, 500)) * 8000
+
+ buffer.write(data)
+ return buffer
+
+
+@pytest.fixture
+def acquisition_rate_none():
+ """Create an AcquisitionRate for NONE mode."""
+ return AcquisitionRate(
+ hardware_rate_hz=125e6,
+ num_channels=2,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.NONE,
+ )
+
+
+@pytest.fixture
+def acquisition_rate_aggregate():
+ """Create an AcquisitionRate for AGGREGATE mode."""
+ return AcquisitionRate(
+ hardware_rate_hz=125e6,
+ num_channels=2,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.AGGREGATE,
+ )
+
+
+class TestLivePlotterBasic:
+ """Basic functionality tests."""
+
+ def test_init_with_buffer(self, qapp, sample_metadata, populated_buffer):
+ """Test plotter initialisation with a ring buffer."""
+ plotter = LivePlotter(
+ ring_buffer=populated_buffer,
+ channels=sample_metadata["channels"],
+ voltage_ranges=sample_metadata["voltage_ranges"],
+ offsets_v=sample_metadata["offsets_v"],
+ resolution=sample_metadata["resolution"],
+ downsample_mode=sample_metadata["downsample_mode"],
+ display_window_seconds=0.5,
+ )
+
+ assert plotter.ring_buffer is populated_buffer
+ # sample_rate is only available after acquisition_rate is set
+ assert plotter.channels == [0, 1]
+ assert plotter.resolution == 16
+
+ plotter.close()
+
+ def test_init_without_buffer(self, qapp, sample_metadata):
+ """Test plotter initialisation with None buffer (waiting state)."""
+ plotter = LivePlotter(
+ ring_buffer=None,
+ channels=sample_metadata["channels"],
+ voltage_ranges=sample_metadata["voltage_ranges"],
+ offsets_v=sample_metadata["offsets_v"],
+ resolution=sample_metadata["resolution"],
+ downsample_mode=sample_metadata["downsample_mode"],
+ )
+
+ assert plotter.ring_buffer is None
+
+ plotter.close()
+
+ def test_set_ring_buffer(self, qapp, sample_metadata, populated_buffer):
+ """Test updating the ring buffer reference."""
+ plotter = LivePlotter(
+ ring_buffer=None,
+ channels=sample_metadata["channels"],
+ voltage_ranges=sample_metadata["voltage_ranges"],
+ offsets_v=sample_metadata["offsets_v"],
+ resolution=sample_metadata["resolution"],
+ downsample_mode=sample_metadata["downsample_mode"],
+ )
+
+ assert plotter.ring_buffer is None
+
+ plotter.set_ring_buffer(populated_buffer)
+ assert plotter.ring_buffer is populated_buffer
+ assert plotter.last_position == 0
+
+ plotter.close()
+
+ def test_set_ring_buffer_to_none(self, qapp, sample_metadata, populated_buffer):
+ """Test clearing the ring buffer reference."""
+ plotter = LivePlotter(
+ ring_buffer=populated_buffer,
+ channels=sample_metadata["channels"],
+ voltage_ranges=sample_metadata["voltage_ranges"],
+ offsets_v=sample_metadata["offsets_v"],
+ resolution=sample_metadata["resolution"],
+ downsample_mode=sample_metadata["downsample_mode"],
+ )
+
+ plotter.set_ring_buffer(None)
+ assert plotter.ring_buffer is None
+
+ plotter.close()
+
+
+class TestLivePlotterDisplay:
+ """Display functionality tests."""
+
+ def test_pipeline_time_axis(
+ self, qapp, sample_metadata, populated_buffer, acquisition_rate_none
+ ):
+ """Test time axis creation via DataPipeline for requested sample count."""
+ # Set acquisition rate on buffer so plotter can initialize pipeline
+ populated_buffer.set_acquisition_rate(acquisition_rate_none)
+
+ plotter = LivePlotter(
+ ring_buffer=populated_buffer,
+ channels=sample_metadata["channels"],
+ voltage_ranges=sample_metadata["voltage_ranges"],
+ offsets_v=sample_metadata["offsets_v"],
+ resolution=sample_metadata["resolution"],
+ downsample_mode="NONE",
+ display_window_seconds=1.0,
+ decimation_factor=1,
+ )
+
+ # Notify plotter that buffer has acquisition rate
+ plotter._update_pipeline_sample_rate()
+
+ # Request time axis for 1.0s duration with decimation=1
+ assert plotter.pipeline is not None
+ time_axis = plotter.pipeline.create_time_axis_for_window(1.0, 100)
+
+ assert len(time_axis) > 0
+ assert time_axis[-1] > time_axis[0] # Monotonic
+ # Time axis is relative from 0 to window_duration
+ assert abs(time_axis[-1] - 1.0) < 0.01 # Window is 1.0s
+
+ plotter.close()
+
+ def test_pipeline_time_axis_normal_mode(
+ self, qapp, sample_metadata, populated_buffer, acquisition_rate_none
+ ):
+ """Test time axis via DataPipeline in normal (non-aggregate) mode."""
+ populated_buffer.set_acquisition_rate(acquisition_rate_none)
+
+ plotter = LivePlotter(
+ ring_buffer=populated_buffer,
+ channels=sample_metadata["channels"],
+ voltage_ranges=sample_metadata["voltage_ranges"],
+ offsets_v=sample_metadata["offsets_v"],
+ resolution=sample_metadata["resolution"],
+ downsample_mode="NONE",
+ display_window_seconds=5.0,
+ decimation_factor=100,
+ )
+
+ plotter._update_pipeline_sample_rate()
+
+ assert plotter.pipeline is not None
+ time_axis = plotter.pipeline.create_time_axis_for_window(5.0, 100)
+
+ assert len(time_axis) > 0
+ # Normal mode: linear from 0 to window_duration
+ assert abs(time_axis[-1] - 5.0) < 0.01
+
+ plotter.close()
+
+ def test_pipeline_time_axis_aggregate_mode(
+ self, qapp, sample_metadata, populated_buffer, acquisition_rate_aggregate
+ ):
+ """Test time axis via DataPipeline in AGGREGATE mode.
+
+ In AGGREGATE mode, n_display_points is doubled (min/max pairs).
+ """
+ populated_buffer.set_acquisition_rate(acquisition_rate_aggregate)
+
+ plotter = LivePlotter(
+ ring_buffer=populated_buffer,
+ channels=sample_metadata["channels"],
+ voltage_ranges=sample_metadata["voltage_ranges"],
+ offsets_v=sample_metadata["offsets_v"],
+ resolution=sample_metadata["resolution"],
+ downsample_mode="AGGREGATE", # Use AGGREGATE mode
+ display_window_seconds=5.0,
+ decimation_factor=50,
+ )
+
+ plotter._update_pipeline_sample_rate()
+
+ # In AGGREGATE mode:
+ # - Each time point is repeated twice for min/max pairs
+ assert plotter.pipeline is not None
+ time_axis = plotter.pipeline.create_time_axis_for_window(5.0, 50)
+
+ # AGGREGATE mode: should have time axis with min/max pairs
+ assert len(time_axis) > 0
+ # Should be multiples of 2 due to min/max pairs
+ assert len(time_axis) % 2 == 0
+
+ plotter.close()
+
+ def test_format_sample_count(self, qapp, sample_metadata):
+ """Test sample count formatting."""
+ plotter = LivePlotter(
+ ring_buffer=None,
+ channels=sample_metadata["channels"],
+ voltage_ranges=sample_metadata["voltage_ranges"],
+ offsets_v=sample_metadata["offsets_v"],
+ resolution=sample_metadata["resolution"],
+ downsample_mode=sample_metadata["downsample_mode"],
+ )
+
+ assert plotter.format_sample_count(500) == "500"
+ assert plotter.format_sample_count(1500) == "1.50K"
+ assert plotter.format_sample_count(2500000) == "2.50M"
+ assert plotter.format_sample_count(1500000000) == "1.50G"
+
+ plotter.close()
+
+ def test_format_rate_sps(self, qapp, sample_metadata):
+ """Test sample rate formatting."""
+ plotter = LivePlotter(
+ ring_buffer=None,
+ channels=sample_metadata["channels"],
+ voltage_ranges=sample_metadata["voltage_ranges"],
+ offsets_v=sample_metadata["offsets_v"],
+ resolution=sample_metadata["resolution"],
+ downsample_mode=sample_metadata["downsample_mode"],
+ )
+
+ assert plotter._format_rate_sps(500) == "500.00 S/s"
+ assert plotter._format_rate_sps(1500) == "1.50 kS/s"
+ assert plotter._format_rate_sps(2500000) == "2.50 MS/s"
+ assert plotter._format_rate_sps(1500000000) == "1.50 GS/s"
+
+ plotter.close()
+
+
+class TestLivePlotterChannelVisibility:
+ """Channel visibility tests."""
+
+ def test_set_channel_visible(self, qapp, sample_metadata, populated_buffer):
+ """Test toggling channel visibility."""
+ plotter = LivePlotter(
+ ring_buffer=populated_buffer,
+ channels=sample_metadata["channels"],
+ voltage_ranges=sample_metadata["voltage_ranges"],
+ offsets_v=sample_metadata["offsets_v"],
+ resolution=sample_metadata["resolution"],
+ downsample_mode=sample_metadata["downsample_mode"],
+ )
+
+ assert plotter.lines[0].visible
+
+ plotter.set_channel_visible(0, False)
+ assert not plotter.lines[0].visible
+
+ plotter.set_channel_visible(0, True)
+ assert plotter.lines[0].visible
+
+ plotter.close()
+
+
+class TestLivePlotterYLimits:
+ """Y-axis limit tests for division-based scaling."""
+
+ def test_y_axis_division_range(self, qapp, sample_metadata, populated_buffer):
+ """Test that Y-axis is set to division range (-5 to +5)."""
+ plotter = LivePlotter(
+ ring_buffer=populated_buffer,
+ channels=sample_metadata["channels"],
+ voltage_ranges=sample_metadata["voltage_ranges"],
+ offsets_v=sample_metadata["offsets_v"],
+ resolution=sample_metadata["resolution"],
+ downsample_mode=sample_metadata["downsample_mode"],
+ )
+
+ # Verify Y-axis range is set to divisions (-5 to +5) via camera
+ rect = plotter.view.camera.get_state()["rect"]
+ y_bottom, y_top = rect.bottom, rect.top
+ # Camera may add small padding, check approximate range
+ assert abs(y_bottom - (-5)) < 1
+ assert abs(y_top - 5) < 1
+
+ plotter.close()
+
+ def test_v_div_settings(self, qapp, sample_metadata):
+ """Test setting V/div and Y-position settings."""
+ plotter = LivePlotter(
+ ring_buffer=None,
+ channels=sample_metadata["channels"],
+ voltage_ranges=sample_metadata["voltage_ranges"],
+ offsets_v=sample_metadata["offsets_v"],
+ resolution=sample_metadata["resolution"],
+ downsample_mode=sample_metadata["downsample_mode"],
+ )
+
+ # Set V/div settings for both channels
+ v_div_settings = {0: 0.1, 1: 0.1} # 100mV/div
+ y_pos_settings = {0: 0.0, 1: 0.0}
+
+ plotter.set_v_div_settings(v_div_settings, y_pos_settings)
+
+ assert plotter.v_div_volts[0] == 0.1
+ assert plotter.v_div_volts[1] == 0.1
+ assert plotter.y_pos_divs[0] == 0.0
+ assert plotter.y_pos_divs[1] == 0.0
+
+ plotter.close()
+
+ def test_channel_clip_detection(self, qapp, sample_metadata):
+ """Test channel clipping detection."""
+ plotter = LivePlotter(
+ ring_buffer=None,
+ channels=[0, 1],
+ voltage_ranges={0: 5.0, 1: 10.0},
+ offsets_v={0: 0.0, 1: 0.0},
+ resolution=12,
+ downsample_mode="NONE",
+ )
+
+ # Set V/div = 1V, Y-pos = 0 (range is -5 to +5 divisions)
+ plotter.v_div_volts[0] = 1.0
+ plotter.y_pos_divs[0] = 0.0
+
+ # No clipping within range (±2 divisions = ±2V)
+ plotter._display_cache[0] = np.array([-2.0, 2.0]) # ±2 divisions
+ assert plotter.is_channel_clipped(0) == 0
+
+ # Clipping above range (6 divisions exceeds +5)
+ plotter._display_cache[0] = np.array([0.0, 6.0])
+ assert plotter.is_channel_clipped(0) == 1
+
+ # Clipping below range (-6 divisions exceeds -5)
+ plotter._display_cache[0] = np.array([-6.0, 0.0])
+ assert plotter.is_channel_clipped(0) == -1
+
+ plotter.close()
+
+
+class TestLivePlotterPipeline:
+ """DataPipeline integration tests."""
+
+ def test_pipeline_initialised(
+ self, qapp, sample_metadata, populated_buffer, acquisition_rate_none
+ ):
+ """Test that DataPipeline is initialised when acquisition rate is set."""
+ populated_buffer.set_acquisition_rate(acquisition_rate_none)
+
+ plotter = LivePlotter(
+ ring_buffer=populated_buffer,
+ channels=sample_metadata["channels"],
+ voltage_ranges=sample_metadata["voltage_ranges"],
+ offsets_v=sample_metadata["offsets_v"],
+ resolution=sample_metadata["resolution"],
+ downsample_mode=sample_metadata["downsample_mode"],
+ )
+
+ plotter._update_pipeline_sample_rate()
+
+ assert plotter.pipeline is not None
+ assert plotter.acquisition_rate.num_channels == 2
+ assert plotter.acquisition_rate.downsample_mode == DownsampleMode.NONE
+
+ plotter.close()
+
+ def test_set_acquisition_rate(self, qapp, sample_metadata):
+ """Test setting acquisition rate on plotter."""
+ plotter = LivePlotter(
+ ring_buffer=None,
+ channels=sample_metadata["channels"],
+ voltage_ranges=sample_metadata["voltage_ranges"],
+ offsets_v=sample_metadata["offsets_v"],
+ resolution=sample_metadata["resolution"],
+ downsample_mode=sample_metadata["downsample_mode"],
+ )
+
+ rate = AcquisitionRate(
+ hardware_rate_hz=100e6,
+ num_channels=1,
+ downsample_ratio=2,
+ downsample_mode=DownsampleMode.AGGREGATE,
+ )
+
+ plotter.set_acquisition_rate(rate)
+
+ assert plotter.pipeline is not None
+ assert plotter.acquisition_rate.num_channels == 1
+ assert plotter.resolution == 16 # Remains from initialisation
+ assert plotter.acquisition_rate.downsample_mode == DownsampleMode.AGGREGATE
+ assert plotter.acquisition_rate.downsample_ratio == 2
+ assert plotter.voltage_ranges[0] == 20.0
+
+ plotter.close()
diff --git a/picostream/test_max_adc_fix.py b/picostream/test_max_adc_fix.py
new file mode 100644
index 0000000..731d2c1
--- /dev/null
+++ b/picostream/test_max_adc_fix.py
@@ -0,0 +1,73 @@
+"""Test that max_adc is calculated correctly from resolution."""
+
+import sys
+from pathlib import Path
+
+sys.path.insert(0, str(Path(__file__).parent.parent.parent))
+
+
+def test_max_adc_calculation():
+ """Test max_adc calculation for various resolutions."""
+ print("Testing max_adc calculation from resolution:")
+ print("-" * 50)
+
+ for res in [8, 12, 14, 15, 16]:
+ max_adc = (1 << (res - 1)) - 1
+ print(f" {res}-bit: max_adc = {max_adc}")
+
+ # Verify against what we expect
+ assert (1 << (12 - 1)) - 1 == 2047, "12-bit max_adc should be 2047"
+ assert (1 << (14 - 1)) - 1 == 8191, "14-bit max_adc should be 8191"
+ assert (1 << (16 - 1)) - 1 == 32767, "16-bit max_adc should be 32767"
+
+ print("\n✓ All assertions passed")
+
+
+def test_voltage_conversion():
+ """Test that voltage conversion works correctly with calculated max_adc."""
+ import numpy as np
+
+ from picostream.conversion_utils import adc_to_mV
+
+ print("\nTesting voltage conversion with corrected max_adc:")
+ print("-" * 50)
+
+ # Simulate 1V signal on 5V range, 14-bit
+ resolution = 14
+ max_adc = (1 << (resolution - 1)) - 1 # 8191
+ voltage_range_v = 5.0
+
+ # ADC counts for 1V (centered at 0, so ±0.5V)
+ true_voltage_v = 0.5 # Peak of sine wave
+ adc_counts = int((true_voltage_v / voltage_range_v) * max_adc)
+
+ print(f" Resolution: {resolution}-bit")
+ print(f" Using max_adc: {max_adc}")
+ print(f" Voltage range: ±{voltage_range_v}V")
+ print(f" True voltage: {true_voltage_v}V")
+ print(f" ADC counts: {adc_counts}")
+
+ # Convert back
+ voltage_mv = adc_to_mV(
+ np.array([adc_counts], dtype=np.int16), voltage_range_v, max_adc
+ )
+
+ print(f" Converted voltage: {voltage_mv[0]:.2f} mV = {voltage_mv[0] / 1000:.4f} V")
+ print(f" Error: {abs(voltage_mv[0] / 1000 - true_voltage_v) * 100:.2f}%")
+
+ # Compare with WRONG max_adc (32767)
+ wrong_max_adc = 32767
+ voltage_mv_wrong = adc_to_mV(
+ np.array([adc_counts], dtype=np.int16), voltage_range_v, wrong_max_adc
+ )
+
+ print(f"\n With WRONG max_adc={wrong_max_adc}:")
+ print(
+ f" Converted voltage: {voltage_mv_wrong[0]:.2f} mV = {voltage_mv_wrong[0] / 1000:.4f} V"
+ )
+ print(f" Error factor: {voltage_mv[0] / voltage_mv_wrong[0]:.2f}x")
+
+
+if __name__ == "__main__":
+ test_max_adc_calculation()
+ test_voltage_conversion()
diff --git a/picostream/test_rate_contract.py b/picostream/test_rate_contract.py
new file mode 100644
index 0000000..2e65f39
--- /dev/null
+++ b/picostream/test_rate_contract.py
@@ -0,0 +1,263 @@
+"""Integration tests for rate handling across components.
+
+These tests verify that rate information flows correctly from the device
+through the ring buffer to the data pipeline, ensuring time calculations
+are accurate.
+"""
+
+import numpy as np
+
+from picostream.acquisition_rate import AcquisitionRate, DownsampleMode
+from picostream.data_pipeline import DataPipeline, PipelineConfig
+from picostream.ring_buffer import RingBuffer
+
+
+class TestRateContractDeviceToPipeline:
+ """Tests for rate contract between device, ring buffer, and pipeline.
+
+ These tests catch issues like double-counting channels or misapplying
+ downsampling ratios across component boundaries.
+ """
+
+ def _simulate_device_setting_storage_rate(
+ self,
+ ring_buffer: RingBuffer,
+ hardware_rate_hz: float,
+ downsample_ratio: int,
+ downsample_mode: str,
+ n_channels: int,
+ ) -> AcquisitionRate:
+ """Simulate how the device sets AcquisitionRate in start_streaming().
+
+ This replicates the logic in PicoscopeBufferedStream.start_streaming()
+ to ensure tests reflect actual device behavior.
+
+ Returns
+ -------
+ AcquisitionRate
+ The acquisition rate object that gets set on the ring buffer.
+ """
+ rate = AcquisitionRate(
+ hardware_rate_hz=hardware_rate_hz,
+ num_channels=n_channels,
+ downsample_ratio=downsample_ratio,
+ downsample_mode=DownsampleMode(downsample_mode),
+ )
+ ring_buffer.set_acquisition_rate(rate)
+ return rate
+
+ def test_single_channel_no_downsample_time_calculation(self):
+ """1 channel, no downsampling: 1 second of data = 1 second displayed."""
+ # Setup ring buffer as device would
+ rb = RingBuffer(duration_s=10.0, sample_rate=62.5e6, num_channels=1)
+
+ # Device configures the actual storage rate
+ rate = self._simulate_device_setting_storage_rate(
+ rb,
+ hardware_rate_hz=62.5e6,
+ downsample_ratio=1,
+ downsample_mode="NONE",
+ n_channels=1,
+ )
+
+ # Pipeline config and acquisition rate
+ config = PipelineConfig(
+ resolution=16,
+ voltage_ranges={0: 20.0},
+ offsets_v={0: 0.0},
+ max_adc_value=32767,
+ target_plot_points=20000,
+ )
+ pipeline = DataPipeline(config, rate)
+
+ # Write 1 second of data to ring buffer
+ one_second_samples = int(rate.storage_rate_hz) # 62.5M samples
+ data = np.random.randint(
+ -1000, 1000, size=(one_second_samples, 1), dtype=np.int16
+ )
+ rb.write(data)
+
+ # Read it back and verify time calculation
+ read_data = rb.read_last(one_second_samples)
+ duration = pipeline.samples_to_time(len(read_data))
+
+ # Should be exactly 1 second (within floating point tolerance)
+ assert abs(duration - 1.0) < 0.001, f"Expected ~1.0s, got {duration}s"
+
+ def test_two_channel_no_downsample_time_calculation(self):
+ """2 channels, no downsampling: 1 second of data = 1 second displayed."""
+ rb = RingBuffer(duration_s=10.0, sample_rate=125e6, num_channels=2)
+
+ # Device: 2 channels at 62.5 MS/s each = 125 MS/s total
+ rate = self._simulate_device_setting_storage_rate(
+ rb,
+ hardware_rate_hz=125e6, # 62.5e6 * 2
+ downsample_ratio=1,
+ downsample_mode="NONE",
+ n_channels=2,
+ )
+
+ config = PipelineConfig(
+ resolution=16,
+ voltage_ranges={0: 20.0, 1: 20.0},
+ offsets_v={0: 0.0, 1: 0.0},
+ max_adc_value=32767,
+ target_plot_points=20000,
+ )
+ pipeline = DataPipeline(config, rate)
+
+ # Write 1 second of data
+ one_second_samples = int(rate.storage_rate_hz) # 125M samples total
+ data = np.random.randint(
+ -1000, 1000, size=(one_second_samples, 2), dtype=np.int16
+ )
+ rb.write(data)
+
+ # Verify: per_channel_rate should be 62.5 MS/s
+ assert pipeline.acquisition_rate.per_channel_rate_hz == 62.5e6
+
+ # Time for 62.5M samples on one channel = 1 second
+ duration = pipeline.samples_to_time(62500000)
+ assert abs(duration - 1.0) < 0.001, f"Expected ~1.0s, got {duration}s"
+
+ def test_ten_x_downsample_time_calculation(self):
+ """10x downsampling: 1 second of real time = 1 second displayed."""
+ rb = RingBuffer(duration_s=10.0, sample_rate=12.5e6, num_channels=2)
+
+ # Device: 125 MS/s hardware / 10x downsample = 12.5 MS/s storage
+ rate = self._simulate_device_setting_storage_rate(
+ rb,
+ hardware_rate_hz=125e6,
+ downsample_ratio=10,
+ downsample_mode="DECIMATE",
+ n_channels=2,
+ )
+
+ config = PipelineConfig(
+ resolution=16,
+ voltage_ranges={0: 20.0, 1: 20.0},
+ offsets_v={0: 0.0, 1: 0.0},
+ max_adc_value=32767,
+ target_plot_points=20000,
+ )
+ pipeline = DataPipeline(config, rate)
+
+ # Write 1 second of downsampled data
+ one_second_samples = int(rate.storage_rate_hz) # 12.5M samples
+ data = np.random.randint(
+ -1000, 1000, size=(one_second_samples, 2), dtype=np.int16
+ )
+ rb.write(data)
+
+ # Per-channel rate should be 6.25 MS/s
+ assert pipeline.acquisition_rate.per_channel_rate_hz == 6.25e6
+
+ # Time for 6.25M samples (1 channel) = 1 second
+ duration = pipeline.samples_to_time(6250000)
+ assert abs(duration - 1.0) < 0.001, f"Expected ~1.0s, got {duration}s"
+
+ def test_aggregate_mode_time_calculation(self):
+ """AGGREGATE mode: 1 second of real time = 1 second displayed."""
+ rb = RingBuffer(duration_s=10.0, sample_rate=25e6, num_channels=2)
+
+ # Device: 125 MS/s hardware / 10x * 2 (min/max) = 25 MS/s storage
+ rate = self._simulate_device_setting_storage_rate(
+ rb,
+ hardware_rate_hz=125e6,
+ downsample_ratio=10,
+ downsample_mode="AGGREGATE",
+ n_channels=2,
+ )
+
+ config = PipelineConfig(
+ resolution=16,
+ voltage_ranges={0: 20.0, 1: 20.0},
+ offsets_v={0: 0.0, 1: 0.0},
+ max_adc_value=32767,
+ target_plot_points=20000,
+ )
+ pipeline = DataPipeline(config, rate)
+
+ # Write 1 second of AGGREGATE data
+ one_second_samples = int(
+ rate.storage_rate_hz
+ ) # 25M samples (12.5M min/max pairs)
+ data = np.random.randint(
+ -1000, 1000, size=(one_second_samples, 2), dtype=np.int16
+ )
+ rb.write(data)
+
+ # Storage rate is 25 MS/s (12.5M pairs/sec * 2 values/pair)
+ # Per-channel rate is 6.25 MS/s (125M / 2 channels / 10 downsample)
+ assert rate.storage_rate_hz == 25e6
+ assert pipeline.acquisition_rate.per_channel_rate_hz == 6.25e6
+
+ # In AGGREGATE mode:
+ # - 25e6 samples = 12.5e6 min/max pairs
+ # - samples_to_time: n_time_points = 25e6 // 2 = 12.5e6 pairs
+ # - duration = 12.5e6 pairs / 6.25e6 per_channel_rate = 2 seconds
+ duration = pipeline.samples_to_time(25000000)
+ assert abs(duration - 2.0) < 0.001, f"Expected ~2.0s, got {duration}s"
+
+
+class TestRateContractInvariants:
+ """Invariants that must hold regardless of configuration."""
+
+ def test_storage_rate_is_total_not_per_channel(self):
+ """storage_rate_hz must always represent total rate, not per-channel."""
+ # Single channel
+ rate1 = AcquisitionRate(
+ hardware_rate_hz=10e6,
+ num_channels=1,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.NONE,
+ )
+ assert rate1.storage_rate_hz == 10e6
+
+ # Two channels - rate doubles (more samples per unit time)
+ rate2 = AcquisitionRate(
+ hardware_rate_hz=20e6,
+ num_channels=2,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.NONE,
+ )
+ assert rate2.storage_rate_hz == 20e6 # 10 MS/s per channel * 2 channels
+
+ def test_pipeline_divides_by_channel_count(self):
+ """Pipeline must divide total rate by channel count for per-channel rate."""
+ # Test 2-channel setup
+ rate2 = AcquisitionRate(
+ hardware_rate_hz=100e6,
+ num_channels=2,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.NONE,
+ )
+ config = PipelineConfig(
+ resolution=16,
+ voltage_ranges={0: 20.0, 1: 20.0},
+ offsets_v={0: 0.0, 1: 0.0},
+ max_adc_value=32767,
+ target_plot_points=20000,
+ )
+ pipeline = DataPipeline(config, rate2)
+
+ # Per-channel rate should be 50 MS/s (100 / 2 channels)
+ assert pipeline.acquisition_rate.per_channel_rate_hz == 50e6
+
+ # Test 1-channel setup
+ rate1 = AcquisitionRate(
+ hardware_rate_hz=50e6,
+ num_channels=1,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.NONE,
+ )
+ config1 = PipelineConfig(
+ resolution=16,
+ voltage_ranges={0: 20.0},
+ offsets_v={0: 0.0},
+ max_adc_value=32767,
+ target_plot_points=20000,
+ )
+ pipeline1 = DataPipeline(config1, rate1)
+
+ assert pipeline1.acquisition_rate.per_channel_rate_hz == 50e6
diff --git a/picostream/test_rate_invariants.py b/picostream/test_rate_invariants.py
new file mode 100644
index 0000000..f9bbbd3
--- /dev/null
+++ b/picostream/test_rate_invariants.py
@@ -0,0 +1,329 @@
+"""Rate invariant tests - verify rate calculations remain consistent across system.
+
+These tests catch bugs where rates get multiplied/divided incorrectly at
+system boundaries (device → ring buffer → pipeline → display).
+"""
+
+from __future__ import annotations
+
+from picostream.acquisition_rate import AcquisitionRate, DownsampleMode
+from picostream.data_pipeline import DataPipeline, PipelineConfig
+
+
+class TestDeviceToRingBufferRateInvariant:
+ """Verify storage_rate_hz is calculated correctly from hardware parameters."""
+
+ def test_no_downsampling_no_aggregate(self) -> None:
+ """Storage rate = hardware rate when no downsampling/AGGREGATE."""
+ hardware_rate_hz = 125e6 # 125 MS/s total (62.5 per channel, 2 channels)
+
+ rate = AcquisitionRate(
+ hardware_rate_hz=hardware_rate_hz,
+ num_channels=2,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.NONE,
+ )
+
+ expected_per_channel = hardware_rate_hz / 2 # 2 channels
+ assert rate.storage_rate_hz == 125e6
+ assert rate.per_channel_rate_hz == expected_per_channel == 62.5e6
+
+ def test_with_downsampling_no_aggregate(self) -> None:
+ """Per-channel rate = hardware rate / (channels * downsample_ratio)."""
+ hardware_rate_hz = 125e6
+ downsample_ratio = 10
+
+ rate = AcquisitionRate(
+ hardware_rate_hz=hardware_rate_hz,
+ num_channels=2,
+ downsample_ratio=downsample_ratio,
+ downsample_mode=DownsampleMode.NONE,
+ )
+
+ assert rate.storage_rate_hz == 12.5e6
+ assert rate.per_channel_rate_hz == 6.25e6
+
+ def test_with_aggregate_doubles_storage_rate(self) -> None:
+ """AGGREGATE mode at 1x downsampling does not double storage rate.
+
+ When downsample_ratio=1, AGGREGATE mode has no effect - data passes
+ through unchanged. Doubling only occurs when actually downsampling.
+ """
+ hardware_rate_hz = 125e6
+
+ rate = AcquisitionRate(
+ hardware_rate_hz=hardware_rate_hz,
+ num_channels=2,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.AGGREGATE,
+ )
+
+ # At 1x downsampling, AGGREGATE does not apply - rate is unchanged
+ assert rate.storage_rate_hz == 125e6
+
+ def test_with_downsample_and_aggregate(self) -> None:
+ """Downsampling and AGGREGATE combined."""
+ hardware_rate_hz = 125e6
+ downsample_ratio = 10
+
+ rate = AcquisitionRate(
+ hardware_rate_hz=hardware_rate_hz,
+ num_channels=2,
+ downsample_ratio=downsample_ratio,
+ downsample_mode=DownsampleMode.AGGREGATE,
+ )
+
+ assert rate.storage_rate_hz == 25e6
+ assert rate.per_channel_rate_hz == 6.25e6
+
+
+class TestRingBufferToPlotterRatePropagation:
+ """Verify plotter passes storage_rate_hz to pipeline WITHOUT multiplying by channels.
+
+ This is the specific bug that was fixed: plotter was doing
+ pipeline.set_actual_hardware_rate(storage_rate_hz * len(channels))
+ when it should just pass storage_rate_hz directly.
+ """
+
+ def test_pipeline_receives_acquire_rate(self) -> None:
+ """Pipeline should receive total rate, not rate * channel_count."""
+ rate = AcquisitionRate(
+ hardware_rate_hz=125e6,
+ num_channels=2,
+ downsample_ratio=10,
+ downsample_mode=DownsampleMode.NONE,
+ )
+
+ config = PipelineConfig(
+ resolution=16,
+ voltage_ranges={0: 20.0, 1: 20.0},
+ offsets_v={0: 0.0, 1: 0.0},
+ max_adc_value=32767,
+ target_plot_points=20000,
+ )
+ pipeline = DataPipeline(config, rate)
+
+ # Pipeline internally divides by channel count to get per-channel rate
+ per_channel_rate = pipeline.acquisition_rate.per_channel_rate_hz
+
+ # Should be 6.25 MS/s per channel
+ assert per_channel_rate == 6.25e6
+
+ def test_storage_rate_from_ring_buffer_used_directly(self) -> None:
+ """Plotter should use ring_buffer.storage_rate_hz as-is for pipeline."""
+ # Create rate representing device setting after streaming starts
+ rate = AcquisitionRate(
+ hardware_rate_hz=125e6,
+ num_channels=2,
+ downsample_ratio=10,
+ downsample_mode=DownsampleMode.NONE,
+ )
+
+ # Storage rate should be 12.5e6 (125e6 / 10)
+ assert rate.storage_rate_hz == 12.5e6
+
+ config = PipelineConfig(
+ resolution=16,
+ voltage_ranges={0: 20.0, 1: 20.0},
+ offsets_v={0: 0.0, 1: 0.0},
+ max_adc_value=32767,
+ target_plot_points=20000,
+ )
+ pipeline = DataPipeline(config, rate)
+
+ # Pipeline should report 6.25 MS/s per channel
+ assert pipeline.acquisition_rate.per_channel_rate_hz == 6.25e6
+
+
+class TestEndToEndTimeAxisAccuracy:
+ """Verify time axis calculations produce correct durations after downsampling.
+
+ This tests the symptom: "time intervals display 10x too large".
+ """
+
+ def test_time_axis_accuracy_no_downsampling(self) -> None:
+ """1 second of data should display as 1.0s without downsampling."""
+ rate = AcquisitionRate(
+ hardware_rate_hz=125e6, # 62.5 per channel, 2 channels
+ num_channels=2,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.NONE,
+ )
+
+ config = PipelineConfig(
+ resolution=16,
+ voltage_ranges={0: 20.0, 1: 20.0},
+ offsets_v={0: 0.0, 1: 0.0},
+ max_adc_value=32767,
+ target_plot_points=20000,
+ )
+ pipeline = DataPipeline(config, rate)
+
+ # Create time axis for 1 second display window with decimation=1
+ # Signature: (window_duration: float, decimation: int)
+ time_axis = pipeline.create_time_axis_for_window(1.0, 1)
+
+ # Should span from 0 to 1.0 seconds
+ assert time_axis[0] == 0.0
+ assert abs(time_axis[-1] - 1.0) < 1e-9
+
+ def test_time_axis_accuracy_with_10x_downsampling(self) -> None:
+ """1 second of data should display as 1.0s with 10x downsampling.
+
+ This is the critical test for the bug: with 10x downsampling,
+ if rate was multiplied by channels, time would show as 10x too large.
+ """
+ hardware_rate_hz = 125e6 # 125 MS/s total (62.5 per channel)
+ downsample_ratio = 10
+
+ rate = AcquisitionRate(
+ hardware_rate_hz=hardware_rate_hz,
+ num_channels=2,
+ downsample_ratio=downsample_ratio,
+ downsample_mode=DownsampleMode.NONE,
+ )
+
+ config = PipelineConfig(
+ resolution=16,
+ voltage_ranges={0: 20.0, 1: 20.0},
+ offsets_v={0: 0.0, 1: 0.0},
+ max_adc_value=32767,
+ target_plot_points=20000,
+ )
+ pipeline = DataPipeline(config, rate)
+
+ # Per-channel effective rate after downsampling
+ effective_rate_per_channel = pipeline.acquisition_rate.per_channel_rate_hz
+ assert effective_rate_per_channel == 6.25e6
+
+ # Create time axis for 1 second window with decimation=1
+ # Signature: (window_duration: float, decimation: int)
+ time_axis = pipeline.create_time_axis_for_window(1.0, 1)
+
+ # Should span from 0 to 1.0 seconds
+ assert time_axis[0] == 0.0
+ assert abs(time_axis[-1] - 1.0) < 1e-9
+
+ def test_time_axis_accuracy_aggregate_mode(self) -> None:
+ """Time axis should be accurate in AGGREGATE mode with downsampling."""
+ rate = AcquisitionRate(
+ hardware_rate_hz=125e6,
+ num_channels=2,
+ downsample_ratio=10,
+ downsample_mode=DownsampleMode.AGGREGATE,
+ )
+
+ config = PipelineConfig(
+ resolution=16,
+ voltage_ranges={0: 20.0, 1: 20.0},
+ offsets_v={0: 0.0, 1: 0.0},
+ max_adc_value=32767,
+ target_plot_points=20000,
+ )
+ pipeline = DataPipeline(config, rate)
+
+ # AGGREGATE mode: time_to_samples returns storage samples (already doubled)
+ # 1s window = 6.25e6 time points per channel * 2 for min/max = 12.5M storage samples
+ # Each time point produces 2 time axis values (repeated), so 12.5M total values
+ time_axis = pipeline.create_time_axis_for_window(1.0, 1)
+
+ # Should still span 0 to 1.0 seconds
+ assert time_axis[0] == 0.0
+ assert abs(time_axis[-1] - 1.0) < 1e-9
+ # window_samples from time_to_samples // 2 (pairs) * 2 (doubled) = window_samples
+ window_samples = pipeline.time_to_samples(1.0)
+ expected_length = int(window_samples // 2 * 2)
+ assert len(time_axis) == expected_length
+
+
+class TestDisplayDurationCalculation:
+ """Verify display duration calculations account for decimation correctly."""
+
+ def test_display_duration_no_decimation_no_downsample(self) -> None:
+ """Display duration should match actual data duration.
+
+ Note: get_display_duration accounts for software decimation
+ (min/max pairs) and hardware AGGREGATE mode.
+ """
+ rate = AcquisitionRate(
+ hardware_rate_hz=125e6,
+ num_channels=2,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.NONE,
+ )
+
+ config = PipelineConfig(
+ resolution=16,
+ voltage_ranges={0: 20.0, 1: 20.0},
+ offsets_v={0: 0.0, 1: 0.0},
+ max_adc_value=32767,
+ target_plot_points=20000,
+ )
+ pipeline = DataPipeline(config, rate)
+
+ # 100 display samples = 50 min/max pairs (software decimation), at decimation 1
+ # = 50 original samples at 62.5 MS/s per_channel_rate = 0.8 microseconds
+ display_duration = pipeline.get_display_duration(100, decimation=1)
+ # n_display_samples // 2 = 50 pairs * decimation 1 = 50 original samples
+ # 50 / 62.5e6 = 8e-7 seconds
+ expected = 50 / 62.5e6
+ assert abs(display_duration - expected) < 1e-9
+
+ def test_display_duration_with_decimation_and_downsampling(self) -> None:
+ """Display duration should account for both software and hardware decimation."""
+ rate = AcquisitionRate(
+ hardware_rate_hz=125e6,
+ num_channels=2,
+ downsample_ratio=10,
+ downsample_mode=DownsampleMode.NONE,
+ )
+
+ config = PipelineConfig(
+ resolution=16,
+ voltage_ranges={0: 20.0, 1: 20.0},
+ offsets_v={0: 0.0, 1: 0.0},
+ max_adc_value=32767,
+ target_plot_points=20000,
+ )
+ pipeline = DataPipeline(config, rate)
+
+ # 200 display samples = 100 min/max pairs with software decimation 1
+ # = 100 original samples at 6.25 MS/s per_channel_rate = 16 microseconds
+ display_duration = pipeline.get_display_duration(200, decimation=1)
+ expected = 100 / 6.25e6
+ assert abs(display_duration - expected) < 1e-12
+
+ def test_display_duration_10x_downsampling_sanity_check(self) -> None:
+ """Sanity check: verify rate calculations with 10x downsampling.
+
+ With 10x downsampling, 125 MS/s hardware becomes 12.5 MS/s storage.
+ Per channel effective rate is 6.25 MS/s.
+ """
+ rate = AcquisitionRate(
+ hardware_rate_hz=125e6,
+ num_channels=2,
+ downsample_ratio=10,
+ downsample_mode=DownsampleMode.NONE,
+ )
+
+ config = PipelineConfig(
+ resolution=16,
+ voltage_ranges={0: 20.0, 1: 20.0},
+ offsets_v={0: 0.0, 1: 0.0},
+ max_adc_value=32767,
+ target_plot_points=20000,
+ )
+ pipeline = DataPipeline(config, rate)
+
+ # Verify effective rate is correct
+ assert pipeline.acquisition_rate.per_channel_rate_hz == 6.25e6
+
+ # 1000 display samples = 500 min/max pairs (software decimation)
+ # at decimation 10 = 5000 original samples
+ # at 6.25 MS/s = 5000 / 6.25e6 = 0.8 milliseconds
+ display_duration = pipeline.get_display_duration(1000, decimation=10)
+ expected_duration = (500 * 10) / 6.25e6 # 0.0008 seconds
+ assert abs(display_duration - expected_duration) < 1e-9
+
+ # Verify the value is reasonable (not orders of magnitude off)
+ assert 0.0001 < display_duration < 0.01 # Between 0.1ms and 10ms
diff --git a/picostream/test_ring_buffer.py b/picostream/test_ring_buffer.py
new file mode 100644
index 0000000..c622e8f
--- /dev/null
+++ b/picostream/test_ring_buffer.py
@@ -0,0 +1,508 @@
+"""Tests for the RingBuffer class."""
+
+from __future__ import annotations
+
+import threading
+import time
+
+import numpy as np
+import pytest
+
+from picostream.acquisition_rate import AcquisitionRate, DownsampleMode
+from picostream.ring_buffer import RingBuffer
+
+
+class TestRingBufferBasic:
+ """Basic functionality tests."""
+
+ def test_init(self) -> None:
+ """Test buffer initialisation with valid parameters."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=1000.0, num_channels=2)
+
+ acquisition_rate = AcquisitionRate(
+ hardware_rate_hz=1000.0,
+ num_channels=2,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.NONE,
+ )
+ buffer.set_acquisition_rate(acquisition_rate)
+
+ assert buffer.capacity == 1000
+ assert buffer.num_channels == 2
+ assert buffer.storage_rate_hz == 1000.0
+ assert buffer.write_idx == 0
+ assert buffer.buffer.shape == (1000, 2)
+ assert buffer.buffer.dtype == np.int16
+
+ def test_init_invalid_capacity(self) -> None:
+ """Test initialisation with invalid parameters raises ValueError."""
+ with pytest.raises(ValueError, match="Buffer capacity must be positive"):
+ RingBuffer(duration_s=0.0, sample_rate=1000.0)
+
+ with pytest.raises(ValueError, match="Buffer capacity must be positive"):
+ RingBuffer(duration_s=1.0, sample_rate=0.0)
+
+ def test_get_snapshot(self) -> None:
+ """Test snapshot returns current write position."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=1000.0, num_channels=2)
+ assert buffer.get_snapshot() == 0
+
+ data = np.ones((100, 2), dtype=np.int16)
+ buffer.write(data)
+ assert buffer.get_snapshot() == 100
+
+ def test_get_utilisation(self) -> None:
+ """Test utilisation calculation."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2)
+ assert buffer.get_utilisation() == 0.0
+
+ data = np.ones((50, 2), dtype=np.int16)
+ buffer.write(data)
+ assert buffer.get_utilisation() == 0.5
+
+ buffer.write(np.ones((100, 2), dtype=np.int16))
+ assert buffer.get_utilisation() == 1.0
+
+ def test_get_duration_available(self) -> None:
+ """Test duration available calculation."""
+ buffer = RingBuffer(duration_s=10.0, sample_rate=1000.0, num_channels=2)
+
+ acquisition_rate = AcquisitionRate(
+ hardware_rate_hz=1000.0,
+ num_channels=2,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.NONE,
+ )
+ buffer.set_acquisition_rate(acquisition_rate)
+
+ assert buffer.get_duration_available() == 0.0
+
+ data = np.ones((5000, 2), dtype=np.int16)
+ buffer.write(data)
+ assert buffer.get_duration_available() == 5.0
+
+
+class TestRingBufferWrite:
+ """Tests for the write method."""
+
+ def test_write_single(self) -> None:
+ """Test writing a single batch of data."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=1000.0, num_channels=2)
+ data = np.arange(200, dtype=np.int16).reshape(100, 2)
+
+ buffer.write(data)
+ assert buffer.write_idx == 100
+
+ result = buffer.read_last(100)
+ np.testing.assert_array_equal(result, data)
+
+ def test_write_wrong_dtype(self) -> None:
+ """Test writing wrong dtype raises ValueError."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=1000.0, num_channels=2)
+ data = np.ones((100, 2), dtype=np.float32)
+
+ with pytest.raises(ValueError, match="Data must be int16"):
+ buffer.write(data)
+
+ def test_write_wrong_shape(self) -> None:
+ """Test writing wrong shape raises ValueError."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=1000.0, num_channels=2)
+
+ # Wrong number of channels
+ data = np.ones((100, 3), dtype=np.int16)
+ with pytest.raises(ValueError, match="Data shape must be"):
+ buffer.write(data)
+
+ # 1D array
+ data = np.ones(100, dtype=np.int16)
+ with pytest.raises(ValueError, match="Data shape must be"):
+ buffer.write(data)
+
+ def test_write_empty(self) -> None:
+ """Test writing empty array is a no-op."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=1000.0, num_channels=2)
+ data = np.zeros((0, 2), dtype=np.int16)
+
+ buffer.write(data)
+ assert buffer.write_idx == 0
+
+
+class TestRingBufferReadLast:
+ """Tests for the read_last method."""
+
+ def test_read_last_basic(self) -> None:
+ """Test reading most recent data."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2)
+
+ data1 = np.ones((50, 2), dtype=np.int16) * 1
+ data2 = np.ones((50, 2), dtype=np.int16) * 2
+ buffer.write(data1)
+ buffer.write(data2)
+
+ result = buffer.read_last(50)
+ expected = np.ones((50, 2), dtype=np.int16) * 2
+ np.testing.assert_array_equal(result, expected)
+
+ def test_read_last_zero(self) -> None:
+ """Test reading zero samples returns empty array."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2)
+ buffer.write(np.ones((50, 2), dtype=np.int16))
+
+ result = buffer.read_last(0)
+ assert result.shape == (0, 2)
+
+ def test_read_last_more_than_available(self) -> None:
+ """Test reading more than available clamps to available."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2)
+ data = np.arange(100, dtype=np.int16).reshape(50, 2)
+ buffer.write(data)
+
+ result = buffer.read_last(1000)
+ assert result.shape == (50, 2)
+ np.testing.assert_array_equal(result, data)
+
+ def test_read_last_before_write(self) -> None:
+ """Test reading from empty buffer returns empty."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2)
+ result = buffer.read_last(50)
+ assert result.shape == (0, 2)
+
+
+class TestRingBufferReadRange:
+ """Tests for the read_range method."""
+
+ def test_read_range_no_wrap(self) -> None:
+ """Test reading a range without wrap-around."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2)
+ data = np.arange(200, dtype=np.int16).reshape(100, 2)
+ buffer.write(data)
+
+ result = buffer.read_range(0, 50)
+ expected = data[:50]
+ np.testing.assert_array_equal(result, expected)
+
+ def test_read_range_empty(self) -> None:
+ """Test reading empty range returns empty array."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2)
+ buffer.write(np.ones((50, 2), dtype=np.int16))
+
+ result = buffer.read_range(50, 50)
+ assert result.shape == (0, 2)
+
+ result = buffer.read_range(60, 50)
+ assert result.shape == (0, 2)
+
+ def test_read_range_clamped(self) -> None:
+ """Test reading more than capacity clamps to capacity."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2)
+ data = np.arange(200, dtype=np.int16).reshape(100, 2)
+ buffer.write(data)
+
+ result = buffer.read_range(0, 200)
+ assert result.shape == (100, 2)
+
+
+class TestRingBufferWrapAround:
+ """Tests for wrap-around behavior."""
+
+ def test_write_wraparound(self) -> None:
+ """Test writing past capacity overwrites old data."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2)
+
+ data1 = np.ones((100, 2), dtype=np.int16) * 1
+ buffer.write(data1)
+ assert buffer.write_idx == 100
+
+ data2 = np.ones((50, 2), dtype=np.int16) * 2
+ buffer.write(data2)
+ assert buffer.write_idx == 150
+
+ # Read last 100 samples - should have 50 from data1 (positions 50-99)
+ # and 50 from data2 (positions 100-149)
+ result = buffer.read_last(100)
+ assert result[0, 0] == 1 # From data1 (earlier in the read)
+ assert result[50, 0] == 2 # From data2 (later in the read)
+
+ def test_write_exact_capacity(self) -> None:
+ """Test writing exactly capacity samples."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2)
+ data = np.arange(200, dtype=np.int16).reshape(100, 2)
+ buffer.write(data)
+
+ result = buffer.read_last(100)
+ np.testing.assert_array_equal(result, data)
+
+ def test_write_past_capacity(self) -> None:
+ """Test writing past capacity overwrites from beginning."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2)
+
+ data1 = np.arange(200, dtype=np.int16).reshape(100, 2)
+ buffer.write(data1)
+
+ data2 = np.arange(200, 400, dtype=np.int16).reshape(100, 2)
+ buffer.write(data2)
+
+ # Buffer should contain only data2
+ result = buffer.read_last(100)
+ np.testing.assert_array_equal(result, data2)
+
+ def test_read_range_wraparound(self) -> None:
+ """Test reading a range that spans the wrap point."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2)
+
+ # Fill buffer completely
+ data1 = np.arange(200, dtype=np.int16).reshape(100, 2)
+ buffer.write(data1)
+ assert buffer.write_idx == 100
+
+ # Write more to cause wrap (only last 100 samples kept)
+ data2 = np.arange(200, 400, dtype=np.int16).reshape(100, 2)
+ buffer.write(data2)
+ assert buffer.write_idx == 200
+
+ # Buffer now contains data2 at positions 100-199 (write_idx range)
+ # but stored at buffer indices 0-99
+
+ # Read from position 150 to 200 (last 50 samples)
+ result = buffer.read_range(150, 200)
+ expected = data2[50:100] # Last 50 samples of data2
+ np.testing.assert_array_equal(result, expected)
+
+ def test_read_last_wraparound(self) -> None:
+ """Test read_last with wraparound."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2)
+
+ data1 = np.ones((80, 2), dtype=np.int16) * 1
+ data2 = np.ones((50, 2), dtype=np.int16) * 2
+ buffer.write(data1)
+ buffer.write(data2)
+
+ # Read last 60 samples - should span wrap point
+ result = buffer.read_last(60)
+ assert result.shape == (60, 2)
+ assert result[-1, 0] == 2
+ assert result[0, 0] == 1
+
+
+class TestRingBufferReadSince:
+ """Tests for the read_since method."""
+
+ def test_read_since_basic(self) -> None:
+ """Test reading data since a position."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2)
+
+ pos1 = buffer.get_snapshot()
+ data1 = np.ones((50, 2), dtype=np.int16) * 1
+ buffer.write(data1)
+
+ result = buffer.read_since(pos1)
+ assert result.shape == (50, 2)
+ np.testing.assert_array_equal(result, data1)
+
+ def test_read_since_empty(self) -> None:
+ """Test reading since current position returns empty."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2)
+ buffer.write(np.ones((50, 2), dtype=np.int16))
+
+ pos = buffer.get_snapshot()
+ result = buffer.read_since(pos)
+ assert result.shape == (0, 2)
+
+
+class TestRingBufferIsValidRange:
+ """Tests for the is_valid_range method."""
+
+ def test_is_valid_range_true(self) -> None:
+ """Test valid range returns True."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2)
+ buffer.write(np.ones((50, 2), dtype=np.int16))
+
+ assert buffer.is_valid_range(0) is True
+ assert buffer.is_valid_range(25) is True
+ assert buffer.is_valid_range(49) is True
+
+ def test_is_valid_range_false(self) -> None:
+ """Test overwritten range returns False."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2)
+
+ data1 = np.ones((100, 2), dtype=np.int16)
+ buffer.write(data1)
+
+ data2 = np.ones((60, 2), dtype=np.int16) * 2
+ buffer.write(data2)
+
+ # Position 0 has been overwritten
+ assert buffer.is_valid_range(0) is False
+ assert buffer.is_valid_range(50) is False
+
+ # Position 60 (start of data2) should be valid
+ assert buffer.is_valid_range(60) is True
+
+
+class TestRingBufferMultiChannel:
+ """Tests for multi-channel data integrity."""
+
+ def test_multi_channel_aligned(self) -> None:
+ """Test channel data stays aligned through wrap-around."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2)
+
+ # Write data where each sample has a specific pattern
+ data = np.zeros((150, 2), dtype=np.int16)
+ data[:, 0] = np.arange(150, dtype=np.int16) # Channel 0: 0, 1, 2, ...
+ data[:, 1] = np.arange(
+ 300, 450, dtype=np.int16
+ ) # Channel 1: 300, 301, 302, ...
+
+ buffer.write(data)
+
+ # Read last 50 samples - should be the tail of original data
+ result = buffer.read_last(50)
+
+ # Channel 0 should have consecutive values
+ expected_ch0 = np.arange(100, 150, dtype=np.int16)
+ np.testing.assert_array_equal(result[:, 0], expected_ch0)
+
+ # Channel 1 should have the corresponding values
+ expected_ch1 = np.arange(400, 450, dtype=np.int16)
+ np.testing.assert_array_equal(result[:, 1], expected_ch1)
+
+ def test_single_channel(self) -> None:
+ """Test with single channel."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=1)
+ data = np.arange(100, dtype=np.int16).reshape(100, 1)
+
+ buffer.write(data)
+ result = buffer.read_last(50)
+
+ expected = np.arange(50, 100, dtype=np.int16).reshape(50, 1)
+ np.testing.assert_array_equal(result, expected)
+
+ def test_four_channels(self) -> None:
+ """Test with four channels."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=4)
+ data = np.arange(400, dtype=np.int16).reshape(100, 4)
+
+ buffer.write(data)
+ result = buffer.read_last(50)
+
+ expected = data[50:]
+ np.testing.assert_array_equal(result, expected)
+
+
+class TestRingBufferReturnsCopy:
+ """Tests that read methods return copies, not views."""
+
+ def test_read_last_returns_copy(self) -> None:
+ """Test read_last returns a copy."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2)
+ data = np.ones((50, 2), dtype=np.int16)
+ buffer.write(data)
+
+ result = buffer.read_last(50)
+ result[0, 0] = 999
+
+ # Original buffer should be unchanged
+ result2 = buffer.read_last(50)
+ assert result2[0, 0] == 1
+
+ def test_read_range_returns_copy(self) -> None:
+ """Test read_range returns a copy."""
+ buffer = RingBuffer(duration_s=1.0, sample_rate=100.0, num_channels=2)
+ data = np.ones((50, 2), dtype=np.int16)
+ buffer.write(data)
+
+ result = buffer.read_range(0, 50)
+ result[0, 0] = 999
+
+ result2 = buffer.read_range(0, 50)
+ assert result2[0, 0] == 1
+
+
+class TestRingBufferThreadSafety:
+ """Tests for thread safety under load."""
+
+ def test_concurrent_write_read(self) -> None:
+ """Test concurrent writing and reading.
+
+ Runs writer and reader threads concurrently to verify no crashes
+ or obvious data corruption.
+ """
+ buffer = RingBuffer(duration_s=0.5, sample_rate=10000.0, num_channels=2)
+ errors: list[str] = []
+ stop_event = threading.Event()
+
+ def writer() -> None:
+ """Writer thread."""
+ counter = 0
+ while not stop_event.is_set():
+ data = np.full((100, 2), counter, dtype=np.int16)
+ try:
+ buffer.write(data)
+ except Exception as e:
+ errors.append(f"Writer error: {e}")
+ break
+ counter = (counter + 1) % 1000
+
+ def reader() -> None:
+ """Reader thread."""
+ while not stop_event.is_set():
+ try:
+ _ = buffer.read_last(1000)
+ except Exception as e:
+ errors.append(f"Reader error: {e}")
+ break
+ time.sleep(0.001)
+
+ writer_thread = threading.Thread(target=writer)
+ reader_thread = threading.Thread(target=reader)
+
+ writer_thread.start()
+ reader_thread.start()
+
+ time.sleep(0.5)
+ stop_event.set()
+
+ writer_thread.join(timeout=2.0)
+ reader_thread.join(timeout=2.0)
+
+ assert not errors, f"Thread errors occurred: {errors}"
+
+ def test_data_integrity_under_load(self) -> None:
+ """Test data integrity with sequential pattern under thread load.
+
+ Writer writes sequential values, readers verify they can read
+ consistent data.
+ """
+ buffer = RingBuffer(duration_s=0.1, sample_rate=10000.0, num_channels=2)
+ stop_event = threading.Event()
+
+ def writer() -> None:
+ """Writer with sequential pattern."""
+ counter = 0
+ while not stop_event.is_set():
+ data = np.zeros((50, 2), dtype=np.int16)
+ data[:, 0] = counter
+ data[:, 1] = counter + 1
+ buffer.write(data)
+ counter += 1
+ if counter > 10000:
+ counter = 0
+
+ writer_thread = threading.Thread(target=writer)
+ writer_thread.start()
+
+ time.sleep(0.05)
+
+ # Multiple reads to verify consistency
+ for _ in range(100):
+ result = buffer.read_last(100)
+ if len(result) > 0:
+ # All values in a column should be the same
+ # (within the same write batch)
+ assert (
+ np.all(result[:, 0] == result[0, 0]) or len(set(result[:, 0])) <= 2
+ )
+ time.sleep(0.001)
+
+ stop_event.set()
+ writer_thread.join(timeout=2.0)
diff --git a/picostream/test_zarr_reader.py b/picostream/test_zarr_reader.py
new file mode 100644
index 0000000..aabf360
--- /dev/null
+++ b/picostream/test_zarr_reader.py
@@ -0,0 +1,331 @@
+"""Tests for the PicoZarrReader class."""
+
+from __future__ import annotations
+
+import os
+import shutil
+import tempfile
+
+import numpy as np
+import pytest
+
+from picostream.acquisition_rate import AcquisitionRate, DownsampleMode
+from picostream.zarr_reader import PicoZarrReader
+from picostream.zarr_writer import ZarrStreamWriter
+
+
+@pytest.fixture
+def temp_zarr_path():
+ """Provide a temporary path for Zarr files."""
+ path = tempfile.mkdtemp(suffix=".zarr")
+ yield path
+ if os.path.exists(path):
+ shutil.rmtree(path)
+
+
+@pytest.fixture
+def sample_metadata():
+ """Provide sample metadata for Zarr tests."""
+ return {
+ "hardware_rate_hz": 2000.0,
+ "channels": [0, 1],
+ "voltage_ranges": [20.0, 20.0],
+ "offsets_v": [0.0, 0.0],
+ "resolution": 16,
+ "downsample_ratio": 1,
+ "downsample_mode": "NONE",
+ "pre_trigger_seconds": 5.0,
+ }
+
+
+@pytest.fixture
+def populated_zarr_file(temp_zarr_path, sample_metadata):
+ """Create a Zarr file with test data."""
+ acquisition_rate = AcquisitionRate(
+ hardware_rate_hz=sample_metadata["hardware_rate_hz"],
+ num_channels=len(sample_metadata["channels"]),
+ downsample_ratio=sample_metadata["downsample_ratio"],
+ downsample_mode=DownsampleMode(sample_metadata["downsample_mode"]),
+ )
+ writer = ZarrStreamWriter(
+ path=temp_zarr_path,
+ acquisition_rate=acquisition_rate,
+ num_channels=len(sample_metadata["channels"]),
+ )
+
+ data = np.zeros((5000, 2), dtype=np.int16)
+ data[:, 0] = np.sin(np.linspace(0, 4 * np.pi, 5000)) * 10000
+ data[:, 1] = np.cos(np.linspace(0, 4 * np.pi, 5000)) * 8000
+
+ writer.append(data)
+ writer.close(sample_metadata)
+
+ return temp_zarr_path
+
+
+class TestPicoZarrReaderBasic:
+ """Basic functionality tests."""
+
+ def test_init(self, populated_zarr_file, sample_metadata):
+ """Test reader initialisation."""
+ reader = PicoZarrReader(populated_zarr_file)
+
+ assert reader.path.name == os.path.basename(populated_zarr_file)
+ assert reader.sample_rate == 2000.0
+ assert reader.num_channels == 2
+ assert reader.total_samples == 5000
+ assert reader.duration_s == 2.5
+ assert reader.channels == [0, 1]
+ assert reader.resolution == 16
+
+ def test_init_nonexistent_path(self):
+ """Test initialisation with non-existent path raises FileNotFoundError."""
+ with pytest.raises(FileNotFoundError):
+ PicoZarrReader("/nonexistent/path/file.zarr")
+
+ def test_voltage_ranges(self, populated_zarr_file):
+ """Test voltage ranges are loaded correctly."""
+ reader = PicoZarrReader(populated_zarr_file)
+
+ assert reader.voltage_ranges == {0: 20.0, 1: 20.0}
+ assert reader.offsets_v == {0: 0.0, 1: 0.0}
+
+ def test_max_adc_calculation(self, populated_zarr_file):
+ """Test max ADC value is calculated from resolution."""
+ reader = PicoZarrReader(populated_zarr_file)
+
+ expected_max = (1 << (16 - 1)) - 1
+ assert reader.max_adc == expected_max
+
+
+class TestPicoZarrReaderRaw:
+ """Raw data reading tests."""
+
+ def test_get_raw_basic(self, populated_zarr_file):
+ """Test reading raw data."""
+ reader = PicoZarrReader(populated_zarr_file)
+
+ data = reader.get_raw(0.0, 1.0)
+
+ assert data.shape == (1000, 2)
+ assert data.dtype == np.int16
+
+ def test_get_raw_empty_range(self, populated_zarr_file):
+ """Test reading empty range returns empty array."""
+ reader = PicoZarrReader(populated_zarr_file)
+
+ data = reader.get_raw(0.0, 0.0)
+ assert data.shape == (0, 2)
+
+ data = reader.get_raw(1.0, 0.5)
+ assert data.shape == (0, 2)
+
+ def test_get_raw_out_of_bounds(self, populated_zarr_file):
+ """Test reading beyond file bounds clamps correctly."""
+ reader = PicoZarrReader(populated_zarr_file)
+
+ data = reader.get_raw(4.0, 10.0)
+
+ assert data.shape[0] == 1000
+
+ def test_get_raw_returns_copy(self, populated_zarr_file):
+ """Test that get_raw returns a copy, not a view."""
+ reader = PicoZarrReader(populated_zarr_file)
+
+ data1 = reader.get_raw(0.0, 1.0)
+ data1[0, 0] = 999
+
+ data2 = reader.get_raw(0.0, 1.0)
+ assert data2[0, 0] != 999
+
+
+class TestPicoZarrReaderVoltage:
+ """Voltage conversion tests."""
+
+ def test_get_voltage_basic(self, populated_zarr_file):
+ """Test reading voltage data."""
+ reader = PicoZarrReader(populated_zarr_file)
+
+ voltage = reader.get_voltage(0.0, 1.0)
+
+ assert voltage.shape == (1000, 2)
+ assert voltage.dtype == np.float64
+
+ def test_get_voltage_empty(self, populated_zarr_file):
+ """Test reading empty range returns empty float array."""
+ reader = PicoZarrReader(populated_zarr_file)
+
+ voltage = reader.get_voltage(0.0, 0.0)
+ assert voltage.shape == (0, 2)
+ assert voltage.dtype == np.float64
+
+ def test_get_voltage_conversion(self, populated_zarr_file):
+ """Test voltage conversion is correct."""
+ reader = PicoZarrReader(populated_zarr_file)
+
+ raw = reader.get_raw(0.0, 1.0)
+ voltage = reader.get_voltage(0.0, 1.0)
+
+ max_raw = np.max(np.abs(raw))
+ max_voltage = np.max(np.abs(voltage))
+
+ expected_ratio = 20.0 * 1000 / reader.max_adc
+ actual_ratio = max_voltage / max_raw if max_raw > 0 else 0
+
+ assert abs(actual_ratio - expected_ratio) < 0.1
+
+
+class TestPicoZarrReaderTimeAxis:
+ """Time axis tests."""
+
+ def test_get_time_axis_basic(self, populated_zarr_file):
+ """Test time axis creation."""
+ reader = PicoZarrReader(populated_zarr_file)
+
+ time_axis = reader.get_time_axis(0.0, 1.0)
+
+ assert len(time_axis) == 1000
+ assert time_axis[0] == 0.0
+ assert abs(time_axis[-1] - 1.0) < 0.01
+
+ def test_get_time_axis_empty(self, populated_zarr_file):
+ """Test empty time axis."""
+ reader = PicoZarrReader(populated_zarr_file)
+
+ time_axis = reader.get_time_axis(0.0, 0.0)
+ assert len(time_axis) == 0
+
+ def test_get_time_axis_matches_data(self, populated_zarr_file):
+ """Test time axis length matches data length."""
+ reader = PicoZarrReader(populated_zarr_file)
+
+ raw = reader.get_raw(0.5, 1.5)
+ time_axis = reader.get_time_axis(0.5, 1.5)
+
+ assert len(time_axis) == raw.shape[0]
+
+
+class TestPicoZarrReaderMetadata:
+ """Metadata tests."""
+
+ def test_get_metadata(self, populated_zarr_file, sample_metadata):
+ """Test metadata retrieval."""
+ reader = PicoZarrReader(populated_zarr_file)
+
+ metadata = reader.get_metadata()
+
+ assert metadata["sample_rate_hz"] == 2000.0
+ assert metadata["channels"] == [0, 1]
+ assert metadata["total_samples"] == 5000
+
+ def test_pre_trigger_time(self, populated_zarr_file):
+ """Test pre-trigger time retrieval."""
+ reader = PicoZarrReader(populated_zarr_file)
+
+ assert reader.get_pre_trigger_time() == 5.0
+
+ def test_get_channel_name(self, populated_zarr_file):
+ """Test channel name generation."""
+ reader = PicoZarrReader(populated_zarr_file)
+
+ assert reader.get_channel_name(0) == "A"
+ assert reader.get_channel_name(1) == "B"
+
+
+class TestPicoZarrReaderTimeConversion:
+ """Time/sample conversion tests."""
+
+ def test_time_to_sample(self, populated_zarr_file):
+ """Test time to sample conversion."""
+ reader = PicoZarrReader(populated_zarr_file)
+
+ assert reader._time_to_sample(0.0) == 0
+ assert reader._time_to_sample(1.0) == 1000
+ assert reader._time_to_sample(5.0) == 5000
+
+ def test_time_to_sample_clamping(self, populated_zarr_file):
+ """Test time to sample clamping."""
+ reader = PicoZarrReader(populated_zarr_file)
+
+ assert reader._time_to_sample(-1.0) == 0
+ assert reader._time_to_sample(10.0) == 5000
+
+ def test_sample_to_time(self, populated_zarr_file):
+ """Test sample to time conversion."""
+ reader = PicoZarrReader(populated_zarr_file)
+
+ assert reader._sample_to_time(0) == 0.0
+ assert reader._sample_to_time(1000) == 1.0
+ assert reader._sample_to_time(5000) == 5.0
+
+
+class TestPicoZarrReaderClose:
+ """Close/cleanup tests."""
+
+ def test_close(self, populated_zarr_file):
+ """Test reader close."""
+ reader = PicoZarrReader(populated_zarr_file)
+
+ reader.close()
+
+ with pytest.raises(ValueError, match="Reader has been closed"):
+ reader.get_raw(0.0, 1.0)
+
+ def test_context_manager(self, populated_zarr_file):
+ """Test context manager usage."""
+ with PicoZarrReader(populated_zarr_file) as reader:
+ assert reader.total_samples == 5000
+
+ with pytest.raises(ValueError, match="Reader has been closed"):
+ reader.get_raw(0.0, 1.0)
+
+
+class TestPicoZarrReaderPartialReads:
+ """Partial read tests."""
+
+ def test_partial_read_start(self, populated_zarr_file):
+ """Test reading from start of file."""
+ reader = PicoZarrReader(populated_zarr_file)
+
+ data = reader.get_raw(0.0, 0.5)
+
+ assert data.shape[0] == 500
+
+ def test_partial_read_end(self, populated_zarr_file):
+ """Test reading to end of file."""
+ reader = PicoZarrReader(populated_zarr_file)
+
+ data = reader.get_raw(4.5, 5.0)
+
+ assert data.shape[0] == 500
+
+ def test_partial_read_middle(self, populated_zarr_file):
+ """Test reading from middle of file."""
+ reader = PicoZarrReader(populated_zarr_file)
+
+ data = reader.get_raw(2.0, 3.0)
+
+ assert data.shape[0] == 1000
+
+ def test_arbitrary_time_range(self, populated_zarr_file):
+ """Test arbitrary time range."""
+ reader = PicoZarrReader(populated_zarr_file)
+
+ data = reader.get_raw(1.234, 3.567)
+
+ expected_samples = int((3.567 - 1.234) * 1000)
+ assert abs(data.shape[0] - expected_samples) <= 1
+
+
+class TestPicoZarrReaderRepresentation:
+ """String representation tests."""
+
+ def test_repr(self, populated_zarr_file):
+ """Test string representation."""
+ reader = PicoZarrReader(populated_zarr_file)
+
+ repr_str = repr(reader)
+
+ assert "PicoZarrReader" in repr_str
+ assert "2.50s" in repr_str
+ assert "2 ch" in repr_str
diff --git a/picostream/test_zarr_viewer.py b/picostream/test_zarr_viewer.py
new file mode 100644
index 0000000..10b5683
--- /dev/null
+++ b/picostream/test_zarr_viewer.py
@@ -0,0 +1,223 @@
+"""Tests for the ZarrViewerWindow class."""
+
+from __future__ import annotations
+
+import os
+import shutil
+import tempfile
+
+import numpy as np
+import pytest
+
+from picostream.acquisition_rate import AcquisitionRate, DownsampleMode
+from picostream.zarr_writer import ZarrStreamWriter
+
+pytest.importorskip("PyQt6.QtWidgets", reason="PyQt6 not available")
+
+from PyQt6.QtWidgets import QApplication
+
+from picostream.zarr_viewer import ZarrViewerWindow
+
+
+@pytest.fixture(scope="session")
+def qapp():
+ """Provide a Qt application instance for the test session."""
+ app = QApplication.instance()
+ if app is None:
+ app = QApplication([])
+ return app
+
+
+@pytest.fixture
+def temp_zarr_path():
+ """Provide a temporary path for Zarr files."""
+ path = tempfile.mkdtemp(suffix=".zarr")
+ yield path
+ if os.path.exists(path):
+ shutil.rmtree(path)
+
+
+@pytest.fixture
+def sample_metadata():
+ """Provide sample metadata for Zarr tests."""
+ return {
+ "hardware_rate_hz": 1000.0,
+ "channels": [0, 1],
+ "voltage_ranges": [20.0, 20.0],
+ "offsets_v": [0.0, 0.0],
+ "resolution": 16,
+ "downsample_ratio": 1,
+ "downsample_mode": "NONE",
+ "pre_trigger_seconds": 5.0,
+ }
+
+
+@pytest.fixture
+def populated_zarr_file(temp_zarr_path, sample_metadata):
+ """Create a Zarr file with test data."""
+ acquisition_rate = AcquisitionRate(
+ hardware_rate_hz=sample_metadata["hardware_rate_hz"],
+ num_channels=2,
+ downsample_ratio=sample_metadata["downsample_ratio"],
+ downsample_mode=DownsampleMode(sample_metadata["downsample_mode"]),
+ )
+ writer = ZarrStreamWriter(
+ path=temp_zarr_path,
+ acquisition_rate=acquisition_rate,
+ num_channels=len(sample_metadata["channels"]),
+ )
+
+ data = np.zeros((10000, 2), dtype=np.int16)
+ data[:, 0] = np.sin(np.linspace(0, 8 * np.pi, 10000)) * 10000
+ data[:, 1] = np.cos(np.linspace(0, 8 * np.pi, 10000)) * 8000
+
+ writer.append(data)
+ writer.close(sample_metadata)
+
+ return temp_zarr_path
+
+
+class TestZarrViewerWindowBasic:
+ """Basic viewer window tests."""
+
+ def test_window_creation(self, qapp, populated_zarr_file):
+ """Test viewer window can be created."""
+ viewer = ZarrViewerWindow(populated_zarr_file)
+
+ assert viewer.reader is not None
+ assert viewer.reader.duration_s == 10.0
+ # Window duration is clamped to valid range
+ assert (
+ viewer.MIN_WINDOW_DURATION
+ <= viewer.window_duration_s
+ <= min(viewer.MAX_WINDOW_DURATION, viewer.reader.duration_s)
+ )
+
+ viewer.close()
+
+ def test_window_title(self, qapp, populated_zarr_file):
+ """Test window title is set correctly."""
+ viewer = ZarrViewerWindow(populated_zarr_file)
+
+ title = viewer.windowTitle()
+ assert "PicoStream Viewer" in title
+ assert "10.0s" in title
+
+ viewer.close()
+
+
+class TestZarrViewerWindowControls:
+ """Control interaction tests."""
+
+ def test_window_duration_change(self, qapp, populated_zarr_file):
+ """Test window duration can be changed."""
+ viewer = ZarrViewerWindow(populated_zarr_file)
+
+ # Get initial duration
+ initial_duration = viewer.window_duration_s
+ assert (
+ viewer.MIN_WINDOW_DURATION
+ <= initial_duration
+ <= min(viewer.MAX_WINDOW_DURATION, viewer.reader.duration_s)
+ )
+
+ viewer.window_duration_s = 3.0
+ assert viewer.window_duration_s == 3.0
+
+ viewer.close()
+
+ def test_scroll_position(self, qapp, populated_zarr_file):
+ """Test scroll position."""
+ viewer = ZarrViewerWindow(populated_zarr_file)
+
+ assert viewer.scroll_position_s == 0.0
+
+ viewer.scroll_position_s = 3.0
+ assert viewer.scroll_position_s == 3.0
+
+ viewer.close()
+
+ def test_clamp_scroll_position(self, qapp, populated_zarr_file):
+ """Test scroll position clamping."""
+ viewer = ZarrViewerWindow(populated_zarr_file)
+
+ viewer.window_duration_s = 5.0
+ viewer.scroll_position_s = 20.0
+ viewer._clamp_scroll_position()
+
+ max_pos = 10.0 - 5.0
+ assert viewer.scroll_position_s <= max_pos
+
+ viewer.close()
+
+
+class TestZarrViewerWindowChannelVisibility:
+ """Channel visibility tests."""
+
+ def test_channel_toggling(self, qapp, populated_zarr_file):
+ """Test channel visibility toggling."""
+ viewer = ZarrViewerWindow(populated_zarr_file)
+
+ assert viewer.curves[0].isVisible()
+
+ viewer._on_channel_toggled(0, False)
+ assert not viewer.curves[0].isVisible()
+
+ viewer._on_channel_toggled(0, True)
+ assert viewer.curves[0].isVisible()
+
+ viewer.close()
+
+
+class TestZarrViewerWindowTimeConversion:
+ """Time/sample conversion tests."""
+
+ def test_time_to_sample(self, qapp, populated_zarr_file):
+ """Test time to sample conversion."""
+ viewer = ZarrViewerWindow(populated_zarr_file)
+
+ sample_idx = viewer.reader._time_to_sample(1.0)
+ assert sample_idx == 500
+
+ sample_idx = viewer.reader._time_to_sample(5.0)
+ assert sample_idx == 2500
+
+ viewer.close()
+
+
+class TestZarrViewerWindowMetadata:
+ """Metadata display tests."""
+
+ def test_metadata_loaded(self, qapp, populated_zarr_file):
+ """Test metadata is loaded correctly."""
+ viewer = ZarrViewerWindow(populated_zarr_file)
+
+ assert viewer.reader.sample_rate == 1000.0
+ assert viewer.reader.num_channels == 2
+ assert viewer.reader.channels == [0, 1]
+ assert viewer.reader.pre_trigger_seconds == 5.0
+
+ viewer.close()
+
+
+class TestZarrViewerWindowConstants:
+ """Constants tests."""
+
+ def test_window_limits(self, qapp, populated_zarr_file):
+ """Test window duration limits."""
+ viewer = ZarrViewerWindow(populated_zarr_file)
+
+ assert viewer.MIN_WINDOW_DURATION == 0.5
+ assert viewer.MAX_WINDOW_DURATION == 30.0
+ assert viewer.MAX_ADC_VALUE == 32767
+
+ viewer.close()
+
+ def test_lossless_mode(self, qapp, populated_zarr_file):
+ """Test viewer operates in lossless mode (no decimation)."""
+ viewer = ZarrViewerWindow(populated_zarr_file)
+
+ assert viewer.pipeline is not None
+ assert viewer.pipeline.config.target_plot_points is None
+
+ viewer.close()
diff --git a/picostream/test_zarr_writer.py b/picostream/test_zarr_writer.py
new file mode 100644
index 0000000..74e578a
--- /dev/null
+++ b/picostream/test_zarr_writer.py
@@ -0,0 +1,310 @@
+"""Unit tests for ZarrStreamWriter."""
+
+import os
+import shutil
+import tempfile
+from typing import Generator
+
+import numpy as np
+import pytest
+import zarr
+
+from picostream.acquisition_rate import AcquisitionRate, DownsampleMode
+from picostream.zarr_writer import ZarrStreamWriter
+
+
+@pytest.fixture
+def temp_zarr_path() -> Generator[str, None, None]:
+ """Provide a temporary path for Zarr files."""
+ temp_dir = tempfile.mkdtemp()
+ zarr_path = os.path.join(temp_dir, "test.zarr")
+ yield zarr_path
+ shutil.rmtree(temp_dir, ignore_errors=True)
+
+
+class TestZarrStreamWriter:
+ """Test cases for ZarrStreamWriter."""
+
+ def test_basic_write_read_back(self, temp_zarr_path: str) -> None:
+ """Test basic writing and reading back data."""
+ acquisition_rate = AcquisitionRate(
+ hardware_rate_hz=1000.0,
+ num_channels=2,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.NONE,
+ )
+ num_channels = 2
+ writer = ZarrStreamWriter(temp_zarr_path, acquisition_rate, num_channels)
+
+ expected_data = np.array([[100, 200], [300, 400], [500, 600]], dtype=np.int16)
+ writer.append(expected_data)
+
+ metadata = {
+ "channels": [0, 1],
+ "voltage_ranges": [20.0, 20.0],
+ }
+ writer.close(metadata)
+
+ root = zarr.open(temp_zarr_path, mode="r")
+ stored_data = root["data"][:]
+
+ np.testing.assert_array_equal(stored_data, expected_data)
+
+ def test_batching_many_small_appends(self, temp_zarr_path: str) -> None:
+ """Test that many small appends are batched correctly."""
+ acquisition_rate = AcquisitionRate(
+ hardware_rate_hz=100.0,
+ num_channels=2,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.NONE,
+ )
+ num_channels = 2
+ writer = ZarrStreamWriter(temp_zarr_path, acquisition_rate, num_channels)
+
+ expected_data = []
+ for i in range(50):
+ small_chunk = np.array([[i, i + 1], [i + 2, i + 3]], dtype=np.int16)
+ writer.append(small_chunk)
+ expected_data.append(small_chunk)
+
+ writer.close({})
+
+ expected_concatenated = np.concatenate(expected_data, axis=0)
+ root = zarr.open(temp_zarr_path, mode="r")
+ stored_data = root["data"][:]
+
+ np.testing.assert_array_equal(stored_data, expected_concatenated)
+
+ def test_chunk_alignment(self, temp_zarr_path: str) -> None:
+ """Test that chunks align with 1-second boundaries."""
+ acquisition_rate = AcquisitionRate(
+ hardware_rate_hz=1000.0,
+ num_channels=2,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.NONE,
+ )
+ num_channels = 2
+ writer = ZarrStreamWriter(temp_zarr_path, acquisition_rate, num_channels)
+
+ assert writer._chunk_size == 1000
+
+ data = np.ones((2500, 2), dtype=np.int16)
+ writer.append(data)
+ writer.close({})
+
+ root = zarr.open(temp_zarr_path, mode="r")
+ data_array = root["data"]
+
+ assert data_array.chunks == (1000, 2)
+
+ def test_metadata_round_trip(self, temp_zarr_path: str) -> None:
+ """Test that metadata is correctly written and readable."""
+ acquisition_rate = AcquisitionRate(
+ hardware_rate_hz=50000.0,
+ num_channels=2,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.NONE,
+ )
+ num_channels = 2
+ writer = ZarrStreamWriter(temp_zarr_path, acquisition_rate, num_channels)
+
+ test_data = np.ones((100, 2), dtype=np.int16)
+ writer.append(test_data)
+
+ expected_metadata = {
+ "channels": [0, 1],
+ "voltage_ranges": [20.0, 20.0],
+ "offsets_v": [0.0, 0.0],
+ "resolution": 12,
+ "max_adc": 32767,
+ "pre_trigger_seconds": 10.0,
+ "compression": None,
+ }
+ writer.close(expected_metadata)
+
+ root = zarr.open(temp_zarr_path, mode="r")
+
+ assert list(root.attrs["channels"]) == [0, 1]
+ assert list(root.attrs["voltage_ranges"]) == [20.0, 20.0]
+ assert root.attrs["resolution"] == 12
+ assert root.attrs["max_adc"] == 32767
+ assert root.attrs["format_version"] == "2.0"
+ assert "total_samples" in root.attrs
+ assert "duration_s" in root.attrs
+ assert "start_time_iso" in root.attrs
+
+ def test_empty_file(self, temp_zarr_path: str) -> None:
+ """Test closing without writing any data."""
+ acquisition_rate = AcquisitionRate(
+ hardware_rate_hz=1000.0,
+ num_channels=2,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.NONE,
+ )
+ num_channels = 2
+ writer = ZarrStreamWriter(temp_zarr_path, acquisition_rate, num_channels)
+ writer.close({})
+
+ root = zarr.open(temp_zarr_path, mode="r")
+ stored_data = root["data"][:]
+
+ assert stored_data.shape == (0, 2)
+ assert root.attrs["total_samples"] == 0
+
+ def test_large_write(self, temp_zarr_path: str) -> None:
+ """Test writing a larger amount of data."""
+ acquisition_rate = AcquisitionRate(
+ hardware_rate_hz=10000.0,
+ num_channels=2,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.NONE,
+ )
+ num_channels = 2
+ writer = ZarrStreamWriter(temp_zarr_path, acquisition_rate, num_channels)
+
+ large_data = np.random.randint(-10000, 10000, size=(60000, 2), dtype=np.int16)
+ writer.append(large_data)
+ writer.close({})
+
+ root = zarr.open(temp_zarr_path, mode="r")
+ stored_data = root["data"][:]
+
+ np.testing.assert_array_equal(stored_data, large_data)
+ assert root.attrs["total_samples"] == 60000
+
+ def test_invalid_data_dtype(self, temp_zarr_path: str) -> None:
+ """Test that non-int16 data raises an error."""
+ acquisition_rate = AcquisitionRate(
+ hardware_rate_hz=1000.0,
+ num_channels=2,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.NONE,
+ )
+ writer = ZarrStreamWriter(temp_zarr_path, acquisition_rate, 2)
+
+ invalid_data = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
+
+ with pytest.raises(ValueError, match="Data must be int16"):
+ writer.append(invalid_data)
+
+ def test_invalid_data_shape(self, temp_zarr_path: str) -> None:
+ """Test that data with wrong shape raises an error."""
+ acquisition_rate = AcquisitionRate(
+ hardware_rate_hz=1000.0,
+ num_channels=2,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.NONE,
+ )
+ writer = ZarrStreamWriter(temp_zarr_path, acquisition_rate, 2)
+
+ wrong_channels = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)
+
+ with pytest.raises(ValueError, match="Data shape must be"):
+ writer.append(wrong_channels)
+
+ def test_context_manager(self, temp_zarr_path: str) -> None:
+ """Test using the writer as a context manager."""
+ data = np.array([[100, 200], [300, 400]], dtype=np.int16)
+
+ acquisition_rate = AcquisitionRate(
+ hardware_rate_hz=1000.0,
+ num_channels=2,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.NONE,
+ )
+ with ZarrStreamWriter(temp_zarr_path, acquisition_rate, 2) as writer:
+ writer.append(data)
+
+ root = zarr.open(temp_zarr_path, mode="r")
+ stored_data = root["data"][:]
+
+ np.testing.assert_array_equal(stored_data, data)
+
+ def test_discard(self, temp_zarr_path: str) -> None:
+ """Test discarding a file."""
+ acquisition_rate = AcquisitionRate(
+ hardware_rate_hz=1000.0,
+ num_channels=2,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.NONE,
+ )
+ writer = ZarrStreamWriter(temp_zarr_path, acquisition_rate, 2)
+ data = np.array([[1, 2], [3, 4]], dtype=np.int16)
+ writer.append(data)
+ writer.discard()
+
+ assert not os.path.exists(temp_zarr_path)
+
+ def test_flush_partial_chunk(self, temp_zarr_path: str) -> None:
+ """Test that flush writes partial chunks."""
+ acquisition_rate = AcquisitionRate(
+ hardware_rate_hz=1000.0,
+ num_channels=2,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.NONE,
+ )
+ num_channels = 2
+ writer = ZarrStreamWriter(temp_zarr_path, acquisition_rate, num_channels)
+
+ data = np.ones((100, 2), dtype=np.int16)
+ writer.append(data)
+
+ writer.flush()
+ writer.close({})
+
+ root = zarr.open(temp_zarr_path, mode="r")
+ stored_data = root["data"][:]
+
+ np.testing.assert_array_equal(stored_data, data)
+
+ def test_multiple_flush_calls(self, temp_zarr_path: str) -> None:
+ """Test multiple flush calls work correctly."""
+ acquisition_rate = AcquisitionRate(
+ hardware_rate_hz=1000.0,
+ num_channels=2,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.NONE,
+ )
+ num_channels = 2
+ writer = ZarrStreamWriter(temp_zarr_path, acquisition_rate, num_channels)
+
+ data1 = np.ones((100, 2), dtype=np.int16)
+ writer.append(data1)
+ writer.flush()
+
+ data2 = np.ones((50, 2), dtype=np.int16) * 2
+ writer.append(data2)
+ writer.flush()
+
+ expected = np.concatenate([data1, data2], axis=0)
+
+ writer.close({})
+ root = zarr.open(temp_zarr_path, mode="r")
+ stored_data = root["data"][:]
+
+ np.testing.assert_array_equal(stored_data, expected)
+
+ def test_written_counter(self, temp_zarr_path: str) -> None:
+ """Test that the written counter is accurate."""
+ acquisition_rate = AcquisitionRate(
+ hardware_rate_hz=1000.0,
+ num_channels=2,
+ downsample_ratio=1,
+ downsample_mode=DownsampleMode.NONE,
+ )
+ writer = ZarrStreamWriter(temp_zarr_path, acquisition_rate, 2)
+
+ assert writer.written == 0
+
+ data1 = np.ones((100, 2), dtype=np.int16)
+ writer.append(data1)
+ writer.flush()
+ assert writer.written == 100
+
+ data2 = np.ones((200, 2), dtype=np.int16)
+ writer.append(data2)
+ writer.flush()
+ assert writer.written == 300
+
+ writer.close({})
+ assert writer.written == 300
diff --git a/picostream/zarr_reader.py b/picostream/zarr_reader.py
new file mode 100644
index 0000000..7c8b1af
--- /dev/null
+++ b/picostream/zarr_reader.py
@@ -0,0 +1,405 @@
+"""Zarr file reader for post-hoc analysis of saved PicoStream data.
+
+This module provides a reader for Zarr files created by ZarrStreamWriter,
+enabling post-acquisition analysis with on-demand disk access. Data is not
+loaded entirely into memory; instead, Zarr's chunk-level access is used to
+read only the requested time ranges.
+"""
+
+from __future__ import annotations
+
+from pathlib import Path
+from typing import TYPE_CHECKING, Dict, List, Optional
+
+import numpy as np
+import zarr
+from loguru import logger
+
+from picostream.conversion_utils import adc_to_mV
+
+if TYPE_CHECKING:
+ from picostream.acquisition_rate import AcquisitionRate
+
+
+class PicoZarrReader:
+ """Reader for saved PicoStream Zarr files.
+
+ Provides access to ADC data and metadata from Zarr files with on-demand
+ reading. Voltage conversion and time axis construction are handled
+ automatically.
+
+ Parameters
+ ----------
+ path : str
+ Path to the Zarr directory store.
+
+ Attributes
+ ----------
+ path : Path
+ Path to the Zarr file.
+ root : zarr.Group
+ Root Zarr group.
+ data : zarr.Array
+ Data array containing int16 ADC samples.
+ sample_rate : float
+ Sample rate in Hz.
+ num_channels : int
+ Number of channels in the data.
+ total_samples : int
+ Total number of samples per channel.
+ duration_s : float
+ Total duration in seconds.
+ channels : List[int]
+ List of channel indices.
+ voltage_ranges : Dict[int, float]
+ Voltage range in Volts for each channel.
+ offsets_v : Dict[int, float]
+ Voltage offset in Volts for each channel.
+ resolution : int
+ ADC resolution in bits.
+ max_adc : int
+ Maximum ADC value.
+ downsample_ratio : int
+ Downsample ratio if applied.
+ downsample_mode : str
+ Downsample mode (e.g., "NONE", "AGGREGATE").
+ pre_trigger_seconds : float
+ Pre-trigger capture duration in seconds.
+ format_version : str
+ File format version string.
+ compression : Optional[str]
+ Compression algorithm used, if any.
+ start_time_iso : str
+ ISO 8601 timestamp when capture started.
+ buffer_duration_s : float
+ Ring buffer duration in seconds.
+ device_id : str
+ Device identifier.
+ serial_code : Optional[str]
+ Device serial code, if available.
+ coupling : str
+ Input coupling mode (e.g., "DC", "AC").
+ bandwidth_limiter : str
+ Bandwidth limiter setting (e.g., "FULL", "20MHZ").
+
+ Examples
+ --------
+ >>> reader = PicoZarrReader("save_20240115_143022.zarr")
+ >>> print(f"Duration: {reader.duration_s}s, Channels: {reader.channels}")
+ >>> raw = reader.get_raw(0, 1.0) # Get first second of raw data
+ >>> voltage = reader.get_voltage(0, 1.0) # Get first second in mV
+ >>> reader.close()
+ """
+
+ def __init__(self, path: str) -> None:
+ """Initialise the reader and load metadata from the Zarr file."""
+ self.path = Path(path)
+
+ if not self.path.exists():
+ raise FileNotFoundError(f"Zarr file not found: {path}")
+
+ self.root = zarr.open(str(self.path), mode="r")
+ self.data: zarr.Array = self.root["data"] # type: ignore
+
+ attrs = self.root.attrs
+
+ self.num_channels: int = int(self.data.shape[1])
+ self.total_samples: int = int(self.data.shape[0])
+
+ # Reconstruct AcquisitionRate from metadata
+ hardware_rate = attrs.get("hardware_rate_hz") # type: ignore
+ ds_ratio = attrs.get("downsample_ratio", 1) # type: ignore
+ ds_mode = attrs.get("downsample_mode", "NONE") # type: ignore
+
+ if hardware_rate is None:
+ raise ValueError(
+ "Missing required 'hardware_rate_hz' metadata field. "
+ "This file may be from an old format version."
+ )
+
+ from picostream.acquisition_rate import AcquisitionRate, DownsampleMode
+
+ num_channels_meta = attrs.get("num_channels", self.num_channels)
+ assert isinstance(num_channels_meta, (int, float))
+ assert isinstance(ds_ratio, (int, float, str))
+ downsample_ratio_value: int = int(float(ds_ratio))
+
+ self.acquisition_rate: Optional[AcquisitionRate] = AcquisitionRate(
+ hardware_rate_hz=float(hardware_rate), # type: ignore
+ num_channels=int(num_channels_meta),
+ downsample_ratio=downsample_ratio_value,
+ downsample_mode=DownsampleMode(str(ds_mode)), # type: ignore
+ )
+ # Use acquisition_rate for all rate calculations
+ self.sample_rate: float = self.acquisition_rate.storage_rate_hz
+
+ if self.sample_rate > 0:
+ self.duration_s: float = self.total_samples / self.sample_rate
+ else:
+ self.duration_s = 0.0
+
+ channels_attr = attrs.get("channels", []) # type: ignore
+ self.channels: List[int] = [int(ch) for ch in channels_attr] # type: ignore
+
+ voltage_ranges_attr = attrs.get("voltage_ranges", []) # type: ignore
+ self.voltage_ranges: Dict[int, float] = {
+ ch: float(vr)
+ for ch, vr in zip(self.channels, voltage_ranges_attr, strict=True) # type: ignore
+ }
+
+ offsets_attr = attrs.get("offsets_v", []) # type: ignore
+ self.offsets_v: Dict[int, float] = {
+ ch: float(off)
+ for ch, off in zip(self.channels, offsets_attr, strict=True) # type: ignore
+ }
+
+ self.resolution: int = int(attrs.get("resolution", 16)) # type: ignore
+ # Default to 32767 (PicoScope 5000a max) if not specified
+ self.max_adc: int = int(attrs.get("max_adc", 32767)) # type: ignore
+ self.downsample_ratio: int = int(
+ float(ds_ratio) if isinstance(ds_ratio, str) else ds_ratio
+ ) # type: ignore
+ self.downsample_mode: str = str(ds_mode) # type: ignore
+ self.pre_trigger_seconds: float = float(attrs.get("pre_trigger_seconds", 0.0)) # type: ignore
+ self.format_version: str = str(attrs.get("format_version", "unknown")) # type: ignore
+ self.compression: Optional[str] = attrs.get("compression") # type: ignore
+ self.start_time_iso: str = str(attrs.get("start_time_iso", "")) # type: ignore
+ self.buffer_duration_s: float = float(attrs.get("buffer_duration_s", 0.0)) # type: ignore
+ self.device_id: str = str(attrs.get("device_id", "")) # type: ignore
+ self.serial_code: Optional[str] = attrs.get("serial_code") # type: ignore
+ self.coupling: str = str(attrs.get("coupling", "DC")) # type: ignore
+ self.bandwidth_limiter: str = str(attrs.get("bandwidth_limiter", "FULL")) # type: ignore
+
+ logger.info(
+ "PicoZarrReader opened: {}, {} samples, {:.2f}s, {} channels",
+ self.path.name,
+ self.total_samples,
+ self.duration_s,
+ self.num_channels,
+ )
+
+ def _time_to_sample(self, time_s: float) -> int:
+ """Convert time in seconds to sample index.
+
+ Parameters
+ ----------
+ time_s : float
+ Time in seconds from start of capture.
+
+ Returns
+ -------
+ int
+ Sample index, clamped to valid range.
+ """
+ sample_idx = self.acquisition_rate.seconds_to_samples(time_s)
+ return max(0, min(sample_idx, self.total_samples))
+
+ def _sample_to_time(self, sample_idx: int) -> float:
+ """Convert sample index to time in seconds.
+
+ Parameters
+ ----------
+ sample_idx : int
+ Sample index.
+
+ Returns
+ -------
+ float
+ Time in seconds.
+ """
+ return self.acquisition_rate.samples_to_seconds(sample_idx)
+
+ def get_raw(self, start_s: float, end_s: float) -> np.ndarray:
+ """Read raw ADC data from the specified time range.
+
+ Reads int16 ADC counts from disk using Zarr's chunk-level access.
+ Returns a copy, not a view into the Zarr array.
+
+ Parameters
+ ----------
+ start_s : float
+ Start time in seconds from capture start.
+ end_s : float
+ End time in seconds from capture start (exclusive).
+
+ Returns
+ -------
+ np.ndarray
+ Array of shape (n_samples, num_channels) with dtype int16.
+ Returns empty array if start_s >= end_s or out of bounds.
+
+ Raises
+ ------
+ ValueError
+ If requesting data from a closed reader.
+ """
+ if not hasattr(self, "data"):
+ raise ValueError("Reader has been closed")
+
+ start_sample = self._time_to_sample(start_s)
+ end_sample = self._time_to_sample(end_s)
+
+ if start_sample >= end_sample:
+ return np.empty((0, self.num_channels), dtype=np.int16)
+
+ if start_sample >= self.total_samples:
+ return np.empty((0, self.num_channels), dtype=np.int16)
+
+ end_sample = min(end_sample, self.total_samples)
+
+ data = self.data[start_sample:end_sample]
+
+ return np.array(data, dtype=np.int16, copy=True)
+
+ def get_voltage(self, start_s: float, end_s: float) -> np.ndarray:
+ """Read voltage data (in mV) from the specified time range.
+
+ Applies ADC-to-mV conversion using stored voltage ranges and offsets.
+
+ Parameters
+ ----------
+ start_s : float
+ Start time in seconds from capture start.
+ end_s : float
+ End time in seconds from capture start (exclusive).
+
+ Returns
+ -------
+ np.ndarray
+ Array of shape (n_samples, num_channels) with dtype float64,
+ containing voltage values in millivolts.
+ """
+ raw_data = self.get_raw(start_s, end_s)
+
+ if raw_data.size == 0:
+ return np.empty((0, self.num_channels), dtype=np.float64)
+
+ voltage_data = np.empty_like(raw_data, dtype=np.float64)
+
+ for i, ch in enumerate(self.channels):
+ if i < raw_data.shape[1]:
+ voltage_range = self.voltage_ranges.get(ch, 1.0)
+ offset = self.offsets_v.get(ch, 0.0)
+
+ converted = adc_to_mV(raw_data[:, i], voltage_range, self.max_adc)
+ voltage_data[:, i] = converted + (offset * 1000.0)
+
+ return voltage_data
+
+ def get_time_axis(self, start_s: float, end_s: float) -> np.ndarray:
+ """Create a time axis for the specified time range.
+
+ Parameters
+ ----------
+ start_s : float
+ Start time in seconds.
+ end_s : float
+ End time in seconds.
+
+ Returns
+ -------
+ np.ndarray
+ Array of time values in seconds, aligned with get_raw/get_voltage.
+ """
+ start_sample = self._time_to_sample(start_s)
+ end_sample = self._time_to_sample(end_s)
+
+ n_samples = max(0, end_sample - start_sample)
+ if n_samples == 0:
+ return np.array([], dtype=np.float64)
+
+ start_time = self._sample_to_time(start_sample)
+ end_time = self._sample_to_time(end_sample)
+
+ return np.linspace(start_time, end_time, n_samples, endpoint=False)
+
+ def get_metadata(self) -> Dict:
+ """Return all metadata as a dictionary.
+
+ Returns
+ -------
+ Dict
+ Dictionary containing all file metadata.
+ """
+ return {
+ "path": str(self.path),
+ "sample_rate_hz": self.sample_rate,
+ "num_channels": self.num_channels,
+ "total_samples": self.total_samples,
+ "duration_s": self.duration_s,
+ "channels": self.channels,
+ "voltage_ranges": self.voltage_ranges,
+ "offsets_v": self.offsets_v,
+ "resolution": self.resolution,
+ "max_adc": self.max_adc,
+ "downsample_ratio": self.downsample_ratio,
+ "downsample_mode": self.downsample_mode,
+ "pre_trigger_seconds": self.pre_trigger_seconds,
+ "format_version": self.format_version,
+ "compression": self.compression,
+ "start_time_iso": self.start_time_iso,
+ "buffer_duration_s": self.buffer_duration_s,
+ "device_id": self.device_id,
+ "serial_code": self.serial_code,
+ "coupling": self.coupling,
+ "bandwidth_limiter": self.bandwidth_limiter,
+ }
+
+ def get_pre_trigger_time(self) -> float:
+ """Get the time at which the save trigger occurred.
+
+ The trigger point is the start of post-trigger data, which
+ is at pre_trigger_seconds from the beginning of the file.
+
+ Returns
+ -------
+ float
+ Time in seconds of the trigger point.
+ """
+ return self.pre_trigger_seconds
+
+ def get_channel_name(self, channel_index: int) -> str:
+ """Get the display name for a channel index.
+
+ Parameters
+ ----------
+ channel_index : int
+ Channel index (0, 1, etc.).
+
+ Returns
+ -------
+ str
+ Channel name (e.g., "A", "B", "Ch 0").
+ """
+ if channel_index < len(self.channels):
+ ch = self.channels[channel_index]
+ return chr(ord("A") + ch)
+ return f"Ch {channel_index}"
+
+ def close(self) -> None:
+ """Close the reader and release resources.
+
+ After closing, the reader can no longer be used to access data.
+ """
+ if hasattr(self, "root"):
+ del self.root
+ del self.data
+ logger.info("PicoZarrReader closed: {}", self.path.name)
+
+ def __enter__(self) -> PicoZarrReader:
+ """Context manager entry."""
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
+ """Context manager exit."""
+ self.close()
+
+ def __repr__(self) -> str:
+ """String representation of the reader."""
+ return (
+ f"PicoZarrReader("
+ f"path='{self.path.name}', "
+ f"{self.duration_s:.2f}s, "
+ f"{self.sample_rate / 1e6:.2f} MS/s, "
+ f"{len(self.channels)} ch)"
+ )
diff --git a/picostream/zarr_viewer.py b/picostream/zarr_viewer.py
new file mode 100644
index 0000000..0dc7ecf
--- /dev/null
+++ b/picostream/zarr_viewer.py
@@ -0,0 +1,553 @@
+"""Standalone viewer window for saved PicoStream Zarr files.
+
+This module provides a Qt-based viewer for analysing saved Zarr captures
+independently from the live acquisition system. Multiple viewer windows
+can be open simultaneously.
+"""
+
+from __future__ import annotations
+
+import sys
+from typing import Dict, Optional
+
+import numpy as np
+import pyqtgraph as pg
+from loguru import logger
+from PyQt6.QtCore import QSettings, Qt
+from PyQt6.QtGui import QCloseEvent, QFont
+from PyQt6.QtWidgets import (
+ QApplication,
+ QDoubleSpinBox,
+ QFileDialog,
+ QHBoxLayout,
+ QLabel,
+ QMainWindow,
+ QPushButton,
+ QScrollBar,
+ QVBoxLayout,
+ QWidget,
+)
+
+from picostream.data_pipeline import DataPipeline, PipelineConfig
+from picostream.zarr_reader import PicoZarrReader
+
+
+class ZarrViewerWindow(QMainWindow):
+ """Independent viewer window for saved Zarr files.
+
+ Displays saved capture data with a scrollable timeline and configurable
+ window duration. Uses on-demand disk access via Zarr for memory efficiency.
+
+ Parameters
+ ----------
+ path : str
+ Path to the Zarr file to display.
+ parent : Optional[QWidget]
+ Parent widget, if any.
+
+ Attributes
+ ----------
+ reader : PicoZarrReader
+ The Zarr reader instance.
+ window_duration_s : float
+ Current visible window duration in seconds.
+ scroll_position_s : float
+ Current scroll position (start of visible window) in seconds.
+ """
+
+ MIN_WINDOW_DURATION: float = 0.5
+ MAX_WINDOW_DURATION: float = 30.0
+
+ # Fixed max ADC value for PicoScope 5000a series (consistent with dfplot)
+ MAX_ADC_VALUE: int = 32767
+
+ def __init__(self, path: str, parent: Optional[QWidget] = None) -> None:
+ """Initialise the viewer window and load the Zarr file."""
+ super().__init__(parent)
+
+ self.reader = PicoZarrReader(path)
+
+ # Load window duration from QSettings (default 2.0s)
+ settings = QSettings("lab", "picostream")
+ default_window = settings.value("zarr_viewer/window_duration", 2.0, float)
+ self.window_duration_s: float = max(
+ self.MIN_WINDOW_DURATION,
+ min(default_window, self.reader.duration_s),
+ )
+ self.scroll_position_s: float = 0.0
+
+ # Scale factors: 1.0 = no scaling, auto-calculated from voltage ranges
+ self.scale_factors: Dict[int, float] = {}
+ self._calculate_scale_factors()
+
+ # Initialise DataPipeline for consistent processing with dfplot
+ self._initialise_pipeline()
+
+ self._setup_ui()
+ self._update_title()
+ self._update_scrollbar_range()
+ self._refresh_plot()
+
+ logger.info(
+ "ZarrViewerWindow opened: {}, duration={:.2f}s",
+ path,
+ self.reader.duration_s,
+ )
+
+ def _setup_ui(self) -> None:
+ """Set up the window UI components."""
+ self.setMinimumSize(800, 600)
+
+ central_widget = QWidget()
+ self.setCentralWidget(central_widget)
+
+ main_layout = QVBoxLayout(central_widget)
+ main_layout.setContentsMargins(10, 10, 10, 10)
+
+ # Info panel
+ info_layout = QHBoxLayout()
+
+ self.title_label = QLabel()
+ self.title_label.setFont(QFont("Monospace"))
+ info_layout.addWidget(self.title_label)
+
+ info_layout.addStretch()
+
+ # Metadata display
+ meta_text = (
+ f"SR: {self.reader.sample_rate / 1e6:.2f} MS/s | "
+ f"Ch: {', '.join(self.reader.get_channel_name(ch) for ch in self.reader.channels)} | "
+ f"Pre-trigger: {self.reader.pre_trigger_seconds:.1f}s | "
+ f"Resolution: {self.reader.resolution}-bit"
+ )
+ self.meta_label = QLabel(meta_text)
+ self.meta_label.setFont(QFont("Monospace", 9))
+ info_layout.addWidget(self.meta_label)
+
+ main_layout.addLayout(info_layout)
+
+ # Plot area
+ self.plot_widget = pg.PlotWidget()
+ self.plot_widget.setLabel("left", "Voltage", "mV")
+ self.plot_widget.setLabel("bottom", "Time", "s")
+ self.plot_widget.showGrid(x=True, y=True)
+ self.plot_widget.setDownsampling(mode="peak")
+ self.plot_widget.setClipToView(True)
+ self.plot_widget.showAxis("right", False)
+
+ self.viewBox = self.plot_widget.getViewBox()
+
+ # Channel colours matching dfplot (RGBA as hex)
+ self.channel_pens = {
+ 0: {"color": "#00BFFF"}, # Cyan (0.0, 0.75, 1.0) → #00BFFF
+ 1: {"color": "#FF4444"}, # Red (1.0, 0.27, 0.27) → #FF4444
+ }
+
+ self.curves: Dict[int, pg.PlotDataItem] = {}
+ self.curves[0] = pg.PlotDataItem(
+ pen=self.channel_pens[0]["color"], width=1, name="Ch A"
+ )
+ self.curves[1] = pg.PlotDataItem(
+ pen=self.channel_pens[1]["color"], width=1, name="Ch B"
+ )
+
+ self.viewBox.addItem(self.curves[0])
+ self.viewBox.addItem(self.curves[1])
+
+ # Show/hide curves based on available channels
+ for ch_idx in self.curves:
+ self.curves[ch_idx].setVisible(ch_idx in self.reader.channels)
+
+ # Add legend
+ self.legend = pg.LegendItem(offset=(50, 30))
+ self.legend.setParentItem(self.plot_widget.graphicsItem())
+ self._update_legend()
+
+ main_layout.addWidget(self.plot_widget, stretch=1)
+
+ # Controls panel
+ controls_layout = QHBoxLayout()
+
+ # Window duration control
+ controls_layout.addWidget(QLabel("Window:"))
+ self.window_spin = QDoubleSpinBox()
+ self.window_spin.setRange(
+ self.MIN_WINDOW_DURATION,
+ min(self.MAX_WINDOW_DURATION, self.reader.duration_s),
+ )
+ self.window_spin.setSingleStep(0.1)
+ self.window_spin.setDecimals(2)
+ self.window_spin.setValue(self.window_duration_s)
+ self.window_spin.setSuffix(" s")
+ self.window_spin.valueChanged.connect(self._on_window_changed)
+ controls_layout.addWidget(self.window_spin)
+
+ controls_layout.addSpacing(20)
+
+ # Channel visibility toggles
+ controls_layout.addWidget(QLabel("Channels:"))
+ self.channel_buttons: Dict[int, QPushButton] = {}
+ for ch in self.reader.channels:
+ btn = QPushButton(f"Ch {chr(ord('A') + ch)}")
+ btn.setCheckable(True)
+ btn.setChecked(True)
+ btn.clicked.connect(
+ lambda checked, ch=ch: self._on_channel_toggled(ch, checked)
+ )
+ self.channel_buttons[ch] = btn
+ controls_layout.addWidget(btn)
+
+ controls_layout.addStretch()
+
+ # Info labels
+ self.position_label = QLabel("0.0s")
+ self.position_label.setFont(QFont("Monospace"))
+ controls_layout.addWidget(self.position_label)
+
+ controls_layout.addWidget(QLabel("/"))
+
+ self.duration_label = QLabel(f"{self.reader.duration_s:.1f}s")
+ self.duration_label.setFont(QFont("Monospace"))
+ controls_layout.addWidget(self.duration_label)
+
+ main_layout.addLayout(controls_layout)
+
+ # Horizontal scrollbar
+ self.scrollbar = QScrollBar(Qt.Orientation.Horizontal)
+ self.scrollbar.valueChanged.connect(self._on_scroll_changed)
+ main_layout.addWidget(self.scrollbar)
+
+ def _update_title(self) -> None:
+ """Update the window title with file info."""
+ filename = self.reader.path.name
+ self.setWindowTitle(
+ f"PicoStream Viewer — {filename} ({self.reader.duration_s:.1f}s)"
+ )
+ self.title_label.setText(f"<b>{filename}</b>")
+
+ def _update_axis_visibility(self) -> None:
+ pass
+
+ def _calculate_scale_factors(self) -> None:
+ """Calculate automatic scale factors based on voltage ranges."""
+ range_0 = self.reader.voltage_ranges[0]
+ self.scale_factors[0] = 1.0
+ if 1 in self.reader.voltage_ranges:
+ range_1 = self.reader.voltage_ranges[1]
+ if range_1 > 0:
+ self.scale_factors[1] = range_0 / range_1
+ else:
+ self.scale_factors[1] = 1.0
+
+ def _initialise_pipeline(self) -> None:
+ """Initialise the DataPipeline for consistent processing with dfplot."""
+ if self.reader.acquisition_rate is None:
+ logger.warning("No acquisition_rate available, pipeline not initialised")
+ self.pipeline = None
+ return
+
+ config = PipelineConfig(
+ resolution=self.reader.resolution,
+ voltage_ranges=self.reader.voltage_ranges,
+ offsets_v=self.reader.offsets_v,
+ max_adc_value=self.MAX_ADC_VALUE,
+ target_plot_points=None, # Lossless: show all samples
+ )
+ self.pipeline = DataPipeline(config, self.reader.acquisition_rate)
+ logger.debug("DataPipeline initialised for ZarrViewer")
+
+ def set_scale_factors(self, scale_factors: Dict[int, float]) -> None:
+ """Set explicit scale factors for channels."""
+ self.scale_factors = scale_factors.copy()
+ self._update_legend()
+ logger.info("Scale factors set: {}", scale_factors)
+
+ def _update_legend(self) -> None:
+ """Update the legend with current channel ranges and scale factors."""
+ self.legend.clear()
+ for ch in sorted(self.curves.keys()):
+ if ch in self.reader.channels:
+ range_v = self.reader.voltage_ranges.get(ch, 0.0)
+ scale = self._get_scale_factor(ch)
+ is_auto = ch not in self.scale_factors
+ ch_name = "A" if ch == 0 else "B"
+ if is_auto:
+ label = f"Ch {ch_name}: ±{range_v}V (auto×{scale:.2f})"
+ else:
+ label = f"Ch {ch_name}: ±{range_v}V (×{scale:.2f})"
+ self.legend.addItem(self.curves[ch], label)
+
+ def _get_scale_factor(self, ch: int) -> float:
+ """Get scale factor for a channel."""
+ return self.scale_factors.get(ch, 1.0)
+
+ def _update_scrollbar_range(self) -> None:
+ """Update scrollbar range based on file duration and window size."""
+ max_scroll = max(0, self.reader.duration_s - self.window_duration_s)
+
+ self.scrollbar.setRange(0, int(max_scroll * 1000))
+ self.scrollbar.setPageStep(int(self.window_duration_s * 1000))
+ self.scrollbar.setSingleStep(int(self.window_duration_s * 100))
+
+ def _on_window_changed(self, value: float) -> None:
+ """Handle window duration spinner change.
+
+ Parameters
+ ----------
+ value : float
+ New window duration in seconds.
+ """
+ self.window_duration_s = value
+
+ # Save to QSettings for persistence
+ settings = QSettings("lab", "picostream")
+ settings.setValue("zarr_viewer/window_duration", value)
+
+ self._update_scrollbar_range()
+ self._clamp_scroll_position()
+ self._refresh_plot()
+
+ def _on_scroll_changed(self, value: int) -> None:
+ """Handle scrollbar position change.
+
+ Parameters
+ ----------
+ value : int
+ Scroll position in milliseconds.
+ """
+ self.scroll_position_s = value / 1000.0
+ self._refresh_plot()
+
+ def _clamp_scroll_position(self) -> None:
+ """Ensure scroll position is within valid bounds."""
+ max_pos = max(0, self.reader.duration_s - self.window_duration_s)
+ self.scroll_position_s = max(0, min(self.scroll_position_s, max_pos))
+ self.scrollbar.setValue(int(self.scroll_position_s * 1000))
+
+ def _on_channel_toggled(self, channel: int, visible: bool) -> None:
+ """Toggle channel visibility.
+
+ Parameters
+ ----------
+ channel : int
+ Channel index.
+ visible : bool
+ True to show, False to hide.
+ """
+ if channel in self.curves:
+ self.curves[channel].setVisible(visible)
+ logger.debug("Channel {} visibility set to {}", channel, visible)
+
+ def _refresh_plot(self) -> None:
+ """Refresh the plot with current scroll position and window."""
+ start_s = self.scroll_position_s
+ end_s = min(start_s + self.window_duration_s, self.reader.duration_s)
+
+ if start_s >= end_s:
+ return
+
+ if self.pipeline is None:
+ logger.warning("Cannot refresh plot: pipeline not initialised")
+ return
+
+ try:
+ raw_data = self.reader.get_raw(start_s, end_s)
+
+ if raw_data.size == 0:
+ return
+
+ self.position_label.setText(f"{start_s:.2f}s")
+
+ # Lossless mode: decimation is always 1, no software min-max
+ decimation = 1
+
+ # Create time axis starting at scroll position (absolute time, not right-aligned)
+ # Unlike live plotter, Zarr viewer shows data at its actual time position.
+ # Note: We use reader.get_time_axis() instead of DataPipeline because:
+ # 1. The pipeline's get_window_display_geometry() is designed for live streaming
+ # with right-aligned relative windows, while the Zarr viewer needs absolute time.
+ # 2. The reader already has access to AcquisitionRate for consistent conversions.
+ time_axis = self.reader.get_time_axis(start_s, end_s)
+
+ # Handle AGGREGATE mode hardware min-max pairs only
+ is_aggregate = (
+ self.reader.acquisition_rate.downsample_mode.name == "AGGREGATE"
+ and self.reader.acquisition_rate.downsample_ratio > 1
+ )
+
+ if is_aggregate and len(time_axis) > 1:
+ # AGGREGATE mode: data already has min-max pairs interleaved
+ # Each pair represents one time point
+ n_pairs = len(time_axis) // 2
+ if n_pairs > 0:
+ pair_times = time_axis[::2][:n_pairs]
+ time_axis = np.repeat(pair_times, 2)
+
+ for i, ch in enumerate(self.reader.channels):
+ if i >= raw_data.shape[1]:
+ continue
+
+ ch_data = raw_data[:, i]
+
+ # Use DataPipeline for processing (consistent with dfplot)
+ voltage_data, has_data = self.pipeline.process_channel_data(
+ ch_data, ch, decimation
+ )
+
+ if not has_data:
+ continue
+
+ # Apply scale factor
+ voltage_data = voltage_data * self._get_scale_factor(ch)
+
+ # Create display time axis matching data length
+ n_points = len(voltage_data)
+ if n_points > 0 and len(time_axis) > 0:
+ display_time = time_axis[:n_points]
+ self.curves[ch].setData(display_time, voltage_data)
+
+ if len(time_axis) > 0:
+ self.viewBox.setXRange(time_axis[0], time_axis[-1], padding=0)
+
+ self.viewBox.enableAutoRange(axis="y")
+
+ except Exception as e:
+ logger.exception("Error refreshing plot: {}", e)
+
+ def keyPressEvent(self, event) -> None: # pyright: ignore
+ """Handle keyboard navigation.
+
+ Arrow keys scroll the view. Page Up/Down scroll by window size.
+ """
+ step = self.window_duration_s * 0.1
+
+ if event.key() == Qt.Key.Key_Left:
+ self.scroll_position_s = max(0, self.scroll_position_s - step)
+ self._clamp_scroll_position()
+ self._refresh_plot()
+ elif event.key() == Qt.Key.Key_Right:
+ max_pos = self.reader.duration_s - self.window_duration_s
+ self.scroll_position_s = min(max_pos, self.scroll_position_s + step)
+ self._clamp_scroll_position()
+ self._refresh_plot()
+ elif event.key() == Qt.Key.Key_Home:
+ self.scroll_position_s = 0
+ self._clamp_scroll_position()
+ self._refresh_plot()
+ elif event.key() == Qt.Key.Key_End:
+ self.scroll_position_s = self.reader.duration_s - self.window_duration_s
+ self._clamp_scroll_position()
+ self._refresh_plot()
+ elif event.key() == Qt.Key.Key_PageUp:
+ self.scroll_position_s = max(
+ 0, self.scroll_position_s - self.window_duration_s
+ )
+ self._clamp_scroll_position()
+ self._refresh_plot()
+ elif event.key() == Qt.Key.Key_PageDown:
+ max_pos = self.reader.duration_s - self.window_duration_s
+ self.scroll_position_s = min(
+ max_pos, self.scroll_position_s + self.window_duration_s
+ )
+ self._clamp_scroll_position()
+ self._refresh_plot()
+ else:
+ super().keyPressEvent(event)
+
+ def closeEvent(self, event: QCloseEvent) -> None: # pyright: ignore
+ """Handle window close event."""
+ if hasattr(self, "reader"):
+ logger.info("ZarrViewerWindow closing: {}", self.reader.path.name)
+ else:
+ logger.info("ZarrViewerWindow closing")
+ self._cleanup_resources()
+ event.accept()
+
+ def _cleanup_resources(self) -> None:
+ """Clean up resources including Zarr file handles.
+
+ Forces immediate release of file handles by explicitly clearing
+ references and invoking garbage collection.
+ """
+ try:
+ # Clear pipeline first
+ if hasattr(self, "pipeline") and self.pipeline is not None:
+ self.pipeline.invalidate_cache()
+ del self.pipeline
+
+ # Close the reader
+ if hasattr(self, "reader"):
+ self.reader.close()
+ del self.reader
+
+ # Force garbage collection to release file handles promptly
+ import gc
+
+ gc.collect()
+
+ logger.debug("ZarrViewerWindow resources cleaned up")
+ except Exception as e:
+ logger.exception("Error during resource cleanup: {}", e)
+
+
+def open_zarr_viewer_dialog(
+ parent: Optional[QWidget] = None,
+) -> Optional[ZarrViewerWindow]:
+ """Open a file dialog and create a ZarrViewerWindow.
+
+ Parameters
+ ----------
+ parent : Optional[QWidget]
+ Parent widget for the file dialog.
+
+ Returns
+ -------
+ Optional[ZarrViewerWindow]
+ The opened viewer window, or None if cancelled.
+ """
+ path = QFileDialog.getExistingDirectory(
+ parent,
+ "Open Zarr File",
+ "",
+ QFileDialog.Option.ShowDirsOnly,
+ )
+
+ if not path:
+ return None
+
+ try:
+ return ZarrViewerWindow(path)
+ except Exception as e:
+ logger.exception("Failed to open Zarr file: {}", e)
+ return None
+
+
+def main() -> None:
+ """Standalone Zarr viewer entry point."""
+ app = QApplication(sys.argv)
+
+ import argparse
+
+ parser = argparse.ArgumentParser(description="Standalone Zarr file viewer.")
+ parser.add_argument("path", nargs="?", help="Path to Zarr file to open")
+ args = parser.parse_args()
+
+ if args.path:
+ try:
+ viewer = ZarrViewerWindow(args.path)
+ viewer.show()
+ sys.exit(app.exec())
+ except Exception as e:
+ logger.exception("Failed to open Zarr file: {}", e)
+ sys.exit(1)
+ else:
+ viewer = open_zarr_viewer_dialog()
+ if viewer:
+ viewer.show()
+ sys.exit(app.exec())
+ else:
+ sys.exit(0)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/picostream/zarr_writer.py b/picostream/zarr_writer.py
new file mode 100644
index 0000000..bd7d092
--- /dev/null
+++ b/picostream/zarr_writer.py
@@ -0,0 +1,250 @@
+"""Zarr-based streaming data writer for PicoStream.
+
+This module provides a high-performance writer for saving streaming ADC data
+to Zarr format with configurable chunking and optional compression.
+"""
+
+from __future__ import annotations
+
+import os
+import shutil
+from datetime import datetime, timezone
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+import numpy as np
+import zarr
+from loguru import logger
+
+if TYPE_CHECKING:
+ from picostream.acquisition_rate import AcquisitionRate
+
+
+class ZarrStreamWriter:
+ """Streaming data writer using Zarr format.
+
+ This writer appends int16 ADC data to a Zarr directory store with
+ configurable chunk sizes. It batches small appends internally and
+ flushes on chunk boundaries to optimise disk I/O.
+
+ Parameters
+ ----------
+ path : str
+ Output path for the Zarr directory store.
+ acquisition_rate : AcquisitionRate
+ Acquisition rate object with all rate information.
+ num_channels : int
+ Number of channels in the data.
+ compression : Optional[str], default None
+ Compression algorithm to use. Options: None, "lz4".
+
+ Attributes
+ ----------
+ root : zarr.Group
+ The root Zarr group.
+ data : zarr.Array
+ The main data array containing int16 ADC samples.
+ written : int
+ Number of samples written so far.
+
+ Examples
+ --------
+ >>> writer = ZarrStreamWriter("output.zarr", acquisition_rate, 2)
+ >>> data = np.random.randint(-1000, 1000, size=(1000, 2), dtype=np.int16)
+ >>> writer.append(data)
+ >>> writer.close({"sample_rate_hz": 62500000.0})
+ """
+
+ def __init__(
+ self,
+ path: str,
+ acquisition_rate: AcquisitionRate,
+ num_channels: int,
+ compression: Optional[str] = None,
+ ) -> None:
+ """Initialise the Zarr stream writer."""
+ self.final_path = path
+ # Use .incomplete suffix during writing to identify incomplete files
+ self.path = path + ".incomplete"
+ self.acquisition_rate = acquisition_rate
+ self.num_channels = num_channels
+ self.compression = compression
+
+ chunk_size = max(1000, int(acquisition_rate.storage_rate_hz * 0.2))
+ self._chunk_size = chunk_size
+
+ self.root = zarr.open(self.path, mode="w")
+
+ self.data = self.root.create_array(
+ "data",
+ shape=(0, num_channels),
+ chunks=(chunk_size, num_channels),
+ dtype=np.int16,
+ )
+
+ self.written = 0
+ self._pending: List[np.ndarray] = []
+ self._pending_samples: int = 0
+
+ logger.debug(
+ "ZarrStreamWriter initialised: path={}, chunk_size={}, "
+ "num_channels={}, compression={}",
+ path,
+ chunk_size,
+ num_channels,
+ compression,
+ )
+
+ def append(self, data: np.ndarray) -> None:
+ """Append data to the Zarr store.
+
+ Data is buffered internally and flushed to disk when the buffer
+ reaches the chunk size boundary.
+
+ Parameters
+ ----------
+ data : np.ndarray
+ Array of shape (n_samples, num_channels) with dtype int16.
+
+ Raises
+ ------
+ ValueError
+ If data shape or dtype is incorrect.
+ """
+ if data.dtype != np.int16:
+ raise ValueError(f"Data must be int16, got {data.dtype}")
+
+ if data.ndim != 2 or data.shape[1] != self.num_channels:
+ raise ValueError(
+ f"Data shape must be (n_samples, {self.num_channels}), got {data.shape}"
+ )
+
+ n_samples = data.shape[0]
+ if n_samples == 0:
+ return
+
+ self._pending.append(data)
+ self._pending_samples += n_samples
+
+ if self._pending_samples >= self._chunk_size:
+ self._flush()
+
+ def _flush(self) -> None:
+ """Flush pending data to the Zarr array.
+
+ This method concatenates all pending data and writes it to the
+ Zarr array, resizing as needed. It then clears the pending buffer.
+ """
+ if not self._pending:
+ return
+
+ concatenated = np.concatenate(self._pending, axis=0)
+ n_samples = concatenated.shape[0]
+
+ new_size = self.written + n_samples
+ self.data.resize((new_size, self.num_channels))
+ self.data[self.written : new_size] = concatenated
+
+ self.written += n_samples
+ self._pending = []
+ self._pending_samples = 0
+
+ logger.debug(
+ "ZarrStreamWriter flushed {} samples, total written: {}",
+ n_samples,
+ self.written,
+ )
+
+ @property
+ def total_samples(self) -> int:
+ """Total samples including pending (not yet flushed) data.
+
+ Returns
+ -------
+ int
+ Total samples written to disk plus samples buffered in memory.
+ """
+ return self.written + self._pending_samples
+
+ def flush(self) -> None:
+ """Force-flush any pending data to disk.
+
+ This should be called before closing to ensure all data is written.
+ """
+ self._flush()
+ logger.debug("ZarrStreamWriter flush complete")
+
+ def close(self, metadata: Dict[str, Any]) -> None:
+ """Close the writer and finalise the Zarr file.
+
+ This flushes any pending data and writes metadata to the root
+ attributes. On successful close, the .incomplete suffix is removed
+ to mark the file as complete.
+
+ Parameters
+ ----------
+ metadata : Dict[str, Any]
+ Dictionary of metadata to store in the Zarr attributes.
+ Common keys: hardware_rate_hz, num_channels, downsample_ratio,
+ downsample_mode, pre_trigger_seconds, compression.
+
+ Notes
+ -----
+ Files with the .incomplete suffix were not successfully closed and
+ may contain partial or corrupted data. These can be safely deleted.
+ """
+ try:
+ self._flush()
+
+ metadata["total_samples"] = self.written
+ storage_rate = self.acquisition_rate.storage_rate_hz
+ if storage_rate > 0:
+ metadata["duration_s"] = self.written / storage_rate
+
+ metadata["format_version"] = "2.0"
+ metadata["compression"] = self.compression
+
+ if "start_time_iso" not in metadata:
+ metadata["start_time_iso"] = datetime.now(timezone.utc).isoformat()
+
+ for key, value in metadata.items():
+ self.root.attrs[key] = value
+
+ # Rename to final path to mark as complete
+ os.rename(self.path, self.final_path)
+
+ logger.info(
+ "ZarrStreamWriter closed: {} with {} samples",
+ self.final_path,
+ self.written,
+ )
+ except Exception as e:
+ logger.exception("Error closing ZarrStreamWriter: {}", e)
+ raise
+
+ def discard(self) -> None:
+ """Discard the Zarr file and close the writer.
+
+ This deletes the entire Zarr directory and should be called
+ when the user chooses to discard a capture.
+ """
+ try:
+ self._pending = []
+ self._pending_samples = 0
+
+ if os.path.exists(self.path):
+ shutil.rmtree(self.path)
+ logger.info("ZarrStreamWriter discarded: {}", self.path)
+ except Exception:
+ logger.exception("Error discarding Zarr file")
+ raise
+
+ def __enter__(self) -> ZarrStreamWriter:
+ """Context manager entry."""
+ return self
+
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
+ """Context manager exit."""
+ if exc_type is None:
+ self.close({})
+ else:
+ self.discard()
diff --git a/pyproject.toml b/pyproject.toml
index a835307..486d2f6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,23 +1,73 @@
-[build-system]
-requires = ["setuptools"]
-build-backend = "setuptools.build_meta"
-
-[project]
-name = "picostream"
-version = "0.2.0"
-dependencies = [
- "numpy",
- "loguru",
- "picosdk",
- "h5py",
- "pyqtgraph",
- "numba",
- "click",
-]
-
-
-[tool.setuptools]
-packages = ["picostream"] # List the package names directly
-
-[project.scripts]
-picostream = "picostream.main:main"
+[project]
+name = "picostream"
+version = "1.0.0"
+description = "High-speed dual-channel data acquisition for PicoScope 5000a series"
+readme = "README.md"
+requires-python = ">=3.10"
+dependencies = [
+ "labdaemon>=1.0.0",
+ "numpy",
+ "numba",
+ "PyQt6>=6.0.0",
+ "pyqtgraph>=0.13.0",
+ "loguru>=0.6.0",
+ "zarr>=2.16.0",
+ "vispy>=0.14.0",
+ "picosdk",
+]
+
+[dependency-groups]
+dev = [
+ "pyinstaller",
+ "pytest>=7.0.0",
+ "pytest-cov",
+ "ruff>=0.1.0",
+ "pyright>=1.1.0",
+ "snakeviz>=2.2.0",
+]
+
+[tool.pyright]
+# Type checking rules to disable for FFI libraries and external tools
+# reportAttributeAccessIssue: ctypes/FFI library attributes (picosdk SDK)
+# reportOptionalSubscript: ctypes functions that can return None
+# reportOptionalMemberAccess: FFI methods with optional returns
+# reportCallIssue: calling possibly-None FFI function results
+# reportOptionalOperand: arithmetic on possibly-None FFI results
+reportAttributeAccessIssue = false
+reportOptionalSubscript = false
+reportOptionalMemberAccess = false
+reportCallIssue = false
+reportOptionalOperand = false
+
+exclude = [
+ "picostream/test_buffered_stream.py",
+ "picostream/test_live_plotter.py",
+ "picostream/test_zarr_writer.py",
+ "picostream/tests/",
+ "picostream/v0_backup/",
+]
+
+[tool.coverage.run]
+source = ["picostream"]
+omit = [
+ "picostream/tests/",
+ "picostream/v0_backup/**",
+]
+
+[tool.coverage.report]
+exclude_lines = [
+ "pragma: no cover",
+ "def __repr__",
+ "if TYPE_CHECKING:",
+ "raise NotImplementedError",
+ "if __name__ == .__main__.:",
+]
+
+[tool.setuptools]
+packages = ["picostream"]
+
+[project.scripts]
+picostream = "picostream.main:main"
+
+[project.urls]
+Repository = "https://github.com/yourusername/picostream"