diff --git a/changelog.d/7219.misc b/changelog.d/7219.misc
index 4af5da8646..dbf7a530be 100644
--- a/changelog.d/7219.misc
+++ b/changelog.d/7219.misc
@@ -1 +1 @@
-Add typing information to federation server code.
+Add typing annotations in `synapse.federation`.
diff --git a/changelog.d/7281.misc b/changelog.d/7281.misc
new file mode 100644
index 0000000000..86ad511e19
--- /dev/null
+++ b/changelog.d/7281.misc
@@ -0,0 +1 @@
+Add MultiWriterIdGenerator to support multiple concurrent writers of streams.
diff --git a/changelog.d/7374.misc b/changelog.d/7374.misc
new file mode 100644
index 0000000000..676f285377
--- /dev/null
+++ b/changelog.d/7374.misc
@@ -0,0 +1 @@
+Move catchup of replication streams logic to worker.
diff --git a/changelog.d/7382.misc b/changelog.d/7382.misc
new file mode 100644
index 0000000000..dbf7a530be
--- /dev/null
+++ b/changelog.d/7382.misc
@@ -0,0 +1 @@
+Add typing annotations in `synapse.federation`.
diff --git a/changelog.d/7396.misc b/changelog.d/7396.misc
new file mode 100644
index 0000000000..290d2befc7
--- /dev/null
+++ b/changelog.d/7396.misc
@@ -0,0 +1 @@
+Convert the room handler to async/await.
diff --git a/changelog.d/7428.misc b/changelog.d/7428.misc
new file mode 100644
index 0000000000..db5ff76ded
--- /dev/null
+++ b/changelog.d/7428.misc
@@ -0,0 +1 @@
+Improve performance of `get_e2e_cross_signing_key`.
diff --git a/changelog.d/7429.misc b/changelog.d/7429.misc
new file mode 100644
index 0000000000..3c25cd9917
--- /dev/null
+++ b/changelog.d/7429.misc
@@ -0,0 +1 @@
+Improve performance of `mark_as_sent_devices_by_remote`.
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index e1700ca8aa..52f4f54215 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -31,6 +31,7 @@ Events are replicated via a separate events stream.
import logging
from collections import namedtuple
+from typing import Dict, List, Tuple, Type
from six import iteritems
@@ -56,21 +57,35 @@ class FederationRemoteSendQueue(object):
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
- self.presence_map = {} # Pending presence map user_id -> UserPresenceState
- self.presence_changed = SortedDict() # Stream position -> list[user_id]
+ # Pending presence map user_id -> UserPresenceState
+ self.presence_map = {} # type: Dict[str, UserPresenceState]
+
+ # Stream position -> list[user_id]
+ self.presence_changed = SortedDict() # type: SortedDict[int, List[str]]
# Stores the destinations we need to explicitly send presence to about a
# given user.
# Stream position -> (user_id, destinations)
- self.presence_destinations = SortedDict()
+ self.presence_destinations = (
+ SortedDict()
+ ) # type: SortedDict[int, Tuple[str, List[str]]]
+
+ # (destination, key) -> EDU
+ self.keyed_edu = {} # type: Dict[Tuple[str, tuple], Edu]
- self.keyed_edu = {} # (destination, key) -> EDU
- self.keyed_edu_changed = SortedDict() # stream position -> (destination, key)
+ # stream position -> (destination, key)
+ self.keyed_edu_changed = (
+ SortedDict()
+ ) # type: SortedDict[int, Tuple[str, tuple]]
- self.edus = SortedDict() # stream position -> Edu
+ self.edus = SortedDict() # type: SortedDict[int, Edu]
+ # stream ID for the next entry into presence_changed/keyed_edu_changed/edus.
self.pos = 1
- self.pos_time = SortedDict()
+
+ # map from stream ID to the time that stream entry was generated, so that we
+ # can clear out entries after a while
+ self.pos_time = SortedDict() # type: SortedDict[int, int]
# EVERYTHING IS SAD. In particular, python only makes new scopes when
# we make a new function, so we need to make a new function so the inner
@@ -158,8 +173,10 @@ class FederationRemoteSendQueue(object):
for edu_key in self.keyed_edu_changed.values():
live_keys.add(edu_key)
- to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys]
- for edu_key in to_del:
+ keys_to_del = [
+ edu_key for edu_key in self.keyed_edu if edu_key not in live_keys
+ ]
+ for edu_key in keys_to_del:
del self.keyed_edu[edu_key]
# Delete things out of edu map
@@ -250,19 +267,23 @@ class FederationRemoteSendQueue(object):
self._clear_queue_before_pos(token)
async def get_replication_rows(
- self, from_token, to_token, limit, federation_ack=None
- ):
+ self, instance_name: str, from_token: int, to_token: int, target_row_count: int
+ ) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
"""Get rows to be sent over federation between the two tokens
Args:
- from_token (int)
- to_token(int)
- limit (int)
- federation_ack (int): Optional. The position where the worker is
- explicitly acknowledged it has handled. Allows us to drop
- data from before that point
+ instance_name: the name of the current process
+ from_token: the previous stream token: the starting point for fetching the
+ updates
+ to_token: the new stream token: the point to get updates up to
+ target_row_count: a target for the number of rows to be returned.
+
+ Returns: a triplet `(updates, new_last_token, limited)`, where:
+ * `updates` is a list of `(token, row)` entries.
+ * `new_last_token` is the new position in stream.
+ * `limited` is whether there are more updates to fetch.
"""
- # TODO: Handle limit.
+ # TODO: Handle target_row_count.
# To handle restarts where we wrap around
if from_token > self.pos:
@@ -270,12 +291,7 @@ class FederationRemoteSendQueue(object):
# list of tuple(int, BaseFederationRow), where the first is the position
# of the federation stream.
- rows = []
-
- # There should be only one reader, so lets delete everything its
- # acknowledged its seen.
- if federation_ack:
- self._clear_queue_before_pos(federation_ack)
+ rows = [] # type: List[Tuple[int, BaseFederationRow]]
# Fetch changed presence
i = self.presence_changed.bisect_right(from_token)
@@ -332,7 +348,11 @@ class FederationRemoteSendQueue(object):
# Sort rows based on pos
rows.sort()
- return [(pos, row.TypeId, row.to_data()) for pos, row in rows]
+ return (
+ [(pos, (row.TypeId, row.to_data())) for pos, row in rows],
+ to_token,
+ False,
+ )
class BaseFederationRow(object):
@@ -341,7 +361,7 @@ class BaseFederationRow(object):
Specifies how to identify, serialize and deserialize the different types.
"""
- TypeId = None # Unique string that ids the type. Must be overriden in sub classes.
+ TypeId = "" # Unique string that ids the type. Must be overriden in sub classes.
@staticmethod
def from_data(data):
@@ -454,10 +474,14 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
-TypeToRow = {
- Row.TypeId: Row
- for Row in (PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow,)
-}
+_rowtypes = (
+ PresenceRow,
+ PresenceDestinationsRow,
+ KeyedEduRow,
+ EduRow,
+) # type: Tuple[Type[BaseFederationRow], ...]
+
+TypeToRow = {Row.TypeId: Row for Row in _rowtypes}
ParsedFederationStreamData = namedtuple(
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index a477578e44..d473576902 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Dict, Hashable, Iterable, List, Optional, Set
+from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple
from six import itervalues
@@ -498,14 +498,16 @@ class FederationSender(object):
self._get_per_destination_queue(destination).attempt_new_transaction()
- def get_current_token(self) -> int:
+ @staticmethod
+ def get_current_token() -> int:
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
return 0
+ @staticmethod
async def get_replication_rows(
- self, from_token, to_token, limit, federation_ack=None
- ):
+ instance_name: str, from_token: int, to_token: int, target_row_count: int
+ ) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
- return []
+ return [], 0, False
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index e13cd20ffa..276a2b596f 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -15,11 +15,10 @@
# limitations under the License.
import datetime
import logging
-from typing import Dict, Hashable, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Tuple
from prometheus_client import Counter
-import synapse.server
from synapse.api.errors import (
FederationDeniedError,
HttpResponseException,
@@ -34,6 +33,9 @@ from synapse.storage.presence import UserPresenceState
from synapse.types import ReadReceipt
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
+if TYPE_CHECKING:
+ import synapse.server
+
# This is defined in the Matrix spec and enforced by the receiver.
MAX_EDUS_PER_TRANSACTION = 100
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 3c2a02a3b3..a2752a54a5 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -13,11 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import List
+from typing import TYPE_CHECKING, List
from canonicaljson import json
-import synapse.server
from synapse.api.errors import HttpResponseException
from synapse.events import EventBase
from synapse.federation.persistence import TransactionActions
@@ -31,6 +30,9 @@ from synapse.logging.opentracing import (
)
from synapse.util.metrics import measure_func
+if TYPE_CHECKING:
+ import synapse.server
+
logger = logging.getLogger(__name__)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index da12df7f53..73f9eeb399 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -25,8 +25,6 @@ from collections import OrderedDict
from six import iteritems, string_types
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
@@ -103,8 +101,7 @@ class RoomCreationHandler(BaseHandler):
self.third_party_event_rules = hs.get_third_party_event_rules()
- @defer.inlineCallbacks
- def upgrade_room(
+ async def upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
):
"""Replace a room with a new room with a different version
@@ -117,7 +114,7 @@ class RoomCreationHandler(BaseHandler):
Returns:
Deferred[unicode]: the new room id
"""
- yield self.ratelimit(requester)
+ await self.ratelimit(requester)
user_id = requester.user.to_string()
@@ -138,7 +135,7 @@ class RoomCreationHandler(BaseHandler):
# If this user has sent multiple upgrade requests for the same room
# and one of them is not complete yet, cache the response and
# return it to all subsequent requests
- ret = yield self._upgrade_response_cache.wrap(
+ ret = await self._upgrade_response_cache.wrap(
(old_room_id, user_id),
self._upgrade_room,
requester,
@@ -856,8 +853,7 @@ class RoomCreationHandler(BaseHandler):
for (etype, state_key), content in initial_state.items():
await send(etype=etype, state_key=state_key, content=content)
- @defer.inlineCallbacks
- def _generate_room_id(
+ async def _generate_room_id(
self, creator_id: str, is_public: str, room_version: RoomVersion,
):
# autogen room IDs and try to create it. We may clash, so just
@@ -869,7 +865,7 @@ class RoomCreationHandler(BaseHandler):
gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
if isinstance(gen_room_id, bytes):
gen_room_id = gen_room_id.decode("utf-8")
- yield self.store.store_room(
+ await self.store.store_room(
room_id=gen_room_id,
room_creator_user_id=creator_id,
is_public=is_public,
@@ -888,8 +884,7 @@ class RoomContextHandler(object):
self.storage = hs.get_storage()
self.state_store = self.storage.state
- @defer.inlineCallbacks
- def get_event_context(self, user, room_id, event_id, limit, event_filter):
+ async def get_event_context(self, user, room_id, event_id, limit, event_filter):
"""Retrieves events, pagination tokens and state around a given event
in a room.
@@ -908,7 +903,7 @@ class RoomContextHandler(object):
before_limit = math.floor(limit / 2.0)
after_limit = limit - before_limit
- users = yield self.store.get_users_in_room(room_id)
+ users = await self.store.get_users_in_room(room_id)
is_peeking = user.to_string() not in users
def filter_evts(events):
@@ -916,17 +911,17 @@ class RoomContextHandler(object):
self.storage, user.to_string(), events, is_peeking=is_peeking
)
- event = yield self.store.get_event(
+ event = await self.store.get_event(
event_id, get_prev_content=True, allow_none=True
)
if not event:
return None
- filtered = yield (filter_evts([event]))
+ filtered = await filter_evts([event])
if not filtered:
raise AuthError(403, "You don't have permission to access that event.")
- results = yield self.store.get_events_around(
+ results = await self.store.get_events_around(
room_id, event_id, before_limit, after_limit, event_filter
)
@@ -934,8 +929,8 @@ class RoomContextHandler(object):
results["events_before"] = event_filter.filter(results["events_before"])
results["events_after"] = event_filter.filter(results["events_after"])
- results["events_before"] = yield filter_evts(results["events_before"])
- results["events_after"] = yield filter_evts(results["events_after"])
+ results["events_before"] = await filter_evts(results["events_before"])
+ results["events_after"] = await filter_evts(results["events_after"])
# filter_evts can return a pruned event in case the user is allowed to see that
# there's something there but not see the content, so use the event that's in
# `filtered` rather than the event we retrieved from the datastore.
@@ -962,7 +957,7 @@ class RoomContextHandler(object):
# first? Shouldn't we be consistent with /sync?
# https://github.com/matrix-org/matrix-doc/issues/687
- state = yield self.state_store.get_state_for_events(
+ state = await self.state_store.get_state_for_events(
[last_event_id], state_filter=state_filter
)
@@ -970,7 +965,7 @@ class RoomContextHandler(object):
if event_filter:
state_events = event_filter.filter(state_events)
- results["state"] = yield filter_evts(state_events)
+ results["state"] = await filter_evts(state_events)
# We use a dummy token here as we only care about the room portion of
# the token, which we replace.
@@ -989,13 +984,12 @@ class RoomEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def get_new_events(
+ async def get_new_events(
self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None
):
# We just ignore the key for now.
- to_key = yield self.get_current_key()
+ to_key = await self.get_current_key()
from_token = RoomStreamToken.parse(from_key)
if from_token.topological:
@@ -1008,11 +1002,11 @@ class RoomEventSource(object):
# See https://github.com/matrix-org/matrix-doc/issues/1144
raise NotImplementedError()
else:
- room_events = yield self.store.get_membership_changes_for_user(
+ room_events = await self.store.get_membership_changes_for_user(
user.to_string(), from_key, to_key
)
- room_to_events = yield self.store.get_room_events_stream_for_rooms(
+ room_to_events = await self.store.get_room_events_stream_for_rooms(
room_ids=room_ids,
from_key=from_key,
to_key=to_key,
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 33d2f589ac..b690abedad 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -80,7 +80,7 @@ class ReplicationStreamer(object):
for stream in STREAMS_MAP.values():
if stream == FederationStream and hs.config.send_federation:
# We only support federation stream if federation sending
- # hase been disabled on the master.
+ # has been disabled on the master.
continue
self.streams.append(stream(hs))
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b0f87c365b..084604e8b0 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -104,7 +104,8 @@ class Stream(object):
implemented by subclasses.
current_token_function is called to get the current token of the underlying
- stream.
+ stream. It is only meaningful on the process that is the source of the
+ replication stream (ie, usually the master).
update_function is called to get updates for this stream between a pair of
stream tokens. See the UpdateFunction type definition for more info.
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index e8bd52e389..b0505b8a2c 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,7 +15,7 @@
# limitations under the License.
from collections import namedtuple
-from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
+from synapse.replication.tcp.streams._base import Stream, make_http_update_function
class FederationStream(Stream):
@@ -35,21 +35,33 @@ class FederationStream(Stream):
ROW_TYPE = FederationStreamRow
def __init__(self, hs):
- # Not all synapse instances will have a federation sender instance,
- # whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
- # so we stub the stream out when that is the case.
- if hs.config.worker_app is None or hs.should_send_federation():
+ if hs.config.worker_app is None:
+ # master process: get updates from the FederationRemoteSendQueue.
+ # (if the master is configured to send federation itself, federation_sender
+ # will be a real FederationSender, which has stubs for current_token and
+ # get_replication_rows.)
federation_sender = hs.get_federation_sender()
current_token = federation_sender.get_current_token
- update_function = db_query_to_update_function(
- federation_sender.get_replication_rows
- )
+ update_function = federation_sender.get_replication_rows
+
+ elif hs.should_send_federation():
+ # federation sender: Query master process
+ update_function = make_http_update_function(hs, self.NAME)
+ current_token = self._stub_current_token
+
else:
- current_token = lambda: 0
+ # other worker: stub out the update function (we're not interested in
+ # any updates so when we get a POSITION we do nothing)
update_function = self._stub_update_function
+ current_token = self._stub_current_token
super().__init__(hs.get_instance_name(), current_token, update_function)
@staticmethod
+ def _stub_current_token():
+ # dummy current-token method for use on workers
+ return 0
+
+ @staticmethod
async def _stub_update_function(instance_name, from_token, upto_token, limit):
return [], upto_token, False
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 03f5141e6c..fe6d6ecfe0 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -55,6 +55,10 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
+BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX = (
+ "drop_device_lists_outbound_last_success_non_unique_idx"
+)
+
class DeviceWorkerStore(SQLBaseStore):
def get_device(self, user_id, device_id):
@@ -342,32 +346,23 @@ class DeviceWorkerStore(SQLBaseStore):
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
# We update the device_lists_outbound_last_success with the successfully
- # poked users. We do the join to see which users need to be inserted and
- # which updated.
+ # poked users.
sql = """
- SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
+ SELECT user_id, coalesce(max(o.stream_id), 0)
FROM device_lists_outbound_pokes as o
- LEFT JOIN device_lists_outbound_last_success as s
- USING (destination, user_id)
WHERE destination = ? AND o.stream_id <= ?
GROUP BY user_id
"""
txn.execute(sql, (destination, stream_id))
rows = txn.fetchall()
- sql = """
- UPDATE device_lists_outbound_last_success
- SET stream_id = ?
- WHERE destination = ? AND user_id = ?
- """
- txn.executemany(sql, ((row[1], destination, row[0]) for row in rows if row[2]))
-
- sql = """
- INSERT INTO device_lists_outbound_last_success
- (destination, user_id, stream_id) VALUES (?, ?, ?)
- """
- txn.executemany(
- sql, ((destination, row[0], row[1]) for row in rows if not row[2])
+ self.db.simple_upsert_many_txn(
+ txn=txn,
+ table="device_lists_outbound_last_success",
+ key_names=("destination", "user_id"),
+ key_values=((destination, user_id) for user_id, _ in rows),
+ value_names=("stream_id",),
+ value_values=((stream_id,) for _, stream_id in rows),
)
# Delete all sent outbound pokes
@@ -725,6 +720,21 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
)
+ # create a unique index on device_lists_outbound_last_success
+ self.db.updates.register_background_index_update(
+ "device_lists_outbound_last_success_unique_idx",
+ index_name="device_lists_outbound_last_success_unique_idx",
+ table="device_lists_outbound_last_success",
+ columns=["destination", "user_id"],
+ unique=True,
+ )
+
+ # once that completes, we can remove the old non-unique index.
+ self.db.updates.register_background_update_handler(
+ BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX,
+ self._drop_device_lists_outbound_last_success_non_unique_idx,
+ )
+
@defer.inlineCallbacks
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
def f(conn):
@@ -799,6 +809,20 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
return rows
+ async def _drop_device_lists_outbound_last_success_non_unique_idx(
+ self, progress, batch_size
+ ):
+ def f(txn):
+ txn.execute("DROP INDEX IF EXISTS device_lists_outbound_last_success_idx")
+
+ await self.db.runInteraction(
+ "drop_device_lists_outbound_last_success_non_unique_idx", f,
+ )
+ await self.db.updates._end_background_update(
+ BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX
+ )
+ return 1
+
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def __init__(self, database: Database, db_conn, hs):
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index bcf746b7ef..20698bfd16 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -25,7 +25,9 @@ from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import make_in_list_sql_clause
from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.iterutils import batch_iter
class EndToEndKeyWorkerStore(SQLBaseStore):
@@ -268,53 +270,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
- def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None):
- """Returns a user's cross-signing key.
-
- Args:
- txn (twisted.enterprise.adbapi.Connection): db connection
- user_id (str): the user whose key is being requested
- key_type (str): the type of key that is being requested: either 'master'
- for a master key, 'self_signing' for a self-signing key, or
- 'user_signing' for a user-signing key
- from_user_id (str): if specified, signatures made by this user on
- the key will be included in the result
-
- Returns:
- dict of the key data or None if not found
- """
- sql = (
- "SELECT keydata "
- " FROM e2e_cross_signing_keys "
- " WHERE user_id = ? AND keytype = ? ORDER BY stream_id DESC LIMIT 1"
- )
- txn.execute(sql, (user_id, key_type))
- row = txn.fetchone()
- if not row:
- return None
- key = json.loads(row[0])
-
- device_id = None
- for k in key["keys"].values():
- device_id = k
-
- if from_user_id is not None:
- sql = (
- "SELECT key_id, signature "
- " FROM e2e_cross_signing_signatures "
- " WHERE user_id = ? "
- " AND target_user_id = ? "
- " AND target_device_id = ? "
- )
- txn.execute(sql, (from_user_id, user_id, device_id))
- row = txn.fetchone()
- if row:
- key.setdefault("signatures", {}).setdefault(from_user_id, {})[
- row[0]
- ] = row[1]
-
- return key
-
+ @defer.inlineCallbacks
def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
"""Returns a user's cross-signing key.
@@ -329,13 +285,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
Returns:
dict of the key data or None if not found
"""
- return self.db.runInteraction(
- "get_e2e_cross_signing_key",
- self._get_e2e_cross_signing_key_txn,
- user_id,
- key_type,
- from_user_id,
- )
+ res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
+ user_keys = res.get(user_id)
+ if not user_keys:
+ return None
+ return user_keys.get(key_type)
@cached(num_args=1)
def _get_bare_e2e_cross_signing_keys(self, user_id):
@@ -391,26 +345,24 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"""
result = {}
- batch_size = 100
- chunks = [
- user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size)
- ]
- for user_chunk in chunks:
- sql = """
+ for user_chunk in batch_iter(user_ids, 100):
+ clause, params = make_in_list_sql_clause(
+ txn.database_engine, "k.user_id", user_chunk
+ )
+ sql = (
+ """
SELECT k.user_id, k.keytype, k.keydata, k.stream_id
FROM e2e_cross_signing_keys k
INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
FROM e2e_cross_signing_keys
GROUP BY user_id, keytype) s
USING (user_id, stream_id, keytype)
- WHERE k.user_id IN (%s)
- """ % (
- ",".join("?" for u in user_chunk),
+ WHERE
+ """
+ + clause
)
- query_params = []
- query_params.extend(user_chunk)
- txn.execute(sql, query_params)
+ txn.execute(sql, params)
rows = self.db.cursor_to_dict(txn)
for row in rows:
@@ -453,15 +405,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
device_id = k
devices[(user_id, device_id)] = key_type
- device_list = list(devices)
-
- # split into batches
- batch_size = 100
- chunks = [
- device_list[i : i + batch_size]
- for i in range(0, len(device_list), batch_size)
- ]
- for user_chunk in chunks:
+ for batch in batch_iter(devices.keys(), size=100):
sql = """
SELECT target_user_id, target_device_id, key_id, signature
FROM e2e_cross_signing_signatures
@@ -469,11 +413,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
AND (%s)
""" % (
" OR ".join(
- "(target_user_id = ? AND target_device_id = ?)" for d in devices
+ "(target_user_id = ? AND target_device_id = ?)" for _ in batch
)
)
query_params = [from_user_id]
- for item in devices:
+ for item in batch:
# item is a (user_id, device_id) tuple
query_params.extend(item)
diff --git a/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql b/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql
new file mode 100644
index 0000000000..d5e6deb878
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql
@@ -0,0 +1,28 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- register a background update which will create a unique index on
+-- device_lists_outbound_last_success
+INSERT into background_updates (ordering, update_name, progress_json)
+ VALUES (5804, 'device_lists_outbound_last_success_unique_idx', '{}');
+
+-- once that completes, we can drop the old index.
+INSERT into background_updates (ordering, update_name, progress_json, depends_on)
+ VALUES (
+ 5804,
+ 'drop_device_lists_outbound_last_success_non_unique_idx',
+ '{}',
+ 'device_lists_outbound_last_success_unique_idx'
+ );
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a7cd97b0b0..2b635d6ca0 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -49,6 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
+from synapse.types import Collection
from synapse.util.stringutils import exception_to_unicode
logger = logging.getLogger(__name__)
@@ -78,6 +79,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
"device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
"device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
"event_search": "event_search_event_id_idx",
+ "device_lists_outbound_last_success": "device_lists_outbound_last_success_unique_idx",
}
@@ -889,20 +891,24 @@ class Database(object):
txn.execute(sql, list(allvalues.values()))
def simple_upsert_many_txn(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ key_names: Collection[str],
+ key_values: Collection[Iterable[Any]],
+ value_names: Collection[str],
+ value_values: Iterable[Iterable[str]],
+ ) -> None:
"""
Upsert, many times.
Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
+ table: The table to upsert into
+ key_names: The key column names.
+ key_values: A list of each row's key column values.
+ value_names: The value column names
+ value_values: A list of each row's value column values.
+ Ignored if value_names is empty.
"""
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
return self.simple_upsert_many_txn_native_upsert(
@@ -914,20 +920,24 @@ class Database(object):
)
def simple_upsert_many_txn_emulated(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ key_names: Iterable[str],
+ key_values: Collection[Iterable[Any]],
+ value_names: Collection[str],
+ value_values: Iterable[Iterable[str]],
+ ) -> None:
"""
Upsert, many times, but without native UPSERT support or batching.
Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
+ table: The table to upsert into
+ key_names: The key column names.
+ key_values: A list of each row's key column values.
+ value_names: The value column names
+ value_values: A list of each row's value column values.
+ Ignored if value_names is empty.
"""
# No value columns, therefore make a blank list so that the following
# zip() works correctly.
@@ -941,20 +951,24 @@ class Database(object):
self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
def simple_upsert_many_txn_native_upsert(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ key_names: Collection[str],
+ key_values: Collection[Iterable[Any]],
+ value_names: Collection[str],
+ value_values: Iterable[Iterable[Any]],
+ ) -> None:
"""
Upsert, many times, using batching where possible.
Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
+ table: The table to upsert into
+ key_names: The key column names.
+ key_values: A list of each row's key column values.
+ value_names: The value column names
+ value_values: A list of each row's value column values.
+ Ignored if value_names is empty.
"""
allnames = [] # type: List[str]
allnames.extend(key_names)
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 9d851beaa5..86d04ea9ac 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -16,6 +16,11 @@
import contextlib
import threading
from collections import deque
+from typing import Dict, Set, Tuple
+
+from typing_extensions import Deque
+
+from synapse.storage.database import Database, LoggingTransaction
class IdGenerator(object):
@@ -87,7 +92,7 @@ class StreamIdGenerator(object):
self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step)
)
- self._unfinished_ids = deque()
+ self._unfinished_ids = deque() # type: Deque[int]
def get_next(self):
"""
@@ -163,7 +168,7 @@ class ChainedIdGenerator(object):
self.chained_generator = chained_generator
self._lock = threading.Lock()
self._current_max = _load_current_id(db_conn, table, column)
- self._unfinished_ids = deque()
+ self._unfinished_ids = deque() # type: Deque[Tuple[int, int]]
def get_next(self):
"""
@@ -198,3 +203,163 @@ class ChainedIdGenerator(object):
return stream_id - 1, chained_id
return self._current_max, self.chained_generator.get_current_token()
+
+
+class MultiWriterIdGenerator:
+ """An ID generator that tracks a stream that can have multiple writers.
+
+ Uses a Postgres sequence to coordinate ID assignment, but positions of other
+ writers will only get updated when `advance` is called (by replication).
+
+ Note: Only works with Postgres.
+
+ Args:
+ db_conn
+ db
+ instance_name: The name of this instance.
+ table: Database table associated with stream.
+ instance_column: Column that stores the row's writer's instance name
+ id_column: Column that stores the stream ID.
+ sequence_name: The name of the postgres sequence used to generate new
+ IDs.
+ """
+
+ def __init__(
+ self,
+ db_conn,
+ db: Database,
+ instance_name: str,
+ table: str,
+ instance_column: str,
+ id_column: str,
+ sequence_name: str,
+ ):
+ self._db = db
+ self._instance_name = instance_name
+ self._sequence_name = sequence_name
+
+ # We lock as some functions may be called from DB threads.
+ self._lock = threading.Lock()
+
+ self._current_positions = self._load_current_ids(
+ db_conn, table, instance_column, id_column
+ )
+
+ # Set of local IDs that we're still processing. The current position
+ # should be less than the minimum of this set (if not empty).
+ self._unfinished_ids = set() # type: Set[int]
+
+ def _load_current_ids(
+ self, db_conn, table: str, instance_column: str, id_column: str
+ ) -> Dict[str, int]:
+ sql = """
+ SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
+ GROUP BY %(instance)s
+ """ % {
+ "instance": instance_column,
+ "id": id_column,
+ "table": table,
+ }
+
+ cur = db_conn.cursor()
+ cur.execute(sql)
+
+ # `cur` is an iterable over returned rows, which are 2-tuples.
+ current_positions = dict(cur)
+
+ cur.close()
+
+ return current_positions
+
+ def _load_next_id_txn(self, txn):
+ txn.execute("SELECT nextval(?)", (self._sequence_name,))
+ (next_id,) = txn.fetchone()
+ return next_id
+
+ async def get_next(self):
+ """
+ Usage:
+ with await stream_id_gen.get_next() as stream_id:
+ # ... persist event ...
+ """
+ next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
+
+ # Assert the fetched ID is actually greater than what we currently
+ # believe the ID to be. If not, then the sequence and table have got
+ # out of sync somehow.
+ assert self.get_current_token() < next_id
+
+ with self._lock:
+ self._unfinished_ids.add(next_id)
+
+ @contextlib.contextmanager
+ def manager():
+ try:
+ yield next_id
+ finally:
+ self._mark_id_as_finished(next_id)
+
+ return manager()
+
+ def get_next_txn(self, txn: LoggingTransaction):
+ """
+ Usage:
+
+ stream_id = stream_id_gen.get_next(txn)
+ # ... persist event ...
+ """
+
+ next_id = self._load_next_id_txn(txn)
+
+ with self._lock:
+ self._unfinished_ids.add(next_id)
+
+ txn.call_after(self._mark_id_as_finished, next_id)
+ txn.call_on_exception(self._mark_id_as_finished, next_id)
+
+ return next_id
+
+ def _mark_id_as_finished(self, next_id: int):
+ """The ID has finished being processed so we should advance the
+ current poistion if possible.
+ """
+
+ with self._lock:
+ self._unfinished_ids.discard(next_id)
+
+ # Figure out if its safe to advance the position by checking there
+ # aren't any lower allocated IDs that are yet to finish.
+ if all(c > next_id for c in self._unfinished_ids):
+ curr = self._current_positions.get(self._instance_name, 0)
+ self._current_positions[self._instance_name] = max(curr, next_id)
+
+ def get_current_token(self, instance_name: str = None) -> int:
+ """Gets the current position of a named writer (defaults to current
+ instance).
+
+ Returns 0 if we don't have a position for the named writer (likely due
+ to it being a new writer).
+ """
+
+ if instance_name is None:
+ instance_name = self._instance_name
+
+ with self._lock:
+ return self._current_positions.get(instance_name, 0)
+
+ def get_positions(self) -> Dict[str, int]:
+ """Get a copy of the current positon map.
+ """
+
+ with self._lock:
+ return dict(self._current_positions)
+
+ def advance(self, instance_name: str, new_id: int):
+ """Advance the postion of the named writer to the given ID, if greater
+ than existing entry.
+ """
+
+ with self._lock:
+ self._current_positions[instance_name] = max(
+ new_id, self._current_positions.get(instance_name, 0)
+ )
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
index 7b56d2028d..9d4f0bbe44 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/tcp/streams/_base.py
@@ -27,6 +27,7 @@ from synapse.app.generic_worker import (
GenericWorkerServer,
)
from synapse.http.site import SynapseRequest
+from synapse.replication.http import streams
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -42,6 +43,10 @@ logger = logging.getLogger(__name__)
class BaseStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests of the replication streams"""
+ servlets = [
+ streams.register_servlets,
+ ]
+
def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
@@ -49,17 +54,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.server = server_factory.buildProtocol(None)
# Make a new HomeServer object for the worker
- config = self.default_config()
- config["worker_app"] = "synapse.app.generic_worker"
- config["worker_replication_host"] = "testserv"
- config["worker_replication_http_port"] = "8765"
-
self.reactor.lookups["testserv"] = "1.2.3.4"
-
self.worker_hs = self.setup_test_homeserver(
http_client=None,
homeserverToUse=GenericWorkerServer,
- config=config,
+ config=self._get_worker_hs_config(),
reactor=self.reactor,
)
@@ -78,6 +77,13 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._client_transport = None
self._server_transport = None
+ def _get_worker_hs_config(self) -> dict:
+ config = self.default_config()
+ config["worker_app"] = "synapse.app.generic_worker"
+ config["worker_replication_host"] = "testserv"
+ config["worker_replication_http_port"] = "8765"
+ return config
+
def _build_replication_data_handler(self):
return TestReplicationDataHandler(self.worker_hs)
diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py
new file mode 100644
index 0000000000..eea4565da3
--- /dev/null
+++ b/tests/replication/tcp/streams/test_federation.py
@@ -0,0 +1,81 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.federation.send_queue import EduRow
+from synapse.replication.tcp.streams.federation import FederationStream
+
+from tests.replication.tcp.streams._base import BaseStreamTestCase
+
+
+class FederationStreamTestCase(BaseStreamTestCase):
+ def _get_worker_hs_config(self) -> dict:
+ # enable federation sending on the worker
+ config = super()._get_worker_hs_config()
+ # TODO: make it so we don't need both of these
+ config["send_federation"] = True
+ config["worker_app"] = "synapse.app.federation_sender"
+ return config
+
+ def test_catchup(self):
+ """Basic test of catchup on reconnect
+
+ Makes sure that updates sent while we are offline are received later.
+ """
+ fed_sender = self.hs.get_federation_sender()
+ received_rows = self.test_handler.received_rdata_rows
+
+ fed_sender.build_and_send_edu("testdest", "m.test_edu", {"a": "b"})
+
+ self.reconnect()
+ self.reactor.advance(0)
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual(received_rows, [])
+
+ # We should now see an attempt to connect to the master
+ request = self.handle_http_replication_attempt()
+ self.assert_request_is_get_repl_stream_updates(request, "federation")
+
+ # we should have received an update row
+ stream_name, token, row = received_rows.pop()
+ self.assertEqual(stream_name, "federation")
+ self.assertIsInstance(row, FederationStream.FederationStreamRow)
+ self.assertEqual(row.type, EduRow.TypeId)
+ edurow = EduRow.from_data(row.data)
+ self.assertEqual(edurow.edu.edu_type, "m.test_edu")
+ self.assertEqual(edurow.edu.origin, self.hs.hostname)
+ self.assertEqual(edurow.edu.destination, "testdest")
+ self.assertEqual(edurow.edu.content, {"a": "b"})
+
+ self.assertEqual(received_rows, [])
+
+ # additional updates should be transferred without an HTTP hit
+ fed_sender.build_and_send_edu("testdest", "m.test1", {"c": "d"})
+ self.reactor.advance(0)
+ # there should be no http hit
+ self.assertEqual(len(self.reactor.tcpClients), 0)
+ # ... but we should have a row
+ self.assertEqual(len(received_rows), 1)
+
+ stream_name, token, row = received_rows.pop()
+ self.assertEqual(stream_name, "federation")
+ self.assertIsInstance(row, FederationStream.FederationStreamRow)
+ self.assertEqual(row.type, EduRow.TypeId)
+ edurow = EduRow.from_data(row.data)
+ self.assertEqual(edurow.edu.edu_type, "m.test1")
+ self.assertEqual(edurow.edu.origin, self.hs.hostname)
+ self.assertEqual(edurow.edu.destination, "testdest")
+ self.assertEqual(edurow.edu.content, {"c": "d"})
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index d25a7b194e..125c63dab5 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -15,7 +15,6 @@
from mock import Mock
from synapse.handlers.typing import RoomMember
-from synapse.replication.http import streams
from synapse.replication.tcp.streams import TypingStream
from tests.replication.tcp.streams._base import BaseStreamTestCase
@@ -24,10 +23,6 @@ USER_ID = "@feeling:blue"
class TypingStreamTestCase(BaseStreamTestCase):
- servlets = [
- streams.register_servlets,
- ]
-
def _build_replication_data_handler(self):
return Mock(wraps=super()._build_replication_data_handler())
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
new file mode 100644
index 0000000000..55e9ecf264
--- /dev/null
+++ b/tests/storage/test_id_generators.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.storage.database import Database
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
+
+from tests.unittest import HomeserverTestCase
+from tests.utils import USE_POSTGRES_FOR_TESTS
+
+
+class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
+ if not USE_POSTGRES_FOR_TESTS:
+ skip = "Requires Postgres"
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.db = self.store.db # type: Database
+
+ self.get_success(self.db.runInteraction("_setup_db", self._setup_db))
+
+ def _setup_db(self, txn):
+ txn.execute("CREATE SEQUENCE foobar_seq")
+ txn.execute(
+ """
+ CREATE TABLE foobar (
+ stream_id BIGINT NOT NULL,
+ instance_name TEXT NOT NULL,
+ data TEXT
+ );
+ """
+ )
+
+ def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+ def _create(conn):
+ return MultiWriterIdGenerator(
+ conn,
+ self.db,
+ instance_name=instance_name,
+ table="foobar",
+ instance_column="instance_name",
+ id_column="stream_id",
+ sequence_name="foobar_seq",
+ )
+
+ return self.get_success(self.db.runWithConnection(_create))
+
+ def _insert_rows(self, instance_name: str, number: int):
+ def _insert(txn):
+ for _ in range(number):
+ txn.execute(
+ "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
+ (instance_name,),
+ )
+
+ self.get_success(self.db.runInteraction("test_single_instance", _insert))
+
+ def test_empty(self):
+ """Test an ID generator against an empty database gives sensible
+ current positions.
+ """
+
+ id_gen = self._create_id_generator()
+
+ # The table is empty so we expect an empty map for positions
+ self.assertEqual(id_gen.get_positions(), {})
+
+ def test_single_instance(self):
+ """Test that reads and writes from a single process are handled
+ correctly.
+ """
+
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
+
+ id_gen = self._create_id_generator()
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
+
+ async def _get_next_async():
+ with await id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 8)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(id_gen.get_positions(), {"master": 8})
+ self.assertEqual(id_gen.get_current_token("master"), 8)
+
+ def test_multi_instance(self):
+ """Test that reads and writes from multiple processes are handled
+ correctly.
+ """
+ self._insert_rows("first", 3)
+ self._insert_rows("second", 4)
+
+ first_id_gen = self._create_id_generator("first")
+ second_id_gen = self._create_id_generator("second")
+
+ self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
+ self.assertEqual(first_id_gen.get_current_token("first"), 3)
+ self.assertEqual(first_id_gen.get_current_token("second"), 7)
+
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
+
+ async def _get_next_async():
+ with await first_id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 8)
+
+ self.assertEqual(
+ first_id_gen.get_positions(), {"first": 3, "second": 7}
+ )
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(first_id_gen.get_positions(), {"first": 8, "second": 7})
+
+ # However the ID gen on the second instance won't have seen the update
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+
+ # ... but calling `get_next` on the second instance should give a unique
+ # stream ID
+
+ async def _get_next_async():
+ with await second_id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 9)
+
+ self.assertEqual(
+ second_id_gen.get_positions(), {"first": 3, "second": 7}
+ )
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9})
+
+ # If the second ID gen gets told about the first, it correctly updates
+ second_id_gen.advance("first", 8)
+ self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
+
+ def test_get_next_txn(self):
+ """Test that the `get_next_txn` function works correctly.
+ """
+
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
+
+ id_gen = self._create_id_generator()
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
+
+ def _get_next_txn(txn):
+ stream_id = id_gen.get_next_txn(txn)
+ self.assertEqual(stream_id, 8)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ self.get_success(self.db.runInteraction("test", _get_next_txn))
+
+ self.assertEqual(id_gen.get_positions(), {"master": 8})
+ self.assertEqual(id_gen.get_current_token("master"), 8)
diff --git a/tox.ini b/tox.ini
index eccc44e436..ad4ed8299e 100644
--- a/tox.ini
+++ b/tox.ini
@@ -181,11 +181,7 @@ commands = mypy \
synapse/appservice \
synapse/config \
synapse/events/spamcheck.py \
- synapse/federation/federation_base.py \
- synapse/federation/federation_client.py \
- synapse/federation/federation_server.py \
- synapse/federation/sender \
- synapse/federation/transport \
+ synapse/federation \
synapse/handlers/auth.py \
synapse/handlers/cas_handler.py \
synapse/handlers/directory.py \
@@ -203,6 +199,7 @@ commands = mypy \
synapse/storage/data_stores/main/ui_auth.py \
synapse/storage/database.py \
synapse/storage/engines \
+ synapse/storage/util \
synapse/streams \
synapse/util/caches/stream_change_cache.py \
tests/replication/tcp/streams \
|