bugfixes, etc

This commit is contained in:
2025-10-03 08:58:45 -04:00
parent cf033c4b4f
commit d78b82c9e3
8 changed files with 928 additions and 807 deletions

View File

@@ -1,337 +1 @@
import abc
import copy
import datetime
import inspect
from itertools import islice
import logging
import math
import operator
from collections.abc import Iterable
from functools import reduce
import types
from typing import (Any, Callable, Generator, Generic, Iterator, Optional, Type, TypeVar,
Union)
logger = logging.getLogger(__name__)
T = TypeVar("T")
class IllegalStateException(ValueError):
...
def coerce_channels(x: Any) -> Iterator["ASignal"]:
if isinstance(x, ASignal):
yield x
else:
if callable(x):
if isinstance(x, Type):
yield x()
else:
yield SignalFunction(x)
elif isinstance(x, Iterable): # and not isinstance(x, str):
for it in (coerce_channels(y) for y in x):
for channel in it:
yield channel
else:
yield Constantly(x)
class ASignalMeta(abc.ABCMeta):
def __or__(self, other: Any) -> "Filter":
"""
Allows `|` composition starting from an uninitialized class.
See doc for `__or__` below in `ASignal`.
"""
return self() | coerce_channels(other)
def __radd__(self, other): return self() + other
def __add__(self, other): return self() + other
def __rmul__(self, other): return self() * other
def __mul__(self, other): return self() * other
class ASignal(Generic[T], metaclass=ASignalMeta):
def __init__(self, srate: Optional[float] = None):
self._srate = srate
self._cursor: Optional[Iterator[T]] = None
@property
def srate(self):
if self._srate is None:
raise IllegalStateException(
f"{self.__class__}: `srate` is None."
)
return self._srate
@srate.setter
def srate(self, val: float): self._srate = val
def __iter__(self):
self._cursor = self.samples()
return self
def __next__(self): return next(self.cursor)
@abc.abstractmethod
def samples(self) -> Iterator[T]: ...
@property
def cursor(self):
"""
An `Iterator` representing the current pipeline in progress.
"""
if self._cursor is None:
# this can only happen once
self._cursor = self.samples()
return self._cursor
def __getstate__(self):
"""
`_cursor` is a generator, and generators aren't picklable.
"""
state = self.__dict__.copy()
if state.get("_cursor"):
del state["_cursor"]
return state
def stream(self):
while True:
try:
yield next(self.cursor)
except StopIteration:
self = iter(self)
def of_duration(self, duration: datetime.timedelta):
"""
Returns an `Iterator` of samples for a particular duration expressed
as a `datetime.timedelta`
:param:`duration` - `datetime.timedelta` representing the duration
"""
return islice(
self.stream(),
0,
math.floor(self.srate * duration.total_seconds()),
)
def __or__(
left,
right: Union["Filter", Callable, Iterable],
) -> "Filter":
"""
Allows composition of filter pipelines with `|` operator.
e.g.,
```
myFooGenerator
| BarFilter
| baz_filter_func
| (lambda reader: (x for x in reader))
```
"""
if isinstance(right, SignalFunction):
return left | FilterFunction(fn=right._fn, name=right.Function)
if not isinstance(right, ASignal):
return reduce(operator.or_, (left, *coerce_channels(right)))
if not isinstance(right, Filter):
raise ValueError(
f"Right side must be a `{Filter.__name__}`; "
f"received: {type(right)}",
)
filter: Filter = right
while getattr(filter, "_reader", None) is not None:
# Assuming this is a filter pipeline, we want the last node's
# reader to be whatever's on the left side of this operation.
filter = filter.reader
if hasattr(filter, "_reader"):
# We hit the "bottom" and found a filter.
filter.reader = left
else:
# We hit the "bottom" and found a non-filter/generator.
raise ValueError(
f"{right.__class__.__name__}: filter pipeline already has a "
"generator."
)
# Will often be `None` unless `left` is a generator.
right.srate = left._srate
return right
def __radd__(right, left): return right.__add__(left)
def __add__(left, right): return left._operator_impl(operator.add, right)
def __rmul__(right, left): return right.__mul__(left)
def __mul__(left, right): return left._operator_impl(operator.mul, right)
# FIXME: other operators? Also, shouldn't `*` mean convolve instead?
def _operator_impl(left, operator: Callable[..., T], right: Any):
channels = list(coerce_channels(right))
for channel in channels:
if channel._srate is None:
channel.srate = left._srate
return Reduce(operator, left, *channels, srate=left._srate)
def __repr__(self):
members = {}
for k in [k for k in dir(self)
if not k.startswith("_")
and not k in {"stream", "reader", "cursor", "wave", }]:
try:
v = getattr(self, k)
if not inspect.isroutine(v):
members[k] = v
except IllegalStateException as e:
members[k] = None
return (
f"{self.__class__.__name__}"
f"""({
f", ".join(
f"{k}={v}"
for k, v in members.items()
)
})"""
)
S = TypeVar("S", bound=ASignal)
class Reduce(ASignal, Generic[S, T]):
def __init__(
self,
# FIXME: typing https://stackoverflow.com/a/67814270
fn: Callable[..., T],
*streams: S,
srate: Optional[float] = None,
stateful=False,
):
super().__init__(srate)
self._fn = fn
self.fn = fn.__name__
self.streams = []
for stream in streams:
if stateful:
self.streams.append(stream)
continue
stream_ = (
copy.deepcopy(stream)
if not isinstance(stream, types.GeneratorType)
else stream
)
stream_.srate = srate
self.streams.append(stream_)
@property
def srate(self): return ASignal.srate.fget(self)
@srate.setter
def srate(self, val: float):
ASignal.srate.fset(self, val)
for stream in self.streams:
if isinstance(stream, ASignal):
stream.srate = val
def samples(self): return (
reduce(self._fn, args)
for args in zip(*self.streams)
)
class Filter(ASignal, Generic[S]):
def __init__(
self,
reader: Optional[S] = None,
srate: Optional[float] = None,
):
super().__init__(srate)
self.reader: Optional[S] = reader
@property
def reader(self) -> S:
"""
The input stream this filter reads.
"""
if not self._reader:
raise IllegalStateException(
f"{self.__class__}: `reader` is None."
)
return self._reader
@reader.setter
def reader(self, val: S):
self._reader = val
if val is not None and self._srate is None:
self.srate = val._srate
@property
def srate(self): return ASignal.srate.fget(self)
@srate.setter
def srate(self, val: float):
ASignal.srate.fset(self, val)
child = getattr(self, "_reader", None)
previous_srate = val
while child is not None:
# Since `srate` is optional at initialization, but required in
# general, we make our best attempt to normalize it for the
# filter pipeline, which should be consistent for most
# applications, by applying it to all children.
if child._srate is None:
child.srate = previous_srate
child: Optional[ASignal] = getattr(child, "_reader", None)
if isinstance(child, ASignal) and child._srate is not None:
previous_srate = child._srate
def samples(self) -> Iterator[T]: return self.reader.samples()
def __repr__(self):
return (
f"{self._reader} | {super().__repr__()}"
)
class FilterFunction(Filter, Generic[T, S]):
def __init__(
self,
fn: Callable[[S], Iterator[T]],
name: Optional[str] = None,
reader: Optional[S] = None,
srate: Optional[float] = None,
):
super().__init__(reader, srate)
self._fn = fn
self.Function = name if name else fn.__name__
def samples(self): return self._fn(self.reader)
class SignalFunction(ASignal, Generic[T]):
def __init__(
self,
fn: Callable[[int], Iterator[T]],
name: Optional[str] = None,
srate: Optional[float] = None,
):
super().__init__(srate)
self._fn = fn
self.Function = name if name else fn.__name__
def samples(self) -> Iterator[T]: return self._fn(self.srate)
class Constantly(ASignal, Generic[T]):
def __init__(self, constant: T, srate: float = 0.0):
super().__init__(srate)
self.constant = constant
def samples(self) -> Iterator[T]:
while True:
yield self.constant
from .signal import *

