Source code for ntfy_api.subscription

  1"""The :class:`NtfySubscription` class used for handling subscriptions.
  2
  3:copyright: (c) 2024 Tanner Corcoran
  4:license: Apache 2.0, see LICENSE for more details.
  5
  6"""
  7
  8import dataclasses
  9import json
 10import logging
 11import sys
 12import threading
 13from types import MappingProxyType, TracebackType
 14from typing import Literal, Union
 15
 16if sys.version_info >= (3, 11):  # pragma: no cover
 17    from typing import Self
 18else:  # pragma: no cover
 19    from typing_extensions import Self
 20
 21from websockets import exceptions as ws_exc
 22from websockets.sync import client as ws_client
 23
 24from .__version__ import *  # noqa: F401,F403
 25from ._internals import URL, ClearableQueue, StrTuple
 26from .creds import Credentials
 27from .filter import Filter
 28from .message import ReceivedMessage, _ReceivedMessage
 29
 30__all__ = ("NtfySubscription",)
 31logger = logging.Logger(__name__)
 32
 33
[docs] 34@dataclasses.dataclass(eq=False, frozen=True) 35class NtfySubscription: 36 """The class that handles subscriptions. 37 38 :param base_url: The base URL of a ntfy server. 39 :param topics: The topics to subscribe to. 40 :param credentials: The user credentials, if any. 41 :param filter: Optional response filters. 42 :param max_queue_size: The maximum size of the message queue. If 43 ``<=0``, the queue is unbounded. If the queue is filled, all new 44 messages are discarded. Only when the queue has room for 45 another message, will messages start being added again. This 46 means that, if bounded, some messages may be dropped if the 47 frequency of received messages is greater than your 48 program's ability to handle those messages. 49 50 """ 51 52 base_url: str 53 """See the :paramref:`~NtfySubscription.base_url` parameter.""" 54 55 topics: StrTuple 56 """See the :paramref:`~NtfySubscription.topics` parameter.""" 57 58 credentials: Union[Credentials, None] = None 59 """See the :paramref:`~NtfySubscription.credentials` parameter.""" 60 61 filter: Union[Filter, None] = None 62 """See the :paramref:`~NtfySubscription.filter` parameter.""" 63 64 max_queue_size: int = 0 65 """See the :paramref:`~NtfySubscription.max_queue_size` parameter. 66 67 """ 68 69 messages: ClearableQueue[ReceivedMessage] = dataclasses.field(init=False) 70 """The message queue. 71 72 This attribute stores received messages. See :class:`queue.Queue` 73 for details on how to interact with this attribute. 74 75 """ 76 77 _url: URL = dataclasses.field(init=False) 78 _auth_header: MappingProxyType[str, str] = dataclasses.field(init=False) 79 _ws_conn: Union[ws_client.ClientConnection, None] = dataclasses.field( 80 default=None, init=False 81 ) 82 _thread: Union[threading.Thread, None] = dataclasses.field( 83 default=None, init=False 84 ) 85 86 def __post_init__(self) -> None: 87 """Create message queue, and set URL and credentials.""" 88 # message queue 89 object.__setattr__( 90 self, "messages", ClearableQueue(self.max_queue_size) 91 ) 92 93 # url 94 object.__setattr__(self, "_url", URL.parse(self.base_url)) 95 96 # credentials 97 object.__setattr__( 98 self, 99 "_auth_header", 100 (self.credentials or Credentials()).get_header(), 101 ) 102
[docs] 103 def __enter__(self) -> Self: 104 """Enter the context manager protocol. 105 106 :returns: The `NtfySubscription` instance. 107 :rtype: NtfySubscription 108 109 """ 110 if not self._ws_conn: 111 self.connect() 112 return self
113
[docs] 114 def __exit__( 115 self, 116 exc_type: Union[type[BaseException], None], 117 exc_val: Union[BaseException, None], 118 exc_tb: Union[TracebackType, None], 119 ) -> Literal[False]: 120 """Exit the context manager protocol. 121 122 This ensures the client is closed. 123 124 :returns: Always :py:obj:`False`. See :meth:`object.__exit__` 125 for more information on what this return value means. 126 :rtype: typing.Literal[False] 127 128 """ 129 self.close() 130 return False
131
[docs] 132 def connect( 133 self, connection: Union[ws_client.ClientConnection, None] = None 134 ) -> Self: 135 """Initiate the websocket connection. 136 137 .. note:: 138 This also clears :attr:`~NtfySubscription.messages`. 139 140 :param connection: The websocket connection to use. If not 141 provided, one will be created. 142 :type connection: websockets.sync.client.ClientConnection | 143 None, optional 144 145 :returns: This :class:`NtfySubscription` instance. 146 :rtype: NtfySubscription 147 148 """ 149 object.__setattr__( 150 self, 151 "_ws_conn", 152 connection 153 or ws_client.connect( 154 uri=self._url.unparse( 155 endpoint=(",".join(self.topics), "ws"), 156 scheme=("ws", "wss"), 157 ), 158 additional_headers={ 159 **self._auth_header, 160 **(self.filter.serialize() if self.filter else {}), 161 }, 162 ), 163 ) 164 self.messages.clear() 165 object.__setattr__( 166 self, "_thread", threading.Thread(target=self._thread_fn) 167 ) 168 if self._thread: 169 self._thread.start() 170 171 # this if/else is mostly here for type safety, as self._thread 172 # can be None, hence the pragma below 173 else: # pragma: no cover 174 raise ValueError( 175 "Attempted to start consumer thread, but the thread was not" 176 " successfully created" 177 ) 178 179 return self
180 181 def _thread_fn(self) -> None: 182 while True: 183 if self._ws_conn is None: 184 return 185 try: 186 raw = self._ws_conn.recv() 187 data = json.loads(raw) 188 self.messages.put( 189 _ReceivedMessage.from_json(data), block=False 190 ) 191 print(self.messages) 192 except json.JSONDecodeError as e: 193 logger.warning( 194 f"Failed to process JSON input ('{e}'): {raw!r}" 195 ) 196 continue 197 except (AttributeError, TypeError, ValueError) as e: 198 logger.warning( 199 "Failed to instantiated _ReceivedMessage instance" 200 f" ('{e}'): {raw!r}" 201 ) 202 continue 203 except ws_exc.ConnectionClosed: 204 return 205
[docs] 206 def close(self) -> None: 207 """Close the websocket connection, if it exists.""" 208 if self._ws_conn: # pragma: no branch 209 self._ws_conn.close() 210 object.__setattr__(self, "_ws_conn", None) 211 if self._thread and self._thread.is_alive(): # pragma: no branch 212 self._thread.join()