diff options
Diffstat (limited to '')
60 files changed, 1836 insertions, 673 deletions
diff --git a/CHANGES.rst b/CHANGES.rst index 49673ccce4..c40a32abd6 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,40 @@ +Changes in synapse v0.17.2 (2016-09-08) +======================================= + +This release contains security bug fixes. Please upgrade. + + +No changes since v0.17.2 + + +Changes in synapse v0.17.2-rc1 (2016-09-05) +=========================================== + +Features: + +* Start adding store-and-forward direct-to-device messaging (PR #1046, #1050, + #1062, #1066) + + +Changes: + +* Avoid pulling the full state of a room out so often (PR #1047, #1049, #1063, + #1068) +* Don't notify for online to online presence transitions. (PR #1054) +* Occasionally persist unpersisted presence updates (PR #1055) +* Allow application services to have an optional 'url' (PR #1056) +* Clean up old sent transactions from DB (PR #1059) + + +Bug fixes: + +* Fix None check in backfill (PR #1043) +* Fix membership changes to be idempotent (PR #1067) +* Fix bug in get_pdu where it would sometimes return events with incorrect + signature + + + Changes in synapse v0.17.1 (2016-08-24) ======================================= diff --git a/README.rst b/README.rst index 172dd4dfa0..f1ccc8dc45 100644 --- a/README.rst +++ b/README.rst @@ -134,6 +134,12 @@ Installing prerequisites on Raspbian:: sudo pip install --upgrade ndg-httpsclient sudo pip install --upgrade virtualenv +Installing prerequisites on openSUSE:: + + sudo zypper in -t pattern devel_basis + sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \ + python-devel libffi-devel libopenssl-devel libjpeg62-devel + To install the synapse homeserver run:: virtualenv -p python2.7 ~/.synapse @@ -199,6 +205,21 @@ run (e.g. ``~/.synapse``), and:: source ./bin/activate synctl start +Security Note +============= + +Matrix serves raw user generated data in some APIs - specifically the content +repository endpoints: http://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-media-r0-download-servername-mediaid +Whilst we have tried to mitigate against possible XSS attacks (e.g. +https://github.com/matrix-org/synapse/pull/1021) we recommend running +matrix homeservers on a dedicated domain name, to limit any malicious user generated +content served to web browsers a matrix API from being able to attack webapps hosted +on the same domain. This is particularly true of sharing a matrix webclient and +server on the same domain. + +See https://github.com/vector-im/vector-web/issues/1977 and +https://developer.github.com/changes/2014-04-25-user-content-security for more details. + Using PostgreSQL ================ @@ -215,9 +236,6 @@ The advantages of Postgres include: pointing at the same DB master, as well as enabling DB replication in synapse itself. -The only disadvantage is that the code is relatively new as of April 2015 and -may have a few regressions relative to SQLite. - For information on how to install and use PostgreSQL, please see `docs/postgres.rst <docs/postgres.rst>`_. diff --git a/synapse/__init__.py b/synapse/__init__.py index 43bf78f885..523deaa5ff 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.17.1" +__version__ = "0.17.2" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 0db26fcfd7..dcda40863f 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -52,7 +52,7 @@ class Auth(object): self.state = hs.get_state_handler() self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 # Docs for these currently lives at - # https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst + # github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst # In addition, we have type == delete_pusher which grants access only to # delete pushers. self._KNOWN_CAVEAT_PREFIXES = set([ @@ -63,6 +63,17 @@ class Auth(object): "user_id = ", ]) + @defer.inlineCallbacks + def check_from_context(self, event, context, do_sig_check=True): + auth_events_ids = yield self.compute_auth_events( + event, context.prev_state_ids, for_verification=True, + ) + auth_events = yield self.store.get_events(auth_events_ids) + auth_events = { + (e.type, e.state_key): e for e in auth_events.values() + } + self.check(event, auth_events=auth_events, do_sig_check=False) + def check(self, event, auth_events, do_sig_check=True): """ Checks if this event is correctly authed. @@ -267,21 +278,17 @@ class Auth(object): @defer.inlineCallbacks def check_host_in_room(self, room_id, host): - curr_state = yield self.state.get_current_state(room_id) - - for event in curr_state.values(): - if event.type == EventTypes.Member: - try: - if get_domain_from_id(event.state_key) != host: - continue - except: - logger.warn("state_key not user_id: %s", event.state_key) - continue + with Measure(self.clock, "check_host_in_room"): + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - if event.content["membership"] == Membership.JOIN: - defer.returnValue(True) + entry = yield self.state.resolve_state_groups( + room_id, latest_event_ids + ) - defer.returnValue(False) + ret = yield self.store.is_host_joined( + room_id, host, entry.state_group, entry.state + ) + defer.returnValue(ret) def check_event_sender_in_room(self, event, auth_events): key = (EventTypes.Member, event.user_id, ) @@ -847,7 +854,7 @@ class Auth(object): @defer.inlineCallbacks def add_auth_events(self, builder, context): - auth_ids = self.compute_auth_events(builder, context.current_state) + auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids) auth_events_entries = yield self.store.add_event_hashes( auth_ids @@ -855,30 +862,32 @@ class Auth(object): builder.auth_events = auth_events_entries - def compute_auth_events(self, event, current_state): + @defer.inlineCallbacks + def compute_auth_events(self, event, current_state_ids, for_verification=False): if event.type == EventTypes.Create: - return [] + defer.returnValue([]) auth_ids = [] key = (EventTypes.PowerLevels, "", ) - power_level_event = current_state.get(key) + power_level_event_id = current_state_ids.get(key) - if power_level_event: - auth_ids.append(power_level_event.event_id) + if power_level_event_id: + auth_ids.append(power_level_event_id) key = (EventTypes.JoinRules, "", ) - join_rule_event = current_state.get(key) + join_rule_event_id = current_state_ids.get(key) key = (EventTypes.Member, event.user_id, ) - member_event = current_state.get(key) + member_event_id = current_state_ids.get(key) key = (EventTypes.Create, "", ) - create_event = current_state.get(key) - if create_event: - auth_ids.append(create_event.event_id) + create_event_id = current_state_ids.get(key) + if create_event_id: + auth_ids.append(create_event_id) - if join_rule_event: + if join_rule_event_id: + join_rule_event = yield self.store.get_event(join_rule_event_id) join_rule = join_rule_event.content.get("join_rule") is_public = join_rule == JoinRules.PUBLIC if join_rule else False else: @@ -887,15 +896,21 @@ class Auth(object): if event.type == EventTypes.Member: e_type = event.content["membership"] if e_type in [Membership.JOIN, Membership.INVITE]: - if join_rule_event: - auth_ids.append(join_rule_event.event_id) + if join_rule_event_id: + auth_ids.append(join_rule_event_id) if e_type == Membership.JOIN: - if member_event and not is_public: - auth_ids.append(member_event.event_id) + if member_event_id and not is_public: + auth_ids.append(member_event_id) else: - if member_event: - auth_ids.append(member_event.event_id) + if member_event_id: + auth_ids.append(member_event_id) + + if for_verification: + key = (EventTypes.Member, event.state_key, ) + existing_event_id = current_state_ids.get(key) + if existing_event_id: + auth_ids.append(existing_event_id) if e_type == Membership.INVITE: if "third_party_invite" in event.content: @@ -903,14 +918,15 @@ class Auth(object): EventTypes.ThirdPartyInvite, event.content["third_party_invite"]["signed"]["token"] ) - third_party_invite = current_state.get(key) - if third_party_invite: - auth_ids.append(third_party_invite.event_id) - elif member_event: + third_party_invite_id = current_state_ids.get(key) + if third_party_invite_id: + auth_ids.append(third_party_invite_id) + elif member_event_id: + member_event = yield self.store.get_event(member_event_id) if member_event.content["membership"] == Membership.JOIN: auth_ids.append(member_event.event_id) - return auth_ids + defer.returnValue(auth_ids) def _get_send_level(self, etype, state_key, auth_events): key = (EventTypes.PowerLevels, "", ) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 8cf4d6169c..a8123cddcb 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -85,3 +85,8 @@ class RoomCreationPreset(object): PRIVATE_CHAT = "private_chat" PUBLIC_CHAT = "public_chat" TRUSTED_PRIVATE_CHAT = "trusted_private_chat" + + +class ThirdPartyEntityKind(object): + USER = "user" + LOCATION = "location" diff --git a/synapse/api/urls.py b/synapse/api/urls.py index 0fd9b7f244..91a33a3402 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -25,4 +25,3 @@ SERVER_KEY_PREFIX = "/_matrix/key/v1" SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" MEDIA_PREFIX = "/_matrix/media/r0" LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" -APP_SERVICE_PREFIX = "/_matrix/appservice/v1" diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index e3173533e2..07d3d047c6 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -36,6 +36,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.slave.storage.filtering import SlavedFilteringStore from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.presence import SlavedPresenceStore +from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.server import HomeServer from synapse.storage.client_ips import ClientIpStore from synapse.storage.engines import create_engine @@ -72,6 +73,7 @@ class SynchrotronSlavedStore( SlavedRegistrationStore, SlavedFilteringStore, SlavedPresenceStore, + SlavedDeviceInboxStore, BaseSlavedStore, ClientIpStore, # After BaseSlavedStore because the constructor is different ): @@ -397,6 +399,9 @@ class SynchrotronServer(HomeServer): notify_from_stream( result, "typing", "typing_key", room="room_id" ) + notify_from_stream( + result, "to_device", "to_device_key", user="user_id" + ) while True: try: diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index bde9b51b2e..126a10efb7 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -88,6 +88,8 @@ class ApplicationService(object): self.sender = sender self.namespaces = self._check_namespaces(namespaces) self.id = id + + # .protocols is a publicly visible field if protocols: self.protocols = set(protocols) else: diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 066127b666..cc4af23962 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -14,10 +14,11 @@ # limitations under the License. from twisted.internet import defer +from synapse.api.constants import ThirdPartyEntityKind from synapse.api.errors import CodeMessageException from synapse.http.client import SimpleHttpClient from synapse.events.utils import serialize_event -from synapse.types import ThirdPartyEntityKind +from synapse.util.caches.response_cache import ResponseCache import logging import urllib @@ -25,6 +26,12 @@ import urllib logger = logging.getLogger(__name__) +HOUR_IN_MS = 60 * 60 * 1000 + + +APP_SERVICE_PREFIX = "/_matrix/app/unstable" + + def _is_valid_3pe_result(r, field): if not isinstance(r, dict): return False @@ -56,8 +63,12 @@ class ApplicationServiceApi(SimpleHttpClient): super(ApplicationServiceApi, self).__init__(hs) self.clock = hs.get_clock() + self.protocol_meta_cache = ResponseCache(hs, timeout_ms=HOUR_IN_MS) + @defer.inlineCallbacks def query_user(self, service, user_id): + if service.url is None: + defer.returnValue(False) uri = service.url + ("/users/%s" % urllib.quote(user_id)) response = None try: @@ -77,6 +88,8 @@ class ApplicationServiceApi(SimpleHttpClient): @defer.inlineCallbacks def query_alias(self, service, alias): + if service.url is None: + defer.returnValue(False) uri = service.url + ("/rooms/%s" % urllib.quote(alias)) response = None try: @@ -97,16 +110,22 @@ class ApplicationServiceApi(SimpleHttpClient): @defer.inlineCallbacks def query_3pe(self, service, kind, protocol, fields): if kind == ThirdPartyEntityKind.USER: - uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol)) required_field = "userid" elif kind == ThirdPartyEntityKind.LOCATION: - uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol)) required_field = "alias" else: raise ValueError( "Unrecognised 'kind' argument %r to query_3pe()", kind ) + if service.url is None: + defer.returnValue([]) + uri = "%s%s/thirdparty/%s/%s" % ( + service.url, + APP_SERVICE_PREFIX, + kind, + urllib.quote(protocol) + ) try: response = yield self.get_json(uri, fields) if not isinstance(response, list): @@ -131,8 +150,34 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning("query_3pe to %s threw exception %s", uri, ex) defer.returnValue([]) + def get_3pe_protocol(self, service, protocol): + if service.url is None: + defer.returnValue({}) + + @defer.inlineCallbacks + def _get(): + uri = "%s%s/thirdparty/protocol/%s" % ( + service.url, + APP_SERVICE_PREFIX, + urllib.quote(protocol) + ) + try: + defer.returnValue((yield self.get_json(uri, {}))) + except Exception as ex: + logger.warning("query_3pe_protocol to %s threw exception %s", + uri, ex) + defer.returnValue({}) + + key = (service.id, protocol) + return self.protocol_meta_cache.get(key) or ( + self.protocol_meta_cache.set(key, _get()) + ) + @defer.inlineCallbacks def push_bulk(self, service, events, txn_id=None): + if service.url is None: + defer.returnValue(True) + events = self._serialize(events) if txn_id is None: diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index dfe43b0b4c..d7537e8d44 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -86,7 +86,7 @@ def load_appservices(hostname, config_files): def _load_appservice(hostname, as_info, config_filename): required_string_fields = [ - "id", "url", "as_token", "hs_token", "sender_localpart" + "id", "as_token", "hs_token", "sender_localpart" ] for field in required_string_fields: if not isinstance(as_info.get(field), basestring): @@ -94,6 +94,14 @@ def _load_appservice(hostname, as_info, config_filename): field, config_filename, )) + # 'url' must either be a string or explicitly null, not missing + # to avoid accidentally turning off push for ASes. + if (not isinstance(as_info.get("url"), basestring) and + as_info.get("url", "") is not None): + raise KeyError( + "Required string field or explicit null: 'url' (%s)" % (config_filename,) + ) + localpart = as_info["sender_localpart"] if urllib.quote(localpart) != localpart: raise ValueError( @@ -132,6 +140,13 @@ def _load_appservice(hostname, as_info, config_filename): for p in protocols: if not isinstance(p, str): raise KeyError("Bad value for 'protocols' item") + + if as_info["url"] is None: + logger.info( + "(%s) Explicitly empty 'url' provided. This application service" + " will not receive events or queries.", + config_filename, + ) return ApplicationService( token=as_info["as_token"], url=as_info["url"], diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 13154b1723..bcb8f33a58 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -99,7 +99,7 @@ class EventBase(object): return d - def get(self, key, default): + def get(self, key, default=None): return self._event_dict.get(key, default) def get_internal_metadata_dict(self): diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 8a475417a6..e895b1c450 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -15,9 +15,9 @@ class EventContext(object): - - def __init__(self, current_state=None): - self.current_state = current_state + def __init__(self): + self.current_state_ids = None + self.prev_state_ids = None self.state_group = None self.rejected = False self.push_actions = [] diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index f2b3aceb49..627acc6a4f 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -29,6 +29,7 @@ from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.events import FrozenEvent +from synapse.types import get_domain_from_id import synapse.metrics from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination @@ -63,6 +64,7 @@ class FederationClient(FederationBase): self._clock.looping_call( self._clear_tried_cache, 60 * 1000, ) + self.state = hs.get_state_handler() def _clear_tried_cache(self): """Clear pdu_destination_tried cache""" @@ -267,7 +269,7 @@ class FederationClient(FederationBase): pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) - pdu = None + signed_pdu = None for destination in destinations: now = self._clock.time_msec() last_attempt = pdu_attempts.get(destination, 0) @@ -297,7 +299,7 @@ class FederationClient(FederationBase): pdu = pdu_list[0] # Check signatures are correct. - pdu = yield self._check_sigs_and_hashes([pdu])[0] + signed_pdu = yield self._check_sigs_and_hashes([pdu])[0] break @@ -320,10 +322,10 @@ class FederationClient(FederationBase): ) continue - if self._get_pdu_cache is not None and pdu: - self._get_pdu_cache[event_id] = pdu + if self._get_pdu_cache is not None and signed_pdu: + self._get_pdu_cache[event_id] = signed_pdu - defer.returnValue(pdu) + defer.returnValue(signed_pdu) @defer.inlineCallbacks @log_function @@ -811,7 +813,8 @@ class FederationClient(FederationBase): if len(signed_events) >= limit: defer.returnValue(signed_events) - servers = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + servers = set(get_domain_from_id(u) for u in users) servers = set(servers) servers.discard(self.server_name) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index aba19639c7..5621655098 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -223,16 +223,14 @@ class FederationServer(FederationBase): if not in_room: raise AuthError(403, "Host not in room.") - pdus = yield self.handler.get_state_for_pdu( + state_ids = yield self.handler.get_state_ids_for_pdu( room_id, event_id, ) - auth_chain = yield self.store.get_auth_chain( - [pdu.event_id for pdu in pdus] - ) + auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids) defer.returnValue((200, { - "pdu_ids": [pdu.event_id for pdu in pdus], - "auth_chain_ids": [pdu.event_id for pdu in auth_chain], + "pdu_ids": state_ids, + "auth_chain_ids": auth_chain_ids, })) @defer.inlineCallbacks diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 11081a0cd5..e58735294e 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -65,33 +65,21 @@ class BaseHandler(object): retry_after_ms=int(1000 * (time_allowed - time_now)), ) - def is_host_in_room(self, current_state): - room_members = [ - (state_key, event.membership) - for ((event_type, state_key), event) in current_state.items() - if event_type == EventTypes.Member - ] - if len(room_members) == 0: - # Have we just created the room, and is this about to be the very - # first member event? - create_event = current_state.get(("m.room.create", "")) - if create_event: - return True - for (state_key, membership) in room_members: - if ( - self.hs.is_mine_id(state_key) - and membership == Membership.JOIN - ): - return True - return False - @defer.inlineCallbacks - def maybe_kick_guest_users(self, event, current_state): + def maybe_kick_guest_users(self, event, context=None): # Technically this function invalidates current_state by changing it. # Hopefully this isn't that important to the caller. if event.type == EventTypes.GuestAccess: guest_access = event.content.get("guest_access", "forbidden") if guest_access != "can_join": + if context: + current_state = yield self.store.get_events( + context.current_state_ids.values() + ) + current_state = current_state.values() + else: + current_state = yield self.store.get_current_state(event.room_id) + logger.info("maybe_kick_guest_users %r", current_state) yield self.kick_guest_users(current_state) @defer.inlineCallbacks diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 306686a384..b440280b74 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -176,6 +176,16 @@ class ApplicationServicesHandler(object): defer.returnValue(ret) @defer.inlineCallbacks + def get_3pe_protocols(self): + services = yield self.store.get_app_services() + protocols = {} + for s in services: + for p in s.protocols: + protocols[p] = yield self.appservice_api.get_3pe_protocol(s, p) + + defer.returnValue(protocols) + + @defer.inlineCallbacks def _get_services_for_event(self, event): """Retrieve a list of application services interested in this event. diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 4bea7f2b19..14352985e2 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -19,7 +19,7 @@ from ._base import BaseHandler from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError from synapse.api.constants import EventTypes -from synapse.types import RoomAlias, UserID +from synapse.types import RoomAlias, UserID, get_domain_from_id import logging import string @@ -55,7 +55,8 @@ class DirectoryHandler(BaseHandler): # TODO(erikj): Add transactions. # TODO(erikj): Check if there is a current association. if not servers: - servers = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + servers = set(get_domain_from_id(u) for u in users) if not servers: raise SynapseError(400, "Failed to get server list") @@ -193,7 +194,8 @@ class DirectoryHandler(BaseHandler): Codes.NOT_FOUND ) - extra_servers = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + extra_servers = set(get_domain_from_id(u) for u in users) servers = set(extra_servers) | set(servers) # If this server is in the list of servers, return it first. diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 3a3a1257d3..d3685fb12a 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -47,6 +47,7 @@ class EventStreamHandler(BaseHandler): self.clock = hs.get_clock() self.notifier = hs.get_notifier() + self.state = hs.get_state_handler() @defer.inlineCallbacks @log_function @@ -90,7 +91,7 @@ class EventStreamHandler(BaseHandler): # Send down presence. if event.state_key == auth_user_id: # Send down presence for everyone in the room. - users = yield self.store.get_users_in_room(event.room_id) + users = yield self.state.get_current_user_in_room(event.room_id) states = yield presence_handler.get_states( users, as_event=True, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 01a761715b..dc90a5dde4 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -29,6 +29,7 @@ from synapse.util import unwrapFirstError from synapse.util.logcontext import ( PreserveLoggingContext, preserve_fn, preserve_context_over_deferred ) +from synapse.util.metrics import measure_func from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor from synapse.util.frozenutils import unfreeze @@ -100,6 +101,9 @@ class FederationHandler(BaseHandler): def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None): """ Called by the ReplicationLayer when we have a new pdu. We need to do auth checks and put it through the StateHandler. + + auth_chain and state are None if we already have the necessary state + and prev_events in the db """ event = pdu @@ -117,12 +121,21 @@ class FederationHandler(BaseHandler): # FIXME (erikj): Awful hack to make the case where we are not currently # in the room work - is_in_room = yield self.auth.check_host_in_room( - event.room_id, - self.server_name - ) - if not is_in_room and not event.internal_metadata.is_outlier(): - logger.debug("Got event for room we're not in.") + # If state and auth_chain are None, then we don't need to do this check + # as we already know we have enough state in the DB to handle this + # event. + if state and auth_chain and not event.internal_metadata.is_outlier(): + is_in_room = yield self.auth.check_host_in_room( + event.room_id, + self.server_name + ) + else: + is_in_room = True + if not is_in_room: + logger.info( + "Got event for room we're not in: %r %r", + event.room_id, event.event_id + ) try: event_stream_id, max_stream_id = yield self._persist_auth_tree( @@ -217,17 +230,28 @@ class FederationHandler(BaseHandler): if event.type == EventTypes.Member: if event.membership == Membership.JOIN: - prev_state = context.current_state.get((event.type, event.state_key)) - if not prev_state or prev_state.membership != Membership.JOIN: - # Only fire user_joined_room if the user has acutally - # joined the room. Don't bother if the user is just - # changing their profile info. + # Only fire user_joined_room if the user has acutally + # joined the room. Don't bother if the user is just + # changing their profile info. + newly_joined = True + prev_state_id = context.prev_state_ids.get( + (event.type, event.state_key) + ) + if prev_state_id: + prev_state = yield self.store.get_event( + prev_state_id, allow_none=True, + ) + if prev_state and prev_state.membership == Membership.JOIN: + newly_joined = False + + if newly_joined: user = UserID.from_string(event.state_key) yield user_joined_room(self.distributor, user, event.room_id) + @measure_func("_filter_events_for_server") @defer.inlineCallbacks def _filter_events_for_server(self, server_name, room_id, events): - event_to_state = yield self.store.get_state_for_events( + event_to_state_ids = yield self.store.get_state_ids_for_events( frozenset(e.event_id for e in events), types=( (EventTypes.RoomHistoryVisibility, ""), @@ -235,6 +259,30 @@ class FederationHandler(BaseHandler): ) ) + # We only want to pull out member events that correspond to the + # server's domain. + + def check_match(id): + try: + return server_name == get_domain_from_id(id) + except: + return False + + event_map = yield self.store.get_events([ + e_id for key_to_eid in event_to_state_ids.values() + for key, e_id in key_to_eid + if key[0] != EventTypes.Member or check_match(key[1]) + ]) + + event_to_state = { + e_id: { + key: event_map[inner_e_id] + for key, inner_e_id in key_to_eid.items() + if inner_e_id in event_map + } + for e_id, key_to_eid in event_to_state_ids.items() + } + def redact_disallowed(event, state): if not state: return event @@ -377,7 +425,9 @@ class FederationHandler(BaseHandler): )).addErrback(unwrapFirstError) auth_events.update({a.event_id: a for a in results if a}) required_auth.update( - a_id for event in results for a_id, _ in event.auth_events if event + a_id + for event in results if event + for a_id, _ in event.auth_events ) missing_auth = required_auth - set(auth_events) @@ -560,6 +610,18 @@ class FederationHandler(BaseHandler): ])) states = dict(zip(event_ids, [s[1] for s in states])) + state_map = yield self.store.get_events( + [e_id for ids in states.values() for e_id in ids], + get_prev_content=False + ) + states = { + key: { + k: state_map[e_id] + for k, e_id in state_dict.items() + if e_id in state_map + } for key, state_dict in states.items() + } + for e_id, _ in sorted_extremeties_tuple: likely_domains = get_domains_from_state(states[e_id]) @@ -722,7 +784,7 @@ class FederationHandler(BaseHandler): # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` - self.auth.check(event, auth_events=context.current_state, do_sig_check=False) + yield self.auth.check_from_context(event, context, do_sig_check=False) defer.returnValue(event) @@ -770,18 +832,11 @@ class FederationHandler(BaseHandler): new_pdu = event - destinations = set() - - for k, s in context.current_state.items(): - try: - if k[0] == EventTypes.Member: - if s.content["membership"] == Membership.JOIN: - destinations.add(get_domain_from_id(s.state_key)) - except: - logger.warn( - "Failed to get destination from event %s", s.event_id - ) - + message_handler = self.hs.get_handlers().message_handler + destinations = yield message_handler.get_joined_hosts_for_room_from_state( + context + ) + destinations = set(destinations) destinations.discard(origin) logger.debug( @@ -792,13 +847,15 @@ class FederationHandler(BaseHandler): self.replication_layer.send_pdu(new_pdu, destinations) - state_ids = [e.event_id for e in context.current_state.values()] + state_ids = context.prev_state_ids.values() auth_chain = yield self.store.get_auth_chain(set( [event.event_id] + state_ids )) + state = yield self.store.get_events(context.prev_state_ids.values()) + defer.returnValue({ - "state": context.current_state.values(), + "state": state.values(), "auth_chain": auth_chain, }) @@ -954,7 +1011,7 @@ class FederationHandler(BaseHandler): try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_leave_request` - self.auth.check(event, auth_events=context.current_state, do_sig_check=False) + yield self.auth.check_from_context(event, context, do_sig_check=False) except AuthError as e: logger.warn("Failed to create new leave %r because %s", event, e) raise e @@ -998,18 +1055,11 @@ class FederationHandler(BaseHandler): new_pdu = event - destinations = set() - - for k, s in context.current_state.items(): - try: - if k[0] == EventTypes.Member: - if s.content["membership"] == Membership.LEAVE: - destinations.add(get_domain_from_id(s.state_key)) - except: - logger.warn( - "Failed to get destination from event %s", s.event_id - ) - + message_handler = self.hs.get_handlers().message_handler + destinations = yield message_handler.get_joined_hosts_for_room_from_state( + context + ) + destinations = set(destinations) destinations.discard(origin) logger.debug( @@ -1024,6 +1074,8 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def get_state_for_pdu(self, room_id, event_id): + """Returns the state at the event. i.e. not including said event. + """ yield run_on_reactor() state_groups = yield self.store.get_state_groups( @@ -1065,6 +1117,34 @@ class FederationHandler(BaseHandler): defer.returnValue([]) @defer.inlineCallbacks + def get_state_ids_for_pdu(self, room_id, event_id): + """Returns the state at the event. i.e. not including said event. + """ + yield run_on_reactor() + + state_groups = yield self.store.get_state_groups_ids( + room_id, [event_id] + ) + + if state_groups: + _, state = state_groups.items().pop() + results = state + + event = yield self.store.get_event(event_id) + if event and event.is_state(): + # Get previous state + if "replaces_state" in event.unsigned: + prev_id = event.unsigned["replaces_state"] + if prev_id != event.event_id: + results[(event.type, event.state_key)] = prev_id + else: + del results[(event.type, event.state_key)] + + defer.returnValue(results.values()) + else: + defer.returnValue([]) + + @defer.inlineCallbacks @log_function def on_backfill_request(self, origin, room_id, pdu_list, limit): in_room = yield self.auth.check_host_in_room(room_id, origin) @@ -1294,7 +1374,13 @@ class FederationHandler(BaseHandler): ) if not auth_events: - auth_events = context.current_state + auth_events_ids = yield self.auth.compute_auth_events( + event, context.prev_state_ids, for_verification=True, + ) + auth_events = yield self.store.get_events(auth_events_ids) + auth_events = { + (e.type, e.state_key): e for e in auth_events.values() + } # This is a hack to fix some old rooms where the initial join event # didn't reference the create event in its auth events. @@ -1320,8 +1406,7 @@ class FederationHandler(BaseHandler): context.rejected = RejectedReason.AUTH_ERROR if event.type == EventTypes.GuestAccess: - full_context = yield self.store.get_current_state(room_id=event.room_id) - yield self.maybe_kick_guest_users(event, full_context) + yield self.maybe_kick_guest_users(event) defer.returnValue(context) @@ -1389,6 +1474,11 @@ class FederationHandler(BaseHandler): current_state = set(e.event_id for e in auth_events.values()) event_auth_events = set(e_id for e_id, _ in event.auth_events) + if event.is_state(): + event_key = (event.type, event.state_key) + else: + event_key = None + if event_auth_events - current_state: have_events = yield self.store.have_events( event_auth_events - current_state @@ -1492,8 +1582,14 @@ class FederationHandler(BaseHandler): current_state = set(e.event_id for e in auth_events.values()) different_auth = event_auth_events - current_state - context.current_state.update(auth_events) - context.state_group = None + context.current_state_ids.update({ + k: a.event_id for k, a in auth_events.items() + if k != event_key + }) + context.prev_state_ids.update({ + k: a.event_id for k, a in auth_events.items() + }) + context.state_group = self.store.get_next_state_group() if different_auth and not event.internal_metadata.is_outlier(): logger.info("Different auth after resolution: %s", different_auth) @@ -1514,8 +1610,8 @@ class FederationHandler(BaseHandler): if do_resolution: # 1. Get what we think is the auth chain. - auth_ids = self.auth.compute_auth_events( - event, context.current_state + auth_ids = yield self.auth.compute_auth_events( + event, context.prev_state_ids ) local_auth_chain = yield self.store.get_auth_chain(auth_ids) @@ -1571,8 +1667,14 @@ class FederationHandler(BaseHandler): # 4. Look at rejects and their proofs. # TODO. - context.current_state.update(auth_events) - context.state_group = None + context.current_state_ids.update({ + k: a.event_id for k, a in auth_events.items() + if k != event_key + }) + context.prev_state_ids.update({ + k: a.event_id for k, a in auth_events.items() + }) + context.state_group = self.store.get_next_state_group() try: self.auth.check(event, auth_events=auth_events) @@ -1758,12 +1860,12 @@ class FederationHandler(BaseHandler): ) try: - self.auth.check(event, context.current_state) + yield self.auth.check_from_context(event, context) except AuthError as e: logger.warn("Denying new third party invite %r because %s", event, e) raise e - yield self._check_signature(event, auth_events=context.current_state) + yield self._check_signature(event, context) member_handler = self.hs.get_handlers().room_member_handler yield member_handler.send_membership_event(None, event, context) else: @@ -1789,11 +1891,11 @@ class FederationHandler(BaseHandler): ) try: - self.auth.check(event, auth_events=context.current_state) + self.auth.check_from_context(event, context) except AuthError as e: logger.warn("Denying third party invite %r because %s", event, e) raise e - yield self._check_signature(event, auth_events=context.current_state) + yield self._check_signature(event, context) returned_invite = yield self.send_invite(origin, event) # TODO: Make sure the signatures actually are correct. @@ -1807,7 +1909,12 @@ class FederationHandler(BaseHandler): EventTypes.ThirdPartyInvite, event.content["third_party_invite"]["signed"]["token"] ) - original_invite = context.current_state.get(key) + original_invite = None + original_invite_id = context.prev_state_ids.get(key) + if original_invite_id: + original_invite = yield self.store.get_event( + original_invite_id, allow_none=True + ) if not original_invite: logger.info( "Could not find invite event for third_party_invite - " @@ -1824,13 +1931,13 @@ class FederationHandler(BaseHandler): defer.returnValue((event, context)) @defer.inlineCallbacks - def _check_signature(self, event, auth_events): + def _check_signature(self, event, context): """ Checks that the signature in the event is consistent with its invite. Args: event (Event): The m.room.member event to check - auth_events (dict<(event type, state_key), event>): + context (EventContext): Raises: AuthError: if signature didn't match any keys, or key has been @@ -1841,10 +1948,14 @@ class FederationHandler(BaseHandler): signed = event.content["third_party_invite"]["signed"] token = signed["token"] - invite_event = auth_events.get( + invite_event_id = context.prev_state_ids.get( (EventTypes.ThirdPartyInvite, token,) ) + invite_event = None + if invite_event_id: + invite_event = yield self.store.get_event(invite_event_id, allow_none=True) + if not invite_event: raise AuthError(403, "Could not find invite") diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 4c3cd9d12e..3577db0595 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -30,6 +30,7 @@ from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLo from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.metrics import measure_func +from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -248,7 +249,7 @@ class MessageHandler(BaseHandler): assert self.hs.is_mine(user), "User must be our own: %s" % (user,) if event.is_state(): - prev_state = self.deduplicate_state_event(event, context) + prev_state = yield self.deduplicate_state_event(event, context) if prev_state is not None: defer.returnValue(prev_state) @@ -263,6 +264,7 @@ class MessageHandler(BaseHandler): presence = self.hs.get_presence_handler() yield presence.bump_presence_active_time(user) + @defer.inlineCallbacks def deduplicate_state_event(self, event, context): """ Checks whether event is in the latest resolved state in context. @@ -270,13 +272,17 @@ class MessageHandler(BaseHandler): If so, returns the version of the event in context. Otherwise, returns None. """ - prev_event = context.current_state.get((event.type, event.state_key)) + prev_event_id = context.prev_state_ids.get((event.type, event.state_key)) + prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + if not prev_event: + return + if prev_event and event.user_id == prev_event.user_id: prev_content = encode_canonical_json(prev_event.content) next_content = encode_canonical_json(event.content) if prev_content == next_content: - return prev_event - return None + defer.returnValue(prev_event) + return @defer.inlineCallbacks def create_and_send_nonmember_event( @@ -802,8 +808,8 @@ class MessageHandler(BaseHandler): event = builder.build() logger.debug( - "Created event %s with current state: %s", - event.event_id, context.current_state, + "Created event %s with state: %s", + event.event_id, context.prev_state_ids, ) defer.returnValue( @@ -826,12 +832,12 @@ class MessageHandler(BaseHandler): self.ratelimit(requester) try: - self.auth.check(event, auth_events=context.current_state) + yield self.auth.check_from_context(event, context) except AuthError as err: logger.warn("Denying new event %r because %s", event, err) raise err - yield self.maybe_kick_guest_users(event, context.current_state.values()) + yield self.maybe_kick_guest_users(event, context) if event.type == EventTypes.CanonicalAlias: # Check the alias is acually valid (at this time at least) @@ -859,6 +865,15 @@ class MessageHandler(BaseHandler): e.sender == event.sender ) + state_to_include_ids = [ + e_id + for k, e_id in context.current_state_ids.items() + if k[0] in self.hs.config.room_invite_state_types + or k[0] == EventTypes.Member and k[1] == event.sender + ] + + state_to_include = yield self.store.get_events(state_to_include_ids) + event.unsigned["invite_room_state"] = [ { "type": e.type, @@ -866,9 +881,7 @@ class MessageHandler(BaseHandler): "content": e.content, "sender": e.sender, } - for k, e in context.current_state.items() - if e.type in self.hs.config.room_invite_state_types - or is_inviter_member_event(e) + for e in state_to_include.values() ] invitee = UserID.from_string(event.state_key) @@ -890,7 +903,14 @@ class MessageHandler(BaseHandler): ) if event.type == EventTypes.Redaction: - if self.auth.check_redaction(event, auth_events=context.current_state): + auth_events_ids = yield self.auth.compute_auth_events( + event, context.prev_state_ids, for_verification=True, + ) + auth_events = yield self.store.get_events(auth_events_ids) + auth_events = { + (e.type, e.state_key): e for e in auth_events.values() + } + if self.auth.check_redaction(event, auth_events=auth_events): original_event = yield self.store.get_event( event.redacts, check_redacted=False, @@ -904,7 +924,7 @@ class MessageHandler(BaseHandler): "You don't have permission to redact events" ) - if event.type == EventTypes.Create and context.current_state: + if event.type == EventTypes.Create and context.prev_state_ids: raise AuthError( 403, "Changing the room create event is forbidden", @@ -925,16 +945,7 @@ class MessageHandler(BaseHandler): event_stream_id, max_stream_id ) - destinations = set() - for k, s in context.current_state.items(): - try: - if k[0] == EventTypes.Member: - if s.content["membership"] == Membership.JOIN: - destinations.add(get_domain_from_id(s.state_key)) - except SynapseError: - logger.warn( - "Failed to get destination from event %s", s.event_id - ) + destinations = yield self.get_joined_hosts_for_room_from_state(context) @defer.inlineCallbacks def _notify(): @@ -952,3 +963,39 @@ class MessageHandler(BaseHandler): preserve_fn(federation_handler.handle_new_event)( event, destinations=destinations, ) + + def get_joined_hosts_for_room_from_state(self, context): + state_group = context.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_hosts_for_room_from_state( + state_group, context.current_state_ids + ) + + @cachedInlineCallbacks(num_args=1, cache_context=True) + def _get_joined_hosts_for_room_from_state(self, state_group, current_state_ids, + cache_context): + + # Don't bother getting state for people on the same HS + current_state = yield self.store.get_events([ + e_id for key, e_id in current_state_ids.items() + if key[0] == EventTypes.Member and not self.hs.is_mine_id(key[1]) + ]) + + destinations = set() + for e in current_state.itervalues(): + try: + if e.type == EventTypes.Member: + if e.content["membership"] == Membership.JOIN: + destinations.add(get_domain_from_id(e.state_key)) + except SynapseError: + logger.warn( + "Failed to get destination from event %s", e.event_id + ) + + defer.returnValue(destinations) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 6a1fe76c88..cf82a2336e 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -88,6 +88,8 @@ class PresenceHandler(object): self.notifier = hs.get_notifier() self.federation = hs.get_replication_layer() + self.state = hs.get_state_handler() + self.federation.register_edu_handler( "m.presence", self.incoming_presence ) @@ -189,6 +191,13 @@ class PresenceHandler(object): 5000, ) + self.clock.call_later( + 60, + self.clock.looping_call, + self._persist_unpersisted_changes, + 60 * 1000, + ) + metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer)) @defer.inlineCallbacks @@ -215,6 +224,27 @@ class PresenceHandler(object): logger.info("Finished _on_shutdown") @defer.inlineCallbacks + def _persist_unpersisted_changes(self): + """We periodically persist the unpersisted changes, as otherwise they + may stack up and slow down shutdown times. + """ + logger.info( + "Performing _persist_unpersisted_changes. Persiting %d unpersisted changes", + len(self.unpersisted_users_changes) + ) + + unpersisted = self.unpersisted_users_changes + self.unpersisted_users_changes = set() + + if unpersisted: + yield self.store.update_presence([ + self.user_to_current_state[user_id] + for user_id in unpersisted + ]) + + logger.info("Finished _persist_unpersisted_changes") + + @defer.inlineCallbacks def _update_states(self, new_states): """Updates presence of users. Sets the appropriate timeouts. Pokes the notifier and federation if and only if the changed presence state @@ -532,7 +562,9 @@ class PresenceHandler(object): if not local_states: continue - hosts = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + hosts = set(get_domain_from_id(u) for u in users) + for host in hosts: hosts_to_states.setdefault(host, []).extend(local_states) @@ -725,13 +757,13 @@ class PresenceHandler(object): # don't need to send to local clients here, as that is done as part # of the event stream/sync. # TODO: Only send to servers not already in the room. + user_ids = yield self.state.get_current_user_in_room(room_id) if self.is_mine(user): state = yield self.current_state_for_user(user.to_string()) - hosts = yield self.store.get_joined_hosts_for_room(room_id) + hosts = set(get_domain_from_id(u) for u in user_ids) self._push_to_remotes({host: (state,) for host in hosts}) else: - user_ids = yield self.store.get_users_in_room(room_id) user_ids = filter(self.is_mine_id, user_ids) states = yield self.current_state_for_users(user_ids) @@ -918,7 +950,12 @@ def should_notify(old_state, new_state): if new_state.currently_active != old_state.currently_active: return True - if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: + if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: + # Only notify about last active bumps if we're not currently acive + if not (old_state.currently_active and new_state.currently_active): + return True + + elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: # Always notify for a transition where last active gets bumped. return True @@ -955,6 +992,7 @@ class PresenceEventSource(object): self.get_presence_handler = hs.get_presence_handler self.clock = hs.get_clock() self.store = hs.get_datastore() + self.state = hs.get_state_handler() @defer.inlineCallbacks @log_function @@ -1017,7 +1055,7 @@ class PresenceEventSource(object): user_ids_to_check = set() for room_id in room_ids: - users = yield self.store.get_users_in_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) user_ids_to_check.update(users) user_ids_to_check.update(friends) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index e62722d78d..726f7308d2 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -18,6 +18,7 @@ from ._base import BaseHandler from twisted.internet import defer from synapse.util.logcontext import PreserveLoggingContext +from synapse.types import get_domain_from_id import logging @@ -37,6 +38,7 @@ class ReceiptsHandler(BaseHandler): "m.receipt", self._received_remote_receipt ) self.clock = self.hs.get_clock() + self.state = hs.get_state_handler() @defer.inlineCallbacks def received_client_receipt(self, room_id, receipt_type, user_id, @@ -133,7 +135,8 @@ class ReceiptsHandler(BaseHandler): event_ids = receipt["event_ids"] data = receipt["data"] - remotedomains = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + remotedomains = set(get_domain_from_id(u) for u in users) remotedomains = remotedomains.copy() remotedomains.discard(self.server_name) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 8b17632fdc..ba49075a20 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -85,6 +85,12 @@ class RoomMemberHandler(BaseHandler): prev_event_ids=prev_event_ids, ) + # Check if this event matches the previous membership event for the user. + duplicate = yield msg_handler.deduplicate_state_event(event, context) + if duplicate is not None: + # Discard the new event since this membership change is a no-op. + return + yield msg_handler.handle_new_client_event( requester, event, @@ -93,20 +99,26 @@ class RoomMemberHandler(BaseHandler): ratelimit=ratelimit, ) - prev_member_event = context.current_state.get( + prev_member_event_id = context.prev_state_ids.get( (EventTypes.Member, target.to_string()), None ) if event.membership == Membership.JOIN: - if not prev_member_event or prev_member_event.membership != Membership.JOIN: - # Only fire user_joined_room if the user has acutally joined the - # room. Don't bother if the user is just changing their profile - # info. + # Only fire user_joined_room if the user has acutally joined the + # room. Don't bother if the user is just changing their profile + # info. + newly_joined = True + if prev_member_event_id: + prev_member_event = yield self.store.get_event(prev_member_event_id) + newly_joined = prev_member_event.membership != Membership.JOIN + if newly_joined: yield user_joined_room(self.distributor, target, room_id) elif event.membership == Membership.LEAVE: - if prev_member_event and prev_member_event.membership == Membership.JOIN: - user_left_room(self.distributor, target, room_id) + if prev_member_event_id: + prev_member_event = yield self.store.get_event(prev_member_event_id) + if prev_member_event.membership == Membership.JOIN: + user_left_room(self.distributor, target, room_id) @defer.inlineCallbacks def remote_join(self, remote_room_hosts, room_id, user, content): @@ -195,29 +207,32 @@ class RoomMemberHandler(BaseHandler): remote_room_hosts = [] latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - current_state = yield self.state_handler.get_current_state( + current_state_ids = yield self.state_handler.get_current_state_ids( room_id, latest_event_ids=latest_event_ids, ) - old_state = current_state.get((EventTypes.Member, target.to_string())) - old_membership = old_state.content.get("membership") if old_state else None - if action == "unban" and old_membership != "ban": - raise SynapseError( - 403, - "Cannot unban user who was not banned (membership=%s)" % old_membership, - errcode=Codes.BAD_STATE - ) - if old_membership == "ban" and action != "unban": - raise SynapseError( - 403, - "Cannot %s user who was banned" % (action,), - errcode=Codes.BAD_STATE - ) + old_state_id = current_state_ids.get((EventTypes.Member, target.to_string())) + if old_state_id: + old_state = yield self.store.get_event(old_state_id, allow_none=True) + old_membership = old_state.content.get("membership") if old_state else None + if action == "unban" and old_membership != "ban": + raise SynapseError( + 403, + "Cannot unban user who was not banned" + " (membership=%s)" % old_membership, + errcode=Codes.BAD_STATE + ) + if old_membership == "ban" and action != "unban": + raise SynapseError( + 403, + "Cannot %s user who was banned" % (action,), + errcode=Codes.BAD_STATE + ) - is_host_in_room = self.is_host_in_room(current_state) + is_host_in_room = yield self._is_host_in_room(current_state_ids) if effective_membership_state == Membership.JOIN: - if requester.is_guest and not self._can_guest_join(current_state): + if requester.is_guest and not self._can_guest_join(current_state_ids): # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") @@ -326,15 +341,17 @@ class RoomMemberHandler(BaseHandler): requester = synapse.types.create_requester(target_user) message_handler = self.hs.get_handlers().message_handler - prev_event = message_handler.deduplicate_state_event(event, context) + prev_event = yield message_handler.deduplicate_state_event(event, context) if prev_event is not None: return if event.membership == Membership.JOIN: - if requester.is_guest and not self._can_guest_join(context.current_state): - # This should be an auth check, but guests are a local concept, - # so don't really fit into the general auth process. - raise AuthError(403, "Guest access not allowed") + if requester.is_guest: + guest_can_join = yield self._can_guest_join(context.prev_state_ids) + if not guest_can_join: + # This should be an auth check, but guests are a local concept, + # so don't really fit into the general auth process. + raise AuthError(403, "Guest access not allowed") yield message_handler.handle_new_client_event( requester, @@ -344,27 +361,39 @@ class RoomMemberHandler(BaseHandler): ratelimit=ratelimit, ) - prev_member_event = context.current_state.get( - (EventTypes.Member, target_user.to_string()), + prev_member_event_id = context.prev_state_ids.get( + (EventTypes.Member, event.state_key), None ) if event.membership == Membership.JOIN: - if not prev_member_event or prev_member_event.membership != Membership.JOIN: - # Only fire user_joined_room if the user has acutally joined the - # room. Don't bother if the user is just changing their profile - # info. + # Only fire user_joined_room if the user has acutally joined the + # room. Don't bother if the user is just changing their profile + # info. + newly_joined = True + if prev_member_event_id: + prev_member_event = yield self.store.get_event(prev_member_event_id) + newly_joined = prev_member_event.membership != Membership.JOIN + if newly_joined: yield user_joined_room(self.distributor, target_user, room_id) elif event.membership == Membership.LEAVE: - if prev_member_event and prev_member_event.membership == Membership.JOIN: - user_left_room(self.distributor, target_user, room_id) + if prev_member_event_id: + prev_member_event = yield self.store.get_event(prev_member_event_id) + if prev_member_event.membership == Membership.JOIN: + user_left_room(self.distributor, target_user, room_id) - def _can_guest_join(self, current_state): + @defer.inlineCallbacks + def _can_guest_join(self, current_state_ids): """ Returns whether a guest can join a room based on its current state. """ - guest_access = current_state.get((EventTypes.GuestAccess, ""), None) - return ( + guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None) + if not guest_access_id: + defer.returnValue(False) + + guest_access = yield self.store.get_event(guest_access_id) + + defer.returnValue( guest_access and guest_access.content and "guest_access" in guest_access.content @@ -683,3 +712,24 @@ class RoomMemberHandler(BaseHandler): if membership: yield self.store.forget(user_id, room_id) + + @defer.inlineCallbacks + def _is_host_in_room(self, current_state_ids): + # Have we just created the room, and is this about to be the very + # first member event? + create_event_id = current_state_ids.get(("m.room.create", "")) + if len(current_state_ids) == 1 and create_event_id: + defer.returnValue(self.hs.is_mine_id(create_event_id)) + + for (etype, state_key), event_id in current_state_ids.items(): + if etype != EventTypes.Member or not self.hs.is_mine_id(state_key): + continue + + event = yield self.store.get_event(event_id, allow_none=True) + if not event: + continue + + if event.membership == Membership.JOIN: + defer.returnValue(True) + + defer.returnValue(False) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c8dfd02e7b..b5962f4f5a 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -35,6 +35,7 @@ SyncConfig = collections.namedtuple("SyncConfig", [ "filter_collection", "is_guest", "request_key", + "device_id", ]) @@ -113,6 +114,7 @@ class SyncResult(collections.namedtuple("SyncResult", [ "joined", # JoinedSyncResult for each joined room. "invited", # InvitedSyncResult for each invited room. "archived", # ArchivedSyncResult for each archived room. + "to_device", # List of direct messages for the device. ])): __slots__ = [] @@ -126,7 +128,8 @@ class SyncResult(collections.namedtuple("SyncResult", [ self.joined or self.invited or self.archived or - self.account_data + self.account_data or + self.to_device ) @@ -139,6 +142,7 @@ class SyncHandler(object): self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() self.response_cache = ResponseCache(hs) + self.state = hs.get_state_handler() def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, full_state=False): @@ -355,11 +359,11 @@ class SyncHandler(object): Returns: A Deferred map from ((type, state_key)->Event) """ - state = yield self.store.get_state_for_event(event.event_id) + state_ids = yield self.store.get_state_ids_for_event(event.event_id) if event.is_state(): - state = state.copy() - state[(event.type, event.state_key)] = event - defer.returnValue(state) + state_ids = state_ids.copy() + state_ids[(event.type, event.state_key)] = event.event_id + defer.returnValue(state_ids) @defer.inlineCallbacks def get_state_at(self, room_id, stream_position): @@ -412,57 +416,61 @@ class SyncHandler(object): with Measure(self.clock, "compute_state_delta"): if full_state: if batch: - current_state = yield self.store.get_state_for_event( + current_state_ids = yield self.store.get_state_ids_for_event( batch.events[-1].event_id ) - state = yield self.store.get_state_for_event( + state_ids = yield self.store.get_state_ids_for_event( batch.events[0].event_id ) else: - current_state = yield self.get_state_at( + current_state_ids = yield self.get_state_at( room_id, stream_position=now_token ) - state = current_state + state_ids = current_state_ids timeline_state = { - (event.type, event.state_key): event + (event.type, event.state_key): event.event_id for event in batch.events if event.is_state() } - state = _calculate_state( + state_ids = _calculate_state( timeline_contains=timeline_state, - timeline_start=state, + timeline_start=state_ids, previous={}, - current=current_state, + current=current_state_ids, ) elif batch.limited: state_at_previous_sync = yield self.get_state_at( room_id, stream_position=since_token ) - current_state = yield self.store.get_state_for_event( + current_state_ids = yield self.store.get_state_ids_for_event( batch.events[-1].event_id ) - state_at_timeline_start = yield self.store.get_state_for_event( + state_at_timeline_start = yield self.store.get_state_ids_for_event( batch.events[0].event_id ) timeline_state = { - (event.type, event.state_key): event + (event.type, event.state_key): event.event_id for event in batch.events if event.is_state() } - state = _calculate_state( + state_ids = _calculate_state( timeline_contains=timeline_state, timeline_start=state_at_timeline_start, previous=state_at_previous_sync, - current=current_state, + current=current_state_ids, ) else: - state = {} + state_ids = {} + + state = {} + if state_ids: + state = yield self.store.get_events(state_ids.values()) defer.returnValue({ (e.type, e.state_key): e @@ -527,16 +535,58 @@ class SyncHandler(object): sync_result_builder, newly_joined_rooms, newly_joined_users ) + yield self._generate_sync_entry_for_to_device(sync_result_builder) + defer.returnValue(SyncResult( presence=sync_result_builder.presence, account_data=sync_result_builder.account_data, joined=sync_result_builder.joined, invited=sync_result_builder.invited, archived=sync_result_builder.archived, + to_device=sync_result_builder.to_device, next_batch=sync_result_builder.now_token, )) @defer.inlineCallbacks + def _generate_sync_entry_for_to_device(self, sync_result_builder): + """Generates the portion of the sync response. Populates + `sync_result_builder` with the result. + + Args: + sync_result_builder(SyncResultBuilder) + + Returns: + Deferred(dict): A dictionary containing the per room account data. + """ + user_id = sync_result_builder.sync_config.user.to_string() + device_id = sync_result_builder.sync_config.device_id + now_token = sync_result_builder.now_token + since_stream_id = 0 + if sync_result_builder.since_token is not None: + since_stream_id = int(sync_result_builder.since_token.to_device_key) + + if since_stream_id != int(now_token.to_device_key): + # We only delete messages when a new message comes in, but that's + # fine so long as we delete them at some point. + + logger.debug("Deleting messages up to %d", since_stream_id) + yield self.store.delete_messages_for_device( + user_id, device_id, since_stream_id + ) + + logger.debug("Getting messages up to %d", now_token.to_device_key) + messages, stream_id = yield self.store.get_new_messages_for_device( + user_id, device_id, since_stream_id, now_token.to_device_key + ) + logger.debug("Got messages up to %d: %r", stream_id, messages) + sync_result_builder.now_token = now_token.copy_and_replace( + "to_device_key", stream_id + ) + sync_result_builder.to_device = messages + else: + sync_result_builder.to_device = [] + + @defer.inlineCallbacks def _generate_sync_entry_for_account_data(self, sync_result_builder): """Generates the account data portion of the sync response. Populates `sync_result_builder` with the result. @@ -626,7 +676,7 @@ class SyncHandler(object): extra_users_ids = set(newly_joined_users) for room_id in newly_joined_rooms: - users = yield self.store.get_users_in_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) extra_users_ids.update(users) extra_users_ids.discard(user.to_string()) @@ -766,8 +816,13 @@ class SyncHandler(object): # the last sync (even if we have since left). This is to make sure # we do send down the room, and with full state, where necessary if room_id in joined_room_ids or has_join: - old_state = yield self.get_state_at(room_id, since_token) - old_mem_ev = old_state.get((EventTypes.Member, user_id), None) + old_state_ids = yield self.get_state_at(room_id, since_token) + old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None) + old_mem_ev = None + if old_mem_ev_id: + old_mem_ev = yield self.store.get_event( + old_mem_ev_id, allow_none=True + ) if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: newly_joined_rooms.append(room_id) @@ -1059,27 +1114,25 @@ def _calculate_state(timeline_contains, timeline_start, previous, current): Returns: dict """ - event_id_to_state = { - e.event_id: e - for e in itertools.chain( - timeline_contains.values(), - previous.values(), - timeline_start.values(), - current.values(), + event_id_to_key = { + e: key + for key, e in itertools.chain( + timeline_contains.items(), + previous.items(), + timeline_start.items(), + current.items(), ) } - c_ids = set(e.event_id for e in current.values()) - tc_ids = set(e.event_id for e in timeline_contains.values()) - p_ids = set(e.event_id for e in previous.values()) - ts_ids = set(e.event_id for e in timeline_start.values()) + c_ids = set(e for e in current.values()) + tc_ids = set(e for e in timeline_contains.values()) + p_ids = set(e for e in previous.values()) + ts_ids = set(e for e in timeline_start.values()) state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids - evs = (event_id_to_state[e] for e in state_ids) return { - (e.type, e.state_key): e - for e in evs + event_id_to_key[e]: e for e in state_ids } @@ -1103,6 +1156,7 @@ class SyncResultBuilder(object): self.joined = [] self.invited = [] self.archived = [] + self.device = [] class RoomSyncResultBuilder(object): diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 46181984c0..0b530b9034 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -20,7 +20,7 @@ from synapse.util.logcontext import ( PreserveLoggingContext, preserve_fn, preserve_context_over_deferred, ) from synapse.util.metrics import Measure -from synapse.types import UserID +from synapse.types import UserID, get_domain_from_id import logging @@ -42,6 +42,7 @@ class TypingHandler(object): self.auth = hs.get_auth() self.is_mine_id = hs.is_mine_id self.notifier = hs.get_notifier() + self.state = hs.get_state_handler() self.clock = hs.get_clock() @@ -166,7 +167,8 @@ class TypingHandler(object): @defer.inlineCallbacks def _push_update(self, room_id, user_id, typing): - domains = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + domains = set(get_domain_from_id(u) for u in users) deferreds = [] for domain in domains: @@ -199,7 +201,8 @@ class TypingHandler(object): # Check that the string is a valid user id UserID.from_string(user_id) - domains = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + domains = set(get_domain_from_id(u) for u in users) if self.server_name in domains: self._push_update_local( diff --git a/synapse/notifier.py b/synapse/notifier.py index b86648f5e4..48653ae843 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -423,7 +423,8 @@ class Notifier(object): def _is_world_readable(self, room_id): state = yield self.state_handler.get_current_state( room_id, - EventTypes.RoomHistoryVisibility + EventTypes.RoomHistoryVisibility, + "", ) if state and "history_visibility" in state.content: defer.returnValue(state.content["history_visibility"] == "world_readable") diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index ed2ccc4dfb..3f75d3f921 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -40,12 +40,12 @@ class ActionGenerator: def handle_push_actions_for_event(self, event, context): with Measure(self.clock, "evaluator_for_event"): bulk_evaluator = yield evaluator_for_event( - event, self.hs, self.store, context.state_group, context.current_state + event, self.hs, self.store, context ) with Measure(self.clock, "action_for_event_by_user"): actions_by_user = yield bulk_evaluator.action_for_event_by_user( - event, context.current_state + event, context ) context.push_actions = [ diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 004eded61f..f1bbe57dcb 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -19,8 +19,8 @@ from twisted.internet import defer from .push_rule_evaluator import PushRuleEvaluatorForEvent -from synapse.api.constants import EventTypes, Membership -from synapse.visibility import filter_events_for_clients +from synapse.api.constants import EventTypes +from synapse.visibility import filter_events_for_clients_context logger = logging.getLogger(__name__) @@ -36,9 +36,9 @@ def _get_rules(room_id, user_ids, store): @defer.inlineCallbacks -def evaluator_for_event(event, hs, store, state_group, current_state): +def evaluator_for_event(event, hs, store, context): rules_by_user = yield store.bulk_get_push_rules_for_room( - event.room_id, state_group, current_state + event, context ) # if this event is an invite event, we may need to run rules for the user @@ -72,7 +72,7 @@ class BulkPushRuleEvaluator: self.store = store @defer.inlineCallbacks - def action_for_event_by_user(self, event, current_state): + def action_for_event_by_user(self, event, context): actions_by_user = {} # None of these users can be peeking since this list of users comes @@ -82,27 +82,25 @@ class BulkPushRuleEvaluator: (u, False) for u in self.rules_by_user.keys() ] - filtered_by_user = yield filter_events_for_clients( - self.store, user_tuples, [event], {event.event_id: current_state} + filtered_by_user = yield filter_events_for_clients_context( + self.store, user_tuples, [event], {event.event_id: context} ) - room_members = set( - e.state_key for e in current_state.values() - if e.type == EventTypes.Member and e.membership == Membership.JOIN + room_members = yield self.store.get_joined_users_from_context( + event, context ) evaluator = PushRuleEvaluatorForEvent(event, len(room_members)) condition_cache = {} - display_names = {} - for ev in current_state.values(): - nm = ev.content.get("displayname", None) - if nm and ev.type == EventTypes.Member: - display_names[ev.state_key] = nm - for uid, rules in self.rules_by_user.items(): - display_name = display_names.get(uid, None) + display_name = None + member_ev_id = context.current_state_ids.get((EventTypes.Member, uid)) + if member_ev_id: + member_ev = yield self.store.get_event(member_ev_id, allow_none=True) + if member_ev: + display_name = member_ev.content.get("displayname", None) filtered = filtered_by_user[uid] if len(filtered) == 0: diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index feedb075e2..c0f8176e3d 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -245,7 +245,7 @@ class HttpPusher(object): @defer.inlineCallbacks def _build_notification_dict(self, event, tweaks, badge): ctx = yield push_tools.get_context_for_event( - self.state_handler, event, self.user_id + self.store, self.state_handler, event, self.user_id ) d = { diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 1028731bc9..2cafcfd8f5 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -22,7 +22,7 @@ from email.mime.text import MIMEText from email.mime.multipart import MIMEMultipart from synapse.util.async import concurrently_execute -from synapse.util.presentable_names import ( +from synapse.push.presentable_names import ( calculate_room_name, name_from_member_event, descriptor_from_member_events ) from synapse.types import UserID @@ -139,7 +139,7 @@ class Mailer(object): @defer.inlineCallbacks def _fetch_room_state(room_id): - room_state = yield self.state_handler.get_current_state(room_id) + room_state = yield self.state_handler.get_current_state_ids(room_id) state_by_room[room_id] = room_state # Run at most 3 of these at once: sync does 10 at a time but email @@ -159,11 +159,12 @@ class Mailer(object): ) rooms.append(roomvars) - reason['room_name'] = calculate_room_name( - state_by_room[reason['room_id']], user_id, fallback_to_members=True + reason['room_name'] = yield calculate_room_name( + self.store, state_by_room[reason['room_id']], user_id, + fallback_to_members=True ) - summary_text = self.make_summary_text( + summary_text = yield self.make_summary_text( notifs_by_room, state_by_room, notif_events, user_id, reason ) @@ -203,12 +204,15 @@ class Mailer(object): ) @defer.inlineCallbacks - def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state): - my_member_event = room_state[("m.room.member", user_id)] + def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state_ids): + my_member_event_id = room_state_ids[("m.room.member", user_id)] + my_member_event = yield self.store.get_event(my_member_event_id) is_invite = my_member_event.content["membership"] == "invite" + room_name = yield calculate_room_name(self.store, room_state_ids, user_id) + room_vars = { - "title": calculate_room_name(room_state, user_id), + "title": room_name, "hash": string_ordinal_total(room_id), # See sender avatar hash "notifs": [], "invite": is_invite, @@ -218,7 +222,7 @@ class Mailer(object): if not is_invite: for n in notifs: notifvars = yield self.get_notif_vars( - n, user_id, notif_events[n['event_id']], room_state + n, user_id, notif_events[n['event_id']], room_state_ids ) # merge overlapping notifs together. @@ -243,7 +247,7 @@ class Mailer(object): defer.returnValue(room_vars) @defer.inlineCallbacks - def get_notif_vars(self, notif, user_id, notif_event, room_state): + def get_notif_vars(self, notif, user_id, notif_event, room_state_ids): results = yield self.store.get_events_around( notif['room_id'], notif['event_id'], before_limit=CONTEXT_BEFORE, after_limit=CONTEXT_AFTER @@ -261,17 +265,19 @@ class Mailer(object): the_events.append(notif_event) for event in the_events: - messagevars = self.get_message_vars(notif, event, room_state) + messagevars = yield self.get_message_vars(notif, event, room_state_ids) if messagevars is not None: ret['messages'].append(messagevars) defer.returnValue(ret) - def get_message_vars(self, notif, event, room_state): + @defer.inlineCallbacks + def get_message_vars(self, notif, event, room_state_ids): if event.type != EventTypes.Message: - return None + return - sender_state_event = room_state[("m.room.member", event.sender)] + sender_state_event_id = room_state_ids[("m.room.member", event.sender)] + sender_state_event = yield self.store.get_event(sender_state_event_id) sender_name = name_from_member_event(sender_state_event) sender_avatar_url = sender_state_event.content.get("avatar_url") @@ -299,7 +305,7 @@ class Mailer(object): if "body" in event.content: ret["body_text_plain"] = event.content["body"] - return ret + defer.returnValue(ret) def add_text_message_vars(self, messagevars, event): msgformat = event.content.get("format") @@ -321,6 +327,7 @@ class Mailer(object): return messagevars + @defer.inlineCallbacks def make_summary_text(self, notifs_by_room, state_by_room, notif_events, user_id, reason): if len(notifs_by_room) == 1: @@ -330,8 +337,8 @@ class Mailer(object): # If the room has some kind of name, use it, but we don't # want the generated-from-names one here otherwise we'll # end up with, "new message from Bob in the Bob room" - room_name = calculate_room_name( - state_by_room[room_id], user_id, fallback_to_members=False + room_name = yield calculate_room_name( + self.store, state_by_room[room_id], user_id, fallback_to_members=False ) my_member_event = state_by_room[room_id][("m.room.member", user_id)] @@ -342,16 +349,16 @@ class Mailer(object): inviter_name = name_from_member_event(inviter_member_event) if room_name is None: - return INVITE_FROM_PERSON % { + defer.returnValue(INVITE_FROM_PERSON % { "person": inviter_name, "app": self.app_name - } + }) else: - return INVITE_FROM_PERSON_TO_ROOM % { + defer.returnValue(INVITE_FROM_PERSON_TO_ROOM % { "person": inviter_name, "room": room_name, "app": self.app_name, - } + }) sender_name = None if len(notifs_by_room[room_id]) == 1: @@ -362,24 +369,24 @@ class Mailer(object): sender_name = name_from_member_event(state_event) if sender_name is not None and room_name is not None: - return MESSAGE_FROM_PERSON_IN_ROOM % { + defer.returnValue(MESSAGE_FROM_PERSON_IN_ROOM % { "person": sender_name, "room": room_name, "app": self.app_name, - } + }) elif sender_name is not None: - return MESSAGE_FROM_PERSON % { + defer.returnValue(MESSAGE_FROM_PERSON % { "person": sender_name, "app": self.app_name, - } + }) else: # There's more than one notification for this room, so just # say there are several if room_name is not None: - return MESSAGES_IN_ROOM % { + defer.returnValue(MESSAGES_IN_ROOM % { "room": room_name, "app": self.app_name, - } + }) else: # If the room doesn't have a name, say who the messages # are from explicitly to avoid, "messages in the Bob room" @@ -388,22 +395,22 @@ class Mailer(object): for n in notifs_by_room[room_id] ])) - return MESSAGES_FROM_PERSON % { + defer.returnValue(MESSAGES_FROM_PERSON % { "person": descriptor_from_member_events([ state_by_room[room_id][("m.room.member", s)] for s in sender_ids ]), "app": self.app_name, - } + }) else: # Stuff's happened in multiple different rooms # ...but we still refer to the 'reason' room which triggered the mail if reason['room_name'] is not None: - return MESSAGES_IN_ROOM_AND_OTHERS % { + defer.returnValue(MESSAGES_IN_ROOM_AND_OTHERS % { "room": reason['room_name'], "app": self.app_name, - } + }) else: # If the reason room doesn't have a name, say who the messages # are from explicitly to avoid, "messages in the Bob room" @@ -412,13 +419,13 @@ class Mailer(object): for n in notifs_by_room[reason['room_id']] ])) - return MESSAGES_FROM_PERSON_AND_OTHERS % { + defer.returnValue(MESSAGES_FROM_PERSON_AND_OTHERS % { "person": descriptor_from_member_events([ state_by_room[reason['room_id']][("m.room.member", s)] for s in sender_ids ]), "app": self.app_name, - } + }) def make_room_link(self, room_id): # need /beta for Universal Links to work on iOS diff --git a/synapse/util/presentable_names.py b/synapse/push/presentable_names.py index f68676e9e7..277da3cd35 100644 --- a/synapse/util/presentable_names.py +++ b/synapse/push/presentable_names.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + import re import logging @@ -25,7 +27,8 @@ ALIAS_RE = re.compile(r"^#.*:.+$") ALL_ALONE = "Empty Room" -def calculate_room_name(room_state, user_id, fallback_to_members=True, +@defer.inlineCallbacks +def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True, fallback_to_single_member=True): """ Works out a user-facing name for the given room as per Matrix @@ -42,59 +45,78 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True, (string or None) A human readable name for the room. """ # does it have a name? - if ("m.room.name", "") in room_state: - m_room_name = room_state[("m.room.name", "")] - if m_room_name.content and m_room_name.content["name"]: - return m_room_name.content["name"] + if ("m.room.name", "") in room_state_ids: + m_room_name = yield store.get_event( + room_state_ids[("m.room.name", "")], allow_none=True + ) + if m_room_name and m_room_name.content and m_room_name.content["name"]: + defer.returnValue(m_room_name.content["name"]) # does it have a canonical alias? - if ("m.room.canonical_alias", "") in room_state: - canon_alias = room_state[("m.room.canonical_alias", "")] + if ("m.room.canonical_alias", "") in room_state_ids: + canon_alias = yield store.get_event( + room_state_ids[("m.room.canonical_alias", "")], allow_none=True + ) if ( - canon_alias.content and canon_alias.content["alias"] and + canon_alias and canon_alias.content and canon_alias.content["alias"] and _looks_like_an_alias(canon_alias.content["alias"]) ): - return canon_alias.content["alias"] + defer.returnValue(canon_alias.content["alias"]) # at this point we're going to need to search the state by all state keys # for an event type, so rearrange the data structure - room_state_bytype = _state_as_two_level_dict(room_state) + room_state_bytype_ids = _state_as_two_level_dict(room_state_ids) # right then, any aliases at all? - if "m.room.aliases" in room_state_bytype: - m_room_aliases = room_state_bytype["m.room.aliases"] - if len(m_room_aliases.values()) > 0: - first_alias_event = m_room_aliases.values()[0] - if first_alias_event.content and first_alias_event.content["aliases"]: - the_aliases = first_alias_event.content["aliases"] + if "m.room.aliases" in room_state_bytype_ids: + m_room_aliases = room_state_bytype_ids["m.room.aliases"] + for alias_id in m_room_aliases.values(): + alias_event = yield store.get_event( + alias_id, allow_none=True + ) + if alias_event and alias_event.content.get("aliases"): + the_aliases = alias_event.content["aliases"] if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]): - return the_aliases[0] + defer.returnValue(the_aliases[0]) if not fallback_to_members: - return None + defer.returnValue(None) my_member_event = None - if ("m.room.member", user_id) in room_state: - my_member_event = room_state[("m.room.member", user_id)] + if ("m.room.member", user_id) in room_state_ids: + my_member_event = yield store.get_event( + room_state_ids[("m.room.member", user_id)], allow_none=True + ) if ( my_member_event is not None and my_member_event.content['membership'] == "invite" ): - if ("m.room.member", my_member_event.sender) in room_state: - inviter_member_event = room_state[("m.room.member", my_member_event.sender)] - if fallback_to_single_member: - return "Invite from %s" % (name_from_member_event(inviter_member_event),) - else: - return None + if ("m.room.member", my_member_event.sender) in room_state_ids: + inviter_member_event = yield store.get_event( + room_state_ids[("m.room.member", my_member_event.sender)], + allow_none=True, + ) + if inviter_member_event: + if fallback_to_single_member: + defer.returnValue( + "Invite from %s" % ( + name_from_member_event(inviter_member_event), + ) + ) + else: + return else: - return "Room Invite" + defer.returnValue("Room Invite") # we're going to have to generate a name based on who's in the room, # so find out who is in the room that isn't the user. - if "m.room.member" in room_state_bytype: + if "m.room.member" in room_state_bytype_ids: + member_events = yield store.get_events( + room_state_bytype_ids["m.room.member"].values() + ) all_members = [ - ev for ev in room_state_bytype["m.room.member"].values() + ev for ev in member_events.values() if ev.content['membership'] == "join" or ev.content['membership'] == "invite" ] # Sort the member events oldest-first so the we name people in the @@ -111,9 +133,9 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True, # self-chat, peeked room with 1 participant, # or inbound invite, or outbound 3PID invite. if all_members[0].sender == user_id: - if "m.room.third_party_invite" in room_state_bytype: + if "m.room.third_party_invite" in room_state_bytype_ids: third_party_invites = ( - room_state_bytype["m.room.third_party_invite"].values() + room_state_bytype_ids["m.room.third_party_invite"].values() ) if len(third_party_invites) > 0: @@ -126,17 +148,17 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True, # return "Inviting %s" % ( # descriptor_from_member_events(third_party_invites) # ) - return "Inviting email address" + defer.returnValue("Inviting email address") else: - return ALL_ALONE + defer.returnValue(ALL_ALONE) else: - return name_from_member_event(all_members[0]) + defer.returnValue(name_from_member_event(all_members[0])) else: - return ALL_ALONE + defer.returnValue(ALL_ALONE) elif len(other_members) == 1 and not fallback_to_single_member: - return None + return else: - return descriptor_from_member_events(other_members) + defer.returnValue(descriptor_from_member_events(other_members)) def descriptor_from_member_events(member_events): diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index becb8ef1ae..b47bf1f92b 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -14,7 +14,7 @@ # limitations under the License. from twisted.internet import defer -from synapse.util.presentable_names import ( +from synapse.push.presentable_names import ( calculate_room_name, name_from_member_event ) from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred @@ -49,21 +49,22 @@ def get_badge_count(store, user_id): @defer.inlineCallbacks -def get_context_for_event(state_handler, ev, user_id): +def get_context_for_event(store, state_handler, ev, user_id): ctx = {} - room_state = yield state_handler.get_current_state(ev.room_id) + room_state_ids = yield state_handler.get_current_state_ids(ev.room_id) # we no longer bother setting room_alias, and make room_name the # human-readable name instead, be that m.room.name, an alias or # a list of people in the room - name = calculate_room_name( - room_state, user_id, fallback_to_single_member=False + name = yield calculate_room_name( + store, room_state_ids, user_id, fallback_to_single_member=False ) if name: ctx['name'] = name - sender_state_event = room_state[("m.room.member", ev.sender)] + sender_state_event_id = room_state_ids[("m.room.member", ev.sender)] + sender_state_event = yield store.get_event(sender_state_event_id) ctx['sender_display_name'] = name_from_member_event(sender_state_event) defer.returnValue(ctx) diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 84993b33b3..1ed9034bcb 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -40,8 +40,8 @@ STREAM_NAMES = ( ("backfill",), ("push_rules",), ("pushers",), - ("state",), ("caches",), + ("to_device",), ) @@ -130,7 +130,6 @@ class ReplicationResource(Resource): backfill_token = yield self.store.get_current_backfill_token() push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() pushers_token = self.store.get_pushers_stream_token() - state_token = self.store.get_state_stream_token() caches_token = self.store.get_cache_stream_token() defer.returnValue(_ReplicationToken( @@ -142,8 +141,9 @@ class ReplicationResource(Resource): backfill_token, push_rules_token, pushers_token, - state_token, + 0, # State stream is no longer a thing caches_token, + int(stream_token.to_device_key), )) @request_handler() @@ -191,8 +191,8 @@ class ReplicationResource(Resource): yield self.receipts(writer, current_token, limit, request_streams) yield self.push_rules(writer, current_token, limit, request_streams) yield self.pushers(writer, current_token, limit, request_streams) - yield self.state(writer, current_token, limit, request_streams) yield self.caches(writer, current_token, limit, request_streams) + yield self.to_device(writer, current_token, limit, request_streams) self.streams(writer, current_token, request_streams) logger.info("Replicated %d rows", writer.total) @@ -366,25 +366,6 @@ class ReplicationResource(Resource): )) @defer.inlineCallbacks - def state(self, writer, current_token, limit, request_streams): - current_position = current_token.state - - state = request_streams.get("state") - - if state is not None: - state_groups, state_group_state = ( - yield self.store.get_all_new_state_groups( - state, current_position, limit - ) - ) - writer.write_header_and_rows("state_groups", state_groups, ( - "position", "room_id", "event_id" - )) - writer.write_header_and_rows("state_group_state", state_group_state, ( - "position", "type", "state_key", "event_id" - )) - - @defer.inlineCallbacks def caches(self, writer, current_token, limit, request_streams): current_position = current_token.caches @@ -398,6 +379,20 @@ class ReplicationResource(Resource): "position", "cache_func", "keys", "invalidation_ts" )) + @defer.inlineCallbacks + def to_device(self, writer, current_token, limit, request_streams): + current_position = current_token.to_device + + to_device = request_streams.get("to_device") + + if to_device is not None: + to_device_rows = yield self.store.get_all_new_device_messages( + to_device, current_position, limit + ) + writer.write_header_and_rows("to_device", to_device_rows, ( + "position", "user_id", "device_id", "message_json" + )) + class _Writer(object): """Writes the streams as a JSON object as the response to the request""" @@ -426,7 +421,7 @@ class _Writer(object): class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( "events", "presence", "typing", "receipts", "account_data", "backfill", - "push_rules", "pushers", "state", "caches", + "push_rules", "pushers", "state", "caches", "to_device", ))): __slots__ = [] diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py new file mode 100644 index 0000000000..64d8eb2af1 --- /dev/null +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseSlavedStore +from ._slaved_id_tracker import SlavedIdTracker +from synapse.storage import DataStore + + +class SlavedDeviceInboxStore(BaseSlavedStore): + def __init__(self, db_conn, hs): + super(SlavedDeviceInboxStore, self).__init__(db_conn, hs) + self._device_inbox_id_gen = SlavedIdTracker( + db_conn, "device_inbox", "stream_id", + ) + + get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__ + get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__ + delete_messages_for_device = DataStore.delete_messages_for_device.__func__ + + def stream_positions(self): + result = super(SlavedDeviceInboxStore, self).stream_positions() + result["to_device"] = self._device_inbox_id_gen.get_current_token() + return result + + def process_replication(self, result): + stream = result.get("to_device") + if stream: + self._device_inbox_id_gen.advance(int(stream["position"])) + + return super(SlavedDeviceInboxStore, self).process_replication(result) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index f4f31f2d27..cbebd5b2f7 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -120,10 +120,21 @@ class SlavedEventStore(BaseSlavedStore): get_state_for_event = DataStore.get_state_for_event.__func__ get_state_for_events = DataStore.get_state_for_events.__func__ get_state_groups = DataStore.get_state_groups.__func__ + get_state_groups_ids = DataStore.get_state_groups_ids.__func__ + get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__ + get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__ + get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__ + get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__ + _get_joined_users_from_context = ( + RoomMemberStore.__dict__["_get_joined_users_from_context"] + ) + get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__ get_room_events_stream_for_rooms = ( DataStore.get_room_events_stream_for_rooms.__func__ ) + is_host_joined = DataStore.is_host_joined.__func__ + _is_host_joined = RoomMemberStore.__dict__["_is_host_joined"] get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__ _set_before_and_after = staticmethod(DataStore._set_before_and_after) @@ -211,7 +222,6 @@ class SlavedEventStore(BaseSlavedStore): self._get_current_state_for_key.invalidate_all() self.get_rooms_for_user.invalidate_all() self.get_users_in_room.invalidate((event.room_id,)) - # self.get_joined_hosts_for_room.invalidate((event.room_id,)) self._invalidate_get_event_cache(event.event_id) @@ -235,7 +245,6 @@ class SlavedEventStore(BaseSlavedStore): if event.type == EventTypes.Member: self.get_rooms_for_user.invalidate((event.state_key,)) - # self.get_joined_hosts_for_room.invalidate((event.room_id,)) self.get_users_in_room.invalidate((event.room_id,)) self._membership_stream_cache.entity_has_changed( event.state_key, event.internal_metadata.stream_ordering diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 326780405e..f9f5a3e077 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -49,6 +49,7 @@ from synapse.rest.client.v2_alpha import ( notifications, devices, thirdparty, + sendtodevice, ) from synapse.http.server import JsonResource @@ -96,3 +97,4 @@ class ClientRestResource(JsonResource): notifications.register_servlets(hs, client_resource) devices.register_servlets(hs, client_resource) thirdparty.register_servlets(hs, client_resource) + sendtodevice.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py new file mode 100644 index 0000000000..9c10a99acf --- /dev/null +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer +from synapse.http.servlet import parse_json_object_from_request + +from synapse.http import servlet +from synapse.rest.client.v1.transactions import HttpTransactionStore +from ._base import client_v2_patterns + +logger = logging.getLogger(__name__) + + +class SendToDeviceRestServlet(servlet.RestServlet): + PATTERNS = client_v2_patterns( + "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$", + releases=[], v2_alpha=False + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(SendToDeviceRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() + self.is_mine_id = hs.is_mine_id + self.txns = HttpTransactionStore() + + @defer.inlineCallbacks + def on_PUT(self, request, message_type, txn_id): + try: + defer.returnValue( + self.txns.get_client_transaction(request, txn_id) + ) + except KeyError: + pass + + requester = yield self.auth.get_user_by_req(request) + + content = parse_json_object_from_request(request) + + # TODO: Prod the notifier to wake up sync streams. + # TODO: Implement replication for the messages. + # TODO: Send the messages to remote servers if needed. + + local_messages = {} + for user_id, by_device in content["messages"].items(): + if self.is_mine_id(user_id): + messages_by_device = { + device_id: { + "content": message_content, + "type": message_type, + "sender": requester.user.to_string(), + } + for device_id, message_content in by_device.items() + } + if messages_by_device: + local_messages[user_id] = messages_by_device + + stream_id = yield self.store.add_messages_to_device_inbox(local_messages) + + self.notifier.on_new_event( + "to_device_key", stream_id, users=local_messages.keys() + ) + + response = (200, {}) + self.txns.store_client_transaction(request, txn_id, response) + defer.returnValue(response) + + +def register_servlets(hs, http_server): + SendToDeviceRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index b11acdbea7..6fc63715aa 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -97,6 +97,7 @@ class SyncRestServlet(RestServlet): request, allow_guest=True ) user = requester.user + device_id = requester.device_id timeout = parse_integer(request, "timeout", default=0) since = parse_string(request, "since") @@ -109,12 +110,12 @@ class SyncRestServlet(RestServlet): logger.info( "/sync: user=%r, timeout=%r, since=%r," - " set_presence=%r, filter_id=%r" % ( - user, timeout, since, set_presence, filter_id + " set_presence=%r, filter_id=%r, device_id=%r" % ( + user, timeout, since, set_presence, filter_id, device_id ) ) - request_key = (user, timeout, since, filter_id, full_state) + request_key = (user, timeout, since, filter_id, full_state, device_id) if filter_id: if filter_id.startswith('{'): @@ -136,6 +137,7 @@ class SyncRestServlet(RestServlet): filter_collection=filter, is_guest=requester.is_guest, request_key=request_key, + device_id=device_id, ) if since is not None: @@ -173,6 +175,7 @@ class SyncRestServlet(RestServlet): response_content = { "account_data": {"events": sync_result.account_data}, + "to_device": {"events": sync_result.to_device}, "presence": self.encode_presence( sync_result.presence, time_now ), diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 9abca3a8ad..4f6f1a7e17 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -18,15 +18,32 @@ import logging from twisted.internet import defer +from synapse.api.constants import ThirdPartyEntityKind from synapse.http.servlet import RestServlet -from synapse.types import ThirdPartyEntityKind from ._base import client_v2_patterns logger = logging.getLogger(__name__) +class ThirdPartyProtocolsServlet(RestServlet): + PATTERNS = client_v2_patterns("/thirdparty/protocols", releases=()) + + def __init__(self, hs): + super(ThirdPartyProtocolsServlet, self).__init__() + + self.auth = hs.get_auth() + self.appservice_handler = hs.get_application_service_handler() + + @defer.inlineCallbacks + def on_GET(self, request): + yield self.auth.get_user_by_req(request) + + protocols = yield self.appservice_handler.get_3pe_protocols() + defer.returnValue((200, protocols)) + + class ThirdPartyUserServlet(RestServlet): - PATTERNS = client_v2_patterns("/3pu(/(?P<protocol>[^/]+))?$", + PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$", releases=()) def __init__(self, hs): @@ -50,7 +67,7 @@ class ThirdPartyUserServlet(RestServlet): class ThirdPartyLocationServlet(RestServlet): - PATTERNS = client_v2_patterns("/3pl(/(?P<protocol>[^/]+))?$", + PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$", releases=()) def __init__(self, hs): @@ -74,5 +91,6 @@ class ThirdPartyLocationServlet(RestServlet): def register_servlets(hs, http_server): + ThirdPartyProtocolsServlet(hs).register(http_server) ThirdPartyUserServlet(hs).register(http_server) ThirdPartyLocationServlet(hs).register(http_server) diff --git a/synapse/state.py b/synapse/state.py index ef1bc470be..cd792afed1 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.auth import AuthEventTypes from synapse.events.snapshot import EventContext +from synapse.util.async import Linearizer from collections import namedtuple @@ -43,11 +44,35 @@ SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR) EVICTION_TIMEOUT_SECONDS = 60 * 60 +_NEXT_STATE_ID = 1 + + +def _gen_state_id(): + global _NEXT_STATE_ID + s = "X%d" % (_NEXT_STATE_ID,) + _NEXT_STATE_ID += 1 + return s + + class _StateCacheEntry(object): - def __init__(self, state, state_group, ts): + __slots__ = ["state", "state_group", "state_id"] + + def __init__(self, state, state_group): self.state = state self.state_group = state_group + # The `state_id` is a unique ID we generate that can be used as ID for + # this collection of state. Usually this would be the same as the + # state group, but on worker instances we can't generate a new state + # group each time we resolve state, so we generate a separate one that + # isn't persisted and is used solely for caches. + # `state_id` is either a state_group (and so an int) or a string. This + # ensures we don't accidentally persist a state_id as a stateg_group + if state_group: + self.state_id = state_group + else: + self.state_id = _gen_state_id() + class StateHandler(object): """ Responsible for doing state conflict resolution. @@ -60,6 +85,7 @@ class StateHandler(object): # dict of set of event_ids -> _StateCacheEntry. self._state_cache = None + self.resolve_linearizer = Linearizer() def start_caching(self): logger.debug("start_caching") @@ -93,8 +119,32 @@ class StateHandler(object): if not latest_event_ids: latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - res = yield self.resolve_state_groups(room_id, latest_event_ids) - state = res[1] + ret = yield self.resolve_state_groups(room_id, latest_event_ids) + state = ret.state + + if event_type: + event_id = state.get((event_type, state_key)) + event = None + if event_id: + event = yield self.store.get_event(event_id, allow_none=True) + defer.returnValue(event) + return + + state_map = yield self.store.get_events(state.values(), get_prev_content=False) + state = { + key: state_map[e_id] for key, e_id in state.items() if e_id in state_map + } + + defer.returnValue(state) + + @defer.inlineCallbacks + def get_current_state_ids(self, room_id, event_type=None, state_key="", + latest_event_ids=None): + if not latest_event_ids: + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + + ret = yield self.resolve_state_groups(room_id, latest_event_ids) + state = ret.state if event_type: defer.returnValue(state.get((event_type, state_key))) @@ -103,6 +153,15 @@ class StateHandler(object): defer.returnValue(state) @defer.inlineCallbacks + def get_current_user_in_room(self, room_id): + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + entry = yield self.resolve_state_groups(room_id, latest_event_ids) + joined_users = yield self.store.get_joined_users_from_state( + room_id, entry.state_id, entry.state + ) + defer.returnValue(joined_users) + + @defer.inlineCallbacks def compute_event_context(self, event, old_state=None): """ Fills out the context with the `current state` of the graph. The `current state` here is defined to be the state of the event graph @@ -123,54 +182,75 @@ class StateHandler(object): # state. Certainly store.get_current_state won't return any, and # persisting the event won't store the state group. if old_state: - context.current_state = { - (s.type, s.state_key): s for s in old_state + context.prev_state_ids = { + (s.type, s.state_key): s.event_id for s in old_state } + if event.is_state(): + context.current_state_events = dict(context.prev_state_ids) + key = (event.type, event.state_key) + context.current_state_events[key] = event.event_id + else: + context.current_state_events = context.prev_state_ids else: - context.current_state = {} + context.current_state_ids = {} + context.prev_state_ids = {} context.prev_state_events = [] - context.state_group = None + context.state_group = self.store.get_next_state_group() defer.returnValue(context) if old_state: - context.current_state = { - (s.type, s.state_key): s for s in old_state + context.prev_state_ids = { + (s.type, s.state_key): s.event_id for s in old_state } - context.state_group = None + context.state_group = self.store.get_next_state_group() if event.is_state(): key = (event.type, event.state_key) - if key in context.current_state: - replaces = context.current_state[key] - if replaces.event_id != event.event_id: # Paranoia check - event.unsigned["replaces_state"] = replaces.event_id + if key in context.prev_state_ids: + replaces = context.prev_state_ids[key] + if replaces != event.event_id: # Paranoia check + event.unsigned["replaces_state"] = replaces + context.current_state_ids = dict(context.prev_state_ids) + context.current_state_ids[key] = event.event_id + else: + context.current_state_ids = context.prev_state_ids context.prev_state_events = [] defer.returnValue(context) if event.is_state(): - ret = yield self.resolve_state_groups( + entry = yield self.resolve_state_groups( event.room_id, [e for e, _ in event.prev_events], event_type=event.type, state_key=event.state_key, ) else: - ret = yield self.resolve_state_groups( + entry = yield self.resolve_state_groups( event.room_id, [e for e, _ in event.prev_events], ) - group, curr_state, prev_state = ret + curr_state = entry.state - context.current_state = curr_state - context.state_group = group if not event.is_state() else None + context.prev_state_ids = curr_state + if event.is_state(): + context.state_group = self.store.get_next_state_group() + else: + if entry.state_group is None: + entry.state_group = self.store.get_next_state_group() + entry.state_id = entry.state_group + context.state_group = entry.state_group if event.is_state(): key = (event.type, event.state_key) - if key in context.current_state: - replaces = context.current_state[key] - event.unsigned["replaces_state"] = replaces.event_id + if key in context.prev_state_ids: + replaces = context.prev_state_ids[key] + event.unsigned["replaces_state"] = replaces + context.current_state_ids = dict(context.prev_state_ids) + context.current_state_ids[key] = event.event_id + else: + context.current_state_ids = context.prev_state_ids - context.prev_state_events = prev_state + context.prev_state_events = [] defer.returnValue(context) @defer.inlineCallbacks @@ -187,72 +267,88 @@ class StateHandler(object): """ logger.debug("resolve_state_groups event_ids %s", event_ids) - state_groups = yield self.store.get_state_groups( + state_groups_ids = yield self.store.get_state_groups_ids( room_id, event_ids ) logger.debug( "resolve_state_groups state_groups %s", - state_groups.keys() + state_groups_ids.keys() ) - group_names = frozenset(state_groups.keys()) + group_names = frozenset(state_groups_ids.keys()) if len(group_names) == 1: - name, state_list = state_groups.items().pop() - state = { - (e.type, e.state_key): e - for e in state_list - } - prev_state = state.get((event_type, state_key), None) - if prev_state: - prev_state = prev_state.event_id - prev_states = [prev_state] - else: - prev_states = [] + name, state_list = state_groups_ids.items().pop() - defer.returnValue((name, state, prev_states)) + defer.returnValue(_StateCacheEntry( + state=state_list, + state_group=name, + )) - if self._state_cache is not None: - cache = self._state_cache.get(group_names, None) - if cache: - cache.ts = self.clock.time_msec() + with (yield self.resolve_linearizer.queue(group_names)): + if self._state_cache is not None: + cache = self._state_cache.get(group_names, None) + if cache: + defer.returnValue(cache) - event_dict = yield self.store.get_events(cache.state.values()) - state = {(e.type, e.state_key): e for e in event_dict.values()} + logger.info( + "Resolving state for %s with %d groups", room_id, len(state_groups_ids) + ) - prev_state = state.get((event_type, state_key), None) - if prev_state: - prev_state = prev_state.event_id - prev_states = [prev_state] - else: - prev_states = [] - defer.returnValue( - (cache.state_group, state, prev_states) - ) + state = {} + for st in state_groups_ids.values(): + for key, e_id in st.items(): + state.setdefault(key, set()).add(e_id) - logger.info("Resolving state for %s with %d groups", room_id, len(state_groups)) + conflicted_state = { + k: list(v) + for k, v in state.items() + if len(v) > 1 + } - new_state, prev_states = self._resolve_events( - state_groups.values(), event_type, state_key - ) + if conflicted_state: + logger.info("Resolving conflicted state for %r", room_id) + state_map = yield self.store.get_events( + [e_id for st in state_groups_ids.values() for e_id in st.values()], + get_prev_content=False + ) + state_sets = [ + [state_map[e_id] for key, e_id in st.items() if e_id in state_map] + for st in state_groups_ids.values() + ] + new_state, _ = self._resolve_events( + state_sets, event_type, state_key + ) + new_state = { + key: e.event_id for key, e in new_state.items() + } + else: + new_state = { + key: e_ids.pop() for key, e_ids in state.items() + } - state_group = None - new_state_event_ids = frozenset(e.event_id for e in new_state.values()) - for sg, events in state_groups.items(): - if new_state_event_ids == frozenset(e.event_id for e in events): - state_group = sg - break + state_group = None + new_state_event_ids = frozenset(new_state.values()) + for sg, events in state_groups_ids.items(): + if new_state_event_ids == frozenset(e_id for e_id in events): + state_group = sg + break + if state_group is None: + # Worker instances don't have access to this method, but we want + # to set the state_group on the main instance to increase cache + # hits. + if hasattr(self.store, "get_next_state_group"): + state_group = self.store.get_next_state_group() - if self._state_cache is not None: cache = _StateCacheEntry( - state={key: event.event_id for key, event in new_state.items()}, + state=new_state, state_group=state_group, - ts=self.clock.time_msec() ) - self._state_cache[group_names] = cache + if self._state_cache is not None: + self._state_cache[group_names] = cache - defer.returnValue((state_group, new_state, prev_states)) + defer.returnValue(cache) def resolve_events(self, state_sets, event): logger.info( diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 7efc5bfeef..6c32773f25 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -36,6 +36,7 @@ from .push_rule import PushRuleStore from .media_repository import MediaRepositoryStore from .rejections import RejectionsStore from .event_push_actions import EventPushActionsStore +from .deviceinbox import DeviceInboxStore from .state import StateStore from .signatures import SignatureStore @@ -84,6 +85,7 @@ class DataStore(RoomMemberStore, RoomStore, OpenIdStore, ClientIpStore, DeviceStore, + DeviceInboxStore, ): def __init__(self, db_conn, hs): @@ -108,9 +110,12 @@ class DataStore(RoomMemberStore, RoomStore, self._presence_id_gen = StreamIdGenerator( db_conn, "presence_stream", "stream_id" ) + self._device_inbox_id_gen = StreamIdGenerator( + db_conn, "device_inbox", "stream_id" + ) self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") - self._state_groups_id_gen = StreamIdGenerator(db_conn, "state_groups", "id") + self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py new file mode 100644 index 0000000000..68116b0394 --- /dev/null +++ b/synapse/storage/deviceinbox.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import ujson + +from twisted.internet import defer + +from ._base import SQLBaseStore + + +logger = logging.getLogger(__name__) + + +class DeviceInboxStore(SQLBaseStore): + + @defer.inlineCallbacks + def add_messages_to_device_inbox(self, messages_by_user_then_device): + """ + Args: + messages_by_user_and_device(dict): + Dictionary of user_id to device_id to message. + Returns: + A deferred stream_id that resolves when the messages have been + inserted. + """ + + def select_devices_txn(txn, user_id, devices): + if not devices: + return [] + sql = ( + "SELECT user_id, device_id FROM devices" + " WHERE user_id = ? AND device_id IN (" + + ",".join("?" * len(devices)) + + ")" + ) + # TODO: Maybe this needs to be done in batches if there are + # too many local devices for a given user. + args = [user_id] + devices + txn.execute(sql, args) + return [tuple(row) for row in txn.fetchall()] + + def add_messages_to_device_inbox_txn(txn, stream_id): + local_users_and_devices = set() + for user_id, messages_by_device in messages_by_user_then_device.items(): + local_users_and_devices.update( + select_devices_txn(txn, user_id, messages_by_device.keys()) + ) + + sql = ( + "INSERT INTO device_inbox" + " (user_id, device_id, stream_id, message_json)" + " VALUES (?,?,?,?)" + ) + rows = [] + for user_id, messages_by_device in messages_by_user_then_device.items(): + for device_id, message in messages_by_device.items(): + message_json = ujson.dumps(message) + # Only insert into the local inbox if the device exists on + # this server + if (user_id, device_id) in local_users_and_devices: + rows.append((user_id, device_id, stream_id, message_json)) + + txn.executemany(sql, rows) + + with self._device_inbox_id_gen.get_next() as stream_id: + yield self.runInteraction( + "add_messages_to_device_inbox", + add_messages_to_device_inbox_txn, + stream_id + ) + + defer.returnValue(self._device_inbox_id_gen.get_current_token()) + + def get_new_messages_for_device( + self, user_id, device_id, last_stream_id, current_stream_id, limit=100 + ): + """ + Args: + user_id(str): The recipient user_id. + device_id(str): The recipient device_id. + current_stream_id(int): The current position of the to device + message stream. + Returns: + Deferred ([dict], int): List of messages for the device and where + in the stream the messages got to. + """ + def get_new_messages_for_device_txn(txn): + sql = ( + "SELECT stream_id, message_json FROM device_inbox" + " WHERE user_id = ? AND device_id = ?" + " AND ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + " LIMIT ?" + ) + txn.execute(sql, ( + user_id, device_id, last_stream_id, current_stream_id, limit + )) + messages = [] + for row in txn.fetchall(): + stream_pos = row[0] + messages.append(ujson.loads(row[1])) + if len(messages) < limit: + stream_pos = current_stream_id + return (messages, stream_pos) + + return self.runInteraction( + "get_new_messages_for_device", get_new_messages_for_device_txn, + ) + + def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): + """ + Args: + user_id(str): The recipient user_id. + device_id(str): The recipient device_id. + up_to_stream_id(int): Where to delete messages up to. + Returns: + A deferred that resolves when the messages have been deleted. + """ + def delete_messages_for_device_txn(txn): + sql = ( + "DELETE FROM device_inbox" + " WHERE user_id = ? AND device_id = ?" + " AND stream_id <= ?" + ) + txn.execute(sql, (user_id, device_id, up_to_stream_id)) + + return self.runInteraction( + "delete_messages_for_device", delete_messages_for_device_txn + ) + + def get_all_new_device_messages(self, last_pos, current_pos, limit): + """ + Args: + last_pos(int): + current_pos(int): + limit(int): + Returns: + A deferred list of rows from the device inbox + """ + if last_pos == current_pos: + return defer.succeed([]) + + def get_all_new_device_messages_txn(txn): + sql = ( + "SELECT stream_id FROM device_inbox" + " WHERE ? < stream_id AND stream_id <= ?" + " GROUP BY stream_id" + " ORDER BY stream_id ASC" + " LIMIT ?" + ) + txn.execute(sql, (last_pos, current_pos, limit)) + stream_ids = txn.fetchall() + if not stream_ids: + return [] + max_stream_id_in_limit = stream_ids[-1] + + sql = ( + "SELECT stream_id, user_id, device_id, message_json" + " FROM device_inbox" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + ) + txn.execute(sql, (last_pos, max_stream_id_in_limit)) + return txn.fetchall() + + return self.runInteraction( + "get_all_new_device_messages", get_all_new_device_messages_txn + ) + + def get_to_device_stream_token(self): + return self._device_inbox_id_gen.get_current_token() diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 57e5005285..1a7d4c5199 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -271,39 +271,28 @@ class EventsStore(SQLBaseStore): len(events_and_contexts) ) - state_group_id_manager = self._state_groups_id_gen.get_next_mult( - len(events_and_contexts) - ) with stream_ordering_manager as stream_orderings: - with state_group_id_manager as state_group_ids: - for (event, context), stream, state_group_id in zip( - events_and_contexts, stream_orderings, state_group_ids - ): - event.internal_metadata.stream_ordering = stream - # Assign a state group_id in case a new id is needed for - # this context. In theory we only need to assign this - # for contexts that have current_state and aren't outliers - # but that make the code more complicated. Assigning an ID - # per event only causes the state_group_ids to grow as fast - # as the stream_ordering so in practise shouldn't be a problem. - context.new_state_group_id = state_group_id - - chunks = [ - events_and_contexts[x:x + 100] - for x in xrange(0, len(events_and_contexts), 100) - ] + for (event, context), stream, in zip( + events_and_contexts, stream_orderings + ): + event.internal_metadata.stream_ordering = stream - for chunk in chunks: - # We can't easily parallelize these since different chunks - # might contain the same event. :( - yield self.runInteraction( - "persist_events", - self._persist_events_txn, - events_and_contexts=chunk, - backfilled=backfilled, - delete_existing=delete_existing, - ) - persist_event_counter.inc_by(len(chunk)) + chunks = [ + events_and_contexts[x:x + 100] + for x in xrange(0, len(events_and_contexts), 100) + ] + + for chunk in chunks: + # We can't easily parallelize these since different chunks + # might contain the same event. :( + yield self.runInteraction( + "persist_events", + self._persist_events_txn, + events_and_contexts=chunk, + backfilled=backfilled, + delete_existing=delete_existing, + ) + persist_event_counter.inc_by(len(chunk)) @_retry_on_integrity_error @defer.inlineCallbacks @@ -312,19 +301,17 @@ class EventsStore(SQLBaseStore): delete_existing=False): try: with self._stream_id_gen.get_next() as stream_ordering: - with self._state_groups_id_gen.get_next() as state_group_id: - event.internal_metadata.stream_ordering = stream_ordering - context.new_state_group_id = state_group_id - yield self.runInteraction( - "persist_event", - self._persist_event_txn, - event=event, - context=context, - current_state=current_state, - backfilled=backfilled, - delete_existing=delete_existing, - ) - persist_event_counter.inc() + event.internal_metadata.stream_ordering = stream_ordering + yield self.runInteraction( + "persist_event", + self._persist_event_txn, + event=event, + context=context, + current_state=current_state, + backfilled=backfilled, + delete_existing=delete_existing, + ) + persist_event_counter.inc() except _RollbackButIsFineException: pass @@ -393,7 +380,6 @@ class EventsStore(SQLBaseStore): txn.call_after(self._get_current_state_for_key.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) - txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) # Add an entry to the current_state_resets table to record the point # where we clobbered the current state @@ -529,7 +515,7 @@ class EventsStore(SQLBaseStore): # Add an entry to the ex_outlier_stream table to replicate the # change in outlier status to our workers. stream_order = event.internal_metadata.stream_ordering - state_group_id = context.state_group or context.new_state_group_id + state_group_id = context.state_group self._simple_insert_txn( txn, table="ex_outlier_stream", diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 78334a98cf..49721656b6 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -16,7 +16,6 @@ from ._base import SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.push.baserules import list_with_base_rules -from synapse.api.constants import EventTypes, Membership from twisted.internet import defer import logging @@ -124,7 +123,8 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(results) - def bulk_get_push_rules_for_room(self, room_id, state_group, current_state): + def bulk_get_push_rules_for_room(self, event, context): + state_group = context.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -132,11 +132,13 @@ class PushRuleStore(SQLBaseStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - return self._bulk_get_push_rules_for_room(room_id, state_group, current_state) + return self._bulk_get_push_rules_for_room( + event.room_id, state_group, context.current_state_ids, event=event + ) @cachedInlineCallbacks(num_args=2, cache_context=True) - def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state, - cache_context): + def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids, + cache_context, event=None): # We don't use `state_group`, its there so that we can cache based # on it. However, its important that its never None, since two current_state's # with a state_group of None are likely to be different. @@ -147,12 +149,15 @@ class PushRuleStore(SQLBaseStore): # their unread countss are correct in the event stream, but to avoid # generating them for bot / AS users etc, we only do so for people who've # sent a read receipt into the room. - local_users_in_room = set( - e.state_key for e in current_state.values() - if e.type == EventTypes.Member and e.membership == Membership.JOIN - and self.hs.is_mine_id(e.state_key) + + users_in_room = yield self._get_joined_users_from_context( + room_id, state_group, current_state_ids, + on_invalidate=cache_context.invalidate, + event=event, ) + local_users_in_room = set(u for u in users_in_room if self.hs.is_mine_id(u)) + # users in the room who have pushers need to get push rules run because # that's how their pushers work if_users_with_pushers = yield self.get_if_users_have_pushers( diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index ccc3811e84..9747a04a9a 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -145,7 +145,7 @@ class ReceiptsStore(SQLBaseStore): defer.returnValue([ev for res in results.values() for ev in res]) - @cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True) + @cachedInlineCallbacks(num_args=3, tree=True) def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): """Get receipts for a single room for sending to clients. diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index a422ddf633..6ab10db328 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -20,7 +20,7 @@ from collections import namedtuple from ._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -from synapse.api.constants import Membership +from synapse.api.constants import Membership, EventTypes from synapse.types import get_domain_from_id import logging @@ -56,7 +56,6 @@ class RoomMemberStore(SQLBaseStore): for event in events: txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,)) - txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) txn.call_after( self._membership_stream_cache.entity_has_changed, @@ -238,11 +237,6 @@ class RoomMemberStore(SQLBaseStore): return results - @cachedInlineCallbacks(max_entries=5000) - def get_joined_hosts_for_room(self, room_id): - user_ids = yield self.get_users_in_room(room_id) - defer.returnValue(set(get_domain_from_id(uid) for uid in user_ids)) - def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None): where_clause = "c.room_id = ?" where_values = [room_id] @@ -325,7 +319,8 @@ class RoomMemberStore(SQLBaseStore): @cachedInlineCallbacks(num_args=3) def was_forgotten_at(self, user_id, room_id, event_id): - """Returns whether user_id has elected to discard history for room_id at event_id. + """Returns whether user_id has elected to discard history for room_id at + event_id. event_id must be a membership event.""" def f(txn): @@ -358,3 +353,98 @@ class RoomMemberStore(SQLBaseStore): }, desc="who_forgot" ) + + def get_joined_users_from_context(self, event, context): + state_group = context.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_users_from_context( + event.room_id, state_group, context.current_state_ids, event=event, + ) + + def get_joined_users_from_state(self, room_id, state_group, state_ids): + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_users_from_context( + room_id, state_group, state_ids, + ) + + @cachedInlineCallbacks(num_args=2, cache_context=True) + def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, + cache_context, event=None): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + member_event_ids = [ + e_id + for key, e_id in current_state_ids.iteritems() + if key[0] == EventTypes.Member + ] + + rows = yield self._simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=member_event_ids, + retcols=['user_id'], + keyvalues={ + "membership": Membership.JOIN, + }, + batch_size=1000, + desc="_get_joined_users_from_context", + ) + + users_in_room = set(row["user_id"] for row in rows) + if event is not None and event.type == EventTypes.Member: + if event.membership == Membership.JOIN: + if event.event_id in member_event_ids: + users_in_room.add(event.state_key) + + defer.returnValue(users_in_room) + + def is_host_joined(self, room_id, host, state_group, state_ids): + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._is_host_joined( + room_id, host, state_group, state_ids + ) + + @cachedInlineCallbacks(num_args=3) + def _is_host_joined(self, room_id, host, state_group, current_state_ids): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + for (etype, state_key), event_id in current_state_ids.items(): + if etype == EventTypes.Member: + try: + if get_domain_from_id(state_key) != host: + continue + except: + logger.warn("state_key not user_id: %s", state_key) + continue + + event = yield self.get_event(event_id, allow_none=True) + if event and event.content["membership"] == Membership.JOIN: + defer.returnValue(True) + + defer.returnValue(False) diff --git a/synapse/storage/schema/delta/34/device_inbox.sql b/synapse/storage/schema/delta/34/device_inbox.sql new file mode 100644 index 0000000000..e68844c74a --- /dev/null +++ b/synapse/storage/schema/delta/34/device_inbox.sql @@ -0,0 +1,24 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE device_inbox ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + stream_id BIGINT NOT NULL, + message_json TEXT NOT NULL -- {"type":, "sender":, "content",} +); + +CREATE INDEX device_inbox_user_stream_id ON device_inbox(user_id, device_id, stream_id); +CREATE INDEX device_inbox_stream_id ON device_inbox(stream_id); diff --git a/synapse/storage/schema/delta/34/sent_txn_purge.py b/synapse/storage/schema/delta/34/sent_txn_purge.py new file mode 100644 index 0000000000..81948e3431 --- /dev/null +++ b/synapse/storage/schema/delta/34/sent_txn_purge.py @@ -0,0 +1,32 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.engines import PostgresEngine + +import logging + +logger = logging.getLogger(__name__) + + +def run_create(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + cur.execute("TRUNCATE sent_transactions") + else: + cur.execute("DELETE FROM sent_transactions") + + cur.execute("CREATE INDEX sent_transactions_ts ON sent_transactions(ts)") + + +def run_upgrade(cur, database_engine, *args, **kwargs): + pass diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 0e8fa93e1f..ec551b0b4f 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -44,11 +44,7 @@ class StateStore(SQLBaseStore): """ @defer.inlineCallbacks - def get_state_groups(self, room_id, event_ids): - """ Get the state groups for the given list of event_ids - - The return value is a dict mapping group names to lists of events. - """ + def get_state_groups_ids(self, room_id, event_ids): if not event_ids: defer.returnValue({}) @@ -59,36 +55,64 @@ class StateStore(SQLBaseStore): groups = set(event_to_groups.values()) group_to_state = yield self._get_state_for_groups(groups) + defer.returnValue(group_to_state) + + @defer.inlineCallbacks + def get_state_groups(self, room_id, event_ids): + """ Get the state groups for the given list of event_ids + + The return value is a dict mapping group names to lists of events. + """ + if not event_ids: + defer.returnValue({}) + + group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) + + state_event_map = yield self.get_events( + [ + ev_id for group_ids in group_to_ids.values() + for ev_id in group_ids.values() + ], + get_prev_content=False + ) + defer.returnValue({ - group: state_map.values() - for group, state_map in group_to_state.items() + group: [ + state_event_map[v] for v in event_id_map.values() if v in state_event_map + ] + for group, event_id_map in group_to_ids.items() }) + def _have_persisted_state_group_txn(self, txn, state_group): + txn.execute( + "SELECT count(*) FROM state_groups WHERE id = ?", + (state_group,) + ) + row = txn.fetchone() + return row and row[0] + def _store_mult_state_groups_txn(self, txn, events_and_contexts): state_groups = {} for event, context in events_and_contexts: if event.internal_metadata.is_outlier(): continue - if context.current_state is None: - continue - - if context.state_group is not None: - state_groups[event.event_id] = context.state_group + if context.current_state_ids is None: continue - state_events = dict(context.current_state) + state_groups[event.event_id] = context.state_group - if event.is_state(): - state_events[(event.type, event.state_key)] = event + if self._have_persisted_state_group_txn(txn, context.state_group): + logger.info("Already persisted state_group: %r", context.state_group) + continue - state_group = context.new_state_group_id + state_event_ids = dict(context.current_state_ids) self._simple_insert_txn( txn, table="state_groups", values={ - "id": state_group, + "id": context.state_group, "room_id": event.room_id, "event_id": event.event_id, }, @@ -99,16 +123,15 @@ class StateStore(SQLBaseStore): table="state_groups_state", values=[ { - "state_group": state_group, - "room_id": state.room_id, - "type": state.type, - "state_key": state.state_key, - "event_id": state.event_id, + "state_group": context.state_group, + "room_id": event.room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, } - for state in state_events.values() + for key, state_id in state_event_ids.items() ], ) - state_groups[event.event_id] = state_group self._simple_insert_many_txn( txn, @@ -248,6 +271,31 @@ class StateStore(SQLBaseStore): groups = set(event_to_groups.values()) group_to_state = yield self._get_state_for_groups(groups, types) + state_event_map = yield self.get_events( + [ev_id for sd in group_to_state.values() for ev_id in sd.values()], + get_prev_content=False + ) + + event_to_state = { + event_id: { + k: state_event_map[v] + for k, v in group_to_state[group].items() + if v in state_event_map + } + for event_id, group in event_to_groups.items() + } + + defer.returnValue({event: event_to_state[event] for event in event_ids}) + + @defer.inlineCallbacks + def get_state_ids_for_events(self, event_ids, types): + event_to_groups = yield self._get_state_group_for_events( + event_ids, + ) + + groups = set(event_to_groups.values()) + group_to_state = yield self._get_state_for_groups(groups, types) + event_to_state = { event_id: group_to_state[group] for event_id, group in event_to_groups.items() @@ -272,6 +320,23 @@ class StateStore(SQLBaseStore): state_map = yield self.get_state_for_events([event_id], types) defer.returnValue(state_map[event_id]) + @defer.inlineCallbacks + def get_state_ids_for_event(self, event_id, types=None): + """ + Get the state dict corresponding to a particular event + + Args: + event_id(str): event whose state should be returned + types(list[(str, str)]|None): List of (type, state_key) tuples + which are used to filter the state fetched. May be None, which + matches any key + + Returns: + A deferred dict from (type, state_key) -> state_event + """ + state_map = yield self.get_state_ids_for_events([event_id], types) + defer.returnValue(state_map[event_id]) + @cached(num_args=2, max_entries=10000) def _get_state_group_for_event(self, room_id, event_id): return self._simple_select_one_onecol( @@ -428,20 +493,13 @@ class StateStore(SQLBaseStore): full=(types is None), ) - state_events = yield self._get_events( - [ev_id for sd in results.values() for ev_id in sd.values()], - get_prev_content=False - ) - - state_events = {e.event_id: e for e in state_events} - # Remove all the entries with None values. The None values were just # used for bookkeeping in the cache. for group, state_dict in results.items(): results[group] = { - key: state_events[event_id] + key: event_id for key, event_id in state_dict.items() - if event_id and event_id in state_events + if event_id } defer.returnValue(results) @@ -473,5 +531,5 @@ class StateStore(SQLBaseStore): "get_all_new_state_groups", get_all_new_state_groups_txn ) - def get_state_stream_token(self): - return self._state_groups_id_gen.get_current_token() + def get_next_state_group(self): + return self._state_groups_id_gen.get_next() diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 58d4de4f1d..5055c04b24 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -245,7 +245,7 @@ class TransactionStore(SQLBaseStore): return self.cursor_to_dict(txn) - @cached() + @cached(max_entries=10000) def get_destination_retry_timings(self, destination): """Gets the current retry timings (if any) for a given destination. @@ -387,8 +387,10 @@ class TransactionStore(SQLBaseStore): def _cleanup_transactions(self): now = self._clock.time_msec() month_ago = now - 30 * 24 * 60 * 60 * 1000 + six_hours_ago = now - 6 * 60 * 60 * 1000 def _cleanup_transactions_txn(txn): txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) + txn.execute("DELETE FROM sent_transactions WHERE ts < ?", (six_hours_ago,)) return self.runInteraction("_persist_in_mem_txns", _cleanup_transactions_txn) diff --git a/synapse/streams/events.py b/synapse/streams/events.py index d4c0bb6732..6bf21d6f5e 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -43,6 +43,7 @@ class EventSources(object): @defer.inlineCallbacks def get_current_token(self, direction='f'): push_rules_key, _ = self.store.get_push_rules_stream_token() + to_device_key = self.store.get_to_device_stream_token() token = StreamToken( room_key=( @@ -61,5 +62,6 @@ class EventSources(object): yield self.sources["account_data"].get_current_key() ), push_rules_key=push_rules_key, + to_device_key=to_device_key, ) defer.returnValue(token) diff --git a/synapse/types.py b/synapse/types.py index fd17ecbbe0..9d64e8c4de 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -154,6 +154,7 @@ class StreamToken( "receipt_key", "account_data_key", "push_rules_key", + "to_device_key", )) ): _SEPARATOR = "_" @@ -190,6 +191,7 @@ class StreamToken( or (int(other.receipt_key) < int(self.receipt_key)) or (int(other.account_data_key) < int(self.account_data_key)) or (int(other.push_rules_key) < int(self.push_rules_key)) + or (int(other.to_device_key) < int(self.to_device_key)) ) def copy_and_advance(self, key, new_value): @@ -269,10 +271,3 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): return "t%d-%d" % (self.topological, self.stream) else: return "s%d" % (self.stream,) - - -# Some arbitrary constants used for internal API enumerations. Don't rely on -# exact values; always pass or compare symbolically -class ThirdPartyEntityKind(object): - USER = 'user' - LOCATION = 'location' diff --git a/synapse/visibility.py b/synapse/visibility.py index cc12c0a23d..199b16d827 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -181,6 +181,25 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state): @defer.inlineCallbacks +def filter_events_for_clients_context(store, user_tuples, events, event_id_to_context): + user_ids = set(u[0] for u in user_tuples) + event_id_to_state = {} + for event_id, context in event_id_to_context.items(): + state = yield store.get_events([ + e_id + for key, e_id in context.current_state_ids.iteritems() + if key == (EventTypes.RoomHistoryVisibility, "") + or (key[0] == EventTypes.Member and key[1] in user_ids) + ]) + event_id_to_state[event_id] = state + + res = yield filter_events_for_clients( + store, user_tuples, events, event_id_to_state + ) + defer.returnValue(res) + + +@defer.inlineCallbacks def filter_events_for_client(store, user_id, events, is_peeking=False): """ Check which events a user is allowed to see diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index b531ba8540..d9e8f634ae 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -115,6 +115,53 @@ class PresenceUpdateTestCase(unittest.TestCase): ), ], any_order=True) + def test_online_to_online_last_active_noop(self): + wheel_timer = Mock() + user_id = "@foo:bar" + now = 5000000 + + prev_state = UserPresenceState.default(user_id) + prev_state = prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=now - LAST_ACTIVE_GRANULARITY - 10, + currently_active=True, + ) + + new_state = prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=now, + ) + + state, persist_and_notify, federation_ping = handle_update( + prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now + ) + + self.assertFalse(persist_and_notify) + self.assertTrue(federation_ping) + self.assertTrue(state.currently_active) + self.assertEquals(new_state.state, state.state) + self.assertEquals(new_state.status_msg, state.status_msg) + self.assertEquals(state.last_federation_update_ts, now) + + self.assertEquals(wheel_timer.insert.call_count, 3) + wheel_timer.insert.assert_has_calls([ + call( + now=now, + obj=user_id, + then=new_state.last_active_ts + IDLE_TIMER + ), + call( + now=now, + obj=user_id, + then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT + ), + call( + now=now, + obj=user_id, + then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY + ), + ], any_order=True) + def test_online_to_online_last_active(self): wheel_timer = Mock() user_id = "@foo:bar" diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index ab9899b7d5..b2957eef9f 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -62,6 +62,7 @@ class TypingNotificationsTestCase(unittest.TestCase): self.on_new_event = mock_notifier.on_new_event self.auth = Mock(spec=[]) + self.state_handler = Mock() hs = yield setup_test_homeserver( "test", @@ -75,6 +76,7 @@ class TypingNotificationsTestCase(unittest.TestCase): "set_received_txn_response", "get_destination_retry_timings", ]), + state_handler=self.state_handler, handlers=None, notifier=mock_notifier, resource_for_client=Mock(), @@ -113,6 +115,10 @@ class TypingNotificationsTestCase(unittest.TestCase): return set(member.domain for member in self.room_members) self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room + def get_current_user_in_room(room_id): + return set(str(u) for u in self.room_members) + self.state_handler.get_current_user_in_room = get_current_user_in_room + self.auth.check_joined_room = check_joined_room # Some local users to test with diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index f33e6f60fb..44e859b5d1 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -305,7 +305,16 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.event_id += 1 - context = EventContext(current_state=state) + if state is not None: + state_ids = { + key: e.event_id for key, e in state.items() + } + else: + state_ids = None + + context = EventContext() + context.current_state_ids = state_ids + context.prev_state_ids = state_ids context.push_actions = push_actions ordering = None diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py index e70ac6f14d..b69832cc1b 100644 --- a/tests/replication/test_resource.py +++ b/tests/replication/test_resource.py @@ -60,8 +60,8 @@ class ReplicationResourceCase(unittest.TestCase): self.assertEquals(body, {}) @defer.inlineCallbacks - def test_events_and_state(self): - get = self.get(events="-1", state="-1", timeout="0") + def test_events(self): + get = self.get(events="-1", timeout="0") yield self.hs.get_handlers().room_creation_handler.create_room( synapse.types.create_requester(self.user), {} ) @@ -70,12 +70,6 @@ class ReplicationResourceCase(unittest.TestCase): self.assertEquals(body["events"]["field_names"], [ "position", "internal", "json", "state_group" ]) - self.assertEquals(body["state_groups"]["field_names"], [ - "position", "room_id", "event_id" - ]) - self.assertEquals(body["state_group_state"]["field_names"], [ - "position", "type", "state_key", "event_id" - ]) @defer.inlineCallbacks def test_presence(self): diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 8853cbb5fc..4fe99ebc0b 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase): @defer.inlineCallbacks def test_topo_token_is_accepted(self): - token = "t1-0_0_0_0_0_0" + token = "t1-0_0_0_0_0_0_0" (code, response) = yield self.mock_resource.trigger_get( "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)) @@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase): @defer.inlineCallbacks def test_stream_token_is_accepted_for_fwd_pagianation(self): - token = "s0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0" (code, response) = yield self.mock_resource.trigger_get( "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 27b2b3d123..1be7d932f6 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -78,44 +78,3 @@ class RoomMemberStoreTestCase(unittest.TestCase): ) )] ) - - @defer.inlineCallbacks - def test_room_hosts(self): - yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN) - - self.assertEquals( - {"test"}, - (yield self.store.get_joined_hosts_for_room(self.room.to_string())) - ) - - # Should still have just one host after second join from it - yield self.inject_room_member(self.room, self.u_bob, Membership.JOIN) - - self.assertEquals( - {"test"}, - (yield self.store.get_joined_hosts_for_room(self.room.to_string())) - ) - - # Should now have two hosts after join from other host - yield self.inject_room_member(self.room, self.u_charlie, Membership.JOIN) - - self.assertEquals( - {"test", "elsewhere"}, - (yield self.store.get_joined_hosts_for_room(self.room.to_string())) - ) - - # Should still have both hosts - yield self.inject_room_member(self.room, self.u_alice, Membership.LEAVE) - - self.assertEquals( - {"test", "elsewhere"}, - (yield self.store.get_joined_hosts_for_room(self.room.to_string())) - ) - - # Should have only one host after other leaves - yield self.inject_room_member(self.room, self.u_charlie, Membership.LEAVE) - - self.assertEquals( - {"test"}, - (yield self.store.get_joined_hosts_for_room(self.room.to_string())) - ) diff --git a/tests/test_state.py b/tests/test_state.py index 1a11bbcee0..6454f994e3 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -67,9 +67,11 @@ class StateGroupStore(object): self._event_to_state_group = {} self._group_to_state = {} + self._event_id_to_event = {} + self._next_group = 1 - def get_state_groups(self, room_id, event_ids): + def get_state_groups_ids(self, room_id, event_ids): groups = {} for event_id in event_ids: group = self._event_to_state_group.get(event_id) @@ -79,22 +81,23 @@ class StateGroupStore(object): return defer.succeed(groups) def store_state_groups(self, event, context): - if context.current_state is None: + if context.current_state_ids is None: return - state_events = context.current_state - - if event.is_state(): - state_events[(event.type, event.state_key)] = event + state_events = dict(context.current_state_ids) - state_group = context.state_group - if not state_group: - state_group = self._next_group - self._next_group += 1 + self._group_to_state[context.state_group] = state_events + self._event_to_state_group[event.event_id] = context.state_group - self._group_to_state[state_group] = state_events.values() + def get_events(self, event_ids, **kwargs): + return { + e_id: self._event_id_to_event[e_id] for e_id in event_ids + if e_id in self._event_id_to_event + } - self._event_to_state_group[event.event_id] = state_group + def register_events(self, events): + for e in events: + self._event_id_to_event[e.event_id] = e class DictObj(dict): @@ -136,8 +139,10 @@ class StateTestCase(unittest.TestCase): def setUp(self): self.store = Mock( spec_set=[ - "get_state_groups", + "get_state_groups_ids", "add_event_hashes", + "get_events", + "get_next_state_group", ] ) hs = Mock(spec_set=[ @@ -148,6 +153,8 @@ class StateTestCase(unittest.TestCase): hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) + self.store.get_next_state_group.side_effect = Mock + self.state = StateHandler(hs) self.event_id = 0 @@ -187,7 +194,7 @@ class StateTestCase(unittest.TestCase): ) store = StateGroupStore() - self.store.get_state_groups.side_effect = store.get_state_groups + self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids context_store = {} @@ -196,7 +203,7 @@ class StateTestCase(unittest.TestCase): store.store_state_groups(event, context) context_store[event.event_id] = context - self.assertEqual(2, len(context_store["D"].current_state)) + self.assertEqual(2, len(context_store["D"].prev_state_ids)) @defer.inlineCallbacks def test_branch_basic_conflict(self): @@ -239,7 +246,9 @@ class StateTestCase(unittest.TestCase): ) store = StateGroupStore() - self.store.get_state_groups.side_effect = store.get_state_groups + self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids + self.store.get_events = store.get_events + store.register_events(graph.walk()) context_store = {} @@ -250,7 +259,7 @@ class StateTestCase(unittest.TestCase): self.assertSetEqual( {"START", "A", "C"}, - {e.event_id for e in context_store["D"].current_state.values()} + {e_id for e_id in context_store["D"].prev_state_ids.values()} ) @defer.inlineCallbacks @@ -303,7 +312,9 @@ class StateTestCase(unittest.TestCase): ) store = StateGroupStore() - self.store.get_state_groups.side_effect = store.get_state_groups + self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids + self.store.get_events = store.get_events + store.register_events(graph.walk()) context_store = {} @@ -314,7 +325,7 @@ class StateTestCase(unittest.TestCase): self.assertSetEqual( {"START", "A", "B", "C"}, - {e.event_id for e in context_store["E"].current_state.values()} + {e for e in context_store["E"].prev_state_ids.values()} ) @defer.inlineCallbacks @@ -384,7 +395,9 @@ class StateTestCase(unittest.TestCase): graph = Graph(nodes, edges) store = StateGroupStore() - self.store.get_state_groups.side_effect = store.get_state_groups + self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids + self.store.get_events = store.get_events + store.register_events(graph.walk()) context_store = {} @@ -395,7 +408,7 @@ class StateTestCase(unittest.TestCase): self.assertSetEqual( {"A1", "A2", "A3", "A5", "B"}, - {e.event_id for e in context_store["D"].current_state.values()} + {e for e in context_store["D"].prev_state_ids.values()} ) def _add_depths(self, nodes, edges): @@ -424,16 +437,11 @@ class StateTestCase(unittest.TestCase): event, old_state=old_state ) - for k, v in context.current_state.items(): - type, state_key = k - self.assertEqual(type, v.type) - self.assertEqual(state_key, v.state_key) - self.assertEqual( - set(old_state), set(context.current_state.values()) + set(e.event_id for e in old_state), set(context.current_state_ids.values()) ) - self.assertIsNone(context.state_group) + self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_annotate_with_old_state(self): @@ -449,18 +457,10 @@ class StateTestCase(unittest.TestCase): event, old_state=old_state ) - for k, v in context.current_state.items(): - type, state_key = k - self.assertEqual(type, v.type) - self.assertEqual(state_key, v.state_key) - self.assertEqual( - set(old_state), - set(context.current_state.values()) + set(e.event_id for e in old_state), set(context.prev_state_ids.values()) ) - self.assertIsNone(context.state_group) - @defer.inlineCallbacks def test_trivial_annotate_message(self): event = create_event(type="test_message", name="event") @@ -473,20 +473,15 @@ class StateTestCase(unittest.TestCase): group_name = "group_name_1" - self.store.get_state_groups.return_value = { - group_name: old_state, + self.store.get_state_groups_ids.return_value = { + group_name: {(e.type, e.state_key): e.event_id for e in old_state}, } context = yield self.state.compute_event_context(event) - for k, v in context.current_state.items(): - type, state_key = k - self.assertEqual(type, v.type) - self.assertEqual(state_key, v.state_key) - self.assertEqual( set([e.event_id for e in old_state]), - set([e.event_id for e in context.current_state.values()]) + set(context.current_state_ids.values()) ) self.assertEqual(group_name, context.state_group) @@ -503,23 +498,18 @@ class StateTestCase(unittest.TestCase): group_name = "group_name_1" - self.store.get_state_groups.return_value = { - group_name: old_state, + self.store.get_state_groups_ids.return_value = { + group_name: {(e.type, e.state_key): e.event_id for e in old_state}, } context = yield self.state.compute_event_context(event) - for k, v in context.current_state.items(): - type, state_key = k - self.assertEqual(type, v.type) - self.assertEqual(state_key, v.state_key) - self.assertEqual( set([e.event_id for e in old_state]), - set([e.event_id for e in context.current_state.values()]) + set(context.prev_state_ids.values()) ) - self.assertIsNone(context.state_group) + self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_resolve_message_conflict(self): @@ -543,11 +533,16 @@ class StateTestCase(unittest.TestCase): create_event(type="test4", state_key=""), ] + store = StateGroupStore() + store.register_events(old_state_1) + store.register_events(old_state_2) + self.store.get_events = store.get_events + context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(len(context.current_state), 6) + self.assertEqual(len(context.current_state_ids), 6) - self.assertIsNone(context.state_group) + self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_resolve_state_conflict(self): @@ -571,11 +566,16 @@ class StateTestCase(unittest.TestCase): create_event(type="test4", state_key=""), ] + store = StateGroupStore() + store.register_events(old_state_1) + store.register_events(old_state_2) + self.store.get_events = store.get_events + context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(len(context.current_state), 6) + self.assertEqual(len(context.current_state_ids), 6) - self.assertIsNone(context.state_group) + self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_standard_depth_conflict(self): @@ -606,9 +606,16 @@ class StateTestCase(unittest.TestCase): create_event(type="test1", state_key="1", depth=2), ] + store = StateGroupStore() + store.register_events(old_state_1) + store.register_events(old_state_2) + self.store.get_events = store.get_events + context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(old_state_2[2], context.current_state[("test1", "1")]) + self.assertEqual( + old_state_2[2].event_id, context.current_state_ids[("test1", "1")] + ) # Reverse the depth to make sure we are actually using the depths # during state resolution. @@ -625,17 +632,22 @@ class StateTestCase(unittest.TestCase): create_event(type="test1", state_key="1", depth=1), ] + store.register_events(old_state_1) + store.register_events(old_state_2) + context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(old_state_1[2], context.current_state[("test1", "1")]) + self.assertEqual( + old_state_1[2].event_id, context.current_state_ids[("test1", "1")] + ) def _get_context(self, event, old_state_1, old_state_2): group_name_1 = "group_name_1" group_name_2 = "group_name_2" - self.store.get_state_groups.return_value = { - group_name_1: old_state_1, - group_name_2: old_state_2, + self.store.get_state_groups_ids.return_value = { + group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1}, + group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2}, } return self.state.compute_event_context(event) |