View File

@@ -68,13 +68,13 @@ class Segments:
class AECGChannel(ASignal[Number]):
@property
@abc.abstractproperty
@abc.abstractmethod
def heart_rate(self) -> float:
"""Frequency of impulses/waves in bpm."""
...
@property
@abc.abstractproperty
@abc.abstractmethod
def wavelength(self) -> int:
"""
The number of samples in a complete impulse/wave cycle.
@@ -83,16 +83,3 @@ class AECGChannel(ASignal[Number]):
heart rate converted from bpm to Hz.
"""
...
@property
@abc.abstractproperty
def segments(self) -> Segments:
"""The analytical segments of the impulse/wave."""
...
@property
def wave(self) -> Iterator[Number]:
"""
Returns an iterator over a single ECG impulse/wave.
"""
return itertools.islice(self, 0, self.wavelength)

View File

@@ -1,279 +1,4 @@
import dataclasses
import logging
import math
from numbers import Number
from typing import Dict, List, Optional, Tuple
import numpy as np
from scipy.interpolate import CubicSpline
from pojagi_dsp.channel.ecg import Segments
from pojagi_dsp.channel.ecg.generator import AECGSynthesizer
logger = logging.getLogger(__name__)
class ECGWaveTable:
"""
This type of wavetable is designed around the P and R. By
convention, R will always be equal to 1, and the baseline (P) will always
be 0. (That doesn't mean, however, that the other values can't cross these
boundaries. E.g., Q and S are often negative.)
"""
def __init__(
self,
data: List[Number],
segments: Segments,
bottom: Optional[Number] = None,
top: Optional[Number] = None,
table_length: int = 1 << 11, # 2048
):
"""
The table size is increased to `table_length` upon initialization
using linear interpolation (via `numpy.interp`).
"""
if len(data) == table_length:
self.data = np.array(data)
else:
# We generate a larger table for use with linear interpolation,
# trading time (CPU) for memory (table size).
#
# Here we use cubic spline interpolation instead of linear for
# wavetable construction, since it usually only happens once at
# startup, and should provide a much better quality table from
# limited data, making it possible to work with small, manually
# composed tables that we JIT convert to the larger table.
### FIXME: use the sinc function instead:
### f(x) = sin(x)/x where x =/= 0 and f(x) = 1 if x = 0
### you have to apply this scaled to each sample in the table and
### then add all of the resulting signals together.
### I think this is the same as summing the dft of each impulse
### as if the impulse is a member of a larger table.
cs = CubicSpline(
range(len(data)),
data,
bc_type="natural",
)
self.data = np.array(
[cs(x) for x in np.linspace(0, len(data), table_length)]
)
# Scale the declared segments to the table_length
self.segments = Segments(
**{
k: int(v / (len(data) / table_length)) if v else v
for k, v in (
(f.name, getattr(segments, f.name))
for f in dataclasses.fields(segments)
)
},
)
# NOTE: these are not the data min/max, but the normal min/max, either
# provided as kwargs, or derived from P and R segment starts, by
# convention.
bottom = bottom if bottom is not None else data[segments.P]
top = top if top is not None else data[segments.R]
if not (0 == bottom and 1 == top):
# Normalize between 0 and 1:
self.data = (self.data - bottom) / (top - bottom)
def __getitem__(self, k):
return self.data[k]
def __len__(self):
return len(self.data) # O(1)
def linear_interpolation(
self,
index: float,
floor: Optional[int] = None,
ceiling: Optional[int] = None,
) -> float:
"""
Handles the situation where the floor would produce duplicate values,
which makes the waveform chunky with aliasing; instead, we obtain a
value weighted between the floor/ceiling, trading time (CPU) for
memory (table size).
"""
dl = len(self.data)
floor = floor if floor is not None else math.floor(index) % dl
ceiling = ceiling if ceiling is not None else (floor + 1) % dl
# e.g., a. 124.75 - 124 == 0.75
# b. 123 - 123 == 0 (no weight goes to ceiling)
ceiling_weight = index - floor
# e.g., a. 1 - 0.75 == 0.25
# b. 1 - 0 == 1 (all weight goes to floor)
floor_weight = 1 - ceiling_weight
return self[floor] * floor_weight + self[ceiling] * ceiling_weight
def merge(
self,
other: "ECGWaveTable",
weight: float,
):
self_weight = 1 - weight
return ECGWaveTable(
data=(self.data * self_weight + other.data * weight),
segments=self.segments.merge(other.segments, weight),
top=1,
bottom=0,
)
class ECGWaveTableSynthesizer(AECGSynthesizer):
def __init__(
self,
/,
tables: Dict[Tuple[float, float], ECGWaveTable],
heart_rate: int,
srate: Optional[float] = None,
):
super().__init__(heart_rate, srate)
self.inc: float = 0.0
self.tables = tables
self._segments: Segments = None
self.phase = 0.0
self.q_lock = False
def samples(self):
# phase: float = 0.0
inc: float = None
idx: int = 0
heart_rate = self.heart_rate
self._calibrate()
while idx < self.wavelength:
phase = self.phase
floor = math.floor(phase)
yield self.table.linear_interpolation(phase, floor=floor)
if (
heart_rate < 60
and self.table.segments.T_P <= floor < self.table.segments.P
):
inc = self.brady_inc
logger.info(["brady", inc, self.inc])
self.q_lock = True
elif (
heart_rate > 60
and self.table.segments.S_T <= floor < self.table.segments.Q
):
# FIXME: this is probably only good below a certain
# `heart_rate` threshold, because at some high frequency, even
# the QRS complex will not have enough room to complete.
inc = self.tachy_inc
self.q_lock = True
else:
inc = None
if self.q_lock:
phase = self.table.segments.Q
self.q_lock = False
logger.info(f"\n{[
self.table.linear_interpolation(phase, floor=floor),
inc,
self.inc,
self.table.segments.Q,
phase,
self.table.segments.S_T
]}")
phase += inc if inc is not None else self.inc
phase %= len(self.table)
self.phase = phase
idx += 1
@AECGSynthesizer.heart_rate.setter
def heart_rate(self, val):
AECGSynthesizer.heart_rate.fset(self, val)
def _calibrate(self):
heart_rate = self.heart_rate
table_matches = {
k: v for k, v in self.tables.items() if k[0] <= heart_rate < k[1]
}
if not table_matches:
raise ValueError(
f"No table found corresponding to heart rate: {heart_rate}."
)
# Since we may have more than two tables that match, we loop
# through all the matches, applying them in key order.
keys = iter(sorted(table_matches))
key = next(keys)
table = table_matches[key]
for next_key in keys:
next_table = table_matches[next_key]
if next_key[1] < key[1]:
# `next_key` is fully contained within `key`
floor, ceiling = next_key
next_weight = (heart_rate - floor) / (ceiling - floor)
weight = 1 - next_weight
if (heart_rate - floor) > ((ceiling - floor) / 2):
# Weights form an "X" shape; i.e., crossfade to 50%
# and back.
weight, next_weight = next_weight, weight
else:
floor = next_key[0] # i.e., the bottom of the top
ceiling = key[1] # i.e., the top of the bottom
next_weight = (heart_rate - floor) / (ceiling - floor)
table = table.merge(next_table, next_weight)
key = next_key
self.table = table
# ECG Tables are designed for 1Hz, and as a default, we don't want to
# stretch anything; hence, no reference to `self.heart_rate` here,
# instead constant 60:
self.inc = len(self.table) / (self.srate * (60 / 60))
# Stretch only the T_P segment to compensate, rather than
# stretching the whole wave.
table_segment_length = self.table.segments.P - self.table.segments.T_P
self.brady_inc = self.stretch_inc(table_segment_length)
# Preserve QRS-J-point; compress Jp-Q to compensate.
table_segment_length = self.table.segments.Q - self.table.segments.S_T
self.tachy_inc = self.stretch_inc(table_segment_length)
def stretch_inc(self, table_segment_length):
# Get the missing samples by subtracting the number of samples
# contributed by the 1Hz table default, minus the segment we
# want to stretch.
tmp_wavelength = (
self.wavelength - (len(self.table) - table_segment_length) / self.inc
)
return table_segment_length / tmp_wavelength
@property
def segments(self):
if self._segments:
return self._segments
table_length = len(self.table)
table_segments = self.table.segments
# FIXME: this is a lie since we stretch T_P, etc.
self._segments = Segments(
**{
k: math.floor(v * self.srate / table_length) if v else v
for k, v in [
(f.name, getattr(table_segments, f.name))
for f in dataclasses.fields(table_segments)
]
}
)
return self._segments
from pojagi_dsp.channel.ecg.generator.wavetable.synthesizer import (
ECGWaveTableSynthesizer,
)
from pojagi_dsp.channel.ecg.generator.wavetable.wavetable import ECGWaveTable

