diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 46bcf8b081..8832ba58bc 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1447,16 +1447,24 @@ class FederationHandler(BaseHandler):
auth_events=auth_events,
)
- if not event.internal_metadata.is_outlier() and not backfilled:
- yield self.action_generator.handle_push_actions_for_event(
- event, context
- )
+ try:
+ if not event.internal_metadata.is_outlier() and not backfilled:
+ yield self.action_generator.handle_push_actions_for_event(
+ event, context
+ )
- event_stream_id, max_stream_id = yield self.store.persist_event(
- event,
- context=context,
- backfilled=backfilled,
- )
+ event_stream_id, max_stream_id = yield self.store.persist_event(
+ event,
+ context=context,
+ backfilled=backfilled,
+ )
+ except: # noqa: E722, as we reraise the exception this is fine.
+ # Ensure that we actually remove the entries in the push actions
+ # staging area
+ logcontext.preserve_fn(
+ self.store.remove_push_actions_from_staging
+ )(event.event_id)
+ raise
if not backfilled:
# this intentionally does not yield: we don't care about the result
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index efbd87918e..6c8d2954d7 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,50 +14,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
-from synapse.storage import DataStore
-from synapse.storage.account_data import AccountDataStore
-from synapse.storage.tags import TagsStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.storage.account_data import AccountDataWorkerStore
+from synapse.storage.tags import TagsWorkerStore
-class SlavedAccountDataStore(BaseSlavedStore):
+class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
- super(SlavedAccountDataStore, self).__init__(db_conn, hs)
self._account_data_id_gen = SlavedIdTracker(
db_conn, "account_data_max_stream_id", "stream_id",
)
- self._account_data_stream_cache = StreamChangeCache(
- "AccountDataAndTagsChangeCache",
- self._account_data_id_gen.get_current_token(),
- )
-
- get_account_data_for_user = (
- AccountDataStore.__dict__["get_account_data_for_user"]
- )
-
- get_global_account_data_by_type_for_users = (
- AccountDataStore.__dict__["get_global_account_data_by_type_for_users"]
- )
- get_global_account_data_by_type_for_user = (
- AccountDataStore.__dict__["get_global_account_data_by_type_for_user"]
- )
-
- get_tags_for_user = TagsStore.__dict__["get_tags_for_user"]
- get_tags_for_room = (
- DataStore.get_tags_for_room.__func__
- )
- get_account_data_for_room = (
- DataStore.get_account_data_for_room.__func__
- )
-
- get_updated_tags = DataStore.get_updated_tags.__func__
- get_updated_account_data_for_user = (
- DataStore.get_updated_account_data_for_user.__func__
- )
+ super(SlavedAccountDataStore, self).__init__(db_conn, hs)
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index c5d6c6bd86..44499dc73f 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,6 +19,7 @@ from synapse.api.constants import EventTypes
from synapse.storage import DataStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.event_push_actions import EventPushActionsWorkerStore
+from synapse.storage.events_worker import EventsWorkerStore
from synapse.storage.roommember import RoomMemberStore
from synapse.storage.state import StateGroupWorkerStore
from synapse.storage.stream import StreamStore
@@ -38,8 +40,8 @@ logger = logging.getLogger(__name__)
# the method descriptor on the DataStore and chuck them into our class.
-class SlavedEventStore(EventPushActionsWorkerStore, StateGroupWorkerStore,
- BaseSlavedStore):
+class SlavedEventStore(EventPushActionsWorkerStore, EventsWorkerStore,
+ StateGroupWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedEventStore, self).__init__(db_conn, hs)
@@ -87,8 +89,6 @@ class SlavedEventStore(EventPushActionsWorkerStore, StateGroupWorkerStore,
_get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"]
has_room_changed_since = DataStore.has_room_changed_since.__func__
- get_event = DataStore.get_event.__func__
- get_events = DataStore.get_events.__func__
get_rooms_for_user_where_membership_is = (
DataStore.get_rooms_for_user_where_membership_is.__func__
)
@@ -118,14 +118,6 @@ class SlavedEventStore(EventPushActionsWorkerStore, StateGroupWorkerStore,
_set_before_and_after = staticmethod(DataStore._set_before_and_after)
- _get_events = DataStore._get_events.__func__
- _get_events_from_cache = DataStore._get_events_from_cache.__func__
-
- _invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__
- _enqueue_events = DataStore._enqueue_events.__func__
- _do_fetch = DataStore._do_fetch.__func__
- _fetch_event_rows = DataStore._fetch_event_rows.__func__
- _get_event_from_row = DataStore._get_event_from_row.__func__
_get_rooms_for_user_where_membership_is_txn = (
DataStore._get_rooms_for_user_where_membership_is_txn.__func__
)
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 83e880fdd2..bb2c40b6e3 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,29 +16,15 @@
from .events import SlavedEventStore
from ._slaved_id_tracker import SlavedIdTracker
-from synapse.storage import DataStore
-from synapse.storage.push_rule import PushRuleStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.storage.push_rule import PushRulesWorkerStore
-class SlavedPushRuleStore(SlavedEventStore):
+class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore):
def __init__(self, db_conn, hs):
- super(SlavedPushRuleStore, self).__init__(db_conn, hs)
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id",
)
- self.push_rules_stream_cache = StreamChangeCache(
- "PushRulesStreamChangeCache",
- self._push_rules_stream_id_gen.get_current_token(),
- )
-
- get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"]
- get_push_rules_enabled_for_user = (
- PushRuleStore.__dict__["get_push_rules_enabled_for_user"]
- )
- have_push_rules_changed_for_user = (
- DataStore.have_push_rules_changed_for_user.__func__
- )
+ super(SlavedPushRuleStore, self).__init__(db_conn, hs)
def get_push_rules_stream_token(self):
return (
@@ -45,6 +32,9 @@ class SlavedPushRuleStore(SlavedEventStore):
self._stream_id_gen.get_current_token(),
)
+ def get_max_push_rules_stream_id(self):
+ return self._push_rules_stream_id_gen.get_current_token()
+
def stream_positions(self):
result = super(SlavedPushRuleStore, self).stream_positions()
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index 4e8d68ece9..a7cd5a7291 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,10 +17,10 @@
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
-from synapse.storage import DataStore
+from synapse.storage.pusher import PusherWorkerStore
-class SlavedPusherStore(BaseSlavedStore):
+class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedPusherStore, self).__init__(db_conn, hs)
@@ -28,13 +29,6 @@ class SlavedPusherStore(BaseSlavedStore):
extra_tables=[("deleted_pushers", "stream_id")],
)
- get_all_pushers = DataStore.get_all_pushers.__func__
- get_pushers_by = DataStore.get_pushers_by.__func__
- get_pushers_by_app_id_and_pushkey = (
- DataStore.get_pushers_by_app_id_and_pushkey.__func__
- )
- _decode_pushers_rows = DataStore._decode_pushers_rows.__func__
-
def stream_positions(self):
result = super(SlavedPusherStore, self).stream_positions()
result["pushers"] = self._pushers_id_gen.get_current_token()
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index e1c4fe086e..0f136f8a06 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -104,9 +105,6 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "events", "stream_ordering", step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
)
- self._account_data_id_gen = StreamIdGenerator(
- db_conn, "account_data_max_stream_id", "stream_id"
- )
self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id"
)
@@ -159,11 +157,6 @@ class DataStore(RoomMemberStore, RoomStore,
"MembershipStreamChangeCache", events_max,
)
- account_max = self._account_data_id_gen.get_current_token()
- self._account_data_stream_cache = StreamChangeCache(
- "AccountDataAndTagsChangeCache", account_max,
- )
-
self._presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self._get_cache_dict(
@@ -177,18 +170,6 @@ class DataStore(RoomMemberStore, RoomStore,
prefilled_cache=presence_cache_prefill
)
- push_rules_prefill, push_rules_id = self._get_cache_dict(
- db_conn, "push_rules_stream",
- entity_column="user_id",
- stream_column="stream_id",
- max_value=self._push_rules_stream_id_gen.get_current_token()[0],
- )
-
- self.push_rules_stream_cache = StreamChangeCache(
- "PushRulesStreamChangeCache", push_rules_id,
- prefilled_cache=push_rules_prefill,
- )
-
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
db_conn, "device_inbox",
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index 56a0bde549..466194e96f 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,18 +14,46 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
from twisted.internet import defer
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.util.id_generators import StreamIdGenerator
+
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
+import abc
import ujson as json
import logging
logger = logging.getLogger(__name__)
-class AccountDataStore(SQLBaseStore):
+class AccountDataWorkerStore(SQLBaseStore):
+ """This is an abstract base class where subclasses must implement
+ `get_max_account_data_stream_id` which can be called in the initializer.
+ """
+
+ # This ABCMeta metaclass ensures that we cannot be instantiated without
+ # the abstract methods being implemented.
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, db_conn, hs):
+ account_max = self.get_max_account_data_stream_id()
+ self._account_data_stream_cache = StreamChangeCache(
+ "AccountDataAndTagsChangeCache", account_max,
+ )
+
+ super(AccountDataWorkerStore, self).__init__(db_conn, hs)
+
+ @abc.abstractmethod
+ def get_max_account_data_stream_id(self):
+ """Get the current max stream ID for account data stream
+
+ Returns:
+ int
+ """
+ raise NotImplementedError()
@cached()
def get_account_data_for_user(self, user_id):
@@ -209,6 +238,36 @@ class AccountDataStore(SQLBaseStore):
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
+ @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
+ def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
+ ignored_account_data = yield self.get_global_account_data_by_type_for_user(
+ "m.ignored_user_list", ignorer_user_id,
+ on_invalidate=cache_context.invalidate,
+ )
+ if not ignored_account_data:
+ defer.returnValue(False)
+
+ defer.returnValue(
+ ignored_user_id in ignored_account_data.get("ignored_users", {})
+ )
+
+
+class AccountDataStore(AccountDataWorkerStore):
+ def __init__(self, db_conn, hs):
+ self._account_data_id_gen = StreamIdGenerator(
+ db_conn, "account_data_max_stream_id", "stream_id"
+ )
+
+ super(AccountDataStore, self).__init__(db_conn, hs)
+
+ def get_max_account_data_stream_id(self):
+ """Get the current max stream id for the private user data stream
+
+ Returns:
+ A deferred int.
+ """
+ return self._account_data_id_gen.get_current_token()
+
@defer.inlineCallbacks
def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
"""Add some account_data to a room for a user.
@@ -321,16 +380,3 @@ class AccountDataStore(SQLBaseStore):
"update_account_data_max_stream_id",
_update,
)
-
- @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
- def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
- ignored_account_data = yield self.get_global_account_data_by_type_for_user(
- "m.ignored_user_list", ignorer_user_id,
- on_invalidate=cache_context.invalidate,
- )
- if not ignored_account_data:
- defer.returnValue(False)
-
- defer.returnValue(
- ignored_user_id in ignored_account_data.get("ignored_users", {})
- )
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index c636da4b72..99d6cca585 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -13,16 +13,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
-from twisted.internet import defer, reactor
+from synapse.storage.events_worker import EventsWorkerStore
-from synapse.events import FrozenEvent, USE_FROZEN_DICTS
-from synapse.events.utils import prune_event
+from twisted.internet import defer
+
+from synapse.events import USE_FROZEN_DICTS
from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import (
- preserve_fn, PreserveLoggingContext, make_deferred_yieldable
+ PreserveLoggingContext, make_deferred_yieldable
)
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
@@ -62,16 +62,6 @@ def encode_json(json_object):
return json.dumps(json_object, ensure_ascii=False)
-# These values are used in the `enqueus_event` and `_do_fetch` methods to
-# control how we batch/bulk fetch events from the database.
-# The values are plucked out of thing air to make initial sync run faster
-# on jki.re
-# TODO: Make these configurable.
-EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events
-EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
-EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
-
-
class _EventPeristenceQueue(object):
"""Queues up events so that they can be persisted in bulk with only one
concurrent transaction per room.
@@ -200,13 +190,12 @@ def _retry_on_integrity_error(func):
return f
-class EventsStore(SQLBaseStore):
+class EventsStore(EventsWorkerStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
def __init__(self, db_conn, hs):
super(EventsStore, self).__init__(db_conn, hs)
- self._clock = hs.get_clock()
self.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
)
@@ -610,62 +599,6 @@ class EventsStore(SQLBaseStore):
defer.returnValue((to_delete, to_insert))
- @defer.inlineCallbacks
- def get_event(self, event_id, check_redacted=True,
- get_prev_content=False, allow_rejected=False,
- allow_none=False):
- """Get an event from the database by event_id.
-
- Args:
- event_id (str): The event_id of the event to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
- include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
- allow_none (bool): If True, return None if no event found, if
- False throw an exception.
-
- Returns:
- Deferred : A FrozenEvent.
- """
- events = yield self._get_events(
- [event_id],
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- allow_rejected=allow_rejected,
- )
-
- if not events and not allow_none:
- raise SynapseError(404, "Could not find event %s" % (event_id,))
-
- defer.returnValue(events[0] if events else None)
-
- @defer.inlineCallbacks
- def get_events(self, event_ids, check_redacted=True,
- get_prev_content=False, allow_rejected=False):
- """Get events from the database
-
- Args:
- event_ids (list): The event_ids of the events to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
- include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
-
- Returns:
- Deferred : Dict from event_id to event.
- """
- events = yield self._get_events(
- event_ids,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- allow_rejected=allow_rejected,
- )
-
- defer.returnValue({e.event_id: e for e in events})
-
@log_function
def _persist_events_txn(self, txn, events_and_contexts, backfilled,
delete_existing=False, state_delta_for_room={},
@@ -1377,292 +1310,6 @@ class EventsStore(SQLBaseStore):
)
@defer.inlineCallbacks
- def _get_events(self, event_ids, check_redacted=True,
- get_prev_content=False, allow_rejected=False):
- if not event_ids:
- defer.returnValue([])
-
- event_id_list = event_ids
- event_ids = set(event_ids)
-
- event_entry_map = self._get_events_from_cache(
- event_ids,
- allow_rejected=allow_rejected,
- )
-
- missing_events_ids = [e for e in event_ids if e not in event_entry_map]
-
- if missing_events_ids:
- missing_events = yield self._enqueue_events(
- missing_events_ids,
- check_redacted=check_redacted,
- allow_rejected=allow_rejected,
- )
-
- event_entry_map.update(missing_events)
-
- events = []
- for event_id in event_id_list:
- entry = event_entry_map.get(event_id, None)
- if not entry:
- continue
-
- if allow_rejected or not entry.event.rejected_reason:
- if check_redacted and entry.redacted_event:
- event = entry.redacted_event
- else:
- event = entry.event
-
- events.append(event)
-
- if get_prev_content:
- if "replaces_state" in event.unsigned:
- prev = yield self.get_event(
- event.unsigned["replaces_state"],
- get_prev_content=False,
- allow_none=True,
- )
- if prev:
- event.unsigned = dict(event.unsigned)
- event.unsigned["prev_content"] = prev.content
- event.unsigned["prev_sender"] = prev.sender
-
- defer.returnValue(events)
-
- def _invalidate_get_event_cache(self, event_id):
- self._get_event_cache.invalidate((event_id,))
-
- def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
- """Fetch events from the caches
-
- Args:
- events (list(str)): list of event_ids to fetch
- allow_rejected (bool): Whether to teturn events that were rejected
- update_metrics (bool): Whether to update the cache hit ratio metrics
-
- Returns:
- dict of event_id -> _EventCacheEntry for each event_id in cache. If
- allow_rejected is `False` then there will still be an entry but it
- will be `None`
- """
- event_map = {}
-
- for event_id in events:
- ret = self._get_event_cache.get(
- (event_id,), None,
- update_metrics=update_metrics,
- )
- if not ret:
- continue
-
- if allow_rejected or not ret.event.rejected_reason:
- event_map[event_id] = ret
- else:
- event_map[event_id] = None
-
- return event_map
-
- def _do_fetch(self, conn):
- """Takes a database connection and waits for requests for events from
- the _event_fetch_list queue.
- """
- event_list = []
- i = 0
- while True:
- try:
- with self._event_fetch_lock:
- event_list = self._event_fetch_list
- self._event_fetch_list = []
-
- if not event_list:
- single_threaded = self.database_engine.single_threaded
- if single_threaded or i > EVENT_QUEUE_ITERATIONS:
- self._event_fetch_ongoing -= 1
- return
- else:
- self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
- i += 1
- continue
- i = 0
-
- event_id_lists = zip(*event_list)[0]
- event_ids = [
- item for sublist in event_id_lists for item in sublist
- ]
-
- rows = self._new_transaction(
- conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
- )
-
- row_dict = {
- r["event_id"]: r
- for r in rows
- }
-
- # We only want to resolve deferreds from the main thread
- def fire(lst, res):
- for ids, d in lst:
- if not d.called:
- try:
- with PreserveLoggingContext():
- d.callback([
- res[i]
- for i in ids
- if i in res
- ])
- except Exception:
- logger.exception("Failed to callback")
- with PreserveLoggingContext():
- reactor.callFromThread(fire, event_list, row_dict)
- except Exception as e:
- logger.exception("do_fetch")
-
- # We only want to resolve deferreds from the main thread
- def fire(evs):
- for _, d in evs:
- if not d.called:
- with PreserveLoggingContext():
- d.errback(e)
-
- if event_list:
- with PreserveLoggingContext():
- reactor.callFromThread(fire, event_list)
-
- @defer.inlineCallbacks
- def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
- """Fetches events from the database using the _event_fetch_list. This
- allows batch and bulk fetching of events - it allows us to fetch events
- without having to create a new transaction for each request for events.
- """
- if not events:
- defer.returnValue({})
-
- events_d = defer.Deferred()
- with self._event_fetch_lock:
- self._event_fetch_list.append(
- (events, events_d)
- )
-
- self._event_fetch_lock.notify()
-
- if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
- self._event_fetch_ongoing += 1
- should_start = True
- else:
- should_start = False
-
- if should_start:
- with PreserveLoggingContext():
- self.runWithConnection(
- self._do_fetch
- )
-
- logger.debug("Loading %d events", len(events))
- with PreserveLoggingContext():
- rows = yield events_d
- logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
-
- if not allow_rejected:
- rows[:] = [r for r in rows if not r["rejects"]]
-
- res = yield make_deferred_yieldable(defer.gatherResults(
- [
- preserve_fn(self._get_event_from_row)(
- row["internal_metadata"], row["json"], row["redacts"],
- rejected_reason=row["rejects"],
- )
- for row in rows
- ],
- consumeErrors=True
- ))
-
- defer.returnValue({
- e.event.event_id: e
- for e in res if e
- })
-
- def _fetch_event_rows(self, txn, events):
- rows = []
- N = 200
- for i in range(1 + len(events) / N):
- evs = events[i * N:(i + 1) * N]
- if not evs:
- break
-
- sql = (
- "SELECT "
- " e.event_id as event_id, "
- " e.internal_metadata,"
- " e.json,"
- " r.redacts as redacts,"
- " rej.event_id as rejects "
- " FROM event_json as e"
- " LEFT JOIN rejections as rej USING (event_id)"
- " LEFT JOIN redactions as r ON e.event_id = r.redacts"
- " WHERE e.event_id IN (%s)"
- ) % (",".join(["?"] * len(evs)),)
-
- txn.execute(sql, evs)
- rows.extend(self.cursor_to_dict(txn))
-
- return rows
-
- @defer.inlineCallbacks
- def _get_event_from_row(self, internal_metadata, js, redacted,
- rejected_reason=None):
- with Measure(self._clock, "_get_event_from_row"):
- d = json.loads(js)
- internal_metadata = json.loads(internal_metadata)
-
- if rejected_reason:
- rejected_reason = yield self._simple_select_one_onecol(
- table="rejections",
- keyvalues={"event_id": rejected_reason},
- retcol="reason",
- desc="_get_event_from_row_rejected_reason",
- )
-
- original_ev = FrozenEvent(
- d,
- internal_metadata_dict=internal_metadata,
- rejected_reason=rejected_reason,
- )
-
- redacted_event = None
- if redacted:
- redacted_event = prune_event(original_ev)
-
- redaction_id = yield self._simple_select_one_onecol(
- table="redactions",
- keyvalues={"redacts": redacted_event.event_id},
- retcol="event_id",
- desc="_get_event_from_row_redactions",
- )
-
- redacted_event.unsigned["redacted_by"] = redaction_id
- # Get the redaction event.
-
- because = yield self.get_event(
- redaction_id,
- check_redacted=False,
- allow_none=True,
- )
-
- if because:
- # It's fine to do add the event directly, since get_pdu_json
- # will serialise this field correctly
- redacted_event.unsigned["redacted_because"] = because
-
- cache_entry = _EventCacheEntry(
- event=original_ev,
- redacted_event=redacted_event,
- )
-
- self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
-
- defer.returnValue(cache_entry)
-
- @defer.inlineCallbacks
def count_daily_messages(self):
"""
Returns an estimate of the number of messages sent in the last day.
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
new file mode 100644
index 0000000000..86c3b48ad4
--- /dev/null
+++ b/synapse/storage/events_worker.py
@@ -0,0 +1,395 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import SQLBaseStore
+
+from twisted.internet import defer, reactor
+
+from synapse.events import FrozenEvent
+from synapse.events.utils import prune_event
+
+from synapse.util.logcontext import (
+ preserve_fn, PreserveLoggingContext, make_deferred_yieldable
+)
+from synapse.util.metrics import Measure
+from synapse.api.errors import SynapseError
+
+from collections import namedtuple
+
+import logging
+import ujson as json
+
+# these are only included to make the type annotations work
+from synapse.events import EventBase # noqa: F401
+from synapse.events.snapshot import EventContext # noqa: F401
+
+logger = logging.getLogger(__name__)
+
+
+# These values are used in the `enqueus_event` and `_do_fetch` methods to
+# control how we batch/bulk fetch events from the database.
+# The values are plucked out of thing air to make initial sync run faster
+# on jki.re
+# TODO: Make these configurable.
+EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events
+EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
+EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
+
+
+_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
+
+
+class EventsWorkerStore(SQLBaseStore):
+
+ @defer.inlineCallbacks
+ def get_event(self, event_id, check_redacted=True,
+ get_prev_content=False, allow_rejected=False,
+ allow_none=False):
+ """Get an event from the database by event_id.
+
+ Args:
+ event_id (str): The event_id of the event to fetch
+ check_redacted (bool): If True, check if event has been redacted
+ and redact it.
+ get_prev_content (bool): If True and event is a state event,
+ include the previous states content in the unsigned field.
+ allow_rejected (bool): If True return rejected events.
+ allow_none (bool): If True, return None if no event found, if
+ False throw an exception.
+
+ Returns:
+ Deferred : A FrozenEvent.
+ """
+ events = yield self._get_events(
+ [event_id],
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
+
+ if not events and not allow_none:
+ raise SynapseError(404, "Could not find event %s" % (event_id,))
+
+ defer.returnValue(events[0] if events else None)
+
+ @defer.inlineCallbacks
+ def get_events(self, event_ids, check_redacted=True,
+ get_prev_content=False, allow_rejected=False):
+ """Get events from the database
+
+ Args:
+ event_ids (list): The event_ids of the events to fetch
+ check_redacted (bool): If True, check if event has been redacted
+ and redact it.
+ get_prev_content (bool): If True and event is a state event,
+ include the previous states content in the unsigned field.
+ allow_rejected (bool): If True return rejected events.
+
+ Returns:
+ Deferred : Dict from event_id to event.
+ """
+ events = yield self._get_events(
+ event_ids,
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
+
+ defer.returnValue({e.event_id: e for e in events})
+
+ @defer.inlineCallbacks
+ def _get_events(self, event_ids, check_redacted=True,
+ get_prev_content=False, allow_rejected=False):
+ if not event_ids:
+ defer.returnValue([])
+
+ event_id_list = event_ids
+ event_ids = set(event_ids)
+
+ event_entry_map = self._get_events_from_cache(
+ event_ids,
+ allow_rejected=allow_rejected,
+ )
+
+ missing_events_ids = [e for e in event_ids if e not in event_entry_map]
+
+ if missing_events_ids:
+ missing_events = yield self._enqueue_events(
+ missing_events_ids,
+ check_redacted=check_redacted,
+ allow_rejected=allow_rejected,
+ )
+
+ event_entry_map.update(missing_events)
+
+ events = []
+ for event_id in event_id_list:
+ entry = event_entry_map.get(event_id, None)
+ if not entry:
+ continue
+
+ if allow_rejected or not entry.event.rejected_reason:
+ if check_redacted and entry.redacted_event:
+ event = entry.redacted_event
+ else:
+ event = entry.event
+
+ events.append(event)
+
+ if get_prev_content:
+ if "replaces_state" in event.unsigned:
+ prev = yield self.get_event(
+ event.unsigned["replaces_state"],
+ get_prev_content=False,
+ allow_none=True,
+ )
+ if prev:
+ event.unsigned = dict(event.unsigned)
+ event.unsigned["prev_content"] = prev.content
+ event.unsigned["prev_sender"] = prev.sender
+
+ defer.returnValue(events)
+
+ def _invalidate_get_event_cache(self, event_id):
+ self._get_event_cache.invalidate((event_id,))
+
+ def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
+ """Fetch events from the caches
+
+ Args:
+ events (list(str)): list of event_ids to fetch
+ allow_rejected (bool): Whether to teturn events that were rejected
+ update_metrics (bool): Whether to update the cache hit ratio metrics
+
+ Returns:
+ dict of event_id -> _EventCacheEntry for each event_id in cache. If
+ allow_rejected is `False` then there will still be an entry but it
+ will be `None`
+ """
+ event_map = {}
+
+ for event_id in events:
+ ret = self._get_event_cache.get(
+ (event_id,), None,
+ update_metrics=update_metrics,
+ )
+ if not ret:
+ continue
+
+ if allow_rejected or not ret.event.rejected_reason:
+ event_map[event_id] = ret
+ else:
+ event_map[event_id] = None
+
+ return event_map
+
+ def _do_fetch(self, conn):
+ """Takes a database connection and waits for requests for events from
+ the _event_fetch_list queue.
+ """
+ event_list = []
+ i = 0
+ while True:
+ try:
+ with self._event_fetch_lock:
+ event_list = self._event_fetch_list
+ self._event_fetch_list = []
+
+ if not event_list:
+ single_threaded = self.database_engine.single_threaded
+ if single_threaded or i > EVENT_QUEUE_ITERATIONS:
+ self._event_fetch_ongoing -= 1
+ return
+ else:
+ self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+ i += 1
+ continue
+ i = 0
+
+ event_id_lists = zip(*event_list)[0]
+ event_ids = [
+ item for sublist in event_id_lists for item in sublist
+ ]
+
+ rows = self._new_transaction(
+ conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
+ )
+
+ row_dict = {
+ r["event_id"]: r
+ for r in rows
+ }
+
+ # We only want to resolve deferreds from the main thread
+ def fire(lst, res):
+ for ids, d in lst:
+ if not d.called:
+ try:
+ with PreserveLoggingContext():
+ d.callback([
+ res[i]
+ for i in ids
+ if i in res
+ ])
+ except Exception:
+ logger.exception("Failed to callback")
+ with PreserveLoggingContext():
+ reactor.callFromThread(fire, event_list, row_dict)
+ except Exception as e:
+ logger.exception("do_fetch")
+
+ # We only want to resolve deferreds from the main thread
+ def fire(evs):
+ for _, d in evs:
+ if not d.called:
+ with PreserveLoggingContext():
+ d.errback(e)
+
+ if event_list:
+ with PreserveLoggingContext():
+ reactor.callFromThread(fire, event_list)
+
+ @defer.inlineCallbacks
+ def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
+ """Fetches events from the database using the _event_fetch_list. This
+ allows batch and bulk fetching of events - it allows us to fetch events
+ without having to create a new transaction for each request for events.
+ """
+ if not events:
+ defer.returnValue({})
+
+ events_d = defer.Deferred()
+ with self._event_fetch_lock:
+ self._event_fetch_list.append(
+ (events, events_d)
+ )
+
+ self._event_fetch_lock.notify()
+
+ if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
+ self._event_fetch_ongoing += 1
+ should_start = True
+ else:
+ should_start = False
+
+ if should_start:
+ with PreserveLoggingContext():
+ self.runWithConnection(
+ self._do_fetch
+ )
+
+ logger.debug("Loading %d events", len(events))
+ with PreserveLoggingContext():
+ rows = yield events_d
+ logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
+
+ if not allow_rejected:
+ rows[:] = [r for r in rows if not r["rejects"]]
+
+ res = yield make_deferred_yieldable(defer.gatherResults(
+ [
+ preserve_fn(self._get_event_from_row)(
+ row["internal_metadata"], row["json"], row["redacts"],
+ rejected_reason=row["rejects"],
+ )
+ for row in rows
+ ],
+ consumeErrors=True
+ ))
+
+ defer.returnValue({
+ e.event.event_id: e
+ for e in res if e
+ })
+
+ def _fetch_event_rows(self, txn, events):
+ rows = []
+ N = 200
+ for i in range(1 + len(events) / N):
+ evs = events[i * N:(i + 1) * N]
+ if not evs:
+ break
+
+ sql = (
+ "SELECT "
+ " e.event_id as event_id, "
+ " e.internal_metadata,"
+ " e.json,"
+ " r.redacts as redacts,"
+ " rej.event_id as rejects "
+ " FROM event_json as e"
+ " LEFT JOIN rejections as rej USING (event_id)"
+ " LEFT JOIN redactions as r ON e.event_id = r.redacts"
+ " WHERE e.event_id IN (%s)"
+ ) % (",".join(["?"] * len(evs)),)
+
+ txn.execute(sql, evs)
+ rows.extend(self.cursor_to_dict(txn))
+
+ return rows
+
+ @defer.inlineCallbacks
+ def _get_event_from_row(self, internal_metadata, js, redacted,
+ rejected_reason=None):
+ with Measure(self._clock, "_get_event_from_row"):
+ d = json.loads(js)
+ internal_metadata = json.loads(internal_metadata)
+
+ if rejected_reason:
+ rejected_reason = yield self._simple_select_one_onecol(
+ table="rejections",
+ keyvalues={"event_id": rejected_reason},
+ retcol="reason",
+ desc="_get_event_from_row_rejected_reason",
+ )
+
+ original_ev = FrozenEvent(
+ d,
+ internal_metadata_dict=internal_metadata,
+ rejected_reason=rejected_reason,
+ )
+
+ redacted_event = None
+ if redacted:
+ redacted_event = prune_event(original_ev)
+
+ redaction_id = yield self._simple_select_one_onecol(
+ table="redactions",
+ keyvalues={"redacts": redacted_event.event_id},
+ retcol="event_id",
+ desc="_get_event_from_row_redactions",
+ )
+
+ redacted_event.unsigned["redacted_by"] = redaction_id
+ # Get the redaction event.
+
+ because = yield self.get_event(
+ redaction_id,
+ check_redacted=False,
+ allow_none=True,
+ )
+
+ if because:
+ # It's fine to do add the event directly, since get_pdu_json
+ # will serialise this field correctly
+ redacted_event.unsigned["redacted_because"] = because
+
+ cache_entry = _EventCacheEntry(
+ event=original_ev,
+ redacted_event=redacted_event,
+ )
+
+ self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
+
+ defer.returnValue(cache_entry)
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 8758b1c0c7..583efb7bdf 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,10 +16,12 @@
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.push.baserules import list_with_base_rules
from synapse.api.constants import EventTypes
from twisted.internet import defer
+import abc
import logging
import simplejson as json
@@ -48,7 +51,39 @@ def _load_rules(rawrules, enabled_map):
return rules
-class PushRuleStore(SQLBaseStore):
+class PushRulesWorkerStore(SQLBaseStore):
+ """This is an abstract base class where subclasses must implement
+ `get_max_push_rules_stream_id` which can be called in the initializer.
+ """
+
+ # This ABCMeta metaclass ensures that we cannot be instantiated without
+ # the abstract methods being implemented.
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, db_conn, hs):
+ super(PushRulesWorkerStore, self).__init__(db_conn, hs)
+
+ push_rules_prefill, push_rules_id = self._get_cache_dict(
+ db_conn, "push_rules_stream",
+ entity_column="user_id",
+ stream_column="stream_id",
+ max_value=self.get_max_push_rules_stream_id(),
+ )
+
+ self.push_rules_stream_cache = StreamChangeCache(
+ "PushRulesStreamChangeCache", push_rules_id,
+ prefilled_cache=push_rules_prefill,
+ )
+
+ @abc.abstractmethod
+ def get_max_push_rules_stream_id(self):
+ """Get the position of the push rules stream.
+
+ Returns:
+ int
+ """
+ raise NotImplementedError()
+
@cachedInlineCallbacks(max_entries=5000)
def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list(
@@ -89,6 +124,24 @@ class PushRuleStore(SQLBaseStore):
r['rule_id']: False if r['enabled'] == 0 else True for r in results
})
+ def have_push_rules_changed_for_user(self, user_id, last_id):
+ if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
+ return defer.succeed(False)
+ else:
+ def have_push_rules_changed_txn(txn):
+ sql = (
+ "SELECT COUNT(stream_id) FROM push_rules_stream"
+ " WHERE user_id = ? AND ? < stream_id"
+ )
+ txn.execute(sql, (user_id, last_id))
+ count, = txn.fetchone()
+ return bool(count)
+ return self.runInteraction(
+ "have_push_rules_changed", have_push_rules_changed_txn
+ )
+
+
+class PushRuleStore(PushRulesWorkerStore):
@cachedList(cached_method_name="get_push_rules_for_user",
list_name="user_ids", num_args=1, inlineCallbacks=True)
def bulk_get_push_rules(self, user_ids):
@@ -526,21 +579,8 @@ class PushRuleStore(SQLBaseStore):
room stream ordering it corresponds to."""
return self._push_rules_stream_id_gen.get_current_token()
- def have_push_rules_changed_for_user(self, user_id, last_id):
- if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
- return defer.succeed(False)
- else:
- def have_push_rules_changed_txn(txn):
- sql = (
- "SELECT COUNT(stream_id) FROM push_rules_stream"
- " WHERE user_id = ? AND ? < stream_id"
- )
- txn.execute(sql, (user_id, last_id))
- count, = txn.fetchone()
- return bool(count)
- return self.runInteraction(
- "have_push_rules_changed", have_push_rules_changed_txn
- )
+ def get_max_push_rules_stream_id(self):
+ return self.get_push_rules_stream_token()[0]
class RuleNotFoundException(Exception):
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 3d8b4d5d5b..f4af3e4caa 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -27,7 +28,7 @@ import types
logger = logging.getLogger(__name__)
-class PusherStore(SQLBaseStore):
+class PusherWorkerStore(SQLBaseStore):
def _decode_pushers_rows(self, rows):
for r in rows:
dataJson = r['data']
@@ -102,9 +103,6 @@ class PusherStore(SQLBaseStore):
rows = yield self.runInteraction("get_all_pushers", get_pushers)
defer.returnValue(rows)
- def get_pushers_stream_token(self):
- return self._pushers_id_gen.get_current_token()
-
def get_all_updated_pushers(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed(([], []))
@@ -177,6 +175,11 @@ class PusherStore(SQLBaseStore):
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
)
+
+class PusherStore(PusherWorkerStore):
+ def get_pushers_stream_token(self):
+ return self._pushers_id_gen.get_current_token()
+
@cachedInlineCallbacks(num_args=1, max_entries=15000)
def get_if_user_has_pusher(self, user_id):
# This only exists for the cachedList decorator
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index bff73f3f04..fc46bf7bb3 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,7 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
+from synapse.storage.account_data import AccountDataWorkerStore
+
from synapse.util.caches.descriptors import cached
from twisted.internet import defer
@@ -23,15 +25,7 @@ import logging
logger = logging.getLogger(__name__)
-class TagsStore(SQLBaseStore):
- def get_max_account_data_stream_id(self):
- """Get the current max stream id for the private user data stream
-
- Returns:
- A deferred int.
- """
- return self._account_data_id_gen.get_current_token()
-
+class TagsWorkerStore(AccountDataWorkerStore):
@cached()
def get_tags_for_user(self, user_id):
"""Get all the tags for a user.
@@ -170,6 +164,8 @@ class TagsStore(SQLBaseStore):
row["tag"]: json.loads(row["content"]) for row in rows
})
+
+class TagsStore(TagsWorkerStore):
@defer.inlineCallbacks
def add_tag_to_room(self, user_id, room_id, tag, content):
"""Add a tag to a room for a user.
|