diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 0db419ea57..daacc34cea 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -46,6 +46,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
is_guest,
is_appservice_ghost,
should_issue_refresh_token,
+ auth_provider_id,
+ auth_provider_session_id,
):
"""
Args:
@@ -63,6 +65,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
"is_guest": is_guest,
"is_appservice_ghost": is_appservice_ghost,
"should_issue_refresh_token": should_issue_refresh_token,
+ "auth_provider_id": auth_provider_id,
+ "auth_provider_session_id": auth_provider_session_id,
}
async def _handle_request(self, request, user_id):
@@ -73,6 +77,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
is_guest = content["is_guest"]
is_appservice_ghost = content["is_appservice_ghost"]
should_issue_refresh_token = content["should_issue_refresh_token"]
+ auth_provider_id = content["auth_provider_id"]
+ auth_provider_session_id = content["auth_provider_session_id"]
res = await self.registration_handler.register_device_inner(
user_id,
@@ -81,6 +87,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
is_guest,
is_appservice_ghost=is_appservice_ghost,
should_issue_refresh_token=should_issue_refresh_token,
+ auth_provider_id=auth_provider_id,
+ auth_provider_session_id=auth_provider_session_id,
)
return 200, res
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index 8c1bf9227a..fa132d10b4 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -14,10 +14,18 @@
from typing import List, Optional, Tuple
from synapse.storage.database import LoggingDatabaseConnection
-from synapse.storage.util.id_generators import _load_current_id
+from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id
-class SlavedIdTracker:
+class SlavedIdTracker(AbstractStreamIdTracker):
+ """Tracks the "current" stream ID of a stream with a single writer.
+
+ See `AbstractStreamIdTracker` for more details.
+
+ Note that this class does not work correctly when there are multiple
+ writers.
+ """
+
def __init__(
self,
db_conn: LoggingDatabaseConnection,
@@ -36,17 +44,7 @@ class SlavedIdTracker:
self._current = (max if self.step > 0 else min)(self._current, new_id)
def get_current_token(self) -> int:
- """
-
- Returns:
- int
- """
return self._current
def get_current_token_for_writer(self, instance_name: str) -> int:
- """Returns the position of the given writer.
-
- For streams with single writers this is equivalent to
- `get_current_token`.
- """
return self.get_current_token()
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 4d5f862862..7541e21de9 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
@@ -25,9 +24,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- # We assert this for the benefit of mypy
- assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
-
if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(instance_name, token)
for row in rows:
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index a030e9299e..a390cfcb74 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -14,7 +14,7 @@
# limitations under the License.
import heapq
from collections.abc import Iterable
-from typing import TYPE_CHECKING, List, Optional, Tuple, Type
+from typing import TYPE_CHECKING, Optional, Tuple, Type
import attr
@@ -157,7 +157,7 @@ class EventsStream(Stream):
# now we fetch up to that many rows from the events table
- event_rows: List[Tuple] = await self._store.get_all_new_forward_event_rows(
+ event_rows = await self._store.get_all_new_forward_event_rows(
instance_name, from_token, current_token, target_row_count
)
@@ -191,7 +191,7 @@ class EventsStream(Stream):
# finally, fetch the ex-outliers rows. We assume there are few enough of these
# not to bother with the limit.
- ex_outliers_rows: List[Tuple] = await self._store.get_ex_outlier_stream_rows(
+ ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
instance_name, from_token, upper_limit
)
|