diff --git a/MANIFEST.in b/MANIFEST.in
index 156d6f04f7..120ce5b776 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -30,23 +30,24 @@ recursive-include synapse/static *.gif
recursive-include synapse/static *.html
recursive-include synapse/static *.js
-exclude Dockerfile
+exclude .codecov.yml
+exclude .coveragerc
exclude .dockerignore
-exclude test_postgresql.sh
exclude .editorconfig
+exclude Dockerfile
+exclude mypy.ini
exclude sytest-blacklist
+exclude test_postgresql.sh
include pyproject.toml
recursive-include changelog.d *
prune .buildkite
prune .circleci
-prune .codecov.yml
-prune .coveragerc
prune .github
+prune contrib
prune debian
prune demo/etc
prune docker
-prune mypy.ini
prune snap
prune stubs
diff --git a/changelog.d/7219.misc b/changelog.d/7219.misc
index 4af5da8646..dbf7a530be 100644
--- a/changelog.d/7219.misc
+++ b/changelog.d/7219.misc
@@ -1 +1 @@
-Add typing information to federation server code.
+Add typing annotations in `synapse.federation`.
diff --git a/changelog.d/7281.misc b/changelog.d/7281.misc
new file mode 100644
index 0000000000..86ad511e19
--- /dev/null
+++ b/changelog.d/7281.misc
@@ -0,0 +1 @@
+Add MultiWriterIdGenerator to support multiple concurrent writers of streams.
diff --git a/changelog.d/7374.misc b/changelog.d/7374.misc
new file mode 100644
index 0000000000..676f285377
--- /dev/null
+++ b/changelog.d/7374.misc
@@ -0,0 +1 @@
+Move catchup of replication streams logic to worker.
diff --git a/changelog.d/7382.misc b/changelog.d/7382.misc
new file mode 100644
index 0000000000..dbf7a530be
--- /dev/null
+++ b/changelog.d/7382.misc
@@ -0,0 +1 @@
+Add typing annotations in `synapse.federation`.
diff --git a/changelog.d/7393.bugfix b/changelog.d/7393.bugfix
new file mode 100644
index 0000000000..74419af858
--- /dev/null
+++ b/changelog.d/7393.bugfix
@@ -0,0 +1 @@
+Fix bug in `EventContext.deserialize`.
diff --git a/changelog.d/7396.misc b/changelog.d/7396.misc
new file mode 100644
index 0000000000..290d2befc7
--- /dev/null
+++ b/changelog.d/7396.misc
@@ -0,0 +1 @@
+Convert the room handler to async/await.
diff --git a/changelog.d/7401.feature b/changelog.d/7401.feature
new file mode 100644
index 0000000000..ce6140fdd1
--- /dev/null
+++ b/changelog.d/7401.feature
@@ -0,0 +1 @@
+Add support for running replication over Redis when using workers.
diff --git a/changelog.d/7404.misc b/changelog.d/7404.misc
new file mode 100644
index 0000000000..9ac17958cc
--- /dev/null
+++ b/changelog.d/7404.misc
@@ -0,0 +1 @@
+Fix issues with the Python package manifest.
diff --git a/changelog.d/7408.misc b/changelog.d/7408.misc
new file mode 100644
index 0000000000..731f4dcb52
--- /dev/null
+++ b/changelog.d/7408.misc
@@ -0,0 +1 @@
+Clean up some LoggingContext code.
diff --git a/changelog.d/7420.misc b/changelog.d/7420.misc
new file mode 100644
index 0000000000..e834a9163e
--- /dev/null
+++ b/changelog.d/7420.misc
@@ -0,0 +1 @@
+Prevent methods in `synapse.handlers.auth` from polling the homeserver config every request.
\ No newline at end of file
diff --git a/changelog.d/7421.misc b/changelog.d/7421.misc
new file mode 100644
index 0000000000..676f285377
--- /dev/null
+++ b/changelog.d/7421.misc
@@ -0,0 +1 @@
+Move catchup of replication streams logic to worker.
diff --git a/changelog.d/7423.misc b/changelog.d/7423.misc
new file mode 100644
index 0000000000..eb1767ac13
--- /dev/null
+++ b/changelog.d/7423.misc
@@ -0,0 +1 @@
+Speed up fetching device lists changes when handling `/sync` requests.
diff --git a/changelog.d/7427.feature b/changelog.d/7427.feature
new file mode 100644
index 0000000000..ce6140fdd1
--- /dev/null
+++ b/changelog.d/7427.feature
@@ -0,0 +1 @@
+Add support for running replication over Redis when using workers.
diff --git a/changelog.d/7428.misc b/changelog.d/7428.misc
new file mode 100644
index 0000000000..db5ff76ded
--- /dev/null
+++ b/changelog.d/7428.misc
@@ -0,0 +1 @@
+Improve performance of `get_e2e_cross_signing_key`.
diff --git a/changelog.d/7429.misc b/changelog.d/7429.misc
new file mode 100644
index 0000000000..3c25cd9917
--- /dev/null
+++ b/changelog.d/7429.misc
@@ -0,0 +1 @@
+Improve performance of `mark_as_sent_devices_by_remote`.
diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi
index 763d3fb404..cac689d4f3 100644
--- a/stubs/txredisapi.pyi
+++ b/stubs/txredisapi.pyi
@@ -22,7 +22,10 @@ class RedisProtocol:
def publish(self, channel: str, message: bytes): ...
class SubscriberProtocol:
+ password: Optional[str]
def subscribe(self, channels: Union[str, List[str]]): ...
+ def connectionMade(self): ...
+ def connectionLost(self, reason): ...
def lazyConnection(
host: str = ...,
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index c5d1eb952b..1ad5ff9410 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -26,16 +26,15 @@ from twisted.internet import defer
import synapse.logging.opentracing as opentracing
import synapse.types
from synapse import event_auth
-from synapse.api.constants import EventTypes, LimitBlockingTypes, Membership, UserTypes
+from synapse.api.auth_blocking import AuthBlocking
+from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
AuthError,
Codes,
InvalidClientTokenError,
MissingClientTokenError,
- ResourceLimitError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
-from synapse.config.server import is_threepid_reserved
from synapse.events import EventBase
from synapse.types import StateMap, UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
@@ -77,7 +76,11 @@ class Auth(object):
self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000)
register_cache("cache", "token_cache", self.token_cache)
+ self._auth_blocking = AuthBlocking(self.hs)
+
self._account_validity = hs.config.account_validity
+ self._track_appservice_user_ips = hs.config.track_appservice_user_ips
+ self._macaroon_secret_key = hs.config.macaroon_secret_key
@defer.inlineCallbacks
def check_from_context(self, room_version: str, event, context, do_sig_check=True):
@@ -191,7 +194,7 @@ class Auth(object):
opentracing.set_tag("authenticated_entity", user_id)
opentracing.set_tag("appservice_id", app_service.id)
- if ip_addr and self.hs.config.track_appservice_user_ips:
+ if ip_addr and self._track_appservice_user_ips:
yield self.store.insert_client_ip(
user_id=user_id,
access_token=access_token,
@@ -454,7 +457,7 @@ class Auth(object):
# access_tokens include a nonce for uniqueness: any value is acceptable
v.satisfy_general(lambda c: c.startswith("nonce = "))
- v.verify(macaroon, self.hs.config.macaroon_secret_key)
+ v.verify(macaroon, self._macaroon_secret_key)
def _verify_expiry(self, caveat):
prefix = "time < "
@@ -663,71 +666,5 @@ class Auth(object):
% (user_id, room_id),
)
- @defer.inlineCallbacks
- def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
- """Checks if the user should be rejected for some external reason,
- such as monthly active user limiting or global disable flag
-
- Args:
- user_id(str|None): If present, checks for presence against existing
- MAU cohort
-
- threepid(dict|None): If present, checks for presence against configured
- reserved threepid. Used in cases where the user is trying register
- with a MAU blocked server, normally they would be rejected but their
- threepid is on the reserved list. user_id and
- threepid should never be set at the same time.
-
- user_type(str|None): If present, is used to decide whether to check against
- certain blocking reasons like MAU.
- """
-
- # Never fail an auth check for the server notices users or support user
- # This can be a problem where event creation is prohibited due to blocking
- if user_id is not None:
- if user_id == self.hs.config.server_notices_mxid:
- return
- if (yield self.store.is_support_user(user_id)):
- return
-
- if self.hs.config.hs_disabled:
- raise ResourceLimitError(
- 403,
- self.hs.config.hs_disabled_message,
- errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
- admin_contact=self.hs.config.admin_contact,
- limit_type=LimitBlockingTypes.HS_DISABLED,
- )
- if self.hs.config.limit_usage_by_mau is True:
- assert not (user_id and threepid)
-
- # If the user is already part of the MAU cohort or a trial user
- if user_id:
- timestamp = yield self.store.user_last_seen_monthly_active(user_id)
- if timestamp:
- return
-
- is_trial = yield self.store.is_trial_user(user_id)
- if is_trial:
- return
- elif threepid:
- # If the user does not exist yet, but is signing up with a
- # reserved threepid then pass auth check
- if is_threepid_reserved(
- self.hs.config.mau_limits_reserved_threepids, threepid
- ):
- return
- elif user_type == UserTypes.SUPPORT:
- # If the user does not exist yet and is of type "support",
- # allow registration. Support users are excluded from MAU checks.
- return
- # Else if there is no room in the MAU bucket, bail
- current_mau = yield self.store.get_monthly_active_count()
- if current_mau >= self.hs.config.max_mau_value:
- raise ResourceLimitError(
- 403,
- "Monthly Active User Limit Exceeded",
- admin_contact=self.hs.config.admin_contact,
- errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
- limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER,
- )
+ def check_auth_blocking(self, *args, **kwargs):
+ return self._auth_blocking.check_auth_blocking(*args, **kwargs)
diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py
new file mode 100644
index 0000000000..5c499b6b4e
--- /dev/null
+++ b/synapse/api/auth_blocking.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.constants import LimitBlockingTypes, UserTypes
+from synapse.api.errors import Codes, ResourceLimitError
+from synapse.config.server import is_threepid_reserved
+
+logger = logging.getLogger(__name__)
+
+
+class AuthBlocking(object):
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+
+ self._server_notices_mxid = hs.config.server_notices_mxid
+ self._hs_disabled = hs.config.hs_disabled
+ self._hs_disabled_message = hs.config.hs_disabled_message
+ self._admin_contact = hs.config.admin_contact
+ self._max_mau_value = hs.config.max_mau_value
+ self._limit_usage_by_mau = hs.config.limit_usage_by_mau
+ self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
+
+ @defer.inlineCallbacks
+ def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
+ """Checks if the user should be rejected for some external reason,
+ such as monthly active user limiting or global disable flag
+
+ Args:
+ user_id(str|None): If present, checks for presence against existing
+ MAU cohort
+
+ threepid(dict|None): If present, checks for presence against configured
+ reserved threepid. Used in cases where the user is trying register
+ with a MAU blocked server, normally they would be rejected but their
+ threepid is on the reserved list. user_id and
+ threepid should never be set at the same time.
+
+ user_type(str|None): If present, is used to decide whether to check against
+ certain blocking reasons like MAU.
+ """
+
+ # Never fail an auth check for the server notices users or support user
+ # This can be a problem where event creation is prohibited due to blocking
+ if user_id is not None:
+ if user_id == self._server_notices_mxid:
+ return
+ if (yield self.store.is_support_user(user_id)):
+ return
+
+ if self._hs_disabled:
+ raise ResourceLimitError(
+ 403,
+ self._hs_disabled_message,
+ errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
+ admin_contact=self._admin_contact,
+ limit_type=LimitBlockingTypes.HS_DISABLED,
+ )
+ if self._limit_usage_by_mau is True:
+ assert not (user_id and threepid)
+
+ # If the user is already part of the MAU cohort or a trial user
+ if user_id:
+ timestamp = yield self.store.user_last_seen_monthly_active(user_id)
+ if timestamp:
+ return
+
+ is_trial = yield self.store.is_trial_user(user_id)
+ if is_trial:
+ return
+ elif threepid:
+ # If the user does not exist yet, but is signing up with a
+ # reserved threepid then pass auth check
+ if is_threepid_reserved(self._mau_limits_reserved_threepids, threepid):
+ return
+ elif user_type == UserTypes.SUPPORT:
+ # If the user does not exist yet and is of type "support",
+ # allow registration. Support users are excluded from MAU checks.
+ return
+ # Else if there is no room in the MAU bucket, bail
+ current_mau = yield self.store.get_monthly_active_count()
+ if current_mau >= self._max_mau_value:
+ raise ResourceLimitError(
+ 403,
+ "Monthly Active User Limit Exceeded",
+ admin_contact=self._admin_contact,
+ errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
+ limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER,
+ )
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 9ea85e93e6..7c5f620d09 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -322,11 +322,14 @@ class _AsyncEventContextImpl(EventContext):
self._current_state_ids = yield self._storage.state.get_state_ids_for_group(
self.state_group
)
- if self._prev_state_id and self._event_state_key is not None:
+ if self._event_state_key is not None:
self._prev_state_ids = dict(self._current_state_ids)
key = (self._event_type, self._event_state_key)
- self._prev_state_ids[key] = self._prev_state_id
+ if self._prev_state_id:
+ self._prev_state_ids[key] = self._prev_state_id
+ else:
+ self._prev_state_ids.pop(key, None)
else:
self._prev_state_ids = self._current_state_ids
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index e1700ca8aa..52f4f54215 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -31,6 +31,7 @@ Events are replicated via a separate events stream.
import logging
from collections import namedtuple
+from typing import Dict, List, Tuple, Type
from six import iteritems
@@ -56,21 +57,35 @@ class FederationRemoteSendQueue(object):
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
- self.presence_map = {} # Pending presence map user_id -> UserPresenceState
- self.presence_changed = SortedDict() # Stream position -> list[user_id]
+ # Pending presence map user_id -> UserPresenceState
+ self.presence_map = {} # type: Dict[str, UserPresenceState]
+
+ # Stream position -> list[user_id]
+ self.presence_changed = SortedDict() # type: SortedDict[int, List[str]]
# Stores the destinations we need to explicitly send presence to about a
# given user.
# Stream position -> (user_id, destinations)
- self.presence_destinations = SortedDict()
+ self.presence_destinations = (
+ SortedDict()
+ ) # type: SortedDict[int, Tuple[str, List[str]]]
+
+ # (destination, key) -> EDU
+ self.keyed_edu = {} # type: Dict[Tuple[str, tuple], Edu]
- self.keyed_edu = {} # (destination, key) -> EDU
- self.keyed_edu_changed = SortedDict() # stream position -> (destination, key)
+ # stream position -> (destination, key)
+ self.keyed_edu_changed = (
+ SortedDict()
+ ) # type: SortedDict[int, Tuple[str, tuple]]
- self.edus = SortedDict() # stream position -> Edu
+ self.edus = SortedDict() # type: SortedDict[int, Edu]
+ # stream ID for the next entry into presence_changed/keyed_edu_changed/edus.
self.pos = 1
- self.pos_time = SortedDict()
+
+ # map from stream ID to the time that stream entry was generated, so that we
+ # can clear out entries after a while
+ self.pos_time = SortedDict() # type: SortedDict[int, int]
# EVERYTHING IS SAD. In particular, python only makes new scopes when
# we make a new function, so we need to make a new function so the inner
@@ -158,8 +173,10 @@ class FederationRemoteSendQueue(object):
for edu_key in self.keyed_edu_changed.values():
live_keys.add(edu_key)
- to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys]
- for edu_key in to_del:
+ keys_to_del = [
+ edu_key for edu_key in self.keyed_edu if edu_key not in live_keys
+ ]
+ for edu_key in keys_to_del:
del self.keyed_edu[edu_key]
# Delete things out of edu map
@@ -250,19 +267,23 @@ class FederationRemoteSendQueue(object):
self._clear_queue_before_pos(token)
async def get_replication_rows(
- self, from_token, to_token, limit, federation_ack=None
- ):
+ self, instance_name: str, from_token: int, to_token: int, target_row_count: int
+ ) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
"""Get rows to be sent over federation between the two tokens
Args:
- from_token (int)
- to_token(int)
- limit (int)
- federation_ack (int): Optional. The position where the worker is
- explicitly acknowledged it has handled. Allows us to drop
- data from before that point
+ instance_name: the name of the current process
+ from_token: the previous stream token: the starting point for fetching the
+ updates
+ to_token: the new stream token: the point to get updates up to
+ target_row_count: a target for the number of rows to be returned.
+
+ Returns: a triplet `(updates, new_last_token, limited)`, where:
+ * `updates` is a list of `(token, row)` entries.
+ * `new_last_token` is the new position in stream.
+ * `limited` is whether there are more updates to fetch.
"""
- # TODO: Handle limit.
+ # TODO: Handle target_row_count.
# To handle restarts where we wrap around
if from_token > self.pos:
@@ -270,12 +291,7 @@ class FederationRemoteSendQueue(object):
# list of tuple(int, BaseFederationRow), where the first is the position
# of the federation stream.
- rows = []
-
- # There should be only one reader, so lets delete everything its
- # acknowledged its seen.
- if federation_ack:
- self._clear_queue_before_pos(federation_ack)
+ rows = [] # type: List[Tuple[int, BaseFederationRow]]
# Fetch changed presence
i = self.presence_changed.bisect_right(from_token)
@@ -332,7 +348,11 @@ class FederationRemoteSendQueue(object):
# Sort rows based on pos
rows.sort()
- return [(pos, row.TypeId, row.to_data()) for pos, row in rows]
+ return (
+ [(pos, (row.TypeId, row.to_data())) for pos, row in rows],
+ to_token,
+ False,
+ )
class BaseFederationRow(object):
@@ -341,7 +361,7 @@ class BaseFederationRow(object):
Specifies how to identify, serialize and deserialize the different types.
"""
- TypeId = None # Unique string that ids the type. Must be overriden in sub classes.
+ TypeId = "" # Unique string that ids the type. Must be overriden in sub classes.
@staticmethod
def from_data(data):
@@ -454,10 +474,14 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
-TypeToRow = {
- Row.TypeId: Row
- for Row in (PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow,)
-}
+_rowtypes = (
+ PresenceRow,
+ PresenceDestinationsRow,
+ KeyedEduRow,
+ EduRow,
+) # type: Tuple[Type[BaseFederationRow], ...]
+
+TypeToRow = {Row.TypeId: Row for Row in _rowtypes}
ParsedFederationStreamData = namedtuple(
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index a477578e44..d473576902 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Dict, Hashable, Iterable, List, Optional, Set
+from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple
from six import itervalues
@@ -498,14 +498,16 @@ class FederationSender(object):
self._get_per_destination_queue(destination).attempt_new_transaction()
- def get_current_token(self) -> int:
+ @staticmethod
+ def get_current_token() -> int:
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
return 0
+ @staticmethod
async def get_replication_rows(
- self, from_token, to_token, limit, federation_ack=None
- ):
+ instance_name: str, from_token: int, to_token: int, target_row_count: int
+ ) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
- return []
+ return [], 0, False
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index e13cd20ffa..276a2b596f 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -15,11 +15,10 @@
# limitations under the License.
import datetime
import logging
-from typing import Dict, Hashable, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Tuple
from prometheus_client import Counter
-import synapse.server
from synapse.api.errors import (
FederationDeniedError,
HttpResponseException,
@@ -34,6 +33,9 @@ from synapse.storage.presence import UserPresenceState
from synapse.types import ReadReceipt
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
+if TYPE_CHECKING:
+ import synapse.server
+
# This is defined in the Matrix spec and enforced by the receiver.
MAX_EDUS_PER_TRANSACTION = 100
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 3c2a02a3b3..a2752a54a5 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -13,11 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import List
+from typing import TYPE_CHECKING, List
from canonicaljson import json
-import synapse.server
from synapse.api.errors import HttpResponseException
from synapse.events import EventBase
from synapse.federation.persistence import TransactionActions
@@ -31,6 +30,9 @@ from synapse.logging.opentracing import (
)
from synapse.util.metrics import measure_func
+if TYPE_CHECKING:
+ import synapse.server
+
logger = logging.getLogger(__name__)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index da12df7f53..73f9eeb399 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -25,8 +25,6 @@ from collections import OrderedDict
from six import iteritems, string_types
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
@@ -103,8 +101,7 @@ class RoomCreationHandler(BaseHandler):
self.third_party_event_rules = hs.get_third_party_event_rules()
- @defer.inlineCallbacks
- def upgrade_room(
+ async def upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
):
"""Replace a room with a new room with a different version
@@ -117,7 +114,7 @@ class RoomCreationHandler(BaseHandler):
Returns:
Deferred[unicode]: the new room id
"""
- yield self.ratelimit(requester)
+ await self.ratelimit(requester)
user_id = requester.user.to_string()
@@ -138,7 +135,7 @@ class RoomCreationHandler(BaseHandler):
# If this user has sent multiple upgrade requests for the same room
# and one of them is not complete yet, cache the response and
# return it to all subsequent requests
- ret = yield self._upgrade_response_cache.wrap(
+ ret = await self._upgrade_response_cache.wrap(
(old_room_id, user_id),
self._upgrade_room,
requester,
@@ -856,8 +853,7 @@ class RoomCreationHandler(BaseHandler):
for (etype, state_key), content in initial_state.items():
await send(etype=etype, state_key=state_key, content=content)
- @defer.inlineCallbacks
- def _generate_room_id(
+ async def _generate_room_id(
self, creator_id: str, is_public: str, room_version: RoomVersion,
):
# autogen room IDs and try to create it. We may clash, so just
@@ -869,7 +865,7 @@ class RoomCreationHandler(BaseHandler):
gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
if isinstance(gen_room_id, bytes):
gen_room_id = gen_room_id.decode("utf-8")
- yield self.store.store_room(
+ await self.store.store_room(
room_id=gen_room_id,
room_creator_user_id=creator_id,
is_public=is_public,
@@ -888,8 +884,7 @@ class RoomContextHandler(object):
self.storage = hs.get_storage()
self.state_store = self.storage.state
- @defer.inlineCallbacks
- def get_event_context(self, user, room_id, event_id, limit, event_filter):
+ async def get_event_context(self, user, room_id, event_id, limit, event_filter):
"""Retrieves events, pagination tokens and state around a given event
in a room.
@@ -908,7 +903,7 @@ class RoomContextHandler(object):
before_limit = math.floor(limit / 2.0)
after_limit = limit - before_limit
- users = yield self.store.get_users_in_room(room_id)
+ users = await self.store.get_users_in_room(room_id)
is_peeking = user.to_string() not in users
def filter_evts(events):
@@ -916,17 +911,17 @@ class RoomContextHandler(object):
self.storage, user.to_string(), events, is_peeking=is_peeking
)
- event = yield self.store.get_event(
+ event = await self.store.get_event(
event_id, get_prev_content=True, allow_none=True
)
if not event:
return None
- filtered = yield (filter_evts([event]))
+ filtered = await filter_evts([event])
if not filtered:
raise AuthError(403, "You don't have permission to access that event.")
- results = yield self.store.get_events_around(
+ results = await self.store.get_events_around(
room_id, event_id, before_limit, after_limit, event_filter
)
@@ -934,8 +929,8 @@ class RoomContextHandler(object):
results["events_before"] = event_filter.filter(results["events_before"])
results["events_after"] = event_filter.filter(results["events_after"])
- results["events_before"] = yield filter_evts(results["events_before"])
- results["events_after"] = yield filter_evts(results["events_after"])
+ results["events_before"] = await filter_evts(results["events_before"])
+ results["events_after"] = await filter_evts(results["events_after"])
# filter_evts can return a pruned event in case the user is allowed to see that
# there's something there but not see the content, so use the event that's in
# `filtered` rather than the event we retrieved from the datastore.
@@ -962,7 +957,7 @@ class RoomContextHandler(object):
# first? Shouldn't we be consistent with /sync?
# https://github.com/matrix-org/matrix-doc/issues/687
- state = yield self.state_store.get_state_for_events(
+ state = await self.state_store.get_state_for_events(
[last_event_id], state_filter=state_filter
)
@@ -970,7 +965,7 @@ class RoomContextHandler(object):
if event_filter:
state_events = event_filter.filter(state_events)
- results["state"] = yield filter_evts(state_events)
+ results["state"] = await filter_evts(state_events)
# We use a dummy token here as we only care about the room portion of
# the token, which we replace.
@@ -989,13 +984,12 @@ class RoomEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def get_new_events(
+ async def get_new_events(
self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None
):
# We just ignore the key for now.
- to_key = yield self.get_current_key()
+ to_key = await self.get_current_key()
from_token = RoomStreamToken.parse(from_key)
if from_token.topological:
@@ -1008,11 +1002,11 @@ class RoomEventSource(object):
# See https://github.com/matrix-org/matrix-doc/issues/1144
raise NotImplementedError()
else:
- room_events = yield self.store.get_membership_changes_for_user(
+ room_events = await self.store.get_membership_changes_for_user(
user.to_string(), from_key, to_key
)
- room_to_events = yield self.store.get_room_events_stream_for_rooms(
+ room_to_events = await self.store.get_room_events_stream_for_rooms(
room_ids=room_ids,
from_key=from_key,
to_key=to_key,
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 4f76b7a743..00718d7f2d 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1143,10 +1143,14 @@ class SyncHandler(object):
user_id
)
- tracked_users = set(users_who_share_room)
-
- # Always tell the user about their own devices
- tracked_users.add(user_id)
+ # Always tell the user about their own devices. We check as the user
+ # ID is almost certainly already included (unless they're not in any
+ # rooms) and taking a copy of the set is relatively expensive.
+ if user_id not in users_who_share_room:
+ users_who_share_room = set(users_who_share_room)
+ users_who_share_room.add(user_id)
+
+ tracked_users = users_who_share_room
# Step 1a, check for changes in devices of users we share a room with
users_that_have_changed = await self.store.get_users_whose_devices_changed(
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index a8f674d13d..856534e91a 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -27,6 +27,7 @@ import inspect
import logging
import threading
import types
+import warnings
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
from typing_extensions import Literal
@@ -287,6 +288,46 @@ class LoggingContext(object):
return str(self.request)
return "%s@%x" % (self.name, id(self))
+ @classmethod
+ def current_context(cls) -> LoggingContextOrSentinel:
+ """Get the current logging context from thread local storage
+
+ This exists for backwards compatibility. ``current_context()`` should be
+ called directly.
+
+ Returns:
+ LoggingContext: the current logging context
+ """
+ warnings.warn(
+ "synapse.logging.context.LoggingContext.current_context() is deprecated "
+ "in favor of synapse.logging.context.current_context().",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return current_context()
+
+ @classmethod
+ def set_current_context(
+ cls, context: LoggingContextOrSentinel
+ ) -> LoggingContextOrSentinel:
+ """Set the current logging context in thread local storage
+
+ This exists for backwards compatibility. ``set_current_context()`` should be
+ called directly.
+
+ Args:
+ context(LoggingContext): The context to activate.
+ Returns:
+ The context that was previously active
+ """
+ warnings.warn(
+ "synapse.logging.context.LoggingContext.set_current_context() is deprecated "
+ "in favor of synapse.logging.context.set_current_context().",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return set_current_context(context)
+
def __enter__(self) -> "LoggingContext":
"""Enters this logging context into thread local storage"""
old_context = set_current_context(self)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 2d1d119c7c..b14a3d9fca 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -81,9 +81,6 @@ class ReplicationCommandHandler:
self._instance_id = hs.get_instance_id()
self._instance_name = hs.get_instance_name()
- # Set of streams that we've caught up with.
- self._streams_connected = set() # type: Set[str]
-
self._streams = {
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]
@@ -99,9 +96,13 @@ class ReplicationCommandHandler:
# The factory used to create connections.
self._factory = None # type: Optional[ReconnectingClientFactory]
- # The currently connected connections.
+ # The currently connected connections. (The list of places we need to send
+ # outgoing replication commands to.)
self._connections = [] # type: List[AbstractConnection]
+ # For each connection, the incoming streams that are coming from that connection
+ self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
+
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
"",
@@ -257,12 +258,14 @@ class ReplicationCommandHandler:
# 2. so we don't race with getting a POSITION command and fetching
# missing RDATA.
with await self._position_linearizer.queue(cmd.stream_name):
- if stream_name not in self._streams_connected:
- # If the stream isn't marked as connected then we haven't seen a
- # `POSITION` command yet, and so we may have missed some rows.
+ # make sure that we've processed a POSITION for this stream *on this
+ # connection*. (A POSITION on another connection is no good, as there
+ # is no guarantee that we have seen all the intermediate updates.)
+ sbc = self._streams_by_connection.get(conn)
+ if not sbc or stream_name not in sbc:
# Let's drop the row for now, on the assumption we'll receive a
# `POSITION` soon and we'll catch up correctly then.
- logger.warning(
+ logger.debug(
"Discarding RDATA for unconnected stream %s -> %s",
stream_name,
cmd.token,
@@ -302,21 +305,25 @@ class ReplicationCommandHandler:
# Ignore POSITION that are just our own echoes
return
- stream = self._streams.get(cmd.stream_name)
+ logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
+
+ stream_name = cmd.stream_name
+ stream = self._streams.get(stream_name)
if not stream:
- logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
+ logger.error("Got POSITION for unknown stream: %s", stream_name)
return
# We protect catching up with a linearizer in case the replication
# connection reconnects under us.
- with await self._position_linearizer.queue(cmd.stream_name):
+ with await self._position_linearizer.queue(stream_name):
# We're about to go and catch up with the stream, so remove from set
# of connected streams.
- self._streams_connected.discard(cmd.stream_name)
+ for streams in self._streams_by_connection.values():
+ streams.discard(stream_name)
# We clear the pending batches for the stream as the fetching of the
# missing updates below will fetch all rows in the batch.
- self._pending_batches.pop(cmd.stream_name, [])
+ self._pending_batches.pop(stream_name, [])
# Find where we previously streamed up to.
current_token = stream.current_token()
@@ -326,6 +333,12 @@ class ReplicationCommandHandler:
# between then and now.
missing_updates = cmd.token != current_token
while missing_updates:
+ logger.info(
+ "Fetching replication rows for '%s' between %i and %i",
+ stream_name,
+ current_token,
+ cmd.token,
+ )
(
updates,
current_token,
@@ -341,16 +354,18 @@ class ReplicationCommandHandler:
for token, rows in _batch_updates(updates):
await self.on_rdata(
- cmd.stream_name,
+ stream_name,
cmd.instance_name,
token,
[stream.parse_row(row) for row in rows],
)
+ logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
+
# We've now caught up to position sent to us, notify handler.
- await self._replication_data_handler.on_position(cmd.stream_name, cmd.token)
+ await self._replication_data_handler.on_position(stream_name, cmd.token)
- self._streams_connected.add(cmd.stream_name)
+ self._streams_by_connection.setdefault(conn, set()).add(stream_name)
async def on_REMOTE_SERVER_UP(
self, conn: AbstractConnection, cmd: RemoteServerUpCommand
@@ -408,6 +423,12 @@ class ReplicationCommandHandler:
def lost_connection(self, connection: AbstractConnection):
"""Called when a connection is closed/lost.
"""
+ # we no longer need _streams_by_connection for this connection.
+ streams = self._streams_by_connection.pop(connection, None)
+ if streams:
+ logger.info(
+ "Lost replication connection; streams now disconnected: %s", streams
+ )
try:
self._connections.remove(connection)
except ValueError:
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 617e860f95..db69f92557 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING
import txredisapi
-from synapse.logging.context import PreserveLoggingContext
+from synapse.logging.context import make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import (
Command,
@@ -41,8 +41,14 @@ logger = logging.getLogger(__name__)
class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
"""Connection to redis subscribed to replication stream.
- Parses incoming messages from redis into replication commands, and passes
- them to `ReplicationCommandHandler`
+ This class fulfils two functions:
+
+ (a) it implements the twisted Protocol API, where it handles the SUBSCRIBEd redis
+ connection, parsing *incoming* messages into replication commands, and passing them
+ to `ReplicationCommandHandler`
+
+ (b) it implements the AbstractConnection API, where it sends *outgoing* commands
+ onto outbound_redis_connection.
Due to the vagaries of `txredisapi` we don't want to have a custom
constructor, so instead we expect the defined attributes below to be set
@@ -50,8 +56,8 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
Attributes:
handler: The command handler to handle incoming commands.
- stream_name: The *redis* stream name to subscribe to (not anything to
- do with Synapse replication streams).
+ stream_name: The *redis* stream name to subscribe to and publish from
+ (not anything to do with Synapse replication streams).
outbound_redis_connection: The connection to redis to use to send
commands.
"""
@@ -61,12 +67,23 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
outbound_redis_connection = None # type: txredisapi.RedisProtocol
def connectionMade(self):
- logger.info("Connected to redis instance")
- self.subscribe(self.stream_name)
- self.send_command(ReplicateCommand())
-
+ logger.info("Connected to redis")
+ super().connectionMade()
+ run_as_background_process("subscribe-replication", self._send_subscribe)
self.handler.new_connection(self)
+ async def _send_subscribe(self):
+ # it's important to make sure that we only send the REPLICATE command once we
+ # have successfully subscribed to the stream - otherwise we might miss the
+ # POSITION response sent back by the other end.
+ logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
+ await make_deferred_yieldable(self.subscribe(self.stream_name))
+ logger.info(
+ "Successfully subscribed to redis stream, sending REPLICATE command"
+ )
+ await self._async_send_command(ReplicateCommand())
+ logger.info("REPLICATE successfully sent")
+
def messageReceived(self, pattern: str, channel: str, message: str):
"""Received a message from redis.
"""
@@ -119,7 +136,8 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
logger.warning("Unhandled command: %r", cmd)
def connectionLost(self, reason):
- logger.info("Lost connection to redis instance")
+ logger.info("Lost connection to redis")
+ super().connectionLost(reason)
self.handler.lost_connection(self)
def send_command(self, cmd: Command):
@@ -128,6 +146,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
Args:
cmd (Command)
"""
+ run_as_background_process("send-cmd", self._async_send_command, cmd)
+
+ async def _async_send_command(self, cmd: Command):
+ """Encode a replication command and send it over our outbound connection"""
string = "%s %s" % (cmd.NAME, cmd.to_line())
if "\n" in string:
raise Exception("Unexpected newline in command: %r", string)
@@ -138,15 +160,9 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
# remote instances.
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
- async def _send():
- with PreserveLoggingContext():
- # Note that we use the other connection as we can't send
- # commands using the subscription connection.
- await self.outbound_redis_connection.publish(
- self.stream_name, encoded_string
- )
-
- run_as_background_process("send-cmd", _send)
+ await make_deferred_yieldable(
+ self.outbound_redis_connection.publish(self.stream_name, encoded_string)
+ )
class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
@@ -189,5 +205,6 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
p.handler = self.handler
p.outbound_redis_connection = self.outbound_redis_connection
p.stream_name = self.stream_name
+ p.password = self.password
return p
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 33d2f589ac..b690abedad 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -80,7 +80,7 @@ class ReplicationStreamer(object):
for stream in STREAMS_MAP.values():
if stream == FederationStream and hs.config.send_federation:
# We only support federation stream if federation sending
- # hase been disabled on the master.
+ # has been disabled on the master.
continue
self.streams.append(stream(hs))
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b0f87c365b..084604e8b0 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -104,7 +104,8 @@ class Stream(object):
implemented by subclasses.
current_token_function is called to get the current token of the underlying
- stream.
+ stream. It is only meaningful on the process that is the source of the
+ replication stream (ie, usually the master).
update_function is called to get updates for this stream between a pair of
stream tokens. See the UpdateFunction type definition for more info.
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index e8bd52e389..b0505b8a2c 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,7 +15,7 @@
# limitations under the License.
from collections import namedtuple
-from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
+from synapse.replication.tcp.streams._base import Stream, make_http_update_function
class FederationStream(Stream):
@@ -35,21 +35,33 @@ class FederationStream(Stream):
ROW_TYPE = FederationStreamRow
def __init__(self, hs):
- # Not all synapse instances will have a federation sender instance,
- # whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
- # so we stub the stream out when that is the case.
- if hs.config.worker_app is None or hs.should_send_federation():
+ if hs.config.worker_app is None:
+ # master process: get updates from the FederationRemoteSendQueue.
+ # (if the master is configured to send federation itself, federation_sender
+ # will be a real FederationSender, which has stubs for current_token and
+ # get_replication_rows.)
federation_sender = hs.get_federation_sender()
current_token = federation_sender.get_current_token
- update_function = db_query_to_update_function(
- federation_sender.get_replication_rows
- )
+ update_function = federation_sender.get_replication_rows
+
+ elif hs.should_send_federation():
+ # federation sender: Query master process
+ update_function = make_http_update_function(hs, self.NAME)
+ current_token = self._stub_current_token
+
else:
- current_token = lambda: 0
+ # other worker: stub out the update function (we're not interested in
+ # any updates so when we get a POSITION we do nothing)
update_function = self._stub_update_function
+ current_token = self._stub_current_token
super().__init__(hs.get_instance_name(), current_token, update_function)
@staticmethod
+ def _stub_current_token():
+ # dummy current-token method for use on workers
+ return 0
+
+ @staticmethod
async def _stub_update_function(instance_name, from_token, upto_token, limit):
return [], upto_token, False
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index ee3a2ab031..fe6d6ecfe0 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -55,6 +55,10 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
+BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX = (
+ "drop_device_lists_outbound_last_success_non_unique_idx"
+)
+
class DeviceWorkerStore(SQLBaseStore):
def get_device(self, user_id, device_id):
@@ -342,32 +346,23 @@ class DeviceWorkerStore(SQLBaseStore):
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
# We update the device_lists_outbound_last_success with the successfully
- # poked users. We do the join to see which users need to be inserted and
- # which updated.
+ # poked users.
sql = """
- SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
+ SELECT user_id, coalesce(max(o.stream_id), 0)
FROM device_lists_outbound_pokes as o
- LEFT JOIN device_lists_outbound_last_success as s
- USING (destination, user_id)
WHERE destination = ? AND o.stream_id <= ?
GROUP BY user_id
"""
txn.execute(sql, (destination, stream_id))
rows = txn.fetchall()
- sql = """
- UPDATE device_lists_outbound_last_success
- SET stream_id = ?
- WHERE destination = ? AND user_id = ?
- """
- txn.executemany(sql, ((row[1], destination, row[0]) for row in rows if row[2]))
-
- sql = """
- INSERT INTO device_lists_outbound_last_success
- (destination, user_id, stream_id) VALUES (?, ?, ?)
- """
- txn.executemany(
- sql, ((destination, row[0], row[1]) for row in rows if not row[2])
+ self.db.simple_upsert_many_txn(
+ txn=txn,
+ table="device_lists_outbound_last_success",
+ key_names=("destination", "user_id"),
+ key_values=((destination, user_id) for user_id, _ in rows),
+ value_names=("stream_id",),
+ value_values=((stream_id,) for _, stream_id in rows),
)
# Delete all sent outbound pokes
@@ -541,8 +536,8 @@ class DeviceWorkerStore(SQLBaseStore):
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
- to_check = list(
- self._device_list_stream_cache.get_entities_changed(user_ids, from_key)
+ to_check = self._device_list_stream_cache.get_entities_changed(
+ user_ids, from_key
)
if not to_check:
@@ -725,6 +720,21 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
)
+ # create a unique index on device_lists_outbound_last_success
+ self.db.updates.register_background_index_update(
+ "device_lists_outbound_last_success_unique_idx",
+ index_name="device_lists_outbound_last_success_unique_idx",
+ table="device_lists_outbound_last_success",
+ columns=["destination", "user_id"],
+ unique=True,
+ )
+
+ # once that completes, we can remove the old non-unique index.
+ self.db.updates.register_background_update_handler(
+ BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX,
+ self._drop_device_lists_outbound_last_success_non_unique_idx,
+ )
+
@defer.inlineCallbacks
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
def f(conn):
@@ -799,6 +809,20 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
return rows
+ async def _drop_device_lists_outbound_last_success_non_unique_idx(
+ self, progress, batch_size
+ ):
+ def f(txn):
+ txn.execute("DROP INDEX IF EXISTS device_lists_outbound_last_success_idx")
+
+ await self.db.runInteraction(
+ "drop_device_lists_outbound_last_success_non_unique_idx", f,
+ )
+ await self.db.updates._end_background_update(
+ BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX
+ )
+ return 1
+
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def __init__(self, database: Database, db_conn, hs):
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index bcf746b7ef..20698bfd16 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -25,7 +25,9 @@ from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import make_in_list_sql_clause
from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.iterutils import batch_iter
class EndToEndKeyWorkerStore(SQLBaseStore):
@@ -268,53 +270,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
- def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None):
- """Returns a user's cross-signing key.
-
- Args:
- txn (twisted.enterprise.adbapi.Connection): db connection
- user_id (str): the user whose key is being requested
- key_type (str): the type of key that is being requested: either 'master'
- for a master key, 'self_signing' for a self-signing key, or
- 'user_signing' for a user-signing key
- from_user_id (str): if specified, signatures made by this user on
- the key will be included in the result
-
- Returns:
- dict of the key data or None if not found
- """
- sql = (
- "SELECT keydata "
- " FROM e2e_cross_signing_keys "
- " WHERE user_id = ? AND keytype = ? ORDER BY stream_id DESC LIMIT 1"
- )
- txn.execute(sql, (user_id, key_type))
- row = txn.fetchone()
- if not row:
- return None
- key = json.loads(row[0])
-
- device_id = None
- for k in key["keys"].values():
- device_id = k
-
- if from_user_id is not None:
- sql = (
- "SELECT key_id, signature "
- " FROM e2e_cross_signing_signatures "
- " WHERE user_id = ? "
- " AND target_user_id = ? "
- " AND target_device_id = ? "
- )
- txn.execute(sql, (from_user_id, user_id, device_id))
- row = txn.fetchone()
- if row:
- key.setdefault("signatures", {}).setdefault(from_user_id, {})[
- row[0]
- ] = row[1]
-
- return key
-
+ @defer.inlineCallbacks
def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
"""Returns a user's cross-signing key.
@@ -329,13 +285,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
Returns:
dict of the key data or None if not found
"""
- return self.db.runInteraction(
- "get_e2e_cross_signing_key",
- self._get_e2e_cross_signing_key_txn,
- user_id,
- key_type,
- from_user_id,
- )
+ res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
+ user_keys = res.get(user_id)
+ if not user_keys:
+ return None
+ return user_keys.get(key_type)
@cached(num_args=1)
def _get_bare_e2e_cross_signing_keys(self, user_id):
@@ -391,26 +345,24 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"""
result = {}
- batch_size = 100
- chunks = [
- user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size)
- ]
- for user_chunk in chunks:
- sql = """
+ for user_chunk in batch_iter(user_ids, 100):
+ clause, params = make_in_list_sql_clause(
+ txn.database_engine, "k.user_id", user_chunk
+ )
+ sql = (
+ """
SELECT k.user_id, k.keytype, k.keydata, k.stream_id
FROM e2e_cross_signing_keys k
INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
FROM e2e_cross_signing_keys
GROUP BY user_id, keytype) s
USING (user_id, stream_id, keytype)
- WHERE k.user_id IN (%s)
- """ % (
- ",".join("?" for u in user_chunk),
+ WHERE
+ """
+ + clause
)
- query_params = []
- query_params.extend(user_chunk)
- txn.execute(sql, query_params)
+ txn.execute(sql, params)
rows = self.db.cursor_to_dict(txn)
for row in rows:
@@ -453,15 +405,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
device_id = k
devices[(user_id, device_id)] = key_type
- device_list = list(devices)
-
- # split into batches
- batch_size = 100
- chunks = [
- device_list[i : i + batch_size]
- for i in range(0, len(device_list), batch_size)
- ]
- for user_chunk in chunks:
+ for batch in batch_iter(devices.keys(), size=100):
sql = """
SELECT target_user_id, target_device_id, key_id, signature
FROM e2e_cross_signing_signatures
@@ -469,11 +413,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
AND (%s)
""" % (
" OR ".join(
- "(target_user_id = ? AND target_device_id = ?)" for d in devices
+ "(target_user_id = ? AND target_device_id = ?)" for _ in batch
)
)
query_params = [from_user_id]
- for item in devices:
+ for item in batch:
# item is a (user_id, device_id) tuple
query_params.extend(item)
diff --git a/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql b/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql
new file mode 100644
index 0000000000..d5e6deb878
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql
@@ -0,0 +1,28 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- register a background update which will create a unique index on
+-- device_lists_outbound_last_success
+INSERT into background_updates (ordering, update_name, progress_json)
+ VALUES (5804, 'device_lists_outbound_last_success_unique_idx', '{}');
+
+-- once that completes, we can drop the old index.
+INSERT into background_updates (ordering, update_name, progress_json, depends_on)
+ VALUES (
+ 5804,
+ 'drop_device_lists_outbound_last_success_non_unique_idx',
+ '{}',
+ 'device_lists_outbound_last_success_unique_idx'
+ );
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a7cd97b0b0..2b635d6ca0 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -49,6 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
+from synapse.types import Collection
from synapse.util.stringutils import exception_to_unicode
logger = logging.getLogger(__name__)
@@ -78,6 +79,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
"device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
"device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
"event_search": "event_search_event_id_idx",
+ "device_lists_outbound_last_success": "device_lists_outbound_last_success_unique_idx",
}
@@ -889,20 +891,24 @@ class Database(object):
txn.execute(sql, list(allvalues.values()))
def simple_upsert_many_txn(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ key_names: Collection[str],
+ key_values: Collection[Iterable[Any]],
+ value_names: Collection[str],
+ value_values: Iterable[Iterable[str]],
+ ) -> None:
"""
Upsert, many times.
Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
+ table: The table to upsert into
+ key_names: The key column names.
+ key_values: A list of each row's key column values.
+ value_names: The value column names
+ value_values: A list of each row's value column values.
+ Ignored if value_names is empty.
"""
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
return self.simple_upsert_many_txn_native_upsert(
@@ -914,20 +920,24 @@ class Database(object):
)
def simple_upsert_many_txn_emulated(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ key_names: Iterable[str],
+ key_values: Collection[Iterable[Any]],
+ value_names: Collection[str],
+ value_values: Iterable[Iterable[str]],
+ ) -> None:
"""
Upsert, many times, but without native UPSERT support or batching.
Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
+ table: The table to upsert into
+ key_names: The key column names.
+ key_values: A list of each row's key column values.
+ value_names: The value column names
+ value_values: A list of each row's value column values.
+ Ignored if value_names is empty.
"""
# No value columns, therefore make a blank list so that the following
# zip() works correctly.
@@ -941,20 +951,24 @@ class Database(object):
self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
def simple_upsert_many_txn_native_upsert(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ key_names: Collection[str],
+ key_values: Collection[Iterable[Any]],
+ value_names: Collection[str],
+ value_values: Iterable[Iterable[Any]],
+ ) -> None:
"""
Upsert, many times, using batching where possible.
Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
+ table: The table to upsert into
+ key_names: The key column names.
+ key_values: A list of each row's key column values.
+ value_names: The value column names
+ value_values: A list of each row's value column values.
+ Ignored if value_names is empty.
"""
allnames = [] # type: List[str]
allnames.extend(key_names)
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 9d851beaa5..86d04ea9ac 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -16,6 +16,11 @@
import contextlib
import threading
from collections import deque
+from typing import Dict, Set, Tuple
+
+from typing_extensions import Deque
+
+from synapse.storage.database import Database, LoggingTransaction
class IdGenerator(object):
@@ -87,7 +92,7 @@ class StreamIdGenerator(object):
self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step)
)
- self._unfinished_ids = deque()
+ self._unfinished_ids = deque() # type: Deque[int]
def get_next(self):
"""
@@ -163,7 +168,7 @@ class ChainedIdGenerator(object):
self.chained_generator = chained_generator
self._lock = threading.Lock()
self._current_max = _load_current_id(db_conn, table, column)
- self._unfinished_ids = deque()
+ self._unfinished_ids = deque() # type: Deque[Tuple[int, int]]
def get_next(self):
"""
@@ -198,3 +203,163 @@ class ChainedIdGenerator(object):
return stream_id - 1, chained_id
return self._current_max, self.chained_generator.get_current_token()
+
+
+class MultiWriterIdGenerator:
+ """An ID generator that tracks a stream that can have multiple writers.
+
+ Uses a Postgres sequence to coordinate ID assignment, but positions of other
+ writers will only get updated when `advance` is called (by replication).
+
+ Note: Only works with Postgres.
+
+ Args:
+ db_conn
+ db
+ instance_name: The name of this instance.
+ table: Database table associated with stream.
+ instance_column: Column that stores the row's writer's instance name
+ id_column: Column that stores the stream ID.
+ sequence_name: The name of the postgres sequence used to generate new
+ IDs.
+ """
+
+ def __init__(
+ self,
+ db_conn,
+ db: Database,
+ instance_name: str,
+ table: str,
+ instance_column: str,
+ id_column: str,
+ sequence_name: str,
+ ):
+ self._db = db
+ self._instance_name = instance_name
+ self._sequence_name = sequence_name
+
+ # We lock as some functions may be called from DB threads.
+ self._lock = threading.Lock()
+
+ self._current_positions = self._load_current_ids(
+ db_conn, table, instance_column, id_column
+ )
+
+ # Set of local IDs that we're still processing. The current position
+ # should be less than the minimum of this set (if not empty).
+ self._unfinished_ids = set() # type: Set[int]
+
+ def _load_current_ids(
+ self, db_conn, table: str, instance_column: str, id_column: str
+ ) -> Dict[str, int]:
+ sql = """
+ SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
+ GROUP BY %(instance)s
+ """ % {
+ "instance": instance_column,
+ "id": id_column,
+ "table": table,
+ }
+
+ cur = db_conn.cursor()
+ cur.execute(sql)
+
+ # `cur` is an iterable over returned rows, which are 2-tuples.
+ current_positions = dict(cur)
+
+ cur.close()
+
+ return current_positions
+
+ def _load_next_id_txn(self, txn):
+ txn.execute("SELECT nextval(?)", (self._sequence_name,))
+ (next_id,) = txn.fetchone()
+ return next_id
+
+ async def get_next(self):
+ """
+ Usage:
+ with await stream_id_gen.get_next() as stream_id:
+ # ... persist event ...
+ """
+ next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
+
+ # Assert the fetched ID is actually greater than what we currently
+ # believe the ID to be. If not, then the sequence and table have got
+ # out of sync somehow.
+ assert self.get_current_token() < next_id
+
+ with self._lock:
+ self._unfinished_ids.add(next_id)
+
+ @contextlib.contextmanager
+ def manager():
+ try:
+ yield next_id
+ finally:
+ self._mark_id_as_finished(next_id)
+
+ return manager()
+
+ def get_next_txn(self, txn: LoggingTransaction):
+ """
+ Usage:
+
+ stream_id = stream_id_gen.get_next(txn)
+ # ... persist event ...
+ """
+
+ next_id = self._load_next_id_txn(txn)
+
+ with self._lock:
+ self._unfinished_ids.add(next_id)
+
+ txn.call_after(self._mark_id_as_finished, next_id)
+ txn.call_on_exception(self._mark_id_as_finished, next_id)
+
+ return next_id
+
+ def _mark_id_as_finished(self, next_id: int):
+ """The ID has finished being processed so we should advance the
+ current poistion if possible.
+ """
+
+ with self._lock:
+ self._unfinished_ids.discard(next_id)
+
+ # Figure out if its safe to advance the position by checking there
+ # aren't any lower allocated IDs that are yet to finish.
+ if all(c > next_id for c in self._unfinished_ids):
+ curr = self._current_positions.get(self._instance_name, 0)
+ self._current_positions[self._instance_name] = max(curr, next_id)
+
+ def get_current_token(self, instance_name: str = None) -> int:
+ """Gets the current position of a named writer (defaults to current
+ instance).
+
+ Returns 0 if we don't have a position for the named writer (likely due
+ to it being a new writer).
+ """
+
+ if instance_name is None:
+ instance_name = self._instance_name
+
+ with self._lock:
+ return self._current_positions.get(instance_name, 0)
+
+ def get_positions(self) -> Dict[str, int]:
+ """Get a copy of the current positon map.
+ """
+
+ with self._lock:
+ return dict(self._current_positions)
+
+ def advance(self, instance_name: str, new_id: int):
+ """Advance the postion of the named writer to the given ID, if greater
+ than existing entry.
+ """
+
+ with self._lock:
+ self._current_positions[instance_name] = max(
+ new_id, self._current_positions.get(instance_name, 0)
+ )
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 38dc3f501e..e54f80d76e 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -14,12 +14,13 @@
# limitations under the License.
import logging
-from typing import Dict, Iterable, List, Mapping, Optional, Set
+from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Union
from six import integer_types
from sortedcontainers import SortedDict
+from synapse.types import Collection
from synapse.util import caches
logger = logging.getLogger(__name__)
@@ -85,8 +86,8 @@ class StreamChangeCache:
return False
def get_entities_changed(
- self, entities: Iterable[EntityType], stream_pos: int
- ) -> Set[EntityType]:
+ self, entities: Collection[EntityType], stream_pos: int
+ ) -> Union[Set[EntityType], FrozenSet[EntityType]]:
"""
Returns subset of entities that have had new things since the given
position. Entities unknown to the cache will be returned. If the
@@ -94,7 +95,17 @@ class StreamChangeCache:
"""
changed_entities = self.get_all_entities_changed(stream_pos)
if changed_entities is not None:
- result = set(changed_entities).intersection(entities)
+ # We now do an intersection, trying to do so in the most efficient
+ # way possible (some of these sets are *large*). First check in the
+ # given iterable is already set that we can reuse, otherwise we
+ # create a set of the *smallest* of the two iterables and call
+ # `intersection(..)` on it (this can be twice as fast as the reverse).
+ if isinstance(entities, (set, frozenset)):
+ result = entities.intersection(changed_entities)
+ elif len(changed_entities) < len(entities):
+ result = set(changed_entities).intersection(entities)
+ else:
+ result = set(entities).intersection(changed_entities)
self.metrics.inc_hits()
else:
result = set(entities)
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index cc0b10e7f6..0bfb86bf1f 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -52,6 +52,10 @@ class AuthTestCase(unittest.TestCase):
self.hs.handlers = TestHandlers(self.hs)
self.auth = Auth(self.hs)
+ # AuthBlocking reads from the hs' config on initialization. We need to
+ # modify its config instead of the hs'
+ self.auth_blocking = self.auth._auth_blocking
+
self.test_user = "@foo:bar"
self.test_token = b"_test_token_"
@@ -321,15 +325,15 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_blocking_mau(self):
- self.hs.config.limit_usage_by_mau = False
- self.hs.config.max_mau_value = 50
+ self.auth_blocking._limit_usage_by_mau = False
+ self.auth_blocking._max_mau_value = 50
lots_of_users = 100
small_number_of_users = 1
# Ensure no error thrown
yield defer.ensureDeferred(self.auth.check_auth_blocking())
- self.hs.config.limit_usage_by_mau = True
+ self.auth_blocking._limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(lots_of_users)
@@ -349,8 +353,8 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_blocking_mau__depending_on_user_type(self):
- self.hs.config.max_mau_value = 50
- self.hs.config.limit_usage_by_mau = True
+ self.auth_blocking._max_mau_value = 50
+ self.auth_blocking._limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Support users allowed
@@ -370,12 +374,12 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_reserved_threepid(self):
- self.hs.config.limit_usage_by_mau = True
- self.hs.config.max_mau_value = 1
+ self.auth_blocking._limit_usage_by_mau = True
+ self.auth_blocking._max_mau_value = 1
self.store.get_monthly_active_count = lambda: defer.succeed(2)
threepid = {"medium": "email", "address": "reserved@server.com"}
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
- self.hs.config.mau_limits_reserved_threepids = [threepid]
+ self.auth_blocking._mau_limits_reserved_threepids = [threepid]
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(self.auth.check_auth_blocking())
@@ -389,8 +393,8 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_hs_disabled(self):
- self.hs.config.hs_disabled = True
- self.hs.config.hs_disabled_message = "Reason for being disabled"
+ self.auth_blocking._hs_disabled = True
+ self.auth_blocking._hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e:
yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
@@ -404,10 +408,10 @@ class AuthTestCase(unittest.TestCase):
"""
# this should be the default, but we had a bug where the test was doing the wrong
# thing, so let's make it explicit
- self.hs.config.server_notices_mxid = None
+ self.auth_blocking._server_notices_mxid = None
- self.hs.config.hs_disabled = True
- self.hs.config.hs_disabled_message = "Reason for being disabled"
+ self.auth_blocking._hs_disabled = True
+ self.auth_blocking._hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e:
yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
@@ -416,8 +420,8 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_server_notices_mxid_special_cased(self):
- self.hs.config.hs_disabled = True
+ self.auth_blocking._hs_disabled = True
user = "@user:server"
- self.hs.config.server_notices_mxid = user
- self.hs.config.hs_disabled_message = "Reason for being disabled"
+ self.auth_blocking._server_notices_mxid = user
+ self.auth_blocking._hs_disabled_message = "Reason for being disabled"
yield defer.ensureDeferred(self.auth.check_auth_blocking(user))
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
new file mode 100644
index 0000000000..640f5f3bce
--- /dev/null
+++ b/tests/events/test_snapshot.py
@@ -0,0 +1,100 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.events.snapshot import EventContext
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests import unittest
+from tests.test_utils.event_injection import create_event
+
+
+class TestEventContext(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
+
+ self.user_id = self.register_user("u1", "pass")
+ self.user_tok = self.login("u1", "pass")
+ self.room_id = self.helper.create_room_as(tok=self.user_tok)
+
+ def test_serialize_deserialize_msg(self):
+ """Test that an EventContext for a message event is the same after
+ serialize/deserialize.
+ """
+
+ event, context = create_event(
+ self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
+ )
+
+ self._check_serialize_deserialize(event, context)
+
+ def test_serialize_deserialize_state_no_prev(self):
+ """Test that an EventContext for a state event (with not previous entry)
+ is the same after serialize/deserialize.
+ """
+ event, context = create_event(
+ self.hs,
+ room_id=self.room_id,
+ type="m.test",
+ sender=self.user_id,
+ state_key="",
+ )
+
+ self._check_serialize_deserialize(event, context)
+
+ def test_serialize_deserialize_state_prev(self):
+ """Test that an EventContext for a state event (which replaces a
+ previous entry) is the same after serialize/deserialize.
+ """
+ event, context = create_event(
+ self.hs,
+ room_id=self.room_id,
+ type="m.room.member",
+ sender=self.user_id,
+ state_key=self.user_id,
+ content={"membership": "leave"},
+ )
+
+ self._check_serialize_deserialize(event, context)
+
+ def _check_serialize_deserialize(self, event, context):
+ serialized = self.get_success(context.serialize(event, self.store))
+
+ d_context = EventContext.deserialize(self.storage, serialized)
+
+ self.assertEqual(context.state_group, d_context.state_group)
+ self.assertEqual(context.rejected, d_context.rejected)
+ self.assertEqual(
+ context.state_group_before_event, d_context.state_group_before_event
+ )
+ self.assertEqual(context.prev_group, d_context.prev_group)
+ self.assertEqual(context.delta_ids, d_context.delta_ids)
+ self.assertEqual(context.app_service, d_context.app_service)
+
+ self.assertEqual(
+ self.get_success(context.get_current_state_ids()),
+ self.get_success(d_context.get_current_state_ids()),
+ )
+ self.assertEqual(
+ self.get_success(context.get_prev_state_ids()),
+ self.get_success(d_context.get_prev_state_ids()),
+ )
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 52c4ac8b11..c01b04e1dc 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -39,8 +39,13 @@ class AuthTestCase(unittest.TestCase):
self.hs.handlers = AuthHandlers(self.hs)
self.auth_handler = self.hs.handlers.auth_handler
self.macaroon_generator = self.hs.get_macaroon_generator()
+
# MAU tests
- self.hs.config.max_mau_value = 50
+ # AuthBlocking reads from the hs' config on initialization. We need to
+ # modify its config instead of the hs'
+ self.auth_blocking = self.hs.get_auth()._auth_blocking
+ self.auth_blocking._max_mau_value = 50
+
self.small_number_of_users = 1
self.large_number_of_users = 100
@@ -119,7 +124,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_mau_limits_disabled(self):
- self.hs.config.limit_usage_by_mau = False
+ self.auth_blocking._limit_usage_by_mau = False
# Ensure does not throw exception
yield defer.ensureDeferred(
self.auth_handler.get_access_token_for_user_id(
@@ -135,7 +140,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_mau_limits_exceeded_large(self):
- self.hs.config.limit_usage_by_mau = True
+ self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users)
)
@@ -159,11 +164,11 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_mau_limits_parity(self):
- self.hs.config.limit_usage_by_mau = True
+ self.auth_blocking._limit_usage_by_mau = True
# If not in monthly active cohort
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ return_value=defer.succeed(self.auth_blocking._max_mau_value)
)
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(
@@ -173,7 +178,7 @@ class AuthTestCase(unittest.TestCase):
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ return_value=defer.succeed(self.auth_blocking._max_mau_value)
)
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(
@@ -186,7 +191,7 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ return_value=defer.succeed(self.auth_blocking._max_mau_value)
)
yield defer.ensureDeferred(
self.auth_handler.get_access_token_for_user_id(
@@ -197,7 +202,7 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
- return_value=defer.succeed(self.hs.config.max_mau_value)
+ return_value=defer.succeed(self.auth_blocking._max_mau_value)
)
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
@@ -207,7 +212,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_mau_limits_not_exceeded(self):
- self.hs.config.limit_usage_by_mau = True
+ self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.small_number_of_users)
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 4cbe9784ed..e178d7765b 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -30,28 +30,31 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler = self.hs.get_sync_handler()
self.store = self.hs.get_datastore()
- def test_wait_for_sync_for_user_auth_blocking(self):
+ # AuthBlocking reads from the hs' config on initialization. We need to
+ # modify its config instead of the hs'
+ self.auth_blocking = self.hs.get_auth()._auth_blocking
+ def test_wait_for_sync_for_user_auth_blocking(self):
user_id1 = "@user1:test"
user_id2 = "@user2:test"
sync_config = self._generate_sync_config(user_id1)
self.reactor.advance(100) # So we get not 0 time
- self.hs.config.limit_usage_by_mau = True
- self.hs.config.max_mau_value = 1
+ self.auth_blocking._limit_usage_by_mau = True
+ self.auth_blocking._max_mau_value = 1
# Check that the happy case does not throw errors
self.get_success(self.store.upsert_monthly_active_user(user_id1))
self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config))
# Test that global lock works
- self.hs.config.hs_disabled = True
+ self.auth_blocking._hs_disabled = True
e = self.get_failure(
self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
- self.hs.config.hs_disabled = False
+ self.auth_blocking._hs_disabled = False
sync_config = self._generate_sync_config(user_id2)
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
index 7b56d2028d..9d4f0bbe44 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/tcp/streams/_base.py
@@ -27,6 +27,7 @@ from synapse.app.generic_worker import (
GenericWorkerServer,
)
from synapse.http.site import SynapseRequest
+from synapse.replication.http import streams
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -42,6 +43,10 @@ logger = logging.getLogger(__name__)
class BaseStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests of the replication streams"""
+ servlets = [
+ streams.register_servlets,
+ ]
+
def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
@@ -49,17 +54,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.server = server_factory.buildProtocol(None)
# Make a new HomeServer object for the worker
- config = self.default_config()
- config["worker_app"] = "synapse.app.generic_worker"
- config["worker_replication_host"] = "testserv"
- config["worker_replication_http_port"] = "8765"
-
self.reactor.lookups["testserv"] = "1.2.3.4"
-
self.worker_hs = self.setup_test_homeserver(
http_client=None,
homeserverToUse=GenericWorkerServer,
- config=config,
+ config=self._get_worker_hs_config(),
reactor=self.reactor,
)
@@ -78,6 +77,13 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._client_transport = None
self._server_transport = None
+ def _get_worker_hs_config(self) -> dict:
+ config = self.default_config()
+ config["worker_app"] = "synapse.app.generic_worker"
+ config["worker_replication_host"] = "testserv"
+ config["worker_replication_http_port"] = "8765"
+ return config
+
def _build_replication_data_handler(self):
return TestReplicationDataHandler(self.worker_hs)
diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py
new file mode 100644
index 0000000000..eea4565da3
--- /dev/null
+++ b/tests/replication/tcp/streams/test_federation.py
@@ -0,0 +1,81 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.federation.send_queue import EduRow
+from synapse.replication.tcp.streams.federation import FederationStream
+
+from tests.replication.tcp.streams._base import BaseStreamTestCase
+
+
+class FederationStreamTestCase(BaseStreamTestCase):
+ def _get_worker_hs_config(self) -> dict:
+ # enable federation sending on the worker
+ config = super()._get_worker_hs_config()
+ # TODO: make it so we don't need both of these
+ config["send_federation"] = True
+ config["worker_app"] = "synapse.app.federation_sender"
+ return config
+
+ def test_catchup(self):
+ """Basic test of catchup on reconnect
+
+ Makes sure that updates sent while we are offline are received later.
+ """
+ fed_sender = self.hs.get_federation_sender()
+ received_rows = self.test_handler.received_rdata_rows
+
+ fed_sender.build_and_send_edu("testdest", "m.test_edu", {"a": "b"})
+
+ self.reconnect()
+ self.reactor.advance(0)
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual(received_rows, [])
+
+ # We should now see an attempt to connect to the master
+ request = self.handle_http_replication_attempt()
+ self.assert_request_is_get_repl_stream_updates(request, "federation")
+
+ # we should have received an update row
+ stream_name, token, row = received_rows.pop()
+ self.assertEqual(stream_name, "federation")
+ self.assertIsInstance(row, FederationStream.FederationStreamRow)
+ self.assertEqual(row.type, EduRow.TypeId)
+ edurow = EduRow.from_data(row.data)
+ self.assertEqual(edurow.edu.edu_type, "m.test_edu")
+ self.assertEqual(edurow.edu.origin, self.hs.hostname)
+ self.assertEqual(edurow.edu.destination, "testdest")
+ self.assertEqual(edurow.edu.content, {"a": "b"})
+
+ self.assertEqual(received_rows, [])
+
+ # additional updates should be transferred without an HTTP hit
+ fed_sender.build_and_send_edu("testdest", "m.test1", {"c": "d"})
+ self.reactor.advance(0)
+ # there should be no http hit
+ self.assertEqual(len(self.reactor.tcpClients), 0)
+ # ... but we should have a row
+ self.assertEqual(len(received_rows), 1)
+
+ stream_name, token, row = received_rows.pop()
+ self.assertEqual(stream_name, "federation")
+ self.assertIsInstance(row, FederationStream.FederationStreamRow)
+ self.assertEqual(row.type, EduRow.TypeId)
+ edurow = EduRow.from_data(row.data)
+ self.assertEqual(edurow.edu.edu_type, "m.test1")
+ self.assertEqual(edurow.edu.origin, self.hs.hostname)
+ self.assertEqual(edurow.edu.destination, "testdest")
+ self.assertEqual(edurow.edu.content, {"c": "d"})
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index d25a7b194e..125c63dab5 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -15,7 +15,6 @@
from mock import Mock
from synapse.handlers.typing import RoomMember
-from synapse.replication.http import streams
from synapse.replication.tcp.streams import TypingStream
from tests.replication.tcp.streams._base import BaseStreamTestCase
@@ -24,10 +23,6 @@ USER_ID = "@feeling:blue"
class TypingStreamTestCase(BaseStreamTestCase):
- servlets = [
- streams.register_servlets,
- ]
-
def _build_replication_data_handler(self):
return Mock(wraps=super()._build_replication_data_handler())
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
new file mode 100644
index 0000000000..55e9ecf264
--- /dev/null
+++ b/tests/storage/test_id_generators.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.storage.database import Database
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
+
+from tests.unittest import HomeserverTestCase
+from tests.utils import USE_POSTGRES_FOR_TESTS
+
+
+class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
+ if not USE_POSTGRES_FOR_TESTS:
+ skip = "Requires Postgres"
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.db = self.store.db # type: Database
+
+ self.get_success(self.db.runInteraction("_setup_db", self._setup_db))
+
+ def _setup_db(self, txn):
+ txn.execute("CREATE SEQUENCE foobar_seq")
+ txn.execute(
+ """
+ CREATE TABLE foobar (
+ stream_id BIGINT NOT NULL,
+ instance_name TEXT NOT NULL,
+ data TEXT
+ );
+ """
+ )
+
+ def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+ def _create(conn):
+ return MultiWriterIdGenerator(
+ conn,
+ self.db,
+ instance_name=instance_name,
+ table="foobar",
+ instance_column="instance_name",
+ id_column="stream_id",
+ sequence_name="foobar_seq",
+ )
+
+ return self.get_success(self.db.runWithConnection(_create))
+
+ def _insert_rows(self, instance_name: str, number: int):
+ def _insert(txn):
+ for _ in range(number):
+ txn.execute(
+ "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
+ (instance_name,),
+ )
+
+ self.get_success(self.db.runInteraction("test_single_instance", _insert))
+
+ def test_empty(self):
+ """Test an ID generator against an empty database gives sensible
+ current positions.
+ """
+
+ id_gen = self._create_id_generator()
+
+ # The table is empty so we expect an empty map for positions
+ self.assertEqual(id_gen.get_positions(), {})
+
+ def test_single_instance(self):
+ """Test that reads and writes from a single process are handled
+ correctly.
+ """
+
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
+
+ id_gen = self._create_id_generator()
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
+
+ async def _get_next_async():
+ with await id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 8)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(id_gen.get_positions(), {"master": 8})
+ self.assertEqual(id_gen.get_current_token("master"), 8)
+
+ def test_multi_instance(self):
+ """Test that reads and writes from multiple processes are handled
+ correctly.
+ """
+ self._insert_rows("first", 3)
+ self._insert_rows("second", 4)
+
+ first_id_gen = self._create_id_generator("first")
+ second_id_gen = self._create_id_generator("second")
+
+ self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
+ self.assertEqual(first_id_gen.get_current_token("first"), 3)
+ self.assertEqual(first_id_gen.get_current_token("second"), 7)
+
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
+
+ async def _get_next_async():
+ with await first_id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 8)
+
+ self.assertEqual(
+ first_id_gen.get_positions(), {"first": 3, "second": 7}
+ )
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(first_id_gen.get_positions(), {"first": 8, "second": 7})
+
+ # However the ID gen on the second instance won't have seen the update
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+
+ # ... but calling `get_next` on the second instance should give a unique
+ # stream ID
+
+ async def _get_next_async():
+ with await second_id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 9)
+
+ self.assertEqual(
+ second_id_gen.get_positions(), {"first": 3, "second": 7}
+ )
+
+ self.get_success(_get_next_async())
+
+ self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9})
+
+ # If the second ID gen gets told about the first, it correctly updates
+ second_id_gen.advance("first", 8)
+ self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
+
+ def test_get_next_txn(self):
+ """Test that the `get_next_txn` function works correctly.
+ """
+
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
+
+ id_gen = self._create_id_generator()
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ # Try allocating a new ID gen and check that we only see position
+ # advanced after we leave the context manager.
+
+ def _get_next_txn(txn):
+ stream_id = id_gen.get_next_txn(txn)
+ self.assertEqual(stream_id, 8)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token("master"), 7)
+
+ self.get_success(self.db.runInteraction("test", _get_next_txn))
+
+ self.assertEqual(id_gen.get_positions(), {"master": 8})
+ self.assertEqual(id_gen.get_current_token("master"), 8)
diff --git a/tests/test_mau.py b/tests/test_mau.py
index 1fbe0d51ff..eb159e3ba5 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -19,6 +19,7 @@ import json
from mock import Mock
+from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.rest.client.v2_alpha import register, sync
@@ -45,11 +46,17 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.hs.config.limit_usage_by_mau = True
self.hs.config.hs_disabled = False
self.hs.config.max_mau_value = 2
- self.hs.config.mau_trial_days = 0
self.hs.config.server_notices_mxid = "@server:red"
self.hs.config.server_notices_mxid_display_name = None
self.hs.config.server_notices_mxid_avatar_url = None
self.hs.config.server_notices_room_name = "Test Server Notice Room"
+ self.hs.config.mau_trial_days = 0
+
+ # AuthBlocking reads config options during hs creation. Recreate the
+ # hs' copy of AuthBlocking after we've updated config values above
+ self.auth_blocking = AuthBlocking(self.hs)
+ self.hs.get_auth()._auth_blocking = self.auth_blocking
+
return self.hs
def test_simple_deny_mau(self):
@@ -121,6 +128,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_trial_users_cant_come_back(self):
+ self.auth_blocking._mau_trial_days = 1
self.hs.config.mau_trial_days = 1
# We should be able to register more than the limit initially
@@ -169,8 +177,8 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_tracked_but_not_limited(self):
- self.hs.config.max_mau_value = 1 # should not matter
- self.hs.config.limit_usage_by_mau = False
+ self.auth_blocking._max_mau_value = 1 # should not matter
+ self.auth_blocking._limit_usage_by_mau = False
self.hs.config.mau_stats_only = True
# Simply being able to create 2 users indicates that the
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 8f6872761a..431e9f8e5e 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -14,12 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional
+from typing import Optional, Tuple
import synapse.server
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
from synapse.types import Collection
from tests.test_utils import get_awaitable_result
@@ -75,6 +76,23 @@ def inject_event(
"""
test_reactor = hs.get_reactor()
+ event, context = create_event(hs, room_version, prev_event_ids, **kwargs)
+
+ d = hs.get_storage().persistence.persist_event(event, context)
+ test_reactor.advance(0)
+ get_awaitable_result(d)
+
+ return event
+
+
+def create_event(
+ hs: synapse.server.HomeServer,
+ room_version: Optional[str] = None,
+ prev_event_ids: Optional[Collection[str]] = None,
+ **kwargs
+) -> Tuple[EventBase, EventContext]:
+ test_reactor = hs.get_reactor()
+
if room_version is None:
d = hs.get_datastore().get_room_version_id(kwargs["room_id"])
test_reactor.advance(0)
@@ -89,8 +107,4 @@ def inject_event(
test_reactor.advance(0)
event, context = get_awaitable_result(d)
- d = hs.get_storage().persistence.persist_event(event, context)
- test_reactor.advance(0)
- get_awaitable_result(d)
-
- return event
+ return event, context
diff --git a/tox.ini b/tox.ini
index eccc44e436..ad4ed8299e 100644
--- a/tox.ini
+++ b/tox.ini
@@ -181,11 +181,7 @@ commands = mypy \
synapse/appservice \
synapse/config \
synapse/events/spamcheck.py \
- synapse/federation/federation_base.py \
- synapse/federation/federation_client.py \
- synapse/federation/federation_server.py \
- synapse/federation/sender \
- synapse/federation/transport \
+ synapse/federation \
synapse/handlers/auth.py \
synapse/handlers/cas_handler.py \
synapse/handlers/directory.py \
@@ -203,6 +199,7 @@ commands = mypy \
synapse/storage/data_stores/main/ui_auth.py \
synapse/storage/database.py \
synapse/storage/engines \
+ synapse/storage/util \
synapse/streams \
synapse/util/caches/stream_change_cache.py \
tests/replication/tcp/streams \
|