View File

@@ -1,156 +1,165 @@
import math
import numpy as np
from numbers import Number
import random
from typing import List
from pojagi_dsp.channel.ecg import Segments
from pojagi_dsp.channel.ecg.generator.wavetable import ECGWaveTable
# NOTE: larger table required for avoiding aliasing at different srates than 125Hz
sinus_data = np.array([
# R-S: 0
2000,
1822,
374,
# S-Jp: 3
-474,
-271,
-28,
18,
66,
# Jp-T: 9
63,
73,
91,
101,
101,
101,
116,
124,
124,
# T: 17
141,
171,
186,
196,
229,
265,
297,
327,
363,
406,
446,
475,
493,
508,
526,
533,
518,
475,
403,
327,
272,
222,
174,
138,
109,
88,
73,
66,
69,
69,
66,
73,
81,
76,
73,
76,
76,
66,
58,
58,
63,
63,
41,
26,
26,
18,
8,
8,
8,
# U: 66 -- not found
# T-P: 66
2,
3,
2,
2,
2,
-1,
2,
2,
2,
-1,
0,
-1,
-1,
3,
2,
1,
3,
2,
1,
0,
1,
# P: 87
0,
3,
11,
11,
0,
8,
18,
18,
18,
15,
8,
18,
26,
26,
26,
8,
32,
61,
116,
164,
182,
159,
131,
116,
116,
109,
91,
73,
58,
55,
58,
63,
69,
# P-R: 120
48,
-14,
# Q-R: 122
-40,
131,
931,
]) # len == 125
sinus_data = np.array(
[
# R-S: 0
2000,
1822,
374,
# S-Jp: 3
-474,
-271,
-28,
18,
66,
# Jp-T: 9
63,
73,
91,
101,
101,
101,
116,
124,
124,
# T: 17
141,
171,
186,
196,
229,
265,
297,
327,
363,
406,
446,
475,
493,
508,
526,
533,
518,
475,
403,
327,
272,
222,
174,
138,
109,
88,
73,
66,
69,
69,
66,
73,
81,
76,
73,
76,
76,
66,
58,
58,
63,
63,
41,
26,
26,
18,
8,
8,
8,
# U: 66 -- not found
# T-P: 66
2,
3,
2,
2,
2,
-1,
2,
2,
2,
-1,
0,
-1,
-1,
3,
2,
1,
3,
2,
1,
0,
1,
# P: 87
0,
3,
11,
11,
0,
8,
18,
18,
18,
15,
8,
18,
26,
26,
26,
8,
32,
61,
116,
164,
182,
159,
131,
116,
116,
109,
91,
73,
58,
55,
58,
63,
69,
# P-R: 120
48,
-14,
# Q-R: 122
-40,
131,
931,
]
) # len == 125
pr_idx = 120
t_idx = 18
t_pr_idx_diff = pr_idx - t_idx
t_pr = np.arange(-math.floor(t_pr_idx_diff / 2), math.ceil(t_pr_idx_diff / 2))
t_pr_curve: np.ndarray = t_pr**2 * -0.2
t_pr_curve = (t_pr_curve - t_pr_curve[0]) + 141
def parabolic_curve(t_idx, pr_idx):
"""Calculates a smooth, physiologically mimetic curve for the
T-P-R segment of the ECG waveform.
"""
# Compute the difference in indices to determine the segment length
t_pr_idx_diff = pr_idx - t_idx
# Generate a symmetric range of values centered around zero for the segment
t_pr = np.arange(-math.floor(t_pr_idx_diff / 2), math.ceil(t_pr_idx_diff / 2))
# Apply a parabolic transformation to create a smooth transition
t_pr_curve: np.ndarray = t_pr**2 * -0.25
# Normalize the curve so it starts at the T wave amplitude (141)
return t_pr_curve - t_pr_curve[0]
tachycardia = np.array(
[ # 58-107 flat
@@ -174,8 +183,9 @@ tachycardia = np.array(
116,
124,
124,
*t_pr_curve,
# P-R: 119
# T: 17
*parabolic_curve(17, 119) + 141,
# P-R: 120
124,
48,
-14,
@@ -188,43 +198,67 @@ tachycardia = np.array(
def SinusWaveTable():
segments = Segments(
S=3,
S_T=9,
T=17,
T_P=66,
P=87,
P_R=120,
Q=122,
)
return ECGWaveTable(
data=sinus_data,
segments=Segments(
S=3,
S_T=9,
T=17,
T_P=66,
P=87,
P_R=120,
Q=122,
),
segments=segments,
)
def TachycardiaWaveTable():
segments = Segments(
S=3,
S_T=8,
T=17,
T_P=66,
P=87,
P_R=119,
Q=122,
)
return ECGWaveTable(
data=tachycardia,
segments=Segments(
S=3,
S_T=8,
T=17,
T_P=66,
P=87,
P_R=119,
Q=122,
),
# Tachy is weaker than sinus, so we inflate the range here, which
# effectively attenuates the signal by 1/3 (i.e., it is 2/3 the
# original).
segments=segments,
# Tachy is weaker than sinus, so we inflate the range here by 3/2,
# which effectively attenuates the signal by 1/3 (i.e., it is 2/3 of
# the amplitude of the data definition).
top=2000 * (3 / 2),
bottom=0,
)
impulse_data = np.array([1, *([0] * 124)])
def FastTachycardiaWaveTable():
segments = Segments(
S=3,
S_T=8,
T=17,
T_P=66,
P=87,
P_R=119,
Q=122,
)
return ECGWaveTable(
data=np.arange(-50, 51) ** 11 / 1e19,
segments=segments,
tachy_compress=("R", "R"),
top=2,
bottom=0,
)
def ImpulseWaveTable():
return ECGWaveTable(
data=impulse_data,
data=np.array([1, *([0] * 124)]),
segments=Segments(
S=3,
S_T=9,
@@ -241,8 +275,8 @@ def ImpulseWaveTable():
if __name__ == "__main__":
from matplotlib import pyplot as plt
samples = np.tile(SinusWaveTable().data, 3)
samples = np.tile(FastTachycardiaWaveTable().data, 3)
plt.plot(range(len(samples)), samples)
plt.show()

View File

@@ -0,0 +1,165 @@
import numpy as np
from pojagi_dsp.channel.ecg.generator import AECGSynthesizer
from pojagi_dsp.channel.ecg.generator.wavetable.wavetable import (
ECGWaveTable,
)
class ECGWaveTableSynthesizer(AECGSynthesizer):
def __init__(
self,
/,
tables: dict[tuple[float, float], ECGWaveTable],
heart_rate: int,
srate: float | None = None,
):
super().__init__(heart_rate, srate)
self.inc: float = 0.0
self.tables = tables
self.table: ECGWaveTable | None = None
self.phase: float = 0.0
self.brady_start: int = 0
self.brady_end: int = 0
self.brady_inc: float = 0.0
self.tachy_start: int = 0
self.tachy_end: int = 0
self.tachy_inc: float = 0.0
def samples(self):
inc: float = None
idx: int = 0
heart_rate = self.heart_rate
self._calibrate()
print("inside samples", hex(id(self)))
while idx < self.wavelength:
phase = self.phase
floor = np.floor(phase)
yield self.table.linear_interpolation(
phase, floor=floor
)
if heart_rate < 60 and (
self.brady_start <= phase < self.brady_end
):
inc = self.brady_inc
if phase + inc > self.brady_end:
inc = self.brady_end - phase
elif heart_rate > 60 and (
self.tachy_start <= phase < self.tachy_end
):
# FIXME: this might only good below a certain `heart_rate`
# threshold, because at some high frequency, even the QRS
# complex will not have enough room to complete.
inc = self.tachy_inc
if phase + inc > self.tachy_end:
inc = self.tachy_end - phase
else:
inc = None
phase += inc if inc is not None else self.inc
if phase > len(self.table):
phase = 0.0
self.phase = phase
idx += 1
@AECGSynthesizer.heart_rate.setter
def heart_rate(self, val):
AECGSynthesizer.heart_rate.fset(self, val)
print("setter", hex(id(self)))
def _calibrate(self):
heart_rate = self.heart_rate
table_matches = {
k: v
for k, v in self.tables.items()
if k[0] <= heart_rate < k[1]
}
if not table_matches:
raise ValueError(
f"No table found corresponding to heart rate: {heart_rate}."
)
# Since we may have more than two tables that match, we loop
# through all the matches, applying them in key order.
keys = iter(sorted(table_matches))
key = next(keys)
table = table_matches[key]
for next_key in keys:
next_table = table_matches[next_key]
if next_key[1] < key[1]:
# `next_key` is fully contained within `key`
floor, ceiling = next_key
next_weight = (heart_rate - floor) / (
ceiling - floor
)
weight = 1 - next_weight
if (heart_rate - floor) > (
(ceiling - floor) / 2
):
# Weights form an "X" shape; i.e., crossfade to 50%
# and back.
weight, next_weight = next_weight, weight
else:
floor = next_key[
0
] # i.e., the bottom of the top
ceiling = key[1] # i.e., the top of the bottom
next_weight = (heart_rate - floor) / (
ceiling - floor
)
table = table.merge(next_table, next_weight)
key = next_key
self.table = table
# ECG Tables are designed for 1Hz, and as a default, we don't want to
# stretch anything; hence, no reference to `self.heart_rate` here,
# instead constant 60:
self.inc = len(self.table) / (self.srate * (60 / 60))
self.brady_start, self.brady_end = (
getattr(self.table.segments, x)
for x in self.table.brady_stretch
)
if self.table.brady_stretch == "R":
self.brady_end == len(self.table) - 1
# Stretch only the T_P segment to compensate, rather than
# stretching the whole wave.
table_segment_length = self.brady_end - self.brady_start
self.brady_inc = self.stretch_inc(table_segment_length)
self.tachy_start, self.tachy_end = (
getattr(self.table.segments, x)
for x in self.table.tachy_compress
)
if self.table.tachy_compress[1] == "R":
self.tachy_end = len(self.table) - 1
# Preserve QRS-J-point; compress Jp-Q to compensate.
table_segment_length = self.tachy_end - self.tachy_start
self.tachy_inc = self.stretch_inc(table_segment_length)
def stretch_inc(self, table_segment_length: int) -> float:
# Get the missing samples by subtracting the number of samples
# contributed by the 1Hz table default, minus the segment we
# want to stretch.
tmp_wavelength = (
self.wavelength
- (len(self.table) - table_segment_length) / self.inc
)
return table_segment_length / tmp_wavelength

View File

@@ -0,0 +1,182 @@
import dataclasses
from numbers import Number
import numpy as np
from scipy.interpolate import CubicSpline
from pojagi_dsp.channel.ecg import Segments
class ECGWaveTable:
"""
This type of wavetable is designed around the P and R. By
convention, R will always be equal to 1, and the baseline (P) will always
be 0. (That doesn't mean, however, that the other values can't cross these
boundaries. E.g., Q and S are often negative.)
"""
def __init__(
self,
/,
data: list[Number],
segments: Segments,
brady_stretch: tuple[str, str] | None = None,
tachy_compress: tuple[str, str] | None = None,
bottom: Number | None = None,
top: Number | None = None,
table_length: int = (1 << 10) * 2,
):
"""
Initialize an ECGWaveTable.
Args:
data (List[Number]): The raw waveform data points for one cardiac cycle.
segments (Segments): Segment indices (e.g., P, Q, R, S, T) marking key features in the waveform.
brady_stretch (tuple[str, str] | None): Segment interval to stretch for bradycardia (slow heart rate).
tachy_compress (tuple[str, str] | None): Segment interval to compress for tachycardia (fast heart rate).
bottom (Optional[Number]): Value to use as the baseline (P segment) for normalization. If None, uses data[segments.P].
top (Optional[Number]): Value to use as the peak (R segment) for normalization. If None, uses data[segments.R].
table_length (int): Number of samples in the expanded wavetable (default: 2048).
"""
if len(data) == table_length:
self.data = np.array(data)
else:
# We generate a larger table for use with linear interpolation,
# trading time (CPU) for memory (table size).
#
# Here we use cubic spline interpolation instead of linear for
# wavetable construction, since it usually only happens once at
# startup, and should provide a much better quality table from
# limited data, making it possible to work with small, manually
# composed tables that we JIT convert to the larger table.
### FIXME: use the sinc function instead:
### f(x) = sin(x)/x where x =/= 0 and f(x) = 1 if x = 0
### you have to apply this scaled to each sample in the table and
### then add all of the resulting signals together.
### I think this is the same as summing the dft of each impulse
### as if the impulse is a member of a larger table.
cs = CubicSpline(
range(len(data)),
data,
bc_type="natural",
)
self.data = np.array(
[
cs(x)
for x in np.linspace(
0, len(data), table_length
)
]
)
# Scale the declared segments to the table_length
self.segments = Segments(
**{
k: (
int(v / (len(data) / table_length))
if v
else v
)
for k, v in (
(f.name, getattr(segments, f.name))
for f in dataclasses.fields(segments)
)
},
)
# NOTE: these are not the data min/max, but the normal min/max, either
# provided as kwargs, or derived from P and R segment starts, by
# convention.
bottom = (
bottom if bottom is not None else data[segments.P]
)
top = top if top is not None else data[segments.R]
if not (0 == bottom and 1 == top):
# Normalize between 0 and 1:
self.data = (self.data - bottom) / (top - bottom)
self.brady_stretch = (
brady_stretch
if brady_stretch is not None
else ("S_T", "P")
)
self.tachy_compress = (
tachy_compress
if tachy_compress is not None
else ("S_T", "Q")
)
def __getitem__(self, k):
return self.data[k]
def __len__(self):
return len(self.data) # O(1)
def linear_interpolation(
self,
index: float,
floor: Number | None = None,
ceiling: Number | None = None,
) -> float:
"""
Returns a smoothly interpolated value from the wavetable at a fractional index.
Instead of returning discrete values (which can cause aliasing and a "chunky" waveform),
this method linearly interpolates between the nearest lower (floor) and upper (ceiling)
indices, weighted by their distance from the requested index. This improves waveform
smoothness and reduces artifacts when sampling at arbitrary positions.
Args:
index (float): The fractional index to sample.
floor (Optional[Number]): Override for the lower index (default: floor of index).
ceiling (Optional[Number]): Override for the upper index (default: floor + 1).
Returns:
float: The interpolated value at the given index.
"""
dl = len(self.data)
floor = (
floor if floor is not None else np.floor(index) % dl
)
ceiling = (
ceiling if ceiling is not None else (floor + 1) % dl
)
# e.g., a. 124.75 - 124 == 0.75
# b. 123 - 123 == 0 (no weight goes to ceiling)
ceiling_weight = index - floor
# e.g., a. 1 - 0.75 == 0.25
# b. 1 - 0 == 1 (all weight goes to floor)
floor_weight = 1 - ceiling_weight
return (
self[int(floor)] * floor_weight
+ self[int(ceiling)] * ceiling_weight
)
def merge(
self,
other: "ECGWaveTable",
weight: float,
):
self_weight = 1 - weight
return ECGWaveTable(
data=(self.data * self_weight + other.data * weight),
segments=self.segments.merge(other.segments, weight),
brady_stretch=(
other.brady_stretch
if self_weight < 0.5
else self.brady_stretch
),
tachy_compress=(
other.tachy_compress
if self_weight < 0.5
else self.tachy_compress
),
top=1,
bottom=0,
)

View File

@@ -12,14 +12,15 @@ class SineWave(ASignal[float]):
self,
hz: float,
phase: float = 0.0, # radians
srate: Optional[float] = None
srate: Optional[float] = None,
):
super().__init__(srate)
self.hz = hz
self.phase = phase
@property
def wavelength(self): return self.srate/self.hz
def wavelength(self):
return self.srate / self.hz
def samples(self):
"""An iterator over one period."""
@@ -27,7 +28,7 @@ class SineWave(ASignal[float]):
self.phase %= _2_PI
while self.phase < _2_PI:
inc = (_2_PI * self.hz)/self.srate
inc = (_2_PI * self.hz) / self.srate
yield math.sin(self.phase)
self.phase += inc
@@ -44,7 +45,9 @@ if __name__ == "__main__":
# for _ in range(10):
# values += list(sine)
for y in sine.of_duration(datetime.timedelta(milliseconds=10)):
for y in sine.of_duration(
datetime.timedelta(milliseconds=10)
):
values.append(y)
plt.plot(range(len(values)), values)

View File

@@ -0,0 +1,361 @@
import abc
import copy
import datetime
import inspect
import logging
import math
import operator
import types
from collections.abc import Iterable
from functools import reduce
from itertools import islice
from typing import Any, Callable, Generic, Iterator, Optional, Type, TypeVar, Union
logger = logging.getLogger(__name__)
T = TypeVar("T")
class IllegalStateException(ValueError): ...
def coerce_channels(x: Any) -> Iterator["ASignal"]:
if isinstance(x, ASignal):
yield x
else:
if callable(x):
if isinstance(x, Type):
yield x()
else:
yield SignalFunction(x)
elif isinstance(x, Iterable): # and not isinstance(x, str):
for it in (coerce_channels(y) for y in x):
for channel in it:
yield channel
else:
yield Constantly(x)
class ASignalMeta(abc.ABCMeta):
def __or__(self, other: Any) -> "Filter":
"""
Allows `|` composition starting from an uninitialized class.
See doc for `__or__` below in `ASignal`.
"""
return self() | coerce_channels(other)
def __radd__(self, other):
return self() + other
def __add__(self, other):
return self() + other
def __rmul__(self, other):
return self() * other
def __mul__(self, other):
return self() * other
class ASignal(Generic[T], metaclass=ASignalMeta):
def __init__(self, srate: Optional[float] = None):
self._srate = srate
self._cursor: Optional[Iterator[T]] = None
@property
def srate(self):
if self._srate is None:
raise IllegalStateException(f"{self.__class__}: `srate` is None.")
return self._srate
@srate.setter
def srate(self, val: float):
self._srate = val
def __iter__(self):
self._cursor = self.samples()
return self
def __next__(self):
return next(self.cursor)
@abc.abstractmethod
def samples(self) -> Iterator[T]: ...
@property
def cursor(self):
"""
An `Iterator` representing the current pipeline in progress.
"""
if self._cursor is None:
# this can only happen once
self._cursor = self.samples()
return self._cursor
def __getstate__(self):
"""
`_cursor` is a generator, and generators aren't picklable.
"""
state = self.__dict__.copy()
if state.get("_cursor"):
del state["_cursor"]
return state
def stream(self):
while True:
try:
yield next(self.cursor)
except StopIteration:
self = iter(self)
def of_duration(self, duration: datetime.timedelta):
"""
Returns an `Iterator` of samples for a particular duration expressed
as a `datetime.timedelta`
:param:`duration` - `datetime.timedelta` representing the duration
"""
return islice(
self.stream(),
0,
math.floor(self.srate * duration.total_seconds()),
)
def __or__(
left,
right: Union["Filter", Callable, Iterable],
) -> "Filter":
"""
Allows composition of filter pipelines with `|` operator.
e.g.,
```
myFooGenerator
| BarFilter
| baz_filter_func
| (lambda reader: (x for x in reader))
```
"""
if isinstance(right, SignalFunction):
return left | FilterFunction(fn=right._fn, name=right.Function)
if not isinstance(right, ASignal):
return reduce(operator.or_, (left, *coerce_channels(right)))
if not isinstance(right, Filter):
raise ValueError(
f"Right side must be a `{Filter.__name__}`; "
f"received: {type(right)}",
)
filter: Filter = right
while getattr(filter, "_reader", None) is not None:
# Assuming this is a filter pipeline, we want the last node's
# reader to be whatever's on the left side of this operation.
filter = filter.reader
if hasattr(filter, "_reader"):
# We hit the "bottom" and found a filter.
filter.reader = left
else:
# We hit the "bottom" and found a non-filter/generator.
raise ValueError(
f"{right.__class__.__name__}: filter pipeline already has a "
"generator."
)
# Will often be `None` unless `left` is a generator.
right.srate = left._srate
return right
def __radd__(right, left):
return right.__add__(left)
def __add__(left, right):
return left._operator_impl(operator.add, right)
def __rmul__(right, left):
return right.__mul__(left)
def __mul__(left, right):
return left._operator_impl(operator.mul, right)
# FIXME: other operators? Also, shouldn't `*` mean convolve instead?
def _operator_impl(left, operator: Callable[..., T], right: Any):
channels = list(coerce_channels(right))
for channel in channels:
if channel._srate is None:
channel.srate = left._srate
return Reduce(operator, left, *channels, srate=left._srate)
def __repr__(self):
members = {}
for k in [
k
for k in dir(self)
if not k.startswith("_")
and not k
in {
"stream",
"reader",
"cursor",
"wave",
}
]:
try:
v = getattr(self, k)
if not inspect.isroutine(v):
members[k] = v
except IllegalStateException as e:
members[k] = None
return (
f"{self.__class__.__name__}"
f"""({
f", ".join(
f"{k}={v}"
for k, v in members.items()
)
})"""
)
S = TypeVar("S", bound=ASignal)
class Reduce(ASignal, Generic[S, T]):
def __init__(
self,
# FIXME: typing https://stackoverflow.com/a/67814270
fn: Callable[..., T],
*streams: S,
srate: Optional[float] = None,
stateful=True,
):
super().__init__(srate)
self._fn = fn
self.fn = fn.__name__
self.streams = []
for stream in streams:
if stateful:
self.streams.append(stream)
continue
stream_ = (
copy.deepcopy(stream)
if not isinstance(stream, types.GeneratorType)
else stream
)
stream_.srate = srate
self.streams.append(stream_)
@property
def srate(self):
return ASignal.srate.fget(self)
@srate.setter
def srate(self, val: float):
ASignal.srate.fset(self, val)
for stream in self.streams:
if isinstance(stream, ASignal):
stream.srate = val
def samples(self):
return (reduce(self._fn, args) for args in zip(*self.streams))
class Filter(ASignal, Generic[S]):
def __init__(
self,
reader: Optional[S] = None,
srate: Optional[float] = None,
):
super().__init__(srate)
self.reader: Optional[S] = reader
@property
def reader(self) -> S:
"""
The input stream this filter reads.
"""
if not self._reader:
raise IllegalStateException(f"{self.__class__}: `reader` is None.")
return self._reader
@reader.setter
def reader(self, val: S):
self._reader = val
if val is not None and self._srate is None:
self.srate = val._srate
@property
def srate(self):
return ASignal.srate.fget(self)
@srate.setter
def srate(self, val: float):
ASignal.srate.fset(self, val)
child = getattr(self, "_reader", None)
previous_srate = val
while child is not None:
# Since `srate` is optional at initialization, but required in
# general, we make our best attempt to normalize it for the
# filter pipeline, which should be consistent for most
# applications, by applying it to all children.
if child._srate is None:
child.srate = previous_srate
child: Optional[ASignal] = getattr(child, "_reader", None)
if isinstance(child, ASignal) and child._srate is not None:
previous_srate = child._srate
def samples(self) -> Iterator[T]:
"""The below is a default implementation, but this is meant to be
overrided.
"""
return self.reader.samples()
def __repr__(self):
return f"{self._reader} | {super().__repr__()}"
class FilterFunction(Filter, Generic[T, S]):
def __init__(
self,
fn: Callable[[S], Iterator[T]],
name: Optional[str] = None,
reader: Optional[S] = None,
srate: Optional[float] = None,
):
super().__init__(reader, srate)
self._fn = fn
self.Function = name if name else fn.__name__
def samples(self):
return self._fn(self.reader)
class SignalFunction(ASignal, Generic[T]):
def __init__(
self,
fn: Callable[[int], Iterator[T]],
name: Optional[str] = None,
srate: Optional[float] = None,
):
super().__init__(srate)
self._fn = fn
self.Function = name if name else fn.__name__
def samples(self) -> Iterator[T]:
return self._fn(self.srate)
class Constantly(ASignal, Generic[T]):
def __init__(self, constant: T, srate: float = 0.0):
super().__init__(srate)
self.constant = constant
def samples(self) -> Iterator[T]:
while True:
yield self.constant