diff options
Diffstat (limited to 'synapse/handlers/receipts.py')
-rw-r--r-- | synapse/handlers/receipts.py | 30 |
1 files changed, 19 insertions, 11 deletions
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index e850e45e46..a9abdf42e0 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -13,17 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.appservice import ApplicationService from synapse.handlers._base import BaseHandler from synapse.types import JsonDict, ReadReceipt, get_domain_from_id +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class ReceiptsHandler(BaseHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.server_name = hs.config.server_name @@ -36,7 +39,7 @@ class ReceiptsHandler(BaseHandler): self.clock = self.hs.get_clock() self.state = hs.get_state_handler() - async def _received_remote_receipt(self, origin, content): + async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None: """Called when we receive an EDU of type m.receipt from a remote HS. """ receipts = [] @@ -63,11 +66,11 @@ class ReceiptsHandler(BaseHandler): await self._handle_new_receipts(receipts) - async def _handle_new_receipts(self, receipts): + async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool: """Takes a list of receipts, stores them and informs the notifier. """ - min_batch_id = None - max_batch_id = None + min_batch_id = None # type: Optional[int] + max_batch_id = None # type: Optional[int] for receipt in receipts: res = await self.store.insert_receipt( @@ -89,7 +92,8 @@ class ReceiptsHandler(BaseHandler): if max_batch_id is None or max_persisted_id > max_batch_id: max_batch_id = max_persisted_id - if min_batch_id is None: + # Either both of these should be None or neither. + if min_batch_id is None or max_batch_id is None: # no new receipts return False @@ -103,7 +107,9 @@ class ReceiptsHandler(BaseHandler): return True - async def received_client_receipt(self, room_id, receipt_type, user_id, event_id): + async def received_client_receipt( + self, room_id: str, receipt_type: str, user_id: str, event_id: str + ) -> None: """Called when a client tells us a local user has read up to the given event_id in the room. """ @@ -123,10 +129,12 @@ class ReceiptsHandler(BaseHandler): class ReceiptEventSource: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() - async def get_new_events(self, from_key, room_ids, **kwargs): + async def get_new_events( + self, from_key: int, room_ids: List[str], **kwargs + ) -> Tuple[List[JsonDict], int]: from_key = int(from_key) to_key = self.get_current_key() @@ -171,5 +179,5 @@ class ReceiptEventSource: return (events, to_key) - def get_current_key(self, direction="f"): + def get_current_key(self, direction: str = "f") -> int: return self.store.get_max_receipt_stream_id() |