summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-03-29 11:43:20 -0400
committerGitHub <noreply@github.com>2021-03-29 11:43:20 -0400
commitda75d2ea1f2784791399dbeba16be401e2bb37d2 (patch)
tree1ccbaf8cec32418a1cdbaae8b9197f6dc2c0bea1 /synapse/federation
parentUpdate the OIDC sample config (#9695) (diff)
downloadsynapse-da75d2ea1f2784791399dbeba16be401e2bb37d2.tar.xz
Add type hints for the federation sender. (#9681)
Includes an abstract base class which both the FederationSender
and the FederationRemoteSendQueue must implement.
Diffstat (limited to '')
-rw-r--r--synapse/federation/send_queue.py88
-rw-r--r--synapse/federation/sender/__init__.py116
2 files changed, 160 insertions, 44 deletions
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 3e993b428b..0c18c49abb 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -31,25 +31,39 @@ Events are replicated via a separate events stream.
 
 import logging
 from collections import namedtuple
-from typing import Dict, List, Tuple, Type
+from typing import (
+    TYPE_CHECKING,
+    Dict,
+    Hashable,
+    Iterable,
+    List,
+    Optional,
+    Sized,
+    Tuple,
+    Type,
+)
 
 from sortedcontainers import SortedDict
 
-from twisted.internet import defer
-
 from synapse.api.presence import UserPresenceState
+from synapse.federation.sender import AbstractFederationSender, FederationSender
 from synapse.metrics import LaterGauge
+from synapse.replication.tcp.streams.federation import FederationStream
+from synapse.types import JsonDict, ReadReceipt, RoomStreamToken
 from synapse.util.metrics import Measure
 
 from .units import Edu
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
-class FederationRemoteSendQueue:
+class FederationRemoteSendQueue(AbstractFederationSender):
     """A drop in replacement for FederationSender"""
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.server_name = hs.hostname
         self.clock = hs.get_clock()
         self.notifier = hs.get_notifier()
@@ -58,7 +72,7 @@ class FederationRemoteSendQueue:
         # We may have multiple federation sender instances, so we need to track
         # their positions separately.
         self._sender_instances = hs.config.worker.federation_shard_config.instances
-        self._sender_positions = {}
+        self._sender_positions = {}  # type: Dict[str, int]
 
         # Pending presence map user_id -> UserPresenceState
         self.presence_map = {}  # type: Dict[str, UserPresenceState]
@@ -71,7 +85,7 @@ class FederationRemoteSendQueue:
         # Stream position -> (user_id, destinations)
         self.presence_destinations = (
             SortedDict()
-        )  # type: SortedDict[int, Tuple[str, List[str]]]
+        )  # type: SortedDict[int, Tuple[str, Iterable[str]]]
 
         # (destination, key) -> EDU
         self.keyed_edu = {}  # type: Dict[Tuple[str, tuple], Edu]
@@ -94,7 +108,7 @@ class FederationRemoteSendQueue:
         # we make a new function, so we need to make a new function so the inner
         # lambda binds to the queue rather than to the name of the queue which
         # changes. ARGH.
-        def register(name, queue):
+        def register(name: str, queue: Sized) -> None:
             LaterGauge(
                 "synapse_federation_send_queue_%s_size" % (queue_name,),
                 "",
@@ -115,13 +129,13 @@ class FederationRemoteSendQueue:
 
         self.clock.looping_call(self._clear_queue, 30 * 1000)
 
-    def _next_pos(self):
+    def _next_pos(self) -> int:
         pos = self.pos
         self.pos += 1
         self.pos_time[self.clock.time_msec()] = pos
         return pos
 
-    def _clear_queue(self):
+    def _clear_queue(self) -> None:
         """Clear the queues for anything older than N minutes"""
 
         FIVE_MINUTES_AGO = 5 * 60 * 1000
@@ -138,7 +152,7 @@ class FederationRemoteSendQueue:
 
         self._clear_queue_before_pos(position_to_delete)
 
-    def _clear_queue_before_pos(self, position_to_delete):
+    def _clear_queue_before_pos(self, position_to_delete: int) -> None:
         """Clear all the queues from before a given position"""
         with Measure(self.clock, "send_queue._clear"):
             # Delete things out of presence maps
@@ -188,13 +202,18 @@ class FederationRemoteSendQueue:
             for key in keys[:i]:
                 del self.edus[key]
 
-    def notify_new_events(self, max_token):
+    def notify_new_events(self, max_token: RoomStreamToken) -> None:
         """As per FederationSender"""
-        # We don't need to replicate this as it gets sent down a different
-        # stream.
-        pass
+        # This should never get called.
+        raise NotImplementedError()
 
-    def build_and_send_edu(self, destination, edu_type, content, key=None):
+    def build_and_send_edu(
+        self,
+        destination: str,
+        edu_type: str,
+        content: JsonDict,
+        key: Optional[Hashable] = None,
+    ) -> None:
         """As per FederationSender"""
         if destination == self.server_name:
             logger.info("Not sending EDU to ourselves")
@@ -218,38 +237,39 @@ class FederationRemoteSendQueue:
 
         self.notifier.on_new_replication_data()
 
-    def send_read_receipt(self, receipt):
+    async def send_read_receipt(self, receipt: ReadReceipt) -> None:
         """As per FederationSender
 
         Args:
-            receipt (synapse.types.ReadReceipt):
+            receipt:
         """
         # nothing to do here: the replication listener will handle it.
-        return defer.succeed(None)
 
-    def send_presence(self, states):
+    def send_presence(self, states: List[UserPresenceState]) -> None:
         """As per FederationSender
 
         Args:
-            states (list(UserPresenceState))
+            states
         """
         pos = self._next_pos()
 
         # We only want to send presence for our own users, so lets always just
         # filter here just in case.
-        local_states = list(filter(lambda s: self.is_mine_id(s.user_id), states))
+        local_states = [s for s in states if self.is_mine_id(s.user_id)]
 
         self.presence_map.update({state.user_id: state for state in local_states})
         self.presence_changed[pos] = [state.user_id for state in local_states]
 
         self.notifier.on_new_replication_data()
 
-    def send_presence_to_destinations(self, states, destinations):
+    def send_presence_to_destinations(
+        self, states: Iterable[UserPresenceState], destinations: Iterable[str]
+    ) -> None:
         """As per FederationSender
 
         Args:
-            states (list[UserPresenceState])
-            destinations (list[str])
+            states
+            destinations
         """
         for state in states:
             pos = self._next_pos()
@@ -258,15 +278,18 @@ class FederationRemoteSendQueue:
 
         self.notifier.on_new_replication_data()
 
-    def send_device_messages(self, destination):
+    def send_device_messages(self, destination: str) -> None:
         """As per FederationSender"""
         # We don't need to replicate this as it gets sent down a different
         # stream.
 
-    def get_current_token(self):
+    def wake_destination(self, server: str) -> None:
+        pass
+
+    def get_current_token(self) -> int:
         return self.pos - 1
 
-    def federation_ack(self, instance_name, token):
+    def federation_ack(self, instance_name: str, token: int) -> None:
         if self._sender_instances:
             # If we have configured multiple federation sender instances we need
             # to track their positions separately, and only clear the queue up
@@ -504,13 +527,16 @@ ParsedFederationStreamData = namedtuple(
 )
 
 
-def process_rows_for_federation(transaction_queue, rows):
+def process_rows_for_federation(
+    transaction_queue: FederationSender,
+    rows: List[FederationStream.FederationStreamRow],
+) -> None:
     """Parse a list of rows from the federation stream and put them in the
     transaction queue ready for sending to the relevant homeservers.
 
     Args:
-        transaction_queue (FederationSender)
-        rows (list(synapse.replication.tcp.streams.federation.FederationStream.FederationStreamRow))
+        transaction_queue
+        rows
     """
 
     # The federation stream contains a bunch of different types of
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 24ebc4b803..8babb1ebbe 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -13,14 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import abc
 import logging
-from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set, Tuple
 
 from prometheus_client import Counter
 
 from twisted.internet import defer
 
-import synapse
 import synapse.metrics
 from synapse.api.presence import UserPresenceState
 from synapse.events import EventBase
@@ -40,9 +40,12 @@ from synapse.metrics import (
     events_processed_counter,
 )
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import ReadReceipt, RoomStreamToken
+from synapse.types import JsonDict, ReadReceipt, RoomStreamToken
 from synapse.util.metrics import Measure, measure_func
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 sent_pdus_destination_dist_count = Counter(
@@ -65,8 +68,91 @@ CATCH_UP_STARTUP_DELAY_SEC = 15
 CATCH_UP_STARTUP_INTERVAL_SEC = 5
 
 
-class FederationSender:
-    def __init__(self, hs: "synapse.server.HomeServer"):
+class AbstractFederationSender(metaclass=abc.ABCMeta):
+    @abc.abstractmethod
+    def notify_new_events(self, max_token: RoomStreamToken) -> None:
+        """This gets called when we have some new events we might want to
+        send out to other servers.
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    async def send_read_receipt(self, receipt: ReadReceipt) -> None:
+        """Send a RR to any other servers in the room
+
+        Args:
+            receipt: receipt to be sent
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def send_presence(self, states: List[UserPresenceState]) -> None:
+        """Send the new presence states to the appropriate destinations.
+
+        This actually queues up the presence states ready for sending and
+        triggers a background task to process them and send out the transactions.
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def send_presence_to_destinations(
+        self, states: Iterable[UserPresenceState], destinations: Iterable[str]
+    ) -> None:
+        """Send the given presence states to the given destinations.
+
+        Args:
+            destinations:
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def build_and_send_edu(
+        self,
+        destination: str,
+        edu_type: str,
+        content: JsonDict,
+        key: Optional[Hashable] = None,
+    ) -> None:
+        """Construct an Edu object, and queue it for sending
+
+        Args:
+            destination: name of server to send to
+            edu_type: type of EDU to send
+            content: content of EDU
+            key: clobbering key for this edu
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def send_device_messages(self, destination: str) -> None:
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def wake_destination(self, destination: str) -> None:
+        """Called when we want to retry sending transactions to a remote.
+
+        This is mainly useful if the remote server has been down and we think it
+        might have come back.
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def get_current_token(self) -> int:
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def federation_ack(self, instance_name: str, token: int) -> None:
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    async def get_replication_rows(
+        self, instance_name: str, from_token: int, to_token: int, target_row_count: int
+    ) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
+        raise NotImplementedError()
+
+
+class FederationSender(AbstractFederationSender):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.server_name = hs.hostname
 
@@ -432,7 +518,7 @@ class FederationSender:
             queue.flush_read_receipts_for_room(room_id)
 
     @preserve_fn  # the caller should not yield on this
-    async def send_presence(self, states: List[UserPresenceState]):
+    async def send_presence(self, states: List[UserPresenceState]) -> None:
         """Send the new presence states to the appropriate destinations.
 
         This actually queues up the presence states ready for sending and
@@ -494,7 +580,7 @@ class FederationSender:
             self._get_per_destination_queue(destination).send_presence(states)
 
     @measure_func("txnqueue._process_presence")
-    async def _process_presence_inner(self, states: List[UserPresenceState]):
+    async def _process_presence_inner(self, states: List[UserPresenceState]) -> None:
         """Given a list of states populate self.pending_presence_by_dest and
         poke to send a new transaction to each destination
         """
@@ -516,9 +602,9 @@ class FederationSender:
         self,
         destination: str,
         edu_type: str,
-        content: dict,
+        content: JsonDict,
         key: Optional[Hashable] = None,
-    ):
+    ) -> None:
         """Construct an Edu object, and queue it for sending
 
         Args:
@@ -545,7 +631,7 @@ class FederationSender:
 
         self.send_edu(edu, key)
 
-    def send_edu(self, edu: Edu, key: Optional[Hashable]):
+    def send_edu(self, edu: Edu, key: Optional[Hashable]) -> None:
         """Queue an EDU for sending
 
         Args:
@@ -563,7 +649,7 @@ class FederationSender:
         else:
             queue.send_edu(edu)
 
-    def send_device_messages(self, destination: str):
+    def send_device_messages(self, destination: str) -> None:
         if destination == self.server_name:
             logger.warning("Not sending device update to ourselves")
             return
@@ -575,7 +661,7 @@ class FederationSender:
 
         self._get_per_destination_queue(destination).attempt_new_transaction()
 
-    def wake_destination(self, destination: str):
+    def wake_destination(self, destination: str) -> None:
         """Called when we want to retry sending transactions to a remote.
 
         This is mainly useful if the remote server has been down and we think it
@@ -599,6 +685,10 @@ class FederationSender:
         # to a worker.
         return 0
 
+    def federation_ack(self, instance_name: str, token: int) -> None:
+        # It is not expected that this gets called on FederationSender.
+        raise NotImplementedError()
+
     @staticmethod
     async def get_replication_rows(
         instance_name: str, from_token: int, to_token: int, target_row_count: int
@@ -607,7 +697,7 @@ class FederationSender:
         # to a worker.
         return [], 0, False
 
-    async def _wake_destinations_needing_catchup(self):
+    async def _wake_destinations_needing_catchup(self) -> None:
         """
         Wakes up destinations that need catch-up and are not currently being
         backed off from.