diff options
Diffstat (limited to 'synapse/storage')
25 files changed, 1496 insertions, 499 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 23b4a8d76d..53c685c173 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -30,6 +30,7 @@ from .appservice import ApplicationServiceStore, ApplicationServiceTransactionSt from .client_ips import ClientIpStore from .deviceinbox import DeviceInboxStore from .directory import DirectoryStore +from .e2e_room_keys import EndToEndRoomKeyStore from .end_to_end_keys import EndToEndKeyStore from .engines import PostgresEngine from .event_federation import EventFederationStore @@ -77,6 +78,7 @@ class DataStore(RoomMemberStore, RoomStore, ApplicationServiceTransactionStore, ReceiptsStore, EndToEndKeyStore, + EndToEndRoomKeyStore, SearchStore, TagsStore, AccountDataStore, diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 08dffd774f..d9d0255d0b 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -17,9 +17,10 @@ import sys import threading import time -from six import iteritems, iterkeys, itervalues -from six.moves import intern, range +from six import PY2, iteritems, iterkeys, itervalues +from six.moves import builtins, intern, range +from canonicaljson import json from prometheus_client import Histogram from twisted.internet import defer @@ -1216,3 +1217,32 @@ class _RollbackButIsFineException(Exception): something went wrong. """ pass + + +def db_to_json(db_content): + """ + Take some data from a database row and return a JSON-decoded object. + + Args: + db_content (memoryview|buffer|bytes|bytearray|unicode) + """ + # psycopg2 on Python 3 returns memoryview objects, which we need to + # cast to bytes to decode + if isinstance(db_content, memoryview): + db_content = db_content.tobytes() + + # psycopg2 on Python 2 returns buffer objects, which we need to cast to + # bytes to decode + if PY2 and isinstance(db_content, builtins.buffer): + db_content = bytes(db_content) + + # Decode it to a Unicode string before feeding it to json.loads, so we + # consistenty get a Unicode-containing object out. + if isinstance(db_content, (bytes, bytearray)): + db_content = db_content.decode('utf8') + + try: + return json.loads(db_content) + except Exception: + logging.warning("Tried to decode '%r' as JSON and failed", db_content) + raise diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 8fc678fa67..9ad17b7c25 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -119,21 +119,25 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): for entry in iteritems(to_update): (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry - self._simple_upsert_txn( - txn, - table="user_ips", - keyvalues={ - "user_id": user_id, - "access_token": access_token, - "ip": ip, - "user_agent": user_agent, - "device_id": device_id, - }, - values={ - "last_seen": last_seen, - }, - lock=False, - ) + try: + self._simple_upsert_txn( + txn, + table="user_ips", + keyvalues={ + "user_id": user_id, + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "device_id": device_id, + }, + values={ + "last_seen": last_seen, + }, + lock=False, + ) + except Exception as e: + # Failed to upsert, log and continue + logger.error("Failed to insert client IP %r: %r", entry, e) @defer.inlineCallbacks def get_last_client_ip_by_device(self, user_id, device_id): diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index 73646da025..e06b0bc56d 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -169,7 +169,7 @@ class DeviceInboxStore(BackgroundUpdateStore): local_by_user_then_device = {} for user_id, messages_by_device in messages_by_user_then_device.items(): messages_json_for_user = {} - devices = messages_by_device.keys() + devices = list(messages_by_device.keys()) if len(devices) == 1 and devices[0] == "*": # Handle wildcard device_ids. sql = ( diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index c0943ecf91..d10ff9e4b9 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -24,7 +24,7 @@ from synapse.api.errors import StoreError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList -from ._base import Cache, SQLBaseStore +from ._base import Cache, SQLBaseStore, db_to_json logger = logging.getLogger(__name__) @@ -411,7 +411,7 @@ class DeviceStore(SQLBaseStore): if device is not None: key_json = device.get("key_json", None) if key_json: - result["keys"] = json.loads(key_json) + result["keys"] = db_to_json(key_json) device_display_name = device.get("device_display_name", None) if device_display_name: result["device_display_name"] = device_display_name @@ -466,7 +466,7 @@ class DeviceStore(SQLBaseStore): retcol="content", desc="_get_cached_user_device", ) - defer.returnValue(json.loads(content)) + defer.returnValue(db_to_json(content)) @cachedInlineCallbacks() def _get_cached_devices_for_user(self, user_id): @@ -479,7 +479,7 @@ class DeviceStore(SQLBaseStore): desc="_get_cached_devices_for_user", ) defer.returnValue({ - device["device_id"]: json.loads(device["content"]) + device["device_id"]: db_to_json(device["content"]) for device in devices }) @@ -511,7 +511,7 @@ class DeviceStore(SQLBaseStore): key_json = device.get("key_json", None) if key_json: - result["keys"] = json.loads(key_json) + result["keys"] = db_to_json(key_json) device_display_name = device.get("device_display_name", None) if device_display_name: result["device_display_name"] = device_display_name diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 808194236a..61a029a53c 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -75,7 +75,6 @@ class DirectoryWorkerStore(SQLBaseStore): }, retcol="creator", desc="get_room_alias_creator", - allow_none=True ) @cached(max_entries=5000) @@ -91,7 +90,7 @@ class DirectoryWorkerStore(SQLBaseStore): class DirectoryStore(DirectoryWorkerStore): @defer.inlineCallbacks def create_room_alias_association(self, room_alias, room_id, servers, creator=None): - """ Creates an associatin between a room alias and room_id/servers + """ Creates an association between a room alias and room_id/servers Args: room_alias (RoomAlias) diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py new file mode 100644 index 0000000000..f25ded2295 --- /dev/null +++ b/synapse/storage/e2e_room_keys.py @@ -0,0 +1,320 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from twisted.internet import defer + +from synapse.api.errors import StoreError + +from ._base import SQLBaseStore + + +class EndToEndRoomKeyStore(SQLBaseStore): + + @defer.inlineCallbacks + def get_e2e_room_key(self, user_id, version, room_id, session_id): + """Get the encrypted E2E room key for a given session from a given + backup version of room_keys. We only store the 'best' room key for a given + session at a given time, as determined by the handler. + + Args: + user_id(str): the user whose backup we're querying + version(str): the version ID of the backup for the set of keys we're querying + room_id(str): the ID of the room whose keys we're querying. + This is a bit redundant as it's implied by the session_id, but + we include for consistency with the rest of the API. + session_id(str): the session whose room_key we're querying. + + Returns: + A deferred dict giving the session_data and message metadata for + this room key. + """ + + row = yield self._simple_select_one( + table="e2e_room_keys", + keyvalues={ + "user_id": user_id, + "version": version, + "room_id": room_id, + "session_id": session_id, + }, + retcols=( + "first_message_index", + "forwarded_count", + "is_verified", + "session_data", + ), + desc="get_e2e_room_key", + ) + + row["session_data"] = json.loads(row["session_data"]) + + defer.returnValue(row) + + @defer.inlineCallbacks + def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key): + """Replaces or inserts the encrypted E2E room key for a given session in + a given backup + + Args: + user_id(str): the user whose backup we're setting + version(str): the version ID of the backup we're updating + room_id(str): the ID of the room whose keys we're setting + session_id(str): the session whose room_key we're setting + room_key(dict): the room_key being set + Raises: + StoreError + """ + + yield self._simple_upsert( + table="e2e_room_keys", + keyvalues={ + "user_id": user_id, + "room_id": room_id, + "session_id": session_id, + }, + values={ + "version": version, + "first_message_index": room_key['first_message_index'], + "forwarded_count": room_key['forwarded_count'], + "is_verified": room_key['is_verified'], + "session_data": json.dumps(room_key['session_data']), + }, + lock=False, + ) + + @defer.inlineCallbacks + def get_e2e_room_keys( + self, user_id, version, room_id=None, session_id=None + ): + """Bulk get the E2E room keys for a given backup, optionally filtered to a given + room, or a given session. + + Args: + user_id(str): the user whose backup we're querying + version(str): the version ID of the backup for the set of keys we're querying + room_id(str): Optional. the ID of the room whose keys we're querying, if any. + If not specified, we return the keys for all the rooms in the backup. + session_id(str): Optional. the session whose room_key we're querying, if any. + If specified, we also require the room_id to be specified. + If not specified, we return all the keys in this version of + the backup (or for the specified room) + + Returns: + A deferred list of dicts giving the session_data and message metadata for + these room keys. + """ + + keyvalues = { + "user_id": user_id, + "version": version, + } + if room_id: + keyvalues['room_id'] = room_id + if session_id: + keyvalues['session_id'] = session_id + + rows = yield self._simple_select_list( + table="e2e_room_keys", + keyvalues=keyvalues, + retcols=( + "user_id", + "room_id", + "session_id", + "first_message_index", + "forwarded_count", + "is_verified", + "session_data", + ), + desc="get_e2e_room_keys", + ) + + sessions = {'rooms': {}} + for row in rows: + room_entry = sessions['rooms'].setdefault(row['room_id'], {"sessions": {}}) + room_entry['sessions'][row['session_id']] = { + "first_message_index": row["first_message_index"], + "forwarded_count": row["forwarded_count"], + "is_verified": row["is_verified"], + "session_data": json.loads(row["session_data"]), + } + + defer.returnValue(sessions) + + @defer.inlineCallbacks + def delete_e2e_room_keys( + self, user_id, version, room_id=None, session_id=None + ): + """Bulk delete the E2E room keys for a given backup, optionally filtered to a given + room or a given session. + + Args: + user_id(str): the user whose backup we're deleting from + version(str): the version ID of the backup for the set of keys we're deleting + room_id(str): Optional. the ID of the room whose keys we're deleting, if any. + If not specified, we delete the keys for all the rooms in the backup. + session_id(str): Optional. the session whose room_key we're querying, if any. + If specified, we also require the room_id to be specified. + If not specified, we delete all the keys in this version of + the backup (or for the specified room) + + Returns: + A deferred of the deletion transaction + """ + + keyvalues = { + "user_id": user_id, + "version": version, + } + if room_id: + keyvalues['room_id'] = room_id + if session_id: + keyvalues['session_id'] = session_id + + yield self._simple_delete( + table="e2e_room_keys", + keyvalues=keyvalues, + desc="delete_e2e_room_keys", + ) + + @staticmethod + def _get_current_version(txn, user_id): + txn.execute( + "SELECT MAX(version) FROM e2e_room_keys_versions " + "WHERE user_id=? AND deleted=0", + (user_id,) + ) + row = txn.fetchone() + if not row: + raise StoreError(404, 'No current backup version') + return row[0] + + def get_e2e_room_keys_version_info(self, user_id, version=None): + """Get info metadata about a version of our room_keys backup. + + Args: + user_id(str): the user whose backup we're querying + version(str): Optional. the version ID of the backup we're querying about + If missing, we return the information about the current version. + Raises: + StoreError: with code 404 if there are no e2e_room_keys_versions present + Returns: + A deferred dict giving the info metadata for this backup version + """ + + def _get_e2e_room_keys_version_info_txn(txn): + if version is None: + this_version = self._get_current_version(txn, user_id) + else: + this_version = version + + result = self._simple_select_one_txn( + txn, + table="e2e_room_keys_versions", + keyvalues={ + "user_id": user_id, + "version": this_version, + "deleted": 0, + }, + retcols=( + "version", + "algorithm", + "auth_data", + ), + ) + result["auth_data"] = json.loads(result["auth_data"]) + return result + + return self.runInteraction( + "get_e2e_room_keys_version_info", + _get_e2e_room_keys_version_info_txn + ) + + def create_e2e_room_keys_version(self, user_id, info): + """Atomically creates a new version of this user's e2e_room_keys store + with the given version info. + + Args: + user_id(str): the user whose backup we're creating a version + info(dict): the info about the backup version to be created + + Returns: + A deferred string for the newly created version ID + """ + + def _create_e2e_room_keys_version_txn(txn): + txn.execute( + "SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?", + (user_id,) + ) + current_version = txn.fetchone()[0] + if current_version is None: + current_version = '0' + + new_version = str(int(current_version) + 1) + + self._simple_insert_txn( + txn, + table="e2e_room_keys_versions", + values={ + "user_id": user_id, + "version": new_version, + "algorithm": info["algorithm"], + "auth_data": json.dumps(info["auth_data"]), + }, + ) + + return new_version + + return self.runInteraction( + "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn + ) + + def delete_e2e_room_keys_version(self, user_id, version=None): + """Delete a given backup version of the user's room keys. + Doesn't delete their actual key data. + + Args: + user_id(str): the user whose backup version we're deleting + version(str): Optional. the version ID of the backup version we're deleting + If missing, we delete the current backup version info. + Raises: + StoreError: with code 404 if there are no e2e_room_keys_versions present, + or if the version requested doesn't exist. + """ + + def _delete_e2e_room_keys_version_txn(txn): + if version is None: + this_version = self._get_current_version(txn, user_id) + else: + this_version = version + + return self._simple_update_one_txn( + txn, + table="e2e_room_keys_versions", + keyvalues={ + "user_id": user_id, + "version": this_version, + }, + updatevalues={ + "deleted": 1, + } + ) + + return self.runInteraction( + "delete_e2e_room_keys_version", + _delete_e2e_room_keys_version_txn + ) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 523b4360c3..1f1721e820 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -14,13 +14,13 @@ # limitations under the License. from six import iteritems -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.util.caches.descriptors import cached -from ._base import SQLBaseStore +from ._base import SQLBaseStore, db_to_json class EndToEndKeyStore(SQLBaseStore): @@ -90,7 +90,7 @@ class EndToEndKeyStore(SQLBaseStore): for user_id, device_keys in iteritems(results): for device_id, device_info in iteritems(device_keys): - device_info["keys"] = json.loads(device_info.pop("key_json")) + device_info["keys"] = db_to_json(device_info.pop("key_json")) defer.returnValue(results) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 8a0386c1a4..42225f8a2a 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -41,13 +41,18 @@ class PostgresEngine(object): db_conn.set_isolation_level( self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ ) + + # Set the bytea output to escape, vs the default of hex + cursor = db_conn.cursor() + cursor.execute("SET bytea_output TO escape") + # Asynchronous commit, don't wait for the server to call fsync before # ending the transaction. # https://www.postgresql.org/docs/current/static/wal-async-commit.html if not self.synchronous_commit: - cursor = db_conn.cursor() cursor.execute("SET synchronous_commit TO OFF") - cursor.close() + + cursor.close() def is_deadlock(self, error): if isinstance(error, self.module.DatabaseError): diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 24345b20a6..3faca2a042 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -376,33 +376,25 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, @defer.inlineCallbacks def get_missing_events(self, room_id, earliest_events, latest_events, - limit, min_depth): + limit): ids = yield self.runInteraction( "get_missing_events", self._get_missing_events, - room_id, earliest_events, latest_events, limit, min_depth + room_id, earliest_events, latest_events, limit, ) - events = yield self._get_events(ids) - - events = sorted( - [ev for ev in events if ev.depth >= min_depth], - key=lambda e: e.depth, - ) - - defer.returnValue(events[:limit]) + defer.returnValue(events) def _get_missing_events(self, txn, room_id, earliest_events, latest_events, - limit, min_depth): - - earliest_events = set(earliest_events) - front = set(latest_events) - earliest_events + limit): - event_results = set() + seen_events = set(earliest_events) + front = set(latest_events) - seen_events + event_results = [] query = ( "SELECT prev_event_id FROM event_edges " - "WHERE event_id = ? AND is_state = ? " + "WHERE room_id = ? AND event_id = ? AND is_state = ? " "LIMIT ?" ) @@ -411,18 +403,20 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, for event_id in front: txn.execute( query, - (event_id, False, limit - len(event_results)) + (room_id, event_id, False, limit - len(event_results)) ) - for e_id, in txn: - new_front.add(e_id) + new_results = set(t[0] for t in txn) - seen_events - new_front -= earliest_events - new_front -= event_results + new_front |= new_results + seen_events |= new_results + event_results.extend(new_results) front = new_front - event_results |= new_front + # we built the list working backwards from latest_events; we now need to + # reverse it so that the events are approximately chronological. + event_results.reverse() return event_results diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 025a7fb6d9..8881b009df 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -19,7 +19,7 @@ import logging from collections import OrderedDict, deque, namedtuple from functools import wraps -from six import iteritems +from six import iteritems, text_type from six.moves import range from canonicaljson import json @@ -34,10 +34,12 @@ from synapse.api.errors import SynapseError from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.state import StateResolutionStore from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.event_federation import EventFederationStore from synapse.storage.events_worker import EventsWorkerStore from synapse.types import RoomStreamToken, get_domain_from_id +from synapse.util import batch_iter from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.frozenutils import frozendict_json_encoder @@ -386,12 +388,10 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore ) for room_id, ev_ctx_rm in iteritems(events_by_room): - # Work out new extremities by recursively adding and removing - # the new events. latest_event_ids = yield self.get_latest_event_ids_in_room( room_id ) - new_latest_event_ids = yield self._calculate_new_extremeties( + new_latest_event_ids = yield self._calculate_new_extremities( room_id, ev_ctx_rm, latest_event_ids ) @@ -400,6 +400,12 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore # No change in extremities, so no change in state continue + # there should always be at least one forward extremity. + # (except during the initial persistence of the send_join + # results, in which case there will be no existing + # extremities, so we'll `continue` above and skip this bit.) + assert new_latest_event_ids, "No forward extremities left!" + new_forward_extremeties[room_id] = new_latest_event_ids len_1 = ( @@ -517,44 +523,79 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore ) @defer.inlineCallbacks - def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids): - """Calculates the new forward extremeties for a room given events to + def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids): + """Calculates the new forward extremities for a room given events to persist. Assumes that we are only persisting events for one room at a time. """ - new_latest_event_ids = set(latest_event_ids) - # First, add all the new events to the list - new_latest_event_ids.update( - event.event_id for event, ctx in event_contexts + + # we're only interested in new events which aren't outliers and which aren't + # being rejected. + new_events = [ + event for event, ctx in event_contexts if not event.internal_metadata.is_outlier() and not ctx.rejected + ] + + # start with the existing forward extremities + result = set(latest_event_ids) + + # add all the new events to the list + result.update( + event.event_id for event in new_events ) - # Now remove all events that are referenced by the to-be-added events - new_latest_event_ids.difference_update( + + # Now remove all events which are prev_events of any of the new events + result.difference_update( e_id - for event, ctx in event_contexts + for event in new_events for e_id, _ in event.prev_events - if not event.internal_metadata.is_outlier() and not ctx.rejected ) - # And finally remove any events that are referenced by previously added - # events. - rows = yield self._simple_select_many_batch( - table="event_edges", - column="prev_event_id", - iterable=list(new_latest_event_ids), - retcols=["prev_event_id"], - keyvalues={ - "is_state": False, - }, - desc="_calculate_new_extremeties", - ) + # Finally, remove any events which are prev_events of any existing events. + existing_prevs = yield self._get_events_which_are_prevs(result) + result.difference_update(existing_prevs) - new_latest_event_ids.difference_update( - row["prev_event_id"] for row in rows - ) + defer.returnValue(result) + + @defer.inlineCallbacks + def _get_events_which_are_prevs(self, event_ids): + """Filter the supplied list of event_ids to get those which are prev_events of + existing (non-outlier/rejected) events. + + Args: + event_ids (Iterable[str]): event ids to filter + + Returns: + Deferred[List[str]]: filtered event ids + """ + results = [] + + def _get_events(txn, batch): + sql = """ + SELECT prev_event_id + FROM event_edges + INNER JOIN events USING (event_id) + LEFT JOIN rejections USING (event_id) + WHERE + prev_event_id IN (%s) + AND NOT events.outlier + AND rejections.event_id IS NULL + """ % ( + ",".join("?" for _ in batch), + ) + + txn.execute(sql, batch) + results.extend(r[0] for r in txn) + + for chunk in batch_iter(event_ids, 100): + yield self.runInteraction( + "_get_events_which_are_prevs", + _get_events, + chunk, + ) - defer.returnValue(new_latest_event_ids) + defer.returnValue(results) @defer.inlineCallbacks def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids, @@ -586,10 +627,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore the new current state is only returned if we've already calculated it. """ - - if not new_latest_event_ids: - return - # map from state_group to ((type, key) -> event_id) state map state_groups_map = {} @@ -695,19 +732,17 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore # Ok, we need to defer to the state handler to resolve our state sets. - def get_events(ev_ids): - return self.get_events( - ev_ids, get_prev_content=False, check_redacted=False, - ) - state_groups = { sg: state_groups_map[sg] for sg in new_state_groups } events_map = {ev.event_id: ev for ev, _ in events_context} + room_version = yield self.get_room_version(room_id) + logger.debug("calling resolve_state_groups from preserve_events") res = yield self._state_resolution_handler.resolve_state_groups( - room_id, state_groups, events_map, get_events + room_id, room_version, state_groups, events_map, + state_res_store=StateResolutionStore(self) ) defer.returnValue((res.state, None)) @@ -816,6 +851,27 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore # Insert into event_to_state_groups. self._store_event_state_mappings_txn(txn, events_and_contexts) + # We want to store event_auth mappings for rejected events, as they're + # used in state res v2. + # This is only necessary if the rejected event appears in an accepted + # event's auth chain, but its easier for now just to store them (and + # it doesn't take much storage compared to storing the entire event + # anyway). + self._simple_insert_many_txn( + txn, + table="event_auth", + values=[ + { + "event_id": event.event_id, + "room_id": event.room_id, + "auth_id": auth_id, + } + for event, _ in events_and_contexts + for auth_id, _ in event.auth_events + if event.is_state() + ], + ) + # _store_rejected_events_txn filters out any events which were # rejected, and returns the filtered list. events_and_contexts = self._store_rejected_events_txn( @@ -928,6 +984,10 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore ) self._invalidate_cache_and_stream( + txn, self.get_room_summary, (room_id,) + ) + + self._invalidate_cache_and_stream( txn, self.get_current_state_ids, (room_id,) ) @@ -1218,7 +1278,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore "sender": event.sender, "contains_url": ( "url" in event.content - and isinstance(event.content["url"], basestring) + and isinstance(event.content["url"], text_type) ), } for event, _ in events_and_contexts @@ -1287,21 +1347,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore txn, event.room_id, event.redacts ) - self._simple_insert_many_txn( - txn, - table="event_auth", - values=[ - { - "event_id": event.event_id, - "room_id": event.room_id, - "auth_id": auth_id, - } - for event, _ in events_and_contexts - for auth_id, _ in event.auth_events - if event.is_state() - ], - ) - # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. self._handle_mult_prev_events( @@ -1527,7 +1572,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore contains_url = "url" in content if contains_url: - contains_url &= isinstance(content["url"], basestring) + contains_url &= isinstance(content["url"], text_type) except (KeyError, AttributeError): # If the event is missing a necessary field then # skip over it. @@ -1884,20 +1929,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore ")" ) - # create an index on should_delete because later we'll be looking for - # the should_delete / shouldn't_delete subsets - txn.execute( - "CREATE INDEX events_to_purge_should_delete" - " ON events_to_purge(should_delete)", - ) - - # We do joins against events_to_purge for e.g. calculating state - # groups to purge, etc., so lets make an index. - txn.execute( - "CREATE INDEX events_to_purge_id" - " ON events_to_purge(event_id)", - ) - # First ensure that we're not about to delete all the forward extremeties txn.execute( "SELECT e.event_id, e.depth FROM events as e " @@ -1908,9 +1939,9 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore (room_id,) ) rows = txn.fetchall() - max_depth = max(row[0] for row in rows) + max_depth = max(row[1] for row in rows) - if max_depth <= token.topological: + if max_depth < token.topological: # We need to ensure we don't delete all the events from the database # otherwise we wouldn't be able to send any events (due to not # having any backwards extremeties) @@ -1924,19 +1955,45 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore should_delete_params = () if not delete_local_events: should_delete_expr += " AND event_id NOT LIKE ?" - should_delete_params += ("%:" + self.hs.hostname, ) + + # We include the parameter twice since we use the expression twice + should_delete_params += ( + "%:" + self.hs.hostname, + "%:" + self.hs.hostname, + ) should_delete_params += (room_id, token.topological) + # Note that we insert events that are outliers and aren't going to be + # deleted, as nothing will happen to them. txn.execute( "INSERT INTO events_to_purge" " SELECT event_id, %s" " FROM events AS e LEFT JOIN state_events USING (event_id)" - " WHERE e.room_id = ? AND topological_ordering < ?" % ( + " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?" + % ( + should_delete_expr, should_delete_expr, ), should_delete_params, ) + + # We create the indices *after* insertion as that's a lot faster. + + # create an index on should_delete because later we'll be looking for + # the should_delete / shouldn't_delete subsets + txn.execute( + "CREATE INDEX events_to_purge_should_delete" + " ON events_to_purge(should_delete)", + ) + + # We do joins against events_to_purge for e.g. calculating state + # groups to purge, etc., so lets make an index. + txn.execute( + "CREATE INDEX events_to_purge_id" + " ON events_to_purge(event_id)", + ) + txn.execute( "SELECT event_id, should_delete FROM events_to_purge" ) @@ -2032,7 +2089,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore for sg in remaining_state_groups: logger.info("[purge] de-delta-ing remaining state group %s", sg) curr_state = self._get_state_groups_from_groups_txn( - txn, [sg], types=None + txn, [sg], ) curr_state = curr_state[sg] diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 59822178ff..a8326f5296 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import itertools import logging from collections import namedtuple @@ -265,7 +266,7 @@ class EventsWorkerStore(SQLBaseStore): """ with Measure(self._clock, "_fetch_event_list"): try: - event_id_lists = zip(*event_list)[0] + event_id_lists = list(zip(*event_list))[0] event_ids = [ item for sublist in event_id_lists for item in sublist ] @@ -299,14 +300,14 @@ class EventsWorkerStore(SQLBaseStore): logger.exception("do_fetch") # We only want to resolve deferreds from the main thread - def fire(evs): + def fire(evs, exc): for _, d in evs: if not d.called: with PreserveLoggingContext(): - d.errback(e) + d.errback(exc) with PreserveLoggingContext(): - self.hs.get_reactor().callFromThread(fire, event_list) + self.hs.get_reactor().callFromThread(fire, event_list, e) @defer.inlineCallbacks def _enqueue_events(self, events, check_redacted=True, allow_rejected=False): diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py index 2d5896c5b4..6ddcc909bf 100644 --- a/synapse/storage/filtering.py +++ b/synapse/storage/filtering.py @@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.api.errors import Codes, SynapseError from synapse.util.caches.descriptors import cachedInlineCallbacks -from ._base import SQLBaseStore +from ._base import SQLBaseStore, db_to_json class FilteringStore(SQLBaseStore): @@ -44,7 +44,7 @@ class FilteringStore(SQLBaseStore): desc="get_user_filter", ) - defer.returnValue(json.loads(bytes(def_json).decode("utf-8"))) + defer.returnValue(db_to_json(def_json)) def add_user_filter(self, user_localpart, user_filter): def_json = encode_canonical_json(user_filter) diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index f547977600..8af17921e3 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) # py2 sqlite has buffer hardcoded as only binary type, so we must use it, # despite being deprecated and removed in favor of memoryview if six.PY2: - db_binary_type = buffer + db_binary_type = six.moves.builtins.buffer else: db_binary_type = memoryview @@ -134,6 +134,7 @@ class KeyStore(SQLBaseStore): """ key_id = "%s:%s" % (verify_key.alg, verify_key.version) + # XXX fix this to not need a lock (#3819) def _txn(txn): self._simple_upsert_txn( txn, diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py index 06f9a75a97..cf4104dc2e 100644 --- a/synapse/storage/monthly_active_users.py +++ b/synapse/storage/monthly_active_users.py @@ -33,20 +33,29 @@ class MonthlyActiveUsersStore(SQLBaseStore): self._clock = hs.get_clock() self.hs = hs self.reserved_users = () + # Do not add more reserved users than the total allowable number + self._initialise_reserved_users( + dbconn.cursor(), + hs.config.mau_limits_reserved_threepids[:self.hs.config.max_mau_value], + ) - @defer.inlineCallbacks - def initialise_reserved_users(self, threepids): - # TODO Why can't I do this in init? - store = self.hs.get_datastore() + def _initialise_reserved_users(self, txn, threepids): + """Ensures that reserved threepids are accounted for in the MAU table, should + be called on start up. + + Args: + txn (cursor): + threepids (list[dict]): List of threepid dicts to reserve + """ reserved_user_list = [] - # Do not add more reserved users than the total allowable number - for tp in threepids[:self.hs.config.max_mau_value]: - user_id = yield store.get_user_id_by_threepid( + for tp in threepids: + user_id = self.get_user_id_by_threepid_txn( + txn, tp["medium"], tp["address"] ) if user_id: - yield self.upsert_monthly_active_user(user_id) + self.upsert_monthly_active_user_txn(txn, user_id) reserved_user_list.append(user_id) else: logger.warning( @@ -56,8 +65,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): @defer.inlineCallbacks def reap_monthly_active_users(self): - """ - Cleans out monthly active user table to ensure that no stale + """Cleans out monthly active user table to ensure that no stale entries exist. Returns: @@ -147,16 +155,63 @@ class MonthlyActiveUsersStore(SQLBaseStore): return count return self.runInteraction("count_users", _count_users) + @defer.inlineCallbacks + def get_registered_reserved_users_count(self): + """Of the reserved threepids defined in config, how many are associated + with registered users? + + Returns: + Defered[int]: Number of real reserved users + """ + count = 0 + for tp in self.hs.config.mau_limits_reserved_threepids: + user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + tp["medium"], tp["address"] + ) + if user_id: + count = count + 1 + defer.returnValue(count) + + @defer.inlineCallbacks def upsert_monthly_active_user(self, user_id): + """Updates or inserts the user into the monthly active user table, which + is used to track the current MAU usage of the server + + Args: + user_id (str): user to add/update """ - Updates or inserts monthly active user member - Arguments: - user_id (str): user to add/update - Deferred[bool]: True if a new entry was created, False if an - existing one was updated. + is_insert = yield self.runInteraction( + "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, + user_id + ) + + if is_insert: + self.user_last_seen_monthly_active.invalidate((user_id,)) + self.get_monthly_active_count.invalidate(()) + + def upsert_monthly_active_user_txn(self, txn, user_id): + """Updates or inserts monthly active user member + + Note that, after calling this method, it will generally be necessary + to invalidate the caches on user_last_seen_monthly_active and + get_monthly_active_count. We can't do that here, because we are running + in a database thread rather than the main thread, and we can't call + txn.call_after because txn may not be a LoggingTransaction. + + Args: + txn (cursor): + user_id (str): user to add/update + + Returns: + bool: True if a new entry was created, False if an + existing one was updated. """ - is_insert = self._simple_upsert( - desc="upsert_monthly_active_user", + # Am consciously deciding to lock the table on the basis that is ought + # never be a big table and alternative approaches (batching multiple + # upserts into a single txn) introduced a lot of extra complexity. + # See https://github.com/matrix-org/synapse/issues/3854 for more + is_insert = self._simple_upsert_txn( + txn, table="monthly_active_users", keyvalues={ "user_id": user_id, @@ -164,11 +219,9 @@ class MonthlyActiveUsersStore(SQLBaseStore): values={ "timestamp": int(self._clock.time_msec()), }, - lock=False, ) - if is_insert: - self.user_last_seen_monthly_active.invalidate((user_id,)) - self.get_monthly_active_count.invalidate(()) + + return is_insert @cached(num_args=1) def user_last_seen_monthly_active(self, user_id): @@ -199,7 +252,16 @@ class MonthlyActiveUsersStore(SQLBaseStore): Args: user_id(str): the user_id to query """ + if self.hs.config.limit_usage_by_mau: + # Trial users and guests should not be included as part of MAU group + is_guest = yield self.is_guest(user_id) + if is_guest: + return + is_trial = yield self.is_trial_user(user_id) + if is_trial: + return + last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id) now = self.hs.get_clock().time_msec() diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py index e6848c70a0..10133f0a4a 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/profile.py @@ -123,8 +123,6 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_from_remote_profile_cache", ) - -class ProfileStore(ProfileWorkerStore): def set_profile_displayname(self, user_localpart, new_displayname, batchnum): return self._simple_upsert( table="profiles", @@ -165,6 +163,8 @@ class ProfileStore(ProfileWorkerStore): lock=False # we can do this because user_id has a unique index ) + +class ProfileStore(ProfileWorkerStore): def add_remote_profile_cache(self, user_id, displayname, avatar_url): """Ensure we are caching the remote user's profiles. diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 8443bd4c1b..2743b52bad 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -15,7 +15,8 @@ # limitations under the License. import logging -import types + +import six from canonicaljson import encode_canonical_json, json @@ -27,6 +28,11 @@ from ._base import SQLBaseStore logger = logging.getLogger(__name__) +if six.PY2: + db_binary_type = six.moves.builtins.buffer +else: + db_binary_type = memoryview + class PusherWorkerStore(SQLBaseStore): def _decode_pushers_rows(self, rows): @@ -34,18 +40,18 @@ class PusherWorkerStore(SQLBaseStore): dataJson = r['data'] r['data'] = None try: - if isinstance(dataJson, types.BufferType): + if isinstance(dataJson, db_binary_type): dataJson = str(dataJson).decode("UTF8") r['data'] = json.loads(dataJson) except Exception as e: logger.warn( "Invalid JSON in data for pusher %d: %s, %s", - r['id'], dataJson, e.message, + r['id'], dataJson, e.args[0], ) pass - if isinstance(r['pushkey'], types.BufferType): + if isinstance(r['pushkey'], db_binary_type): r['pushkey'] = str(r['pushkey']).decode("UTF8") return rows diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 50706519aa..65061f4c61 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -26,6 +26,11 @@ from synapse.util.caches.descriptors import cached, cachedInlineCallbacks class RegistrationWorkerStore(SQLBaseStore): + def __init__(self, db_conn, hs): + super(RegistrationWorkerStore, self).__init__(db_conn, hs) + + self.config = hs.config + @cached() def get_user_by_id(self, user_id): return self._simple_select_one( @@ -36,12 +41,33 @@ class RegistrationWorkerStore(SQLBaseStore): retcols=[ "name", "password_hash", "is_guest", "consent_version", "consent_server_notice_sent", - "appservice_id", + "appservice_id", "creation_ts", ], allow_none=True, desc="get_user_by_id", ) + @defer.inlineCallbacks + def is_trial_user(self, user_id): + """Checks if user is in the "trial" period, i.e. within the first + N days of registration defined by `mau_trial_days` config + + Args: + user_id (str) + + Returns: + Deferred[bool] + """ + + info = yield self.get_user_by_id(user_id) + if not info: + defer.returnValue(False) + + now = self.clock.time_msec() + trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000 + is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms + defer.returnValue(is_trial) + @cached() def get_user_by_access_token(self, token): """Get a user from the given access token. @@ -436,17 +462,44 @@ class RegistrationStore(RegistrationWorkerStore, @defer.inlineCallbacks def get_user_id_by_threepid(self, medium, address): - ret = yield self._simple_select_one( + """Returns user id from threepid + + Args: + medium (str): threepid medium e.g. email + address (str): threepid address e.g. me@example.com + + Returns: + Deferred[str|None]: user id or None if no user id/threepid mapping exists + """ + user_id = yield self.runInteraction( + "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, + medium, address + ) + defer.returnValue(user_id) + + def get_user_id_by_threepid_txn(self, txn, medium, address): + """Returns user id from threepid + + Args: + txn (cursor): + medium (str): threepid medium e.g. email + address (str): threepid address e.g. me@example.com + + Returns: + str|None: user id or None if no user id/threepid mapping exists + """ + ret = self._simple_select_one_txn( + txn, "user_threepids", { "medium": medium, "address": address }, - ['user_id'], True, 'get_user_id_by_threepid' + ['user_id'], True ) if ret: - defer.returnValue(ret['user_id']) - defer.returnValue(None) + return ret['user_id'] + return None def user_delete_threepid(self, user_id, medium, address): return self._simple_delete( @@ -529,7 +582,7 @@ class RegistrationStore(RegistrationWorkerStore, def _find_next_generated_user_id(txn): txn.execute("SELECT name FROM users") - regex = re.compile("^@(\d+):") + regex = re.compile(r"^@(\d+):") found = set() diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 3378fc77d1..61013b8919 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -186,6 +186,35 @@ class RoomWorkerStore(SQLBaseStore): desc="is_room_blocked", ) + @cachedInlineCallbacks(max_entries=10000) + def get_ratelimit_for_user(self, user_id): + """Check if there are any overrides for ratelimiting for the given + user + + Args: + user_id (str) + + Returns: + RatelimitOverride if there is an override, else None. If the contents + of RatelimitOverride are None or 0 then ratelimitng has been + disabled for that user entirely. + """ + row = yield self._simple_select_one( + table="ratelimit_override", + keyvalues={"user_id": user_id}, + retcols=("messages_per_second", "burst_count"), + allow_none=True, + desc="get_ratelimit_for_user", + ) + + if row: + defer.returnValue(RatelimitOverride( + messages_per_second=row["messages_per_second"], + burst_count=row["burst_count"], + )) + else: + defer.returnValue(None) + class RoomStore(RoomWorkerStore, SearchStore): @@ -469,35 +498,6 @@ class RoomStore(RoomWorkerStore, SearchStore): "get_all_new_public_rooms", get_all_new_public_rooms ) - @cachedInlineCallbacks(max_entries=10000) - def get_ratelimit_for_user(self, user_id): - """Check if there are any overrides for ratelimiting for the given - user - - Args: - user_id (str) - - Returns: - RatelimitOverride if there is an override, else None. If the contents - of RatelimitOverride are None or 0 then ratelimitng has been - disabled for that user entirely. - """ - row = yield self._simple_select_one( - table="ratelimit_override", - keyvalues={"user_id": user_id}, - retcols=("messages_per_second", "burst_count"), - allow_none=True, - desc="get_ratelimit_for_user", - ) - - if row: - defer.returnValue(RatelimitOverride( - messages_per_second=row["messages_per_second"], - burst_count=row["burst_count"], - )) - else: - defer.returnValue(None) - @defer.inlineCallbacks def block_room(self, room_id, user_id): yield self._simple_insert( diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 9b4e6d6aa8..0707f9a86a 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -51,6 +51,12 @@ ProfileInfo = namedtuple( "ProfileInfo", ("avatar_url", "display_name") ) +# "members" points to a truncated list of (user_id, event_id) tuples for users of +# a given membership type, suitable for use in calculating heroes for a room. +# "count" points to the total numberr of users of a given membership type. +MemberSummary = namedtuple( + "MemberSummary", ("members", "count") +) _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" @@ -82,6 +88,65 @@ class RoomMemberWorkerStore(EventsWorkerStore): return [to_ascii(r[0]) for r in txn] return self.runInteraction("get_users_in_room", f) + @cached(max_entries=100000) + def get_room_summary(self, room_id): + """ Get the details of a room roughly suitable for use by the room + summary extension to /sync. Useful when lazy loading room members. + Args: + room_id (str): The room ID to query + Returns: + Deferred[dict[str, MemberSummary]: + dict of membership states, pointing to a MemberSummary named tuple. + """ + + def _get_room_summary_txn(txn): + # first get counts. + # We do this all in one transaction to keep the cache small. + # FIXME: get rid of this when we have room_stats + sql = """ + SELECT count(*), m.membership FROM room_memberships as m + INNER JOIN current_state_events as c + ON m.event_id = c.event_id + AND m.room_id = c.room_id + AND m.user_id = c.state_key + WHERE c.type = 'm.room.member' AND c.room_id = ? + GROUP BY m.membership + """ + + txn.execute(sql, (room_id,)) + res = {} + for count, membership in txn: + summary = res.setdefault(to_ascii(membership), MemberSummary([], count)) + + # we order by membership and then fairly arbitrarily by event_id so + # heroes are consistent + sql = """ + SELECT m.user_id, m.membership, m.event_id + FROM room_memberships as m + INNER JOIN current_state_events as c + ON m.event_id = c.event_id + AND m.room_id = c.room_id + AND m.user_id = c.state_key + WHERE c.type = 'm.room.member' AND c.room_id = ? + ORDER BY + CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC, + m.event_id ASC + LIMIT ? + """ + + # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user. + txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6)) + for user_id, membership, event_id in txn: + summary = res[to_ascii(membership)] + # we will always have a summary for this membership type at this + # point given the summary currently contains the counts. + members = summary.members + members.append((to_ascii(user_id), to_ascii(event_id))) + + return res + + return self.runInteraction("get_room_summary", _get_room_summary_txn) + @cached() def get_invited_rooms_for_user(self, user_id): """ Get all the rooms the user is invited to diff --git a/synapse/storage/schema/delta/51/e2e_room_keys.sql b/synapse/storage/schema/delta/51/e2e_room_keys.sql new file mode 100644 index 0000000000..c0e66a697d --- /dev/null +++ b/synapse/storage/schema/delta/51/e2e_room_keys.sql @@ -0,0 +1,39 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- users' optionally backed up encrypted e2e sessions +CREATE TABLE e2e_room_keys ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + session_id TEXT NOT NULL, + version TEXT NOT NULL, + first_message_index INT, + forwarded_count INT, + is_verified BOOLEAN, + session_data TEXT NOT NULL +); + +CREATE UNIQUE INDEX e2e_room_keys_idx ON e2e_room_keys(user_id, room_id, session_id); + +-- the metadata for each generation of encrypted e2e session backups +CREATE TABLE e2e_room_keys_versions ( + user_id TEXT NOT NULL, + version TEXT NOT NULL, + algorithm TEXT NOT NULL, + auth_data TEXT NOT NULL, + deleted SMALLINT DEFAULT 0 NOT NULL +); + +CREATE UNIQUE INDEX e2e_room_keys_versions_idx ON e2e_room_keys_versions(user_id, version); diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index 5623391f6e..158e9dbe7b 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -27,7 +27,7 @@ from ._base import SQLBaseStore # py2 sqlite has buffer hardcoded as only binary type, so we must use it, # despite being deprecated and removed in favor of memoryview if six.PY2: - db_binary_type = buffer + db_binary_type = six.moves.builtins.buffer else: db_binary_type = memoryview diff --git a/synapse/storage/state.py b/synapse/storage/state.py index dd03c4168b..ef65929bb2 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -19,6 +19,8 @@ from collections import namedtuple from six import iteritems, itervalues from six.moves import range +import attr + from twisted.internet import defer from synapse.api.constants import EventTypes @@ -48,6 +50,318 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt return len(self.delta_ids) if self.delta_ids else 0 +@attr.s(slots=True) +class StateFilter(object): + """A filter used when querying for state. + + Attributes: + types (dict[str, set[str]|None]): Map from type to set of state keys (or + None). This specifies which state_keys for the given type to fetch + from the DB. If None then all events with that type are fetched. If + the set is empty then no events with that type are fetched. + include_others (bool): Whether to fetch events with types that do not + appear in `types`. + """ + + types = attr.ib() + include_others = attr.ib(default=False) + + def __attrs_post_init__(self): + # If `include_others` is set we canonicalise the filter by removing + # wildcards from the types dictionary + if self.include_others: + self.types = { + k: v for k, v in iteritems(self.types) + if v is not None + } + + @staticmethod + def all(): + """Creates a filter that fetches everything. + + Returns: + StateFilter + """ + return StateFilter(types={}, include_others=True) + + @staticmethod + def none(): + """Creates a filter that fetches nothing. + + Returns: + StateFilter + """ + return StateFilter(types={}, include_others=False) + + @staticmethod + def from_types(types): + """Creates a filter that only fetches the given types + + Args: + types (Iterable[tuple[str, str|None]]): A list of type and state + keys to fetch. A state_key of None fetches everything for + that type + + Returns: + StateFilter + """ + type_dict = {} + for typ, s in types: + if typ in type_dict: + if type_dict[typ] is None: + continue + + if s is None: + type_dict[typ] = None + continue + + type_dict.setdefault(typ, set()).add(s) + + return StateFilter(types=type_dict) + + @staticmethod + def from_lazy_load_member_list(members): + """Creates a filter that returns all non-member events, plus the member + events for the given users + + Args: + members (iterable[str]): Set of user IDs + + Returns: + StateFilter + """ + return StateFilter( + types={EventTypes.Member: set(members)}, + include_others=True, + ) + + def return_expanded(self): + """Creates a new StateFilter where type wild cards have been removed + (except for memberships). The returned filter is a superset of the + current one, i.e. anything that passes the current filter will pass + the returned filter. + + This helps the caching as the DictionaryCache knows if it has *all* the + state, but does not know if it has all of the keys of a particular type, + which makes wildcard lookups expensive unless we have a complete cache. + Hence, if we are doing a wildcard lookup, populate the cache fully so + that we can do an efficient lookup next time. + + Note that since we have two caches, one for membership events and one for + other events, we can be a bit more clever than simply returning + `StateFilter.all()` if `has_wildcards()` is True. + + We return a StateFilter where: + 1. the list of membership events to return is the same + 2. if there is a wildcard that matches non-member events we + return all non-member events + + Returns: + StateFilter + """ + + if self.is_full(): + # If we're going to return everything then there's nothing to do + return self + + if not self.has_wildcards(): + # If there are no wild cards, there's nothing to do + return self + + if EventTypes.Member in self.types: + get_all_members = self.types[EventTypes.Member] is None + else: + get_all_members = self.include_others + + has_non_member_wildcard = self.include_others or any( + state_keys is None + for t, state_keys in iteritems(self.types) + if t != EventTypes.Member + ) + + if not has_non_member_wildcard: + # If there are no non-member wild cards we can just return ourselves + return self + + if get_all_members: + # We want to return everything. + return StateFilter.all() + else: + # We want to return all non-members, but only particular + # memberships + return StateFilter( + types={EventTypes.Member: self.types[EventTypes.Member]}, + include_others=True, + ) + + def make_sql_filter_clause(self): + """Converts the filter to an SQL clause. + + For example: + + f = StateFilter.from_types([("m.room.create", "")]) + clause, args = f.make_sql_filter_clause() + clause == "(type = ? AND state_key = ?)" + args == ['m.room.create', ''] + + + Returns: + tuple[str, list]: The SQL string (may be empty) and arguments. An + empty SQL string is returned when the filter matches everything + (i.e. is "full"). + """ + + where_clause = "" + where_args = [] + + if self.is_full(): + return where_clause, where_args + + if not self.include_others and not self.types: + # i.e. this is an empty filter, so we need to return a clause that + # will match nothing + return "1 = 2", [] + + # First we build up a lost of clauses for each type/state_key combo + clauses = [] + for etype, state_keys in iteritems(self.types): + if state_keys is None: + clauses.append("(type = ?)") + where_args.append(etype) + continue + + for state_key in state_keys: + clauses.append("(type = ? AND state_key = ?)") + where_args.extend((etype, state_key)) + + # This will match anything that appears in `self.types` + where_clause = " OR ".join(clauses) + + # If we want to include stuff that's not in the types dict then we add + # a `OR type NOT IN (...)` clause to the end. + if self.include_others: + if where_clause: + where_clause += " OR " + + where_clause += "type NOT IN (%s)" % ( + ",".join(["?"] * len(self.types)), + ) + where_args.extend(self.types) + + return where_clause, where_args + + def max_entries_returned(self): + """Returns the maximum number of entries this filter will return if + known, otherwise returns None. + + For example a simple state filter asking for `("m.room.create", "")` + will return 1, whereas the default state filter will return None. + + This is used to bail out early if the right number of entries have been + fetched. + """ + if self.has_wildcards(): + return None + + return len(self.concrete_types()) + + def filter_state(self, state_dict): + """Returns the state filtered with by this StateFilter + + Args: + state (dict[tuple[str, str], Any]): The state map to filter + + Returns: + dict[tuple[str, str], Any]: The filtered state map + """ + if self.is_full(): + return dict(state_dict) + + filtered_state = {} + for k, v in iteritems(state_dict): + typ, state_key = k + if typ in self.types: + state_keys = self.types[typ] + if state_keys is None or state_key in state_keys: + filtered_state[k] = v + elif self.include_others: + filtered_state[k] = v + + return filtered_state + + def is_full(self): + """Whether this filter fetches everything or not + + Returns: + bool + """ + return self.include_others and not self.types + + def has_wildcards(self): + """Whether the filter includes wildcards or is attempting to fetch + specific state. + + Returns: + bool + """ + + return ( + self.include_others + or any( + state_keys is None + for state_keys in itervalues(self.types) + ) + ) + + def concrete_types(self): + """Returns a list of concrete type/state_keys (i.e. not None) that + will be fetched. This will be a complete list if `has_wildcards` + returns False, but otherwise will be a subset (or even empty). + + Returns: + list[tuple[str,str]] + """ + return [ + (t, s) + for t, state_keys in iteritems(self.types) + if state_keys is not None + for s in state_keys + ] + + def get_member_split(self): + """Return the filter split into two: one which assumes it's exclusively + matching against member state, and one which assumes it's matching + against non member state. + + This is useful due to the returned filters giving correct results for + `is_full()`, `has_wildcards()`, etc, when operating against maps that + either exclusively contain member events or only contain non-member + events. (Which is the case when dealing with the member vs non-member + state caches). + + Returns: + tuple[StateFilter, StateFilter]: The member and non member filters + """ + + if EventTypes.Member in self.types: + state_keys = self.types[EventTypes.Member] + if state_keys is None: + member_filter = StateFilter.all() + else: + member_filter = StateFilter({EventTypes.Member: state_keys}) + elif self.include_others: + member_filter = StateFilter.all() + else: + member_filter = StateFilter.none() + + non_member_filter = StateFilter( + types={k: v for k, v in iteritems(self.types) if k != EventTypes.Member}, + include_others=self.include_others, + ) + + return member_filter, non_member_filter + + # this inherits from EventsWorkerStore because it calls self.get_events class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers. @@ -60,8 +374,43 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): def __init__(self, db_conn, hs): super(StateGroupWorkerStore, self).__init__(db_conn, hs) + # Originally the state store used a single DictionaryCache to cache the + # event IDs for the state types in a given state group to avoid hammering + # on the state_group* tables. + # + # The point of using a DictionaryCache is that it can cache a subset + # of the state events for a given state group (i.e. a subset of the keys for a + # given dict which is an entry in the cache for a given state group ID). + # + # However, this poses problems when performing complicated queries + # on the store - for instance: "give me all the state for this group, but + # limit members to this subset of users", as DictionaryCache's API isn't + # rich enough to say "please cache any of these fields, apart from this subset". + # This is problematic when lazy loading members, which requires this behaviour, + # as without it the cache has no choice but to speculatively load all + # state events for the group, which negates the efficiency being sought. + # + # Rather than overcomplicating DictionaryCache's API, we instead split the + # state_group_cache into two halves - one for tracking non-member events, + # and the other for tracking member_events. This means that lazy loading + # queries can be made in a cache-friendly manner by querying both caches + # separately and then merging the result. So for the example above, you + # would query the members cache for a specific subset of state keys + # (which DictionaryCache will handle efficiently and fine) and the non-members + # cache for all state (which DictionaryCache will similarly handle fine) + # and then just merge the results together. + # + # We size the non-members cache to be smaller than the members cache as the + # vast majority of state in Matrix (today) is member events. + self._state_group_cache = DictionaryCache( - "*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache") + "*stateGroupCache*", + # TODO: this hasn't been tuned yet + 50000 * get_cache_factor_for("stateGroupCache") + ) + self._state_group_members_cache = DictionaryCache( + "*stateGroupMembersCache*", + 500000 * get_cache_factor_for("stateGroupMembersCache") ) @defer.inlineCallbacks @@ -117,61 +466,41 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) # FIXME: how should this be cached? - def get_filtered_current_state_ids(self, room_id, types, filtered_types=None): + def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()): """Get the current state event of a given type for a room based on the current_state_events table. This may not be as up-to-date as the result of doing a fresh state resolution as per state_handler.get_current_state + Args: room_id (str) - types (list[(Str, (Str|None))]): List of (type, state_key) tuples - which are used to filter the state fetched. `state_key` may be - None, which matches any `state_key` - filtered_types (list[Str]|None): List of types to apply the above filter to. + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: - deferred: dict of (type, state_key) -> event + Deferred[dict[tuple[str, str], str]]: Map from type/state_key to + event ID. """ - include_other_types = False if filtered_types is None else True - def _get_filtered_current_state_ids_txn(txn): results = {} - sql = """SELECT type, state_key, event_id FROM current_state_events - WHERE room_id = ? %s""" - # Turns out that postgres doesn't like doing a list of OR's and - # is about 1000x slower, so we just issue a query for each specific - # type seperately. - if types: - clause_to_args = [ - ( - "AND type = ? AND state_key = ?", - (etype, state_key) - ) if state_key is not None else ( - "AND type = ?", - (etype,) - ) - for etype, state_key in types - ] - - if include_other_types: - unique_types = set(filtered_types) - clause_to_args.append( - ( - "AND type <> ? " * len(unique_types), - list(unique_types) - ) - ) - else: - # If types is None we fetch all the state, and so just use an - # empty where clause with no extra args. - clause_to_args = [("", [])] - for where_clause, where_args in clause_to_args: - args = [room_id] - args.extend(where_args) - txn.execute(sql % (where_clause,), args) - for row in txn: - typ, state_key, event_id = row - key = (intern_string(typ), intern_string(state_key)) - results[key] = event_id + sql = """ + SELECT type, state_key, event_id FROM current_state_events + WHERE room_id = ? + """ + + where_clause, where_args = state_filter.make_sql_filter_clause() + + if where_clause: + sql += " AND (%s)" % (where_clause,) + + args = [room_id] + args.extend(where_args) + txn.execute(sql, args) + for row in txn: + typ, state_key, event_id = row + key = (intern_string(typ), intern_string(state_key)) + results[key] = event_id + return results return self.runInteraction( @@ -220,7 +549,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) @defer.inlineCallbacks - def get_state_groups_ids(self, room_id, event_ids): + def get_state_groups_ids(self, _room_id, event_ids): + """Get the event IDs of all the state for the state groups for the given events + + Args: + _room_id (str): id of the room for these events + event_ids (iterable[str]): ids of the events + + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ if not event_ids: defer.returnValue({}) @@ -235,7 +574,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): @defer.inlineCallbacks def get_state_ids_for_group(self, state_group): - """Get the state IDs for the given state group + """Get the event IDs of all the state in the given state group Args: state_group (int) @@ -251,7 +590,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): def get_state_groups(self, room_id, event_ids): """ Get the state groups for the given list of event_ids - The return value is a dict mapping group names to lists of events. + Returns: + Deferred[dict[int, list[EventBase]]]: + dict of state_group_id -> list of state events. """ if not event_ids: defer.returnValue({}) @@ -275,18 +616,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): }) @defer.inlineCallbacks - def _get_state_groups_from_groups(self, groups, types): + def _get_state_groups_from_groups(self, groups, state_filter): """Returns the state groups for a given set of groups, filtering on types of state events. Args: groups(list[int]): list of state group IDs to query - types (Iterable[str, str|None]|None): list of 2-tuples of the form - (`type`, `state_key`), where a `state_key` of `None` matches all - state_keys for the `type`. If None, all types are returned. - + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: - dictionary state_group -> (dict of (type, state_key) -> event id) + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) """ results = {} @@ -294,19 +634,23 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): for chunk in chunks: res = yield self.runInteraction( "_get_state_groups_from_groups", - self._get_state_groups_from_groups_txn, chunk, types, + self._get_state_groups_from_groups_txn, chunk, state_filter, ) results.update(res) defer.returnValue(results) def _get_state_groups_from_groups_txn( - self, txn, groups, types=None, + self, txn, groups, state_filter=StateFilter.all(), ): results = {group: {} for group in groups} - if types is not None: - types = list(set(types)) # deduplicate types list + where_clause, where_args = state_filter.make_sql_filter_clause() + + # Unless the filter clause is empty, we're going to append it after an + # existing where clause + if where_clause: + where_clause = " AND (%s)" % (where_clause,) if isinstance(self.database_engine, PostgresEngine): # Temporarily disable sequential scans in this transaction. This is @@ -322,69 +666,33 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # group for the given type, state_key. # This may return multiple rows per (type, state_key), but last_value # should be the same. - sql = (""" + sql = """ WITH RECURSIVE state(state_group) AS ( VALUES(?::bigint) UNION ALL SELECT prev_state_group FROM state_group_edges e, state s WHERE s.state_group = e.state_group ) - SELECT type, state_key, last_value(event_id) OVER ( + SELECT DISTINCT type, state_key, last_value(event_id) OVER ( PARTITION BY type, state_key ORDER BY state_group ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING ) AS event_id FROM state_groups_state WHERE state_group IN ( SELECT state_group FROM state ) - %s - """) - - # Turns out that postgres doesn't like doing a list of OR's and - # is about 1000x slower, so we just issue a query for each specific - # type seperately. - if types is not None: - clause_to_args = [ - ( - "AND type = ? AND state_key = ?", - (etype, state_key) - ) if state_key is not None else ( - "AND type = ?", - (etype,) - ) - for etype, state_key in types - ] - else: - # If types is None we fetch all the state, and so just use an - # empty where clause with no extra args. - clause_to_args = [("", [])] + """ - for where_clause, where_args in clause_to_args: - for group in groups: - args = [group] - args.extend(where_args) + for group in groups: + args = [group] + args.extend(where_args) - txn.execute(sql % (where_clause,), args) - for row in txn: - typ, state_key, event_id = row - key = (typ, state_key) - results[group][key] = event_id + txn.execute(sql + where_clause, args) + for row in txn: + typ, state_key, event_id = row + key = (typ, state_key) + results[group][key] = event_id else: - where_args = [] - where_clauses = [] - wildcard_types = False - if types is not None: - for typ in types: - if typ[1] is None: - where_clauses.append("(type = ?)") - where_args.append(typ[0]) - wildcard_types = True - else: - where_clauses.append("(type = ? AND state_key = ?)") - where_args.extend([typ[0], typ[1]]) - - where_clause = "AND (%s)" % (" OR ".join(where_clauses)) - else: - where_clause = "" + max_entries_returned = state_filter.max_entries_returned() # We don't use WITH RECURSIVE on sqlite3 as there are distributions # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) @@ -398,12 +706,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # without the right indices (which we can't add until # after we finish deduping state, which requires this func) args = [next_group] - if types: - args.extend(where_args) + args.extend(where_args) txn.execute( "SELECT type, state_key, event_id FROM state_groups_state" - " WHERE state_group = ? %s" % (where_clause,), + " WHERE state_group = ? " + where_clause, args ) results[group].update( @@ -419,9 +726,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # wildcards (i.e. Nones) in which case we have to do an exhaustive # search if ( - types is not None and - not wildcard_types and - len(results[group]) == len(types) + max_entries_returned is not None and + len(results[group]) == max_entries_returned ): break @@ -436,20 +742,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return results @defer.inlineCallbacks - def get_state_for_events(self, event_ids, types, filtered_types=None): + def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): """Given a list of event_ids and type tuples, return a list of state - dicts for each event. The state dicts will only have the type/state_keys - that are in the `types` list. + dicts for each event. Args: event_ids (list[string]) - types (list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: deferred: A dict of (event_id) -> (type, state_key) -> [state_events] @@ -459,7 +759,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, types, filtered_types) + group_to_state = yield self._get_state_for_groups(groups, state_filter) state_event_map = yield self.get_events( [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], @@ -478,20 +778,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): defer.returnValue({event: event_to_state[event] for event in event_ids}) @defer.inlineCallbacks - def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None): + def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): """ Get the state dicts corresponding to a list of events, containing the event_ids of the state events (as opposed to the events themselves) Args: event_ids(list(str)): events whose state should be returned - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: A deferred dict from event_id -> (type, state_key) -> event_id @@ -501,7 +796,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, types, filtered_types) + group_to_state = yield self._get_state_for_groups(groups, state_filter) event_to_state = { event_id: group_to_state[group] @@ -511,45 +806,35 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): defer.returnValue({event: event_to_state[event] for event in event_ids}) @defer.inlineCallbacks - def get_state_for_event(self, event_id, types=None, filtered_types=None): + def get_state_for_event(self, event_id, state_filter=StateFilter.all()): """ Get the state dict corresponding to a particular event Args: event_id(str): event whose state should be returned - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: A deferred dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_for_events([event_id], types, filtered_types) + state_map = yield self.get_state_for_events([event_id], state_filter) defer.returnValue(state_map[event_id]) @defer.inlineCallbacks - def get_state_ids_for_event(self, event_id, types=None, filtered_types=None): + def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): """ Get the state dict corresponding to a particular event Args: event_id(str): event whose state should be returned - types(list[(str, str|None)]|None): List of (type, state_key) tuples - which are used to filter the state fetched. If `state_key` is None, - all events are returned of the given type. - May be None, which matches any key. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: A deferred dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_ids_for_events([event_id], types, filtered_types) + state_map = yield self.get_state_ids_for_events([event_id], state_filter) defer.returnValue(state_map[event_id]) @cached(max_entries=50000) @@ -580,179 +865,207 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): defer.returnValue({row["event_id"]: row["state_group"] for row in rows}) - def _get_some_state_from_cache(self, group, types, filtered_types=None): + def _get_state_for_group_using_cache(self, cache, group, state_filter): """Checks if group is in cache. See `_get_state_for_groups` Args: + cache(DictionaryCache): the state group cache to use group(int): The state group to lookup - types(list[str, str|None]): List of 2-tuples of the form - (`type`, `state_key`), where a `state_key` of `None` matches all - state_keys for the `type`. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool indicating if we successfully retrieved all requests state from the cache, if False we need to query the DB for the missing state. """ - is_all, known_absent, state_dict_ids = self._state_group_cache.get(group) + is_all, known_absent, state_dict_ids = cache.get(group) - type_to_key = {} + if is_all or state_filter.is_full(): + # Either we have everything or want everything, either way + # `is_all` tells us whether we've gotten everything. + return state_filter.filter_state(state_dict_ids), is_all - # tracks whether any of ourrequested types are missing from the cache + # tracks whether any of our requested types are missing from the cache missing_types = False - for typ, state_key in types: - key = (typ, state_key) - - if ( - state_key is None or - (filtered_types is not None and typ not in filtered_types) - ): - type_to_key[typ] = None - # we mark the type as missing from the cache because - # when the cache was populated it might have been done with a - # restricted set of state_keys, so the wildcard will not work - # and the cache may be incomplete. - missing_types = True - else: - if type_to_key.get(typ, object()) is not None: - type_to_key.setdefault(typ, set()).add(state_key) - + if state_filter.has_wildcards(): + # We don't know if we fetched all the state keys for the types in + # the filter that are wildcards, so we have to assume that we may + # have missed some. + missing_types = True + else: + # There aren't any wild cards, so `concrete_types()` returns the + # complete list of event types we're wanting. + for key in state_filter.concrete_types(): if key not in state_dict_ids and key not in known_absent: missing_types = True + break - sentinel = object() - - def include(typ, state_key): - valid_state_keys = type_to_key.get(typ, sentinel) - if valid_state_keys is sentinel: - return filtered_types is not None and typ not in filtered_types - if valid_state_keys is None: - return True - if state_key in valid_state_keys: - return True - return False - - got_all = is_all - if not got_all: - # the cache is incomplete. We may still have got all the results we need, if - # we don't have any wildcards in the match list. - if not missing_types and filtered_types is None: - got_all = True - - return { - k: v for k, v in iteritems(state_dict_ids) - if include(k[0], k[1]) - }, got_all - - def _get_all_state_from_cache(self, group): - """Checks if group is in cache. See `_get_state_for_groups` + return state_filter.filter_state(state_dict_ids), not missing_types - Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool - indicating if we successfully retrieved all requests state from the - cache, if False we need to query the DB for the missing state. + @defer.inlineCallbacks + def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key Args: - group: The state group to lookup + groups (iterable[int]): list of state groups for which we want + to get the state. + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) """ - is_all, _, state_dict_ids = self._state_group_cache.get(group) - return state_dict_ids, is_all + member_filter, non_member_filter = state_filter.get_member_split() - @defer.inlineCallbacks - def _get_state_for_groups(self, groups, types=None, filtered_types=None): + # Now we look them up in the member and non-member caches + non_member_state, incomplete_groups_nm, = ( + yield self._get_state_for_groups_using_cache( + groups, self._state_group_cache, + state_filter=non_member_filter, + ) + ) + + member_state, incomplete_groups_m, = ( + yield self._get_state_for_groups_using_cache( + groups, self._state_group_members_cache, + state_filter=member_filter, + ) + ) + + state = dict(non_member_state) + for group in groups: + state[group].update(member_state[group]) + + # Now fetch any missing groups from the database + + incomplete_groups = incomplete_groups_m | incomplete_groups_nm + + if not incomplete_groups: + defer.returnValue(state) + + cache_sequence_nm = self._state_group_cache.sequence + cache_sequence_m = self._state_group_members_cache.sequence + + # Help the cache hit ratio by expanding the filter a bit + db_state_filter = state_filter.return_expanded() + + group_to_state_dict = yield self._get_state_groups_from_groups( + list(incomplete_groups), + state_filter=db_state_filter, + ) + + # Now lets update the caches + self._insert_into_cache( + group_to_state_dict, + db_state_filter, + cache_seq_num_members=cache_sequence_m, + cache_seq_num_non_members=cache_sequence_nm, + ) + + # And finally update the result dict, by filtering out any extra + # stuff we pulled out of the database. + for group, group_state_dict in iteritems(group_to_state_dict): + # We just replace any existing entries, as we will have loaded + # everything we need from the database anyway. + state[group] = state_filter.filter_state(group_state_dict) + + defer.returnValue(state) + + def _get_state_for_groups_using_cache( + self, groups, cache, state_filter, + ): """Gets the state at each of a list of state groups, optionally - filtering by type/state_key + filtering by type/state_key, querying from a specific cache. Args: groups (iterable[int]): list of state groups for which we want to get the state. - types (None|iterable[(str, None|str)]): - indicates the state type/keys required. If None, the whole - state is fetched and returned. - - Otherwise, each entry should be a `(type, state_key)` tuple to - include in the response. A `state_key` of None is a wildcard - meaning that we require all state with that type. - filtered_types(list[str]|None): Only apply filtering via `types` to this - list of event types. Other types of events are returned unfiltered. - If None, `types` filtering is applied to all events. + cache (DictionaryCache): the cache of group ids to state dicts which + we will pass through - either the normal state cache or the specific + members state cache. + state_filter (StateFilter): The state filter used to fetch state + from the database. Returns: - Deferred[dict[int, dict[(type, state_key), EventBase]]] - a dictionary mapping from state group to state dictionary. + tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of + dict of state_group_id -> (dict of (type, state_key) -> event id) + of entries in the cache, and the state group ids either missing + from the cache or incomplete. """ - if types: - types = frozenset(types) results = {} - missing_groups = [] - if types is not None: - for group in set(groups): - state_dict_ids, got_all = self._get_some_state_from_cache( - group, types, filtered_types - ) - results[group] = state_dict_ids + incomplete_groups = set() + for group in set(groups): + state_dict_ids, got_all = self._get_state_for_group_using_cache( + cache, group, state_filter + ) + results[group] = state_dict_ids - if not got_all: - missing_groups.append(group) - else: - for group in set(groups): - state_dict_ids, got_all = self._get_all_state_from_cache( - group - ) + if not got_all: + incomplete_groups.add(group) + + return results, incomplete_groups - results[group] = state_dict_ids + def _insert_into_cache(self, group_to_state_dict, state_filter, + cache_seq_num_members, cache_seq_num_non_members): + """Inserts results from querying the database into the relevant cache. - if not got_all: - missing_groups.append(group) + Args: + group_to_state_dict (dict): The new entries pulled from database. + Map from state group to state dict + state_filter (StateFilter): The state filter used to fetch state + from the database. + cache_seq_num_members (int): Sequence number of member cache since + last lookup in cache + cache_seq_num_non_members (int): Sequence number of member cache since + last lookup in cache + """ - if missing_groups: - # Okay, so we have some missing_types, lets fetch them. - cache_seq_num = self._state_group_cache.sequence + # We need to work out which types we've fetched from the DB for the + # member vs non-member caches. This should be as accurate as possible, + # but can be an underestimate (e.g. when we have wild cards) - # the DictionaryCache knows if it has *all* the state, but - # does not know if it has all of the keys of a particular type, - # which makes wildcard lookups expensive unless we have a complete - # cache. Hence, if we are doing a wildcard lookup, populate the - # cache fully so that we can do an efficient lookup next time. + member_filter, non_member_filter = state_filter.get_member_split() + if member_filter.is_full(): + # We fetched all member events + member_types = None + else: + # `concrete_types()` will only return a subset when there are wild + # cards in the filter, but that's fine. + member_types = member_filter.concrete_types() - if filtered_types or (types and any(k is None for (t, k) in types)): - types_to_fetch = None - else: - types_to_fetch = types + if non_member_filter.is_full(): + # We fetched all non member events + non_member_types = None + else: + non_member_types = non_member_filter.concrete_types() - group_to_state_dict = yield self._get_state_groups_from_groups( - missing_groups, types_to_fetch - ) + for group, group_state_dict in iteritems(group_to_state_dict): + state_dict_members = {} + state_dict_non_members = {} - for group, group_state_dict in iteritems(group_to_state_dict): - state_dict = results[group] - - # update the result, filtering by `types`. - if types: - for k, v in iteritems(group_state_dict): - (typ, _) = k - if ( - (k in types or (typ, None) in types) or - (filtered_types and typ not in filtered_types) - ): - state_dict[k] = v + for k, v in iteritems(group_state_dict): + if k[0] == EventTypes.Member: + state_dict_members[k] = v else: - state_dict.update(group_state_dict) - - # update the cache with all the things we fetched from the - # database. - self._state_group_cache.update( - cache_seq_num, - key=group, - value=group_state_dict, - fetched_keys=types_to_fetch, - ) + state_dict_non_members[k] = v - defer.returnValue(results) + self._state_group_members_cache.update( + cache_seq_num_members, + key=group, + value=state_dict_members, + fetched_keys=member_types, + ) + + self._state_group_cache.update( + cache_seq_num_non_members, + key=group, + value=state_dict_non_members, + fetched_keys=non_member_types, + ) def store_state_group(self, event_id, room_id, prev_group, delta_ids, current_state_ids): @@ -847,15 +1160,33 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ], ) - # Prefill the state group cache with this group. + # Prefill the state group caches with this group. # It's fine to use the sequence like this as the state group map # is immutable. (If the map wasn't immutable then this prefill could # race with another update) + + current_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] == EventTypes.Member + } + txn.call_after( + self._state_group_members_cache.update, + self._state_group_members_cache.sequence, + key=state_group, + value=dict(current_member_state_ids), + ) + + current_non_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] != EventTypes.Member + } txn.call_after( self._state_group_cache.update, self._state_group_cache.sequence, key=state_group, - value=dict(current_state_ids), + value=dict(current_non_member_state_ids), ) return state_group @@ -1043,12 +1374,12 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): continue prev_state = self._get_state_groups_from_groups_txn( - txn, [prev_group], types=None + txn, [prev_group], ) prev_state = prev_state[prev_group] curr_state = self._get_state_groups_from_groups_txn( - txn, [state_group], types=None + txn, [state_group], ) curr_state = curr_state[state_group] diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 4c296d72c0..d6cfdba519 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -630,7 +630,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): @defer.inlineCallbacks def get_all_new_events_stream(self, from_id, current_id, limit): - """Get all new events""" + """Get all new events + + Returns all events with from_id < stream_ordering <= current_id. + + Args: + from_id (int): the stream_ordering of the last event we processed + current_id (int): the stream_ordering of the most recently processed event + limit (int): the maximum number of events to return + + Returns: + Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where + `next_id` is the next value to pass as `from_id` (it will either be the + stream_ordering of the last returned event, or, if fewer than `limit` events + were found, `current_id`. + """ def get_all_new_events_stream_txn(txn): sql = ( diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 428e7fa36e..d8bf953ec0 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -18,19 +18,19 @@ from collections import namedtuple import six -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.caches.descriptors import cached +from synapse.util.caches.expiringcache import ExpiringCache -from ._base import SQLBaseStore +from ._base import SQLBaseStore, db_to_json # py2 sqlite has buffer hardcoded as only binary type, so we must use it, # despite being deprecated and removed in favor of memoryview if six.PY2: - db_binary_type = buffer + db_binary_type = six.moves.builtins.buffer else: db_binary_type = memoryview @@ -50,6 +50,8 @@ _UpdateTransactionRow = namedtuple( ) ) +SENTINEL = object() + class TransactionStore(SQLBaseStore): """A collection of queries for handling PDUs. @@ -60,6 +62,12 @@ class TransactionStore(SQLBaseStore): self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) + self._destination_retry_cache = ExpiringCache( + cache_name="get_destination_retry_timings", + clock=self._clock, + expiry_ms=5 * 60 * 1000, + ) + def get_received_txn_response(self, transaction_id, origin): """For an incoming transaction from a given origin, check if we have already responded to it. If so, return the response code and response @@ -95,7 +103,8 @@ class TransactionStore(SQLBaseStore): ) if result and result["response_code"]: - return result["response_code"], json.loads(str(result["response_json"])) + return result["response_code"], db_to_json(result["response_json"]) + else: return None @@ -155,7 +164,7 @@ class TransactionStore(SQLBaseStore): """ pass - @cached(max_entries=10000) + @defer.inlineCallbacks def get_destination_retry_timings(self, destination): """Gets the current retry timings (if any) for a given destination. @@ -166,10 +175,20 @@ class TransactionStore(SQLBaseStore): None if not retrying Otherwise a dict for the retry scheme """ - return self.runInteraction( + + result = self._destination_retry_cache.get(destination, SENTINEL) + if result is not SENTINEL: + defer.returnValue(result) + + result = yield self.runInteraction( "get_destination_retry_timings", self._get_destination_retry_timings, destination) + # We don't hugely care about race conditions between getting and + # invalidating the cache, since we time out fairly quickly anyway. + self._destination_retry_cache[destination] = result + defer.returnValue(result) + def _get_destination_retry_timings(self, txn, destination): result = self._simple_select_one_txn( txn, @@ -197,8 +216,7 @@ class TransactionStore(SQLBaseStore): retry_interval (int) - how long until next retry in ms """ - # XXX: we could chose to not bother persisting this if our cache thinks - # this is a NOOP + self._destination_retry_cache.pop(destination, None) return self.runInteraction( "set_destination_retry_timings", self._set_destination_retry_timings, @@ -211,10 +229,6 @@ class TransactionStore(SQLBaseStore): retry_last_ts, retry_interval): self.database_engine.lock_table(txn, "destinations") - self._invalidate_cache_and_stream( - txn, self.get_destination_retry_timings, (destination,) - ) - # We need to be careful here as the data may have changed from under us # due to a worker setting the timings. |