diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..38facaf --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,4 @@ +{ + "editor.defaultFormatter": "ms-python.black-formatter", + "editor.formatOnSave": false +} \ No newline at end of file diff --git a/Makefile b/Makefile index 065dbf2..982c221 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,3 @@ -# Makefile for medtrace-synth - VENV := .venv PYTHON := $(VENV)/bin/python PIP := $(PYTHON) -m pip diff --git a/README.md b/README.md index c954277..a6e0d05 100644 --- a/README.md +++ b/README.md @@ -1 +1,22 @@ -# Back-end \ No newline at end of file +# Websockets server + +## Quick start + +This project uses [GNU make](https://www.gnu.org/software/make/) to build and run. When available, type `make` and hit enter to see what is available: + +``` +➜ make +Targets: + venv - Create virtualenv in .venv + install - Install deps and this package + run - Run the server via 'python -m medtrace_synth' + install-dev - Install deps (and this package) in dev mode + dev - Run using PYTHONPATH=src (no install) + build - Build sdist and wheel into dist/ + clean - Remove build artifacts + nuke - Clean artifacts and remove .venv +``` + +Try `make run` to download all dependencies and run the server. + +> Note that running `make dev` will start the server and watch the `src` directory, but you also will need to have the `pojagi-dsp` project locally, and the `POJAGI_DSP_PATH` environment variable exported to point to the top level of that project's directory. diff --git a/src/medtrace_synth/__main__.py b/src/medtrace_synth/__main__.py index c09027a..f3c3b36 100644 --- a/src/medtrace_synth/__main__.py +++ b/src/medtrace_synth/__main__.py @@ -1,14 +1,25 @@ import asyncio +import json import logging -import time +import random +import signal +from numbers import Number +from typing import Iterable import websockets -from pojagi_dsp.channel.ecg.generator.wavetable import ECGWaveTableSynthesizer +from pojagi_dsp.channel.ecg.generator.wavetable import ( + ECGWaveTableSynthesizer, +) from pojagi_dsp.channel.ecg.generator.wavetable.sinus import ( - SinusWaveTable, TachycardiaWaveTable) + FastTachycardiaWaveTable, + SinusWaveTable, + TachycardiaWaveTable, +) +from pojagi_dsp.channel.generator.sine import SineWave +from pojagi_dsp.channel.signal import Constantly, Filter +from websockets import Data, WebSocketServerProtocol -if __name__ != "__main__": - raise ImportWarning("This script is not intended to be imported.") +PORT = 7890 logging.basicConfig( level=logging.INFO, @@ -16,52 +27,160 @@ logging.basicConfig( ) log = logging.getLogger(__name__) -PORT = 7890 + +async def listen_for_messages( + websocket: websockets.WebSocketServerProtocol, + ecg: ECGWaveTableSynthesizer, +): + try: + async for message in websocket: + try: + packet: dict = json.loads(message) + if new_rate := packet.get("heartRate"): + try: + if 30 <= new_rate <= 300: + ecg.heart_rate = new_rate + log.info( + f"Heart rate updated to {new_rate}" + ) + else: + log.warning( + f"Invalid heart rate: {new_rate}" + ) + except ValueError: + log.warning( + f"Non-integer message received: {message}" + ) + except Exception as e: + log.warning(f"Uncaught exception: {e}") + except websockets.exceptions.ConnectionClosed: + log.info("Client disconnected (listener)") -async def consumer_handler(websocket: websockets.WebSocketServerProtocol): - async for message in websocket: - print(f"message received: {message}") +def randomize(g: Iterable[Number]): + return (x + ((random.uniform(-0.5, 0.5)) * 0.4) for x in g) -async def producer_handler(websocket: websockets.WebSocketServerProtocol): - srate = 50 +class Noise(Filter): + def __init__( + self, coefficient: float, reader=None, srate=None + ): + super().__init__(reader, srate) + self.coef = coefficient + def samples(self): + return ( + x + ((random.uniform(-0.5, 0.5)) * self.coef) + for x in self.reader + ) + + +class NoiseOscillator(SineWave): + def __init__( + self, + hz, + hz_variance=None, + amp=1.0, + amp_variance=None, + phase=0, + srate=None, + ): + super().__init__(hz, phase, srate) + self.base_hz = hz + self.hz_variance = hz_variance or 0.0 + self.amp = amp + self.amp_variance = amp_variance or 0.0 + self.base_amp = amp + + def randomize_frequency(self): + return ( + random.random() - 0.5 + ) * self.hz_variance + self.base_hz + + def randomize_amplitude(self): + return ( + random.random() - 0.5 + ) * self.amp_variance + self.base_amp + + def samples(self): + prev = self.hz + next = abs(self.randomize_frequency()) + inc = (next - prev) / self.srate + amp = self.randomize_amplitude() + + for x in super().samples(): + yield x * amp + self.hz += inc + + +async def handler( + websocket: WebSocketServerProtocol, + path: str, +) -> None: + log.info(f"New connection. Path: {path}") + srate = 100 + heart_rate = 90 ecg = ECGWaveTableSynthesizer( tables={ - (0, 90): SinusWaveTable(), - (70, 300): TachycardiaWaveTable(), + (0, 160): SinusWaveTable(), + (70, 290): TachycardiaWaveTable(), + (170, 300): FastTachycardiaWaveTable(), }, - heart_rate=60, + heart_rate=heart_rate, srate=srate, ) - while True: - try: - message = next(ecg) + output_signal = ( + Constantly(0, srate=srate) + + ecg + # + SineWave(hz=1) * 1 + # + SineWave(hz=0.4) * 1.15 + # + SineWave(hz=0.3) * 1.15 + # + NoiseOscillator( + # hz=2, + # hz_variance=20, + # amp=1.0, + # amp_variance=0.0, + # srate=srate, + # ) + ) + stream = output_signal.stream() + + listener_task = asyncio.create_task( + listen_for_messages(websocket, ecg) + ) + + await websocket.send(json.dumps({"srate": srate})) + + try: + while True: + message = next(stream) await websocket.send(str(message)) - await asyncio.sleep(1/srate) - except websockets.exceptions.ConnectionClosed as e: - print("A client just disconnected") - break + await asyncio.sleep(1 / srate) + except websockets.exceptions.ConnectionClosed: + log.info("Client disconnected (sender)") + finally: + listener_task.cancel() + log.info("Connection handler finished") -async def handler(websocket, path): - while True: - print(f"New connection. Path: {path}") - consumer_task = asyncio.create_task(consumer_handler(websocket)) - producer_task = asyncio.create_task(producer_handler(websocket)) - done, pending = await asyncio.wait( - [consumer_task, producer_task], - return_when=asyncio.FIRST_COMPLETED, - ) - for task in pending: - task.cancel() +async def main() -> None: + stop = asyncio.Future() -# Start the server -start_server = websockets.serve(handler, "0.0.0.0", PORT) -asyncio.get_event_loop().run_until_complete(start_server) + def shutdown(): + log.info("Received exit signal, shutting down...") + stop.set_result(None) -try: - asyncio.get_event_loop().run_forever() -except KeyboardInterrupt: - log.info("exiting...") + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler(sig, shutdown) + + async with websockets.serve(handler, "0.0.0.0", PORT): + log.info(f"WebSocket server started on port {PORT}") + await stop + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except Exception as e: + log.error(f"Server error: {e}")