summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/http/login.py8
-rw-r--r--synapse/replication/slave/storage/_slaved_id_tracker.py22
-rw-r--r--synapse/replication/slave/storage/push_rule.py4
-rw-r--r--synapse/replication/tcp/streams/events.py6
4 files changed, 19 insertions, 21 deletions
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index daacc34cea..0db419ea57 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -46,8 +46,6 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
         is_guest,
         is_appservice_ghost,
         should_issue_refresh_token,
-        auth_provider_id,
-        auth_provider_session_id,
     ):
         """
         Args:
@@ -65,8 +63,6 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
             "is_guest": is_guest,
             "is_appservice_ghost": is_appservice_ghost,
             "should_issue_refresh_token": should_issue_refresh_token,
-            "auth_provider_id": auth_provider_id,
-            "auth_provider_session_id": auth_provider_session_id,
         }
 
     async def _handle_request(self, request, user_id):
@@ -77,8 +73,6 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
         is_guest = content["is_guest"]
         is_appservice_ghost = content["is_appservice_ghost"]
         should_issue_refresh_token = content["should_issue_refresh_token"]
-        auth_provider_id = content["auth_provider_id"]
-        auth_provider_session_id = content["auth_provider_session_id"]
 
         res = await self.registration_handler.register_device_inner(
             user_id,
@@ -87,8 +81,6 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
             is_guest,
             is_appservice_ghost=is_appservice_ghost,
             should_issue_refresh_token=should_issue_refresh_token,
-            auth_provider_id=auth_provider_id,
-            auth_provider_session_id=auth_provider_session_id,
         )
 
         return 200, res
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index fa132d10b4..8c1bf9227a 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -14,18 +14,10 @@
 from typing import List, Optional, Tuple
 
 from synapse.storage.database import LoggingDatabaseConnection
-from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id
+from synapse.storage.util.id_generators import _load_current_id
 
 
-class SlavedIdTracker(AbstractStreamIdTracker):
-    """Tracks the "current" stream ID of a stream with a single writer.
-
-    See `AbstractStreamIdTracker` for more details.
-
-    Note that this class does not work correctly when there are multiple
-    writers.
-    """
-
+class SlavedIdTracker:
     def __init__(
         self,
         db_conn: LoggingDatabaseConnection,
@@ -44,7 +36,17 @@ class SlavedIdTracker(AbstractStreamIdTracker):
         self._current = (max if self.step > 0 else min)(self._current, new_id)
 
     def get_current_token(self) -> int:
+        """
+
+        Returns:
+            int
+        """
         return self._current
 
     def get_current_token_for_writer(self, instance_name: str) -> int:
+        """Returns the position of the given writer.
+
+        For streams with single writers this is equivalent to
+        `get_current_token`.
+        """
         return self.get_current_token()
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 7541e21de9..4d5f862862 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import PushRulesStream
 from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
 
@@ -24,6 +25,9 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
         return self._push_rules_stream_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
+        # We assert this for the benefit of mypy
+        assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
+
         if stream_name == PushRulesStream.NAME:
             self._push_rules_stream_id_gen.advance(instance_name, token)
             for row in rows:
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index a390cfcb74..a030e9299e 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import heapq
 from collections.abc import Iterable
-from typing import TYPE_CHECKING, Optional, Tuple, Type
+from typing import TYPE_CHECKING, List, Optional, Tuple, Type
 
 import attr
 
@@ -157,7 +157,7 @@ class EventsStream(Stream):
 
         # now we fetch up to that many rows from the events table
 
-        event_rows = await self._store.get_all_new_forward_event_rows(
+        event_rows: List[Tuple] = await self._store.get_all_new_forward_event_rows(
             instance_name, from_token, current_token, target_row_count
         )
 
@@ -191,7 +191,7 @@ class EventsStream(Stream):
         # finally, fetch the ex-outliers rows. We assume there are few enough of these
         # not to bother with the limit.
 
-        ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
+        ex_outliers_rows: List[Tuple] = await self._store.get_ex_outlier_stream_rows(
             instance_name, from_token, upper_limit
         )