diff options
-rw-r--r-- | synapse/handlers/federation.py | 26 | ||||
-rw-r--r-- | synapse/push/bulk_push_rule_evaluator.py | 16 | ||||
-rw-r--r-- | synapse/replication/slave/storage/account_data.py | 43 | ||||
-rw-r--r-- | synapse/replication/slave/storage/events.py | 26 | ||||
-rw-r--r-- | synapse/replication/slave/storage/push_rule.py | 24 | ||||
-rw-r--r-- | synapse/replication/slave/storage/pushers.py | 12 | ||||
-rw-r--r-- | synapse/storage/__init__.py | 21 | ||||
-rw-r--r-- | synapse/storage/account_data.py | 76 | ||||
-rw-r--r-- | synapse/storage/event_push_actions.py | 215 | ||||
-rw-r--r-- | synapse/storage/events.py | 395 | ||||
-rw-r--r-- | synapse/storage/events_worker.py | 395 | ||||
-rw-r--r-- | synapse/storage/push_rule.py | 72 | ||||
-rw-r--r-- | synapse/storage/pusher.py | 11 | ||||
-rw-r--r-- | synapse/storage/tags.py | 16 | ||||
-rw-r--r-- | tests/replication/slave/storage/test_events.py | 10 | ||||
-rw-r--r-- | tests/storage/test_event_push_actions.py | 4 |
16 files changed, 734 insertions, 628 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 46bcf8b081..8832ba58bc 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1447,16 +1447,24 @@ class FederationHandler(BaseHandler): auth_events=auth_events, ) - if not event.internal_metadata.is_outlier() and not backfilled: - yield self.action_generator.handle_push_actions_for_event( - event, context - ) + try: + if not event.internal_metadata.is_outlier() and not backfilled: + yield self.action_generator.handle_push_actions_for_event( + event, context + ) - event_stream_id, max_stream_id = yield self.store.persist_event( - event, - context=context, - backfilled=backfilled, - ) + event_stream_id, max_stream_id = yield self.store.persist_event( + event, + context=context, + backfilled=backfilled, + ) + except: # noqa: E722, as we reraise the exception this is fine. + # Ensure that we actually remove the entries in the push actions + # staging area + logcontext.preserve_fn( + self.store.remove_push_actions_from_staging + )(event.event_id) + raise if not backfilled: # this intentionally does not yield: we don't care about the result diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index bf4f1c5836..7c680659b6 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -144,6 +144,7 @@ class BulkPushRuleEvaluator(object): Deferred """ rules_by_user = yield self._get_rules_for_event(event, context) + actions_by_user = {} room_members = yield self.store.get_joined_users_from_context( event, context @@ -189,14 +190,17 @@ class BulkPushRuleEvaluator(object): if matches: actions = [x for x in rule['actions'] if x != 'dont_notify'] if actions and 'notify' in actions: - # Push rules say we should notify the user of this event, - # so we mark it in the DB in the staging area. (This - # will then get handled when we persist the event) - yield self.store.add_push_actions_to_staging( - event.event_id, uid, actions, - ) + # Push rules say we should notify the user of this event + actions_by_user[uid] = actions break + # Mark in the DB staging area the push actions for users who should be + # notified for this event. (This will then get handled when we persist + # the event) + yield self.store.add_push_actions_to_staging( + event.event_id, actions_by_user, + ) + def _condition_checker(evaluator, conditions, uid, display_name, cache): for cond in conditions: diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index efbd87918e..6c8d2954d7 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,50 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker -from synapse.storage import DataStore -from synapse.storage.account_data import AccountDataStore -from synapse.storage.tags import TagsStore -from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.storage.account_data import AccountDataWorkerStore +from synapse.storage.tags import TagsWorkerStore -class SlavedAccountDataStore(BaseSlavedStore): +class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): - super(SlavedAccountDataStore, self).__init__(db_conn, hs) self._account_data_id_gen = SlavedIdTracker( db_conn, "account_data_max_stream_id", "stream_id", ) - self._account_data_stream_cache = StreamChangeCache( - "AccountDataAndTagsChangeCache", - self._account_data_id_gen.get_current_token(), - ) - - get_account_data_for_user = ( - AccountDataStore.__dict__["get_account_data_for_user"] - ) - - get_global_account_data_by_type_for_users = ( - AccountDataStore.__dict__["get_global_account_data_by_type_for_users"] - ) - get_global_account_data_by_type_for_user = ( - AccountDataStore.__dict__["get_global_account_data_by_type_for_user"] - ) - - get_tags_for_user = TagsStore.__dict__["get_tags_for_user"] - get_tags_for_room = ( - DataStore.get_tags_for_room.__func__ - ) - get_account_data_for_room = ( - DataStore.get_account_data_for_room.__func__ - ) - - get_updated_tags = DataStore.get_updated_tags.__func__ - get_updated_account_data_for_user = ( - DataStore.get_updated_account_data_for_user.__func__ - ) + super(SlavedAccountDataStore, self).__init__(db_conn, hs) def get_max_account_data_stream_id(self): return self._account_data_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 74a81a0a51..44499dc73f 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,8 +18,8 @@ import logging from synapse.api.constants import EventTypes from synapse.storage import DataStore from synapse.storage.event_federation import EventFederationStore -from synapse.storage.event_push_actions import EventPushActionsStore -from synapse.storage.events import EventsWorkerStore +from synapse.storage.event_push_actions import EventPushActionsWorkerStore +from synapse.storage.events_worker import EventsWorkerStore from synapse.storage.roommember import RoomMemberStore from synapse.storage.state import StateGroupWorkerStore from synapse.storage.stream import StreamStore @@ -39,7 +40,8 @@ logger = logging.getLogger(__name__) # the method descriptor on the DataStore and chuck them into our class. -class SlavedEventStore(EventsWorkerStore, StateGroupWorkerStore, BaseSlavedStore): +class SlavedEventStore(EventPushActionsWorkerStore, EventsWorkerStore, + StateGroupWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedEventStore, self).__init__(db_conn, hs) @@ -81,30 +83,12 @@ class SlavedEventStore(EventsWorkerStore, StateGroupWorkerStore, BaseSlavedStore get_invited_rooms_for_user = RoomMemberStore.__dict__[ "get_invited_rooms_for_user" ] - get_unread_event_push_actions_by_room_for_user = ( - EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"] - ) - _get_unread_counts_by_receipt_txn = ( - DataStore._get_unread_counts_by_receipt_txn.__func__ - ) - _get_unread_counts_by_pos_txn = ( - DataStore._get_unread_counts_by_pos_txn.__func__ - ) get_recent_event_ids_for_room = ( StreamStore.__dict__["get_recent_event_ids_for_room"] ) _get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"] has_room_changed_since = DataStore.has_room_changed_since.__func__ - get_unread_push_actions_for_user_in_range_for_http = ( - DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__ - ) - get_unread_push_actions_for_user_in_range_for_email = ( - DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__ - ) - get_push_action_users_in_range = ( - DataStore.get_push_action_users_in_range.__func__ - ) get_rooms_for_user_where_membership_is = ( DataStore.get_rooms_for_user_where_membership_is.__func__ ) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 83e880fdd2..bb2c40b6e3 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,29 +16,15 @@ from .events import SlavedEventStore from ._slaved_id_tracker import SlavedIdTracker -from synapse.storage import DataStore -from synapse.storage.push_rule import PushRuleStore -from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.storage.push_rule import PushRulesWorkerStore -class SlavedPushRuleStore(SlavedEventStore): +class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore): def __init__(self, db_conn, hs): - super(SlavedPushRuleStore, self).__init__(db_conn, hs) self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id", ) - self.push_rules_stream_cache = StreamChangeCache( - "PushRulesStreamChangeCache", - self._push_rules_stream_id_gen.get_current_token(), - ) - - get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"] - get_push_rules_enabled_for_user = ( - PushRuleStore.__dict__["get_push_rules_enabled_for_user"] - ) - have_push_rules_changed_for_user = ( - DataStore.have_push_rules_changed_for_user.__func__ - ) + super(SlavedPushRuleStore, self).__init__(db_conn, hs) def get_push_rules_stream_token(self): return ( @@ -45,6 +32,9 @@ class SlavedPushRuleStore(SlavedEventStore): self._stream_id_gen.get_current_token(), ) + def get_max_push_rules_stream_id(self): + return self._push_rules_stream_id_gen.get_current_token() + def stream_positions(self): result = super(SlavedPushRuleStore, self).stream_positions() result["push_rules"] = self._push_rules_stream_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index 4e8d68ece9..a7cd5a7291 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,10 +17,10 @@ from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker -from synapse.storage import DataStore +from synapse.storage.pusher import PusherWorkerStore -class SlavedPusherStore(BaseSlavedStore): +class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedPusherStore, self).__init__(db_conn, hs) @@ -28,13 +29,6 @@ class SlavedPusherStore(BaseSlavedStore): extra_tables=[("deleted_pushers", "stream_id")], ) - get_all_pushers = DataStore.get_all_pushers.__func__ - get_pushers_by = DataStore.get_pushers_by.__func__ - get_pushers_by_app_id_and_pushkey = ( - DataStore.get_pushers_by_app_id_and_pushkey.__func__ - ) - _decode_pushers_rows = DataStore._decode_pushers_rows.__func__ - def stream_positions(self): result = super(SlavedPusherStore, self).stream_positions() result["pushers"] = self._pushers_id_gen.get_current_token() diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index e1c4fe086e..0f136f8a06 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -104,9 +105,6 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "events", "stream_ordering", step=-1, extra_tables=[("ex_outlier_stream", "event_stream_ordering")] ) - self._account_data_id_gen = StreamIdGenerator( - db_conn, "account_data_max_stream_id", "stream_id" - ) self._presence_id_gen = StreamIdGenerator( db_conn, "presence_stream", "stream_id" ) @@ -159,11 +157,6 @@ class DataStore(RoomMemberStore, RoomStore, "MembershipStreamChangeCache", events_max, ) - account_max = self._account_data_id_gen.get_current_token() - self._account_data_stream_cache = StreamChangeCache( - "AccountDataAndTagsChangeCache", account_max, - ) - self._presence_on_startup = self._get_active_presence(db_conn) presence_cache_prefill, min_presence_val = self._get_cache_dict( @@ -177,18 +170,6 @@ class DataStore(RoomMemberStore, RoomStore, prefilled_cache=presence_cache_prefill ) - push_rules_prefill, push_rules_id = self._get_cache_dict( - db_conn, "push_rules_stream", - entity_column="user_id", - stream_column="stream_id", - max_value=self._push_rules_stream_id_gen.get_current_token()[0], - ) - - self.push_rules_stream_cache = StreamChangeCache( - "PushRulesStreamChangeCache", push_rules_id, - prefilled_cache=push_rules_prefill, - ) - max_device_inbox_id = self._device_inbox_id_gen.get_current_token() device_inbox_prefill, min_device_inbox_id = self._get_cache_dict( db_conn, "device_inbox", diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index 56a0bde549..466194e96f 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,18 +14,46 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore from twisted.internet import defer +from synapse.storage._base import SQLBaseStore +from synapse.storage.util.id_generators import StreamIdGenerator + +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks +import abc import ujson as json import logging logger = logging.getLogger(__name__) -class AccountDataStore(SQLBaseStore): +class AccountDataWorkerStore(SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_max_account_data_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, db_conn, hs): + account_max = self.get_max_account_data_stream_id() + self._account_data_stream_cache = StreamChangeCache( + "AccountDataAndTagsChangeCache", account_max, + ) + + super(AccountDataWorkerStore, self).__init__(db_conn, hs) + + @abc.abstractmethod + def get_max_account_data_stream_id(self): + """Get the current max stream ID for account data stream + + Returns: + int + """ + raise NotImplementedError() @cached() def get_account_data_for_user(self, user_id): @@ -209,6 +238,36 @@ class AccountDataStore(SQLBaseStore): "get_updated_account_data_for_user", get_updated_account_data_for_user_txn ) + @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) + def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): + ignored_account_data = yield self.get_global_account_data_by_type_for_user( + "m.ignored_user_list", ignorer_user_id, + on_invalidate=cache_context.invalidate, + ) + if not ignored_account_data: + defer.returnValue(False) + + defer.returnValue( + ignored_user_id in ignored_account_data.get("ignored_users", {}) + ) + + +class AccountDataStore(AccountDataWorkerStore): + def __init__(self, db_conn, hs): + self._account_data_id_gen = StreamIdGenerator( + db_conn, "account_data_max_stream_id", "stream_id" + ) + + super(AccountDataStore, self).__init__(db_conn, hs) + + def get_max_account_data_stream_id(self): + """Get the current max stream id for the private user data stream + + Returns: + A deferred int. + """ + return self._account_data_id_gen.get_current_token() + @defer.inlineCallbacks def add_account_data_to_room(self, user_id, room_id, account_data_type, content): """Add some account_data to a room for a user. @@ -321,16 +380,3 @@ class AccountDataStore(SQLBaseStore): "update_account_data_max_stream_id", _update, ) - - @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) - def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): - ignored_account_data = yield self.get_global_account_data_by_type_for_user( - "m.ignored_user_list", ignorer_user_id, - on_invalidate=cache_context.invalidate, - ) - if not ignored_account_data: - defer.returnValue(False) - - defer.returnValue( - ignored_user_id in ignored_account_data.get("ignored_users", {}) - ) diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index f787431b7a..fe6887414e 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -62,77 +63,7 @@ def _deserialize_action(actions, is_highlight): return DEFAULT_NOTIF_ACTION -class EventPushActionsStore(SQLBaseStore): - EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - - def __init__(self, db_conn, hs): - super(EventPushActionsStore, self).__init__(db_conn, hs) - - self.register_background_index_update( - self.EPA_HIGHLIGHT_INDEX, - index_name="event_push_actions_u_highlight", - table="event_push_actions", - columns=["user_id", "stream_ordering"], - ) - - self.register_background_index_update( - "event_push_actions_highlights_index", - index_name="event_push_actions_highlights_index", - table="event_push_actions", - columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], - where_clause="highlight=1" - ) - - self._doing_notif_rotation = False - self._rotate_notif_loop = self._clock.looping_call( - self._rotate_notifs, 30 * 60 * 1000 - ) - - def _set_push_actions_for_event_and_users_txn(self, txn, event): - """ - Args: - event: the event set actions for - tuples: list of tuples of (user_id, actions) - """ - - sql = """ - INSERT INTO event_push_actions ( - room_id, event_id, user_id, actions, stream_ordering, - topological_ordering, notif, highlight - ) - SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight - FROM event_push_actions_staging - WHERE event_id = ? - """ - - txn.execute(sql, ( - event.room_id, event.internal_metadata.stream_ordering, - event.depth, event.event_id, - )) - - user_ids = self._simple_select_onecol_txn( - txn, - table="event_push_actions_staging", - keyvalues={ - "event_id": event.event_id, - }, - retcol="user_id", - ) - - self._simple_delete_txn( - txn, - table="event_push_actions_staging", - keyvalues={ - "event_id": event.event_id, - }, - ) - - for uid in user_ids: - txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (event.room_id, uid,) - ) - +class EventPushActionsWorkerStore(SQLBaseStore): @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000) def get_unread_event_push_actions_by_room_for_user( self, room_id, user_id, last_read_event_id @@ -449,6 +380,95 @@ class EventPushActionsStore(SQLBaseStore): # Now return the first `limit` defer.returnValue(notifs[:limit]) + +class EventPushActionsStore(EventPushActionsWorkerStore): + EPA_HIGHLIGHT_INDEX = "epa_highlight_index" + + def __init__(self, db_conn, hs): + super(EventPushActionsStore, self).__init__(db_conn, hs) + + self.register_background_index_update( + self.EPA_HIGHLIGHT_INDEX, + index_name="event_push_actions_u_highlight", + table="event_push_actions", + columns=["user_id", "stream_ordering"], + ) + + self.register_background_index_update( + "event_push_actions_highlights_index", + index_name="event_push_actions_highlights_index", + table="event_push_actions", + columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], + where_clause="highlight=1" + ) + + self._doing_notif_rotation = False + self._rotate_notif_loop = self._clock.looping_call( + self._rotate_notifs, 30 * 60 * 1000 + ) + + def _set_push_actions_for_event_and_users_txn(self, txn, events_and_contexts, + all_events_and_contexts): + """Handles moving push actions from staging table to main + event_push_actions table for all events in `events_and_contexts`. + + Also ensures that all events in `all_events_and_contexts` are removed + from the push action staging area. + + Args: + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + all_events_and_contexts (list[(EventBase, EventContext)]): all + events that we were going to persist. This includes events + we've already persisted, etc, that wouldn't appear in + events_and_context. + """ + + sql = """ + INSERT INTO event_push_actions ( + room_id, event_id, user_id, actions, stream_ordering, + topological_ordering, notif, highlight + ) + SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight + FROM event_push_actions_staging + WHERE event_id = ? + """ + + if events_and_contexts: + txn.executemany(sql, ( + ( + event.room_id, event.internal_metadata.stream_ordering, + event.depth, event.event_id, + ) + for event, _ in events_and_contexts + )) + + for event, _ in events_and_contexts: + user_ids = self._simple_select_onecol_txn( + txn, + table="event_push_actions_staging", + keyvalues={ + "event_id": event.event_id, + }, + retcol="user_id", + ) + + for uid in user_ids: + txn.call_after( + self.get_unread_event_push_actions_by_room_for_user.invalidate_many, + (event.room_id, uid,) + ) + + # Now we delete the staging area for *all* events that were being + # persisted. + txn.executemany( + "DELETE FROM event_push_actions_staging WHERE event_id = ?", + ( + (event.event_id,) + for event, _ in all_events_and_contexts + ) + ) + @defer.inlineCallbacks def get_push_actions_for_user(self, user_id, before=None, limit=50, only_highlight=False): @@ -755,32 +775,51 @@ class EventPushActionsStore(SQLBaseStore): (rotate_to_stream_ordering,) ) - def add_push_actions_to_staging(self, event_id, user_id, actions): - """Add the push actions for the user and event to the push - action staging area. + def add_push_actions_to_staging(self, event_id, user_id_actions): + """Add the push actions for the event to the push action staging area. Args: event_id (str) - user_id (str) - actions (list[dict|str]): An action can either be a string or - dict. + user_id_actions (dict[str, list[dict|str])]): A dictionary mapping + user_id to list of push actions, where an action can either be + a string or dict. Returns: Deferred """ - is_highlight = 1 if _action_has_highlight(actions) else 0 + if not user_id_actions: + return - return self._simple_insert( - table="event_push_actions_staging", - values={ - "event_id": event_id, - "user_id": user_id, - "actions": _serialize_action(actions, is_highlight), - "notif": 1, - "highlight": is_highlight, - }, - desc="add_push_actions_to_staging", + # This is a helper function for generating the necessary tuple that + # can be used to inert into the `event_push_actions_staging` table. + def _gen_entry(user_id, actions): + is_highlight = 1 if _action_has_highlight(actions) else 0 + return ( + event_id, # event_id column + user_id, # user_id column + _serialize_action(actions, is_highlight), # actions column + 1, # notif column + is_highlight, # highlight column + ) + + def _add_push_actions_to_staging_txn(txn): + # We don't use _simple_insert_many here to avoid the overhead + # of generating lists of dicts. + + sql = """ + INSERT INTO event_push_actions_staging + (event_id, user_id, actions, notif, highlight) + VALUES (?, ?, ?, ?, ?) + """ + + txn.executemany(sql, ( + _gen_entry(user_id, actions) + for user_id, actions in user_id_actions.iteritems() + )) + + return self.runInteraction( + "add_push_actions_to_staging", _add_push_actions_to_staging_txn ) def remove_push_actions_from_staging(self, event_id): diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 681a33314d..b63392a6cd 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,16 +13,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from twisted.internet import defer, reactor +from synapse.storage.events_worker import EventsWorkerStore -from synapse.events import FrozenEvent, USE_FROZEN_DICTS -from synapse.events.utils import prune_event +from twisted.internet import defer + +from synapse.events import USE_FROZEN_DICTS from synapse.util.async import ObservableDeferred from synapse.util.logcontext import ( - preserve_fn, PreserveLoggingContext, make_deferred_yieldable + PreserveLoggingContext, make_deferred_yieldable ) from synapse.util.logutils import log_function from synapse.util.metrics import Measure @@ -61,16 +62,6 @@ def encode_json(json_object): return json.dumps(json_object, ensure_ascii=False) -# These values are used in the `enqueus_event` and `_do_fetch` methods to -# control how we batch/bulk fetch events from the database. -# The values are plucked out of thing air to make initial sync run faster -# on jki.re -# TODO: Make these configurable. -EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events -EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events -EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events - - class _EventPeristenceQueue(object): """Queues up events so that they can be persisted in bulk with only one concurrent transaction per room. @@ -199,362 +190,12 @@ def _retry_on_integrity_error(func): return f -class EventsWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(EventsWorkerStore, self).__init__(db_conn, hs) - - self._event_persist_queue = _EventPeristenceQueue() - - @defer.inlineCallbacks - def get_event(self, event_id, check_redacted=True, - get_prev_content=False, allow_rejected=False, - allow_none=False): - """Get an event from the database by event_id. - - Args: - event_id (str): The event_id of the event to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, - include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. - allow_none (bool): If True, return None if no event found, if - False throw an exception. - - Returns: - Deferred : A FrozenEvent. - """ - events = yield self._get_events( - [event_id], - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - ) - - if not events and not allow_none: - raise SynapseError(404, "Could not find event %s" % (event_id,)) - - defer.returnValue(events[0] if events else None) - - @defer.inlineCallbacks - def get_events(self, event_ids, check_redacted=True, - get_prev_content=False, allow_rejected=False): - """Get events from the database - - Args: - event_ids (list): The event_ids of the events to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, - include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. - - Returns: - Deferred : Dict from event_id to event. - """ - events = yield self._get_events( - event_ids, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - ) - - defer.returnValue({e.event_id: e for e in events}) - - @defer.inlineCallbacks - def _get_events(self, event_ids, check_redacted=True, - get_prev_content=False, allow_rejected=False): - if not event_ids: - defer.returnValue([]) - - event_id_list = event_ids - event_ids = set(event_ids) - - event_entry_map = self._get_events_from_cache( - event_ids, - allow_rejected=allow_rejected, - ) - - missing_events_ids = [e for e in event_ids if e not in event_entry_map] - - if missing_events_ids: - missing_events = yield self._enqueue_events( - missing_events_ids, - check_redacted=check_redacted, - allow_rejected=allow_rejected, - ) - - event_entry_map.update(missing_events) - - events = [] - for event_id in event_id_list: - entry = event_entry_map.get(event_id, None) - if not entry: - continue - - if allow_rejected or not entry.event.rejected_reason: - if check_redacted and entry.redacted_event: - event = entry.redacted_event - else: - event = entry.event - - events.append(event) - - if get_prev_content: - if "replaces_state" in event.unsigned: - prev = yield self.get_event( - event.unsigned["replaces_state"], - get_prev_content=False, - allow_none=True, - ) - if prev: - event.unsigned = dict(event.unsigned) - event.unsigned["prev_content"] = prev.content - event.unsigned["prev_sender"] = prev.sender - - defer.returnValue(events) - - def _invalidate_get_event_cache(self, event_id): - self._get_event_cache.invalidate((event_id,)) - - def _get_events_from_cache(self, events, allow_rejected, update_metrics=True): - """Fetch events from the caches - - Args: - events (list(str)): list of event_ids to fetch - allow_rejected (bool): Whether to teturn events that were rejected - update_metrics (bool): Whether to update the cache hit ratio metrics - - Returns: - dict of event_id -> _EventCacheEntry for each event_id in cache. If - allow_rejected is `False` then there will still be an entry but it - will be `None` - """ - event_map = {} - - for event_id in events: - ret = self._get_event_cache.get( - (event_id,), None, - update_metrics=update_metrics, - ) - if not ret: - continue - - if allow_rejected or not ret.event.rejected_reason: - event_map[event_id] = ret - else: - event_map[event_id] = None - - return event_map - - def _do_fetch(self, conn): - """Takes a database connection and waits for requests for events from - the _event_fetch_list queue. - """ - event_list = [] - i = 0 - while True: - try: - with self._event_fetch_lock: - event_list = self._event_fetch_list - self._event_fetch_list = [] - - if not event_list: - single_threaded = self.database_engine.single_threaded - if single_threaded or i > EVENT_QUEUE_ITERATIONS: - self._event_fetch_ongoing -= 1 - return - else: - self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) - i += 1 - continue - i = 0 - - event_id_lists = zip(*event_list)[0] - event_ids = [ - item for sublist in event_id_lists for item in sublist - ] - - rows = self._new_transaction( - conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids - ) - - row_dict = { - r["event_id"]: r - for r in rows - } - - # We only want to resolve deferreds from the main thread - def fire(lst, res): - for ids, d in lst: - if not d.called: - try: - with PreserveLoggingContext(): - d.callback([ - res[i] - for i in ids - if i in res - ]) - except Exception: - logger.exception("Failed to callback") - with PreserveLoggingContext(): - reactor.callFromThread(fire, event_list, row_dict) - except Exception as e: - logger.exception("do_fetch") - - # We only want to resolve deferreds from the main thread - def fire(evs): - for _, d in evs: - if not d.called: - with PreserveLoggingContext(): - d.errback(e) - - if event_list: - with PreserveLoggingContext(): - reactor.callFromThread(fire, event_list) - - @defer.inlineCallbacks - def _enqueue_events(self, events, check_redacted=True, allow_rejected=False): - """Fetches events from the database using the _event_fetch_list. This - allows batch and bulk fetching of events - it allows us to fetch events - without having to create a new transaction for each request for events. - """ - if not events: - defer.returnValue({}) - - events_d = defer.Deferred() - with self._event_fetch_lock: - self._event_fetch_list.append( - (events, events_d) - ) - - self._event_fetch_lock.notify() - - if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: - self._event_fetch_ongoing += 1 - should_start = True - else: - should_start = False - - if should_start: - with PreserveLoggingContext(): - self.runWithConnection( - self._do_fetch - ) - - logger.debug("Loading %d events", len(events)) - with PreserveLoggingContext(): - rows = yield events_d - logger.debug("Loaded %d events (%d rows)", len(events), len(rows)) - - if not allow_rejected: - rows[:] = [r for r in rows if not r["rejects"]] - - res = yield make_deferred_yieldable(defer.gatherResults( - [ - preserve_fn(self._get_event_from_row)( - row["internal_metadata"], row["json"], row["redacts"], - rejected_reason=row["rejects"], - ) - for row in rows - ], - consumeErrors=True - )) - - defer.returnValue({ - e.event.event_id: e - for e in res if e - }) - - def _fetch_event_rows(self, txn, events): - rows = [] - N = 200 - for i in range(1 + len(events) / N): - evs = events[i * N:(i + 1) * N] - if not evs: - break - - sql = ( - "SELECT " - " e.event_id as event_id, " - " e.internal_metadata," - " e.json," - " r.redacts as redacts," - " rej.event_id as rejects " - " FROM event_json as e" - " LEFT JOIN rejections as rej USING (event_id)" - " LEFT JOIN redactions as r ON e.event_id = r.redacts" - " WHERE e.event_id IN (%s)" - ) % (",".join(["?"] * len(evs)),) - - txn.execute(sql, evs) - rows.extend(self.cursor_to_dict(txn)) - - return rows - - @defer.inlineCallbacks - def _get_event_from_row(self, internal_metadata, js, redacted, - rejected_reason=None): - with Measure(self._clock, "_get_event_from_row"): - d = json.loads(js) - internal_metadata = json.loads(internal_metadata) - - if rejected_reason: - rejected_reason = yield self._simple_select_one_onecol( - table="rejections", - keyvalues={"event_id": rejected_reason}, - retcol="reason", - desc="_get_event_from_row_rejected_reason", - ) - - original_ev = FrozenEvent( - d, - internal_metadata_dict=internal_metadata, - rejected_reason=rejected_reason, - ) - - redacted_event = None - if redacted: - redacted_event = prune_event(original_ev) - - redaction_id = yield self._simple_select_one_onecol( - table="redactions", - keyvalues={"redacts": redacted_event.event_id}, - retcol="event_id", - desc="_get_event_from_row_redactions", - ) - - redacted_event.unsigned["redacted_by"] = redaction_id - # Get the redaction event. - - because = yield self.get_event( - redaction_id, - check_redacted=False, - allow_none=True, - ) - - if because: - # It's fine to do add the event directly, since get_pdu_json - # will serialise this field correctly - redacted_event.unsigned["redacted_because"] = because - - cache_entry = _EventCacheEntry( - event=original_ev, - redacted_event=redacted_event, - ) - - self._get_event_cache.prefill((original_ev.event_id,), cache_entry) - - defer.returnValue(cache_entry) - - class EventsStore(EventsWorkerStore): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" def __init__(self, db_conn, hs): super(EventsStore, self).__init__(db_conn, hs) - self._clock = hs.get_clock() self.register_background_update_handler( self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts ) @@ -583,6 +224,8 @@ class EventsStore(EventsWorkerStore): psql_only=True, ) + self._event_persist_queue = _EventPeristenceQueue() + self._state_resolution_handler = hs.get_state_resolution_handler() def persist_events(self, events_and_contexts, backfilled=False): @@ -984,6 +627,8 @@ class EventsStore(EventsWorkerStore): list of the event ids which are the forward extremities. """ + all_events_and_contexts = events_and_contexts + max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering self._update_current_state_txn(txn, state_delta_for_room, max_stream_order) @@ -1046,6 +691,7 @@ class EventsStore(EventsWorkerStore): self._update_metadata_tables_txn( txn, events_and_contexts=events_and_contexts, + all_events_and_contexts=all_events_and_contexts, backfilled=backfilled, ) @@ -1443,26 +1089,33 @@ class EventsStore(EventsWorkerStore): ec for ec in events_and_contexts if ec[0] not in to_remove ] - def _update_metadata_tables_txn(self, txn, events_and_contexts, backfilled): + def _update_metadata_tables_txn(self, txn, events_and_contexts, + all_events_and_contexts, backfilled): """Update all the miscellaneous tables for new events Args: txn (twisted.enterprise.adbapi.Connection): db connection events_and_contexts (list[(EventBase, EventContext)]): events we are persisting + all_events_and_contexts (list[(EventBase, EventContext)]): all + events that we were going to persist. This includes events + we've already persisted, etc, that wouldn't appear in + events_and_context. backfilled (bool): True if the events were backfilled """ + # Insert all the push actions into the event_push_actions table. + self._set_push_actions_for_event_and_users_txn( + txn, + events_and_contexts=events_and_contexts, + all_events_and_contexts=all_events_and_contexts, + ) + if not events_and_contexts: # nothing to do here return for event, context in events_and_contexts: - # Insert all the push actions into the event_push_actions table. - self._set_push_actions_for_event_and_users_txn( - txn, event, - ) - if event.type == EventTypes.Redaction and event.redacts is not None: # Remove the entries in the event_push_actions table for the # redacted event. diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py new file mode 100644 index 0000000000..86c3b48ad4 --- /dev/null +++ b/synapse/storage/events_worker.py @@ -0,0 +1,395 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import SQLBaseStore + +from twisted.internet import defer, reactor + +from synapse.events import FrozenEvent +from synapse.events.utils import prune_event + +from synapse.util.logcontext import ( + preserve_fn, PreserveLoggingContext, make_deferred_yieldable +) +from synapse.util.metrics import Measure +from synapse.api.errors import SynapseError + +from collections import namedtuple + +import logging +import ujson as json + +# these are only included to make the type annotations work +from synapse.events import EventBase # noqa: F401 +from synapse.events.snapshot import EventContext # noqa: F401 + +logger = logging.getLogger(__name__) + + +# These values are used in the `enqueus_event` and `_do_fetch` methods to +# control how we batch/bulk fetch events from the database. +# The values are plucked out of thing air to make initial sync run faster +# on jki.re +# TODO: Make these configurable. +EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events +EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events +EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events + + +_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) + + +class EventsWorkerStore(SQLBaseStore): + + @defer.inlineCallbacks + def get_event(self, event_id, check_redacted=True, + get_prev_content=False, allow_rejected=False, + allow_none=False): + """Get an event from the database by event_id. + + Args: + event_id (str): The event_id of the event to fetch + check_redacted (bool): If True, check if event has been redacted + and redact it. + get_prev_content (bool): If True and event is a state event, + include the previous states content in the unsigned field. + allow_rejected (bool): If True return rejected events. + allow_none (bool): If True, return None if no event found, if + False throw an exception. + + Returns: + Deferred : A FrozenEvent. + """ + events = yield self._get_events( + [event_id], + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + if not events and not allow_none: + raise SynapseError(404, "Could not find event %s" % (event_id,)) + + defer.returnValue(events[0] if events else None) + + @defer.inlineCallbacks + def get_events(self, event_ids, check_redacted=True, + get_prev_content=False, allow_rejected=False): + """Get events from the database + + Args: + event_ids (list): The event_ids of the events to fetch + check_redacted (bool): If True, check if event has been redacted + and redact it. + get_prev_content (bool): If True and event is a state event, + include the previous states content in the unsigned field. + allow_rejected (bool): If True return rejected events. + + Returns: + Deferred : Dict from event_id to event. + """ + events = yield self._get_events( + event_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + defer.returnValue({e.event_id: e for e in events}) + + @defer.inlineCallbacks + def _get_events(self, event_ids, check_redacted=True, + get_prev_content=False, allow_rejected=False): + if not event_ids: + defer.returnValue([]) + + event_id_list = event_ids + event_ids = set(event_ids) + + event_entry_map = self._get_events_from_cache( + event_ids, + allow_rejected=allow_rejected, + ) + + missing_events_ids = [e for e in event_ids if e not in event_entry_map] + + if missing_events_ids: + missing_events = yield self._enqueue_events( + missing_events_ids, + check_redacted=check_redacted, + allow_rejected=allow_rejected, + ) + + event_entry_map.update(missing_events) + + events = [] + for event_id in event_id_list: + entry = event_entry_map.get(event_id, None) + if not entry: + continue + + if allow_rejected or not entry.event.rejected_reason: + if check_redacted and entry.redacted_event: + event = entry.redacted_event + else: + event = entry.event + + events.append(event) + + if get_prev_content: + if "replaces_state" in event.unsigned: + prev = yield self.get_event( + event.unsigned["replaces_state"], + get_prev_content=False, + allow_none=True, + ) + if prev: + event.unsigned = dict(event.unsigned) + event.unsigned["prev_content"] = prev.content + event.unsigned["prev_sender"] = prev.sender + + defer.returnValue(events) + + def _invalidate_get_event_cache(self, event_id): + self._get_event_cache.invalidate((event_id,)) + + def _get_events_from_cache(self, events, allow_rejected, update_metrics=True): + """Fetch events from the caches + + Args: + events (list(str)): list of event_ids to fetch + allow_rejected (bool): Whether to teturn events that were rejected + update_metrics (bool): Whether to update the cache hit ratio metrics + + Returns: + dict of event_id -> _EventCacheEntry for each event_id in cache. If + allow_rejected is `False` then there will still be an entry but it + will be `None` + """ + event_map = {} + + for event_id in events: + ret = self._get_event_cache.get( + (event_id,), None, + update_metrics=update_metrics, + ) + if not ret: + continue + + if allow_rejected or not ret.event.rejected_reason: + event_map[event_id] = ret + else: + event_map[event_id] = None + + return event_map + + def _do_fetch(self, conn): + """Takes a database connection and waits for requests for events from + the _event_fetch_list queue. + """ + event_list = [] + i = 0 + while True: + try: + with self._event_fetch_lock: + event_list = self._event_fetch_list + self._event_fetch_list = [] + + if not event_list: + single_threaded = self.database_engine.single_threaded + if single_threaded or i > EVENT_QUEUE_ITERATIONS: + self._event_fetch_ongoing -= 1 + return + else: + self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) + i += 1 + continue + i = 0 + + event_id_lists = zip(*event_list)[0] + event_ids = [ + item for sublist in event_id_lists for item in sublist + ] + + rows = self._new_transaction( + conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids + ) + + row_dict = { + r["event_id"]: r + for r in rows + } + + # We only want to resolve deferreds from the main thread + def fire(lst, res): + for ids, d in lst: + if not d.called: + try: + with PreserveLoggingContext(): + d.callback([ + res[i] + for i in ids + if i in res + ]) + except Exception: + logger.exception("Failed to callback") + with PreserveLoggingContext(): + reactor.callFromThread(fire, event_list, row_dict) + except Exception as e: + logger.exception("do_fetch") + + # We only want to resolve deferreds from the main thread + def fire(evs): + for _, d in evs: + if not d.called: + with PreserveLoggingContext(): + d.errback(e) + + if event_list: + with PreserveLoggingContext(): + reactor.callFromThread(fire, event_list) + + @defer.inlineCallbacks + def _enqueue_events(self, events, check_redacted=True, allow_rejected=False): + """Fetches events from the database using the _event_fetch_list. This + allows batch and bulk fetching of events - it allows us to fetch events + without having to create a new transaction for each request for events. + """ + if not events: + defer.returnValue({}) + + events_d = defer.Deferred() + with self._event_fetch_lock: + self._event_fetch_list.append( + (events, events_d) + ) + + self._event_fetch_lock.notify() + + if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: + self._event_fetch_ongoing += 1 + should_start = True + else: + should_start = False + + if should_start: + with PreserveLoggingContext(): + self.runWithConnection( + self._do_fetch + ) + + logger.debug("Loading %d events", len(events)) + with PreserveLoggingContext(): + rows = yield events_d + logger.debug("Loaded %d events (%d rows)", len(events), len(rows)) + + if not allow_rejected: + rows[:] = [r for r in rows if not r["rejects"]] + + res = yield make_deferred_yieldable(defer.gatherResults( + [ + preserve_fn(self._get_event_from_row)( + row["internal_metadata"], row["json"], row["redacts"], + rejected_reason=row["rejects"], + ) + for row in rows + ], + consumeErrors=True + )) + + defer.returnValue({ + e.event.event_id: e + for e in res if e + }) + + def _fetch_event_rows(self, txn, events): + rows = [] + N = 200 + for i in range(1 + len(events) / N): + evs = events[i * N:(i + 1) * N] + if not evs: + break + + sql = ( + "SELECT " + " e.event_id as event_id, " + " e.internal_metadata," + " e.json," + " r.redacts as redacts," + " rej.event_id as rejects " + " FROM event_json as e" + " LEFT JOIN rejections as rej USING (event_id)" + " LEFT JOIN redactions as r ON e.event_id = r.redacts" + " WHERE e.event_id IN (%s)" + ) % (",".join(["?"] * len(evs)),) + + txn.execute(sql, evs) + rows.extend(self.cursor_to_dict(txn)) + + return rows + + @defer.inlineCallbacks + def _get_event_from_row(self, internal_metadata, js, redacted, + rejected_reason=None): + with Measure(self._clock, "_get_event_from_row"): + d = json.loads(js) + internal_metadata = json.loads(internal_metadata) + + if rejected_reason: + rejected_reason = yield self._simple_select_one_onecol( + table="rejections", + keyvalues={"event_id": rejected_reason}, + retcol="reason", + desc="_get_event_from_row_rejected_reason", + ) + + original_ev = FrozenEvent( + d, + internal_metadata_dict=internal_metadata, + rejected_reason=rejected_reason, + ) + + redacted_event = None + if redacted: + redacted_event = prune_event(original_ev) + + redaction_id = yield self._simple_select_one_onecol( + table="redactions", + keyvalues={"redacts": redacted_event.event_id}, + retcol="event_id", + desc="_get_event_from_row_redactions", + ) + + redacted_event.unsigned["redacted_by"] = redaction_id + # Get the redaction event. + + because = yield self.get_event( + redaction_id, + check_redacted=False, + allow_none=True, + ) + + if because: + # It's fine to do add the event directly, since get_pdu_json + # will serialise this field correctly + redacted_event.unsigned["redacted_because"] = because + + cache_entry = _EventCacheEntry( + event=original_ev, + redacted_event=redacted_event, + ) + + self._get_event_cache.prefill((original_ev.event_id,), cache_entry) + + defer.returnValue(cache_entry) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 8758b1c0c7..583efb7bdf 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,10 +16,12 @@ from ._base import SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.push.baserules import list_with_base_rules from synapse.api.constants import EventTypes from twisted.internet import defer +import abc import logging import simplejson as json @@ -48,7 +51,39 @@ def _load_rules(rawrules, enabled_map): return rules -class PushRuleStore(SQLBaseStore): +class PushRulesWorkerStore(SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_max_push_rules_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, db_conn, hs): + super(PushRulesWorkerStore, self).__init__(db_conn, hs) + + push_rules_prefill, push_rules_id = self._get_cache_dict( + db_conn, "push_rules_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self.get_max_push_rules_stream_id(), + ) + + self.push_rules_stream_cache = StreamChangeCache( + "PushRulesStreamChangeCache", push_rules_id, + prefilled_cache=push_rules_prefill, + ) + + @abc.abstractmethod + def get_max_push_rules_stream_id(self): + """Get the position of the push rules stream. + + Returns: + int + """ + raise NotImplementedError() + @cachedInlineCallbacks(max_entries=5000) def get_push_rules_for_user(self, user_id): rows = yield self._simple_select_list( @@ -89,6 +124,24 @@ class PushRuleStore(SQLBaseStore): r['rule_id']: False if r['enabled'] == 0 else True for r in results }) + def have_push_rules_changed_for_user(self, user_id, last_id): + if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): + return defer.succeed(False) + else: + def have_push_rules_changed_txn(txn): + sql = ( + "SELECT COUNT(stream_id) FROM push_rules_stream" + " WHERE user_id = ? AND ? < stream_id" + ) + txn.execute(sql, (user_id, last_id)) + count, = txn.fetchone() + return bool(count) + return self.runInteraction( + "have_push_rules_changed", have_push_rules_changed_txn + ) + + +class PushRuleStore(PushRulesWorkerStore): @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1, inlineCallbacks=True) def bulk_get_push_rules(self, user_ids): @@ -526,21 +579,8 @@ class PushRuleStore(SQLBaseStore): room stream ordering it corresponds to.""" return self._push_rules_stream_id_gen.get_current_token() - def have_push_rules_changed_for_user(self, user_id, last_id): - if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): - return defer.succeed(False) - else: - def have_push_rules_changed_txn(txn): - sql = ( - "SELECT COUNT(stream_id) FROM push_rules_stream" - " WHERE user_id = ? AND ? < stream_id" - ) - txn.execute(sql, (user_id, last_id)) - count, = txn.fetchone() - return bool(count) - return self.runInteraction( - "have_push_rules_changed", have_push_rules_changed_txn - ) + def get_max_push_rules_stream_id(self): + return self.get_push_rules_stream_token()[0] class RuleNotFoundException(Exception): diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 3d8b4d5d5b..f4af3e4caa 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,7 +28,7 @@ import types logger = logging.getLogger(__name__) -class PusherStore(SQLBaseStore): +class PusherWorkerStore(SQLBaseStore): def _decode_pushers_rows(self, rows): for r in rows: dataJson = r['data'] @@ -102,9 +103,6 @@ class PusherStore(SQLBaseStore): rows = yield self.runInteraction("get_all_pushers", get_pushers) defer.returnValue(rows) - def get_pushers_stream_token(self): - return self._pushers_id_gen.get_current_token() - def get_all_updated_pushers(self, last_id, current_id, limit): if last_id == current_id: return defer.succeed(([], [])) @@ -177,6 +175,11 @@ class PusherStore(SQLBaseStore): "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn ) + +class PusherStore(PusherWorkerStore): + def get_pushers_stream_token(self): + return self._pushers_id_gen.get_current_token() + @cachedInlineCallbacks(num_args=1, max_entries=15000) def get_if_user_has_pusher(self, user_id): # This only exists for the cachedList decorator diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index bff73f3f04..fc46bf7bb3 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore +from synapse.storage.account_data import AccountDataWorkerStore + from synapse.util.caches.descriptors import cached from twisted.internet import defer @@ -23,15 +25,7 @@ import logging logger = logging.getLogger(__name__) -class TagsStore(SQLBaseStore): - def get_max_account_data_stream_id(self): - """Get the current max stream id for the private user data stream - - Returns: - A deferred int. - """ - return self._account_data_id_gen.get_current_token() - +class TagsWorkerStore(AccountDataWorkerStore): @cached() def get_tags_for_user(self, user_id): """Get all the tags for a user. @@ -170,6 +164,8 @@ class TagsStore(SQLBaseStore): row["tag"]: json.loads(row["content"]) for row in rows }) + +class TagsStore(TagsWorkerStore): @defer.inlineCallbacks def add_tag_to_room(self, user_id, room_id, tag, content): """Add a tag to a room for a user. diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 4780f2ab72..cb058d3142 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -230,10 +230,12 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): state_handler = self.hs.get_state_handler() context = yield state_handler.compute_event_context(event) - for user_id, actions in push_actions: - yield self.master_store.add_push_actions_to_staging( - event.event_id, user_id, actions, - ) + yield self.master_store.add_push_actions_to_staging( + event.event_id, { + user_id: actions + for user_id, actions in push_actions + }, + ) ordering = None if backfill: diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index d483e7cf9e..6c1aad149b 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -71,11 +71,11 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): event.depth = stream yield self.store.add_push_actions_to_staging( - event.event_id, user_id, action, + event.event_id, {user_id: action}, ) yield self.store.runInteraction( "", self.store._set_push_actions_for_event_and_users_txn, - event, + [(event, None)], [(event, None)], ) def _rotate(stream): |