diff options
59 files changed, 1302 insertions, 712 deletions
diff --git a/changelog.d/4942.bugfix b/changelog.d/4942.bugfix new file mode 100644 index 0000000000..590d80d58f --- /dev/null +++ b/changelog.d/4942.bugfix @@ -0,0 +1 @@ +Fix bug where presence updates were sent to all servers in a room when a new server joined, rather than to just the new server. diff --git a/changelog.d/4947.feature b/changelog.d/4947.feature new file mode 100644 index 0000000000..b9d27b90f1 --- /dev/null +++ b/changelog.d/4947.feature @@ -0,0 +1 @@ +Add ability for password provider modules to bind email addresses to users upon registration. \ No newline at end of file diff --git a/changelog.d/4949.misc b/changelog.d/4949.misc new file mode 100644 index 0000000000..25c4e05a64 --- /dev/null +++ b/changelog.d/4949.misc @@ -0,0 +1 @@ +Fix/improve some docstrings in the replication code. diff --git a/changelog.d/4953.misc b/changelog.d/4953.misc new file mode 100644 index 0000000000..06a084e6ef --- /dev/null +++ b/changelog.d/4953.misc @@ -0,0 +1,2 @@ +Split synapse.replication.tcp.streams into smaller files. + diff --git a/changelog.d/4954.misc b/changelog.d/4954.misc new file mode 100644 index 0000000000..91f145950d --- /dev/null +++ b/changelog.d/4954.misc @@ -0,0 +1 @@ +Refactor replication row generation/parsing. diff --git a/changelog.d/4955.bugfix b/changelog.d/4955.bugfix new file mode 100644 index 0000000000..e50e67383d --- /dev/null +++ b/changelog.d/4955.bugfix @@ -0,0 +1 @@ +Fix sync bug which made accepting invites unreliable in worker-mode synapses. diff --git a/changelog.d/4959.misc b/changelog.d/4959.misc new file mode 100644 index 0000000000..dd4275501f --- /dev/null +++ b/changelog.d/4959.misc @@ -0,0 +1 @@ +Run `black` to clean up formatting on `synapse/storage/roommember.py` and `synapse/storage/events.py`. \ No newline at end of file diff --git a/changelog.d/4965.misc b/changelog.d/4965.misc new file mode 100644 index 0000000000..284c58b75e --- /dev/null +++ b/changelog.d/4965.misc @@ -0,0 +1 @@ +Remove log line for password via the admin API. diff --git a/changelog.d/4968.misc b/changelog.d/4968.misc new file mode 100644 index 0000000000..7a7b69771c --- /dev/null +++ b/changelog.d/4968.misc @@ -0,0 +1 @@ +Fix typo in TLS filenames in docker/README.md. Also add the '-p' commandline option to the 'docker run' example. Contributed by Jurrie Overgoor. diff --git a/changelog.d/4969.misc b/changelog.d/4969.misc new file mode 100644 index 0000000000..e3a3214e6b --- /dev/null +++ b/changelog.d/4969.misc @@ -0,0 +1,2 @@ +Refactor room version definitions. + diff --git a/docker/README.md b/docker/README.md index 4b98b7fd75..44ade63f27 100644 --- a/docker/README.md +++ b/docker/README.md @@ -31,6 +31,7 @@ docker run \ --mount type=volume,src=synapse-data,dst=/data \ -e SYNAPSE_SERVER_NAME=my.matrix.host \ -e SYNAPSE_REPORT_STATS=yes \ + -p 8448:8448 \ matrixdotorg/synapse:latest ``` @@ -57,8 +58,8 @@ configuration file there. Multiple application services are supported. Synapse requires a valid TLS certificate. You can do one of the following: * Provide your own certificate and key (as - `${DATA_PATH}/${SYNAPSE_SERVER_NAME}.crt` and - `${DATA_PATH}/${SYNAPSE_SERVER_NAME}.key`, or elsewhere by providing an + `${DATA_PATH}/${SYNAPSE_SERVER_NAME}.tls.crt` and + `${DATA_PATH}/${SYNAPSE_SERVER_NAME}.tls.key`, or elsewhere by providing an entire config as `${SYNAPSE_CONFIG_PATH}`). * Use a reverse proxy to terminate incoming TLS, and forward the plain http diff --git a/synapse/api/constants.py b/synapse/api/constants.py index f47c33a074..dc913feeee 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -102,46 +102,6 @@ class ThirdPartyEntityKind(object): LOCATION = "location" -class RoomVersions(object): - V1 = "1" - V2 = "2" - V3 = "3" - STATE_V2_TEST = "state-v2-test" - - -class RoomDisposition(object): - STABLE = "stable" - UNSTABLE = "unstable" - - -# the version we will give rooms which are created on this server -DEFAULT_ROOM_VERSION = RoomVersions.V1 - -# vdh-test-version is a placeholder to get room versioning support working and tested -# until we have a working v2. -KNOWN_ROOM_VERSIONS = { - RoomVersions.V1, - RoomVersions.V2, - RoomVersions.V3, - RoomVersions.STATE_V2_TEST, - RoomVersions.V3, -} - - -class EventFormatVersions(object): - """This is an internal enum for tracking the version of the event format, - independently from the room version. - """ - V1 = 1 - V2 = 2 - - -KNOWN_EVENT_FORMAT_VERSIONS = { - EventFormatVersions.V1, - EventFormatVersions.V2, -} - - ServerNoticeMsgType = "m.server_notice" ServerNoticeLimitReached = "m.server_notice.usage_limit_reached" diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py new file mode 100644 index 0000000000..e77abe1040 --- /dev/null +++ b/synapse/api/room_versions.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import attr + + +class EventFormatVersions(object): + """This is an internal enum for tracking the version of the event format, + independently from the room version. + """ + V1 = 1 # $id:server format + V2 = 2 # MSC1659-style $hash format: introduced for room v3 + + +KNOWN_EVENT_FORMAT_VERSIONS = { + EventFormatVersions.V1, + EventFormatVersions.V2, +} + + +class StateResolutionVersions(object): + """Enum to identify the state resolution algorithms""" + V1 = 1 # room v1 state res + V2 = 2 # MSC1442 state res: room v2 and later + + +class RoomDisposition(object): + STABLE = "stable" + UNSTABLE = "unstable" + + +@attr.s(slots=True, frozen=True) +class RoomVersion(object): + """An object which describes the unique attributes of a room version.""" + + identifier = attr.ib() # str; the identifier for this version + disposition = attr.ib() # str; one of the RoomDispositions + event_format = attr.ib() # int; one of the EventFormatVersions + state_res = attr.ib() # int; one of the StateResolutionVersions + + +class RoomVersions(object): + V1 = RoomVersion( + "1", + RoomDisposition.STABLE, + EventFormatVersions.V1, + StateResolutionVersions.V1, + ) + STATE_V2_TEST = RoomVersion( + "state-v2-test", + RoomDisposition.UNSTABLE, + EventFormatVersions.V1, + StateResolutionVersions.V2, + ) + V2 = RoomVersion( + "2", + RoomDisposition.STABLE, + EventFormatVersions.V1, + StateResolutionVersions.V2, + ) + V3 = RoomVersion( + "3", + RoomDisposition.STABLE, + EventFormatVersions.V2, + StateResolutionVersions.V2, + ) + + +# the version we will give rooms which are created on this server +DEFAULT_ROOM_VERSION = RoomVersions.V1 + + +KNOWN_ROOM_VERSIONS = { + v.identifier: v for v in ( + RoomVersions.V1, + RoomVersions.V2, + RoomVersions.V3, + RoomVersions.STATE_V2_TEST, + ) +} # type: dict[str, RoomVersion] diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 9711a7147c..1d43f2b075 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -38,7 +38,7 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.replication.tcp.streams import ReceiptsStream +from synapse.replication.tcp.streams._base import ReceiptsStream from synapse.server import HomeServer from synapse.storage.engines import create_engine from synapse.types import ReadReceipt diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 9163b56d86..5388def28a 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -48,6 +48,7 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.replication.tcp.streams.events import EventsStreamEventRow from synapse.rest.client.v1 import events from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet from synapse.rest.client.v1.room import RoomInitialSyncRestServlet @@ -369,7 +370,9 @@ class SyncReplicationHandler(ReplicationClientHandler): # We shouldn't get multiple rows per token for events stream, so # we don't need to optimise this for multiple rows. for row in rows: - event = yield self.store.get_event(row.event_id) + if row.type != EventsStreamEventRow.TypeId: + continue + event = yield self.store.get_event(row.data.event_id) extra_users = () if event.type == EventTypes.Member: extra_users = (event.state_key,) diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index d1ab9512cd..355f5aa71d 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -36,6 +36,10 @@ from synapse.replication.slave.storage.client_ips import SlavedClientIpStore from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.replication.tcp.streams.events import ( + EventsStream, + EventsStreamCurrentStateRow, +) from synapse.rest.client.v2_alpha import user_directory from synapse.server import HomeServer from synapse.storage.engines import create_engine @@ -73,19 +77,18 @@ class UserDirectorySlaveStore( prefilled_cache=curr_state_delta_prefill, ) - self._current_state_delta_pos = events_max - def stream_positions(self): result = super(UserDirectorySlaveStore, self).stream_positions() - result["current_state_deltas"] = self._current_state_delta_pos return result def process_replication_rows(self, stream_name, token, rows): - if stream_name == "current_state_deltas": - self._current_state_delta_pos = token + if stream_name == EventsStream.NAME: + self._stream_id_gen.advance(token) for row in rows: + if row.type != EventsStreamCurrentStateRow.TypeId: + continue self._curr_state_delta_stream_cache.entity_has_changed( - row.room_id, token + row.data.room_id, token ) return super(UserDirectorySlaveStore, self).process_replication_rows( stream_name, token, rows @@ -170,7 +173,7 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler): yield super(UserDirectoryReplicationHandler, self).on_rdata( stream_name, token, rows ) - if stream_name == "current_state_deltas": + if stream_name == EventsStream.NAME: run_in_background(self._notify_directory) @defer.inlineCallbacks diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 8f9e330da5..203490fc36 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -20,15 +20,9 @@ from signedjson.key import decode_verify_key_bytes from signedjson.sign import SignatureVerifyException, verify_signed_json from unpaddedbase64 import decode_base64 -from synapse.api.constants import ( - KNOWN_ROOM_VERSIONS, - EventFormatVersions, - EventTypes, - JoinRules, - Membership, - RoomVersions, -) +from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.errors import AuthError, EventSizeError, SynapseError +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions from synapse.types import UserID, get_domain_from_id logger = logging.getLogger(__name__) @@ -452,16 +446,18 @@ def check_redaction(room_version, event, auth_events): if user_level >= redact_level: return False - if room_version in (RoomVersions.V1, RoomVersions.V2,): + v = KNOWN_ROOM_VERSIONS.get(room_version) + if not v: + raise RuntimeError("Unrecognized room version %r" % (room_version,)) + + if v.event_format == EventFormatVersions.V1: redacter_domain = get_domain_from_id(event.event_id) redactee_domain = get_domain_from_id(event.redacts) if redacter_domain == redactee_domain: return True - elif room_version == RoomVersions.V3: + else: event.internal_metadata.recheck_redaction = True return True - else: - raise RuntimeError("Unrecognized room version %r" % (room_version,)) raise AuthError( 403, diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index fafa135182..12056d5be2 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -21,7 +21,7 @@ import six from unpaddedbase64 import encode_base64 -from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventFormatVersions, RoomVersions +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions from synapse.util.caches import intern_dict from synapse.util.frozenutils import freeze @@ -351,18 +351,13 @@ def room_version_to_event_format(room_version): Returns: int """ - if room_version not in KNOWN_ROOM_VERSIONS: + v = KNOWN_ROOM_VERSIONS.get(room_version) + + if not v: # We should have already checked version, so this should not happen raise RuntimeError("Unrecognized room version %s" % (room_version,)) - if room_version in ( - RoomVersions.V1, RoomVersions.V2, RoomVersions.STATE_V2_TEST, - ): - return EventFormatVersions.V1 - elif room_version in (RoomVersions.V3,): - return EventFormatVersions.V2 - else: - raise RuntimeError("Unrecognized room version %s" % (room_version,)) + return v.event_format def event_type_from_format_version(format_version): diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 06e01be918..fba27177c7 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -17,21 +17,17 @@ import attr from twisted.internet import defer -from synapse.api.constants import ( +from synapse.api.constants import MAX_DEPTH +from synapse.api.room_versions import ( KNOWN_EVENT_FORMAT_VERSIONS, KNOWN_ROOM_VERSIONS, - MAX_DEPTH, EventFormatVersions, ) from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.types import EventID from synapse.util.stringutils import random_string -from . import ( - _EventInternalMetadata, - event_type_from_format_version, - room_version_to_event_format, -) +from . import _EventInternalMetadata, event_type_from_format_version @attr.s(slots=True, cmp=False, frozen=True) @@ -170,21 +166,34 @@ class EventBuilderFactory(object): def new(self, room_version, key_values): """Generate an event builder appropriate for the given room version + Deprecated: use for_room_version with a RoomVersion object instead + Args: - room_version (str): Version of the room that we're creating an - event builder for + room_version (str): Version of the room that we're creating an event builder + for key_values (dict): Fields used as the basis of the new event Returns: EventBuilder """ - - # There's currently only the one event version defined - if room_version not in KNOWN_ROOM_VERSIONS: + v = KNOWN_ROOM_VERSIONS.get(room_version) + if not v: raise Exception( "No event format defined for version %r" % (room_version,) ) + return self.for_room_version(v, key_values) + def for_room_version(self, room_version, key_values): + """Generate an event builder appropriate for the given room version + + Args: + room_version (synapse.api.room_versions.RoomVersion): + Version of the room that we're creating an event builder for + key_values (dict): Fields used as the basis of the new event + + Returns: + EventBuilder + """ return EventBuilder( store=self.store, state=self.state, @@ -192,7 +201,7 @@ class EventBuilderFactory(object): clock=self.clock, hostname=self.hostname, signing_key=self.signing_key, - format_version=room_version_to_event_format(room_version), + format_version=room_version.event_format, type=key_values["type"], state_key=key_values.get("state_key"), room_id=key_values["room_id"], @@ -222,7 +231,6 @@ def create_local_event_from_event_dict(clock, hostname, signing_key, FrozenEvent """ - # There's currently only the one event version defined if format_version not in KNOWN_EVENT_FORMAT_VERSIONS: raise Exception( "No event format defined for version %r" % (format_version,) diff --git a/synapse/events/validator.py b/synapse/events/validator.py index a072674b02..514273c792 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -15,8 +15,9 @@ from six import string_types -from synapse.api.constants import EventFormatVersions, EventTypes, Membership +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError +from synapse.api.room_versions import EventFormatVersions from synapse.types import EventID, RoomID, UserID diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index a7a2ec4523..dfe6b4aa5c 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -20,8 +20,9 @@ import six from twisted.internet import defer from twisted.internet.defer import DeferredList -from synapse.api.constants import MAX_DEPTH, EventTypes, Membership, RoomVersions +from synapse.api.constants import MAX_DEPTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions from synapse.crypto.event_signing import check_event_content_hash from synapse.events import event_type_from_format_version from synapse.events.utils import prune_event @@ -274,9 +275,12 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): # now let's look for events where the sender's domain is different to the # event id's domain (normally only the case for joins/leaves), and add additional # checks. Only do this if the room version has a concept of event ID domain - if room_version in ( - RoomVersions.V1, RoomVersions.V2, RoomVersions.STATE_V2_TEST, - ): + # (ie, the room version uses old-style non-hash event IDs). + v = KNOWN_ROOM_VERSIONS.get(room_version) + if not v: + raise RuntimeError("Unrecognized room version %s" % (room_version,)) + + if v.event_format == EventFormatVersions.V1: pdus_to_check_event_id = [ p for p in pdus_to_check if p.sender_domain != get_domain_from_id(p.pdu.event_id) @@ -289,10 +293,6 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): for p, d in zip(pdus_to_check_event_id, more_deferreds): p.deferreds.append(d) - elif room_version in (RoomVersions.V3,): - pass # No further checks needed, as event IDs are hashes here - else: - raise RuntimeError("Unrecognized room version %s" % (room_version,)) # replace lists of deferreds with single Deferreds return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check] diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 58e04d81ab..f3fc897a0a 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -25,12 +25,7 @@ from prometheus_client import Counter from twisted.internet import defer -from synapse.api.constants import ( - KNOWN_ROOM_VERSIONS, - EventTypes, - Membership, - RoomVersions, -) +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( CodeMessageException, Codes, @@ -38,6 +33,11 @@ from synapse.api.errors import ( HttpResponseException, SynapseError, ) +from synapse.api.room_versions import ( + KNOWN_ROOM_VERSIONS, + EventFormatVersions, + RoomVersions, +) from synapse.events import builder, room_version_to_event_format from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.util import logcontext, unwrapFirstError @@ -570,7 +570,7 @@ class FederationClient(FederationBase): Deferred[tuple[str, FrozenEvent, int]]: resolves to a tuple of `(origin, event, event_format)` where origin is the remote homeserver which generated the event, and event_format is one of - `synapse.api.constants.EventFormatVersions`. + `synapse.api.room_versions.EventFormatVersions`. Fails with a ``SynapseError`` if the chosen remote server returns a 300/400 code. @@ -592,7 +592,7 @@ class FederationClient(FederationBase): # Note: If not supplied, the room version may be either v1 or v2, # however either way the event format version will be v1. - room_version = ret.get("room_version", RoomVersions.V1) + room_version = ret.get("room_version", RoomVersions.V1.identifier) event_format = room_version_to_event_format(room_version) pdu_dict = ret.get("event", None) @@ -695,7 +695,9 @@ class FederationClient(FederationBase): room_version = None for e in state: if (e.type, e.state_key) == (EventTypes.Create, ""): - room_version = e.content.get("room_version", RoomVersions.V1) + room_version = e.content.get( + "room_version", RoomVersions.V1.identifier + ) break if room_version is None: @@ -802,11 +804,10 @@ class FederationClient(FederationBase): raise err # Otherwise, we assume that the remote server doesn't understand - # the v2 invite API. - - if room_version in (RoomVersions.V1, RoomVersions.V2): - pass # We'll fall through - else: + # the v2 invite API. That's ok provided the room uses old-style event + # IDs. + v = KNOWN_ROOM_VERSIONS.get(room_version) + if v.event_format != EventFormatVersions.V1: raise SynapseError( 400, "User's homeserver does not support this room version", diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 81f3b4b1ff..df60828dba 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -25,7 +25,7 @@ from twisted.internet import defer from twisted.internet.abstract import isIPAddress from twisted.python import failure -from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventTypes, Membership +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, @@ -34,6 +34,7 @@ from synapse.api.errors import ( NotFoundError, SynapseError, ) +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.crypto.event_signing import compute_event_signature from synapse.events import room_version_to_event_format from synapse.federation.federation_base import FederationBase, event_from_pdu_json diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 04d04a4457..0240b339b0 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -55,7 +55,12 @@ class FederationRemoteSendQueue(object): self.is_mine_id = hs.is_mine_id self.presence_map = {} # Pending presence map user_id -> UserPresenceState - self.presence_changed = SortedDict() # Stream position -> user_id + self.presence_changed = SortedDict() # Stream position -> list[user_id] + + # Stores the destinations we need to explicitly send presence to about a + # given user. + # Stream position -> (user_id, destinations) + self.presence_destinations = SortedDict() self.keyed_edu = {} # (destination, key) -> EDU self.keyed_edu_changed = SortedDict() # stream position -> (destination, key) @@ -77,7 +82,7 @@ class FederationRemoteSendQueue(object): for queue_name in [ "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed", - "edus", "device_messages", "pos_time", + "edus", "device_messages", "pos_time", "presence_destinations", ]: register(queue_name, getattr(self, queue_name)) @@ -121,6 +126,15 @@ class FederationRemoteSendQueue(object): for user_id in uids ) + keys = self.presence_destinations.keys() + i = self.presence_destinations.bisect_left(position_to_delete) + for key in keys[:i]: + del self.presence_destinations[key] + + user_ids.update( + user_id for user_id, _ in self.presence_destinations.values() + ) + to_del = [ user_id for user_id in self.presence_map if user_id not in user_ids ] @@ -209,6 +223,20 @@ class FederationRemoteSendQueue(object): self.notifier.on_new_replication_data() + def send_presence_to_destinations(self, states, destinations): + """As per FederationSender + + Args: + states (list[UserPresenceState]) + destinations (list[str]) + """ + for state in states: + pos = self._next_pos() + self.presence_map.update({state.user_id: state for state in states}) + self.presence_destinations[pos] = (state.user_id, destinations) + + self.notifier.on_new_replication_data() + def send_device_messages(self, destination): """As per FederationSender""" pos = self._next_pos() @@ -261,6 +289,16 @@ class FederationRemoteSendQueue(object): state=self.presence_map[user_id], ))) + # Fetch presence to send to destinations + i = self.presence_destinations.bisect_right(from_token) + j = self.presence_destinations.bisect_right(to_token) + 1 + + for pos, (user_id, dests) in self.presence_destinations.items()[i:j]: + rows.append((pos, PresenceDestinationsRow( + state=self.presence_map[user_id], + destinations=list(dests), + ))) + # Fetch changes keyed edus i = self.keyed_edu_changed.bisect_right(from_token) j = self.keyed_edu_changed.bisect_right(to_token) + 1 @@ -357,6 +395,29 @@ class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", ( buff.presence.append(self.state) +class PresenceDestinationsRow(BaseFederationRow, namedtuple("PresenceDestinationsRow", ( + "state", # UserPresenceState + "destinations", # list[str] +))): + TypeId = "pd" + + @staticmethod + def from_data(data): + return PresenceDestinationsRow( + state=UserPresenceState.from_dict(data["state"]), + destinations=data["dests"], + ) + + def to_data(self): + return { + "state": self.state.as_dict(), + "dests": self.destinations, + } + + def add_to_buffer(self, buff): + buff.presence_destinations.append((self.state, self.destinations)) + + class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", ( "key", # tuple(str) - the edu key passed to send_edu "edu", # Edu @@ -428,6 +489,7 @@ TypeToRow = { Row.TypeId: Row for Row in ( PresenceRow, + PresenceDestinationsRow, KeyedEduRow, EduRow, DeviceRow, @@ -437,6 +499,7 @@ TypeToRow = { ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", ( "presence", # list(UserPresenceState) + "presence_destinations", # list of tuples of UserPresenceState and destinations "keyed_edus", # dict of destination -> { key -> Edu } "edus", # dict of destination -> [Edu] "device_destinations", # set of destinations @@ -458,6 +521,7 @@ def process_rows_for_federation(transaction_queue, rows): buff = ParsedFederationStreamData( presence=[], + presence_destinations=[], keyed_edus={}, edus={}, device_destinations=set(), @@ -476,6 +540,11 @@ def process_rows_for_federation(transaction_queue, rows): if buff.presence: transaction_queue.send_presence(buff.presence) + for state, destinations in buff.presence_destinations: + transaction_queue.send_presence_to_destinations( + states=[state], destinations=destinations, + ) + for destination, edu_map in iteritems(buff.keyed_edus): for key, edu in edu_map.items(): transaction_queue.send_edu(edu, key) diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 1dc041752b..4f0f939102 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -371,7 +371,7 @@ class FederationSender(object): return # First we queue up the new presence by user ID, so multiple presence - # updates in quick successtion are correctly handled + # updates in quick succession are correctly handled. # We only want to send presence for our own users, so lets always just # filter here just in case. self.pending_presence.update({ @@ -402,6 +402,23 @@ class FederationSender(object): finally: self._processing_pending_presence = False + def send_presence_to_destinations(self, states, destinations): + """Send the given presence states to the given destinations. + + Args: + states (list[UserPresenceState]) + destinations (list[str]) + """ + + if not states or not self.hs.config.use_presence: + # No-op if presence is disabled. + return + + for destination in destinations: + if destination == self.server_name: + continue + self._get_per_destination_queue(destination).send_presence(states) + @measure_func("txnqueue._process_presence") @defer.inlineCallbacks def _process_presence_inner(self, states): diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index efb6bdca48..452599e1a1 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -21,8 +21,8 @@ import re from twisted.internet import defer import synapse -from synapse.api.constants import RoomVersions from synapse.api.errors import Codes, FederationDeniedError, SynapseError +from synapse.api.room_versions import RoomVersions from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.server import JsonResource @@ -513,7 +513,7 @@ class FederationV1InviteServlet(BaseFederationServlet): # state resolution algorithm, and we don't use that for processing # invites content = yield self.handler.on_invite_request( - origin, content, room_version=RoomVersions.V1, + origin, content, room_version=RoomVersions.V1.identifier, ) # V1 federation API is defined to return a content of `[200, {...}]` diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 9eaf2d3e18..0684778882 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -29,13 +29,7 @@ from unpaddedbase64 import decode_base64 from twisted.internet import defer -from synapse.api.constants import ( - KNOWN_ROOM_VERSIONS, - EventTypes, - Membership, - RejectedReason, - RoomVersions, -) +from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.api.errors import ( AuthError, CodeMessageException, @@ -44,6 +38,7 @@ from synapse.api.errors import ( StoreError, SynapseError, ) +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.crypto.event_signing import compute_event_signature from synapse.event_auth import auth_types_for_event from synapse.events.validator import EventValidator @@ -1733,7 +1728,9 @@ class FederationHandler(BaseHandler): # invalid, and it would fail auth checks anyway. raise SynapseError(400, "No create event in state") - room_version = create_event.content.get("room_version", RoomVersions.V1) + room_version = create_event.content.get( + "room_version", RoomVersions.V1.identifier, + ) missing_auth_events = set() for e in itertools.chain(auth_events, state, [event]): diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 9b41c7b205..8bc7a7678a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -22,7 +22,7 @@ from canonicaljson import encode_canonical_json, json from twisted.internet import defer from twisted.internet.defer import succeed -from synapse.api.constants import EventTypes, Membership, RoomVersions +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, @@ -30,6 +30,7 @@ from synapse.api.errors import ( NotFoundError, SynapseError, ) +from synapse.api.room_versions import RoomVersions from synapse.api.urls import ConsentURIBuilder from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator @@ -603,7 +604,9 @@ class EventCreationHandler(object): """ if event.is_state() and (event.type, event.state_key) == (EventTypes.Create, ""): - room_version = event.content.get("room_version", RoomVersions.V1) + room_version = event.content.get( + "room_version", RoomVersions.V1.identifier + ) else: room_version = yield self.store.get_room_version(event.room_id) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 37e87fc054..e85c49742d 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -31,9 +31,11 @@ from prometheus_client import Counter from twisted.internet import defer -from synapse.api.constants import PresenceState +import synapse.metrics +from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.errors import SynapseError from synapse.metrics import LaterGauge +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.presence import UserPresenceState from synapse.types import UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer @@ -98,6 +100,7 @@ class PresenceHandler(object): self.hs = hs self.is_mine = hs.is_mine self.is_mine_id = hs.is_mine_id + self.server_name = hs.hostname self.clock = hs.get_clock() self.store = hs.get_datastore() self.wheel_timer = WheelTimer() @@ -132,9 +135,6 @@ class PresenceHandler(object): ) ) - distributor = hs.get_distributor() - distributor.observe("user_joined_room", self.user_joined_room) - active_presence = self.store.take_presence_startup_info() # A dictionary of the current state of users. This is prefilled with @@ -220,6 +220,15 @@ class PresenceHandler(object): LaterGauge("synapse_handlers_presence_wheel_timer_size", "", [], lambda: len(self.wheel_timer)) + # Used to handle sending of presence to newly joined users/servers + if hs.config.use_presence: + self.notifier.add_replication_callback(self.notify_new_event) + + # Presence is best effort and quickly heals itself, so lets just always + # stream from the current state when we restart. + self._event_pos = self.store.get_current_events_token() + self._event_processing = False + @defer.inlineCallbacks def _on_shutdown(self): """Gets called when shutting down. This lets us persist any updates that @@ -751,31 +760,6 @@ class PresenceHandler(object): yield self._update_states([prev_state.copy_and_replace(**new_fields)]) @defer.inlineCallbacks - def user_joined_room(self, user, room_id): - """Called (via the distributor) when a user joins a room. This funciton - sends presence updates to servers, either: - 1. the joining user is a local user and we send their presence to - all servers in the room. - 2. the joining user is a remote user and so we send presence for all - local users in the room. - """ - # We only need to send presence to servers that don't have it yet. We - # don't need to send to local clients here, as that is done as part - # of the event stream/sync. - # TODO: Only send to servers not already in the room. - if self.is_mine(user): - state = yield self.current_state_for_user(user.to_string()) - - self._push_to_remotes([state]) - else: - user_ids = yield self.store.get_users_in_room(room_id) - user_ids = list(filter(self.is_mine_id, user_ids)) - - states = yield self.current_state_for_users(user_ids) - - self._push_to_remotes(list(states.values())) - - @defer.inlineCallbacks def get_presence_list(self, observer_user, accepted=None): """Returns the presence for all users in their presence list. """ @@ -945,6 +929,140 @@ class PresenceHandler(object): rows = yield self.store.get_all_presence_updates(last_id, current_id) defer.returnValue(rows) + def notify_new_event(self): + """Called when new events have happened. Handles users and servers + joining rooms and require being sent presence. + """ + + if self._event_processing: + return + + @defer.inlineCallbacks + def _process_presence(): + assert not self._event_processing + + self._event_processing = True + try: + yield self._unsafe_process() + finally: + self._event_processing = False + + run_as_background_process("presence.notify_new_event", _process_presence) + + @defer.inlineCallbacks + def _unsafe_process(self): + # Loop round handling deltas until we're up to date + while True: + with Measure(self.clock, "presence_delta"): + deltas = yield self.store.get_current_state_deltas(self._event_pos) + if not deltas: + return + + yield self._handle_state_delta(deltas) + + self._event_pos = deltas[-1]["stream_id"] + + # Expose current event processing position to prometheus + synapse.metrics.event_processing_positions.labels("presence").set( + self._event_pos + ) + + @defer.inlineCallbacks + def _handle_state_delta(self, deltas): + """Process current state deltas to find new joins that need to be + handled. + """ + for delta in deltas: + typ = delta["type"] + state_key = delta["state_key"] + room_id = delta["room_id"] + event_id = delta["event_id"] + prev_event_id = delta["prev_event_id"] + + logger.debug("Handling: %r %r, %s", typ, state_key, event_id) + + if typ != EventTypes.Member: + continue + + event = yield self.store.get_event(event_id) + if event.content.get("membership") != Membership.JOIN: + # We only care about joins + continue + + if prev_event_id: + prev_event = yield self.store.get_event(prev_event_id) + if prev_event.content.get("membership") == Membership.JOIN: + # Ignore changes to join events. + continue + + yield self._on_user_joined_room(room_id, state_key) + + @defer.inlineCallbacks + def _on_user_joined_room(self, room_id, user_id): + """Called when we detect a user joining the room via the current state + delta stream. + + Args: + room_id (str) + user_id (str) + + Returns: + Deferred + """ + + if self.is_mine_id(user_id): + # If this is a local user then we need to send their presence + # out to hosts in the room (who don't already have it) + + # TODO: We should be able to filter the hosts down to those that + # haven't previously seen the user + + state = yield self.current_state_for_user(user_id) + hosts = yield self.state.get_current_hosts_in_room(room_id) + + # Filter out ourselves. + hosts = set(host for host in hosts if host != self.server_name) + + self.federation.send_presence_to_destinations( + states=[state], + destinations=hosts, + ) + else: + # A remote user has joined the room, so we need to: + # 1. Check if this is a new server in the room + # 2. If so send any presence they don't already have for + # local users in the room. + + # TODO: We should be able to filter the users down to those that + # the server hasn't previously seen + + # TODO: Check that this is actually a new server joining the + # room. + + user_ids = yield self.state.get_current_user_in_room(room_id) + user_ids = list(filter(self.is_mine_id, user_ids)) + + states = yield self.current_state_for_users(user_ids) + + # Filter out old presence, i.e. offline presence states where + # the user hasn't been active for a week. We can change this + # depending on what we want the UX to be, but at the least we + # should filter out offline presence where the state is just the + # default state. + now = self.clock.time_msec() + states = [ + state for state in states.values() + if state.state != PresenceState.OFFLINE + or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000 + or state.status_msg is not None + ] + + if states: + self.federation.send_presence_to_destinations( + states=states, + destinations=[get_domain_from_id(user_id)], + ) + def should_notify(old_state, new_state): """Decides if a presence state change should be sent to interested parties. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 58940e0320..a51d11a257 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -153,6 +153,7 @@ class RegistrationHandler(BaseHandler): user_type=None, default_display_name=None, address=None, + bind_emails=[], ): """Registers a new client on the server. @@ -172,6 +173,7 @@ class RegistrationHandler(BaseHandler): default_display_name (unicode|None): if set, the new user's displayname will be set to this. Defaults to 'localpart'. address (str|None): the IP address used to perform the registration. + bind_emails (List[str]): list of emails to bind to this account. Returns: A tuple of (user_id, access_token). Raises: @@ -261,6 +263,21 @@ class RegistrationHandler(BaseHandler): if not self.hs.config.user_consent_at_registration: yield self._auto_join_rooms(user_id) + # Bind any specified emails to this account + current_time = self.hs.get_clock().time_msec() + for email in bind_emails: + # generate threepid dict + threepid_dict = { + "medium": "email", + "address": email, + "validated_at": current_time, + } + + # Bind email to new account + yield self._register_email_threepid( + user_id, threepid_dict, None, False, + ) + defer.returnValue((user_id, token)) @defer.inlineCallbacks diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 67b15697fd..c3dcfec247 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -25,14 +25,9 @@ from six import iteritems, string_types from twisted.internet import defer -from synapse.api.constants import ( - DEFAULT_ROOM_VERSION, - KNOWN_ROOM_VERSIONS, - EventTypes, - JoinRules, - RoomCreationPreset, -) +from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError +from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS from synapse.storage.state import StateFilter from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID from synapse.util import stringutils @@ -479,7 +474,7 @@ class RoomCreationHandler(BaseHandler): if ratelimit: yield self.ratelimit(requester) - room_version = config.get("room_version", DEFAULT_ROOM_VERSION) + room_version = config.get("room_version", DEFAULT_ROOM_VERSION.identifier) if not isinstance(room_version, string_types): raise SynapseError( 400, diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 235ce8334e..b3abd1b3c6 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -74,14 +74,14 @@ class ModuleApi(object): return self._auth_handler.check_user_exists(user_id) @defer.inlineCallbacks - def register(self, localpart, displayname=None): + def register(self, localpart, displayname=None, emails=[]): """Registers a new user with given localpart and optional - displayname. + displayname, emails. Args: localpart (str): The localpart of the new user. - displayname (str|None): The displayname of the new user. If None, - the user's displayname will default to `localpart`. + displayname (str|None): The displayname of the new user. + emails (List[str]): Emails to bind to the new user. Returns: Deferred: a 2-tuple of (user_id, access_token) @@ -90,6 +90,7 @@ class ModuleApi(object): reg = self.hs.get_registration_handler() user_id, access_token = yield reg.register( localpart=localpart, default_display_name=displayname, + bind_emails=emails, ) defer.returnValue((user_id, access_token)) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 4830c68f35..c57385d92f 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -16,6 +16,7 @@ import logging from synapse.api.constants import EventTypes +from synapse.replication.tcp.streams.events import EventsStreamEventRow from synapse.storage.event_federation import EventFederationWorkerStore from synapse.storage.event_push_actions import EventPushActionsWorkerStore from synapse.storage.events_worker import EventsWorkerStore @@ -79,9 +80,12 @@ class SlavedEventStore(EventFederationWorkerStore, if stream_name == "events": self._stream_id_gen.advance(token) for row in rows: + if row.type != EventsStreamEventRow.TypeId: + continue + data = row.data self.invalidate_caches_for_event( - token, row.event_id, row.room_id, row.type, row.state_key, - row.redacts, + token, data.event_id, data.room_id, data.type, data.state_key, + data.redacts, backfilled=False, ) elif stream_name == "backfill": diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e558f90e1a..206dc3b397 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -103,10 +103,19 @@ class ReplicationClientHandler(object): hs.get_reactor().connectTCP(host, port, self.factory) def on_rdata(self, stream_name, token, rows): - """Called when we get new replication data. By default this just pokes - the slave store. + """Called to handle a batch of replication data with a given stream token. - Can be overriden in subclasses to handle more. + By default this just pokes the slave store. Can be overridden in subclasses to + handle more. + + Args: + stream_name (str): name of the replication stream for this batch of rows + token (int): stream token for this batch of rows + rows (list): a list of Stream.ROW_TYPE objects as returned by + Stream.parse_row. + + Returns: + Deferred|None """ logger.debug("Received rdata %s -> %s", stream_name, token) return self.store.process_replication_rows(stream_name, token, rows) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 02e5bf6cc8..b51590cf8f 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -42,8 +42,8 @@ indicate which side is sending, these are *not* included on the wire:: > POSITION backfill 1 > POSITION caches 1 > RDATA caches 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513] - > RDATA events 14 ["$149019767112vOHxz:localhost:8823", - "!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null] + > RDATA events 14 ["ev", ["$149019767112vOHxz:localhost:8823", + "!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null]] < PING 1490197675618 > ERROR server stopping * connection closed by server * @@ -605,7 +605,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): inbound_rdata_count.labels(stream_name).inc() try: - row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row) + row = STREAMS_MAP[stream_name].parse_row(cmd.row) except Exception: logger.exception( "[%s] Failed to parse RDATA: %r %r", diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 7fc346c7b6..f6a38f5140 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -30,7 +30,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.metrics import Measure, measure_func from .protocol import ServerReplicationStreamProtocol -from .streams import STREAMS_MAP, FederationStream +from .streams import STREAMS_MAP +from .streams.federation import FederationStream stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]) diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py new file mode 100644 index 0000000000..634f636dc9 --- /dev/null +++ b/synapse/replication/tcp/streams/__init__.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines all the valid streams that clients can subscribe to, and the format +of the rows returned by each stream. + +Each stream is defined by the following information: + + stream name: The name of the stream + row type: The type that is used to serialise/deserialse the row + current_token: The function that returns the current token for the stream + update_function: The function that returns a list of updates between two tokens +""" + +from . import _base, events, federation + +STREAMS_MAP = { + stream.NAME: stream + for stream in ( + events.EventsStream, + _base.BackfillStream, + _base.PresenceStream, + _base.TypingStream, + _base.ReceiptsStream, + _base.PushRulesStream, + _base.PushersStream, + _base.CachesStream, + _base.PublicRoomsStream, + _base.DeviceListsStream, + _base.ToDeviceStream, + federation.FederationStream, + _base.TagAccountDataStream, + _base.AccountDataStream, + _base.GroupServerStream, + ) +} diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams/_base.py index e23084baae..8971a6a22e 100644 --- a/synapse/replication/tcp/streams.py +++ b/synapse/replication/tcp/streams/_base.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017 Vector Creations Ltd +# Copyright 2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,16 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Defines all the valid streams that clients can subscribe to, and the format -of the rows returned by each stream. -Each stream is defined by the following information: - - stream name: The name of the stream - row type: The type that is used to serialise/deserialse the row - current_token: The function that returns the current token for the stream - update_function: The function that returns a list of updates between two tokens -""" import itertools import logging from collections import namedtuple @@ -34,14 +26,6 @@ logger = logging.getLogger(__name__) MAX_EVENTS_BEHIND = 10000 - -EventStreamRow = namedtuple("EventStreamRow", ( - "event_id", # str - "room_id", # str - "type", # str - "state_key", # str, optional - "redacts", # str, optional -)) BackfillStreamRow = namedtuple("BackfillStreamRow", ( "event_id", # str "room_id", # str @@ -96,10 +80,6 @@ DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", ( ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ( "entity", # str )) -FederationStreamRow = namedtuple("FederationStreamRow", ( - "type", # str, the type of data as defined in the BaseFederationRows - "data", # dict, serialization of a federation.send_queue.BaseFederationRow -)) TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow", ( "user_id", # str "room_id", # str @@ -111,12 +91,6 @@ AccountDataStreamRow = namedtuple("AccountDataStream", ( "data_type", # str "data", # dict )) -CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", ( - "room_id", # str - "type", # str - "state_key", # str - "event_id", # str, optional -)) GroupsStreamRow = namedtuple("GroupsStreamRow", ( "group_id", # str "user_id", # str @@ -132,9 +106,24 @@ class Stream(object): time it was called up until the point `advance_current_token` was called. """ NAME = None # The name of the stream - ROW_TYPE = None # The type of the row + ROW_TYPE = None # The type of the row. Used by the default impl of parse_row. _LIMITED = True # Whether the update function takes a limit + @classmethod + def parse_row(cls, row): + """Parse a row received over replication + + By default, assumes that the row data is an array object and passes its contents + to the constructor of the ROW_TYPE for this stream. + + Args: + row: row data from the incoming RDATA command, after json decoding + + Returns: + ROW_TYPE object for this stream + """ + return cls.ROW_TYPE(*row) + def __init__(self, hs): # The token from which we last asked for updates self.last_token = self.current_token() @@ -162,8 +151,10 @@ class Stream(object): until the `upto_token` Returns: - (list(ROW_TYPE), int): list of updates plus the token used as an - upper bound of the updates (i.e. the "current token") + Deferred[Tuple[List[Tuple[int, Any]], int]: + Resolves to a pair ``(updates, current_token)``, where ``updates`` is a + list of ``(token, row)`` entries. ``row`` will be json-serialised and + sent over the replication steam. """ updates, current_token = yield self.get_updates_since(self.last_token) self.last_token = current_token @@ -176,8 +167,10 @@ class Stream(object): stream updates Returns: - (list(ROW_TYPE), int): list of updates plus the token used as an - upper bound of the updates (i.e. the "current token") + Deferred[Tuple[List[Tuple[int, Any]], int]: + Resolves to a pair ``(updates, current_token)``, where ``updates`` is a + list of ``(token, row)`` entries. ``row`` will be json-serialised and + sent over the replication steam. """ if from_token in ("NOW", "now"): defer.returnValue(([], self.upto_token)) @@ -202,7 +195,7 @@ class Stream(object): from_token, current_token, ) - updates = [(row[0], self.ROW_TYPE(*row[1:])) for row in rows] + updates = [(row[0], row[1:]) for row in rows] # check we didn't get more rows than the limit. # doing it like this allows the update_function to be a generator. @@ -232,20 +225,6 @@ class Stream(object): raise NotImplementedError() -class EventsStream(Stream): - """We received a new event, or an event went from being an outlier to not - """ - NAME = "events" - ROW_TYPE = EventStreamRow - - def __init__(self, hs): - store = hs.get_datastore() - self.current_token = store.get_current_events_token - self.update_function = store.get_all_new_forward_event_rows - - super(EventsStream, self).__init__(hs) - - class BackfillStream(Stream): """We fetched some old events and either we had never seen that event before or it went from being an outlier to not. @@ -400,22 +379,6 @@ class ToDeviceStream(Stream): super(ToDeviceStream, self).__init__(hs) -class FederationStream(Stream): - """Data to be sent over federation. Only available when master has federation - sending disabled. - """ - NAME = "federation" - ROW_TYPE = FederationStreamRow - - def __init__(self, hs): - federation_sender = hs.get_federation_sender() - - self.current_token = federation_sender.get_current_token - self.update_function = federation_sender.get_replication_rows - - super(FederationStream, self).__init__(hs) - - class TagAccountDataStream(Stream): """Someone added/removed a tag for a room """ @@ -459,21 +422,6 @@ class AccountDataStream(Stream): defer.returnValue(results) -class CurrentStateDeltaStream(Stream): - """Current state for a room was changed - """ - NAME = "current_state_deltas" - ROW_TYPE = CurrentStateDeltaStreamRow - - def __init__(self, hs): - store = hs.get_datastore() - - self.current_token = store.get_max_current_state_delta_stream_id - self.update_function = store.get_all_updated_current_state_deltas - - super(CurrentStateDeltaStream, self).__init__(hs) - - class GroupServerStream(Stream): NAME = "groups" ROW_TYPE = GroupsStreamRow @@ -485,26 +433,3 @@ class GroupServerStream(Stream): self.update_function = store.get_all_groups_changes super(GroupServerStream, self).__init__(hs) - - -STREAMS_MAP = { - stream.NAME: stream - for stream in ( - EventsStream, - BackfillStream, - PresenceStream, - TypingStream, - ReceiptsStream, - PushRulesStream, - PushersStream, - CachesStream, - PublicRoomsStream, - DeviceListsStream, - ToDeviceStream, - FederationStream, - TagAccountDataStream, - AccountDataStream, - CurrentStateDeltaStream, - GroupServerStream, - ) -} diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py new file mode 100644 index 0000000000..e0f6e29248 --- /dev/null +++ b/synapse/replication/tcp/streams/events.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import heapq + +import attr + +from twisted.internet import defer + +from ._base import Stream + + +"""Handling of the 'events' replication stream + +This stream contains rows of various types. Each row therefore contains a 'type' +identifier before the real data. For example:: + + RDATA events batch ["state", ["!room:id", "m.type", "", "$event:id"]] + RDATA events 12345 ["ev", ["$event:id", "!room:id", "m.type", null, null]] + +An "ev" row is sent for each new event. The fields in the data part are: + + * The new event id + * The room id for the event + * The type of the new event + * The state key of the event, for state events + * The event id of an event which is redacted by this event. + +A "state" row is sent whenever the "current state" in a room changes. The fields in the +data part are: + + * The room id for the state change + * The event type of the state which has changed + * The state_key of the state which has changed + * The event id of the new state + +""" + + +@attr.s(slots=True, frozen=True) +class EventsStreamRow(object): + """A parsed row from the events replication stream""" + type = attr.ib() # str: the TypeId of one of the *EventsStreamRows + data = attr.ib() # BaseEventsStreamRow + + +class BaseEventsStreamRow(object): + """Base class for rows to be sent in the events stream. + + Specifies how to identify, serialize and deserialize the different types. + """ + + TypeId = None # Unique string that ids the type. Must be overriden in sub classes. + + @classmethod + def from_data(cls, data): + """Parse the data from the replication stream into a row. + + By default we just call the constructor with the data list as arguments + + Args: + data: The value of the data object from the replication stream + """ + return cls(*data) + + +@attr.s(slots=True, frozen=True) +class EventsStreamEventRow(BaseEventsStreamRow): + TypeId = "ev" + + event_id = attr.ib() # str + room_id = attr.ib() # str + type = attr.ib() # str + state_key = attr.ib() # str, optional + redacts = attr.ib() # str, optional + + +@attr.s(slots=True, frozen=True) +class EventsStreamCurrentStateRow(BaseEventsStreamRow): + TypeId = "state" + + room_id = attr.ib() # str + type = attr.ib() # str + state_key = attr.ib() # str + event_id = attr.ib() # str, optional + + +TypeToRow = { + Row.TypeId: Row + for Row in ( + EventsStreamEventRow, + EventsStreamCurrentStateRow, + ) +} + + +class EventsStream(Stream): + """We received a new event, or an event went from being an outlier to not + """ + NAME = "events" + + def __init__(self, hs): + self._store = hs.get_datastore() + self.current_token = self._store.get_current_events_token + + super(EventsStream, self).__init__(hs) + + @defer.inlineCallbacks + def update_function(self, from_token, current_token, limit=None): + event_rows = yield self._store.get_all_new_forward_event_rows( + from_token, current_token, limit, + ) + event_updates = ( + (row[0], EventsStreamEventRow.TypeId, row[1:]) + for row in event_rows + ) + + state_rows = yield self._store.get_all_updated_current_state_deltas( + from_token, current_token, limit + ) + state_updates = ( + (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) + for row in state_rows + ) + + all_updates = heapq.merge(event_updates, state_updates) + + defer.returnValue(all_updates) + + @classmethod + def parse_row(cls, row): + (typ, data) = row + data = TypeToRow[typ].from_data(data) + return EventsStreamRow(typ, data) diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py new file mode 100644 index 0000000000..9aa43aa8d2 --- /dev/null +++ b/synapse/replication/tcp/streams/federation.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple + +from ._base import Stream + +FederationStreamRow = namedtuple("FederationStreamRow", ( + "type", # str, the type of data as defined in the BaseFederationRows + "data", # dict, serialization of a federation.send_queue.BaseFederationRow +)) + + +class FederationStream(Stream): + """Data to be sent over federation. Only available when master has federation + sending disabled. + """ + NAME = "federation" + ROW_TYPE = FederationStreamRow + + def __init__(self, hs): + federation_sender = hs.get_federation_sender() + + self.current_token = federation_sender.get_current_token + self.update_function = federation_sender.get_replication_rows + + super(FederationStream, self).__init__(hs) diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index e788769639..1a26f5a1a6 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -647,8 +647,6 @@ class ResetPasswordRestServlet(ClientV1RestServlet): assert_params_in_dict(params, ["new_password"]) new_password = params['new_password'] - logger.info("new_password: %r", new_password) - yield self._set_password_handler.set_password( target_user_id, new_password, requester ) diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py index 373f95126e..a868d06098 100644 --- a/synapse/rest/client/v2_alpha/capabilities.py +++ b/synapse/rest/client/v2_alpha/capabilities.py @@ -16,7 +16,7 @@ import logging from twisted.internet import defer -from synapse.api.constants import DEFAULT_ROOM_VERSION, RoomDisposition, RoomVersions +from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS from synapse.http.servlet import RestServlet from ._base import client_v2_patterns @@ -48,12 +48,10 @@ class CapabilitiesRestServlet(RestServlet): response = { "capabilities": { "m.room_versions": { - "default": DEFAULT_ROOM_VERSION, + "default": DEFAULT_ROOM_VERSION.identifier, "available": { - RoomVersions.V1: RoomDisposition.STABLE, - RoomVersions.V2: RoomDisposition.STABLE, - RoomVersions.STATE_V2_TEST: RoomDisposition.UNSTABLE, - RoomVersions.V3: RoomDisposition.STABLE, + v.identifier: v.disposition + for v in KNOWN_ROOM_VERSIONS.values() }, }, "m.change_password": {"enabled": change_password}, diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py index e6356101fd..3db7ff8d1b 100644 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py @@ -17,8 +17,8 @@ import logging from twisted.internet import defer -from synapse.api.constants import KNOWN_ROOM_VERSIONS from synapse.api.errors import Codes, SynapseError +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.servlet import ( RestServlet, assert_params_in_dict, diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 68058f613c..52347fee34 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -24,7 +24,8 @@ from frozendict import frozendict from twisted.internet import defer -from synapse.api.constants import EventTypes, RoomVersions +from synapse.api.constants import EventTypes +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions from synapse.events.snapshot import EventContext from synapse.state import v1, v2 from synapse.util.async_helpers import Linearizer @@ -603,22 +604,15 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ - if room_version == RoomVersions.V1: + v = KNOWN_ROOM_VERSIONS[room_version] + if v.state_res == StateResolutionVersions.V1: return v1.resolve_events_with_store( state_sets, event_map, state_res_store.get_events, ) - elif room_version in ( - RoomVersions.STATE_V2_TEST, RoomVersions.V2, RoomVersions.V3, - ): + else: return v2.resolve_events_with_store( room_version, state_sets, event_map, state_res_store, ) - else: - # This should only happen if we added a version but forgot to add it to - # the list above. - raise Exception( - "No state resolution algorithm defined for version %r" % (room_version,) - ) @attr.s diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 6d3afcae7c..29b4e86cfd 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -21,8 +21,9 @@ from six import iteritems, iterkeys, itervalues from twisted.internet import defer from synapse import event_auth -from synapse.api.constants import EventTypes, RoomVersions +from synapse.api.constants import EventTypes from synapse.api.errors import AuthError +from synapse.api.room_versions import RoomVersions logger = logging.getLogger(__name__) @@ -275,7 +276,9 @@ def _resolve_auth_events(events, auth_events): try: # The signatures have already been checked at this point event_auth.check( - RoomVersions.V1, event, auth_events, + RoomVersions.V1.identifier, + event, + auth_events, do_sig_check=False, do_size_check=False, ) @@ -291,7 +294,9 @@ def _resolve_normal_events(events, auth_events): try: # The signatures have already been checked at this point event_auth.check( - RoomVersions.V1, event, auth_events, + RoomVersions.V1.identifier, + event, + auth_events, do_sig_check=False, do_size_check=False, ) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 428300ea0a..d0668e39c4 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -30,7 +30,6 @@ from twisted.internet import defer import synapse.metrics from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError -# these are only included to make the type annotations work from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 from synapse.metrics.background_process_metrics import run_as_background_process @@ -51,8 +50,11 @@ from synapse.util.metrics import Measure logger = logging.getLogger(__name__) persist_event_counter = Counter("synapse_storage_events_persisted_events", "") -event_counter = Counter("synapse_storage_events_persisted_events_sep", "", - ["type", "origin_type", "origin_entity"]) +event_counter = Counter( + "synapse_storage_events_persisted_events_sep", + "", + ["type", "origin_type", "origin_entity"], +) # The number of times we are recalculating the current state state_delta_counter = Counter("synapse_storage_events_state_delta", "") @@ -60,13 +62,15 @@ state_delta_counter = Counter("synapse_storage_events_state_delta", "") # The number of times we are recalculating state when there is only a # single forward extremity state_delta_single_event_counter = Counter( - "synapse_storage_events_state_delta_single_event", "") + "synapse_storage_events_state_delta_single_event", "" +) # The number of times we are reculating state when we could have resonably # calculated the delta when we calculated the state for an event we were # persisting. state_delta_reuse_delta_counter = Counter( - "synapse_storage_events_state_delta_reuse_delta", "") + "synapse_storage_events_state_delta_reuse_delta", "" +) def encode_json(json_object): @@ -84,9 +88,9 @@ class _EventPeristenceQueue(object): concurrent transaction per room. """ - _EventPersistQueueItem = namedtuple("_EventPersistQueueItem", ( - "events_and_contexts", "backfilled", "deferred", - )) + _EventPersistQueueItem = namedtuple( + "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred") + ) def __init__(self): self._event_persist_queues = {} @@ -119,11 +123,13 @@ class _EventPeristenceQueue(object): deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) - queue.append(self._EventPersistQueueItem( - events_and_contexts=events_and_contexts, - backfilled=backfilled, - deferred=deferred, - )) + queue.append( + self._EventPersistQueueItem( + events_and_contexts=events_and_contexts, + backfilled=backfilled, + deferred=deferred, + ) + ) return deferred.observe() @@ -191,6 +197,7 @@ def _retry_on_integrity_error(func): Args: func: function that returns a Deferred and accepts a `delete_existing` arg """ + @wraps(func) @defer.inlineCallbacks def f(self, *args, **kwargs): @@ -206,8 +213,12 @@ def _retry_on_integrity_error(func): # inherits from EventFederationStore so that we can call _update_backward_extremities # and _handle_mult_prev_events (though arguably those could both be moved in here) -class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore, - BackgroundUpdateStore): +class EventsStore( + StateGroupWorkerStore, + EventFederationStore, + EventsWorkerStore, + BackgroundUpdateStore, +): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" @@ -265,8 +276,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore deferreds = [] for room_id, evs_ctxs in iteritems(partitioned): d = self._event_persist_queue.add_to_queue( - room_id, evs_ctxs, - backfilled=backfilled, + room_id, evs_ctxs, backfilled=backfilled ) deferreds.append(d) @@ -296,8 +306,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore and the stream ordering of the latest persisted event """ deferred = self._event_persist_queue.add_to_queue( - event.room_id, [(event, context)], - backfilled=backfilled, + event.room_id, [(event, context)], backfilled=backfilled ) self._maybe_start_persisting(event.room_id) @@ -312,16 +321,16 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore def persisting_queue(item): with Measure(self._clock, "persist_events"): yield self._persist_events( - item.events_and_contexts, - backfilled=item.backfilled, + item.events_and_contexts, backfilled=item.backfilled ) self._event_persist_queue.handle_queue(room_id, persisting_queue) @_retry_on_integrity_error @defer.inlineCallbacks - def _persist_events(self, events_and_contexts, backfilled=False, - delete_existing=False): + def _persist_events( + self, events_and_contexts, backfilled=False, delete_existing=False + ): """Persist events to db Args: @@ -345,13 +354,11 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore ) with stream_ordering_manager as stream_orderings: - for (event, context), stream, in zip( - events_and_contexts, stream_orderings - ): + for (event, context), stream in zip(events_and_contexts, stream_orderings): event.internal_metadata.stream_ordering = stream chunks = [ - events_and_contexts[x:x + 100] + events_and_contexts[x : x + 100] for x in range(0, len(events_and_contexts), 100) ] @@ -445,12 +452,9 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore state_delta_reuse_delta_counter.inc() break - logger.info( - "Calculating state delta for room %s", room_id, - ) + logger.info("Calculating state delta for room %s", room_id) with Measure( - self._clock, - "persist_events.get_new_state_after_events", + self._clock, "persist_events.get_new_state_after_events" ): res = yield self._get_new_state_after_events( room_id, @@ -470,11 +474,10 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore state_delta_for_room[room_id] = ([], delta_ids) elif current_state is not None: with Measure( - self._clock, - "persist_events.calculate_state_delta", + self._clock, "persist_events.calculate_state_delta" ): delta = yield self._calculate_state_delta( - room_id, current_state, + room_id, current_state ) state_delta_for_room[room_id] = delta @@ -498,7 +501,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore # backfilled events have negative stream orderings, so we don't # want to set the event_persisted_position to that. synapse.metrics.event_persisted_position.set( - chunk[-1][0].internal_metadata.stream_ordering, + chunk[-1][0].internal_metadata.stream_ordering ) for event, context in chunk: @@ -515,9 +518,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore event_counter.labels(event.type, origin_type, origin_entity).inc() for room_id, new_state in iteritems(current_state_for_room): - self.get_current_state_ids.prefill( - (room_id, ), new_state - ) + self.get_current_state_ids.prefill((room_id,), new_state) for room_id, latest_event_ids in iteritems(new_forward_extremeties): self.get_latest_event_ids_in_room.prefill( @@ -535,8 +536,10 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore # we're only interested in new events which aren't outliers and which aren't # being rejected. new_events = [ - event for event, ctx in event_contexts - if not event.internal_metadata.is_outlier() and not ctx.rejected + event + for event, ctx in event_contexts + if not event.internal_metadata.is_outlier() + and not ctx.rejected and not event.internal_metadata.is_soft_failed() ] @@ -544,15 +547,11 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore result = set(latest_event_ids) # add all the new events to the list - result.update( - event.event_id for event in new_events - ) + result.update(event.event_id for event in new_events) # Now remove all events which are prev_events of any of the new events result.difference_update( - e_id - for event in new_events - for e_id in event.prev_event_ids() + e_id for event in new_events for e_id in event.prev_event_ids() ) # Finally, remove any events which are prev_events of any existing events. @@ -592,17 +591,14 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore results.extend(r[0] for r in txn) for chunk in batch_iter(event_ids, 100): - yield self.runInteraction( - "_get_events_which_are_prevs", - _get_events, - chunk, - ) + yield self.runInteraction("_get_events_which_are_prevs", _get_events, chunk) defer.returnValue(results) @defer.inlineCallbacks - def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids, - new_latest_event_ids): + def _get_new_state_after_events( + self, room_id, events_context, old_latest_event_ids, new_latest_event_ids + ): """Calculate the current state dict after adding some new events to a room @@ -642,7 +638,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore if not ev.internal_metadata.is_outlier(): raise Exception( "Context for new event %s has no state " - "group" % (ev.event_id, ), + "group" % (ev.event_id,) ) continue @@ -682,9 +678,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore if missing_event_ids: # Now pull out the state groups for any missing events from DB - event_to_groups = yield self._get_state_group_for_events( - missing_event_ids, - ) + event_to_groups = yield self._get_state_group_for_events(missing_event_ids) event_id_to_state_group.update(event_to_groups) # State groups of old_latest_event_ids @@ -710,9 +704,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore new_state_group = next(iter(new_state_groups)) old_state_group = next(iter(old_state_groups)) - delta_ids = state_group_deltas.get( - (old_state_group, new_state_group,), None - ) + delta_ids = state_group_deltas.get((old_state_group, new_state_group), None) if delta_ids is not None: # We have a delta from the existing to new current state, # so lets just return that. If we happen to already have @@ -735,9 +727,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore # Ok, we need to defer to the state handler to resolve our state sets. - state_groups = { - sg: state_groups_map[sg] for sg in new_state_groups - } + state_groups = {sg: state_groups_map[sg] for sg in new_state_groups} events_map = {ev.event_id: ev for ev, _ in events_context} @@ -755,8 +745,11 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore logger.debug("calling resolve_state_groups from preserve_events") res = yield self._state_resolution_handler.resolve_state_groups( - room_id, room_version, state_groups, events_map, - state_res_store=StateResolutionStore(self) + room_id, + room_version, + state_groups, + events_map, + state_res_store=StateResolutionStore(self), ) defer.returnValue((res.state, None)) @@ -774,22 +767,26 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore """ existing_state = yield self.get_current_state_ids(room_id) - to_delete = [ - key for key in existing_state - if key not in current_state - ] + to_delete = [key for key in existing_state if key not in current_state] to_insert = { - key: ev_id for key, ev_id in iteritems(current_state) + key: ev_id + for key, ev_id in iteritems(current_state) if ev_id != existing_state.get(key) } defer.returnValue((to_delete, to_insert)) @log_function - def _persist_events_txn(self, txn, events_and_contexts, backfilled, - delete_existing=False, state_delta_for_room={}, - new_forward_extremeties={}): + def _persist_events_txn( + self, + txn, + events_and_contexts, + backfilled, + delete_existing=False, + state_delta_for_room={}, + new_forward_extremeties={}, + ): """Insert some number of room events into the necessary database tables. Rejected events are only inserted into the events table, the events_json table, @@ -828,20 +825,17 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore # Ensure that we don't have the same event twice. events_and_contexts = self._filter_events_and_contexts_for_duplicates( - events_and_contexts, + events_and_contexts ) self._update_room_depths_txn( - txn, - events_and_contexts=events_and_contexts, - backfilled=backfilled, + txn, events_and_contexts=events_and_contexts, backfilled=backfilled ) # _update_outliers_txn filters out any events which have already been # persisted, and returns the filtered list. events_and_contexts = self._update_outliers_txn( - txn, - events_and_contexts=events_and_contexts, + txn, events_and_contexts=events_and_contexts ) # From this point onwards the events are only events that we haven't @@ -852,15 +846,9 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore # for these events so we can reinsert them. # This gets around any problems with some tables already having # entries. - self._delete_existing_rows_txn( - txn, - events_and_contexts=events_and_contexts, - ) + self._delete_existing_rows_txn(txn, events_and_contexts=events_and_contexts) - self._store_event_txn( - txn, - events_and_contexts=events_and_contexts, - ) + self._store_event_txn(txn, events_and_contexts=events_and_contexts) # Insert into event_to_state_groups. self._store_event_state_mappings_txn(txn, events_and_contexts) @@ -889,8 +877,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore # _store_rejected_events_txn filters out any events which were # rejected, and returns the filtered list. events_and_contexts = self._store_rejected_events_txn( - txn, - events_and_contexts=events_and_contexts, + txn, events_and_contexts=events_and_contexts ) # From this point onwards the events are only ones that weren't @@ -920,22 +907,40 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore WHERE room_id = ? AND type = ? AND state_key = ? ) """ - txn.executemany(sql, ( + txn.executemany( + sql, ( - max_stream_order, room_id, etype, state_key, None, - room_id, etype, state_key, - ) - for etype, state_key in to_delete - # We sanity check that we're deleting rather than updating - if (etype, state_key) not in to_insert - )) - txn.executemany(sql, ( + ( + max_stream_order, + room_id, + etype, + state_key, + None, + room_id, + etype, + state_key, + ) + for etype, state_key in to_delete + # We sanity check that we're deleting rather than updating + if (etype, state_key) not in to_insert + ), + ) + txn.executemany( + sql, ( - max_stream_order, room_id, etype, state_key, ev_id, - room_id, etype, state_key, - ) - for (etype, state_key), ev_id in iteritems(to_insert) - )) + ( + max_stream_order, + room_id, + etype, + state_key, + ev_id, + room_id, + etype, + state_key, + ) + for (etype, state_key), ev_id in iteritems(to_insert) + ), + ) # Now we actually update the current_state_events table @@ -964,7 +969,8 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore txn.call_after( self._curr_state_delta_stream_cache.entity_has_changed, - room_id, max_stream_order, + room_id, + max_stream_order, ) # Invalidate the various caches @@ -982,26 +988,20 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore self._invalidate_state_caches_and_stream(txn, room_id, members_changed) - def _update_forward_extremities_txn(self, txn, new_forward_extremities, - max_stream_order): + def _update_forward_extremities_txn( + self, txn, new_forward_extremities, max_stream_order + ): for room_id, new_extrem in iteritems(new_forward_extremities): self._simple_delete_txn( - txn, - table="event_forward_extremities", - keyvalues={"room_id": room_id}, - ) - txn.call_after( - self.get_latest_event_ids_in_room.invalidate, (room_id,) + txn, table="event_forward_extremities", keyvalues={"room_id": room_id} ) + txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) self._simple_insert_many_txn( txn, table="event_forward_extremities", values=[ - { - "event_id": ev_id, - "room_id": room_id, - } + {"event_id": ev_id, "room_id": room_id} for room_id, new_extrem in iteritems(new_forward_extremities) for ev_id in new_extrem ], @@ -1021,7 +1021,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore } for room_id, new_extrem in iteritems(new_forward_extremities) for event_id in new_extrem - ] + ], ) @classmethod @@ -1065,7 +1065,8 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore if not backfilled: txn.call_after( self._events_stream_cache.entity_has_changed, - event.room_id, event.internal_metadata.stream_ordering, + event.room_id, + event.internal_metadata.stream_ordering, ) if not event.internal_metadata.is_outlier() and not context.rejected: @@ -1092,16 +1093,12 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore are already in the events table. """ txn.execute( - "SELECT event_id, outlier FROM events WHERE event_id in (%s)" % ( - ",".join(["?"] * len(events_and_contexts)), - ), - [event.event_id for event, _ in events_and_contexts] + "SELECT event_id, outlier FROM events WHERE event_id in (%s)" + % (",".join(["?"] * len(events_and_contexts)),), + [event.event_id for event, _ in events_and_contexts], ) - have_persisted = { - event_id: outlier - for event_id, outlier in txn - } + have_persisted = {event_id: outlier for event_id, outlier in txn} to_remove = set() for event, context in events_and_contexts: @@ -1128,18 +1125,12 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore logger.exception("") raise - metadata_json = encode_json( - event.internal_metadata.get_dict() - ) + metadata_json = encode_json(event.internal_metadata.get_dict()) sql = ( - "UPDATE event_json SET internal_metadata = ?" - " WHERE event_id = ?" - ) - txn.execute( - sql, - (metadata_json, event.event_id,) + "UPDATE event_json SET internal_metadata = ?" " WHERE event_id = ?" ) + txn.execute(sql, (metadata_json, event.event_id)) # Add an entry to the ex_outlier_stream table to replicate the # change in outlier status to our workers. @@ -1152,25 +1143,17 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore "event_stream_ordering": stream_order, "event_id": event.event_id, "state_group": state_group_id, - } + }, ) - sql = ( - "UPDATE events SET outlier = ?" - " WHERE event_id = ?" - ) - txn.execute( - sql, - (False, event.event_id,) - ) + sql = "UPDATE events SET outlier = ?" " WHERE event_id = ?" + txn.execute(sql, (False, event.event_id)) # Update the event_backward_extremities table now that this # event isn't an outlier any more. self._update_backward_extremeties(txn, [event]) - return [ - ec for ec in events_and_contexts if ec[0] not in to_remove - ] + return [ec for ec in events_and_contexts if ec[0] not in to_remove] @classmethod def _delete_existing_rows_txn(cls, txn, events_and_contexts): @@ -1181,39 +1164,37 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore logger.info("Deleting existing") for table in ( - "events", - "event_auth", - "event_json", - "event_content_hashes", - "event_destinations", - "event_edge_hashes", - "event_edges", - "event_forward_extremities", - "event_reference_hashes", - "event_search", - "event_signatures", - "event_to_state_groups", - "guest_access", - "history_visibility", - "local_invites", - "room_names", - "state_events", - "rejections", - "redactions", - "room_memberships", - "topics" + "events", + "event_auth", + "event_json", + "event_content_hashes", + "event_destinations", + "event_edge_hashes", + "event_edges", + "event_forward_extremities", + "event_reference_hashes", + "event_search", + "event_signatures", + "event_to_state_groups", + "guest_access", + "history_visibility", + "local_invites", + "room_names", + "state_events", + "rejections", + "redactions", + "room_memberships", + "topics", ): txn.executemany( "DELETE FROM %s WHERE event_id = ?" % (table,), - [(ev.event_id,) for ev, _ in events_and_contexts] + [(ev.event_id,) for ev, _ in events_and_contexts], ) - for table in ( - "event_push_actions", - ): + for table in ("event_push_actions",): txn.executemany( "DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,), - [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts] + [(ev.room_id, ev.event_id) for ev, _ in events_and_contexts], ) def _store_event_txn(self, txn, events_and_contexts): @@ -1296,17 +1277,14 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore for event, context in events_and_contexts: if context.rejected: # Insert the event_id into the rejections table - self._store_rejections_txn( - txn, event.event_id, context.rejected - ) + self._store_rejections_txn(txn, event.event_id, context.rejected) to_remove.add(event) - return [ - ec for ec in events_and_contexts if ec[0] not in to_remove - ] + return [ec for ec in events_and_contexts if ec[0] not in to_remove] - def _update_metadata_tables_txn(self, txn, events_and_contexts, - all_events_and_contexts, backfilled): + def _update_metadata_tables_txn( + self, txn, events_and_contexts, all_events_and_contexts, backfilled + ): """Update all the miscellaneous tables for new events Args: @@ -1342,8 +1320,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. self._handle_mult_prev_events( - txn, - events=[event for event, _ in events_and_contexts], + txn, events=[event for event, _ in events_and_contexts] ) for event, _ in events_and_contexts: @@ -1401,11 +1378,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore state_values.append(vals) - self._simple_insert_many_txn( - txn, - table="state_events", - values=state_values, - ) + self._simple_insert_many_txn(txn, table="state_events", values=state_values) # Prefill the event cache self._add_to_cache(txn, events_and_contexts) @@ -1416,10 +1389,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore rows = [] N = 200 for i in range(0, len(events_and_contexts), N): - ev_map = { - e[0].event_id: e[0] - for e in events_and_contexts[i:i + N] - } + ev_map = {e[0].event_id: e[0] for e in events_and_contexts[i : i + N]} if not ev_map: break @@ -1439,14 +1409,14 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore for row in rows: event = ev_map[row["event_id"]] if not row["rejects"] and not row["redacts"]: - to_prefill.append(_EventCacheEntry( - event=event, - redacted_event=None, - )) + to_prefill.append( + _EventCacheEntry(event=event, redacted_event=None) + ) def prefill(): for cache_entry in to_prefill: self._get_event_cache.prefill((cache_entry[0].event_id,), cache_entry) + txn.call_after(prefill) def _store_redaction(self, txn, event): @@ -1454,7 +1424,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore txn.call_after(self._invalidate_get_event_cache, event.redacts) txn.execute( "INSERT INTO redactions (event_id, redacts) VALUES (?,?)", - (event.event_id, event.redacts) + (event.event_id, event.redacts), ) @defer.inlineCallbacks @@ -1465,6 +1435,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore If it has been significantly less or more than one day since the last call to this function, it will return None. """ + def _count_messages(txn): sql = """ SELECT COALESCE(COUNT(*), 0) FROM events @@ -1492,7 +1463,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore AND stream_ordering > ? """ - txn.execute(sql, (like_clause, self.stream_ordering_day_ago,)) + txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) count, = txn.fetchone() return count @@ -1557,18 +1528,16 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore update_rows.append((sender, contains_url, event_id)) - sql = ( - "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?" - ) + sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?" for index in range(0, len(update_rows), INSERT_CLUMP_SIZE): - clump = update_rows[index:index + INSERT_CLUMP_SIZE] + clump = update_rows[index : index + INSERT_CLUMP_SIZE] txn.executemany(sql, clump) progress = { "target_min_stream_id_inclusive": target_min_stream_id, "max_stream_id_exclusive": min_stream_id, - "rows_inserted": rows_inserted + len(rows) + "rows_inserted": rows_inserted + len(rows), } self._background_update_progress_txn( @@ -1613,10 +1582,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore rows_to_update = [] - chunks = [ - event_ids[i:i + 100] - for i in range(0, len(event_ids), 100) - ] + chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] for chunk in chunks: ev_rows = self._simple_select_many_txn( txn, @@ -1639,18 +1605,16 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore rows_to_update.append((origin_server_ts, event_id)) - sql = ( - "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" - ) + sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE): - clump = rows_to_update[index:index + INSERT_CLUMP_SIZE] + clump = rows_to_update[index : index + INSERT_CLUMP_SIZE] txn.executemany(sql, clump) progress = { "target_min_stream_id_inclusive": target_min_stream_id, "max_stream_id_exclusive": min_stream_id, - "rows_inserted": rows_inserted + len(rows_to_update) + "rows_inserted": rows_inserted + len(rows_to_update), } self._background_update_progress_txn( @@ -1714,6 +1678,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore new_event_updates.extend(txn) return new_event_updates + return self.runInteraction( "get_all_new_forward_event_rows", get_all_new_forward_event_rows ) @@ -1756,13 +1721,20 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore new_event_updates.extend(txn.fetchall()) return new_event_updates + return self.runInteraction( "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows ) @cached(num_args=5, max_entries=10) - def get_all_new_events(self, last_backfill_id, last_forward_id, - current_backfill_id, current_forward_id, limit): + def get_all_new_events( + self, + last_backfill_id, + last_forward_id, + current_backfill_id, + current_forward_id, + limit, + ): """Get all the new events that have arrived at the server either as new events or as backfilled events""" have_backfill_events = last_backfill_id != current_backfill_id @@ -1837,14 +1809,15 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore backward_ex_outliers = [] return AllNewEventsResult( - new_forward_events, new_backfill_events, - forward_ex_outliers, backward_ex_outliers, + new_forward_events, + new_backfill_events, + forward_ex_outliers, + backward_ex_outliers, ) + return self.runInteraction("get_all_new_events", get_all_new_events_txn) - def purge_history( - self, room_id, token, delete_local_events, - ): + def purge_history(self, room_id, token, delete_local_events): """Deletes room history before a certain point Args: @@ -1860,13 +1833,13 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore return self.runInteraction( "purge_history", - self._purge_history_txn, room_id, token, + self._purge_history_txn, + room_id, + token, delete_local_events, ) - def _purge_history_txn( - self, txn, room_id, token_str, delete_local_events, - ): + def _purge_history_txn(self, txn, room_id, token_str, delete_local_events): token = RoomStreamToken.parse(token_str) # Tables that should be pruned: @@ -1913,7 +1886,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore "ON e.event_id = f.event_id " "AND e.room_id = f.room_id " "WHERE f.room_id = ?", - (room_id,) + (room_id,), ) rows = txn.fetchall() max_depth = max(row[1] for row in rows) @@ -1934,10 +1907,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore should_delete_expr += " AND event_id NOT LIKE ?" # We include the parameter twice since we use the expression twice - should_delete_params += ( - "%:" + self.hs.hostname, - "%:" + self.hs.hostname, - ) + should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname) should_delete_params += (room_id, token.topological) @@ -1948,10 +1918,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore " SELECT event_id, %s" " FROM events AS e LEFT JOIN state_events USING (event_id)" " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?" - % ( - should_delete_expr, - should_delete_expr, - ), + % (should_delete_expr, should_delete_expr), should_delete_params, ) @@ -1961,23 +1928,19 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore # the should_delete / shouldn't_delete subsets txn.execute( "CREATE INDEX events_to_purge_should_delete" - " ON events_to_purge(should_delete)", + " ON events_to_purge(should_delete)" ) # We do joins against events_to_purge for e.g. calculating state # groups to purge, etc., so lets make an index. - txn.execute( - "CREATE INDEX events_to_purge_id" - " ON events_to_purge(event_id)", - ) + txn.execute("CREATE INDEX events_to_purge_id" " ON events_to_purge(event_id)") - txn.execute( - "SELECT event_id, should_delete FROM events_to_purge" - ) + txn.execute("SELECT event_id, should_delete FROM events_to_purge") event_rows = txn.fetchall() logger.info( "[purge] found %i events before cutoff, of which %i can be deleted", - len(event_rows), sum(1 for e in event_rows if e[1]), + len(event_rows), + sum(1 for e in event_rows if e[1]), ) logger.info("[purge] Finding new backward extremities") @@ -1989,24 +1952,21 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore "SELECT DISTINCT e.event_id FROM events_to_purge AS e" " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id" " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id" - " WHERE ep2.event_id IS NULL", + " WHERE ep2.event_id IS NULL" ) new_backwards_extrems = txn.fetchall() logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems) txn.execute( - "DELETE FROM event_backward_extremities WHERE room_id = ?", - (room_id,) + "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,) ) # Update backward extremeties txn.executemany( "INSERT INTO event_backward_extremities (room_id, event_id)" " VALUES (?, ?)", - [ - (room_id, event_id) for event_id, in new_backwards_extrems - ] + [(room_id, event_id) for event_id, in new_backwards_extrems], ) logger.info("[purge] finding redundant state groups") @@ -2014,28 +1974,25 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore # Get all state groups that are referenced by events that are to be # deleted. We then go and check if they are referenced by other events # or state groups, and if not we delete them. - txn.execute(""" + txn.execute( + """ SELECT DISTINCT state_group FROM events_to_purge INNER JOIN event_to_state_groups USING (event_id) - """) + """ + ) referenced_state_groups = set(sg for sg, in txn) logger.info( - "[purge] found %i referenced state groups", - len(referenced_state_groups), + "[purge] found %i referenced state groups", len(referenced_state_groups) ) logger.info("[purge] finding state groups that can be deleted") - state_groups_to_delete, remaining_state_groups = ( - self._find_unreferenced_groups_during_purge( - txn, referenced_state_groups, - ) - ) + _ = self._find_unreferenced_groups_during_purge(txn, referenced_state_groups) + state_groups_to_delete, remaining_state_groups = _ logger.info( - "[purge] found %i state groups to delete", - len(state_groups_to_delete), + "[purge] found %i state groups to delete", len(state_groups_to_delete) ) logger.info( @@ -2047,25 +2004,15 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore # groups to non delta versions. for sg in remaining_state_groups: logger.info("[purge] de-delta-ing remaining state group %s", sg) - curr_state = self._get_state_groups_from_groups_txn( - txn, [sg], - ) + curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) curr_state = curr_state[sg] self._simple_delete_txn( - txn, - table="state_groups_state", - keyvalues={ - "state_group": sg, - } + txn, table="state_groups_state", keyvalues={"state_group": sg} ) self._simple_delete_txn( - txn, - table="state_group_edges", - keyvalues={ - "state_group": sg, - } + txn, table="state_group_edges", keyvalues={"state_group": sg} ) self._simple_insert_many_txn( @@ -2099,9 +2046,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore "WHERE event_id IN (SELECT event_id from events_to_purge)" ) for event_id, _ in event_rows: - txn.call_after(self._get_state_group_for_event.invalidate, ( - event_id, - )) + txn.call_after(self._get_state_group_for_event.invalidate, (event_id,)) # Delete all remote non-state events for table in ( @@ -2123,21 +2068,19 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore txn.execute( "DELETE FROM %s WHERE event_id IN (" " SELECT event_id FROM events_to_purge WHERE should_delete" - ")" % (table,), + ")" % (table,) ) # event_push_actions lacks an index on event_id, and has one on # (room_id, event_id) instead. - for table in ( - "event_push_actions", - ): + for table in ("event_push_actions",): logger.info("[purge] removing events from %s", table) txn.execute( "DELETE FROM %s WHERE room_id = ? AND event_id IN (" " SELECT event_id FROM events_to_purge WHERE should_delete" ")" % (table,), - (room_id, ) + (room_id,), ) # Mark all state and own events as outliers @@ -2162,27 +2105,28 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore # extremities. However, the events in event_backward_extremities # are ones we don't have yet so we need to look at the events that # point to it via event_edges table. - txn.execute(""" + txn.execute( + """ SELECT COALESCE(MIN(depth), 0) FROM event_backward_extremities AS eb INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id INNER JOIN events AS e ON e.event_id = eg.event_id WHERE eb.room_id = ? - """, (room_id,)) + """, + (room_id,), + ) min_depth, = txn.fetchone() logger.info("[purge] updating room_depth to %d", min_depth) txn.execute( "UPDATE room_depth SET min_depth = ? WHERE room_id = ?", - (min_depth, room_id,) + (min_depth, room_id), ) # finally, drop the temp table. this will commit the txn in sqlite, # so make sure to keep this actually last. - txn.execute( - "DROP TABLE events_to_purge" - ) + txn.execute("DROP TABLE events_to_purge") logger.info("[purge] done") @@ -2226,7 +2170,9 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore SELECT DISTINCT state_group FROM event_to_state_groups LEFT JOIN events_to_purge AS ep USING (event_id) WHERE state_group IN (%s) AND ep.event_id IS NULL - """ % (",".join("?" for _ in current_search),) + """ % ( + ",".join("?" for _ in current_search), + ) txn.execute(sql, list(current_search)) referenced = set(sg for sg, in txn) @@ -2242,7 +2188,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore column="prev_state_group", iterable=current_search, keyvalues={}, - retcols=("prev_state_group", "state_group",), + retcols=("prev_state_group", "state_group"), ) prevs = set(row["state_group"] for row in rows) @@ -2279,16 +2225,15 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore table="events", retcols=["topological_ordering", "stream_ordering"], keyvalues={"event_id": event_id}, - allow_none=True + allow_none=True, ) if not res: raise SynapseError(404, "Could not find event %s" % (event_id,)) - defer.returnValue((int(res["topological_ordering"]), int(res["stream_ordering"]))) - - def get_max_current_state_delta_stream_id(self): - return self._stream_id_gen.get_current_token() + defer.returnValue( + (int(res["topological_ordering"]), int(res["stream_ordering"])) + ) def get_all_updated_current_state_deltas(self, from_token, to_token, limit): def get_all_updated_current_state_deltas_txn(txn): @@ -2300,13 +2245,19 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore """ txn.execute(sql, (from_token, to_token, limit)) return txn.fetchall() + return self.runInteraction( "get_all_updated_current_state_deltas", get_all_updated_current_state_deltas_txn, ) -AllNewEventsResult = namedtuple("AllNewEventsResult", [ - "new_forward_events", "new_backfill_events", - "forward_ex_outliers", "backward_ex_outliers", -]) +AllNewEventsResult = namedtuple( + "AllNewEventsResult", + [ + "new_forward_events", + "new_backfill_events", + "forward_ex_outliers", + "backward_ex_outliers", + ], +) diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 1716be529a..53c8dc3903 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -21,8 +21,9 @@ from canonicaljson import json from twisted.internet import defer -from synapse.api.constants import EventFormatVersions, EventTypes +from synapse.api.constants import EventTypes from synapse.api.errors import NotFoundError +from synapse.api.room_versions import EventFormatVersions from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401 # these are only included to make the type annotations work from synapse.events.snapshot import EventContext # noqa: F401 diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 592c1bcd33..57df17bcc2 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -35,28 +35,22 @@ logger = logging.getLogger(__name__) RoomsForUser = namedtuple( - "RoomsForUser", - ("room_id", "sender", "membership", "event_id", "stream_ordering") + "RoomsForUser", ("room_id", "sender", "membership", "event_id", "stream_ordering") ) GetRoomsForUserWithStreamOrdering = namedtuple( - "_GetRoomsForUserWithStreamOrdering", - ("room_id", "stream_ordering",) + "_GetRoomsForUserWithStreamOrdering", ("room_id", "stream_ordering") ) # We store this using a namedtuple so that we save about 3x space over using a # dict. -ProfileInfo = namedtuple( - "ProfileInfo", ("avatar_url", "display_name") -) +ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name")) # "members" points to a truncated list of (user_id, event_id) tuples for users of # a given membership type, suitable for use in calculating heroes for a room. # "count" points to the total numberr of users of a given membership type. -MemberSummary = namedtuple( - "MemberSummary", ("members", "count") -) +MemberSummary = namedtuple("MemberSummary", ("members", "count")) _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" @@ -67,7 +61,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """Returns the set of all hosts currently in the room """ user_ids = yield self.get_users_in_room( - room_id, on_invalidate=cache_context.invalidate, + room_id, on_invalidate=cache_context.invalidate ) hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids) defer.returnValue(hosts) @@ -84,8 +78,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?" ) - txn.execute(sql, (room_id, Membership.JOIN,)) + txn.execute(sql, (room_id, Membership.JOIN)) return [to_ascii(r[0]) for r in txn] + return self.runInteraction("get_users_in_room", f) @cached(max_entries=100000) @@ -156,9 +151,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): A deferred list of RoomsForUser. """ - return self.get_rooms_for_user_where_membership_is( - user_id, [Membership.INVITE] - ) + return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE]) @defer.inlineCallbacks def get_invite_for_user_in_room(self, user_id, room_id): @@ -196,11 +189,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): return self.runInteraction( "get_rooms_for_user_where_membership_is", self._get_rooms_for_user_where_membership_is_txn, - user_id, membership_list + user_id, + membership_list, ) - def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id, - membership_list): + def _get_rooms_for_user_where_membership_is_txn( + self, txn, user_id, membership_list + ): do_invite = Membership.INVITE in membership_list membership_list = [m for m in membership_list if m != Membership.INVITE] @@ -227,9 +222,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) % (where_clause,) txn.execute(sql, args) - results = [ - RoomsForUser(**r) for r in self.cursor_to_dict(txn) - ] + results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)] if do_invite: sql = ( @@ -241,13 +234,16 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) txn.execute(sql, (user_id,)) - results.extend(RoomsForUser( - room_id=r["room_id"], - sender=r["inviter"], - event_id=r["event_id"], - stream_ordering=r["stream_ordering"], - membership=Membership.INVITE, - ) for r in self.cursor_to_dict(txn)) + results.extend( + RoomsForUser( + room_id=r["room_id"], + sender=r["inviter"], + event_id=r["event_id"], + stream_ordering=r["stream_ordering"], + membership=Membership.INVITE, + ) + for r in self.cursor_to_dict(txn) + ) return results @@ -264,19 +260,21 @@ class RoomMemberWorkerStore(EventsWorkerStore): of the most recent join for that user and room. """ rooms = yield self.get_rooms_for_user_where_membership_is( - user_id, membership_list=[Membership.JOIN], + user_id, membership_list=[Membership.JOIN] + ) + defer.returnValue( + frozenset( + GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering) + for r in rooms + ) ) - defer.returnValue(frozenset( - GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering) - for r in rooms - )) @defer.inlineCallbacks def get_rooms_for_user(self, user_id, on_invalidate=None): """Returns a set of room_ids the user is currently joined to """ rooms = yield self.get_rooms_for_user_with_stream_ordering( - user_id, on_invalidate=on_invalidate, + user_id, on_invalidate=on_invalidate ) defer.returnValue(frozenset(r.room_id for r in rooms)) @@ -285,13 +283,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): """Returns the set of users who share a room with `user_id` """ room_ids = yield self.get_rooms_for_user( - user_id, on_invalidate=cache_context.invalidate, + user_id, on_invalidate=cache_context.invalidate ) user_who_share_room = set() for room_id in room_ids: user_ids = yield self.get_users_in_room( - room_id, on_invalidate=cache_context.invalidate, + room_id, on_invalidate=cache_context.invalidate ) user_who_share_room.update(user_ids) @@ -309,9 +307,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): current_state_ids = yield context.get_current_state_ids(self) result = yield self._get_joined_users_from_context( - event.room_id, state_group, current_state_ids, - event=event, - context=context, + event.room_id, state_group, current_state_ids, event=event, context=context ) defer.returnValue(result) @@ -325,13 +321,21 @@ class RoomMemberWorkerStore(EventsWorkerStore): state_group = object() return self._get_joined_users_from_context( - room_id, state_group, state_entry.state, context=state_entry, + room_id, state_group, state_entry.state, context=state_entry ) - @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True, - max_entries=100000) - def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, - cache_context, event=None, context=None): + @cachedInlineCallbacks( + num_args=2, cache_context=True, iterable=True, max_entries=100000 + ) + def _get_joined_users_from_context( + self, + room_id, + state_group, + current_state_ids, + cache_context, + event=None, + context=None, + ): # We don't use `state_group`, it's there so that we can cache based # on it. However, it's important that it's never None, since two current_states # with a state_group of None are likely to be different. @@ -371,9 +375,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # the hit ratio counts. After all, we don't populate the cache if we # miss it here event_map = self._get_events_from_cache( - member_event_ids, - allow_rejected=False, - update_metrics=False, + member_event_ids, allow_rejected=False, update_metrics=False ) missing_member_event_ids = [] @@ -397,21 +399,21 @@ class RoomMemberWorkerStore(EventsWorkerStore): table="room_memberships", column="event_id", iterable=missing_member_event_ids, - retcols=('user_id', 'display_name', 'avatar_url',), - keyvalues={ - "membership": Membership.JOIN, - }, + retcols=('user_id', 'display_name', 'avatar_url'), + keyvalues={"membership": Membership.JOIN}, batch_size=500, desc="_get_joined_users_from_context", ) - users_in_room.update({ - to_ascii(row["user_id"]): ProfileInfo( - avatar_url=to_ascii(row["avatar_url"]), - display_name=to_ascii(row["display_name"]), - ) - for row in rows - }) + users_in_room.update( + { + to_ascii(row["user_id"]): ProfileInfo( + avatar_url=to_ascii(row["avatar_url"]), + display_name=to_ascii(row["display_name"]), + ) + for row in rows + } + ) if event is not None and event.type == EventTypes.Member: if event.membership == Membership.JOIN: @@ -505,7 +507,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): state_group = object() return self._get_joined_hosts( - room_id, state_group, state_entry.state, state_entry=state_entry, + room_id, state_group, state_entry.state, state_entry=state_entry ) @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) @@ -531,6 +533,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """Returns whether user_id has elected to discard history for room_id. Returns False if they have since re-joined.""" + def f(txn): sql = ( "SELECT" @@ -547,6 +550,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(sql, (user_id, room_id)) rows = txn.fetchall() return rows[0][0] + count = yield self.runInteraction("did_forget_membership", f) defer.returnValue(count == 0) @@ -575,13 +579,14 @@ class RoomMemberStore(RoomMemberWorkerStore): "avatar_url": event.content.get("avatar_url", None), } for event in events - ] + ], ) for event in events: txn.call_after( self._membership_stream_cache.entity_has_changed, - event.state_key, event.internal_metadata.stream_ordering + event.state_key, + event.internal_metadata.stream_ordering, ) txn.call_after( self.get_invited_rooms_for_user.invalidate, (event.state_key,) @@ -607,7 +612,7 @@ class RoomMemberStore(RoomMemberWorkerStore): "inviter": event.sender, "room_id": event.room_id, "stream_id": event.internal_metadata.stream_ordering, - } + }, ) else: sql = ( @@ -616,12 +621,15 @@ class RoomMemberStore(RoomMemberWorkerStore): " AND replaced_by is NULL" ) - txn.execute(sql, ( - event.internal_metadata.stream_ordering, - event.event_id, - event.room_id, - event.state_key, - )) + txn.execute( + sql, + ( + event.internal_metadata.stream_ordering, + event.event_id, + event.room_id, + event.state_key, + ), + ) @defer.inlineCallbacks def locally_reject_invite(self, user_id, room_id): @@ -632,18 +640,14 @@ class RoomMemberStore(RoomMemberWorkerStore): ) def f(txn, stream_ordering): - txn.execute(sql, ( - stream_ordering, - True, - room_id, - user_id, - )) + txn.execute(sql, (stream_ordering, True, room_id, user_id)) with self._stream_id_gen.get_next() as stream_ordering: yield self.runInteraction("locally_reject_invite", f, stream_ordering) def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" + def f(txn): sql = ( "UPDATE" @@ -657,9 +661,8 @@ class RoomMemberStore(RoomMemberWorkerStore): ) txn.execute(sql, (user_id, room_id)) - self._invalidate_cache_and_stream( - txn, self.did_forget, (user_id, room_id,), - ) + self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id)) + return self.runInteraction("forget_membership", f) @defer.inlineCallbacks @@ -674,7 +677,7 @@ class RoomMemberStore(RoomMemberWorkerStore): INSERT_CLUMP_SIZE = 1000 def add_membership_profile_txn(txn): - sql = (""" + sql = """ SELECT stream_ordering, event_id, events.room_id, event_json.json FROM events INNER JOIN event_json USING (event_id) @@ -683,7 +686,7 @@ class RoomMemberStore(RoomMemberWorkerStore): AND type = 'm.room.member' ORDER BY stream_ordering DESC LIMIT ? - """) + """ txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) @@ -707,16 +710,14 @@ class RoomMemberStore(RoomMemberWorkerStore): avatar_url = content.get("avatar_url", None) if display_name or avatar_url: - to_update.append(( - display_name, avatar_url, event_id, room_id - )) + to_update.append((display_name, avatar_url, event_id, room_id)) - to_update_sql = (""" + to_update_sql = """ UPDATE room_memberships SET display_name = ?, avatar_url = ? WHERE event_id = ? AND room_id = ? - """) + """ for index in range(0, len(to_update), INSERT_CLUMP_SIZE): - clump = to_update[index:index + INSERT_CLUMP_SIZE] + clump = to_update[index : index + INSERT_CLUMP_SIZE] txn.executemany(to_update_sql, clump) progress = { @@ -789,7 +790,7 @@ class _JoinedHostsCache(object): self.hosts_to_joined_users.pop(host, None) else: joined_users = yield self.store.get_joined_users_from_state( - self.room_id, state_entry, + self.room_id, state_entry ) self.hosts_to_joined_users = {} diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index fc2b646ba2..94c6080e34 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -16,7 +16,11 @@ from mock import Mock, call -from synapse.api.constants import PresenceState +from signedjson.key import generate_signing_key + +from synapse.api.constants import EventTypes, Membership, PresenceState +from synapse.events import room_version_to_event_format +from synapse.events.builder import EventBuilder from synapse.handlers.presence import ( FEDERATION_PING_INTERVAL, FEDERATION_TIMEOUT, @@ -26,7 +30,9 @@ from synapse.handlers.presence import ( handle_timeout, handle_update, ) +from synapse.rest.client.v1 import room from synapse.storage.presence import UserPresenceState +from synapse.types import UserID, get_domain_from_id from tests import unittest @@ -405,3 +411,171 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertIsNotNone(new_state) self.assertEquals(state, new_state) + + +class PresenceJoinTestCase(unittest.HomeserverTestCase): + """Tests remote servers get told about presence of users in the room when + they join and when new local users join. + """ + + user_id = "@test:server" + + servlets = [room.register_servlets] + + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver( + "server", http_client=None, + federation_sender=Mock(), + ) + return hs + + def prepare(self, reactor, clock, hs): + self.federation_sender = hs.get_federation_sender() + self.event_builder_factory = hs.get_event_builder_factory() + self.federation_handler = hs.get_handlers().federation_handler + self.presence_handler = hs.get_presence_handler() + + # self.event_builder_for_2 = EventBuilderFactory(hs) + # self.event_builder_for_2.hostname = "test2" + + self.store = hs.get_datastore() + self.state = hs.get_state_handler() + self.auth = hs.get_auth() + + # We don't actually check signatures in tests, so lets just create a + # random key to use. + self.random_signing_key = generate_signing_key("ver") + + def test_remote_joins(self): + # We advance time to something that isn't 0, as we use 0 as a special + # value. + self.reactor.advance(1000000000000) + + # Create a room with two local users + room_id = self.helper.create_room_as(self.user_id) + self.helper.join(room_id, "@test2:server") + + # Mark test2 as online, test will be offline with a last_active of 0 + self.presence_handler.set_state( + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}, + ) + self.reactor.pump([0]) # Wait for presence updates to be handled + + # + # Test that a new server gets told about existing presence + # + + self.federation_sender.reset_mock() + + # Add a new remote server to the room + self._add_new_user(room_id, "@alice:server2") + + # We shouldn't have sent out any local presence *updates* + self.federation_sender.send_presence.assert_not_called() + + # When new server is joined we send it the local users presence states. + # We expect to only see user @test2:server, as @test:server is offline + # and has a zero last_active_ts + expected_state = self.get_success( + self.presence_handler.current_state_for_user("@test2:server") + ) + self.assertEqual(expected_state.state, PresenceState.ONLINE) + self.federation_sender.send_presence_to_destinations.assert_called_once_with( + destinations=["server2"], states=[expected_state] + ) + + # + # Test that only the new server gets sent presence and not existing servers + # + + self.federation_sender.reset_mock() + self._add_new_user(room_id, "@bob:server3") + + self.federation_sender.send_presence.assert_not_called() + self.federation_sender.send_presence_to_destinations.assert_called_once_with( + destinations=["server3"], states=[expected_state] + ) + + def test_remote_gets_presence_when_local_user_joins(self): + # We advance time to something that isn't 0, as we use 0 as a special + # value. + self.reactor.advance(1000000000000) + + # Create a room with one local users + room_id = self.helper.create_room_as(self.user_id) + + # Mark test as online + self.presence_handler.set_state( + UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE}, + ) + + # Mark test2 as online, test will be offline with a last_active of 0. + # Note we don't join them to the room yet + self.presence_handler.set_state( + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}, + ) + + # Add servers to the room + self._add_new_user(room_id, "@alice:server2") + self._add_new_user(room_id, "@bob:server3") + + self.reactor.pump([0]) # Wait for presence updates to be handled + + # + # Test that when a local join happens remote servers get told about it + # + + self.federation_sender.reset_mock() + + # Join local user to room + self.helper.join(room_id, "@test2:server") + + self.reactor.pump([0]) # Wait for presence updates to be handled + + # We shouldn't have sent out any local presence *updates* + self.federation_sender.send_presence.assert_not_called() + + # We expect to only send test2 presence to server2 and server3 + expected_state = self.get_success( + self.presence_handler.current_state_for_user("@test2:server") + ) + self.assertEqual(expected_state.state, PresenceState.ONLINE) + self.federation_sender.send_presence_to_destinations.assert_called_once_with( + destinations=set(("server2", "server3")), + states=[expected_state] + ) + + def _add_new_user(self, room_id, user_id): + """Add new user to the room by creating an event and poking the federation API. + """ + + hostname = get_domain_from_id(user_id) + + room_version = self.get_success(self.store.get_room_version(room_id)) + + builder = EventBuilder( + state=self.state, + auth=self.auth, + store=self.store, + clock=self.clock, + hostname=hostname, + signing_key=self.random_signing_key, + format_version=room_version_to_event_format(room_version), + room_id=room_id, + type=EventTypes.Member, + sender=user_id, + state_key=user_id, + content={"membership": Membership.JOIN} + ) + + prev_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(room_id) + ) + + event = self.get_success(builder.build(prev_event_ids)) + + self.get_success(self.federation_handler.on_receive_pdu(hostname, event)) + + # Check that it was successfully persisted. + self.get_success(self.store.get_event(event.event_id)) + self.get_success(self.store.get_event(event.event_id)) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index 9aa9dfe82e..d5a99f6caa 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.replication.tcp.streams import ReceiptsStreamRow +from synapse.replication.tcp.streams._base import ReceiptsStreamRow from tests.replication.tcp.streams._base import BaseStreamTestCase diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py index d3d43970fb..bbfc77e829 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.api.constants import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS +from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS from synapse.rest.client.v1 import admin, login from synapse.rest.client.v2_alpha import capabilities @@ -52,7 +52,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): for room_version in capabilities['m.room_versions']['available'].keys(): self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version) self.assertEqual( - DEFAULT_ROOM_VERSION, capabilities['m.room_versions']['default'] + DEFAULT_ROOM_VERSION.identifier, capabilities['m.room_versions']['default'] ) def test_get_change_password_capabilities(self): diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 9a5c816927..f448b01326 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -19,7 +19,8 @@ from six.moves import zip import attr -from synapse.api.constants import EventTypes, JoinRules, Membership, RoomVersions +from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.room_versions import RoomVersions from synapse.event_auth import auth_types_for_event from synapse.events import FrozenEvent from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store @@ -539,7 +540,7 @@ class StateTestCase(unittest.TestCase): state_before = dict(state_at_event[prev_events[0]]) else: state_d = resolve_events_with_store( - RoomVersions.V2, + RoomVersions.V2.identifier, [state_at_event[n] for n in prev_events], event_map=event_map, state_res_store=TestStateResolutionStore(event_map), @@ -686,7 +687,7 @@ class SimpleParamStateTestCase(unittest.TestCase): # Test that we correctly handle passing `None` as the event_map state_d = resolve_events_with_store( - RoomVersions.V2, + RoomVersions.V2.identifier, [self.state_at_bob, self.state_at_charlie], event_map=None, state_res_store=TestStateResolutionStore(self.event_map), diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 3957561b1e..0fc5019e9f 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -18,7 +18,8 @@ from mock import Mock from twisted.internet import defer -from synapse.api.constants import EventTypes, Membership, RoomVersions +from synapse.api.constants import EventTypes, Membership +from synapse.api.room_versions import RoomVersions from synapse.types import RoomID, UserID from tests import unittest @@ -51,7 +52,7 @@ class RedactionTestCase(unittest.TestCase): ): content = {"membership": membership} content.update(extra_content) - builder = self.event_builder_factory.new( + builder = self.event_builder_factory.for_room_version( RoomVersions.V1, { "type": EventTypes.Member, @@ -74,7 +75,7 @@ class RedactionTestCase(unittest.TestCase): def inject_message(self, room, user, body): self.depth += 1 - builder = self.event_builder_factory.new( + builder = self.event_builder_factory.for_room_version( RoomVersions.V1, { "type": EventTypes.Message, @@ -95,7 +96,7 @@ class RedactionTestCase(unittest.TestCase): @defer.inlineCallbacks def inject_redaction(self, room, event_id, user, reason): - builder = self.event_builder_factory.new( + builder = self.event_builder_factory.for_room_version( RoomVersions.V1, { "type": EventTypes.Redaction, diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 7fa2f4fd70..063387863e 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -18,7 +18,8 @@ from mock import Mock from twisted.internet import defer -from synapse.api.constants import EventTypes, Membership, RoomVersions +from synapse.api.constants import EventTypes, Membership +from synapse.api.room_versions import RoomVersions from synapse.types import RoomID, UserID from tests import unittest @@ -49,7 +50,7 @@ class RoomMemberStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def inject_room_member(self, room, user, membership, replaces_state=None): - builder = self.event_builder_factory.new( + builder = self.event_builder_factory.for_room_version( RoomVersions.V1, { "type": EventTypes.Member, diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 99cd3e09eb..78e260a7fa 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -17,7 +17,8 @@ import logging from twisted.internet import defer -from synapse.api.constants import EventTypes, Membership, RoomVersions +from synapse.api.constants import EventTypes, Membership +from synapse.api.room_versions import RoomVersions from synapse.storage.state import StateFilter from synapse.types import RoomID, UserID @@ -48,7 +49,7 @@ class StateStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def inject_state_event(self, room, sender, typ, state_key, content): - builder = self.event_builder_factory.new( + builder = self.event_builder_factory.for_room_version( RoomVersions.V1, { "type": typ, diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 7ee318e4e8..4c8f87e958 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -16,8 +16,8 @@ import unittest from synapse import event_auth -from synapse.api.constants import RoomVersions from synapse.api.errors import AuthError +from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent @@ -37,7 +37,7 @@ class EventAuthTestCase(unittest.TestCase): # creator should be able to send state event_auth.check( - RoomVersions.V1, _random_state_event(creator), auth_events, + RoomVersions.V1.identifier, _random_state_event(creator), auth_events, do_sig_check=False, ) @@ -45,7 +45,7 @@ class EventAuthTestCase(unittest.TestCase): self.assertRaises( AuthError, event_auth.check, - RoomVersions.V1, + RoomVersions.V1.identifier, _random_state_event(joiner), auth_events, do_sig_check=False, @@ -74,7 +74,7 @@ class EventAuthTestCase(unittest.TestCase): self.assertRaises( AuthError, event_auth.check, - RoomVersions.V1, + RoomVersions.V1.identifier, _random_state_event(pleb), auth_events, do_sig_check=False, @@ -82,7 +82,7 @@ class EventAuthTestCase(unittest.TestCase): # king should be able to send state event_auth.check( - RoomVersions.V1, _random_state_event(king), auth_events, + RoomVersions.V1.identifier, _random_state_event(king), auth_events, do_sig_check=False, ) diff --git a/tests/test_state.py b/tests/test_state.py index e20c33322a..03e4810c2e 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -18,7 +18,8 @@ from mock import Mock from twisted.internet import defer from synapse.api.auth import Auth -from synapse.api.constants import EventTypes, Membership, RoomVersions +from synapse.api.constants import EventTypes, Membership +from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent from synapse.state import StateHandler, StateResolutionHandler @@ -118,7 +119,7 @@ class StateGroupStore(object): self._event_to_state_group[event_id] = state_group def get_room_version(self, room_id): - return RoomVersions.V1 + return RoomVersions.V1.identifier class DictObj(dict): diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 455db9f276..3bdb500514 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -17,7 +17,7 @@ import logging from twisted.internet import defer from twisted.internet.defer import succeed -from synapse.api.constants import RoomVersions +from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent from synapse.visibility import filter_events_for_server @@ -124,7 +124,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def inject_visibility(self, user_id, visibility): content = {"history_visibility": visibility} - builder = self.event_builder_factory.new( + builder = self.event_builder_factory.for_room_version( RoomVersions.V1, { "type": "m.room.history_visibility", @@ -145,7 +145,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): def inject_room_member(self, user_id, membership="join", extra_content={}): content = {"membership": membership} content.update(extra_content) - builder = self.event_builder_factory.new( + builder = self.event_builder_factory.for_room_version( RoomVersions.V1, { "type": "m.room.member", @@ -167,7 +167,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): def inject_message(self, user_id, content=None): if content is None: content = {"body": "testytest", "msgtype": "m.text"} - builder = self.event_builder_factory.new( + builder = self.event_builder_factory.for_room_version( RoomVersions.V1, { "type": "m.room.message", diff --git a/tests/utils.py b/tests/utils.py index 615b9f8cca..cb75514851 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -27,8 +27,9 @@ from six.moves.urllib import parse as urlparse from twisted.internet import defer, reactor -from synapse.api.constants import EventTypes, RoomVersions +from synapse.api.constants import EventTypes from synapse.api.errors import CodeMessageException, cs_error +from synapse.api.room_versions import RoomVersions from synapse.config.homeserver import HomeServerConfig from synapse.federation.transport import server as federation_server from synapse.http.server import HttpServer @@ -671,7 +672,7 @@ def create_room(hs, room_id, creator_id): event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() - builder = event_builder_factory.new( + builder = event_builder_factory.for_room_version( RoomVersions.V1, { "type": EventTypes.Create, |