diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/api/constants.py | 9 | ||||
-rw-r--r-- | synapse/app/event_creator.py | 12 | ||||
-rw-r--r-- | synapse/crypto/keyclient.py | 8 | ||||
-rw-r--r-- | synapse/federation/transport/client.py | 2 | ||||
-rw-r--r-- | synapse/federation/transport/server.py | 6 | ||||
-rw-r--r-- | synapse/handlers/federation.py | 6 | ||||
-rw-r--r-- | synapse/handlers/profile.py | 31 | ||||
-rw-r--r-- | synapse/handlers/room_member.py | 1 | ||||
-rw-r--r-- | synapse/handlers/user_directory.py | 4 | ||||
-rw-r--r-- | synapse/http/request_metrics.py | 13 | ||||
-rw-r--r-- | synapse/metrics/background_process_metrics.py | 26 | ||||
-rw-r--r-- | synapse/rest/client/v1/room.py | 2 | ||||
-rw-r--r-- | synapse/server.py | 7 | ||||
-rw-r--r-- | synapse/state/__init__.py (renamed from synapse/state.py) | 318 | ||||
-rw-r--r-- | synapse/state/v1.py | 321 | ||||
-rw-r--r-- | synapse/storage/events.py | 4 | ||||
-rw-r--r-- | synapse/storage/profile.py | 4 | ||||
-rw-r--r-- | synapse/storage/room.py | 58 | ||||
-rw-r--r-- | synapse/storage/state.py | 158 |
19 files changed, 645 insertions, 345 deletions
diff --git a/synapse/api/constants.py b/synapse/api/constants.py index b0da506f6d..912bf024bf 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -97,9 +97,14 @@ class ThirdPartyEntityKind(object): LOCATION = "location" +class RoomVersions(object): + V1 = "1" + VDH_TEST = "vdh-test-version" + + # the version we will give rooms which are created on this server -DEFAULT_ROOM_VERSION = "1" +DEFAULT_ROOM_VERSION = RoomVersions.V1 # vdh-test-version is a placeholder to get room versioning support working and tested # until we have a working v2. -KNOWN_ROOM_VERSIONS = {"1", "vdh-test-version"} +KNOWN_ROOM_VERSIONS = {RoomVersions.V1, RoomVersions.VDH_TEST} diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index 03d39968a8..a34c89fa99 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -45,6 +45,11 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.rest.client.v1.profile import ( + ProfileAvatarURLRestServlet, + ProfileDisplaynameRestServlet, + ProfileRestServlet, +) from synapse.rest.client.v1.room import ( JoinRoomAliasServlet, RoomMembershipRestServlet, @@ -53,6 +58,7 @@ from synapse.rest.client.v1.room import ( ) from synapse.server import HomeServer from synapse.storage.engines import create_engine +from synapse.storage.user_directory import UserDirectoryStore from synapse.util.httpresourcetree import create_resource_tree from synapse.util.logcontext import LoggingContext from synapse.util.manhole import manhole @@ -62,6 +68,9 @@ logger = logging.getLogger("synapse.app.event_creator") class EventCreatorSlavedStore( + # FIXME(#3714): We need to add UserDirectoryStore as we write directly + # rather than going via the correct worker. + UserDirectoryStore, DirectoryStore, SlavedTransactionStore, SlavedProfileStore, @@ -101,6 +110,9 @@ class EventCreatorServer(HomeServer): RoomMembershipRestServlet(self).register(resource) RoomStateEventRestServlet(self).register(resource) JoinRoomAliasServlet(self).register(resource) + ProfileAvatarURLRestServlet(self).register(resource) + ProfileDisplaynameRestServlet(self).register(resource) + ProfileRestServlet(self).register(resource) resources.update({ "/_matrix/client/r0": resource, "/_matrix/client/unstable": resource, diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py index c20a32096a..e94400b8e2 100644 --- a/synapse/crypto/keyclient.py +++ b/synapse/crypto/keyclient.py @@ -18,7 +18,9 @@ import logging from canonicaljson import json from twisted.internet import defer, reactor +from twisted.internet.error import ConnectError from twisted.internet.protocol import Factory +from twisted.names.error import DomainError from twisted.web.http import HTTPClient from synapse.http.endpoint import matrix_federation_endpoint @@ -47,12 +49,14 @@ def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1): server_response, server_certificate = yield protocol.remote_key defer.returnValue((server_response, server_certificate)) except SynapseKeyClientError as e: - logger.exception("Error getting key for %r" % (server_name,)) + logger.warn("Error getting key for %r: %s", server_name, e) if e.status.startswith("4"): # Don't retry for 4xx responses. raise IOError("Cannot get key for %r" % server_name) + except (ConnectError, DomainError) as e: + logger.warn("Error getting key for %r: %s", server_name, e) except Exception as e: - logger.exception(e) + logger.exception("Error getting key for %r", server_name) raise IOError("Cannot get key for %r" % server_name) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index b4fbe2c9d5..1054441ca5 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -106,7 +106,7 @@ class TransportLayerClient(object): dest (str) room_id (str) event_tuples (list) - limt (int) + limit (int) Returns: Deferred: Results in a dict received from the remote homeserver. diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 77969a4f38..7a993fd1cf 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -261,10 +261,10 @@ class BaseFederationServlet(object): except NoAuthenticationError: origin = None if self.REQUIRE_AUTH: - logger.exception("authenticate_request failed") + logger.warn("authenticate_request failed: missing authentication") raise - except Exception: - logger.exception("authenticate_request failed") + except Exception as e: + logger.warn("authenticate_request failed: %s", e) raise if origin: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3dd107a285..0ebf0fd188 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -291,8 +291,9 @@ class FederationHandler(BaseHandler): ev_ids, get_prev_content=False, check_redacted=False ) + room_version = yield self.store.get_room_version(pdu.room_id) state_map = yield resolve_events_with_factory( - state_groups, {pdu.event_id: pdu}, fetch + room_version, state_groups, {pdu.event_id: pdu}, fetch ) state = (yield self.store.get_events(state_map.values())).values() @@ -1828,7 +1829,10 @@ class FederationHandler(BaseHandler): (d.type, d.state_key): d for d in different_events if d }) + room_version = yield self.store.get_room_version(event.room_id) + new_state = self.state_handler.resolve_events( + room_version, [list(local_view.values()), list(remote_view.values())], event ) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 9af2e8f869..75b8b7ce6a 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -32,12 +32,16 @@ from ._base import BaseHandler logger = logging.getLogger(__name__) -class ProfileHandler(BaseHandler): - PROFILE_UPDATE_MS = 60 * 1000 - PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 +class BaseProfileHandler(BaseHandler): + """Handles fetching and updating user profile information. + + BaseProfileHandler can be instantiated directly on workers and will + delegate to master when necessary. The master process should use the + subclass MasterProfileHandler + """ def __init__(self, hs): - super(ProfileHandler, self).__init__(hs) + super(BaseProfileHandler, self).__init__(hs) self.federation = hs.get_federation_client() hs.get_federation_registry().register_query_handler( @@ -46,11 +50,6 @@ class ProfileHandler(BaseHandler): self.user_directory_handler = hs.get_user_directory_handler() - if hs.config.worker_app is None: - self.clock.looping_call( - self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS, - ) - @defer.inlineCallbacks def get_profile(self, user_id): target_user = UserID.from_string(user_id) @@ -282,6 +281,20 @@ class ProfileHandler(BaseHandler): room_id, str(e.message) ) + +class MasterProfileHandler(BaseProfileHandler): + PROFILE_UPDATE_MS = 60 * 1000 + PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 + + def __init__(self, hs): + super(MasterProfileHandler, self).__init__(hs) + + assert hs.config.worker_app is None + + self.clock.looping_call( + self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS, + ) + def _start_update_remote_profile_cache(self): return run_as_background_process( "Update remote profile", self._update_remote_profile_cache, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index fb94b5d7d4..f643619047 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -344,6 +344,7 @@ class RoomMemberHandler(object): latest_event_ids = ( event_id for (event_id, _, _) in prev_events_and_hashes ) + current_state_ids = yield self.state_handler.get_current_state_ids( room_id, latest_event_ids=latest_event_ids, ) diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 37dda64587..d8413d6aa7 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -119,6 +119,8 @@ class UserDirectoryHandler(object): """Called to update index of our local user profiles when they change irrespective of any rooms the user may be in. """ + # FIXME(#3714): We should probably do this in the same worker as all + # the other changes. yield self.store.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url, None, ) @@ -127,6 +129,8 @@ class UserDirectoryHandler(object): def handle_user_deactivated(self, user_id): """Called when a user ID is deactivated """ + # FIXME(#3714): We should probably do this in the same worker as all + # the other changes. yield self.store.remove_from_user_dir(user_id) yield self.store.remove_from_user_in_public_room(user_id) diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py index 588e280571..72c2654678 100644 --- a/synapse/http/request_metrics.py +++ b/synapse/http/request_metrics.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +import threading from prometheus_client.core import Counter, Histogram @@ -111,6 +112,9 @@ in_flight_requests_db_sched_duration = Counter( # The set of all in flight requests, set[RequestMetrics] _in_flight_requests = set() +# Protects the _in_flight_requests set from concurrent accesss +_in_flight_requests_lock = threading.Lock() + def _get_in_flight_counts(): """Returns a count of all in flight requests by (method, server_name) @@ -120,7 +124,8 @@ def _get_in_flight_counts(): """ # Cast to a list to prevent it changing while the Prometheus # thread is collecting metrics - reqs = list(_in_flight_requests) + with _in_flight_requests_lock: + reqs = list(_in_flight_requests) for rm in reqs: rm.update_metrics() @@ -154,10 +159,12 @@ class RequestMetrics(object): # to the "in flight" metrics. self._request_stats = self.start_context.get_resource_usage() - _in_flight_requests.add(self) + with _in_flight_requests_lock: + _in_flight_requests.add(self) def stop(self, time_sec, request): - _in_flight_requests.discard(self) + with _in_flight_requests_lock: + _in_flight_requests.discard(self) context = LoggingContext.current_context() diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index ce678d5f75..167167be0a 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import threading + import six from prometheus_client.core import REGISTRY, Counter, GaugeMetricFamily @@ -78,6 +80,9 @@ _background_process_counts = dict() # type: dict[str, int] # of process descriptions that no longer have any active processes. _background_processes = dict() # type: dict[str, set[_BackgroundProcess]] +# A lock that covers the above dicts +_bg_metrics_lock = threading.Lock() + class _Collector(object): """A custom metrics collector for the background process metrics. @@ -92,7 +97,11 @@ class _Collector(object): labels=["name"], ) - for desc, processes in six.iteritems(_background_processes): + # We copy the dict so that it doesn't change from underneath us + with _bg_metrics_lock: + _background_processes_copy = dict(_background_processes) + + for desc, processes in six.iteritems(_background_processes_copy): background_process_in_flight_count.add_metric( (desc,), len(processes), ) @@ -167,19 +176,26 @@ def run_as_background_process(desc, func, *args, **kwargs): """ @defer.inlineCallbacks def run(): - count = _background_process_counts.get(desc, 0) - _background_process_counts[desc] = count + 1 + with _bg_metrics_lock: + count = _background_process_counts.get(desc, 0) + _background_process_counts[desc] = count + 1 + _background_process_start_count.labels(desc).inc() with LoggingContext(desc) as context: context.request = "%s-%i" % (desc, count) proc = _BackgroundProcess(desc, context) - _background_processes.setdefault(desc, set()).add(proc) + + with _bg_metrics_lock: + _background_processes.setdefault(desc, set()).add(proc) + try: yield func(*args, **kwargs) finally: proc.update_metrics() - _background_processes[desc].remove(proc) + + with _bg_metrics_lock: + _background_processes[desc].remove(proc) with PreserveLoggingContext(): return run() diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index fcc1091760..976d98387d 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -531,7 +531,7 @@ class RoomEventServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, event_id): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) event = yield self.event_handler.get_event(requester.user, room_id, event_id) time_now = self.clock.time_msec() diff --git a/synapse/server.py b/synapse/server.py index 26228d8c72..a6fbc6ec0c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -56,7 +56,7 @@ from synapse.handlers.initial_sync import InitialSyncHandler from synapse.handlers.message import EventCreationHandler, MessageHandler from synapse.handlers.pagination import PaginationHandler from synapse.handlers.presence import PresenceHandler -from synapse.handlers.profile import ProfileHandler +from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.room import RoomContextHandler, RoomCreationHandler @@ -308,7 +308,10 @@ class HomeServer(object): return InitialSyncHandler(self) def build_profile_handler(self): - return ProfileHandler(self) + if self.config.worker_app: + return BaseProfileHandler(self) + else: + return MasterProfileHandler(self) def build_event_creation_handler(self): return EventCreationHandler(self) diff --git a/synapse/state.py b/synapse/state/__init__.py index 0f2bedb694..b34970e4d1 100644 --- a/synapse/state.py +++ b/synapse/state/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,21 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import hashlib import logging from collections import namedtuple -from six import iteritems, iterkeys, itervalues +from six import iteritems, itervalues from frozendict import frozendict from twisted.internet import defer -from synapse import event_auth -from synapse.api.constants import EventTypes -from synapse.api.errors import AuthError +from synapse.api.constants import EventTypes, RoomVersions from synapse.events.snapshot import EventContext +from synapse.state import v1 from synapse.util.async_helpers import Linearizer from synapse.util.caches import get_cache_factor_for from synapse.util.caches.expiringcache import ExpiringCache @@ -264,6 +262,7 @@ class StateHandler(object): defer.returnValue(context) logger.debug("calling resolve_state_groups from compute_event_context") + entry = yield self.resolve_state_groups_for_events( event.room_id, [e for e, _ in event.prev_events], ) @@ -338,8 +337,11 @@ class StateHandler(object): event, resolves conflicts between them and returns them. Args: - room_id (str): - event_ids (list[str]): + room_id (str) + event_ids (list[str]) + explicit_room_version (str|None): If set uses the the given room + version to choose the resolution algorithm. If None, then + checks the database for room version. Returns: Deferred[_StateCacheEntry]: resolved state @@ -353,7 +355,12 @@ class StateHandler(object): room_id, event_ids ) - if len(state_groups_ids) == 1: + if len(state_groups_ids) == 0: + defer.returnValue(_StateCacheEntry( + state={}, + state_group=None, + )) + elif len(state_groups_ids) == 1: name, state_list = list(state_groups_ids.items()).pop() prev_group, delta_ids = yield self.store.get_state_group_delta(name) @@ -365,8 +372,11 @@ class StateHandler(object): delta_ids=delta_ids, )) + room_version = yield self.store.get_room_version(room_id) + result = yield self._state_resolution_handler.resolve_state_groups( - room_id, state_groups_ids, None, self._state_map_factory, + room_id, room_version, state_groups_ids, None, + self._state_map_factory, ) defer.returnValue(result) @@ -375,7 +385,7 @@ class StateHandler(object): ev_ids, get_prev_content=False, check_redacted=False, ) - def resolve_events(self, state_sets, event): + def resolve_events(self, room_version, state_sets, event): logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) @@ -391,7 +401,9 @@ class StateHandler(object): } with Measure(self.clock, "state._resolve_events"): - new_state = resolve_events_with_state_map(state_set_ids, state_map) + new_state = resolve_events_with_state_map( + room_version, state_set_ids, state_map, + ) new_state = { key: state_map[ev_id] for key, ev_id in iteritems(new_state) @@ -430,7 +442,7 @@ class StateResolutionHandler(object): @defer.inlineCallbacks @log_function def resolve_state_groups( - self, room_id, state_groups_ids, event_map, state_map_factory, + self, room_id, room_version, state_groups_ids, event_map, state_map_factory, ): """Resolves conflicts between a set of state groups @@ -439,6 +451,7 @@ class StateResolutionHandler(object): Args: room_id (str): room we are resolving for (used for logging) + room_version (str): version of the room state_groups_ids (dict[int, dict[(str, str), str]]): map from state group id to the state in that state group (where 'state' is a map from state key to event id) @@ -492,6 +505,7 @@ class StateResolutionHandler(object): logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_factory( + room_version, list(itervalues(state_groups_ids)), event_map=event_map, state_map_factory=state_map_factory, @@ -575,16 +589,10 @@ def _make_state_cache_entry( ) -def _ordered_events(events): - def key_func(e): - return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest() - - return sorted(events, key=key_func) - - -def resolve_events_with_state_map(state_sets, state_map): +def resolve_events_with_state_map(room_version, state_sets, state_map): """ Args: + room_version(str): Version of the room state_sets(list): List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. state_map(dict): a dict from event_id to event, for all events in @@ -594,75 +602,23 @@ def resolve_events_with_state_map(state_sets, state_map): dict[(str, str), str]: a map from (type, state_key) to event_id. """ - if len(state_sets) == 1: - return state_sets[0] - - unconflicted_state, conflicted_state = _seperate( - state_sets, - ) - - auth_events = _create_auth_events_from_maps( - unconflicted_state, conflicted_state, state_map - ) - - return _resolve_with_state( - unconflicted_state, conflicted_state, auth_events, state_map - ) - - -def _seperate(state_sets): - """Takes the state_sets and figures out which keys are conflicted and - which aren't. i.e., which have multiple different event_ids associated - with them in different state sets. - - Args: - state_sets(iterable[dict[(str, str), str]]): - List of dicts of (type, state_key) -> event_id, which are the - different state groups to resolve. - - Returns: - (dict[(str, str), str], dict[(str, str), set[str]]): - A tuple of (unconflicted_state, conflicted_state), where: - - unconflicted_state is a dict mapping (type, state_key)->event_id - for unconflicted state keys. - - conflicted_state is a dict mapping (type, state_key) to a set of - event ids for conflicted state keys. - """ - state_set_iterator = iter(state_sets) - unconflicted_state = dict(next(state_set_iterator)) - conflicted_state = {} - - for state_set in state_set_iterator: - for key, value in iteritems(state_set): - # Check if there is an unconflicted entry for the state key. - unconflicted_value = unconflicted_state.get(key) - if unconflicted_value is None: - # There isn't an unconflicted entry so check if there is a - # conflicted entry. - ls = conflicted_state.get(key) - if ls is None: - # There wasn't a conflicted entry so haven't seen this key before. - # Therefore it isn't conflicted yet. - unconflicted_state[key] = value - else: - # This key is already conflicted, add our value to the conflict set. - ls.add(value) - elif unconflicted_value != value: - # If the unconflicted value is not the same as our value then we - # have a new conflict. So move the key from the unconflicted_state - # to the conflicted state. - conflicted_state[key] = {value, unconflicted_value} - unconflicted_state.pop(key, None) - - return unconflicted_state, conflicted_state + if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,): + return v1.resolve_events_with_state_map( + state_sets, state_map, + ) + else: + # This should only happen if we added a version but forgot to add it to + # the list above. + raise Exception( + "No state resolution algorithm defined for version %r" % (room_version,) + ) -@defer.inlineCallbacks -def resolve_events_with_factory(state_sets, event_map, state_map_factory): +def resolve_events_with_factory(room_version, state_sets, event_map, state_map_factory): """ Args: + room_version(str): Version of the room + state_sets(list): List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. @@ -682,185 +638,13 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory): Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ - if len(state_sets) == 1: - defer.returnValue(state_sets[0]) - - unconflicted_state, conflicted_state = _seperate( - state_sets, - ) - - needed_events = set( - event_id - for event_ids in itervalues(conflicted_state) - for event_id in event_ids - ) - if event_map is not None: - needed_events -= set(iterkeys(event_map)) - - logger.info("Asking for %d conflicted events", len(needed_events)) - - # dict[str, FrozenEvent]: a map from state event id to event. Only includes - # the state events which are in conflict (and those in event_map) - state_map = yield state_map_factory(needed_events) - if event_map is not None: - state_map.update(event_map) - - # get the ids of the auth events which allow us to authenticate the - # conflicted state, picking only from the unconflicting state. - # - # dict[(str, str), str]: a map from state key to event id - auth_events = _create_auth_events_from_maps( - unconflicted_state, conflicted_state, state_map - ) - - new_needed_events = set(itervalues(auth_events)) - new_needed_events -= needed_events - if event_map is not None: - new_needed_events -= set(iterkeys(event_map)) - - logger.info("Asking for %d auth events", len(new_needed_events)) - - state_map_new = yield state_map_factory(new_needed_events) - state_map.update(state_map_new) - - defer.returnValue(_resolve_with_state( - unconflicted_state, conflicted_state, auth_events, state_map - )) - - -def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map): - auth_events = {} - for event_ids in itervalues(conflicted_state): - for event_id in event_ids: - if event_id in state_map: - keys = event_auth.auth_types_for_event(state_map[event_id]) - for key in keys: - if key not in auth_events: - event_id = unconflicted_state.get(key, None) - if event_id: - auth_events[key] = event_id - return auth_events - - -def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event_ids, - state_map): - conflicted_state = {} - for key, event_ids in iteritems(conflicted_state_ids): - events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] - if len(events) > 1: - conflicted_state[key] = events - elif len(events) == 1: - unconflicted_state_ids[key] = events[0].event_id - - auth_events = { - key: state_map[ev_id] - for key, ev_id in iteritems(auth_event_ids) - if ev_id in state_map - } - - try: - resolved_state = _resolve_state_events( - conflicted_state, auth_events + if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,): + return v1.resolve_events_with_factory( + state_sets, event_map, state_map_factory, + ) + else: + # This should only happen if we added a version but forgot to add it to + # the list above. + raise Exception( + "No state resolution algorithm defined for version %r" % (room_version,) ) - except Exception: - logger.exception("Failed to resolve state") - raise - - new_state = unconflicted_state_ids - for key, event in iteritems(resolved_state): - new_state[key] = event.event_id - - return new_state - - -def _resolve_state_events(conflicted_state, auth_events): - """ This is where we actually decide which of the conflicted state to - use. - - We resolve conflicts in the following order: - 1. power levels - 2. join rules - 3. memberships - 4. other events. - """ - resolved_state = {} - if POWER_KEY in conflicted_state: - events = conflicted_state[POWER_KEY] - logger.debug("Resolving conflicted power levels %r", events) - resolved_state[POWER_KEY] = _resolve_auth_events( - events, auth_events) - - auth_events.update(resolved_state) - - for key, events in iteritems(conflicted_state): - if key[0] == EventTypes.JoinRules: - logger.debug("Resolving conflicted join rules %r", events) - resolved_state[key] = _resolve_auth_events( - events, - auth_events - ) - - auth_events.update(resolved_state) - - for key, events in iteritems(conflicted_state): - if key[0] == EventTypes.Member: - logger.debug("Resolving conflicted member lists %r", events) - resolved_state[key] = _resolve_auth_events( - events, - auth_events - ) - - auth_events.update(resolved_state) - - for key, events in iteritems(conflicted_state): - if key not in resolved_state: - logger.debug("Resolving conflicted state %r:%r", key, events) - resolved_state[key] = _resolve_normal_events( - events, auth_events - ) - - return resolved_state - - -def _resolve_auth_events(events, auth_events): - reverse = [i for i in reversed(_ordered_events(events))] - - auth_keys = set( - key - for event in events - for key in event_auth.auth_types_for_event(event) - ) - - new_auth_events = {} - for key in auth_keys: - auth_event = auth_events.get(key, None) - if auth_event: - new_auth_events[key] = auth_event - - auth_events = new_auth_events - - prev_event = reverse[0] - for event in reverse[1:]: - auth_events[(prev_event.type, prev_event.state_key)] = prev_event - try: - # The signatures have already been checked at this point - event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) - prev_event = event - except AuthError: - return prev_event - - return event - - -def _resolve_normal_events(events, auth_events): - for event in _ordered_events(events): - try: - # The signatures have already been checked at this point - event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) - return event - except AuthError: - pass - - # Use the last event (the one with the least depth) if they all fail - # the auth check. - return event diff --git a/synapse/state/v1.py b/synapse/state/v1.py new file mode 100644 index 0000000000..3a1f7054a1 --- /dev/null +++ b/synapse/state/v1.py @@ -0,0 +1,321 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import logging + +from six import iteritems, iterkeys, itervalues + +from twisted.internet import defer + +from synapse import event_auth +from synapse.api.constants import EventTypes +from synapse.api.errors import AuthError + +logger = logging.getLogger(__name__) + + +POWER_KEY = (EventTypes.PowerLevels, "") + + +def resolve_events_with_state_map(state_sets, state_map): + """ + Args: + state_sets(list): List of dicts of (type, state_key) -> event_id, + which are the different state groups to resolve. + state_map(dict): a dict from event_id to event, for all events in + state_sets. + + Returns + dict[(str, str), str]: + a map from (type, state_key) to event_id. + """ + if len(state_sets) == 1: + return state_sets[0] + + unconflicted_state, conflicted_state = _seperate( + state_sets, + ) + + auth_events = _create_auth_events_from_maps( + unconflicted_state, conflicted_state, state_map + ) + + return _resolve_with_state( + unconflicted_state, conflicted_state, auth_events, state_map + ) + + +@defer.inlineCallbacks +def resolve_events_with_factory(state_sets, event_map, state_map_factory): + """ + Args: + state_sets(list): List of dicts of (type, state_key) -> event_id, + which are the different state groups to resolve. + + event_map(dict[str,FrozenEvent]|None): + a dict from event_id to event, for any events that we happen to + have in flight (eg, those currently being persisted). This will be + used as a starting point fof finding the state we need; any missing + events will be requested via state_map_factory. + + If None, all events will be fetched via state_map_factory. + + state_map_factory(func): will be called + with a list of event_ids that are needed, and should return with + a Deferred of dict of event_id to event. + + Returns + Deferred[dict[(str, str), str]]: + a map from (type, state_key) to event_id. + """ + if len(state_sets) == 1: + defer.returnValue(state_sets[0]) + + unconflicted_state, conflicted_state = _seperate( + state_sets, + ) + + needed_events = set( + event_id + for event_ids in itervalues(conflicted_state) + for event_id in event_ids + ) + if event_map is not None: + needed_events -= set(iterkeys(event_map)) + + logger.info("Asking for %d conflicted events", len(needed_events)) + + # dict[str, FrozenEvent]: a map from state event id to event. Only includes + # the state events which are in conflict (and those in event_map) + state_map = yield state_map_factory(needed_events) + if event_map is not None: + state_map.update(event_map) + + # get the ids of the auth events which allow us to authenticate the + # conflicted state, picking only from the unconflicting state. + # + # dict[(str, str), str]: a map from state key to event id + auth_events = _create_auth_events_from_maps( + unconflicted_state, conflicted_state, state_map + ) + + new_needed_events = set(itervalues(auth_events)) + new_needed_events -= needed_events + if event_map is not None: + new_needed_events -= set(iterkeys(event_map)) + + logger.info("Asking for %d auth events", len(new_needed_events)) + + state_map_new = yield state_map_factory(new_needed_events) + state_map.update(state_map_new) + + defer.returnValue(_resolve_with_state( + unconflicted_state, conflicted_state, auth_events, state_map + )) + + +def _seperate(state_sets): + """Takes the state_sets and figures out which keys are conflicted and + which aren't. i.e., which have multiple different event_ids associated + with them in different state sets. + + Args: + state_sets(iterable[dict[(str, str), str]]): + List of dicts of (type, state_key) -> event_id, which are the + different state groups to resolve. + + Returns: + (dict[(str, str), str], dict[(str, str), set[str]]): + A tuple of (unconflicted_state, conflicted_state), where: + + unconflicted_state is a dict mapping (type, state_key)->event_id + for unconflicted state keys. + + conflicted_state is a dict mapping (type, state_key) to a set of + event ids for conflicted state keys. + """ + state_set_iterator = iter(state_sets) + unconflicted_state = dict(next(state_set_iterator)) + conflicted_state = {} + + for state_set in state_set_iterator: + for key, value in iteritems(state_set): + # Check if there is an unconflicted entry for the state key. + unconflicted_value = unconflicted_state.get(key) + if unconflicted_value is None: + # There isn't an unconflicted entry so check if there is a + # conflicted entry. + ls = conflicted_state.get(key) + if ls is None: + # There wasn't a conflicted entry so haven't seen this key before. + # Therefore it isn't conflicted yet. + unconflicted_state[key] = value + else: + # This key is already conflicted, add our value to the conflict set. + ls.add(value) + elif unconflicted_value != value: + # If the unconflicted value is not the same as our value then we + # have a new conflict. So move the key from the unconflicted_state + # to the conflicted state. + conflicted_state[key] = {value, unconflicted_value} + unconflicted_state.pop(key, None) + + return unconflicted_state, conflicted_state + + +def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map): + auth_events = {} + for event_ids in itervalues(conflicted_state): + for event_id in event_ids: + if event_id in state_map: + keys = event_auth.auth_types_for_event(state_map[event_id]) + for key in keys: + if key not in auth_events: + event_id = unconflicted_state.get(key, None) + if event_id: + auth_events[key] = event_id + return auth_events + + +def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event_ids, + state_map): + conflicted_state = {} + for key, event_ids in iteritems(conflicted_state_ids): + events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] + if len(events) > 1: + conflicted_state[key] = events + elif len(events) == 1: + unconflicted_state_ids[key] = events[0].event_id + + auth_events = { + key: state_map[ev_id] + for key, ev_id in iteritems(auth_event_ids) + if ev_id in state_map + } + + try: + resolved_state = _resolve_state_events( + conflicted_state, auth_events + ) + except Exception: + logger.exception("Failed to resolve state") + raise + + new_state = unconflicted_state_ids + for key, event in iteritems(resolved_state): + new_state[key] = event.event_id + + return new_state + + +def _resolve_state_events(conflicted_state, auth_events): + """ This is where we actually decide which of the conflicted state to + use. + + We resolve conflicts in the following order: + 1. power levels + 2. join rules + 3. memberships + 4. other events. + """ + resolved_state = {} + if POWER_KEY in conflicted_state: + events = conflicted_state[POWER_KEY] + logger.debug("Resolving conflicted power levels %r", events) + resolved_state[POWER_KEY] = _resolve_auth_events( + events, auth_events) + + auth_events.update(resolved_state) + + for key, events in iteritems(conflicted_state): + if key[0] == EventTypes.JoinRules: + logger.debug("Resolving conflicted join rules %r", events) + resolved_state[key] = _resolve_auth_events( + events, + auth_events + ) + + auth_events.update(resolved_state) + + for key, events in iteritems(conflicted_state): + if key[0] == EventTypes.Member: + logger.debug("Resolving conflicted member lists %r", events) + resolved_state[key] = _resolve_auth_events( + events, + auth_events + ) + + auth_events.update(resolved_state) + + for key, events in iteritems(conflicted_state): + if key not in resolved_state: + logger.debug("Resolving conflicted state %r:%r", key, events) + resolved_state[key] = _resolve_normal_events( + events, auth_events + ) + + return resolved_state + + +def _resolve_auth_events(events, auth_events): + reverse = [i for i in reversed(_ordered_events(events))] + + auth_keys = set( + key + for event in events + for key in event_auth.auth_types_for_event(event) + ) + + new_auth_events = {} + for key in auth_keys: + auth_event = auth_events.get(key, None) + if auth_event: + new_auth_events[key] = auth_event + + auth_events = new_auth_events + + prev_event = reverse[0] + for event in reverse[1:]: + auth_events[(prev_event.type, prev_event.state_key)] = prev_event + try: + # The signatures have already been checked at this point + event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) + prev_event = event + except AuthError: + return prev_event + + return event + + +def _resolve_normal_events(events, auth_events): + for event in _ordered_events(events): + try: + # The signatures have already been checked at this point + event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) + return event + except AuthError: + pass + + # Use the last event (the one with the least depth) if they all fail + # the auth check. + return event + + +def _ordered_events(events): + def key_func(e): + return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest() + + return sorted(events, key=key_func) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 025a7fb6d9..f39c8c8461 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -705,9 +705,11 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore } events_map = {ev.event_id: ev for ev, _ in events_context} + room_version = yield self.get_room_version(room_id) + logger.debug("calling resolve_state_groups from preserve_events") res = yield self._state_resolution_handler.resolve_state_groups( - room_id, state_groups, events_map, get_events + room_id, room_version, state_groups, events_map, get_events ) defer.returnValue((res.state, None)) diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py index 60295da254..88b50f33b5 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/profile.py @@ -71,8 +71,6 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_from_remote_profile_cache", ) - -class ProfileStore(ProfileWorkerStore): def create_profile(self, user_localpart): return self._simple_insert( table="profiles", @@ -96,6 +94,8 @@ class ProfileStore(ProfileWorkerStore): desc="set_profile_avatar_url", ) + +class ProfileStore(ProfileWorkerStore): def add_remote_profile_cache(self, user_id, displayname, avatar_url): """Ensure we are caching the remote user's profiles. diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 3378fc77d1..61013b8919 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -186,6 +186,35 @@ class RoomWorkerStore(SQLBaseStore): desc="is_room_blocked", ) + @cachedInlineCallbacks(max_entries=10000) + def get_ratelimit_for_user(self, user_id): + """Check if there are any overrides for ratelimiting for the given + user + + Args: + user_id (str) + + Returns: + RatelimitOverride if there is an override, else None. If the contents + of RatelimitOverride are None or 0 then ratelimitng has been + disabled for that user entirely. + """ + row = yield self._simple_select_one( + table="ratelimit_override", + keyvalues={"user_id": user_id}, + retcols=("messages_per_second", "burst_count"), + allow_none=True, + desc="get_ratelimit_for_user", + ) + + if row: + defer.returnValue(RatelimitOverride( + messages_per_second=row["messages_per_second"], + burst_count=row["burst_count"], + )) + else: + defer.returnValue(None) + class RoomStore(RoomWorkerStore, SearchStore): @@ -469,35 +498,6 @@ class RoomStore(RoomWorkerStore, SearchStore): "get_all_new_public_rooms", get_all_new_public_rooms ) - @cachedInlineCallbacks(max_entries=10000) - def get_ratelimit_for_user(self, user_id): - """Check if there are any overrides for ratelimiting for the given - user - - Args: - user_id (str) - - Returns: - RatelimitOverride if there is an override, else None. If the contents - of RatelimitOverride are None or 0 then ratelimitng has been - disabled for that user entirely. - """ - row = yield self._simple_select_one( - table="ratelimit_override", - keyvalues={"user_id": user_id}, - retcols=("messages_per_second", "burst_count"), - allow_none=True, - desc="get_ratelimit_for_user", - ) - - if row: - defer.returnValue(RatelimitOverride( - messages_per_second=row["messages_per_second"], - burst_count=row["burst_count"], - )) - else: - defer.returnValue(None) - @defer.inlineCallbacks def block_room(self, room_id, user_id): yield self._simple_insert( diff --git a/synapse/storage/state.py b/synapse/storage/state.py index dd03c4168b..4b971efdba 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -60,8 +60,43 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): def __init__(self, db_conn, hs): super(StateGroupWorkerStore, self).__init__(db_conn, hs) + # Originally the state store used a single DictionaryCache to cache the + # event IDs for the state types in a given state group to avoid hammering + # on the state_group* tables. + # + # The point of using a DictionaryCache is that it can cache a subset + # of the state events for a given state group (i.e. a subset of the keys for a + # given dict which is an entry in the cache for a given state group ID). + # + # However, this poses problems when performing complicated queries + # on the store - for instance: "give me all the state for this group, but + # limit members to this subset of users", as DictionaryCache's API isn't + # rich enough to say "please cache any of these fields, apart from this subset". + # This is problematic when lazy loading members, which requires this behaviour, + # as without it the cache has no choice but to speculatively load all + # state events for the group, which negates the efficiency being sought. + # + # Rather than overcomplicating DictionaryCache's API, we instead split the + # state_group_cache into two halves - one for tracking non-member events, + # and the other for tracking member_events. This means that lazy loading + # queries can be made in a cache-friendly manner by querying both caches + # separately and then merging the result. So for the example above, you + # would query the members cache for a specific subset of state keys + # (which DictionaryCache will handle efficiently and fine) and the non-members + # cache for all state (which DictionaryCache will similarly handle fine) + # and then just merge the results together. + # + # We size the non-members cache to be smaller than the members cache as the + # vast majority of state in Matrix (today) is member events. + self._state_group_cache = DictionaryCache( - "*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache") + "*stateGroupCache*", + # TODO: this hasn't been tuned yet + 50000 * get_cache_factor_for("stateGroupCache") + ) + self._state_group_members_cache = DictionaryCache( + "*stateGroupMembersCache*", + 500000 * get_cache_factor_for("stateGroupMembersCache") ) @defer.inlineCallbacks @@ -275,7 +310,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): }) @defer.inlineCallbacks - def _get_state_groups_from_groups(self, groups, types): + def _get_state_groups_from_groups(self, groups, types, members=None): """Returns the state groups for a given set of groups, filtering on types of state events. @@ -284,6 +319,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): types (Iterable[str, str|None]|None): list of 2-tuples of the form (`type`, `state_key`), where a `state_key` of `None` matches all state_keys for the `type`. If None, all types are returned. + members (bool|None): If not None, then, in addition to any filtering + implied by types, the results are also filtered to only include + member events (if True), or to exclude member events (if False) Returns: dictionary state_group -> (dict of (type, state_key) -> event id) @@ -294,14 +332,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): for chunk in chunks: res = yield self.runInteraction( "_get_state_groups_from_groups", - self._get_state_groups_from_groups_txn, chunk, types, + self._get_state_groups_from_groups_txn, chunk, types, members, ) results.update(res) defer.returnValue(results) def _get_state_groups_from_groups_txn( - self, txn, groups, types=None, + self, txn, groups, types=None, members=None, ): results = {group: {} for group in groups} @@ -339,6 +377,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): %s """) + if members is True: + sql += " AND type = '%s'" % (EventTypes.Member,) + elif members is False: + sql += " AND type <> '%s'" % (EventTypes.Member,) + # Turns out that postgres doesn't like doing a list of OR's and # is about 1000x slower, so we just issue a query for each specific # type seperately. @@ -386,6 +429,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): else: where_clause = "" + if members is True: + where_clause += " AND type = '%s'" % EventTypes.Member + elif members is False: + where_clause += " AND type <> '%s'" % EventTypes.Member + # We don't use WITH RECURSIVE on sqlite3 as there are distributions # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) for group in groups: @@ -580,10 +628,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): defer.returnValue({row["event_id"]: row["state_group"] for row in rows}) - def _get_some_state_from_cache(self, group, types, filtered_types=None): + def _get_some_state_from_cache(self, cache, group, types, filtered_types=None): """Checks if group is in cache. See `_get_state_for_groups` Args: + cache(DictionaryCache): the state group cache to use group(int): The state group to lookup types(list[str, str|None]): List of 2-tuples of the form (`type`, `state_key`), where a `state_key` of `None` matches all @@ -597,11 +646,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): requests state from the cache, if False we need to query the DB for the missing state. """ - is_all, known_absent, state_dict_ids = self._state_group_cache.get(group) + is_all, known_absent, state_dict_ids = cache.get(group) type_to_key = {} - # tracks whether any of ourrequested types are missing from the cache + # tracks whether any of our requested types are missing from the cache missing_types = False for typ, state_key in types: @@ -648,7 +697,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): if include(k[0], k[1]) }, got_all - def _get_all_state_from_cache(self, group): + def _get_all_state_from_cache(self, cache, group): """Checks if group is in cache. See `_get_state_for_groups` Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool @@ -656,9 +705,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): cache, if False we need to query the DB for the missing state. Args: + cache(DictionaryCache): the state group cache to use group: The state group to lookup """ - is_all, _, state_dict_ids = self._state_group_cache.get(group) + is_all, _, state_dict_ids = cache.get(group) return state_dict_ids, is_all @@ -685,6 +735,62 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Deferred[dict[int, dict[(type, state_key), EventBase]]] a dictionary mapping from state group to state dictionary. """ + if types is not None: + non_member_types = [t for t in types if t[0] != EventTypes.Member] + + if filtered_types is not None and EventTypes.Member not in filtered_types: + # we want all of the membership events + member_types = None + else: + member_types = [t for t in types if t[0] == EventTypes.Member] + + else: + non_member_types = None + member_types = None + + non_member_state = yield self._get_state_for_groups_using_cache( + groups, self._state_group_cache, non_member_types, filtered_types, + ) + # XXX: we could skip this entirely if member_types is [] + member_state = yield self._get_state_for_groups_using_cache( + # we set filtered_types=None as member_state only ever contain members. + groups, self._state_group_members_cache, member_types, None, + ) + + state = non_member_state + for group in groups: + state[group].update(member_state[group]) + + defer.returnValue(state) + + @defer.inlineCallbacks + def _get_state_for_groups_using_cache( + self, groups, cache, types=None, filtered_types=None + ): + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key, querying from a specific cache. + + Args: + groups (iterable[int]): list of state groups for which we want + to get the state. + cache (DictionaryCache): the cache of group ids to state dicts which + we will pass through - either the normal state cache or the specific + members state cache. + types (None|iterable[(str, None|str)]): + indicates the state type/keys required. If None, the whole + state is fetched and returned. + + Otherwise, each entry should be a `(type, state_key)` tuple to + include in the response. A `state_key` of None is a wildcard + meaning that we require all state with that type. + filtered_types(list[str]|None): Only apply filtering via `types` to this + list of event types. Other types of events are returned unfiltered. + If None, `types` filtering is applied to all events. + + Returns: + Deferred[dict[int, dict[(type, state_key), EventBase]]] + a dictionary mapping from state group to state dictionary. + """ if types: types = frozenset(types) results = {} @@ -692,7 +798,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): if types is not None: for group in set(groups): state_dict_ids, got_all = self._get_some_state_from_cache( - group, types, filtered_types + cache, group, types, filtered_types ) results[group] = state_dict_ids @@ -701,7 +807,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): else: for group in set(groups): state_dict_ids, got_all = self._get_all_state_from_cache( - group + cache, group ) results[group] = state_dict_ids @@ -710,8 +816,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): missing_groups.append(group) if missing_groups: - # Okay, so we have some missing_types, lets fetch them. - cache_seq_num = self._state_group_cache.sequence + # Okay, so we have some missing_types, let's fetch them. + cache_seq_num = cache.sequence # the DictionaryCache knows if it has *all* the state, but # does not know if it has all of the keys of a particular type, @@ -725,7 +831,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): types_to_fetch = types group_to_state_dict = yield self._get_state_groups_from_groups( - missing_groups, types_to_fetch + missing_groups, types_to_fetch, cache == self._state_group_members_cache, ) for group, group_state_dict in iteritems(group_to_state_dict): @@ -745,7 +851,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # update the cache with all the things we fetched from the # database. - self._state_group_cache.update( + cache.update( cache_seq_num, key=group, value=group_state_dict, @@ -847,15 +953,33 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ], ) - # Prefill the state group cache with this group. + # Prefill the state group caches with this group. # It's fine to use the sequence like this as the state group map # is immutable. (If the map wasn't immutable then this prefill could # race with another update) + + current_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] == EventTypes.Member + } + txn.call_after( + self._state_group_members_cache.update, + self._state_group_members_cache.sequence, + key=state_group, + value=dict(current_member_state_ids), + ) + + current_non_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] != EventTypes.Member + } txn.call_after( self._state_group_cache.update, self._state_group_cache.sequence, key=state_group, - value=dict(current_state_ids), + value=dict(current_non_member_state_ids), ) return state_group |