From d78b82c9e3b5ff2f24b3bb9a6805637061cfaf4c Mon Sep 17 00:00:00 2001 From: Tom Brennan Date: Fri, 3 Oct 2025 08:58:45 -0400 Subject: [PATCH] bugfixes, etc --- src/pojagi_dsp/channel/__init__.py | 338 +--------------- src/pojagi_dsp/channel/ecg/__init__.py | 17 +- .../ecg/generator/wavetable/__init__.py | 283 +------------ .../channel/ecg/generator/wavetable/sinus.py | 378 ++++++++++-------- .../ecg/generator/wavetable/synthesizer.py | 165 ++++++++ .../ecg/generator/wavetable/wavetable.py | 182 +++++++++ src/pojagi_dsp/channel/generator/sine.py | 11 +- src/pojagi_dsp/channel/signal.py | 361 +++++++++++++++++ 8 files changed, 928 insertions(+), 807 deletions(-) create mode 100644 src/pojagi_dsp/channel/ecg/generator/wavetable/synthesizer.py create mode 100644 src/pojagi_dsp/channel/ecg/generator/wavetable/wavetable.py create mode 100644 src/pojagi_dsp/channel/signal.py diff --git a/src/pojagi_dsp/channel/__init__.py b/src/pojagi_dsp/channel/__init__.py index 11ec81b..ccaa1fb 100644 --- a/src/pojagi_dsp/channel/__init__.py +++ b/src/pojagi_dsp/channel/__init__.py @@ -1,337 +1 @@ -import abc -import copy -import datetime -import inspect -from itertools import islice -import logging -import math -import operator -from collections.abc import Iterable -from functools import reduce -import types -from typing import (Any, Callable, Generator, Generic, Iterator, Optional, Type, TypeVar, - Union) - -logger = logging.getLogger(__name__) - -T = TypeVar("T") - - -class IllegalStateException(ValueError): - ... - - -def coerce_channels(x: Any) -> Iterator["ASignal"]: - if isinstance(x, ASignal): - yield x - else: - if callable(x): - if isinstance(x, Type): - yield x() - else: - yield SignalFunction(x) - elif isinstance(x, Iterable): # and not isinstance(x, str): - for it in (coerce_channels(y) for y in x): - for channel in it: - yield channel - else: - yield Constantly(x) - - -class ASignalMeta(abc.ABCMeta): - def __or__(self, other: Any) -> "Filter": - """ - Allows `|` composition starting from an uninitialized class. - See doc for `__or__` below in `ASignal`. - """ - return self() | coerce_channels(other) - - def __radd__(self, other): return self() + other - def __add__(self, other): return self() + other - def __rmul__(self, other): return self() * other - def __mul__(self, other): return self() * other - - -class ASignal(Generic[T], metaclass=ASignalMeta): - def __init__(self, srate: Optional[float] = None): - self._srate = srate - self._cursor: Optional[Iterator[T]] = None - - @property - def srate(self): - if self._srate is None: - raise IllegalStateException( - f"{self.__class__}: `srate` is None." - ) - return self._srate - - @srate.setter - def srate(self, val: float): self._srate = val - - def __iter__(self): - self._cursor = self.samples() - return self - - def __next__(self): return next(self.cursor) - - @abc.abstractmethod - def samples(self) -> Iterator[T]: ... - - @property - def cursor(self): - """ - An `Iterator` representing the current pipeline in progress. - """ - if self._cursor is None: - # this can only happen once - self._cursor = self.samples() - return self._cursor - - def __getstate__(self): - """ - `_cursor` is a generator, and generators aren't picklable. - """ - state = self.__dict__.copy() - if state.get("_cursor"): - del state["_cursor"] - return state - - def stream(self): - while True: - try: - yield next(self.cursor) - except StopIteration: - self = iter(self) - - def of_duration(self, duration: datetime.timedelta): - """ - Returns an `Iterator` of samples for a particular duration expressed - as a `datetime.timedelta` - :param:`duration` - `datetime.timedelta` representing the duration - """ - return islice( - self.stream(), - 0, - math.floor(self.srate * duration.total_seconds()), - ) - - def __or__( - left, - right: Union["Filter", Callable, Iterable], - ) -> "Filter": - """ - Allows composition of filter pipelines with `|` operator. - - e.g., - ``` - myFooGenerator - | BarFilter - | baz_filter_func - | (lambda reader: (x for x in reader)) - ``` - """ - if isinstance(right, SignalFunction): - return left | FilterFunction(fn=right._fn, name=right.Function) - - if not isinstance(right, ASignal): - return reduce(operator.or_, (left, *coerce_channels(right))) - - if not isinstance(right, Filter): - raise ValueError( - f"Right side must be a `{Filter.__name__}`; " - f"received: {type(right)}", - ) - - filter: Filter = right - while getattr(filter, "_reader", None) is not None: - # Assuming this is a filter pipeline, we want the last node's - # reader to be whatever's on the left side of this operation. - filter = filter.reader - - if hasattr(filter, "_reader"): - # We hit the "bottom" and found a filter. - filter.reader = left - else: - # We hit the "bottom" and found a non-filter/generator. - raise ValueError( - f"{right.__class__.__name__}: filter pipeline already has a " - "generator." - ) - - # Will often be `None` unless `left` is a generator. - right.srate = left._srate - - return right - - def __radd__(right, left): return right.__add__(left) - def __add__(left, right): return left._operator_impl(operator.add, right) - def __rmul__(right, left): return right.__mul__(left) - def __mul__(left, right): return left._operator_impl(operator.mul, right) - # FIXME: other operators? Also, shouldn't `*` mean convolve instead? - - def _operator_impl(left, operator: Callable[..., T], right: Any): - channels = list(coerce_channels(right)) - for channel in channels: - if channel._srate is None: - channel.srate = left._srate - return Reduce(operator, left, *channels, srate=left._srate) - - def __repr__(self): - members = {} - for k in [k for k in dir(self) - if not k.startswith("_") - and not k in {"stream", "reader", "cursor", "wave", }]: - try: - v = getattr(self, k) - if not inspect.isroutine(v): - members[k] = v - except IllegalStateException as e: - members[k] = None - - return ( - f"{self.__class__.__name__}" - f"""({ - f", ".join( - f"{k}={v}" - for k, v in members.items() - ) - })""" - ) - - -S = TypeVar("S", bound=ASignal) - - -class Reduce(ASignal, Generic[S, T]): - def __init__( - self, - # FIXME: typing https://stackoverflow.com/a/67814270 - fn: Callable[..., T], - *streams: S, - srate: Optional[float] = None, - stateful=False, - ): - super().__init__(srate) - self._fn = fn - self.fn = fn.__name__ - self.streams = [] - for stream in streams: - if stateful: - self.streams.append(stream) - continue - - stream_ = ( - copy.deepcopy(stream) - if not isinstance(stream, types.GeneratorType) - else stream - ) - stream_.srate = srate - self.streams.append(stream_) - - @property - def srate(self): return ASignal.srate.fget(self) - - @srate.setter - def srate(self, val: float): - ASignal.srate.fset(self, val) - for stream in self.streams: - if isinstance(stream, ASignal): - stream.srate = val - - def samples(self): return ( - reduce(self._fn, args) - for args in zip(*self.streams) - ) - - -class Filter(ASignal, Generic[S]): - - def __init__( - self, - reader: Optional[S] = None, - srate: Optional[float] = None, - ): - super().__init__(srate) - self.reader: Optional[S] = reader - - @property - def reader(self) -> S: - """ - The input stream this filter reads. - """ - if not self._reader: - raise IllegalStateException( - f"{self.__class__}: `reader` is None." - ) - return self._reader - - @reader.setter - def reader(self, val: S): - self._reader = val - if val is not None and self._srate is None: - self.srate = val._srate - - @property - def srate(self): return ASignal.srate.fget(self) - - @srate.setter - def srate(self, val: float): - ASignal.srate.fset(self, val) - child = getattr(self, "_reader", None) - previous_srate = val - while child is not None: - # Since `srate` is optional at initialization, but required in - # general, we make our best attempt to normalize it for the - # filter pipeline, which should be consistent for most - # applications, by applying it to all children. - if child._srate is None: - child.srate = previous_srate - child: Optional[ASignal] = getattr(child, "_reader", None) - if isinstance(child, ASignal) and child._srate is not None: - previous_srate = child._srate - - def samples(self) -> Iterator[T]: return self.reader.samples() - - def __repr__(self): - return ( - f"{self._reader} | {super().__repr__()}" - ) - - -class FilterFunction(Filter, Generic[T, S]): - def __init__( - self, - fn: Callable[[S], Iterator[T]], - name: Optional[str] = None, - reader: Optional[S] = None, - srate: Optional[float] = None, - ): - super().__init__(reader, srate) - self._fn = fn - self.Function = name if name else fn.__name__ - - def samples(self): return self._fn(self.reader) - - -class SignalFunction(ASignal, Generic[T]): - def __init__( - self, - fn: Callable[[int], Iterator[T]], - name: Optional[str] = None, - srate: Optional[float] = None, - ): - super().__init__(srate) - self._fn = fn - self.Function = name if name else fn.__name__ - - def samples(self) -> Iterator[T]: return self._fn(self.srate) - - -class Constantly(ASignal, Generic[T]): - def __init__(self, constant: T, srate: float = 0.0): - super().__init__(srate) - self.constant = constant - - def samples(self) -> Iterator[T]: - while True: - yield self.constant +from .signal import * diff --git a/src/pojagi_dsp/channel/ecg/__init__.py b/src/pojagi_dsp/channel/ecg/__init__.py index 034c1f6..1a43f40 100644 --- a/src/pojagi_dsp/channel/ecg/__init__.py +++ b/src/pojagi_dsp/channel/ecg/__init__.py @@ -68,13 +68,13 @@ class Segments: class AECGChannel(ASignal[Number]): @property - @abc.abstractproperty + @abc.abstractmethod def heart_rate(self) -> float: """Frequency of impulses/waves in bpm.""" ... @property - @abc.abstractproperty + @abc.abstractmethod def wavelength(self) -> int: """ The number of samples in a complete impulse/wave cycle. @@ -83,16 +83,3 @@ class AECGChannel(ASignal[Number]): heart rate converted from bpm to Hz. """ ... - - @property - @abc.abstractproperty - def segments(self) -> Segments: - """The analytical segments of the impulse/wave.""" - ... - - @property - def wave(self) -> Iterator[Number]: - """ - Returns an iterator over a single ECG impulse/wave. - """ - return itertools.islice(self, 0, self.wavelength) diff --git a/src/pojagi_dsp/channel/ecg/generator/wavetable/__init__.py b/src/pojagi_dsp/channel/ecg/generator/wavetable/__init__.py index 7f79d87..395281f 100644 --- a/src/pojagi_dsp/channel/ecg/generator/wavetable/__init__.py +++ b/src/pojagi_dsp/channel/ecg/generator/wavetable/__init__.py @@ -1,279 +1,4 @@ -import dataclasses -import logging -import math -from numbers import Number -from typing import Dict, List, Optional, Tuple - -import numpy as np -from scipy.interpolate import CubicSpline - -from pojagi_dsp.channel.ecg import Segments -from pojagi_dsp.channel.ecg.generator import AECGSynthesizer - -logger = logging.getLogger(__name__) - - -class ECGWaveTable: - """ - This type of wavetable is designed around the P and R. By - convention, R will always be equal to 1, and the baseline (P) will always - be 0. (That doesn't mean, however, that the other values can't cross these - boundaries. E.g., Q and S are often negative.) - """ - - def __init__( - self, - data: List[Number], - segments: Segments, - bottom: Optional[Number] = None, - top: Optional[Number] = None, - table_length: int = 1 << 11, # 2048 - ): - """ - The table size is increased to `table_length` upon initialization - using linear interpolation (via `numpy.interp`). - """ - - if len(data) == table_length: - self.data = np.array(data) - else: - # We generate a larger table for use with linear interpolation, - # trading time (CPU) for memory (table size). - # - # Here we use cubic spline interpolation instead of linear for - # wavetable construction, since it usually only happens once at - # startup, and should provide a much better quality table from - # limited data, making it possible to work with small, manually - # composed tables that we JIT convert to the larger table. - ### FIXME: use the sinc function instead: - ### f(x) = sin(x)/x where x =/= 0 and f(x) = 1 if x = 0 - ### you have to apply this scaled to each sample in the table and - ### then add all of the resulting signals together. - ### I think this is the same as summing the dft of each impulse - ### as if the impulse is a member of a larger table. - cs = CubicSpline( - range(len(data)), - data, - bc_type="natural", - ) - self.data = np.array( - [cs(x) for x in np.linspace(0, len(data), table_length)] - ) - - # Scale the declared segments to the table_length - self.segments = Segments( - **{ - k: int(v / (len(data) / table_length)) if v else v - for k, v in ( - (f.name, getattr(segments, f.name)) - for f in dataclasses.fields(segments) - ) - }, - ) - - # NOTE: these are not the data min/max, but the normal min/max, either - # provided as kwargs, or derived from P and R segment starts, by - # convention. - bottom = bottom if bottom is not None else data[segments.P] - top = top if top is not None else data[segments.R] - - if not (0 == bottom and 1 == top): - # Normalize between 0 and 1: - self.data = (self.data - bottom) / (top - bottom) - - def __getitem__(self, k): - return self.data[k] - - def __len__(self): - return len(self.data) # O(1) - - def linear_interpolation( - self, - index: float, - floor: Optional[int] = None, - ceiling: Optional[int] = None, - ) -> float: - """ - Handles the situation where the floor would produce duplicate values, - which makes the waveform chunky with aliasing; instead, we obtain a - value weighted between the floor/ceiling, trading time (CPU) for - memory (table size). - """ - dl = len(self.data) - floor = floor if floor is not None else math.floor(index) % dl - ceiling = ceiling if ceiling is not None else (floor + 1) % dl - - # e.g., a. 124.75 - 124 == 0.75 - # b. 123 - 123 == 0 (no weight goes to ceiling) - ceiling_weight = index - floor - # e.g., a. 1 - 0.75 == 0.25 - # b. 1 - 0 == 1 (all weight goes to floor) - floor_weight = 1 - ceiling_weight - - return self[floor] * floor_weight + self[ceiling] * ceiling_weight - - def merge( - self, - other: "ECGWaveTable", - weight: float, - ): - self_weight = 1 - weight - return ECGWaveTable( - data=(self.data * self_weight + other.data * weight), - segments=self.segments.merge(other.segments, weight), - top=1, - bottom=0, - ) - - -class ECGWaveTableSynthesizer(AECGSynthesizer): - def __init__( - self, - /, - tables: Dict[Tuple[float, float], ECGWaveTable], - heart_rate: int, - srate: Optional[float] = None, - ): - super().__init__(heart_rate, srate) - self.inc: float = 0.0 - self.tables = tables - self._segments: Segments = None - self.phase = 0.0 - self.q_lock = False - - def samples(self): - # phase: float = 0.0 - inc: float = None - idx: int = 0 - heart_rate = self.heart_rate - - self._calibrate() - - while idx < self.wavelength: - phase = self.phase - floor = math.floor(phase) - - yield self.table.linear_interpolation(phase, floor=floor) - - if ( - heart_rate < 60 - and self.table.segments.T_P <= floor < self.table.segments.P - ): - inc = self.brady_inc - logger.info(["brady", inc, self.inc]) - self.q_lock = True - elif ( - heart_rate > 60 - and self.table.segments.S_T <= floor < self.table.segments.Q - ): - # FIXME: this is probably only good below a certain - # `heart_rate` threshold, because at some high frequency, even - # the QRS complex will not have enough room to complete. - inc = self.tachy_inc - self.q_lock = True - else: - inc = None - if self.q_lock: - phase = self.table.segments.Q - self.q_lock = False - logger.info(f"\n{[ - self.table.linear_interpolation(phase, floor=floor), - inc, - self.inc, - self.table.segments.Q, - phase, - self.table.segments.S_T - ]}") - - phase += inc if inc is not None else self.inc - phase %= len(self.table) - self.phase = phase - idx += 1 - - @AECGSynthesizer.heart_rate.setter - def heart_rate(self, val): - AECGSynthesizer.heart_rate.fset(self, val) - - def _calibrate(self): - heart_rate = self.heart_rate - - table_matches = { - k: v for k, v in self.tables.items() if k[0] <= heart_rate < k[1] - } - - if not table_matches: - raise ValueError( - f"No table found corresponding to heart rate: {heart_rate}." - ) - - # Since we may have more than two tables that match, we loop - # through all the matches, applying them in key order. - keys = iter(sorted(table_matches)) - key = next(keys) - table = table_matches[key] - - for next_key in keys: - next_table = table_matches[next_key] - - if next_key[1] < key[1]: - # `next_key` is fully contained within `key` - floor, ceiling = next_key - next_weight = (heart_rate - floor) / (ceiling - floor) - weight = 1 - next_weight - - if (heart_rate - floor) > ((ceiling - floor) / 2): - # Weights form an "X" shape; i.e., crossfade to 50% - # and back. - weight, next_weight = next_weight, weight - else: - floor = next_key[0] # i.e., the bottom of the top - ceiling = key[1] # i.e., the top of the bottom - next_weight = (heart_rate - floor) / (ceiling - floor) - - table = table.merge(next_table, next_weight) - key = next_key - - self.table = table - - # ECG Tables are designed for 1Hz, and as a default, we don't want to - # stretch anything; hence, no reference to `self.heart_rate` here, - # instead constant 60: - self.inc = len(self.table) / (self.srate * (60 / 60)) - - # Stretch only the T_P segment to compensate, rather than - # stretching the whole wave. - table_segment_length = self.table.segments.P - self.table.segments.T_P - self.brady_inc = self.stretch_inc(table_segment_length) - - # Preserve QRS-J-point; compress Jp-Q to compensate. - table_segment_length = self.table.segments.Q - self.table.segments.S_T - self.tachy_inc = self.stretch_inc(table_segment_length) - - def stretch_inc(self, table_segment_length): - # Get the missing samples by subtracting the number of samples - # contributed by the 1Hz table default, minus the segment we - # want to stretch. - tmp_wavelength = ( - self.wavelength - (len(self.table) - table_segment_length) / self.inc - ) - - return table_segment_length / tmp_wavelength - - @property - def segments(self): - if self._segments: - return self._segments - - table_length = len(self.table) - table_segments = self.table.segments - # FIXME: this is a lie since we stretch T_P, etc. - self._segments = Segments( - **{ - k: math.floor(v * self.srate / table_length) if v else v - for k, v in [ - (f.name, getattr(table_segments, f.name)) - for f in dataclasses.fields(table_segments) - ] - } - ) - return self._segments +from pojagi_dsp.channel.ecg.generator.wavetable.synthesizer import ( + ECGWaveTableSynthesizer, +) +from pojagi_dsp.channel.ecg.generator.wavetable.wavetable import ECGWaveTable diff --git a/src/pojagi_dsp/channel/ecg/generator/wavetable/sinus.py b/src/pojagi_dsp/channel/ecg/generator/wavetable/sinus.py index 2d4582c..3182c73 100644 --- a/src/pojagi_dsp/channel/ecg/generator/wavetable/sinus.py +++ b/src/pojagi_dsp/channel/ecg/generator/wavetable/sinus.py @@ -1,156 +1,165 @@ import math import numpy as np -from numbers import Number -import random -from typing import List from pojagi_dsp.channel.ecg import Segments from pojagi_dsp.channel.ecg.generator.wavetable import ECGWaveTable # NOTE: larger table required for avoiding aliasing at different srates than 125Hz -sinus_data = np.array([ - # R-S: 0 - 2000, - 1822, - 374, - # S-Jp: 3 - -474, - -271, - -28, - 18, - 66, - # Jp-T: 9 - 63, - 73, - 91, - 101, - 101, - 101, - 116, - 124, - 124, - # T: 17 - 141, - 171, - 186, - 196, - 229, - 265, - 297, - 327, - 363, - 406, - 446, - 475, - 493, - 508, - 526, - 533, - 518, - 475, - 403, - 327, - 272, - 222, - 174, - 138, - 109, - 88, - 73, - 66, - 69, - 69, - 66, - 73, - 81, - 76, - 73, - 76, - 76, - 66, - 58, - 58, - 63, - 63, - 41, - 26, - 26, - 18, - 8, - 8, - 8, - # U: 66 -- not found - # T-P: 66 - 2, - 3, - 2, - 2, - 2, - -1, - 2, - 2, - 2, - -1, - 0, - -1, - -1, - 3, - 2, - 1, - 3, - 2, - 1, - 0, - 1, - # P: 87 - 0, - 3, - 11, - 11, - 0, - 8, - 18, - 18, - 18, - 15, - 8, - 18, - 26, - 26, - 26, - 8, - 32, - 61, - 116, - 164, - 182, - 159, - 131, - 116, - 116, - 109, - 91, - 73, - 58, - 55, - 58, - 63, - 69, - # P-R: 120 - 48, - -14, - # Q-R: 122 - -40, - 131, - 931, -]) # len == 125 +sinus_data = np.array( + [ + # R-S: 0 + 2000, + 1822, + 374, + # S-Jp: 3 + -474, + -271, + -28, + 18, + 66, + # Jp-T: 9 + 63, + 73, + 91, + 101, + 101, + 101, + 116, + 124, + 124, + # T: 17 + 141, + 171, + 186, + 196, + 229, + 265, + 297, + 327, + 363, + 406, + 446, + 475, + 493, + 508, + 526, + 533, + 518, + 475, + 403, + 327, + 272, + 222, + 174, + 138, + 109, + 88, + 73, + 66, + 69, + 69, + 66, + 73, + 81, + 76, + 73, + 76, + 76, + 66, + 58, + 58, + 63, + 63, + 41, + 26, + 26, + 18, + 8, + 8, + 8, + # U: 66 -- not found + # T-P: 66 + 2, + 3, + 2, + 2, + 2, + -1, + 2, + 2, + 2, + -1, + 0, + -1, + -1, + 3, + 2, + 1, + 3, + 2, + 1, + 0, + 1, + # P: 87 + 0, + 3, + 11, + 11, + 0, + 8, + 18, + 18, + 18, + 15, + 8, + 18, + 26, + 26, + 26, + 8, + 32, + 61, + 116, + 164, + 182, + 159, + 131, + 116, + 116, + 109, + 91, + 73, + 58, + 55, + 58, + 63, + 69, + # P-R: 120 + 48, + -14, + # Q-R: 122 + -40, + 131, + 931, + ] +) # len == 125 -pr_idx = 120 -t_idx = 18 -t_pr_idx_diff = pr_idx - t_idx -t_pr = np.arange(-math.floor(t_pr_idx_diff / 2), math.ceil(t_pr_idx_diff / 2)) -t_pr_curve: np.ndarray = t_pr**2 * -0.2 -t_pr_curve = (t_pr_curve - t_pr_curve[0]) + 141 +def parabolic_curve(t_idx, pr_idx): + """Calculates a smooth, physiologically mimetic curve for the + T-P-R segment of the ECG waveform. + """ + # Compute the difference in indices to determine the segment length + t_pr_idx_diff = pr_idx - t_idx + + # Generate a symmetric range of values centered around zero for the segment + t_pr = np.arange(-math.floor(t_pr_idx_diff / 2), math.ceil(t_pr_idx_diff / 2)) + + # Apply a parabolic transformation to create a smooth transition + t_pr_curve: np.ndarray = t_pr**2 * -0.25 + + # Normalize the curve so it starts at the T wave amplitude (141) + return t_pr_curve - t_pr_curve[0] + tachycardia = np.array( [ # 58-107 flat @@ -174,8 +183,9 @@ tachycardia = np.array( 116, 124, 124, - *t_pr_curve, - # P-R: 119 + # T: 17 + *parabolic_curve(17, 119) + 141, + # P-R: 120 124, 48, -14, @@ -188,43 +198,67 @@ tachycardia = np.array( def SinusWaveTable(): + segments = Segments( + S=3, + S_T=9, + T=17, + T_P=66, + P=87, + P_R=120, + Q=122, + ) + return ECGWaveTable( data=sinus_data, - segments=Segments( - S=3, - S_T=9, - T=17, - T_P=66, - P=87, - P_R=120, - Q=122, - ), + segments=segments, ) def TachycardiaWaveTable(): + segments = Segments( + S=3, + S_T=8, + T=17, + T_P=66, + P=87, + P_R=119, + Q=122, + ) + return ECGWaveTable( data=tachycardia, - segments=Segments( - S=3, - S_T=8, - T=17, - T_P=66, - P=87, - P_R=119, - Q=122, - ), - # Tachy is weaker than sinus, so we inflate the range here, which - # effectively attenuates the signal by 1/3 (i.e., it is 2/3 the - # original). + segments=segments, + # Tachy is weaker than sinus, so we inflate the range here by 3/2, + # which effectively attenuates the signal by 1/3 (i.e., it is 2/3 of + # the amplitude of the data definition). top=2000 * (3 / 2), bottom=0, ) -impulse_data = np.array([1, *([0] * 124)]) + +def FastTachycardiaWaveTable(): + segments = Segments( + S=3, + S_T=8, + T=17, + T_P=66, + P=87, + P_R=119, + Q=122, + ) + + return ECGWaveTable( + data=np.arange(-50, 51) ** 11 / 1e19, + segments=segments, + tachy_compress=("R", "R"), + top=2, + bottom=0, + ) + + def ImpulseWaveTable(): return ECGWaveTable( - data=impulse_data, + data=np.array([1, *([0] * 124)]), segments=Segments( S=3, S_T=9, @@ -241,8 +275,8 @@ def ImpulseWaveTable(): if __name__ == "__main__": from matplotlib import pyplot as plt - - samples = np.tile(SinusWaveTable().data, 3) + + samples = np.tile(FastTachycardiaWaveTable().data, 3) plt.plot(range(len(samples)), samples) plt.show() diff --git a/src/pojagi_dsp/channel/ecg/generator/wavetable/synthesizer.py b/src/pojagi_dsp/channel/ecg/generator/wavetable/synthesizer.py new file mode 100644 index 0000000..5d8feb3 --- /dev/null +++ b/src/pojagi_dsp/channel/ecg/generator/wavetable/synthesizer.py @@ -0,0 +1,165 @@ +import numpy as np +from pojagi_dsp.channel.ecg.generator import AECGSynthesizer +from pojagi_dsp.channel.ecg.generator.wavetable.wavetable import ( + ECGWaveTable, +) + + +class ECGWaveTableSynthesizer(AECGSynthesizer): + def __init__( + self, + /, + tables: dict[tuple[float, float], ECGWaveTable], + heart_rate: int, + srate: float | None = None, + ): + super().__init__(heart_rate, srate) + self.inc: float = 0.0 + self.tables = tables + self.table: ECGWaveTable | None = None + self.phase: float = 0.0 + + self.brady_start: int = 0 + self.brady_end: int = 0 + self.brady_inc: float = 0.0 + + self.tachy_start: int = 0 + self.tachy_end: int = 0 + self.tachy_inc: float = 0.0 + + def samples(self): + inc: float = None + idx: int = 0 + heart_rate = self.heart_rate + + self._calibrate() + + print("inside samples", hex(id(self))) + + while idx < self.wavelength: + phase = self.phase + floor = np.floor(phase) + + yield self.table.linear_interpolation( + phase, floor=floor + ) + + if heart_rate < 60 and ( + self.brady_start <= phase < self.brady_end + ): + inc = self.brady_inc + if phase + inc > self.brady_end: + inc = self.brady_end - phase + elif heart_rate > 60 and ( + self.tachy_start <= phase < self.tachy_end + ): + # FIXME: this might only good below a certain `heart_rate` + # threshold, because at some high frequency, even the QRS + # complex will not have enough room to complete. + inc = self.tachy_inc + if phase + inc > self.tachy_end: + inc = self.tachy_end - phase + else: + inc = None + + phase += inc if inc is not None else self.inc + if phase > len(self.table): + phase = 0.0 + + self.phase = phase + idx += 1 + + @AECGSynthesizer.heart_rate.setter + def heart_rate(self, val): + AECGSynthesizer.heart_rate.fset(self, val) + print("setter", hex(id(self))) + + def _calibrate(self): + heart_rate = self.heart_rate + + table_matches = { + k: v + for k, v in self.tables.items() + if k[0] <= heart_rate < k[1] + } + + if not table_matches: + raise ValueError( + f"No table found corresponding to heart rate: {heart_rate}." + ) + + # Since we may have more than two tables that match, we loop + # through all the matches, applying them in key order. + keys = iter(sorted(table_matches)) + key = next(keys) + table = table_matches[key] + + for next_key in keys: + next_table = table_matches[next_key] + + if next_key[1] < key[1]: + # `next_key` is fully contained within `key` + floor, ceiling = next_key + next_weight = (heart_rate - floor) / ( + ceiling - floor + ) + weight = 1 - next_weight + + if (heart_rate - floor) > ( + (ceiling - floor) / 2 + ): + # Weights form an "X" shape; i.e., crossfade to 50% + # and back. + weight, next_weight = next_weight, weight + else: + floor = next_key[ + 0 + ] # i.e., the bottom of the top + ceiling = key[1] # i.e., the top of the bottom + next_weight = (heart_rate - floor) / ( + ceiling - floor + ) + + table = table.merge(next_table, next_weight) + key = next_key + + self.table = table + + # ECG Tables are designed for 1Hz, and as a default, we don't want to + # stretch anything; hence, no reference to `self.heart_rate` here, + # instead constant 60: + self.inc = len(self.table) / (self.srate * (60 / 60)) + + self.brady_start, self.brady_end = ( + getattr(self.table.segments, x) + for x in self.table.brady_stretch + ) + if self.table.brady_stretch == "R": + self.brady_end == len(self.table) - 1 + + # Stretch only the T_P segment to compensate, rather than + # stretching the whole wave. + table_segment_length = self.brady_end - self.brady_start + self.brady_inc = self.stretch_inc(table_segment_length) + + self.tachy_start, self.tachy_end = ( + getattr(self.table.segments, x) + for x in self.table.tachy_compress + ) + if self.table.tachy_compress[1] == "R": + self.tachy_end = len(self.table) - 1 + + # Preserve QRS-J-point; compress Jp-Q to compensate. + table_segment_length = self.tachy_end - self.tachy_start + self.tachy_inc = self.stretch_inc(table_segment_length) + + def stretch_inc(self, table_segment_length: int) -> float: + # Get the missing samples by subtracting the number of samples + # contributed by the 1Hz table default, minus the segment we + # want to stretch. + tmp_wavelength = ( + self.wavelength + - (len(self.table) - table_segment_length) / self.inc + ) + + return table_segment_length / tmp_wavelength diff --git a/src/pojagi_dsp/channel/ecg/generator/wavetable/wavetable.py b/src/pojagi_dsp/channel/ecg/generator/wavetable/wavetable.py new file mode 100644 index 0000000..2bfb023 --- /dev/null +++ b/src/pojagi_dsp/channel/ecg/generator/wavetable/wavetable.py @@ -0,0 +1,182 @@ +import dataclasses +from numbers import Number + +import numpy as np +from scipy.interpolate import CubicSpline + +from pojagi_dsp.channel.ecg import Segments + + +class ECGWaveTable: + """ + This type of wavetable is designed around the P and R. By + convention, R will always be equal to 1, and the baseline (P) will always + be 0. (That doesn't mean, however, that the other values can't cross these + boundaries. E.g., Q and S are often negative.) + """ + + def __init__( + self, + /, + data: list[Number], + segments: Segments, + brady_stretch: tuple[str, str] | None = None, + tachy_compress: tuple[str, str] | None = None, + bottom: Number | None = None, + top: Number | None = None, + table_length: int = (1 << 10) * 2, + ): + """ + Initialize an ECGWaveTable. + + Args: + data (List[Number]): The raw waveform data points for one cardiac cycle. + segments (Segments): Segment indices (e.g., P, Q, R, S, T) marking key features in the waveform. + brady_stretch (tuple[str, str] | None): Segment interval to stretch for bradycardia (slow heart rate). + tachy_compress (tuple[str, str] | None): Segment interval to compress for tachycardia (fast heart rate). + bottom (Optional[Number]): Value to use as the baseline (P segment) for normalization. If None, uses data[segments.P]. + top (Optional[Number]): Value to use as the peak (R segment) for normalization. If None, uses data[segments.R]. + table_length (int): Number of samples in the expanded wavetable (default: 2048). + """ + + if len(data) == table_length: + self.data = np.array(data) + else: + # We generate a larger table for use with linear interpolation, + # trading time (CPU) for memory (table size). + # + # Here we use cubic spline interpolation instead of linear for + # wavetable construction, since it usually only happens once at + # startup, and should provide a much better quality table from + # limited data, making it possible to work with small, manually + # composed tables that we JIT convert to the larger table. + ### FIXME: use the sinc function instead: + ### f(x) = sin(x)/x where x =/= 0 and f(x) = 1 if x = 0 + ### you have to apply this scaled to each sample in the table and + ### then add all of the resulting signals together. + ### I think this is the same as summing the dft of each impulse + ### as if the impulse is a member of a larger table. + cs = CubicSpline( + range(len(data)), + data, + bc_type="natural", + ) + + self.data = np.array( + [ + cs(x) + for x in np.linspace( + 0, len(data), table_length + ) + ] + ) + + # Scale the declared segments to the table_length + self.segments = Segments( + **{ + k: ( + int(v / (len(data) / table_length)) + if v + else v + ) + for k, v in ( + (f.name, getattr(segments, f.name)) + for f in dataclasses.fields(segments) + ) + }, + ) + + # NOTE: these are not the data min/max, but the normal min/max, either + # provided as kwargs, or derived from P and R segment starts, by + # convention. + bottom = ( + bottom if bottom is not None else data[segments.P] + ) + top = top if top is not None else data[segments.R] + + if not (0 == bottom and 1 == top): + # Normalize between 0 and 1: + self.data = (self.data - bottom) / (top - bottom) + + self.brady_stretch = ( + brady_stretch + if brady_stretch is not None + else ("S_T", "P") + ) + + self.tachy_compress = ( + tachy_compress + if tachy_compress is not None + else ("S_T", "Q") + ) + + def __getitem__(self, k): + return self.data[k] + + def __len__(self): + return len(self.data) # O(1) + + def linear_interpolation( + self, + index: float, + floor: Number | None = None, + ceiling: Number | None = None, + ) -> float: + """ + Returns a smoothly interpolated value from the wavetable at a fractional index. + + Instead of returning discrete values (which can cause aliasing and a "chunky" waveform), + this method linearly interpolates between the nearest lower (floor) and upper (ceiling) + indices, weighted by their distance from the requested index. This improves waveform + smoothness and reduces artifacts when sampling at arbitrary positions. + + Args: + index (float): The fractional index to sample. + floor (Optional[Number]): Override for the lower index (default: floor of index). + ceiling (Optional[Number]): Override for the upper index (default: floor + 1). + + Returns: + float: The interpolated value at the given index. + """ + dl = len(self.data) + floor = ( + floor if floor is not None else np.floor(index) % dl + ) + ceiling = ( + ceiling if ceiling is not None else (floor + 1) % dl + ) + + # e.g., a. 124.75 - 124 == 0.75 + # b. 123 - 123 == 0 (no weight goes to ceiling) + ceiling_weight = index - floor + # e.g., a. 1 - 0.75 == 0.25 + # b. 1 - 0 == 1 (all weight goes to floor) + floor_weight = 1 - ceiling_weight + + return ( + self[int(floor)] * floor_weight + + self[int(ceiling)] * ceiling_weight + ) + + def merge( + self, + other: "ECGWaveTable", + weight: float, + ): + self_weight = 1 - weight + return ECGWaveTable( + data=(self.data * self_weight + other.data * weight), + segments=self.segments.merge(other.segments, weight), + brady_stretch=( + other.brady_stretch + if self_weight < 0.5 + else self.brady_stretch + ), + tachy_compress=( + other.tachy_compress + if self_weight < 0.5 + else self.tachy_compress + ), + top=1, + bottom=0, + ) diff --git a/src/pojagi_dsp/channel/generator/sine.py b/src/pojagi_dsp/channel/generator/sine.py index 6705fda..de9f5fc 100644 --- a/src/pojagi_dsp/channel/generator/sine.py +++ b/src/pojagi_dsp/channel/generator/sine.py @@ -12,14 +12,15 @@ class SineWave(ASignal[float]): self, hz: float, phase: float = 0.0, # radians - srate: Optional[float] = None + srate: Optional[float] = None, ): super().__init__(srate) self.hz = hz self.phase = phase @property - def wavelength(self): return self.srate/self.hz + def wavelength(self): + return self.srate / self.hz def samples(self): """An iterator over one period.""" @@ -27,7 +28,7 @@ class SineWave(ASignal[float]): self.phase %= _2_PI while self.phase < _2_PI: - inc = (_2_PI * self.hz)/self.srate + inc = (_2_PI * self.hz) / self.srate yield math.sin(self.phase) self.phase += inc @@ -44,7 +45,9 @@ if __name__ == "__main__": # for _ in range(10): # values += list(sine) - for y in sine.of_duration(datetime.timedelta(milliseconds=10)): + for y in sine.of_duration( + datetime.timedelta(milliseconds=10) + ): values.append(y) plt.plot(range(len(values)), values) diff --git a/src/pojagi_dsp/channel/signal.py b/src/pojagi_dsp/channel/signal.py new file mode 100644 index 0000000..06f523e --- /dev/null +++ b/src/pojagi_dsp/channel/signal.py @@ -0,0 +1,361 @@ +import abc +import copy +import datetime +import inspect +import logging +import math +import operator +import types +from collections.abc import Iterable +from functools import reduce +from itertools import islice +from typing import Any, Callable, Generic, Iterator, Optional, Type, TypeVar, Union + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +class IllegalStateException(ValueError): ... + + +def coerce_channels(x: Any) -> Iterator["ASignal"]: + if isinstance(x, ASignal): + yield x + else: + if callable(x): + if isinstance(x, Type): + yield x() + else: + yield SignalFunction(x) + elif isinstance(x, Iterable): # and not isinstance(x, str): + for it in (coerce_channels(y) for y in x): + for channel in it: + yield channel + else: + yield Constantly(x) + + +class ASignalMeta(abc.ABCMeta): + def __or__(self, other: Any) -> "Filter": + """ + Allows `|` composition starting from an uninitialized class. + See doc for `__or__` below in `ASignal`. + """ + return self() | coerce_channels(other) + + def __radd__(self, other): + return self() + other + + def __add__(self, other): + return self() + other + + def __rmul__(self, other): + return self() * other + + def __mul__(self, other): + return self() * other + + +class ASignal(Generic[T], metaclass=ASignalMeta): + def __init__(self, srate: Optional[float] = None): + self._srate = srate + self._cursor: Optional[Iterator[T]] = None + + @property + def srate(self): + if self._srate is None: + raise IllegalStateException(f"{self.__class__}: `srate` is None.") + return self._srate + + @srate.setter + def srate(self, val: float): + self._srate = val + + def __iter__(self): + self._cursor = self.samples() + return self + + def __next__(self): + return next(self.cursor) + + @abc.abstractmethod + def samples(self) -> Iterator[T]: ... + + @property + def cursor(self): + """ + An `Iterator` representing the current pipeline in progress. + """ + if self._cursor is None: + # this can only happen once + self._cursor = self.samples() + return self._cursor + + def __getstate__(self): + """ + `_cursor` is a generator, and generators aren't picklable. + """ + state = self.__dict__.copy() + if state.get("_cursor"): + del state["_cursor"] + return state + + def stream(self): + while True: + try: + yield next(self.cursor) + except StopIteration: + self = iter(self) + + def of_duration(self, duration: datetime.timedelta): + """ + Returns an `Iterator` of samples for a particular duration expressed + as a `datetime.timedelta` + :param:`duration` - `datetime.timedelta` representing the duration + """ + return islice( + self.stream(), + 0, + math.floor(self.srate * duration.total_seconds()), + ) + + def __or__( + left, + right: Union["Filter", Callable, Iterable], + ) -> "Filter": + """ + Allows composition of filter pipelines with `|` operator. + + e.g., + ``` + myFooGenerator + | BarFilter + | baz_filter_func + | (lambda reader: (x for x in reader)) + ``` + """ + if isinstance(right, SignalFunction): + return left | FilterFunction(fn=right._fn, name=right.Function) + + if not isinstance(right, ASignal): + return reduce(operator.or_, (left, *coerce_channels(right))) + + if not isinstance(right, Filter): + raise ValueError( + f"Right side must be a `{Filter.__name__}`; " + f"received: {type(right)}", + ) + + filter: Filter = right + while getattr(filter, "_reader", None) is not None: + # Assuming this is a filter pipeline, we want the last node's + # reader to be whatever's on the left side of this operation. + filter = filter.reader + + if hasattr(filter, "_reader"): + # We hit the "bottom" and found a filter. + filter.reader = left + else: + # We hit the "bottom" and found a non-filter/generator. + raise ValueError( + f"{right.__class__.__name__}: filter pipeline already has a " + "generator." + ) + + # Will often be `None` unless `left` is a generator. + right.srate = left._srate + + return right + + def __radd__(right, left): + return right.__add__(left) + + def __add__(left, right): + return left._operator_impl(operator.add, right) + + def __rmul__(right, left): + return right.__mul__(left) + + def __mul__(left, right): + return left._operator_impl(operator.mul, right) + + # FIXME: other operators? Also, shouldn't `*` mean convolve instead? + + def _operator_impl(left, operator: Callable[..., T], right: Any): + channels = list(coerce_channels(right)) + for channel in channels: + if channel._srate is None: + channel.srate = left._srate + return Reduce(operator, left, *channels, srate=left._srate) + + def __repr__(self): + members = {} + for k in [ + k + for k in dir(self) + if not k.startswith("_") + and not k + in { + "stream", + "reader", + "cursor", + "wave", + } + ]: + try: + v = getattr(self, k) + if not inspect.isroutine(v): + members[k] = v + except IllegalStateException as e: + members[k] = None + + return ( + f"{self.__class__.__name__}" + f"""({ + f", ".join( + f"{k}={v}" + for k, v in members.items() + ) + })""" + ) + + +S = TypeVar("S", bound=ASignal) + + +class Reduce(ASignal, Generic[S, T]): + def __init__( + self, + # FIXME: typing https://stackoverflow.com/a/67814270 + fn: Callable[..., T], + *streams: S, + srate: Optional[float] = None, + stateful=True, + ): + super().__init__(srate) + self._fn = fn + self.fn = fn.__name__ + self.streams = [] + for stream in streams: + if stateful: + self.streams.append(stream) + continue + + stream_ = ( + copy.deepcopy(stream) + if not isinstance(stream, types.GeneratorType) + else stream + ) + stream_.srate = srate + self.streams.append(stream_) + + @property + def srate(self): + return ASignal.srate.fget(self) + + @srate.setter + def srate(self, val: float): + ASignal.srate.fset(self, val) + for stream in self.streams: + if isinstance(stream, ASignal): + stream.srate = val + + def samples(self): + return (reduce(self._fn, args) for args in zip(*self.streams)) + + +class Filter(ASignal, Generic[S]): + + def __init__( + self, + reader: Optional[S] = None, + srate: Optional[float] = None, + ): + super().__init__(srate) + self.reader: Optional[S] = reader + + @property + def reader(self) -> S: + """ + The input stream this filter reads. + """ + if not self._reader: + raise IllegalStateException(f"{self.__class__}: `reader` is None.") + return self._reader + + @reader.setter + def reader(self, val: S): + self._reader = val + if val is not None and self._srate is None: + self.srate = val._srate + + @property + def srate(self): + return ASignal.srate.fget(self) + + @srate.setter + def srate(self, val: float): + ASignal.srate.fset(self, val) + child = getattr(self, "_reader", None) + previous_srate = val + while child is not None: + # Since `srate` is optional at initialization, but required in + # general, we make our best attempt to normalize it for the + # filter pipeline, which should be consistent for most + # applications, by applying it to all children. + if child._srate is None: + child.srate = previous_srate + child: Optional[ASignal] = getattr(child, "_reader", None) + if isinstance(child, ASignal) and child._srate is not None: + previous_srate = child._srate + + def samples(self) -> Iterator[T]: + """The below is a default implementation, but this is meant to be + overrided. + """ + return self.reader.samples() + + def __repr__(self): + return f"{self._reader} | {super().__repr__()}" + + +class FilterFunction(Filter, Generic[T, S]): + def __init__( + self, + fn: Callable[[S], Iterator[T]], + name: Optional[str] = None, + reader: Optional[S] = None, + srate: Optional[float] = None, + ): + super().__init__(reader, srate) + self._fn = fn + self.Function = name if name else fn.__name__ + + def samples(self): + return self._fn(self.reader) + + +class SignalFunction(ASignal, Generic[T]): + def __init__( + self, + fn: Callable[[int], Iterator[T]], + name: Optional[str] = None, + srate: Optional[float] = None, + ): + super().__init__(srate) + self._fn = fn + self.Function = name if name else fn.__name__ + + def samples(self) -> Iterator[T]: + return self._fn(self.srate) + + +class Constantly(ASignal, Generic[T]): + def __init__(self, constant: T, srate: float = 0.0): + super().__init__(srate) + self.constant = constant + + def samples(self) -> Iterator[T]: + while True: + yield self.constant