diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index d01d46338a..de00cae447 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,7 +20,6 @@ from synapse.storage.devices import DeviceStore
from .appservice import (
ApplicationServiceStore, ApplicationServiceTransactionStore
)
-from ._base import LoggingTransaction
from .directory import DirectoryStore
from .events import EventsStore
from .presence import PresenceStore, UserPresenceState
@@ -104,12 +104,6 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "events", "stream_ordering", step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
)
- self._receipts_id_gen = StreamIdGenerator(
- db_conn, "receipts_linearized", "stream_id"
- )
- self._account_data_id_gen = StreamIdGenerator(
- db_conn, "account_data_max_stream_id", "stream_id"
- )
self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id"
)
@@ -124,7 +118,6 @@ class DataStore(RoomMemberStore, RoomStore,
)
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
- self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
@@ -147,27 +140,6 @@ class DataStore(RoomMemberStore, RoomStore,
else:
self._cache_id_gen = None
- events_max = self._stream_id_gen.get_current_token()
- event_cache_prefill, min_event_val = self._get_cache_dict(
- db_conn, "events",
- entity_column="room_id",
- stream_column="stream_ordering",
- max_value=events_max,
- )
- self._events_stream_cache = StreamChangeCache(
- "EventsRoomStreamChangeCache", min_event_val,
- prefilled_cache=event_cache_prefill,
- )
-
- self._membership_stream_cache = StreamChangeCache(
- "MembershipStreamChangeCache", events_max,
- )
-
- account_max = self._account_data_id_gen.get_current_token()
- self._account_data_stream_cache = StreamChangeCache(
- "AccountDataAndTagsChangeCache", account_max,
- )
-
self._presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self._get_cache_dict(
@@ -181,18 +153,6 @@ class DataStore(RoomMemberStore, RoomStore,
prefilled_cache=presence_cache_prefill
)
- push_rules_prefill, push_rules_id = self._get_cache_dict(
- db_conn, "push_rules_stream",
- entity_column="user_id",
- stream_column="stream_id",
- max_value=self._push_rules_stream_id_gen.get_current_token()[0],
- )
-
- self.push_rules_stream_cache = StreamChangeCache(
- "PushRulesStreamChangeCache", push_rules_id,
- prefilled_cache=push_rules_prefill,
- )
-
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
db_conn, "device_inbox",
@@ -227,6 +187,7 @@ class DataStore(RoomMemberStore, RoomStore,
"DeviceListFederationStreamChangeCache", device_list_max,
)
+ events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
db_conn, "current_state_delta_stream",
entity_column="room_id",
@@ -251,20 +212,6 @@ class DataStore(RoomMemberStore, RoomStore,
prefilled_cache=_group_updates_prefill,
)
- cur = LoggingTransaction(
- db_conn.cursor(),
- name="_find_stream_orderings_for_times_txn",
- database_engine=self.database_engine,
- after_callbacks=[],
- final_callbacks=[],
- )
- self._find_stream_orderings_for_times_txn(cur)
- cur.close()
-
- self.find_stream_orderings_looping_call = self._clock.looping_call(
- self._find_stream_orderings_for_times, 10 * 60 * 1000
- )
-
self._stream_order_on_start = self.get_room_max_stream_ordering()
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b971f0cb18..2fbebd4907 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -48,16 +48,16 @@ class LoggingTransaction(object):
passed to the constructor. Adds logging and metrics to the .execute()
method."""
__slots__ = [
- "txn", "name", "database_engine", "after_callbacks", "final_callbacks",
+ "txn", "name", "database_engine", "after_callbacks", "exception_callbacks",
]
def __init__(self, txn, name, database_engine, after_callbacks,
- final_callbacks):
+ exception_callbacks):
object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine)
object.__setattr__(self, "after_callbacks", after_callbacks)
- object.__setattr__(self, "final_callbacks", final_callbacks)
+ object.__setattr__(self, "exception_callbacks", exception_callbacks)
def call_after(self, callback, *args, **kwargs):
"""Call the given callback on the main twisted thread after the
@@ -66,8 +66,8 @@ class LoggingTransaction(object):
"""
self.after_callbacks.append((callback, args, kwargs))
- def call_finally(self, callback, *args, **kwargs):
- self.final_callbacks.append((callback, args, kwargs))
+ def call_on_exception(self, callback, *args, **kwargs):
+ self.exception_callbacks.append((callback, args, kwargs))
def __getattr__(self, name):
return getattr(self.txn, name)
@@ -215,7 +215,7 @@ class SQLBaseStore(object):
self._clock.looping_call(loop, 10000)
- def _new_transaction(self, conn, desc, after_callbacks, final_callbacks,
+ def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks,
logging_context, func, *args, **kwargs):
start = time.time() * 1000
txn_id = self._TXN_ID
@@ -236,7 +236,7 @@ class SQLBaseStore(object):
txn = conn.cursor()
txn = LoggingTransaction(
txn, name, self.database_engine, after_callbacks,
- final_callbacks,
+ exception_callbacks,
)
r = func(txn, *args, **kwargs)
conn.commit()
@@ -291,52 +291,66 @@ class SQLBaseStore(object):
@defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs):
- """Wraps the .runInteraction() method on the underlying db_pool."""
- current_context = LoggingContext.current_context()
+ """Starts a transaction on the database and runs a given function
- start_time = time.time() * 1000
+ Arguments:
+ desc (str): description of the transaction, for logging and metrics
+ func (func): callback function, which will be called with a
+ database transaction (twisted.enterprise.adbapi.Transaction) as
+ its first argument, followed by `args` and `kwargs`.
- after_callbacks = []
- final_callbacks = []
+ args (list): positional args to pass to `func`
+ kwargs (dict): named args to pass to `func`
- def inner_func(conn, *args, **kwargs):
- with LoggingContext("runInteraction") as context:
- sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+ Returns:
+ Deferred: The result of func
+ """
+ current_context = LoggingContext.current_context()
- if self.database_engine.is_connection_closed(conn):
- logger.debug("Reconnecting closed database connection")
- conn.reconnect()
+ after_callbacks = []
+ exception_callbacks = []
- current_context.copy_to(context)
- return self._new_transaction(
- conn, desc, after_callbacks, final_callbacks, current_context,
- func, *args, **kwargs
- )
+ def inner_func(conn, *args, **kwargs):
+ return self._new_transaction(
+ conn, desc, after_callbacks, exception_callbacks, current_context,
+ func, *args, **kwargs
+ )
try:
- with PreserveLoggingContext():
- result = yield self._db_pool.runWithConnection(
- inner_func, *args, **kwargs
- )
+ result = yield self.runWithConnection(inner_func, *args, **kwargs)
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
- finally:
- for after_callback, after_args, after_kwargs in final_callbacks:
+ except: # noqa: E722, as we reraise the exception this is fine.
+ for after_callback, after_args, after_kwargs in exception_callbacks:
after_callback(*after_args, **after_kwargs)
+ raise
defer.returnValue(result)
@defer.inlineCallbacks
def runWithConnection(self, func, *args, **kwargs):
- """Wraps the .runInteraction() method on the underlying db_pool."""
+ """Wraps the .runWithConnection() method on the underlying db_pool.
+
+ Arguments:
+ func (func): callback function, which will be called with a
+ database connection (twisted.enterprise.adbapi.Connection) as
+ its first argument, followed by `args` and `kwargs`.
+ args (list): positional args to pass to `func`
+ kwargs (dict): named args to pass to `func`
+
+ Returns:
+ Deferred: The result of func
+ """
current_context = LoggingContext.current_context()
start_time = time.time() * 1000
def inner_func(conn, *args, **kwargs):
with LoggingContext("runWithConnection") as context:
- sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+ sched_duration_ms = time.time() * 1000 - start_time
+ sql_scheduling_timer.inc_by(sched_duration_ms)
+ current_context.add_database_scheduled(sched_duration_ms)
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
@@ -987,7 +1001,8 @@ class SQLBaseStore(object):
# __exit__ called after the transaction finishes.
ctx = self._cache_id_gen.get_next()
stream_id = ctx.__enter__()
- txn.call_finally(ctx.__exit__, None, None, None)
+ txn.call_on_exception(ctx.__exit__, None, None, None)
+ txn.call_after(ctx.__exit__, None, None, None)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
self._simple_insert_txn(
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index 40a2ad8d05..f83ff0454a 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,18 +14,46 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
from twisted.internet import defer
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.util.id_generators import StreamIdGenerator
+
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
+import abc
import simplejson as json
import logging
logger = logging.getLogger(__name__)
-class AccountDataStore(SQLBaseStore):
+class AccountDataWorkerStore(SQLBaseStore):
+ """This is an abstract base class where subclasses must implement
+ `get_max_account_data_stream_id` which can be called in the initializer.
+ """
+
+ # This ABCMeta metaclass ensures that we cannot be instantiated without
+ # the abstract methods being implemented.
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, db_conn, hs):
+ account_max = self.get_max_account_data_stream_id()
+ self._account_data_stream_cache = StreamChangeCache(
+ "AccountDataAndTagsChangeCache", account_max,
+ )
+
+ super(AccountDataWorkerStore, self).__init__(db_conn, hs)
+
+ @abc.abstractmethod
+ def get_max_account_data_stream_id(self):
+ """Get the current max stream ID for account data stream
+
+ Returns:
+ int
+ """
+ raise NotImplementedError()
@cached()
def get_account_data_for_user(self, user_id):
@@ -104,6 +133,7 @@ class AccountDataStore(SQLBaseStore):
for row in rows
})
+ @cached(num_args=2)
def get_account_data_for_room(self, user_id, room_id):
"""Get all the client account_data for a user for a room.
@@ -127,6 +157,38 @@ class AccountDataStore(SQLBaseStore):
"get_account_data_for_room", get_account_data_for_room_txn
)
+ @cached(num_args=3, max_entries=5000)
+ def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
+ """Get the client account_data of given type for a user for a room.
+
+ Args:
+ user_id(str): The user to get the account_data for.
+ room_id(str): The room to get the account_data for.
+ account_data_type (str): The account data type to get.
+ Returns:
+ A deferred of the room account_data for that type, or None if
+ there isn't any set.
+ """
+ def get_account_data_for_room_and_type_txn(txn):
+ content_json = self._simple_select_one_onecol_txn(
+ txn,
+ table="room_account_data",
+ keyvalues={
+ "user_id": user_id,
+ "room_id": room_id,
+ "account_data_type": account_data_type,
+ },
+ retcol="content",
+ allow_none=True
+ )
+
+ return json.loads(content_json) if content_json else None
+
+ return self.runInteraction(
+ "get_account_data_for_room_and_type",
+ get_account_data_for_room_and_type_txn,
+ )
+
def get_all_updated_account_data(self, last_global_id, last_room_id,
current_id, limit):
"""Get all the client account_data that has changed on the server
@@ -209,6 +271,36 @@ class AccountDataStore(SQLBaseStore):
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
+ @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
+ def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
+ ignored_account_data = yield self.get_global_account_data_by_type_for_user(
+ "m.ignored_user_list", ignorer_user_id,
+ on_invalidate=cache_context.invalidate,
+ )
+ if not ignored_account_data:
+ defer.returnValue(False)
+
+ defer.returnValue(
+ ignored_user_id in ignored_account_data.get("ignored_users", {})
+ )
+
+
+class AccountDataStore(AccountDataWorkerStore):
+ def __init__(self, db_conn, hs):
+ self._account_data_id_gen = StreamIdGenerator(
+ db_conn, "account_data_max_stream_id", "stream_id"
+ )
+
+ super(AccountDataStore, self).__init__(db_conn, hs)
+
+ def get_max_account_data_stream_id(self):
+ """Get the current max stream id for the private user data stream
+
+ Returns:
+ A deferred int.
+ """
+ return self._account_data_id_gen.get_current_token()
+
@defer.inlineCallbacks
def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
"""Add some account_data to a room for a user.
@@ -251,6 +343,10 @@ class AccountDataStore(SQLBaseStore):
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,))
+ self.get_account_data_for_room.invalidate((user_id, room_id,))
+ self.get_account_data_for_room_and_type.prefill(
+ (user_id, room_id, account_data_type,), content,
+ )
result = self._account_data_id_gen.get_current_token()
defer.returnValue(result)
@@ -321,16 +417,3 @@ class AccountDataStore(SQLBaseStore):
"update_account_data_max_stream_id",
_update,
)
-
- @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
- def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
- ignored_account_data = yield self.get_global_account_data_by_type_for_user(
- "m.ignored_user_list", ignorer_user_id,
- on_invalidate=cache_context.invalidate,
- )
- if not ignored_account_data:
- defer.returnValue(False)
-
- defer.returnValue(
- ignored_user_id in ignored_account_data.get("ignored_users", {})
- )
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index d8c84b7141..12ea8a158c 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,10 +18,9 @@ import re
import simplejson as json
from twisted.internet import defer
-from synapse.api.constants import Membership
from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
-from synapse.storage.roommember import RoomsForUser
+from synapse.storage.events import EventsWorkerStore
from ._base import SQLBaseStore
@@ -46,17 +46,16 @@ def _make_exclusive_regex(services_cache):
return exclusive_user_regex
-class ApplicationServiceStore(SQLBaseStore):
-
+class ApplicationServiceWorkerStore(SQLBaseStore):
def __init__(self, db_conn, hs):
- super(ApplicationServiceStore, self).__init__(db_conn, hs)
- self.hostname = hs.hostname
self.services_cache = load_appservices(
hs.hostname,
hs.config.app_service_config_files
)
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
+ super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs)
+
def get_app_services(self):
return self.services_cache
@@ -99,83 +98,30 @@ class ApplicationServiceStore(SQLBaseStore):
return service
return None
- def get_app_service_rooms(self, service):
- """Get a list of RoomsForUser for this application service.
-
- Application services may be "interested" in lots of rooms depending on
- the room ID, the room aliases, or the members in the room. This function
- takes all of these into account and returns a list of RoomsForUser which
- represent the entire list of room IDs that this application service
- wants to know about.
+ def get_app_service_by_id(self, as_id):
+ """Get the application service with the given appservice ID.
Args:
- service: The application service to get a room list for.
+ as_id (str): The application service ID.
Returns:
- A list of RoomsForUser.
+ synapse.appservice.ApplicationService or None.
"""
- return self.runInteraction(
- "get_app_service_rooms",
- self._get_app_service_rooms_txn,
- service,
- )
-
- def _get_app_service_rooms_txn(self, txn, service):
- # get all rooms matching the room ID regex.
- room_entries = self._simple_select_list_txn(
- txn=txn, table="rooms", keyvalues=None, retcols=["room_id"]
- )
- matching_room_list = set([
- r["room_id"] for r in room_entries if
- service.is_interested_in_room(r["room_id"])
- ])
-
- # resolve room IDs for matching room alias regex.
- room_alias_mappings = self._simple_select_list_txn(
- txn=txn, table="room_aliases", keyvalues=None,
- retcols=["room_id", "room_alias"]
- )
- matching_room_list |= set([
- r["room_id"] for r in room_alias_mappings if
- service.is_interested_in_alias(r["room_alias"])
- ])
-
- # get all rooms for every user for this AS. This is scoped to users on
- # this HS only.
- user_list = self._simple_select_list_txn(
- txn=txn, table="users", keyvalues=None, retcols=["name"]
- )
- user_list = [
- u["name"] for u in user_list if
- service.is_interested_in_user(u["name"])
- ]
- rooms_for_user_matching_user_id = set() # RoomsForUser list
- for user_id in user_list:
- # FIXME: This assumes this store is linked with RoomMemberStore :(
- rooms_for_user = self._get_rooms_for_user_where_membership_is_txn(
- txn=txn,
- user_id=user_id,
- membership_list=[Membership.JOIN]
- )
- rooms_for_user_matching_user_id |= set(rooms_for_user)
-
- # make RoomsForUser tuples for room ids and aliases which are not in the
- # main rooms_for_user_list - e.g. they are rooms which do not have AS
- # registered users in it.
- known_room_ids = [r.room_id for r in rooms_for_user_matching_user_id]
- missing_rooms_for_user = [
- RoomsForUser(r, service.sender, "join") for r in
- matching_room_list if r not in known_room_ids
- ]
- rooms_for_user_matching_user_id |= set(missing_rooms_for_user)
-
- return rooms_for_user_matching_user_id
+ for service in self.services_cache:
+ if service.id == as_id:
+ return service
+ return None
-class ApplicationServiceTransactionStore(SQLBaseStore):
+class ApplicationServiceStore(ApplicationServiceWorkerStore):
+ # This is currently empty due to there not being any AS storage functions
+ # that can't be run on the workers. Since this may change in future, and
+ # to keep consistency with the other stores, we keep this empty class for
+ # now.
+ pass
- def __init__(self, db_conn, hs):
- super(ApplicationServiceTransactionStore, self).__init__(db_conn, hs)
+class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
+ EventsWorkerStore):
@defer.inlineCallbacks
def get_appservices_by_state(self, state):
"""Get a list of application services based on their state.
@@ -420,3 +366,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
events = yield self._get_events(event_ids)
defer.returnValue((upper_bound, events))
+
+
+class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore):
+ # This is currently empty due to there not being any AS storage functions
+ # that can't be run on the workers. Since this may change in future, and
+ # to keep consistency with the other stores, we keep this empty class for
+ # now.
+ pass
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 8f3bff311a..8af325a9f5 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -242,6 +242,25 @@ class BackgroundUpdateStore(SQLBaseStore):
"""
self._background_update_handlers[update_name] = update_handler
+ def register_noop_background_update(self, update_name):
+ """Register a noop handler for a background update.
+
+ This is useful when we previously did a background update, but no
+ longer wish to do the update. In this case the background update should
+ be removed from the schema delta files, but there may still be some
+ users who have the background update queued, so this method should
+ also be called to clear the update.
+
+ Args:
+ update_name (str): Name of update
+ """
+ @defer.inlineCallbacks
+ def noop_update(progress, batch_size):
+ yield self._end_background_update(update_name)
+ defer.returnValue(1)
+
+ self.register_background_update_handler(update_name, noop_update)
+
def register_background_index_update(self, update_name, index_name,
table, columns, where_clause=None,
unique=False,
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 79e7c540ad..d0c0059757 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -29,8 +29,7 @@ RoomAliasMapping = namedtuple(
)
-class DirectoryStore(SQLBaseStore):
-
+class DirectoryWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_association_from_room_alias(self, room_alias):
""" Get's the room_id and server list for a given room_alias
@@ -69,6 +68,28 @@ class DirectoryStore(SQLBaseStore):
RoomAliasMapping(room_id, room_alias.to_string(), servers)
)
+ def get_room_alias_creator(self, room_alias):
+ return self._simple_select_one_onecol(
+ table="room_aliases",
+ keyvalues={
+ "room_alias": room_alias,
+ },
+ retcol="creator",
+ desc="get_room_alias_creator",
+ allow_none=True
+ )
+
+ @cached(max_entries=5000)
+ def get_aliases_for_room(self, room_id):
+ return self._simple_select_onecol(
+ "room_aliases",
+ {"room_id": room_id},
+ "room_alias",
+ desc="get_aliases_for_room",
+ )
+
+
+class DirectoryStore(DirectoryWorkerStore):
@defer.inlineCallbacks
def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
""" Creates an associatin between a room alias and room_id/servers
@@ -116,17 +137,6 @@ class DirectoryStore(SQLBaseStore):
)
defer.returnValue(ret)
- def get_room_alias_creator(self, room_alias):
- return self._simple_select_one_onecol(
- table="room_aliases",
- keyvalues={
- "room_alias": room_alias,
- },
- retcol="creator",
- desc="get_room_alias_creator",
- allow_none=True
- )
-
@defer.inlineCallbacks
def delete_room_alias(self, room_alias):
room_id = yield self.runInteraction(
@@ -135,7 +145,6 @@ class DirectoryStore(SQLBaseStore):
room_alias,
)
- self.get_aliases_for_room.invalidate((room_id,))
defer.returnValue(room_id)
def _delete_room_alias_txn(self, txn, room_alias):
@@ -160,17 +169,12 @@ class DirectoryStore(SQLBaseStore):
(room_alias.to_string(),)
)
- return room_id
-
- @cached(max_entries=5000)
- def get_aliases_for_room(self, room_id):
- return self._simple_select_onecol(
- "room_aliases",
- {"room_id": room_id},
- "room_alias",
- desc="get_aliases_for_room",
+ self._invalidate_cache_and_stream(
+ txn, self.get_aliases_for_room, (room_id,)
)
+ return room_id
+
def update_aliases_for_room(self, old_room_id, new_room_id, creator):
def _update_aliases_for_room_txn(txn):
sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?"
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index a6ae79dfad..8a0386c1a4 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -62,3 +62,9 @@ class PostgresEngine(object):
def lock_table(self, txn, table):
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
+
+ def get_next_state_group_id(self, txn):
+ """Returns an int that can be used as a new state_group ID
+ """
+ txn.execute("SELECT nextval('state_group_id_seq')")
+ return txn.fetchone()[0]
diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py
index 755c9a1f07..60f0fa7fb3 100644
--- a/synapse/storage/engines/sqlite3.py
+++ b/synapse/storage/engines/sqlite3.py
@@ -16,6 +16,7 @@
from synapse.storage.prepare_database import prepare_database
import struct
+import threading
class Sqlite3Engine(object):
@@ -24,6 +25,11 @@ class Sqlite3Engine(object):
def __init__(self, database_module, database_config):
self.module = database_module
+ # The current max state_group, or None if we haven't looked
+ # in the DB yet.
+ self._current_state_group_id = None
+ self._current_state_group_id_lock = threading.Lock()
+
def check_database(self, txn):
pass
@@ -43,6 +49,19 @@ class Sqlite3Engine(object):
def lock_table(self, txn, table):
return
+ def get_next_state_group_id(self, txn):
+ """Returns an int that can be used as a new state_group ID
+ """
+ # We do application locking here since if we're using sqlite then
+ # we are a single process synapse.
+ with self._current_state_group_id_lock:
+ if self._current_state_group_id is None:
+ txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
+ self._current_state_group_id = txn.fetchone()[0]
+
+ self._current_state_group_id += 1
+ return self._current_state_group_id
+
# Following functions taken from: https://github.com/coleifer/peewee
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 55a05c59d5..00ee82d300 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -15,7 +15,10 @@
from twisted.internet import defer
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.events import EventsWorkerStore
+from synapse.storage.signatures import SignatureWorkerStore
+
from synapse.api.errors import StoreError
from synapse.util.caches.descriptors import cached
from unpaddedbase64 import encode_base64
@@ -27,30 +30,8 @@ from Queue import PriorityQueue, Empty
logger = logging.getLogger(__name__)
-class EventFederationStore(SQLBaseStore):
- """ Responsible for storing and serving up the various graphs associated
- with an event. Including the main event graph and the auth chains for an
- event.
-
- Also has methods for getting the front (latest) and back (oldest) edges
- of the event graphs. These are used to generate the parents for new events
- and backfilling from another server respectively.
- """
-
- EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
-
- def __init__(self, db_conn, hs):
- super(EventFederationStore, self).__init__(db_conn, hs)
-
- self.register_background_update_handler(
- self.EVENT_AUTH_STATE_ONLY,
- self._background_delete_non_state_event_auth,
- )
-
- hs.get_clock().looping_call(
- self._delete_old_forward_extrem_cache, 60 * 60 * 1000
- )
-
+class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
+ SQLBaseStore):
def get_auth_chain(self, event_ids, include_given=False):
"""Get auth events for given event_ids. The events *must* be state events.
@@ -228,88 +209,6 @@ class EventFederationStore(SQLBaseStore):
return int(min_depth) if min_depth is not None else None
- def _update_min_depth_for_room_txn(self, txn, room_id, depth):
- min_depth = self._get_min_depth_interaction(txn, room_id)
-
- if min_depth and depth >= min_depth:
- return
-
- self._simple_upsert_txn(
- txn,
- table="room_depth",
- keyvalues={
- "room_id": room_id,
- },
- values={
- "min_depth": depth,
- },
- )
-
- def _handle_mult_prev_events(self, txn, events):
- """
- For the given event, update the event edges table and forward and
- backward extremities tables.
- """
- self._simple_insert_many_txn(
- txn,
- table="event_edges",
- values=[
- {
- "event_id": ev.event_id,
- "prev_event_id": e_id,
- "room_id": ev.room_id,
- "is_state": False,
- }
- for ev in events
- for e_id, _ in ev.prev_events
- ],
- )
-
- self._update_backward_extremeties(txn, events)
-
- def _update_backward_extremeties(self, txn, events):
- """Updates the event_backward_extremities tables based on the new/updated
- events being persisted.
-
- This is called for new events *and* for events that were outliers, but
- are now being persisted as non-outliers.
-
- Forward extremities are handled when we first start persisting the events.
- """
- events_by_room = {}
- for ev in events:
- events_by_room.setdefault(ev.room_id, []).append(ev)
-
- query = (
- "INSERT INTO event_backward_extremities (event_id, room_id)"
- " SELECT ?, ? WHERE NOT EXISTS ("
- " SELECT 1 FROM event_backward_extremities"
- " WHERE event_id = ? AND room_id = ?"
- " )"
- " AND NOT EXISTS ("
- " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
- " AND outlier = ?"
- " )"
- )
-
- txn.executemany(query, [
- (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
- for ev in events for e_id, _ in ev.prev_events
- if not ev.internal_metadata.is_outlier()
- ])
-
- query = (
- "DELETE FROM event_backward_extremities"
- " WHERE event_id = ? AND room_id = ?"
- )
- txn.executemany(
- query,
- [
- (ev.event_id, ev.room_id) for ev in events
- if not ev.internal_metadata.is_outlier()
- ]
- )
-
def get_forward_extremeties_for_room(self, room_id, stream_ordering):
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
@@ -371,28 +270,6 @@ class EventFederationStore(SQLBaseStore):
get_forward_extremeties_for_room_txn
)
- def _delete_old_forward_extrem_cache(self):
- def _delete_old_forward_extrem_cache_txn(txn):
- # Delete entries older than a month, while making sure we don't delete
- # the only entries for a room.
- sql = ("""
- DELETE FROM stream_ordering_to_exterm
- WHERE
- room_id IN (
- SELECT room_id
- FROM stream_ordering_to_exterm
- WHERE stream_ordering > ?
- ) AND stream_ordering < ?
- """)
- txn.execute(
- sql,
- (self.stream_ordering_month_ago, self.stream_ordering_month_ago,)
- )
- return self.runInteraction(
- "_delete_old_forward_extrem_cache",
- _delete_old_forward_extrem_cache_txn
- )
-
def get_backfill_events(self, room_id, event_list, limit):
"""Get a list of Events for a given topic that occurred before (and
including) the events in event_list. Return a list of max size `limit`
@@ -522,6 +399,135 @@ class EventFederationStore(SQLBaseStore):
return event_results
+
+class EventFederationStore(EventFederationWorkerStore):
+ """ Responsible for storing and serving up the various graphs associated
+ with an event. Including the main event graph and the auth chains for an
+ event.
+
+ Also has methods for getting the front (latest) and back (oldest) edges
+ of the event graphs. These are used to generate the parents for new events
+ and backfilling from another server respectively.
+ """
+
+ EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
+
+ def __init__(self, db_conn, hs):
+ super(EventFederationStore, self).__init__(db_conn, hs)
+
+ self.register_background_update_handler(
+ self.EVENT_AUTH_STATE_ONLY,
+ self._background_delete_non_state_event_auth,
+ )
+
+ hs.get_clock().looping_call(
+ self._delete_old_forward_extrem_cache, 60 * 60 * 1000
+ )
+
+ def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+ min_depth = self._get_min_depth_interaction(txn, room_id)
+
+ if min_depth and depth >= min_depth:
+ return
+
+ self._simple_upsert_txn(
+ txn,
+ table="room_depth",
+ keyvalues={
+ "room_id": room_id,
+ },
+ values={
+ "min_depth": depth,
+ },
+ )
+
+ def _handle_mult_prev_events(self, txn, events):
+ """
+ For the given event, update the event edges table and forward and
+ backward extremities tables.
+ """
+ self._simple_insert_many_txn(
+ txn,
+ table="event_edges",
+ values=[
+ {
+ "event_id": ev.event_id,
+ "prev_event_id": e_id,
+ "room_id": ev.room_id,
+ "is_state": False,
+ }
+ for ev in events
+ for e_id, _ in ev.prev_events
+ ],
+ )
+
+ self._update_backward_extremeties(txn, events)
+
+ def _update_backward_extremeties(self, txn, events):
+ """Updates the event_backward_extremities tables based on the new/updated
+ events being persisted.
+
+ This is called for new events *and* for events that were outliers, but
+ are now being persisted as non-outliers.
+
+ Forward extremities are handled when we first start persisting the events.
+ """
+ events_by_room = {}
+ for ev in events:
+ events_by_room.setdefault(ev.room_id, []).append(ev)
+
+ query = (
+ "INSERT INTO event_backward_extremities (event_id, room_id)"
+ " SELECT ?, ? WHERE NOT EXISTS ("
+ " SELECT 1 FROM event_backward_extremities"
+ " WHERE event_id = ? AND room_id = ?"
+ " )"
+ " AND NOT EXISTS ("
+ " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
+ " AND outlier = ?"
+ " )"
+ )
+
+ txn.executemany(query, [
+ (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
+ for ev in events for e_id, _ in ev.prev_events
+ if not ev.internal_metadata.is_outlier()
+ ])
+
+ query = (
+ "DELETE FROM event_backward_extremities"
+ " WHERE event_id = ? AND room_id = ?"
+ )
+ txn.executemany(
+ query,
+ [
+ (ev.event_id, ev.room_id) for ev in events
+ if not ev.internal_metadata.is_outlier()
+ ]
+ )
+
+ def _delete_old_forward_extrem_cache(self):
+ def _delete_old_forward_extrem_cache_txn(txn):
+ # Delete entries older than a month, while making sure we don't delete
+ # the only entries for a room.
+ sql = ("""
+ DELETE FROM stream_ordering_to_exterm
+ WHERE
+ room_id IN (
+ SELECT room_id
+ FROM stream_ordering_to_exterm
+ WHERE stream_ordering > ?
+ ) AND stream_ordering < ?
+ """)
+ txn.execute(
+ sql,
+ (self.stream_ordering_month_ago, self.stream_ordering_month_ago,)
+ )
+ return self.runInteraction(
+ "_delete_old_forward_extrem_cache",
+ _delete_old_forward_extrem_cache_txn
+ )
+
def clean_room_for_join(self, room_id):
return self.runInteraction(
"clean_room_for_join",
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index 575d710d5d..e78f8d0114 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, LoggingTransaction
from twisted.internet import defer
from synapse.util.async import sleep
from synapse.util.caches.descriptors import cachedInlineCallbacks
@@ -62,60 +63,28 @@ def _deserialize_action(actions, is_highlight):
return DEFAULT_NOTIF_ACTION
-class EventPushActionsStore(SQLBaseStore):
- EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
-
+class EventPushActionsWorkerStore(SQLBaseStore):
def __init__(self, db_conn, hs):
- super(EventPushActionsStore, self).__init__(db_conn, hs)
-
- self.register_background_index_update(
- self.EPA_HIGHLIGHT_INDEX,
- index_name="event_push_actions_u_highlight",
- table="event_push_actions",
- columns=["user_id", "stream_ordering"],
+ super(EventPushActionsWorkerStore, self).__init__(db_conn, hs)
+
+ # These get correctly set by _find_stream_orderings_for_times_txn
+ self.stream_ordering_month_ago = None
+ self.stream_ordering_day_ago = None
+
+ cur = LoggingTransaction(
+ db_conn.cursor(),
+ name="_find_stream_orderings_for_times_txn",
+ database_engine=self.database_engine,
+ after_callbacks=[],
+ exception_callbacks=[],
)
+ self._find_stream_orderings_for_times_txn(cur)
+ cur.close()
- self.register_background_index_update(
- "event_push_actions_highlights_index",
- index_name="event_push_actions_highlights_index",
- table="event_push_actions",
- columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
- where_clause="highlight=1"
+ self.find_stream_orderings_looping_call = self._clock.looping_call(
+ self._find_stream_orderings_for_times, 10 * 60 * 1000
)
- self._doing_notif_rotation = False
- self._rotate_notif_loop = self._clock.looping_call(
- self._rotate_notifs, 30 * 60 * 1000
- )
-
- def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
- """
- Args:
- event: the event set actions for
- tuples: list of tuples of (user_id, actions)
- """
- values = []
- for uid, actions in tuples:
- is_highlight = 1 if _action_has_highlight(actions) else 0
-
- values.append({
- 'room_id': event.room_id,
- 'event_id': event.event_id,
- 'user_id': uid,
- 'actions': _serialize_action(actions, is_highlight),
- 'stream_ordering': event.internal_metadata.stream_ordering,
- 'topological_ordering': event.depth,
- 'notif': 1,
- 'highlight': is_highlight,
- })
-
- for uid, __ in tuples:
- txn.call_after(
- self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
- (event.room_id, uid)
- )
- self._simple_insert_many_txn(txn, "event_push_actions", values)
-
@cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id
@@ -432,6 +401,280 @@ class EventPushActionsStore(SQLBaseStore):
# Now return the first `limit`
defer.returnValue(notifs[:limit])
+ def add_push_actions_to_staging(self, event_id, user_id_actions):
+ """Add the push actions for the event to the push action staging area.
+
+ Args:
+ event_id (str)
+ user_id_actions (dict[str, list[dict|str])]): A dictionary mapping
+ user_id to list of push actions, where an action can either be
+ a string or dict.
+
+ Returns:
+ Deferred
+ """
+
+ if not user_id_actions:
+ return
+
+ # This is a helper function for generating the necessary tuple that
+ # can be used to inert into the `event_push_actions_staging` table.
+ def _gen_entry(user_id, actions):
+ is_highlight = 1 if _action_has_highlight(actions) else 0
+ return (
+ event_id, # event_id column
+ user_id, # user_id column
+ _serialize_action(actions, is_highlight), # actions column
+ 1, # notif column
+ is_highlight, # highlight column
+ )
+
+ def _add_push_actions_to_staging_txn(txn):
+ # We don't use _simple_insert_many here to avoid the overhead
+ # of generating lists of dicts.
+
+ sql = """
+ INSERT INTO event_push_actions_staging
+ (event_id, user_id, actions, notif, highlight)
+ VALUES (?, ?, ?, ?, ?)
+ """
+
+ txn.executemany(sql, (
+ _gen_entry(user_id, actions)
+ for user_id, actions in user_id_actions.iteritems()
+ ))
+
+ return self.runInteraction(
+ "add_push_actions_to_staging", _add_push_actions_to_staging_txn
+ )
+
+ def remove_push_actions_from_staging(self, event_id):
+ """Called if we failed to persist the event to ensure that stale push
+ actions don't build up in the DB
+
+ Args:
+ event_id (str)
+ """
+
+ return self._simple_delete(
+ table="event_push_actions_staging",
+ keyvalues={
+ "event_id": event_id,
+ },
+ desc="remove_push_actions_from_staging",
+ )
+
+ @defer.inlineCallbacks
+ def _find_stream_orderings_for_times(self):
+ yield self.runInteraction(
+ "_find_stream_orderings_for_times",
+ self._find_stream_orderings_for_times_txn
+ )
+
+ def _find_stream_orderings_for_times_txn(self, txn):
+ logger.info("Searching for stream ordering 1 month ago")
+ self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
+ txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
+ )
+ logger.info(
+ "Found stream ordering 1 month ago: it's %d",
+ self.stream_ordering_month_ago
+ )
+ logger.info("Searching for stream ordering 1 day ago")
+ self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn(
+ txn, self._clock.time_msec() - 24 * 60 * 60 * 1000
+ )
+ logger.info(
+ "Found stream ordering 1 day ago: it's %d",
+ self.stream_ordering_day_ago
+ )
+
+ def find_first_stream_ordering_after_ts(self, ts):
+ """Gets the stream ordering corresponding to a given timestamp.
+
+ Specifically, finds the stream_ordering of the first event that was
+ received on or after the timestamp. This is done by a binary search on
+ the events table, since there is no index on received_ts, so is
+ relatively slow.
+
+ Args:
+ ts (int): timestamp in millis
+
+ Returns:
+ Deferred[int]: stream ordering of the first event received on/after
+ the timestamp
+ """
+ return self.runInteraction(
+ "_find_first_stream_ordering_after_ts_txn",
+ self._find_first_stream_ordering_after_ts_txn,
+ ts,
+ )
+
+ @staticmethod
+ def _find_first_stream_ordering_after_ts_txn(txn, ts):
+ """
+ Find the stream_ordering of the first event that was received on or
+ after a given timestamp. This is relatively slow as there is no index
+ on received_ts but we can then use this to delete push actions before
+ this.
+
+ received_ts must necessarily be in the same order as stream_ordering
+ and stream_ordering is indexed, so we manually binary search using
+ stream_ordering
+
+ Args:
+ txn (twisted.enterprise.adbapi.Transaction):
+ ts (int): timestamp to search for
+
+ Returns:
+ int: stream ordering
+ """
+ txn.execute("SELECT MAX(stream_ordering) FROM events")
+ max_stream_ordering = txn.fetchone()[0]
+
+ if max_stream_ordering is None:
+ return 0
+
+ # We want the first stream_ordering in which received_ts is greater
+ # than or equal to ts. Call this point X.
+ #
+ # We maintain the invariants:
+ #
+ # range_start <= X <= range_end
+ #
+ range_start = 0
+ range_end = max_stream_ordering + 1
+
+ # Given a stream_ordering, look up the timestamp at that
+ # stream_ordering.
+ #
+ # The array may be sparse (we may be missing some stream_orderings).
+ # We treat the gaps as the same as having the same value as the
+ # preceding entry, because we will pick the lowest stream_ordering
+ # which satisfies our requirement of received_ts >= ts.
+ #
+ # For example, if our array of events indexed by stream_ordering is
+ # [10, <none>, 20], we should treat this as being equivalent to
+ # [10, 10, 20].
+ #
+ sql = (
+ "SELECT received_ts FROM events"
+ " WHERE stream_ordering <= ?"
+ " ORDER BY stream_ordering DESC"
+ " LIMIT 1"
+ )
+
+ while range_end - range_start > 0:
+ middle = (range_end + range_start) // 2
+ txn.execute(sql, (middle,))
+ row = txn.fetchone()
+ if row is None:
+ # no rows with stream_ordering<=middle
+ range_start = middle + 1
+ continue
+
+ middle_ts = row[0]
+ if ts > middle_ts:
+ # we got a timestamp lower than the one we were looking for.
+ # definitely need to look higher: X > middle.
+ range_start = middle + 1
+ else:
+ # we got a timestamp higher than (or the same as) the one we
+ # were looking for. We aren't yet sure about the point we
+ # looked up, but we can be sure that X <= middle.
+ range_end = middle
+
+ return range_end
+
+
+class EventPushActionsStore(EventPushActionsWorkerStore):
+ EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
+
+ def __init__(self, db_conn, hs):
+ super(EventPushActionsStore, self).__init__(db_conn, hs)
+
+ self.register_background_index_update(
+ self.EPA_HIGHLIGHT_INDEX,
+ index_name="event_push_actions_u_highlight",
+ table="event_push_actions",
+ columns=["user_id", "stream_ordering"],
+ )
+
+ self.register_background_index_update(
+ "event_push_actions_highlights_index",
+ index_name="event_push_actions_highlights_index",
+ table="event_push_actions",
+ columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
+ where_clause="highlight=1"
+ )
+
+ self._doing_notif_rotation = False
+ self._rotate_notif_loop = self._clock.looping_call(
+ self._rotate_notifs, 30 * 60 * 1000
+ )
+
+ def _set_push_actions_for_event_and_users_txn(self, txn, events_and_contexts,
+ all_events_and_contexts):
+ """Handles moving push actions from staging table to main
+ event_push_actions table for all events in `events_and_contexts`.
+
+ Also ensures that all events in `all_events_and_contexts` are removed
+ from the push action staging area.
+
+ Args:
+ events_and_contexts (list[(EventBase, EventContext)]): events
+ we are persisting
+ all_events_and_contexts (list[(EventBase, EventContext)]): all
+ events that we were going to persist. This includes events
+ we've already persisted, etc, that wouldn't appear in
+ events_and_context.
+ """
+
+ sql = """
+ INSERT INTO event_push_actions (
+ room_id, event_id, user_id, actions, stream_ordering,
+ topological_ordering, notif, highlight
+ )
+ SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
+ FROM event_push_actions_staging
+ WHERE event_id = ?
+ """
+
+ if events_and_contexts:
+ txn.executemany(sql, (
+ (
+ event.room_id, event.internal_metadata.stream_ordering,
+ event.depth, event.event_id,
+ )
+ for event, _ in events_and_contexts
+ ))
+
+ for event, _ in events_and_contexts:
+ user_ids = self._simple_select_onecol_txn(
+ txn,
+ table="event_push_actions_staging",
+ keyvalues={
+ "event_id": event.event_id,
+ },
+ retcol="user_id",
+ )
+
+ for uid in user_ids:
+ txn.call_after(
+ self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+ (event.room_id, uid,)
+ )
+
+ # Now we delete the staging area for *all* events that were being
+ # persisted.
+ txn.executemany(
+ "DELETE FROM event_push_actions_staging WHERE event_id = ?",
+ (
+ (event.event_id,)
+ for event, _ in all_events_and_contexts
+ )
+ )
+
@defer.inlineCallbacks
def get_push_actions_for_user(self, user_id, before=None, limit=50,
only_highlight=False):
@@ -551,69 +794,6 @@ class EventPushActionsStore(SQLBaseStore):
""", (room_id, user_id, stream_ordering))
@defer.inlineCallbacks
- def _find_stream_orderings_for_times(self):
- yield self.runInteraction(
- "_find_stream_orderings_for_times",
- self._find_stream_orderings_for_times_txn
- )
-
- def _find_stream_orderings_for_times_txn(self, txn):
- logger.info("Searching for stream ordering 1 month ago")
- self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
- txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
- )
- logger.info(
- "Found stream ordering 1 month ago: it's %d",
- self.stream_ordering_month_ago
- )
- logger.info("Searching for stream ordering 1 day ago")
- self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn(
- txn, self._clock.time_msec() - 24 * 60 * 60 * 1000
- )
- logger.info(
- "Found stream ordering 1 day ago: it's %d",
- self.stream_ordering_day_ago
- )
-
- def _find_first_stream_ordering_after_ts_txn(self, txn, ts):
- """
- Find the stream_ordering of the first event that was received after
- a given timestamp. This is relatively slow as there is no index on
- received_ts but we can then use this to delete push actions before
- this.
-
- received_ts must necessarily be in the same order as stream_ordering
- and stream_ordering is indexed, so we manually binary search using
- stream_ordering
- """
- txn.execute("SELECT MAX(stream_ordering) FROM events")
- max_stream_ordering = txn.fetchone()[0]
-
- if max_stream_ordering is None:
- return 0
-
- range_start = 0
- range_end = max_stream_ordering
-
- sql = (
- "SELECT received_ts FROM events"
- " WHERE stream_ordering > ?"
- " ORDER BY stream_ordering"
- " LIMIT 1"
- )
-
- while range_end - range_start > 1:
- middle = int((range_end + range_start) / 2)
- txn.execute(sql, (middle,))
- middle_ts = txn.fetchone()[0]
- if ts > middle_ts:
- range_start = middle
- else:
- range_end = middle
-
- return range_end
-
- @defer.inlineCallbacks
def _rotate_notifs(self):
if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
return
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 73658a9927..85ce6bea1a 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,23 +13,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
-from twisted.internet import defer, reactor
+from synapse.storage.events_worker import EventsWorkerStore
-from synapse.events import FrozenEvent, USE_FROZEN_DICTS
-from synapse.events.utils import prune_event
+from twisted.internet import defer
+
+from synapse.events import USE_FROZEN_DICTS
from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import (
- preserve_fn, PreserveLoggingContext, make_deferred_yieldable
+ PreserveLoggingContext, make_deferred_yieldable
)
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
-from synapse.state import resolve_events
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.types import get_domain_from_id
from canonicaljson import encode_canonical_json
@@ -61,16 +61,6 @@ def encode_json(json_object):
return json.dumps(json_object, ensure_ascii=False)
-# These values are used in the `enqueus_event` and `_do_fetch` methods to
-# control how we batch/bulk fetch events from the database.
-# The values are plucked out of thing air to make initial sync run faster
-# on jki.re
-# TODO: Make these configurable.
-EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events
-EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
-EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
-
-
class _EventPeristenceQueue(object):
"""Queues up events so that they can be persisted in bulk with only one
concurrent transaction per room.
@@ -109,7 +99,7 @@ class _EventPeristenceQueue(object):
end_item.events_and_contexts.extend(events_and_contexts)
return end_item.deferred.observe()
- deferred = ObservableDeferred(defer.Deferred())
+ deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
queue.append(self._EventPersistQueueItem(
events_and_contexts=events_and_contexts,
@@ -145,18 +135,25 @@ class _EventPeristenceQueue(object):
try:
queue = self._get_drainining_queue(room_id)
for item in queue:
+ # handle_queue_loop runs in the sentinel logcontext, so
+ # there is no need to preserve_fn when running the
+ # callbacks on the deferred.
try:
ret = yield per_item_callback(item)
item.deferred.callback(ret)
- except Exception as e:
- item.deferred.errback(e)
+ except Exception:
+ item.deferred.errback()
finally:
queue = self._event_persist_queues.pop(room_id, None)
if queue:
self._event_persist_queues[room_id] = queue
self._currently_persisting_rooms.discard(room_id)
- preserve_fn(handle_queue_loop)()
+ # set handle_queue_loop off on the background. We don't want to
+ # attribute work done in it to the current request, so we drop the
+ # logcontext altogether.
+ with PreserveLoggingContext():
+ handle_queue_loop()
def _get_drainining_queue(self, room_id):
queue = self._event_persist_queues.setdefault(room_id, deque())
@@ -192,13 +189,12 @@ def _retry_on_integrity_error(func):
return f
-class EventsStore(SQLBaseStore):
+class EventsStore(EventsWorkerStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
def __init__(self, db_conn, hs):
super(EventsStore, self).__init__(db_conn, hs)
- self._clock = hs.get_clock()
self.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
)
@@ -229,6 +225,8 @@ class EventsStore(SQLBaseStore):
self._event_persist_queue = _EventPeristenceQueue()
+ self._state_resolution_handler = hs.get_state_resolution_handler()
+
def persist_events(self, events_and_contexts, backfilled=False):
"""
Write events to the database
@@ -284,10 +282,11 @@ class EventsStore(SQLBaseStore):
def _maybe_start_persisting(self, room_id):
@defer.inlineCallbacks
def persisting_queue(item):
- yield self._persist_events(
- item.events_and_contexts,
- backfilled=item.backfilled,
- )
+ with Measure(self._clock, "persist_events"):
+ yield self._persist_events(
+ item.events_and_contexts,
+ backfilled=item.backfilled,
+ )
self._event_persist_queue.handle_queue(room_id, persisting_queue)
@@ -334,8 +333,20 @@ class EventsStore(SQLBaseStore):
# NB: Assumes that we are only persisting events for one room
# at a time.
+
+ # map room_id->list[event_ids] giving the new forward
+ # extremities in each room
new_forward_extremeties = {}
+
+ # map room_id->(type,state_key)->event_id tracking the full
+ # state in each room after adding these events
current_state_for_room = {}
+
+ # map room_id->(to_delete, to_insert) where each entry is
+ # a map (type,key)->event_id giving the state delta in each
+ # room
+ state_delta_for_room = {}
+
if not backfilled:
with Measure(self._clock, "_calculate_state_and_extrem"):
# Work out the new "current state" for each room.
@@ -378,11 +389,20 @@ class EventsStore(SQLBaseStore):
if all_single_prev_not_state:
continue
- state = yield self._calculate_state_delta(
- room_id, ev_ctx_rm, new_latest_event_ids
+ logger.info(
+ "Calculating state delta for room %s", room_id,
)
- if state:
- current_state_for_room[room_id] = state
+ current_state = yield self._get_new_state_after_events(
+ room_id,
+ ev_ctx_rm, new_latest_event_ids,
+ )
+ if current_state is not None:
+ current_state_for_room[room_id] = current_state
+ delta = yield self._calculate_state_delta(
+ room_id, current_state,
+ )
+ if delta is not None:
+ state_delta_for_room[room_id] = delta
yield self.runInteraction(
"persist_events",
@@ -390,7 +410,7 @@ class EventsStore(SQLBaseStore):
events_and_contexts=chunk,
backfilled=backfilled,
delete_existing=delete_existing,
- current_state_for_room=current_state_for_room,
+ state_delta_for_room=state_delta_for_room,
new_forward_extremeties=new_forward_extremeties,
)
persist_event_counter.inc_by(len(chunk))
@@ -407,7 +427,7 @@ class EventsStore(SQLBaseStore):
event_counter.inc(event.type, origin_type, origin_entity)
- for room_id, (_, _, new_state) in current_state_for_room.iteritems():
+ for room_id, new_state in current_state_for_room.iteritems():
self.get_current_state_ids.prefill(
(room_id, ), new_state
)
@@ -459,22 +479,31 @@ class EventsStore(SQLBaseStore):
defer.returnValue(new_latest_event_ids)
@defer.inlineCallbacks
- def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids):
- """Calculate the new state deltas for a room.
+ def _get_new_state_after_events(self, room_id, events_context, new_latest_event_ids):
+ """Calculate the current state dict after adding some new events to
+ a room
- Assumes that we are only persisting events for one room at a time.
+ Args:
+ room_id (str):
+ room to which the events are being added. Used for logging etc
+
+ events_context (list[(EventBase, EventContext)]):
+ events and contexts which are being added to the room
+
+ new_latest_event_ids (iterable[str]):
+ the new forward extremities for the room.
Returns:
- 3-tuple (to_delete, to_insert, new_state) where both are state dicts,
- i.e. (type, state_key) -> event_id. `to_delete` are the entries to
- first be deleted from current_state_events, `to_insert` are entries
- to insert. `new_state` is the full set of state.
- May return None if there are no changes to be applied.
+ Deferred[dict[(str,str), str]|None]:
+ None if there are no changes to the room state, or
+ a dict of (type, state_key) -> event_id].
"""
- # Now we need to work out the different state sets for
- # each state extremities
- state_sets = []
- state_groups = set()
+
+ if not new_latest_event_ids:
+ defer.returnValue({})
+
+ # map from state_group to ((type, key) -> event_id) state map
+ state_groups = {}
missing_event_ids = []
was_updated = False
for event_id in new_latest_event_ids:
@@ -485,16 +514,19 @@ class EventsStore(SQLBaseStore):
if ctx.current_state_ids is None:
raise Exception("Unknown current state")
+ if ctx.state_group is None:
+ # I don't think this can happen, but let's double-check
+ raise Exception(
+ "Context for new extremity event %s has no state "
+ "group" % (event_id, ),
+ )
+
# If we've already seen the state group don't bother adding
# it to the state sets again
if ctx.state_group not in state_groups:
- state_sets.append(ctx.current_state_ids)
+ state_groups[ctx.state_group] = ctx.current_state_ids
if ctx.delta_ids or hasattr(ev, "state_key"):
was_updated = True
- if ctx.state_group:
- # Add this as a seen state group (if it has a state
- # group)
- state_groups.add(ctx.state_group)
break
else:
# If we couldn't find it, then we'll need to pull
@@ -502,60 +534,50 @@ class EventsStore(SQLBaseStore):
was_updated = True
missing_event_ids.append(event_id)
+ if not was_updated:
+ return
+
if missing_event_ids:
# Now pull out the state for any missing events from DB
event_to_groups = yield self._get_state_group_for_events(
missing_event_ids,
)
- groups = set(event_to_groups.itervalues()) - state_groups
+ groups = set(event_to_groups.itervalues()) - set(state_groups.iterkeys())
if groups:
group_to_state = yield self._get_state_for_groups(groups)
- state_sets.extend(group_to_state.itervalues())
+ state_groups.update(group_to_state)
- if not new_latest_event_ids:
- current_state = {}
- elif was_updated:
- if len(state_sets) == 1:
- # If there is only one state set, then we know what the current
- # state is.
- current_state = state_sets[0]
- else:
- # We work out the current state by passing the state sets to the
- # state resolution algorithm. It may ask for some events, including
- # the events we have yet to persist, so we need a slightly more
- # complicated event lookup function than simply looking the events
- # up in the db.
- events_map = {ev.event_id: ev for ev, _ in events_context}
-
- @defer.inlineCallbacks
- def get_events(ev_ids):
- # We get the events by first looking at the list of events we
- # are trying to persist, and then fetching the rest from the DB.
- db = []
- to_return = {}
- for ev_id in ev_ids:
- ev = events_map.get(ev_id, None)
- if ev:
- to_return[ev_id] = ev
- else:
- db.append(ev_id)
-
- if db:
- evs = yield self.get_events(
- ev_ids, get_prev_content=False, check_redacted=False,
- )
- to_return.update(evs)
- defer.returnValue(to_return)
+ if len(state_groups) == 1:
+ # If there is only one state group, then we know what the current
+ # state is.
+ defer.returnValue(state_groups.values()[0])
- current_state = yield resolve_events(
- state_sets,
- state_map_factory=get_events,
- )
- else:
- return
+ def get_events(ev_ids):
+ return self.get_events(
+ ev_ids, get_prev_content=False, check_redacted=False,
+ )
+ events_map = {ev.event_id: ev for ev, _ in events_context}
+ logger.debug("calling resolve_state_groups from preserve_events")
+ res = yield self._state_resolution_handler.resolve_state_groups(
+ room_id, state_groups, events_map, get_events
+ )
+
+ defer.returnValue(res.state)
+
+ @defer.inlineCallbacks
+ def _calculate_state_delta(self, room_id, current_state):
+ """Calculate the new state deltas for a room.
+ Assumes that we are only persisting events for one room at a time.
+
+ Returns:
+ 2-tuple (to_delete, to_insert) where both are state dicts,
+ i.e. (type, state_key) -> event_id. `to_delete` are the entries to
+ first be deleted from current_state_events, `to_insert` are entries
+ to insert.
+ """
existing_state = yield self.get_current_state_ids(room_id)
existing_events = set(existing_state.itervalues())
@@ -575,67 +597,11 @@ class EventsStore(SQLBaseStore):
if ev_id in events_to_insert
}
- defer.returnValue((to_delete, to_insert, current_state))
-
- @defer.inlineCallbacks
- def get_event(self, event_id, check_redacted=True,
- get_prev_content=False, allow_rejected=False,
- allow_none=False):
- """Get an event from the database by event_id.
-
- Args:
- event_id (str): The event_id of the event to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
- include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
- allow_none (bool): If True, return None if no event found, if
- False throw an exception.
-
- Returns:
- Deferred : A FrozenEvent.
- """
- events = yield self._get_events(
- [event_id],
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- allow_rejected=allow_rejected,
- )
-
- if not events and not allow_none:
- raise SynapseError(404, "Could not find event %s" % (event_id,))
-
- defer.returnValue(events[0] if events else None)
-
- @defer.inlineCallbacks
- def get_events(self, event_ids, check_redacted=True,
- get_prev_content=False, allow_rejected=False):
- """Get events from the database
-
- Args:
- event_ids (list): The event_ids of the events to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
- include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
-
- Returns:
- Deferred : Dict from event_id to event.
- """
- events = yield self._get_events(
- event_ids,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- allow_rejected=allow_rejected,
- )
-
- defer.returnValue({e.event_id: e for e in events})
+ defer.returnValue((to_delete, to_insert))
@log_function
def _persist_events_txn(self, txn, events_and_contexts, backfilled,
- delete_existing=False, current_state_for_room={},
+ delete_existing=False, state_delta_for_room={},
new_forward_extremeties={}):
"""Insert some number of room events into the necessary database tables.
@@ -651,7 +617,7 @@ class EventsStore(SQLBaseStore):
delete_existing (bool): True to purge existing table rows for the
events from the database. This is useful when retrying due to
IntegrityError.
- current_state_for_room (dict[str, (list[str], list[str])]):
+ state_delta_for_room (dict[str, (list[str], list[str])]):
The current-state delta for each room. For each room, a tuple
(to_delete, to_insert), being a list of event ids to be removed
from the current state, and a list of event ids to be added to
@@ -661,9 +627,11 @@ class EventsStore(SQLBaseStore):
list of the event ids which are the forward extremities.
"""
+ all_events_and_contexts = events_and_contexts
+
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
- self._update_current_state_txn(txn, current_state_for_room, max_stream_order)
+ self._update_current_state_txn(txn, state_delta_for_room, max_stream_order)
self._update_forward_extremities_txn(
txn,
@@ -707,9 +675,8 @@ class EventsStore(SQLBaseStore):
events_and_contexts=events_and_contexts,
)
- # Insert into the state_groups, state_groups_state, and
- # event_to_state_groups tables.
- self._store_mult_state_groups_txn(txn, events_and_contexts)
+ # Insert into event_to_state_groups.
+ self._store_event_state_mappings_txn(txn, events_and_contexts)
# _store_rejected_events_txn filters out any events which were
# rejected, and returns the filtered list.
@@ -724,12 +691,13 @@ class EventsStore(SQLBaseStore):
self._update_metadata_tables_txn(
txn,
events_and_contexts=events_and_contexts,
+ all_events_and_contexts=all_events_and_contexts,
backfilled=backfilled,
)
def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
for room_id, current_state_tuple in state_delta_by_room.iteritems():
- to_delete, to_insert, _ = current_state_tuple
+ to_delete, to_insert = current_state_tuple
txn.executemany(
"DELETE FROM current_state_events WHERE event_id = ?",
[(ev_id,) for ev_id in to_delete.itervalues()],
@@ -786,7 +754,7 @@ class EventsStore(SQLBaseStore):
for member in members_changed:
self._invalidate_cache_and_stream(
- txn, self.get_rooms_for_user, (member,)
+ txn, self.get_rooms_for_user_with_stream_ordering, (member,)
)
for host in set(get_domain_from_id(u) for u in members_changed):
@@ -944,10 +912,9 @@ class EventsStore(SQLBaseStore):
# an outlier in the database. We now have some state at that
# so we need to update the state_groups table with that state.
- # insert into the state_group, state_groups_state and
- # event_to_state_groups tables.
+ # insert into event_to_state_groups.
try:
- self._store_mult_state_groups_txn(txn, ((event, context),))
+ self._store_event_state_mappings_txn(txn, ((event, context),))
except Exception:
logger.exception("")
raise
@@ -1122,27 +1089,33 @@ class EventsStore(SQLBaseStore):
ec for ec in events_and_contexts if ec[0] not in to_remove
]
- def _update_metadata_tables_txn(self, txn, events_and_contexts, backfilled):
+ def _update_metadata_tables_txn(self, txn, events_and_contexts,
+ all_events_and_contexts, backfilled):
"""Update all the miscellaneous tables for new events
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
events_and_contexts (list[(EventBase, EventContext)]): events
we are persisting
+ all_events_and_contexts (list[(EventBase, EventContext)]): all
+ events that we were going to persist. This includes events
+ we've already persisted, etc, that wouldn't appear in
+ events_and_context.
backfilled (bool): True if the events were backfilled
"""
+ # Insert all the push actions into the event_push_actions table.
+ self._set_push_actions_for_event_and_users_txn(
+ txn,
+ events_and_contexts=events_and_contexts,
+ all_events_and_contexts=all_events_and_contexts,
+ )
+
if not events_and_contexts:
# nothing to do here
return
for event, context in events_and_contexts:
- # Insert all the push actions into the event_push_actions table.
- if context.push_actions:
- self._set_push_actions_for_event_and_users_txn(
- txn, event, context.push_actions
- )
-
if event.type == EventTypes.Redaction and event.redacts is not None:
# Remove the entries in the event_push_actions table for the
# redacted event.
@@ -1347,292 +1320,6 @@ class EventsStore(SQLBaseStore):
)
@defer.inlineCallbacks
- def _get_events(self, event_ids, check_redacted=True,
- get_prev_content=False, allow_rejected=False):
- if not event_ids:
- defer.returnValue([])
-
- event_id_list = event_ids
- event_ids = set(event_ids)
-
- event_entry_map = self._get_events_from_cache(
- event_ids,
- allow_rejected=allow_rejected,
- )
-
- missing_events_ids = [e for e in event_ids if e not in event_entry_map]
-
- if missing_events_ids:
- missing_events = yield self._enqueue_events(
- missing_events_ids,
- check_redacted=check_redacted,
- allow_rejected=allow_rejected,
- )
-
- event_entry_map.update(missing_events)
-
- events = []
- for event_id in event_id_list:
- entry = event_entry_map.get(event_id, None)
- if not entry:
- continue
-
- if allow_rejected or not entry.event.rejected_reason:
- if check_redacted and entry.redacted_event:
- event = entry.redacted_event
- else:
- event = entry.event
-
- events.append(event)
-
- if get_prev_content:
- if "replaces_state" in event.unsigned:
- prev = yield self.get_event(
- event.unsigned["replaces_state"],
- get_prev_content=False,
- allow_none=True,
- )
- if prev:
- event.unsigned = dict(event.unsigned)
- event.unsigned["prev_content"] = prev.content
- event.unsigned["prev_sender"] = prev.sender
-
- defer.returnValue(events)
-
- def _invalidate_get_event_cache(self, event_id):
- self._get_event_cache.invalidate((event_id,))
-
- def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
- """Fetch events from the caches
-
- Args:
- events (list(str)): list of event_ids to fetch
- allow_rejected (bool): Whether to teturn events that were rejected
- update_metrics (bool): Whether to update the cache hit ratio metrics
-
- Returns:
- dict of event_id -> _EventCacheEntry for each event_id in cache. If
- allow_rejected is `False` then there will still be an entry but it
- will be `None`
- """
- event_map = {}
-
- for event_id in events:
- ret = self._get_event_cache.get(
- (event_id,), None,
- update_metrics=update_metrics,
- )
- if not ret:
- continue
-
- if allow_rejected or not ret.event.rejected_reason:
- event_map[event_id] = ret
- else:
- event_map[event_id] = None
-
- return event_map
-
- def _do_fetch(self, conn):
- """Takes a database connection and waits for requests for events from
- the _event_fetch_list queue.
- """
- event_list = []
- i = 0
- while True:
- try:
- with self._event_fetch_lock:
- event_list = self._event_fetch_list
- self._event_fetch_list = []
-
- if not event_list:
- single_threaded = self.database_engine.single_threaded
- if single_threaded or i > EVENT_QUEUE_ITERATIONS:
- self._event_fetch_ongoing -= 1
- return
- else:
- self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
- i += 1
- continue
- i = 0
-
- event_id_lists = zip(*event_list)[0]
- event_ids = [
- item for sublist in event_id_lists for item in sublist
- ]
-
- rows = self._new_transaction(
- conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
- )
-
- row_dict = {
- r["event_id"]: r
- for r in rows
- }
-
- # We only want to resolve deferreds from the main thread
- def fire(lst, res):
- for ids, d in lst:
- if not d.called:
- try:
- with PreserveLoggingContext():
- d.callback([
- res[i]
- for i in ids
- if i in res
- ])
- except Exception:
- logger.exception("Failed to callback")
- with PreserveLoggingContext():
- reactor.callFromThread(fire, event_list, row_dict)
- except Exception as e:
- logger.exception("do_fetch")
-
- # We only want to resolve deferreds from the main thread
- def fire(evs):
- for _, d in evs:
- if not d.called:
- with PreserveLoggingContext():
- d.errback(e)
-
- if event_list:
- with PreserveLoggingContext():
- reactor.callFromThread(fire, event_list)
-
- @defer.inlineCallbacks
- def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
- """Fetches events from the database using the _event_fetch_list. This
- allows batch and bulk fetching of events - it allows us to fetch events
- without having to create a new transaction for each request for events.
- """
- if not events:
- defer.returnValue({})
-
- events_d = defer.Deferred()
- with self._event_fetch_lock:
- self._event_fetch_list.append(
- (events, events_d)
- )
-
- self._event_fetch_lock.notify()
-
- if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
- self._event_fetch_ongoing += 1
- should_start = True
- else:
- should_start = False
-
- if should_start:
- with PreserveLoggingContext():
- self.runWithConnection(
- self._do_fetch
- )
-
- logger.debug("Loading %d events", len(events))
- with PreserveLoggingContext():
- rows = yield events_d
- logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
-
- if not allow_rejected:
- rows[:] = [r for r in rows if not r["rejects"]]
-
- res = yield make_deferred_yieldable(defer.gatherResults(
- [
- preserve_fn(self._get_event_from_row)(
- row["internal_metadata"], row["json"], row["redacts"],
- rejected_reason=row["rejects"],
- )
- for row in rows
- ],
- consumeErrors=True
- ))
-
- defer.returnValue({
- e.event.event_id: e
- for e in res if e
- })
-
- def _fetch_event_rows(self, txn, events):
- rows = []
- N = 200
- for i in range(1 + len(events) / N):
- evs = events[i * N:(i + 1) * N]
- if not evs:
- break
-
- sql = (
- "SELECT "
- " e.event_id as event_id, "
- " e.internal_metadata,"
- " e.json,"
- " r.redacts as redacts,"
- " rej.event_id as rejects "
- " FROM event_json as e"
- " LEFT JOIN rejections as rej USING (event_id)"
- " LEFT JOIN redactions as r ON e.event_id = r.redacts"
- " WHERE e.event_id IN (%s)"
- ) % (",".join(["?"] * len(evs)),)
-
- txn.execute(sql, evs)
- rows.extend(self.cursor_to_dict(txn))
-
- return rows
-
- @defer.inlineCallbacks
- def _get_event_from_row(self, internal_metadata, js, redacted,
- rejected_reason=None):
- with Measure(self._clock, "_get_event_from_row"):
- d = json.loads(js)
- internal_metadata = json.loads(internal_metadata)
-
- if rejected_reason:
- rejected_reason = yield self._simple_select_one_onecol(
- table="rejections",
- keyvalues={"event_id": rejected_reason},
- retcol="reason",
- desc="_get_event_from_row_rejected_reason",
- )
-
- original_ev = FrozenEvent(
- d,
- internal_metadata_dict=internal_metadata,
- rejected_reason=rejected_reason,
- )
-
- redacted_event = None
- if redacted:
- redacted_event = prune_event(original_ev)
-
- redaction_id = yield self._simple_select_one_onecol(
- table="redactions",
- keyvalues={"redacts": redacted_event.event_id},
- retcol="event_id",
- desc="_get_event_from_row_redactions",
- )
-
- redacted_event.unsigned["redacted_by"] = redaction_id
- # Get the redaction event.
-
- because = yield self.get_event(
- redaction_id,
- check_redacted=False,
- allow_none=True,
- )
-
- if because:
- # It's fine to do add the event directly, since get_pdu_json
- # will serialise this field correctly
- redacted_event.unsigned["redacted_because"] = because
-
- cache_entry = _EventCacheEntry(
- event=original_ev,
- redacted_event=redacted_event,
- )
-
- self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
-
- defer.returnValue(cache_entry)
-
- @defer.inlineCallbacks
def count_daily_messages(self):
"""
Returns an estimate of the number of messages sent in the last day.
@@ -2017,16 +1704,32 @@ class EventsStore(SQLBaseStore):
)
return self.runInteraction("get_all_new_events", get_all_new_events_txn)
- def delete_old_state(self, room_id, topological_ordering):
- return self.runInteraction(
- "delete_old_state",
- self._delete_old_state_txn, room_id, topological_ordering
- )
+ def purge_history(
+ self, room_id, topological_ordering, delete_local_events,
+ ):
+ """Deletes room history before a certain point
+
+ Args:
+ room_id (str):
+
+ topological_ordering (int):
+ minimum topo ordering to preserve
- def _delete_old_state_txn(self, txn, room_id, topological_ordering):
- """Deletes old room state
+ delete_local_events (bool):
+ if True, we will delete local events as well as remote ones
+ (instead of just marking them as outliers and deleting their
+ state groups).
"""
+ return self.runInteraction(
+ "purge_history",
+ self._purge_history_txn, room_id, topological_ordering,
+ delete_local_events,
+ )
+
+ def _purge_history_txn(
+ self, txn, room_id, topological_ordering, delete_local_events,
+ ):
# Tables that should be pruned:
# event_auth
# event_backward_extremities
@@ -2047,6 +1750,30 @@ class EventsStore(SQLBaseStore):
# state_groups
# state_groups_state
+ # we will build a temporary table listing the events so that we don't
+ # have to keep shovelling the list back and forth across the
+ # connection. Annoyingly the python sqlite driver commits the
+ # transaction on CREATE, so let's do this first.
+ #
+ # furthermore, we might already have the table from a previous (failed)
+ # purge attempt, so let's drop the table first.
+
+ txn.execute("DROP TABLE IF EXISTS events_to_purge")
+
+ txn.execute(
+ "CREATE TEMPORARY TABLE events_to_purge ("
+ " event_id TEXT NOT NULL,"
+ " should_delete BOOLEAN NOT NULL"
+ ")"
+ )
+
+ # create an index on should_delete because later we'll be looking for
+ # the should_delete / shouldn't_delete subsets
+ txn.execute(
+ "CREATE INDEX events_to_purge_should_delete"
+ " ON events_to_purge(should_delete)",
+ )
+
# First ensure that we're not about to delete all the forward extremeties
txn.execute(
"SELECT e.event_id, e.depth FROM events as e "
@@ -2067,42 +1794,48 @@ class EventsStore(SQLBaseStore):
400, "topological_ordering is greater than forward extremeties"
)
- logger.debug("[purge] looking for events to delete")
+ logger.info("[purge] looking for events to delete")
+
+ should_delete_expr = "state_key IS NULL"
+ should_delete_params = ()
+ if not delete_local_events:
+ should_delete_expr += " AND event_id NOT LIKE ?"
+ should_delete_params += ("%:" + self.hs.hostname, )
+
+ should_delete_params += (room_id, topological_ordering)
txn.execute(
- "SELECT event_id, state_key FROM events"
- " LEFT JOIN state_events USING (room_id, event_id)"
- " WHERE room_id = ? AND topological_ordering < ?",
- (room_id, topological_ordering,)
+ "INSERT INTO events_to_purge"
+ " SELECT event_id, %s"
+ " FROM events AS e LEFT JOIN state_events USING (event_id)"
+ " WHERE e.room_id = ? AND topological_ordering < ?" % (
+ should_delete_expr,
+ ),
+ should_delete_params,
+ )
+ txn.execute(
+ "SELECT event_id, should_delete FROM events_to_purge"
)
event_rows = txn.fetchall()
-
- to_delete = [
- (event_id,) for event_id, state_key in event_rows
- if state_key is None and not self.hs.is_mine_id(event_id)
- ]
logger.info(
- "[purge] found %i events before cutoff, of which %i are remote"
- " non-state events to delete", len(event_rows), len(to_delete))
-
- for event_id, state_key in event_rows:
- txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
+ "[purge] found %i events before cutoff, of which %i can be deleted",
+ len(event_rows), sum(1 for e in event_rows if e[1]),
+ )
- logger.debug("[purge] Finding new backward extremities")
+ logger.info("[purge] Finding new backward extremities")
# We calculate the new entries for the backward extremeties by finding
# all events that point to events that are to be purged
txn.execute(
- "SELECT DISTINCT e.event_id FROM events as e"
- " INNER JOIN event_edges as ed ON e.event_id = ed.prev_event_id"
- " INNER JOIN events as e2 ON e2.event_id = ed.event_id"
- " WHERE e.room_id = ? AND e.topological_ordering < ?"
- " AND e2.topological_ordering >= ?",
- (room_id, topological_ordering, topological_ordering)
+ "SELECT DISTINCT e.event_id FROM events_to_purge AS e"
+ " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id"
+ " INNER JOIN events AS e2 ON e2.event_id = ed.event_id"
+ " WHERE e2.topological_ordering >= ?",
+ (topological_ordering, )
)
new_backwards_extrems = txn.fetchall()
- logger.debug("[purge] replacing backward extremities: %r", new_backwards_extrems)
+ logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
txn.execute(
"DELETE FROM event_backward_extremities WHERE room_id = ?",
@@ -2118,7 +1851,7 @@ class EventsStore(SQLBaseStore):
]
)
- logger.debug("[purge] finding redundant state groups")
+ logger.info("[purge] finding redundant state groups")
# Get all state groups that are only referenced by events that are
# to be deleted.
@@ -2126,24 +1859,23 @@ class EventsStore(SQLBaseStore):
"SELECT state_group FROM event_to_state_groups"
" INNER JOIN events USING (event_id)"
" WHERE state_group IN ("
- " SELECT DISTINCT state_group FROM events"
+ " SELECT DISTINCT state_group FROM events_to_purge"
" INNER JOIN event_to_state_groups USING (event_id)"
- " WHERE room_id = ? AND topological_ordering < ?"
" )"
" GROUP BY state_group HAVING MAX(topological_ordering) < ?",
- (room_id, topological_ordering, topological_ordering)
+ (topological_ordering, )
)
state_rows = txn.fetchall()
- logger.debug("[purge] found %i redundant state groups", len(state_rows))
+ logger.info("[purge] found %i redundant state groups", len(state_rows))
# make a set of the redundant state groups, so that we can look them up
# efficiently
state_groups_to_delete = set([sg for sg, in state_rows])
# Now we get all the state groups that rely on these state groups
- logger.debug("[purge] finding state groups which depend on redundant"
- " state groups")
+ logger.info("[purge] finding state groups which depend on redundant"
+ " state groups")
remaining_state_groups = []
for i in xrange(0, len(state_rows), 100):
chunk = [sg for sg, in state_rows[i:i + 100]]
@@ -2168,7 +1900,7 @@ class EventsStore(SQLBaseStore):
# Now we turn the state groups that reference to-be-deleted state
# groups to non delta versions.
for sg in remaining_state_groups:
- logger.debug("[purge] de-delta-ing remaining state group %s", sg)
+ logger.info("[purge] de-delta-ing remaining state group %s", sg)
curr_state = self._get_state_groups_from_groups_txn(
txn, [sg], types=None
)
@@ -2205,7 +1937,7 @@ class EventsStore(SQLBaseStore):
],
)
- logger.debug("[purge] removing redundant state groups")
+ logger.info("[purge] removing redundant state groups")
txn.executemany(
"DELETE FROM state_groups_state WHERE state_group = ?",
state_rows
@@ -2215,18 +1947,15 @@ class EventsStore(SQLBaseStore):
state_rows
)
- # Delete all non-state
- logger.debug("[purge] removing events from event_to_state_groups")
- txn.executemany(
- "DELETE FROM event_to_state_groups WHERE event_id = ?",
- [(event_id,) for event_id, _ in event_rows]
- )
-
- logger.debug("[purge] updating room_depth")
+ logger.info("[purge] removing events from event_to_state_groups")
txn.execute(
- "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
- (topological_ordering, room_id,)
+ "DELETE FROM event_to_state_groups "
+ "WHERE event_id IN (SELECT event_id from events_to_purge)"
)
+ for event_id, _ in event_rows:
+ txn.call_after(self._get_state_group_for_event.invalidate, (
+ event_id,
+ ))
# Delete all remote non-state events
for table in (
@@ -2238,28 +1967,60 @@ class EventsStore(SQLBaseStore):
"event_edge_hashes",
"event_edges",
"event_forward_extremities",
- "event_push_actions",
"event_reference_hashes",
"event_search",
"event_signatures",
"rejections",
):
- logger.debug("[purge] removing remote non-state events from %s", table)
+ logger.info("[purge] removing events from %s", table)
- txn.executemany(
- "DELETE FROM %s WHERE event_id = ?" % (table,),
- to_delete
+ txn.execute(
+ "DELETE FROM %s WHERE event_id IN ("
+ " SELECT event_id FROM events_to_purge WHERE should_delete"
+ ")" % (table,),
+ )
+
+ # event_push_actions lacks an index on event_id, and has one on
+ # (room_id, event_id) instead.
+ for table in (
+ "event_push_actions",
+ ):
+ logger.info("[purge] removing events from %s", table)
+
+ txn.execute(
+ "DELETE FROM %s WHERE room_id = ? AND event_id IN ("
+ " SELECT event_id FROM events_to_purge WHERE should_delete"
+ ")" % (table,),
+ (room_id, )
)
# Mark all state and own events as outliers
- logger.debug("[purge] marking remaining events as outliers")
- txn.executemany(
+ logger.info("[purge] marking remaining events as outliers")
+ txn.execute(
"UPDATE events SET outlier = ?"
- " WHERE event_id = ?",
- [
- (True, event_id,) for event_id, state_key in event_rows
- if state_key is not None or self.hs.is_mine_id(event_id)
- ]
+ " WHERE event_id IN ("
+ " SELECT event_id FROM events_to_purge "
+ " WHERE NOT should_delete"
+ ")",
+ (True,),
+ )
+
+ # synapse tries to take out an exclusive lock on room_depth whenever it
+ # persists events (because upsert), and once we run this update, we
+ # will block that for the rest of our transaction.
+ #
+ # So, let's stick it at the end so that we don't block event
+ # persistence.
+ logger.info("[purge] updating room_depth")
+ txn.execute(
+ "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
+ (topological_ordering, room_id,)
+ )
+
+ # finally, drop the temp table. this will commit the txn in sqlite,
+ # so make sure to keep this actually last.
+ txn.execute(
+ "DROP TABLE events_to_purge"
)
logger.info("[purge] done")
@@ -2272,7 +2033,7 @@ class EventsStore(SQLBaseStore):
to_2, so_2 = yield self._get_event_ordering(event_id2)
defer.returnValue((to_1, so_1) > (to_2, so_2))
- @defer.inlineCallbacks
+ @cachedInlineCallbacks(max_entries=5000)
def _get_event_ordering(self, event_id):
res = yield self._simple_select_one(
table="events",
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
new file mode 100644
index 0000000000..2e23dd78ba
--- /dev/null
+++ b/synapse/storage/events_worker.py
@@ -0,0 +1,395 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import SQLBaseStore
+
+from twisted.internet import defer, reactor
+
+from synapse.events import FrozenEvent
+from synapse.events.utils import prune_event
+
+from synapse.util.logcontext import (
+ preserve_fn, PreserveLoggingContext, make_deferred_yieldable
+)
+from synapse.util.metrics import Measure
+from synapse.api.errors import SynapseError
+
+from collections import namedtuple
+
+import logging
+import simplejson as json
+
+# these are only included to make the type annotations work
+from synapse.events import EventBase # noqa: F401
+from synapse.events.snapshot import EventContext # noqa: F401
+
+logger = logging.getLogger(__name__)
+
+
+# These values are used in the `enqueus_event` and `_do_fetch` methods to
+# control how we batch/bulk fetch events from the database.
+# The values are plucked out of thing air to make initial sync run faster
+# on jki.re
+# TODO: Make these configurable.
+EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events
+EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
+EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
+
+
+_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
+
+
+class EventsWorkerStore(SQLBaseStore):
+
+ @defer.inlineCallbacks
+ def get_event(self, event_id, check_redacted=True,
+ get_prev_content=False, allow_rejected=False,
+ allow_none=False):
+ """Get an event from the database by event_id.
+
+ Args:
+ event_id (str): The event_id of the event to fetch
+ check_redacted (bool): If True, check if event has been redacted
+ and redact it.
+ get_prev_content (bool): If True and event is a state event,
+ include the previous states content in the unsigned field.
+ allow_rejected (bool): If True return rejected events.
+ allow_none (bool): If True, return None if no event found, if
+ False throw an exception.
+
+ Returns:
+ Deferred : A FrozenEvent.
+ """
+ events = yield self._get_events(
+ [event_id],
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
+
+ if not events and not allow_none:
+ raise SynapseError(404, "Could not find event %s" % (event_id,))
+
+ defer.returnValue(events[0] if events else None)
+
+ @defer.inlineCallbacks
+ def get_events(self, event_ids, check_redacted=True,
+ get_prev_content=False, allow_rejected=False):
+ """Get events from the database
+
+ Args:
+ event_ids (list): The event_ids of the events to fetch
+ check_redacted (bool): If True, check if event has been redacted
+ and redact it.
+ get_prev_content (bool): If True and event is a state event,
+ include the previous states content in the unsigned field.
+ allow_rejected (bool): If True return rejected events.
+
+ Returns:
+ Deferred : Dict from event_id to event.
+ """
+ events = yield self._get_events(
+ event_ids,
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
+
+ defer.returnValue({e.event_id: e for e in events})
+
+ @defer.inlineCallbacks
+ def _get_events(self, event_ids, check_redacted=True,
+ get_prev_content=False, allow_rejected=False):
+ if not event_ids:
+ defer.returnValue([])
+
+ event_id_list = event_ids
+ event_ids = set(event_ids)
+
+ event_entry_map = self._get_events_from_cache(
+ event_ids,
+ allow_rejected=allow_rejected,
+ )
+
+ missing_events_ids = [e for e in event_ids if e not in event_entry_map]
+
+ if missing_events_ids:
+ missing_events = yield self._enqueue_events(
+ missing_events_ids,
+ check_redacted=check_redacted,
+ allow_rejected=allow_rejected,
+ )
+
+ event_entry_map.update(missing_events)
+
+ events = []
+ for event_id in event_id_list:
+ entry = event_entry_map.get(event_id, None)
+ if not entry:
+ continue
+
+ if allow_rejected or not entry.event.rejected_reason:
+ if check_redacted and entry.redacted_event:
+ event = entry.redacted_event
+ else:
+ event = entry.event
+
+ events.append(event)
+
+ if get_prev_content:
+ if "replaces_state" in event.unsigned:
+ prev = yield self.get_event(
+ event.unsigned["replaces_state"],
+ get_prev_content=False,
+ allow_none=True,
+ )
+ if prev:
+ event.unsigned = dict(event.unsigned)
+ event.unsigned["prev_content"] = prev.content
+ event.unsigned["prev_sender"] = prev.sender
+
+ defer.returnValue(events)
+
+ def _invalidate_get_event_cache(self, event_id):
+ self._get_event_cache.invalidate((event_id,))
+
+ def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
+ """Fetch events from the caches
+
+ Args:
+ events (list(str)): list of event_ids to fetch
+ allow_rejected (bool): Whether to teturn events that were rejected
+ update_metrics (bool): Whether to update the cache hit ratio metrics
+
+ Returns:
+ dict of event_id -> _EventCacheEntry for each event_id in cache. If
+ allow_rejected is `False` then there will still be an entry but it
+ will be `None`
+ """
+ event_map = {}
+
+ for event_id in events:
+ ret = self._get_event_cache.get(
+ (event_id,), None,
+ update_metrics=update_metrics,
+ )
+ if not ret:
+ continue
+
+ if allow_rejected or not ret.event.rejected_reason:
+ event_map[event_id] = ret
+ else:
+ event_map[event_id] = None
+
+ return event_map
+
+ def _do_fetch(self, conn):
+ """Takes a database connection and waits for requests for events from
+ the _event_fetch_list queue.
+ """
+ event_list = []
+ i = 0
+ while True:
+ try:
+ with self._event_fetch_lock:
+ event_list = self._event_fetch_list
+ self._event_fetch_list = []
+
+ if not event_list:
+ single_threaded = self.database_engine.single_threaded
+ if single_threaded or i > EVENT_QUEUE_ITERATIONS:
+ self._event_fetch_ongoing -= 1
+ return
+ else:
+ self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+ i += 1
+ continue
+ i = 0
+
+ event_id_lists = zip(*event_list)[0]
+ event_ids = [
+ item for sublist in event_id_lists for item in sublist
+ ]
+
+ rows = self._new_transaction(
+ conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
+ )
+
+ row_dict = {
+ r["event_id"]: r
+ for r in rows
+ }
+
+ # We only want to resolve deferreds from the main thread
+ def fire(lst, res):
+ for ids, d in lst:
+ if not d.called:
+ try:
+ with PreserveLoggingContext():
+ d.callback([
+ res[i]
+ for i in ids
+ if i in res
+ ])
+ except Exception:
+ logger.exception("Failed to callback")
+ with PreserveLoggingContext():
+ reactor.callFromThread(fire, event_list, row_dict)
+ except Exception as e:
+ logger.exception("do_fetch")
+
+ # We only want to resolve deferreds from the main thread
+ def fire(evs):
+ for _, d in evs:
+ if not d.called:
+ with PreserveLoggingContext():
+ d.errback(e)
+
+ if event_list:
+ with PreserveLoggingContext():
+ reactor.callFromThread(fire, event_list)
+
+ @defer.inlineCallbacks
+ def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
+ """Fetches events from the database using the _event_fetch_list. This
+ allows batch and bulk fetching of events - it allows us to fetch events
+ without having to create a new transaction for each request for events.
+ """
+ if not events:
+ defer.returnValue({})
+
+ events_d = defer.Deferred()
+ with self._event_fetch_lock:
+ self._event_fetch_list.append(
+ (events, events_d)
+ )
+
+ self._event_fetch_lock.notify()
+
+ if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
+ self._event_fetch_ongoing += 1
+ should_start = True
+ else:
+ should_start = False
+
+ if should_start:
+ with PreserveLoggingContext():
+ self.runWithConnection(
+ self._do_fetch
+ )
+
+ logger.debug("Loading %d events", len(events))
+ with PreserveLoggingContext():
+ rows = yield events_d
+ logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
+
+ if not allow_rejected:
+ rows[:] = [r for r in rows if not r["rejects"]]
+
+ res = yield make_deferred_yieldable(defer.gatherResults(
+ [
+ preserve_fn(self._get_event_from_row)(
+ row["internal_metadata"], row["json"], row["redacts"],
+ rejected_reason=row["rejects"],
+ )
+ for row in rows
+ ],
+ consumeErrors=True
+ ))
+
+ defer.returnValue({
+ e.event.event_id: e
+ for e in res if e
+ })
+
+ def _fetch_event_rows(self, txn, events):
+ rows = []
+ N = 200
+ for i in range(1 + len(events) / N):
+ evs = events[i * N:(i + 1) * N]
+ if not evs:
+ break
+
+ sql = (
+ "SELECT "
+ " e.event_id as event_id, "
+ " e.internal_metadata,"
+ " e.json,"
+ " r.redacts as redacts,"
+ " rej.event_id as rejects "
+ " FROM event_json as e"
+ " LEFT JOIN rejections as rej USING (event_id)"
+ " LEFT JOIN redactions as r ON e.event_id = r.redacts"
+ " WHERE e.event_id IN (%s)"
+ ) % (",".join(["?"] * len(evs)),)
+
+ txn.execute(sql, evs)
+ rows.extend(self.cursor_to_dict(txn))
+
+ return rows
+
+ @defer.inlineCallbacks
+ def _get_event_from_row(self, internal_metadata, js, redacted,
+ rejected_reason=None):
+ with Measure(self._clock, "_get_event_from_row"):
+ d = json.loads(js)
+ internal_metadata = json.loads(internal_metadata)
+
+ if rejected_reason:
+ rejected_reason = yield self._simple_select_one_onecol(
+ table="rejections",
+ keyvalues={"event_id": rejected_reason},
+ retcol="reason",
+ desc="_get_event_from_row_rejected_reason",
+ )
+
+ original_ev = FrozenEvent(
+ d,
+ internal_metadata_dict=internal_metadata,
+ rejected_reason=rejected_reason,
+ )
+
+ redacted_event = None
+ if redacted:
+ redacted_event = prune_event(original_ev)
+
+ redaction_id = yield self._simple_select_one_onecol(
+ table="redactions",
+ keyvalues={"redacts": redacted_event.event_id},
+ retcol="event_id",
+ desc="_get_event_from_row_redactions",
+ )
+
+ redacted_event.unsigned["redacted_by"] = redaction_id
+ # Get the redaction event.
+
+ because = yield self.get_event(
+ redaction_id,
+ check_redacted=False,
+ allow_none=True,
+ )
+
+ if because:
+ # It's fine to do add the event directly, since get_pdu_json
+ # will serialise this field correctly
+ redacted_event.unsigned["redacted_because"] = because
+
+ cache_entry = _EventCacheEntry(
+ event=original_ev,
+ redacted_event=redacted_event,
+ )
+
+ self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
+
+ defer.returnValue(cache_entry)
diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py
index a66ff7c1e0..e6cdbb0545 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/media_repository.py
@@ -29,9 +29,6 @@ class MediaRepositoryStore(BackgroundUpdateStore):
where_clause='url_cache IS NOT NULL',
)
- def get_default_thumbnails(self, top_level_type, sub_type):
- return []
-
def get_local_media(self, media_id):
"""Get the metadata for a local piece of media
Returns:
@@ -176,7 +173,14 @@ class MediaRepositoryStore(BackgroundUpdateStore):
desc="store_cached_remote_media",
)
- def update_cached_last_access_time(self, origin_id_tuples, time_ts):
+ def update_cached_last_access_time(self, local_media, remote_media, time_ms):
+ """Updates the last access time of the given media
+
+ Args:
+ local_media (iterable[str]): Set of media_ids
+ remote_media (iterable[(str, str)]): Set of (server_name, media_id)
+ time_ms: Current time in milliseconds
+ """
def update_cache_txn(txn):
sql = (
"UPDATE remote_media_cache SET last_access_ts = ?"
@@ -184,8 +188,18 @@ class MediaRepositoryStore(BackgroundUpdateStore):
)
txn.executemany(sql, (
- (time_ts, media_origin, media_id)
- for media_origin, media_id in origin_id_tuples
+ (time_ms, media_origin, media_id)
+ for media_origin, media_id in remote_media
+ ))
+
+ sql = (
+ "UPDATE local_media_repository SET last_access_ts = ?"
+ " WHERE media_id = ?"
+ )
+
+ txn.executemany(sql, (
+ (time_ms, media_id)
+ for media_id in local_media
))
return self.runInteraction("update_cached_last_access_time", update_cache_txn)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index d1691bbac2..c845a0cec5 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 46
+SCHEMA_VERSION = 47
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
index ec02e73bc2..8612bd5ecc 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/profile.py
@@ -21,14 +21,7 @@ from synapse.api.errors import StoreError
from ._base import SQLBaseStore
-class ProfileStore(SQLBaseStore):
- def create_profile(self, user_localpart):
- return self._simple_insert(
- table="profiles",
- values={"user_id": user_localpart},
- desc="create_profile",
- )
-
+class ProfileWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_profileinfo(self, user_localpart):
try:
@@ -61,14 +54,6 @@ class ProfileStore(SQLBaseStore):
desc="get_profile_displayname",
)
- def set_profile_displayname(self, user_localpart, new_displayname):
- return self._simple_update_one(
- table="profiles",
- keyvalues={"user_id": user_localpart},
- updatevalues={"displayname": new_displayname},
- desc="set_profile_displayname",
- )
-
def get_profile_avatar_url(self, user_localpart):
return self._simple_select_one_onecol(
table="profiles",
@@ -77,14 +62,6 @@ class ProfileStore(SQLBaseStore):
desc="get_profile_avatar_url",
)
- def set_profile_avatar_url(self, user_localpart, new_avatar_url):
- return self._simple_update_one(
- table="profiles",
- keyvalues={"user_id": user_localpart},
- updatevalues={"avatar_url": new_avatar_url},
- desc="set_profile_avatar_url",
- )
-
def get_from_remote_profile_cache(self, user_id):
return self._simple_select_one(
table="remote_profile_cache",
@@ -94,6 +71,31 @@ class ProfileStore(SQLBaseStore):
desc="get_from_remote_profile_cache",
)
+
+class ProfileStore(ProfileWorkerStore):
+ def create_profile(self, user_localpart):
+ return self._simple_insert(
+ table="profiles",
+ values={"user_id": user_localpart},
+ desc="create_profile",
+ )
+
+ def set_profile_displayname(self, user_localpart, new_displayname):
+ return self._simple_update_one(
+ table="profiles",
+ keyvalues={"user_id": user_localpart},
+ updatevalues={"displayname": new_displayname},
+ desc="set_profile_displayname",
+ )
+
+ def set_profile_avatar_url(self, user_localpart, new_avatar_url):
+ return self._simple_update_one(
+ table="profiles",
+ keyvalues={"user_id": user_localpart},
+ updatevalues={"avatar_url": new_avatar_url},
+ desc="set_profile_avatar_url",
+ )
+
def add_remote_profile_cache(self, user_id, displayname, avatar_url):
"""Ensure we are caching the remote user's profiles.
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 8758b1c0c7..04a0b59a39 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,11 +15,17 @@
# limitations under the License.
from ._base import SQLBaseStore
+from synapse.storage.appservice import ApplicationServiceWorkerStore
+from synapse.storage.pusher import PusherWorkerStore
+from synapse.storage.receipts import ReceiptsWorkerStore
+from synapse.storage.roommember import RoomMemberWorkerStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.push.baserules import list_with_base_rules
from synapse.api.constants import EventTypes
from twisted.internet import defer
+import abc
import logging
import simplejson as json
@@ -48,7 +55,43 @@ def _load_rules(rawrules, enabled_map):
return rules
-class PushRuleStore(SQLBaseStore):
+class PushRulesWorkerStore(ApplicationServiceWorkerStore,
+ ReceiptsWorkerStore,
+ PusherWorkerStore,
+ RoomMemberWorkerStore,
+ SQLBaseStore):
+ """This is an abstract base class where subclasses must implement
+ `get_max_push_rules_stream_id` which can be called in the initializer.
+ """
+
+ # This ABCMeta metaclass ensures that we cannot be instantiated without
+ # the abstract methods being implemented.
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, db_conn, hs):
+ super(PushRulesWorkerStore, self).__init__(db_conn, hs)
+
+ push_rules_prefill, push_rules_id = self._get_cache_dict(
+ db_conn, "push_rules_stream",
+ entity_column="user_id",
+ stream_column="stream_id",
+ max_value=self.get_max_push_rules_stream_id(),
+ )
+
+ self.push_rules_stream_cache = StreamChangeCache(
+ "PushRulesStreamChangeCache", push_rules_id,
+ prefilled_cache=push_rules_prefill,
+ )
+
+ @abc.abstractmethod
+ def get_max_push_rules_stream_id(self):
+ """Get the position of the push rules stream.
+
+ Returns:
+ int
+ """
+ raise NotImplementedError()
+
@cachedInlineCallbacks(max_entries=5000)
def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list(
@@ -89,6 +132,22 @@ class PushRuleStore(SQLBaseStore):
r['rule_id']: False if r['enabled'] == 0 else True for r in results
})
+ def have_push_rules_changed_for_user(self, user_id, last_id):
+ if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
+ return defer.succeed(False)
+ else:
+ def have_push_rules_changed_txn(txn):
+ sql = (
+ "SELECT COUNT(stream_id) FROM push_rules_stream"
+ " WHERE user_id = ? AND ? < stream_id"
+ )
+ txn.execute(sql, (user_id, last_id))
+ count, = txn.fetchone()
+ return bool(count)
+ return self.runInteraction(
+ "have_push_rules_changed", have_push_rules_changed_txn
+ )
+
@cachedList(cached_method_name="get_push_rules_for_user",
list_name="user_ids", num_args=1, inlineCallbacks=True)
def bulk_get_push_rules(self, user_ids):
@@ -228,6 +287,8 @@ class PushRuleStore(SQLBaseStore):
results.setdefault(row['user_name'], {})[row['rule_id']] = enabled
defer.returnValue(results)
+
+class PushRuleStore(PushRulesWorkerStore):
@defer.inlineCallbacks
def add_push_rule(
self, user_id, rule_id, priority_class, conditions, actions,
@@ -526,21 +587,8 @@ class PushRuleStore(SQLBaseStore):
room stream ordering it corresponds to."""
return self._push_rules_stream_id_gen.get_current_token()
- def have_push_rules_changed_for_user(self, user_id, last_id):
- if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
- return defer.succeed(False)
- else:
- def have_push_rules_changed_txn(txn):
- sql = (
- "SELECT COUNT(stream_id) FROM push_rules_stream"
- " WHERE user_id = ? AND ? < stream_id"
- )
- txn.execute(sql, (user_id, last_id))
- count, = txn.fetchone()
- return bool(count)
- return self.runInteraction(
- "have_push_rules_changed", have_push_rules_changed_txn
- )
+ def get_max_push_rules_stream_id(self):
+ return self.get_push_rules_stream_token()[0]
class RuleNotFoundException(Exception):
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 3d8b4d5d5b..307660b99a 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -27,7 +28,7 @@ import types
logger = logging.getLogger(__name__)
-class PusherStore(SQLBaseStore):
+class PusherWorkerStore(SQLBaseStore):
def _decode_pushers_rows(self, rows):
for r in rows:
dataJson = r['data']
@@ -102,9 +103,6 @@ class PusherStore(SQLBaseStore):
rows = yield self.runInteraction("get_all_pushers", get_pushers)
defer.returnValue(rows)
- def get_pushers_stream_token(self):
- return self._pushers_id_gen.get_current_token()
-
def get_all_updated_pushers(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed(([], []))
@@ -198,6 +196,11 @@ class PusherStore(SQLBaseStore):
defer.returnValue(result)
+
+class PusherStore(PusherWorkerStore):
+ def get_pushers_stream_token(self):
+ return self._pushers_id_gen.get_current_token()
+
@defer.inlineCallbacks
def add_pusher(self, user_id, access_token, kind, app_id,
app_display_name, device_display_name,
@@ -230,14 +233,18 @@ class PusherStore(SQLBaseStore):
)
if newly_inserted:
- # get_if_user_has_pusher only cares if the user has
- # at least *one* pusher.
- self.get_if_user_has_pusher.invalidate(user_id,)
+ self.runInteraction(
+ "add_pusher",
+ self._invalidate_cache_and_stream,
+ self.get_if_user_has_pusher, (user_id,)
+ )
@defer.inlineCallbacks
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
def delete_pusher_txn(txn, stream_id):
- txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,))
+ self._invalidate_cache_and_stream(
+ txn, self.get_if_user_has_pusher, (user_id,)
+ )
self._simple_delete_one_txn(
txn,
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 2c3aa33693..63997ed449 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,11 +15,13 @@
# limitations under the License.
from ._base import SQLBaseStore
+from .util.id_generators import StreamIdGenerator
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
from twisted.internet import defer
+import abc
import logging
import simplejson as json
@@ -26,39 +29,36 @@ import simplejson as json
logger = logging.getLogger(__name__)
-class ReceiptsStore(SQLBaseStore):
+class ReceiptsWorkerStore(SQLBaseStore):
+ """This is an abstract base class where subclasses must implement
+ `get_max_receipt_stream_id` which can be called in the initializer.
+ """
+
+ # This ABCMeta metaclass ensures that we cannot be instantiated without
+ # the abstract methods being implemented.
+ __metaclass__ = abc.ABCMeta
+
def __init__(self, db_conn, hs):
- super(ReceiptsStore, self).__init__(db_conn, hs)
+ super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
self._receipts_stream_cache = StreamChangeCache(
- "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
+ "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
)
+ @abc.abstractmethod
+ def get_max_receipt_stream_id(self):
+ """Get the current max stream ID for receipts stream
+
+ Returns:
+ int
+ """
+ raise NotImplementedError()
+
@cachedInlineCallbacks()
def get_users_with_read_receipts_in_room(self, room_id):
receipts = yield self.get_receipts_for_room(room_id, "m.read")
defer.returnValue(set(r['user_id'] for r in receipts))
- def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
- user_id):
- if receipt_type != "m.read":
- return
-
- # Returns an ObservableDeferred
- res = self.get_users_with_read_receipts_in_room.cache.get(
- room_id, None, update_metrics=False,
- )
-
- if res:
- if isinstance(res, defer.Deferred) and res.called:
- res = res.result
- if user_id in res:
- # We'd only be adding to the set, so no point invalidating if the
- # user is already there
- return
-
- self.get_users_with_read_receipts_in_room.invalidate((room_id,))
-
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
return self._simple_select_list(
@@ -270,6 +270,59 @@ class ReceiptsStore(SQLBaseStore):
}
defer.returnValue(results)
+ def get_all_updated_receipts(self, last_id, current_id, limit=None):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_updated_receipts_txn(txn):
+ sql = (
+ "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
+ " FROM receipts_linearized"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC"
+ )
+ args = [last_id, current_id]
+ if limit is not None:
+ sql += " LIMIT ?"
+ args.append(limit)
+ txn.execute(sql, args)
+
+ return txn.fetchall()
+ return self.runInteraction(
+ "get_all_updated_receipts", get_all_updated_receipts_txn
+ )
+
+ def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
+ user_id):
+ if receipt_type != "m.read":
+ return
+
+ # Returns an ObservableDeferred
+ res = self.get_users_with_read_receipts_in_room.cache.get(
+ room_id, None, update_metrics=False,
+ )
+
+ if res:
+ if isinstance(res, defer.Deferred) and res.called:
+ res = res.result
+ if user_id in res:
+ # We'd only be adding to the set, so no point invalidating if the
+ # user is already there
+ return
+
+ self.get_users_with_read_receipts_in_room.invalidate((room_id,))
+
+
+class ReceiptsStore(ReceiptsWorkerStore):
+ def __init__(self, db_conn, hs):
+ # We instantiate this first as the ReceiptsWorkerStore constructor
+ # needs to be able to call get_max_receipt_stream_id
+ self._receipts_id_gen = StreamIdGenerator(
+ db_conn, "receipts_linearized", "stream_id"
+ )
+
+ super(ReceiptsStore, self).__init__(db_conn, hs)
+
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
@@ -457,25 +510,3 @@ class ReceiptsStore(SQLBaseStore):
"data": json.dumps(data),
}
)
-
- def get_all_updated_receipts(self, last_id, current_id, limit=None):
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_updated_receipts_txn(txn):
- sql = (
- "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
- " FROM receipts_linearized"
- " WHERE ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC"
- )
- args = [last_id, current_id]
- if limit is not None:
- sql += " LIMIT ?"
- args.append(limit)
- txn.execute(sql, args)
-
- return txn.fetchall()
- return self.runInteraction(
- "get_all_updated_receipts", get_all_updated_receipts_txn
- )
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 3aa810981f..d809b2ba46 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -19,10 +19,70 @@ from twisted.internet import defer
from synapse.api.errors import StoreError, Codes
from synapse.storage import background_updates
+from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-class RegistrationStore(background_updates.BackgroundUpdateStore):
+class RegistrationWorkerStore(SQLBaseStore):
+ @cached()
+ def get_user_by_id(self, user_id):
+ return self._simple_select_one(
+ table="users",
+ keyvalues={
+ "name": user_id,
+ },
+ retcols=["name", "password_hash", "is_guest"],
+ allow_none=True,
+ desc="get_user_by_id",
+ )
+
+ @cached()
+ def get_user_by_access_token(self, token):
+ """Get a user from the given access token.
+
+ Args:
+ token (str): The access token of a user.
+ Returns:
+ defer.Deferred: None, if the token did not match, otherwise dict
+ including the keys `name`, `is_guest`, `device_id`, `token_id`.
+ """
+ return self.runInteraction(
+ "get_user_by_access_token",
+ self._query_for_auth,
+ token
+ )
+
+ @defer.inlineCallbacks
+ def is_server_admin(self, user):
+ res = yield self._simple_select_one_onecol(
+ table="users",
+ keyvalues={"name": user.to_string()},
+ retcol="admin",
+ allow_none=True,
+ desc="is_server_admin",
+ )
+
+ defer.returnValue(res if res else False)
+
+ def _query_for_auth(self, txn, token):
+ sql = (
+ "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
+ " access_tokens.device_id"
+ " FROM users"
+ " INNER JOIN access_tokens on users.name = access_tokens.user_id"
+ " WHERE token = ?"
+ )
+
+ txn.execute(sql, (token,))
+ rows = self.cursor_to_dict(txn)
+ if rows:
+ return rows[0]
+
+ return None
+
+
+class RegistrationStore(RegistrationWorkerStore,
+ background_updates.BackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(RegistrationStore, self).__init__(db_conn, hs)
@@ -39,12 +99,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
# we no longer use refresh tokens, but it's possible that some people
# might have a background update queued to build this index. Just
# clear the background update.
- @defer.inlineCallbacks
- def noop_update(progress, batch_size):
- yield self._end_background_update("refresh_tokens_device_index")
- defer.returnValue(1)
- self.register_background_update_handler(
- "refresh_tokens_device_index", noop_update)
+ self.register_noop_background_update("refresh_tokens_device_index")
@defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id=None):
@@ -192,18 +247,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
)
txn.call_after(self.is_guest.invalidate, (user_id,))
- @cached()
- def get_user_by_id(self, user_id):
- return self._simple_select_one(
- table="users",
- keyvalues={
- "name": user_id,
- },
- retcols=["name", "password_hash", "is_guest"],
- allow_none=True,
- desc="get_user_by_id",
- )
-
def get_users_by_id_case_insensitive(self, user_id):
"""Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash.
@@ -309,34 +352,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
return self.runInteraction("delete_access_token", f)
- @cached()
- def get_user_by_access_token(self, token):
- """Get a user from the given access token.
-
- Args:
- token (str): The access token of a user.
- Returns:
- defer.Deferred: None, if the token did not match, otherwise dict
- including the keys `name`, `is_guest`, `device_id`, `token_id`.
- """
- return self.runInteraction(
- "get_user_by_access_token",
- self._query_for_auth,
- token
- )
-
- @defer.inlineCallbacks
- def is_server_admin(self, user):
- res = yield self._simple_select_one_onecol(
- table="users",
- keyvalues={"name": user.to_string()},
- retcol="admin",
- allow_none=True,
- desc="is_server_admin",
- )
-
- defer.returnValue(res if res else False)
-
@cachedInlineCallbacks()
def is_guest(self, user_id):
res = yield self._simple_select_one_onecol(
@@ -349,22 +364,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
defer.returnValue(res if res else False)
- def _query_for_auth(self, txn, token):
- sql = (
- "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
- " access_tokens.device_id"
- " FROM users"
- " INNER JOIN access_tokens on users.name = access_tokens.user_id"
- " WHERE token = ?"
- )
-
- txn.execute(sql, (token,))
- rows = self.cursor_to_dict(txn)
- if rows:
- return rows[0]
-
- return None
-
@defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
yield self._simple_upsert("user_threepids", {
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 2051d8506d..908551d6d9 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -16,11 +16,10 @@
from twisted.internet import defer
from synapse.api.errors import StoreError
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.search import SearchStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-from ._base import SQLBaseStore
-from .engines import PostgresEngine, Sqlite3Engine
-
import collections
import logging
import simplejson as json
@@ -40,7 +39,138 @@ RatelimitOverride = collections.namedtuple(
)
-class RoomStore(SQLBaseStore):
+class RoomWorkerStore(SQLBaseStore):
+ def get_public_room_ids(self):
+ return self._simple_select_onecol(
+ table="rooms",
+ keyvalues={
+ "is_public": True,
+ },
+ retcol="room_id",
+ desc="get_public_room_ids",
+ )
+
+ @cached(num_args=2, max_entries=100)
+ def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
+ """Get pulbic rooms for a particular list, or across all lists.
+
+ Args:
+ stream_id (int)
+ network_tuple (ThirdPartyInstanceID): The list to use (None, None)
+ means the main list, None means all lsits.
+ """
+ return self.runInteraction(
+ "get_public_room_ids_at_stream_id",
+ self.get_public_room_ids_at_stream_id_txn,
+ stream_id, network_tuple=network_tuple
+ )
+
+ def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
+ network_tuple):
+ return {
+ rm
+ for rm, vis in self.get_published_at_stream_id_txn(
+ txn, stream_id, network_tuple=network_tuple
+ ).items()
+ if vis
+ }
+
+ def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
+ if network_tuple:
+ # We want to get from a particular list. No aggregation required.
+
+ sql = ("""
+ SELECT room_id, visibility FROM public_room_list_stream
+ INNER JOIN (
+ SELECT room_id, max(stream_id) AS stream_id
+ FROM public_room_list_stream
+ WHERE stream_id <= ? %s
+ GROUP BY room_id
+ ) grouped USING (room_id, stream_id)
+ """)
+
+ if network_tuple.appservice_id is not None:
+ txn.execute(
+ sql % ("AND appservice_id = ? AND network_id = ?",),
+ (stream_id, network_tuple.appservice_id, network_tuple.network_id,)
+ )
+ else:
+ txn.execute(
+ sql % ("AND appservice_id IS NULL",),
+ (stream_id,)
+ )
+ return dict(txn)
+ else:
+ # We want to get from all lists, so we need to aggregate the results
+
+ logger.info("Executing full list")
+
+ sql = ("""
+ SELECT room_id, visibility
+ FROM public_room_list_stream
+ INNER JOIN (
+ SELECT
+ room_id, max(stream_id) AS stream_id, appservice_id,
+ network_id
+ FROM public_room_list_stream
+ WHERE stream_id <= ?
+ GROUP BY room_id, appservice_id, network_id
+ ) grouped USING (room_id, stream_id)
+ """)
+
+ txn.execute(
+ sql,
+ (stream_id,)
+ )
+
+ results = {}
+ # A room is visible if its visible on any list.
+ for room_id, visibility in txn:
+ results[room_id] = bool(visibility) or results.get(room_id, False)
+
+ return results
+
+ def get_public_room_changes(self, prev_stream_id, new_stream_id,
+ network_tuple):
+ def get_public_room_changes_txn(txn):
+ then_rooms = self.get_public_room_ids_at_stream_id_txn(
+ txn, prev_stream_id, network_tuple
+ )
+
+ now_rooms_dict = self.get_published_at_stream_id_txn(
+ txn, new_stream_id, network_tuple
+ )
+
+ now_rooms_visible = set(
+ rm for rm, vis in now_rooms_dict.items() if vis
+ )
+ now_rooms_not_visible = set(
+ rm for rm, vis in now_rooms_dict.items() if not vis
+ )
+
+ newly_visible = now_rooms_visible - then_rooms
+ newly_unpublished = now_rooms_not_visible & then_rooms
+
+ return newly_visible, newly_unpublished
+
+ return self.runInteraction(
+ "get_public_room_changes", get_public_room_changes_txn
+ )
+
+ @cached(max_entries=10000)
+ def is_room_blocked(self, room_id):
+ return self._simple_select_one_onecol(
+ table="blocked_rooms",
+ keyvalues={
+ "room_id": room_id,
+ },
+ retcol="1",
+ allow_none=True,
+ desc="is_room_blocked",
+ )
+
+
+class RoomStore(RoomWorkerStore, SearchStore):
@defer.inlineCallbacks
def store_room(self, room_id, room_creator_user_id, is_public):
@@ -227,16 +357,6 @@ class RoomStore(SQLBaseStore):
)
self.hs.get_notifier().on_new_replication_data()
- def get_public_room_ids(self):
- return self._simple_select_onecol(
- table="rooms",
- keyvalues={
- "is_public": True,
- },
- retcol="room_id",
- desc="get_public_room_ids",
- )
-
def get_room_count(self):
"""Retrieve a list of all rooms
"""
@@ -263,8 +383,8 @@ class RoomStore(SQLBaseStore):
},
)
- self._store_event_search_txn(
- txn, event, "content.topic", event.content["topic"]
+ self.store_event_search_txn(
+ txn, event, "content.topic", event.content["topic"],
)
def _store_room_name_txn(self, txn, event):
@@ -279,14 +399,14 @@ class RoomStore(SQLBaseStore):
}
)
- self._store_event_search_txn(
- txn, event, "content.name", event.content["name"]
+ self.store_event_search_txn(
+ txn, event, "content.name", event.content["name"],
)
def _store_room_message_txn(self, txn, event):
if hasattr(event, "content") and "body" in event.content:
- self._store_event_search_txn(
- txn, event, "content.body", event.content["body"]
+ self.store_event_search_txn(
+ txn, event, "content.body", event.content["body"],
)
def _store_history_visibility_txn(self, txn, event):
@@ -308,31 +428,6 @@ class RoomStore(SQLBaseStore):
event.content[key]
))
- def _store_event_search_txn(self, txn, event, key, value):
- if isinstance(self.database_engine, PostgresEngine):
- sql = (
- "INSERT INTO event_search"
- " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
- " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
- )
- txn.execute(
- sql,
- (
- event.event_id, event.room_id, key, value,
- event.internal_metadata.stream_ordering,
- event.origin_server_ts,
- )
- )
- elif isinstance(self.database_engine, Sqlite3Engine):
- sql = (
- "INSERT INTO event_search (event_id, room_id, key, value)"
- " VALUES (?,?,?,?)"
- )
- txn.execute(sql, (event.event_id, event.room_id, key, value,))
- else:
- # This should be unreachable.
- raise Exception("Unrecognized database engine")
-
def add_event_report(self, room_id, event_id, user_id, reason, content,
received_ts):
next_id = self._event_reports_id_gen.get_next()
@@ -353,113 +448,6 @@ class RoomStore(SQLBaseStore):
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
- @cached(num_args=2, max_entries=100)
- def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
- """Get pulbic rooms for a particular list, or across all lists.
-
- Args:
- stream_id (int)
- network_tuple (ThirdPartyInstanceID): The list to use (None, None)
- means the main list, None means all lsits.
- """
- return self.runInteraction(
- "get_public_room_ids_at_stream_id",
- self.get_public_room_ids_at_stream_id_txn,
- stream_id, network_tuple=network_tuple
- )
-
- def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
- network_tuple):
- return {
- rm
- for rm, vis in self.get_published_at_stream_id_txn(
- txn, stream_id, network_tuple=network_tuple
- ).items()
- if vis
- }
-
- def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
- if network_tuple:
- # We want to get from a particular list. No aggregation required.
-
- sql = ("""
- SELECT room_id, visibility FROM public_room_list_stream
- INNER JOIN (
- SELECT room_id, max(stream_id) AS stream_id
- FROM public_room_list_stream
- WHERE stream_id <= ? %s
- GROUP BY room_id
- ) grouped USING (room_id, stream_id)
- """)
-
- if network_tuple.appservice_id is not None:
- txn.execute(
- sql % ("AND appservice_id = ? AND network_id = ?",),
- (stream_id, network_tuple.appservice_id, network_tuple.network_id,)
- )
- else:
- txn.execute(
- sql % ("AND appservice_id IS NULL",),
- (stream_id,)
- )
- return dict(txn)
- else:
- # We want to get from all lists, so we need to aggregate the results
-
- logger.info("Executing full list")
-
- sql = ("""
- SELECT room_id, visibility
- FROM public_room_list_stream
- INNER JOIN (
- SELECT
- room_id, max(stream_id) AS stream_id, appservice_id,
- network_id
- FROM public_room_list_stream
- WHERE stream_id <= ?
- GROUP BY room_id, appservice_id, network_id
- ) grouped USING (room_id, stream_id)
- """)
-
- txn.execute(
- sql,
- (stream_id,)
- )
-
- results = {}
- # A room is visible if its visible on any list.
- for room_id, visibility in txn:
- results[room_id] = bool(visibility) or results.get(room_id, False)
-
- return results
-
- def get_public_room_changes(self, prev_stream_id, new_stream_id,
- network_tuple):
- def get_public_room_changes_txn(txn):
- then_rooms = self.get_public_room_ids_at_stream_id_txn(
- txn, prev_stream_id, network_tuple
- )
-
- now_rooms_dict = self.get_published_at_stream_id_txn(
- txn, new_stream_id, network_tuple
- )
-
- now_rooms_visible = set(
- rm for rm, vis in now_rooms_dict.items() if vis
- )
- now_rooms_not_visible = set(
- rm for rm, vis in now_rooms_dict.items() if not vis
- )
-
- newly_visible = now_rooms_visible - then_rooms
- newly_unpublished = now_rooms_not_visible & then_rooms
-
- return newly_visible, newly_unpublished
-
- return self.runInteraction(
- "get_public_room_changes", get_public_room_changes_txn
- )
-
def get_all_new_public_rooms(self, prev_id, current_id, limit):
def get_all_new_public_rooms(txn):
sql = ("""
@@ -509,18 +497,6 @@ class RoomStore(SQLBaseStore):
else:
defer.returnValue(None)
- @cached(max_entries=10000)
- def is_room_blocked(self, room_id):
- return self._simple_select_one_onecol(
- table="blocked_rooms",
- keyvalues={
- "room_id": room_id,
- },
- retcol="1",
- allow_none=True,
- desc="is_room_blocked",
- )
-
@defer.inlineCallbacks
def block_room(self, room_id, user_id):
yield self._simple_insert(
@@ -531,75 +507,120 @@ class RoomStore(SQLBaseStore):
},
desc="block_room",
)
- self.is_room_blocked.invalidate((room_id,))
+ yield self.runInteraction(
+ "block_room_invalidation",
+ self._invalidate_cache_and_stream,
+ self.is_room_blocked, (room_id,),
+ )
+
+ def get_media_mxcs_in_room(self, room_id):
+ """Retrieves all the local and remote media MXC URIs in a given room
+
+ Args:
+ room_id (str)
+
+ Returns:
+ The local and remote media as a lists of tuples where the key is
+ the hostname and the value is the media ID.
+ """
+ def _get_media_mxcs_in_room_txn(txn):
+ local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+ local_media_mxcs = []
+ remote_media_mxcs = []
+
+ # Convert the IDs to MXC URIs
+ for media_id in local_mxcs:
+ local_media_mxcs.append("mxc://%s/%s" % (self.hostname, media_id))
+ for hostname, media_id in remote_mxcs:
+ remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
+
+ return local_media_mxcs, remote_media_mxcs
+ return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
def quarantine_media_ids_in_room(self, room_id, quarantined_by):
"""For a room loops through all events with media and quarantines
the associated media
"""
- def _get_media_ids_in_room(txn):
- mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+ def _quarantine_media_in_room_txn(txn):
+ local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+ total_media_quarantined = 0
- next_token = self.get_current_events_token() + 1
+ # Now update all the tables to set the quarantined_by flag
- total_media_quarantined = 0
+ txn.executemany("""
+ UPDATE local_media_repository
+ SET quarantined_by = ?
+ WHERE media_id = ?
+ """, ((quarantined_by, media_id) for media_id in local_mxcs))
- while next_token:
- sql = """
- SELECT stream_ordering, content FROM events
- WHERE room_id = ?
- AND stream_ordering < ?
- AND contains_url = ? AND outlier = ?
- ORDER BY stream_ordering DESC
- LIMIT ?
+ txn.executemany(
"""
- txn.execute(sql, (room_id, next_token, True, False, 100))
-
- next_token = None
- local_media_mxcs = []
- remote_media_mxcs = []
- for stream_ordering, content_json in txn:
- next_token = stream_ordering
- content = json.loads(content_json)
-
- content_url = content.get("url")
- thumbnail_url = content.get("info", {}).get("thumbnail_url")
-
- for url in (content_url, thumbnail_url):
- if not url:
- continue
- matches = mxc_re.match(url)
- if matches:
- hostname = matches.group(1)
- media_id = matches.group(2)
- if hostname == self.hostname:
- local_media_mxcs.append(media_id)
- else:
- remote_media_mxcs.append((hostname, media_id))
-
- # Now update all the tables to set the quarantined_by flag
-
- txn.executemany("""
- UPDATE local_media_repository
+ UPDATE remote_media_cache
SET quarantined_by = ?
- WHERE media_id = ?
- """, ((quarantined_by, media_id) for media_id in local_media_mxcs))
-
- txn.executemany(
- """
- UPDATE remote_media_cache
- SET quarantined_by = ?
- WHERE media_origin AND media_id = ?
- """,
- (
- (quarantined_by, origin, media_id)
- for origin, media_id in remote_media_mxcs
- )
+ WHERE media_origin = ? AND media_id = ?
+ """,
+ (
+ (quarantined_by, origin, media_id)
+ for origin, media_id in remote_mxcs
)
+ )
- total_media_quarantined += len(local_media_mxcs)
- total_media_quarantined += len(remote_media_mxcs)
+ total_media_quarantined += len(local_mxcs)
+ total_media_quarantined += len(remote_mxcs)
return total_media_quarantined
- return self.runInteraction("get_media_ids_in_room", _get_media_ids_in_room)
+ return self.runInteraction(
+ "quarantine_media_in_room",
+ _quarantine_media_in_room_txn,
+ )
+
+ def _get_media_mxcs_in_room_txn(self, txn, room_id):
+ """Retrieves all the local and remote media MXC URIs in a given room
+
+ Args:
+ txn (cursor)
+ room_id (str)
+
+ Returns:
+ The local and remote media as a lists of tuples where the key is
+ the hostname and the value is the media ID.
+ """
+ mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+
+ next_token = self.get_current_events_token() + 1
+ local_media_mxcs = []
+ remote_media_mxcs = []
+
+ while next_token:
+ sql = """
+ SELECT stream_ordering, content FROM events
+ WHERE room_id = ?
+ AND stream_ordering < ?
+ AND contains_url = ? AND outlier = ?
+ ORDER BY stream_ordering DESC
+ LIMIT ?
+ """
+ txn.execute(sql, (room_id, next_token, True, False, 100))
+
+ next_token = None
+ for stream_ordering, content_json in txn:
+ next_token = stream_ordering
+ content = json.loads(content_json)
+
+ content_url = content.get("url")
+ thumbnail_url = content.get("info", {}).get("thumbnail_url")
+
+ for url in (content_url, thumbnail_url):
+ if not url:
+ continue
+ matches = mxc_re.match(url)
+ if matches:
+ hostname = matches.group(1)
+ media_id = matches.group(2)
+ if hostname == self.hostname:
+ local_media_mxcs.append(media_id)
+ else:
+ remote_media_mxcs.append((hostname, media_id))
+
+ return local_media_mxcs, remote_media_mxcs
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index c1ca299285..d662d1cfc0 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,7 +18,7 @@ from twisted.internet import defer
from collections import namedtuple
-from ._base import SQLBaseStore
+from synapse.storage.events import EventsWorkerStore
from synapse.util.async import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -37,6 +38,11 @@ RoomsForUser = namedtuple(
("room_id", "sender", "membership", "event_id", "stream_ordering")
)
+GetRoomsForUserWithStreamOrdering = namedtuple(
+ "_GetRoomsForUserWithStreamOrdering",
+ ("room_id", "stream_ordering",)
+)
+
# We store this using a namedtuple so that we save about 3x space over using a
# dict.
@@ -48,97 +54,7 @@ ProfileInfo = namedtuple(
_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
-class RoomMemberStore(SQLBaseStore):
- def __init__(self, db_conn, hs):
- super(RoomMemberStore, self).__init__(db_conn, hs)
- self.register_background_update_handler(
- _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
- )
-
- def _store_room_members_txn(self, txn, events, backfilled):
- """Store a room member in the database.
- """
- self._simple_insert_many_txn(
- txn,
- table="room_memberships",
- values=[
- {
- "event_id": event.event_id,
- "user_id": event.state_key,
- "sender": event.user_id,
- "room_id": event.room_id,
- "membership": event.membership,
- "display_name": event.content.get("displayname", None),
- "avatar_url": event.content.get("avatar_url", None),
- }
- for event in events
- ]
- )
-
- for event in events:
- txn.call_after(
- self._membership_stream_cache.entity_has_changed,
- event.state_key, event.internal_metadata.stream_ordering
- )
- txn.call_after(
- self.get_invited_rooms_for_user.invalidate, (event.state_key,)
- )
-
- # We update the local_invites table only if the event is "current",
- # i.e., its something that has just happened.
- # The only current event that can also be an outlier is if its an
- # invite that has come in across federation.
- is_new_state = not backfilled and (
- not event.internal_metadata.is_outlier()
- or event.internal_metadata.is_invite_from_remote()
- )
- is_mine = self.hs.is_mine_id(event.state_key)
- if is_new_state and is_mine:
- if event.membership == Membership.INVITE:
- self._simple_insert_txn(
- txn,
- table="local_invites",
- values={
- "event_id": event.event_id,
- "invitee": event.state_key,
- "inviter": event.sender,
- "room_id": event.room_id,
- "stream_id": event.internal_metadata.stream_ordering,
- }
- )
- else:
- sql = (
- "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
- " room_id = ? AND invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
-
- txn.execute(sql, (
- event.internal_metadata.stream_ordering,
- event.event_id,
- event.room_id,
- event.state_key,
- ))
-
- @defer.inlineCallbacks
- def locally_reject_invite(self, user_id, room_id):
- sql = (
- "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
- " room_id = ? AND invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
-
- def f(txn, stream_ordering):
- txn.execute(sql, (
- stream_ordering,
- True,
- room_id,
- user_id,
- ))
-
- with self._stream_id_gen.get_next() as stream_ordering:
- yield self.runInteraction("locally_reject_invite", f, stream_ordering)
-
+class RoomMemberWorkerStore(EventsWorkerStore):
@cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
def get_hosts_in_room(self, room_id, cache_context):
"""Returns the set of all hosts currently in the room
@@ -270,12 +186,32 @@ class RoomMemberStore(SQLBaseStore):
return results
@cachedInlineCallbacks(max_entries=500000, iterable=True)
- def get_rooms_for_user(self, user_id):
+ def get_rooms_for_user_with_stream_ordering(self, user_id):
"""Returns a set of room_ids the user is currently joined to
+
+ Args:
+ user_id (str)
+
+ Returns:
+ Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
+ the rooms the user is in currently, along with the stream ordering
+ of the most recent join for that user and room.
"""
rooms = yield self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN],
)
+ defer.returnValue(frozenset(
+ GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
+ for r in rooms
+ ))
+
+ @defer.inlineCallbacks
+ def get_rooms_for_user(self, user_id, on_invalidate=None):
+ """Returns a set of room_ids the user is currently joined to
+ """
+ rooms = yield self.get_rooms_for_user_with_stream_ordering(
+ user_id, on_invalidate=on_invalidate,
+ )
defer.returnValue(frozenset(r.room_id for r in rooms))
@cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
@@ -295,89 +231,6 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(user_who_share_room)
- def forget(self, user_id, room_id):
- """Indicate that user_id wishes to discard history for room_id."""
- def f(txn):
- sql = (
- "UPDATE"
- " room_memberships"
- " SET"
- " forgotten = 1"
- " WHERE"
- " user_id = ?"
- " AND"
- " room_id = ?"
- )
- txn.execute(sql, (user_id, room_id))
-
- txn.call_after(self.was_forgotten_at.invalidate_all)
- txn.call_after(self.did_forget.invalidate, (user_id, room_id))
- self._invalidate_cache_and_stream(
- txn, self.who_forgot_in_room, (room_id,)
- )
- return self.runInteraction("forget_membership", f)
-
- @cachedInlineCallbacks(num_args=2)
- def did_forget(self, user_id, room_id):
- """Returns whether user_id has elected to discard history for room_id.
-
- Returns False if they have since re-joined."""
- def f(txn):
- sql = (
- "SELECT"
- " COUNT(*)"
- " FROM"
- " room_memberships"
- " WHERE"
- " user_id = ?"
- " AND"
- " room_id = ?"
- " AND"
- " forgotten = 0"
- )
- txn.execute(sql, (user_id, room_id))
- rows = txn.fetchall()
- return rows[0][0]
- count = yield self.runInteraction("did_forget_membership", f)
- defer.returnValue(count == 0)
-
- @cachedInlineCallbacks(num_args=3)
- def was_forgotten_at(self, user_id, room_id, event_id):
- """Returns whether user_id has elected to discard history for room_id at
- event_id.
-
- event_id must be a membership event."""
- def f(txn):
- sql = (
- "SELECT"
- " forgotten"
- " FROM"
- " room_memberships"
- " WHERE"
- " user_id = ?"
- " AND"
- " room_id = ?"
- " AND"
- " event_id = ?"
- )
- txn.execute(sql, (user_id, room_id, event_id))
- rows = txn.fetchall()
- return rows[0][0]
- forgot = yield self.runInteraction("did_forget_membership_at", f)
- defer.returnValue(forgot == 1)
-
- @cached()
- def who_forgot_in_room(self, room_id):
- return self._simple_select_list(
- table="room_memberships",
- retcols=("user_id", "event_id"),
- keyvalues={
- "room_id": room_id,
- "forgotten": 1,
- },
- desc="who_forgot"
- )
-
def get_joined_users_from_context(self, event, context):
state_group = context.state_group
if not state_group:
@@ -600,6 +453,185 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(joined_hosts)
+ @cached(max_entries=10000, iterable=True)
+ def _get_joined_hosts_cache(self, room_id):
+ return _JoinedHostsCache(self, room_id)
+
+ @cached()
+ def who_forgot_in_room(self, room_id):
+ return self._simple_select_list(
+ table="room_memberships",
+ retcols=("user_id", "event_id"),
+ keyvalues={
+ "room_id": room_id,
+ "forgotten": 1,
+ },
+ desc="who_forgot"
+ )
+
+
+class RoomMemberStore(RoomMemberWorkerStore):
+ def __init__(self, db_conn, hs):
+ super(RoomMemberStore, self).__init__(db_conn, hs)
+ self.register_background_update_handler(
+ _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
+ )
+
+ def _store_room_members_txn(self, txn, events, backfilled):
+ """Store a room member in the database.
+ """
+ self._simple_insert_many_txn(
+ txn,
+ table="room_memberships",
+ values=[
+ {
+ "event_id": event.event_id,
+ "user_id": event.state_key,
+ "sender": event.user_id,
+ "room_id": event.room_id,
+ "membership": event.membership,
+ "display_name": event.content.get("displayname", None),
+ "avatar_url": event.content.get("avatar_url", None),
+ }
+ for event in events
+ ]
+ )
+
+ for event in events:
+ txn.call_after(
+ self._membership_stream_cache.entity_has_changed,
+ event.state_key, event.internal_metadata.stream_ordering
+ )
+ txn.call_after(
+ self.get_invited_rooms_for_user.invalidate, (event.state_key,)
+ )
+
+ # We update the local_invites table only if the event is "current",
+ # i.e., its something that has just happened.
+ # The only current event that can also be an outlier is if its an
+ # invite that has come in across federation.
+ is_new_state = not backfilled and (
+ not event.internal_metadata.is_outlier()
+ or event.internal_metadata.is_invite_from_remote()
+ )
+ is_mine = self.hs.is_mine_id(event.state_key)
+ if is_new_state and is_mine:
+ if event.membership == Membership.INVITE:
+ self._simple_insert_txn(
+ txn,
+ table="local_invites",
+ values={
+ "event_id": event.event_id,
+ "invitee": event.state_key,
+ "inviter": event.sender,
+ "room_id": event.room_id,
+ "stream_id": event.internal_metadata.stream_ordering,
+ }
+ )
+ else:
+ sql = (
+ "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
+ " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+ " AND replaced_by is NULL"
+ )
+
+ txn.execute(sql, (
+ event.internal_metadata.stream_ordering,
+ event.event_id,
+ event.room_id,
+ event.state_key,
+ ))
+
+ @defer.inlineCallbacks
+ def locally_reject_invite(self, user_id, room_id):
+ sql = (
+ "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
+ " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+ " AND replaced_by is NULL"
+ )
+
+ def f(txn, stream_ordering):
+ txn.execute(sql, (
+ stream_ordering,
+ True,
+ room_id,
+ user_id,
+ ))
+
+ with self._stream_id_gen.get_next() as stream_ordering:
+ yield self.runInteraction("locally_reject_invite", f, stream_ordering)
+
+ def forget(self, user_id, room_id):
+ """Indicate that user_id wishes to discard history for room_id."""
+ def f(txn):
+ sql = (
+ "UPDATE"
+ " room_memberships"
+ " SET"
+ " forgotten = 1"
+ " WHERE"
+ " user_id = ?"
+ " AND"
+ " room_id = ?"
+ )
+ txn.execute(sql, (user_id, room_id))
+
+ txn.call_after(self.was_forgotten_at.invalidate_all)
+ txn.call_after(self.did_forget.invalidate, (user_id, room_id))
+ self._invalidate_cache_and_stream(
+ txn, self.who_forgot_in_room, (room_id,)
+ )
+ return self.runInteraction("forget_membership", f)
+
+ @cachedInlineCallbacks(num_args=2)
+ def did_forget(self, user_id, room_id):
+ """Returns whether user_id has elected to discard history for room_id.
+
+ Returns False if they have since re-joined."""
+ def f(txn):
+ sql = (
+ "SELECT"
+ " COUNT(*)"
+ " FROM"
+ " room_memberships"
+ " WHERE"
+ " user_id = ?"
+ " AND"
+ " room_id = ?"
+ " AND"
+ " forgotten = 0"
+ )
+ txn.execute(sql, (user_id, room_id))
+ rows = txn.fetchall()
+ return rows[0][0]
+ count = yield self.runInteraction("did_forget_membership", f)
+ defer.returnValue(count == 0)
+
+ @cachedInlineCallbacks(num_args=3)
+ def was_forgotten_at(self, user_id, room_id, event_id):
+ """Returns whether user_id has elected to discard history for room_id at
+ event_id.
+
+ event_id must be a membership event."""
+ def f(txn):
+ sql = (
+ "SELECT"
+ " forgotten"
+ " FROM"
+ " room_memberships"
+ " WHERE"
+ " user_id = ?"
+ " AND"
+ " room_id = ?"
+ " AND"
+ " event_id = ?"
+ )
+ txn.execute(sql, (user_id, room_id, event_id))
+ rows = txn.fetchall()
+ return rows[0][0]
+ forgot = yield self.runInteraction("did_forget_membership_at", f)
+ defer.returnValue(forgot == 1)
+
@defer.inlineCallbacks
def _background_add_membership_profile(self, progress, batch_size):
target_min_stream_id = progress.get(
@@ -675,10 +707,6 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(result)
- @cached(max_entries=10000, iterable=True)
- def _get_joined_hosts_cache(self, room_id):
- return _JoinedHostsCache(self, room_id)
-
class _JoinedHostsCache(object):
"""Cache for joined hosts in a room that is optimised to handle updates
diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/schema/delta/27/ts.py
index c0176c41ee..6df57b5206 100644
--- a/synapse/storage/schema/delta/27/ts.py
+++ b/synapse/storage/schema/delta/27/ts.py
@@ -45,7 +45,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
- progress_json = ujson.dumps(progress)
+ progress_json = simplejson.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/schema/delta/38/postgres_fts_gist.sql b/synapse/storage/schema/delta/38/postgres_fts_gist.sql
index f090a7b75a..515e6b8e84 100644
--- a/synapse/storage/schema/delta/38/postgres_fts_gist.sql
+++ b/synapse/storage/schema/delta/38/postgres_fts_gist.sql
@@ -13,5 +13,7 @@
* limitations under the License.
*/
- INSERT into background_updates (update_name, progress_json)
- VALUES ('event_search_postgres_gist', '{}');
+-- We no longer do this given we back it out again in schema 47
+
+-- INSERT into background_updates (update_name, progress_json)
+-- VALUES ('event_search_postgres_gist', '{}');
diff --git a/synapse/storage/schema/delta/47/last_access_media.sql b/synapse/storage/schema/delta/47/last_access_media.sql
new file mode 100644
index 0000000000..f505fb22b5
--- /dev/null
+++ b/synapse/storage/schema/delta/47/last_access_media.sql
@@ -0,0 +1,16 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE local_media_repository ADD COLUMN last_access_ts BIGINT;
diff --git a/synapse/storage/schema/delta/47/postgres_fts_gin.sql b/synapse/storage/schema/delta/47/postgres_fts_gin.sql
new file mode 100644
index 0000000000..31d7a817eb
--- /dev/null
+++ b/synapse/storage/schema/delta/47/postgres_fts_gin.sql
@@ -0,0 +1,17 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT into background_updates (update_name, progress_json)
+ VALUES ('event_search_postgres_gin', '{}');
diff --git a/synapse/storage/schema/delta/47/push_actions_staging.sql b/synapse/storage/schema/delta/47/push_actions_staging.sql
new file mode 100644
index 0000000000..edccf4a96f
--- /dev/null
+++ b/synapse/storage/schema/delta/47/push_actions_staging.sql
@@ -0,0 +1,28 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Temporary staging area for push actions that have been calculated for an
+-- event, but the event hasn't yet been persisted.
+-- When the event is persisted the rows are moved over to the
+-- event_push_actions table.
+CREATE TABLE event_push_actions_staging (
+ event_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ actions TEXT NOT NULL,
+ notif SMALLINT NOT NULL,
+ highlight SMALLINT NOT NULL
+);
+
+CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging(event_id);
diff --git a/synapse/storage/schema/delta/47/state_group_seq.py b/synapse/storage/schema/delta/47/state_group_seq.py
new file mode 100644
index 0000000000..f6766501d2
--- /dev/null
+++ b/synapse/storage/schema/delta/47/state_group_seq.py
@@ -0,0 +1,37 @@
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.engines import PostgresEngine
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ if isinstance(database_engine, PostgresEngine):
+ # if we already have some state groups, we want to start making new
+ # ones with a higher id.
+ cur.execute("SELECT max(id) FROM state_groups")
+ row = cur.fetchone()
+
+ if row[0] is None:
+ start_val = 1
+ else:
+ start_val = row[0] + 1
+
+ cur.execute(
+ "CREATE SEQUENCE state_group_id_seq START WITH %s",
+ (start_val, ),
+ )
+
+
+def run_upgrade(*args, **kwargs):
+ pass
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index c19e4ea449..984643b057 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -13,25 +13,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from collections import namedtuple
+import logging
+import re
+import simplejson as json
+
from twisted.internet import defer
from .background_updates import BackgroundUpdateStore
from synapse.api.errors import SynapseError
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-import logging
-import re
-import simplejson as json
-
logger = logging.getLogger(__name__)
+SearchEntry = namedtuple('SearchEntry', [
+ 'key', 'value', 'event_id', 'room_id', 'stream_ordering',
+ 'origin_server_ts',
+])
+
class SearchStore(BackgroundUpdateStore):
EVENT_SEARCH_UPDATE_NAME = "event_search"
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
+ EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
def __init__(self, db_conn, hs):
super(SearchStore, self).__init__(db_conn, hs)
@@ -42,23 +49,34 @@ class SearchStore(BackgroundUpdateStore):
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
self._background_reindex_search_order
)
- self.register_background_update_handler(
+
+ # we used to have a background update to turn the GIN index into a
+ # GIST one; we no longer do that (obviously) because we actually want
+ # a GIN index. However, it's possible that some people might still have
+ # the background update queued, so we register a handler to clear the
+ # background update.
+ self.register_noop_background_update(
self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME,
- self._background_reindex_gist_search
+ )
+
+ self.register_background_update_handler(
+ self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME,
+ self._background_reindex_gin_search
)
@defer.inlineCallbacks
def _background_reindex_search(self, progress, batch_size):
+ # we work through the events table from highest stream id to lowest
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
- INSERT_CLUMP_SIZE = 1000
TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
def reindex_search_txn(txn):
sql = (
- "SELECT stream_ordering, event_id, room_id, type, content FROM events"
+ "SELECT stream_ordering, event_id, room_id, type, content, "
+ " origin_server_ts FROM events"
" WHERE ? <= stream_ordering AND stream_ordering < ?"
" AND (%s)"
" ORDER BY stream_ordering DESC"
@@ -67,6 +85,10 @@ class SearchStore(BackgroundUpdateStore):
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
+ # we could stream straight from the results into
+ # store_search_entries_txn with a generator function, but that
+ # would mean having two cursors open on the database at once.
+ # Instead we just build a list of results.
rows = self.cursor_to_dict(txn)
if not rows:
return 0
@@ -79,6 +101,8 @@ class SearchStore(BackgroundUpdateStore):
event_id = row["event_id"]
room_id = row["room_id"]
etype = row["type"]
+ stream_ordering = row["stream_ordering"]
+ origin_server_ts = row["origin_server_ts"]
try:
content = json.loads(row["content"])
except Exception:
@@ -93,6 +117,8 @@ class SearchStore(BackgroundUpdateStore):
elif etype == "m.room.name":
key = "content.name"
value = content["name"]
+ else:
+ raise Exception("unexpected event type %s" % etype)
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
@@ -103,25 +129,16 @@ class SearchStore(BackgroundUpdateStore):
# then skip over it
continue
- event_search_rows.append((event_id, room_id, key, value))
+ event_search_rows.append(SearchEntry(
+ key=key,
+ value=value,
+ event_id=event_id,
+ room_id=room_id,
+ stream_ordering=stream_ordering,
+ origin_server_ts=origin_server_ts,
+ ))
- if isinstance(self.database_engine, PostgresEngine):
- sql = (
- "INSERT INTO event_search (event_id, room_id, key, vector)"
- " VALUES (?,?,?,to_tsvector('english', ?))"
- )
- elif isinstance(self.database_engine, Sqlite3Engine):
- sql = (
- "INSERT INTO event_search (event_id, room_id, key, value)"
- " VALUES (?,?,?,?)"
- )
- else:
- # This should be unreachable.
- raise Exception("Unrecognized database engine")
-
- for index in range(0, len(event_search_rows), INSERT_CLUMP_SIZE):
- clump = event_search_rows[index:index + INSERT_CLUMP_SIZE]
- txn.executemany(sql, clump)
+ self.store_search_entries_txn(txn, event_search_rows)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
@@ -145,25 +162,48 @@ class SearchStore(BackgroundUpdateStore):
defer.returnValue(result)
@defer.inlineCallbacks
- def _background_reindex_gist_search(self, progress, batch_size):
+ def _background_reindex_gin_search(self, progress, batch_size):
+ """This handles old synapses which used GIST indexes, if any;
+ converting them back to be GIN as per the actual schema.
+ """
+
def create_index(conn):
conn.rollback()
- conn.set_session(autocommit=True)
- c = conn.cursor()
- c.execute(
- "CREATE INDEX CONCURRENTLY event_search_fts_idx_gist"
- " ON event_search USING GIST (vector)"
- )
+ # we have to set autocommit, because postgres refuses to
+ # CREATE INDEX CONCURRENTLY without it.
+ conn.set_session(autocommit=True)
- c.execute("DROP INDEX event_search_fts_idx")
+ try:
+ c = conn.cursor()
- conn.set_session(autocommit=False)
+ # if we skipped the conversion to GIST, we may already/still
+ # have an event_search_fts_idx; unfortunately postgres 9.4
+ # doesn't support CREATE INDEX IF EXISTS so we just catch the
+ # exception and ignore it.
+ import psycopg2
+ try:
+ c.execute(
+ "CREATE INDEX CONCURRENTLY event_search_fts_idx"
+ " ON event_search USING GIN (vector)"
+ )
+ except psycopg2.ProgrammingError as e:
+ logger.warn(
+ "Ignoring error %r when trying to switch from GIST to GIN",
+ e
+ )
+
+ # we should now be able to delete the GIST index.
+ c.execute(
+ "DROP INDEX IF EXISTS event_search_fts_idx_gist"
+ )
+ finally:
+ conn.set_session(autocommit=False)
if isinstance(self.database_engine, PostgresEngine):
yield self.runWithConnection(create_index)
- yield self._end_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME)
+ yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME)
defer.returnValue(1)
@defer.inlineCallbacks
@@ -242,6 +282,85 @@ class SearchStore(BackgroundUpdateStore):
defer.returnValue(num_rows)
+ def store_event_search_txn(self, txn, event, key, value):
+ """Add event to the search table
+
+ Args:
+ txn (cursor):
+ event (EventBase):
+ key (str):
+ value (str):
+ """
+ self.store_search_entries_txn(
+ txn,
+ (SearchEntry(
+ key=key,
+ value=value,
+ event_id=event.event_id,
+ room_id=event.room_id,
+ stream_ordering=event.internal_metadata.stream_ordering,
+ origin_server_ts=event.origin_server_ts,
+ ),),
+ )
+
+ def store_search_entries_txn(self, txn, entries):
+ """Add entries to the search table
+
+ Args:
+ txn (cursor):
+ entries (iterable[SearchEntry]):
+ entries to be added to the table
+ """
+ if isinstance(self.database_engine, PostgresEngine):
+ sql = (
+ "INSERT INTO event_search"
+ " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
+ " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
+ )
+
+ args = ((
+ entry.event_id, entry.room_id, entry.key, entry.value,
+ entry.stream_ordering, entry.origin_server_ts,
+ ) for entry in entries)
+
+ # inserts to a GIN index are normally batched up into a pending
+ # list, and then all committed together once the list gets to a
+ # certain size. The trouble with that is that postgres (pre-9.5)
+ # uses work_mem to determine the length of the list, and work_mem
+ # is typically very large.
+ #
+ # We therefore reduce work_mem while we do the insert.
+ #
+ # (postgres 9.5 uses the separate gin_pending_list_limit setting,
+ # so doesn't suffer the same problem, but changing work_mem will
+ # be harmless)
+ #
+ # Note that we don't need to worry about restoring it on
+ # exception, because exceptions will cause the transaction to be
+ # rolled back, including the effects of the SET command.
+ #
+ # Also: we use SET rather than SET LOCAL because there's lots of
+ # other stuff going on in this transaction, which want to have the
+ # normal work_mem setting.
+
+ txn.execute("SET work_mem='256kB'")
+ txn.executemany(sql, args)
+ txn.execute("RESET work_mem")
+
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ sql = (
+ "INSERT INTO event_search (event_id, room_id, key, value)"
+ " VALUES (?,?,?,?)"
+ )
+ args = ((
+ entry.event_id, entry.room_id, entry.key, entry.value,
+ ) for entry in entries)
+
+ txn.executemany(sql, args)
+ else:
+ # This should be unreachable.
+ raise Exception("Unrecognized database engine")
+
@defer.inlineCallbacks
def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys.
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index 67d5d9969a..9e6eaaa532 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -22,12 +22,12 @@ from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.util.caches.descriptors import cached, cachedList
-class SignatureStore(SQLBaseStore):
- """Persistence for event signatures and hashes"""
-
+class SignatureWorkerStore(SQLBaseStore):
@cached()
def get_event_reference_hash(self, event_id):
- return self._get_event_reference_hashes_txn(event_id)
+ # This is a dummy function to allow get_event_reference_hashes
+ # to use its cache
+ raise NotImplementedError()
@cachedList(cached_method_name="get_event_reference_hash",
list_name="event_ids", num_args=1)
@@ -74,6 +74,10 @@ class SignatureStore(SQLBaseStore):
txn.execute(query, (event_id, ))
return {k: v for k, v in txn}
+
+class SignatureStore(SignatureWorkerStore):
+ """Persistence for event signatures and hashes"""
+
def _store_event_reference_hashes_txn(self, txn, events):
"""Store a hash for a PDU
Args:
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 360e3e4355..ffa4246031 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -42,11 +42,8 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
return len(self.delta_ids) if self.delta_ids else 0
-class StateGroupReadStore(SQLBaseStore):
- """The read-only parts of StateGroupStore
-
- None of these functions write to the state tables, so are suitable for
- including in the SlavedStores.
+class StateGroupWorkerStore(SQLBaseStore):
+ """The parts of StateGroupStore that can be called from workers.
"""
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
@@ -54,7 +51,7 @@ class StateGroupReadStore(SQLBaseStore):
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
def __init__(self, db_conn, hs):
- super(StateGroupReadStore, self).__init__(db_conn, hs)
+ super(StateGroupWorkerStore, self).__init__(db_conn, hs)
self._state_group_cache = DictionaryCache(
"*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
@@ -143,6 +140,20 @@ class StateGroupReadStore(SQLBaseStore):
defer.returnValue(group_to_state)
@defer.inlineCallbacks
+ def get_state_ids_for_group(self, state_group):
+ """Get the state IDs for the given state group
+
+ Args:
+ state_group (int)
+
+ Returns:
+ Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
+ """
+ group_to_state = yield self._get_state_for_groups((state_group,))
+
+ defer.returnValue(group_to_state[state_group])
+
+ @defer.inlineCallbacks
def get_state_groups(self, room_id, event_ids):
""" Get the state groups for the given list of event_ids
@@ -229,6 +240,9 @@ class StateGroupReadStore(SQLBaseStore):
(
"AND type = ? AND state_key = ?",
(etype, state_key)
+ ) if state_key is not None else (
+ "AND type = ?",
+ (etype,)
)
for etype, state_key in types
]
@@ -248,10 +262,19 @@ class StateGroupReadStore(SQLBaseStore):
key = (typ, state_key)
results[group][key] = event_id
else:
+ where_args = []
+ where_clauses = []
+ wildcard_types = False
if types is not None:
- where_clause = "AND (%s)" % (
- " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
- )
+ for typ in types:
+ if typ[1] is None:
+ where_clauses.append("(type = ?)")
+ where_args.extend(typ[0])
+ wildcard_types = True
+ else:
+ where_clauses.append("(type = ? AND state_key = ?)")
+ where_args.extend([typ[0], typ[1]])
+ where_clause = "AND (%s)" % (" OR ".join(where_clauses))
else:
where_clause = ""
@@ -268,7 +291,7 @@ class StateGroupReadStore(SQLBaseStore):
# after we finish deduping state, which requires this func)
args = [next_group]
if types:
- args.extend(i for typ in types for i in typ)
+ args.extend(where_args)
txn.execute(
"SELECT type, state_key, event_id FROM state_groups_state"
@@ -281,9 +304,17 @@ class StateGroupReadStore(SQLBaseStore):
if (typ, state_key) not in results[group]
)
- # If the lengths match then we must have all the types,
- # so no need to go walk further down the tree.
- if types is not None and len(results[group]) == len(types):
+ # If the number of entries in the (type,state_key)->event_id dict
+ # matches the number of (type,state_keys) types we were searching
+ # for, then we must have found them all, so no need to go walk
+ # further down the tree... UNLESS our types filter contained
+ # wildcards (i.e. Nones) in which case we have to do an exhaustive
+ # search
+ if (
+ types is not None and
+ not wildcard_types and
+ len(results[group]) == len(types)
+ ):
break
next_group = self._simple_select_one_onecol_txn(
@@ -549,116 +580,66 @@ class StateGroupReadStore(SQLBaseStore):
defer.returnValue(results)
+ def store_state_group(self, event_id, room_id, prev_group, delta_ids,
+ current_state_ids):
+ """Store a new set of state, returning a newly assigned state group.
-class StateStore(StateGroupReadStore, BackgroundUpdateStore):
- """ Keeps track of the state at a given event.
-
- This is done by the concept of `state groups`. Every event is a assigned
- a state group (identified by an arbitrary string), which references a
- collection of state events. The current state of an event is then the
- collection of state events referenced by the event's state group.
-
- Hence, every change in the current state causes a new state group to be
- generated. However, if no change happens (e.g., if we get a message event
- with only one parent it inherits the state group from its parent.)
-
- There are three tables:
- * `state_groups`: Stores group name, first event with in the group and
- room id.
- * `event_to_state_groups`: Maps events to state groups.
- * `state_groups_state`: Maps state group to state events.
- """
-
- STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
- STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
- CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
-
- def __init__(self, db_conn, hs):
- super(StateStore, self).__init__(db_conn, hs)
- self.register_background_update_handler(
- self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
- self._background_deduplicate_state,
- )
- self.register_background_update_handler(
- self.STATE_GROUP_INDEX_UPDATE_NAME,
- self._background_index_state,
- )
- self.register_background_index_update(
- self.CURRENT_STATE_INDEX_UPDATE_NAME,
- index_name="current_state_events_member_index",
- table="current_state_events",
- columns=["state_key"],
- where_clause="type='m.room.member'",
- )
-
- def _have_persisted_state_group_txn(self, txn, state_group):
- txn.execute(
- "SELECT count(*) FROM state_groups WHERE id = ?",
- (state_group,)
- )
- row = txn.fetchone()
- return row and row[0]
-
- def _store_mult_state_groups_txn(self, txn, events_and_contexts):
- state_groups = {}
- for event, context in events_and_contexts:
- if event.internal_metadata.is_outlier():
- continue
+ Args:
+ event_id (str): The event ID for which the state was calculated
+ room_id (str)
+ prev_group (int|None): A previous state group for the room, optional.
+ delta_ids (dict|None): The delta between state at `prev_group` and
+ `current_state_ids`, if `prev_group` was given. Same format as
+ `current_state_ids`.
+ current_state_ids (dict): The state to store. Map of (type, state_key)
+ to event_id.
- if context.current_state_ids is None:
+ Returns:
+ Deferred[int]: The state group ID
+ """
+ def _store_state_group_txn(txn):
+ if current_state_ids is None:
# AFAIK, this can never happen
- logger.error(
- "Non-outlier event %s had current_state_ids==None",
- event.event_id)
- continue
-
- # if the event was rejected, just give it the same state as its
- # predecessor.
- if context.rejected:
- state_groups[event.event_id] = context.prev_group
- continue
+ raise Exception("current_state_ids cannot be None")
- state_groups[event.event_id] = context.state_group
-
- if self._have_persisted_state_group_txn(txn, context.state_group):
- continue
+ state_group = self.database_engine.get_next_state_group_id(txn)
self._simple_insert_txn(
txn,
table="state_groups",
values={
- "id": context.state_group,
- "room_id": event.room_id,
- "event_id": event.event_id,
+ "id": state_group,
+ "room_id": room_id,
+ "event_id": event_id,
},
)
# We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades.
- if context.prev_group:
+ if prev_group:
is_in_db = self._simple_select_one_onecol_txn(
txn,
table="state_groups",
- keyvalues={"id": context.prev_group},
+ keyvalues={"id": prev_group},
retcol="id",
allow_none=True,
)
if not is_in_db:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
- % (context.prev_group,)
+ % (prev_group,)
)
potential_hops = self._count_state_group_hops_txn(
- txn, context.prev_group
+ txn, prev_group
)
- if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
+ if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
self._simple_insert_txn(
txn,
table="state_group_edges",
values={
- "state_group": context.state_group,
- "prev_state_group": context.prev_group,
+ "state_group": state_group,
+ "prev_state_group": prev_group,
},
)
@@ -667,13 +648,13 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
table="state_groups_state",
values=[
{
- "state_group": context.state_group,
- "room_id": event.room_id,
+ "state_group": state_group,
+ "room_id": room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in context.delta_ids.iteritems()
+ for key, state_id in delta_ids.iteritems()
],
)
else:
@@ -682,13 +663,13 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
table="state_groups_state",
values=[
{
- "state_group": context.state_group,
- "room_id": event.room_id,
+ "state_group": state_group,
+ "room_id": room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in context.current_state_ids.iteritems()
+ for key, state_id in current_state_ids.iteritems()
],
)
@@ -699,28 +680,14 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
txn.call_after(
self._state_group_cache.update,
self._state_group_cache.sequence,
- key=context.state_group,
- value=dict(context.current_state_ids),
+ key=state_group,
+ value=dict(current_state_ids),
full=True,
)
- self._simple_insert_many_txn(
- txn,
- table="event_to_state_groups",
- values=[
- {
- "state_group": state_group_id,
- "event_id": event_id,
- }
- for event_id, state_group_id in state_groups.iteritems()
- ],
- )
+ return state_group
- for event_id, state_group_id in state_groups.iteritems():
- txn.call_after(
- self._get_state_group_for_event.prefill,
- (event_id,), state_group_id
- )
+ return self.runInteraction("store_state_group", _store_state_group_txn)
def _count_state_group_hops_txn(self, txn, state_group):
"""Given a state group, count how many hops there are in the tree.
@@ -763,8 +730,79 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
return count
- def get_next_state_group(self):
- return self._state_groups_id_gen.get_next()
+
+class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
+ """ Keeps track of the state at a given event.
+
+ This is done by the concept of `state groups`. Every event is a assigned
+ a state group (identified by an arbitrary string), which references a
+ collection of state events. The current state of an event is then the
+ collection of state events referenced by the event's state group.
+
+ Hence, every change in the current state causes a new state group to be
+ generated. However, if no change happens (e.g., if we get a message event
+ with only one parent it inherits the state group from its parent.)
+
+ There are three tables:
+ * `state_groups`: Stores group name, first event with in the group and
+ room id.
+ * `event_to_state_groups`: Maps events to state groups.
+ * `state_groups_state`: Maps state group to state events.
+ """
+
+ STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
+ STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
+ CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
+
+ def __init__(self, db_conn, hs):
+ super(StateStore, self).__init__(db_conn, hs)
+ self.register_background_update_handler(
+ self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
+ self._background_deduplicate_state,
+ )
+ self.register_background_update_handler(
+ self.STATE_GROUP_INDEX_UPDATE_NAME,
+ self._background_index_state,
+ )
+ self.register_background_index_update(
+ self.CURRENT_STATE_INDEX_UPDATE_NAME,
+ index_name="current_state_events_member_index",
+ table="current_state_events",
+ columns=["state_key"],
+ where_clause="type='m.room.member'",
+ )
+
+ def _store_event_state_mappings_txn(self, txn, events_and_contexts):
+ state_groups = {}
+ for event, context in events_and_contexts:
+ if event.internal_metadata.is_outlier():
+ continue
+
+ # if the event was rejected, just give it the same state as its
+ # predecessor.
+ if context.rejected:
+ state_groups[event.event_id] = context.prev_group
+ continue
+
+ state_groups[event.event_id] = context.state_group
+
+ self._simple_insert_many_txn(
+ txn,
+ table="event_to_state_groups",
+ values=[
+ {
+ "state_group": state_group_id,
+ "event_id": event_id,
+ }
+ for event_id, state_group_id in state_groups.iteritems()
+ ],
+ )
+
+ for event_id, state_group_id in state_groups.iteritems():
+ txn.call_after(
+ self._get_state_group_for_event.prefill,
+ (event_id,), state_group_id
+ )
@defer.inlineCallbacks
def _background_deduplicate_state(self, progress, batch_size):
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 52bdce5be2..2956c3b3e0 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -35,13 +35,16 @@ what sort order was used:
from twisted.internet import defer
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.events import EventsWorkerStore
+
from synapse.util.caches.descriptors import cached
-from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+import abc
import logging
@@ -143,81 +146,41 @@ def filter_to_clause(event_filter):
return " AND ".join(clauses), args
-class StreamStore(SQLBaseStore):
- @defer.inlineCallbacks
- def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
- # NB this lives here instead of appservice.py so we can reuse the
- # 'private' StreamToken class in this file.
- if limit:
- limit = max(limit, MAX_STREAM_SIZE)
- else:
- limit = MAX_STREAM_SIZE
-
- # From and to keys should be integers from ordering.
- from_id = RoomStreamToken.parse_stream_token(from_key)
- to_id = RoomStreamToken.parse_stream_token(to_key)
-
- if from_key == to_key:
- defer.returnValue(([], to_key))
- return
-
- # select all the events between from/to with a sensible limit
- sql = (
- "SELECT e.event_id, e.room_id, e.type, s.state_key, "
- "e.stream_ordering FROM events AS e "
- "LEFT JOIN state_events as s ON "
- "e.event_id = s.event_id "
- "WHERE e.stream_ordering > ? AND e.stream_ordering <= ? "
- "ORDER BY stream_ordering ASC LIMIT %(limit)d "
- ) % {
- "limit": limit
- }
-
- def f(txn):
- # pull out all the events between the tokens
- txn.execute(sql, (from_id.stream, to_id.stream,))
- rows = self.cursor_to_dict(txn)
-
- # Logic:
- # - We want ALL events which match the AS room_id regex
- # - We want ALL events which match the rooms represented by the AS
- # room_alias regex
- # - We want ALL events for rooms that AS users have joined.
- # This is currently supported via get_app_service_rooms (which is
- # used for the Notifier listener rooms). We can't reasonably make a
- # SQL query for these room IDs, so we'll pull all the events between
- # from/to and filter in python.
- rooms_for_as = self._get_app_service_rooms_txn(txn, service)
- room_ids_for_as = [r.room_id for r in rooms_for_as]
-
- def app_service_interested(row):
- if row["room_id"] in room_ids_for_as:
- return True
-
- if row["type"] == EventTypes.Member:
- if service.is_interested_in_user(row.get("state_key")):
- return True
- return False
+class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
+ """This is an abstract base class where subclasses must implement
+ `get_room_max_stream_ordering` and `get_room_min_stream_ordering`
+ which can be called in the initializer.
+ """
- return [r for r in rows if app_service_interested(r)]
+ __metaclass__ = abc.ABCMeta
- rows = yield self.runInteraction("get_appservice_room_stream", f)
+ def __init__(self, db_conn, hs):
+ super(StreamWorkerStore, self).__init__(db_conn, hs)
- ret = yield self._get_events(
- [r["event_id"] for r in rows],
- get_prev_content=True
+ events_max = self.get_room_max_stream_ordering()
+ event_cache_prefill, min_event_val = self._get_cache_dict(
+ db_conn, "events",
+ entity_column="room_id",
+ stream_column="stream_ordering",
+ max_value=events_max,
+ )
+ self._events_stream_cache = StreamChangeCache(
+ "EventsRoomStreamChangeCache", min_event_val,
+ prefilled_cache=event_cache_prefill,
+ )
+ self._membership_stream_cache = StreamChangeCache(
+ "MembershipStreamChangeCache", events_max,
)
- self._set_before_and_after(ret, rows, topo_order=from_id is None)
+ self._stream_order_on_start = self.get_room_max_stream_ordering()
- if rows:
- key = "s%d" % max(r["stream_ordering"] for r in rows)
- else:
- # Assume we didn't get anything because there was nothing to
- # get.
- key = to_key
+ @abc.abstractmethod
+ def get_room_max_stream_ordering(self):
+ raise NotImplementedError()
- defer.returnValue((ret, key))
+ @abc.abstractmethod
+ def get_room_min_stream_ordering(self):
+ raise NotImplementedError()
@defer.inlineCallbacks
def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0,
@@ -381,88 +344,6 @@ class StreamStore(SQLBaseStore):
defer.returnValue(ret)
@defer.inlineCallbacks
- def paginate_room_events(self, room_id, from_key, to_key=None,
- direction='b', limit=-1, event_filter=None):
- # Tokens really represent positions between elements, but we use
- # the convention of pointing to the event before the gap. Hence
- # we have a bit of asymmetry when it comes to equalities.
- args = [False, room_id]
- if direction == 'b':
- order = "DESC"
- bounds = upper_bound(
- RoomStreamToken.parse(from_key), self.database_engine
- )
- if to_key:
- bounds = "%s AND %s" % (bounds, lower_bound(
- RoomStreamToken.parse(to_key), self.database_engine
- ))
- else:
- order = "ASC"
- bounds = lower_bound(
- RoomStreamToken.parse(from_key), self.database_engine
- )
- if to_key:
- bounds = "%s AND %s" % (bounds, upper_bound(
- RoomStreamToken.parse(to_key), self.database_engine
- ))
-
- filter_clause, filter_args = filter_to_clause(event_filter)
-
- if filter_clause:
- bounds += " AND " + filter_clause
- args.extend(filter_args)
-
- if int(limit) > 0:
- args.append(int(limit))
- limit_str = " LIMIT ?"
- else:
- limit_str = ""
-
- sql = (
- "SELECT * FROM events"
- " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
- " ORDER BY topological_ordering %(order)s,"
- " stream_ordering %(order)s %(limit)s"
- ) % {
- "bounds": bounds,
- "order": order,
- "limit": limit_str
- }
-
- def f(txn):
- txn.execute(sql, args)
-
- rows = self.cursor_to_dict(txn)
-
- if rows:
- topo = rows[-1]["topological_ordering"]
- toke = rows[-1]["stream_ordering"]
- if direction == 'b':
- # Tokens are positions between events.
- # This token points *after* the last event in the chunk.
- # We need it to point to the event before it in the chunk
- # when we are going backwards so we subtract one from the
- # stream part.
- toke -= 1
- next_token = str(RoomStreamToken(topo, toke))
- else:
- # TODO (erikj): We should work out what to do here instead.
- next_token = to_key if to_key else from_key
-
- return rows, next_token,
-
- rows, token = yield self.runInteraction("paginate_room_events", f)
-
- events = yield self._get_events(
- [r["event_id"] for r in rows],
- get_prev_content=True
- )
-
- self._set_before_and_after(events, rows)
-
- defer.returnValue((events, token))
-
- @defer.inlineCallbacks
def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None):
rows, token = yield self.get_recent_event_ids_for_room(
room_id, limit, end_token, from_token
@@ -534,6 +415,33 @@ class StreamStore(SQLBaseStore):
"get_recent_events_for_room", get_recent_events_for_room_txn
)
+ def get_room_event_after_stream_ordering(self, room_id, stream_ordering):
+ """Gets details of the first event in a room at or after a stream ordering
+
+ Args:
+ room_id (str):
+ stream_ordering (int):
+
+ Returns:
+ Deferred[(int, int, str)]:
+ (stream ordering, topological ordering, event_id)
+ """
+ def _f(txn):
+ sql = (
+ "SELECT stream_ordering, topological_ordering, event_id"
+ " FROM events"
+ " WHERE room_id = ? AND stream_ordering >= ?"
+ " AND NOT outlier"
+ " ORDER BY stream_ordering"
+ " LIMIT 1"
+ )
+ txn.execute(sql, (room_id, stream_ordering, ))
+ return txn.fetchone()
+
+ return self.runInteraction(
+ "get_room_event_after_stream_ordering", _f,
+ )
+
@defer.inlineCallbacks
def get_room_events_max_id(self, room_id=None):
"""Returns the current token for rooms stream.
@@ -542,7 +450,7 @@ class StreamStore(SQLBaseStore):
`room_id` causes it to return the current room specific topological
token.
"""
- token = yield self._stream_id_gen.get_current_token()
+ token = yield self.get_room_max_stream_ordering()
if room_id is None:
defer.returnValue("s%d" % (token,))
else:
@@ -552,12 +460,6 @@ class StreamStore(SQLBaseStore):
)
defer.returnValue("t%d-%d" % (topo, token))
- def get_room_max_stream_ordering(self):
- return self._stream_id_gen.get_current_token()
-
- def get_room_min_stream_ordering(self):
- return self._backfill_id_gen.get_current_token()
-
def get_stream_token_for_event(self, event_id):
"""The stream token for an event
Args:
@@ -832,3 +734,93 @@ class StreamStore(SQLBaseStore):
def has_room_changed_since(self, room_id, stream_id):
return self._events_stream_cache.has_entity_changed(room_id, stream_id)
+
+
+class StreamStore(StreamWorkerStore):
+ def get_room_max_stream_ordering(self):
+ return self._stream_id_gen.get_current_token()
+
+ def get_room_min_stream_ordering(self):
+ return self._backfill_id_gen.get_current_token()
+
+ @defer.inlineCallbacks
+ def paginate_room_events(self, room_id, from_key, to_key=None,
+ direction='b', limit=-1, event_filter=None):
+ # Tokens really represent positions between elements, but we use
+ # the convention of pointing to the event before the gap. Hence
+ # we have a bit of asymmetry when it comes to equalities.
+ args = [False, room_id]
+ if direction == 'b':
+ order = "DESC"
+ bounds = upper_bound(
+ RoomStreamToken.parse(from_key), self.database_engine
+ )
+ if to_key:
+ bounds = "%s AND %s" % (bounds, lower_bound(
+ RoomStreamToken.parse(to_key), self.database_engine
+ ))
+ else:
+ order = "ASC"
+ bounds = lower_bound(
+ RoomStreamToken.parse(from_key), self.database_engine
+ )
+ if to_key:
+ bounds = "%s AND %s" % (bounds, upper_bound(
+ RoomStreamToken.parse(to_key), self.database_engine
+ ))
+
+ filter_clause, filter_args = filter_to_clause(event_filter)
+
+ if filter_clause:
+ bounds += " AND " + filter_clause
+ args.extend(filter_args)
+
+ if int(limit) > 0:
+ args.append(int(limit))
+ limit_str = " LIMIT ?"
+ else:
+ limit_str = ""
+
+ sql = (
+ "SELECT * FROM events"
+ " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
+ " ORDER BY topological_ordering %(order)s,"
+ " stream_ordering %(order)s %(limit)s"
+ ) % {
+ "bounds": bounds,
+ "order": order,
+ "limit": limit_str
+ }
+
+ def f(txn):
+ txn.execute(sql, args)
+
+ rows = self.cursor_to_dict(txn)
+
+ if rows:
+ topo = rows[-1]["topological_ordering"]
+ toke = rows[-1]["stream_ordering"]
+ if direction == 'b':
+ # Tokens are positions between events.
+ # This token points *after* the last event in the chunk.
+ # We need it to point to the event before it in the chunk
+ # when we are going backwards so we subtract one from the
+ # stream part.
+ toke -= 1
+ next_token = str(RoomStreamToken(topo, toke))
+ else:
+ # TODO (erikj): We should work out what to do here instead.
+ next_token = to_key if to_key else from_key
+
+ return rows, next_token,
+
+ rows, token = yield self.runInteraction("paginate_room_events", f)
+
+ events = yield self._get_events(
+ [r["event_id"] for r in rows],
+ get_prev_content=True
+ )
+
+ self._set_before_and_after(events, rows)
+
+ defer.returnValue((events, token))
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index 982a500520..13bff9f055 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,7 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import SQLBaseStore
+from synapse.storage.account_data import AccountDataWorkerStore
+
from synapse.util.caches.descriptors import cached
from twisted.internet import defer
@@ -23,15 +25,7 @@ import logging
logger = logging.getLogger(__name__)
-class TagsStore(SQLBaseStore):
- def get_max_account_data_stream_id(self):
- """Get the current max stream id for the private user data stream
-
- Returns:
- A deferred int.
- """
- return self._account_data_id_gen.get_current_token()
-
+class TagsWorkerStore(AccountDataWorkerStore):
@cached()
def get_tags_for_user(self, user_id):
"""Get all the tags for a user.
@@ -170,6 +164,8 @@ class TagsStore(SQLBaseStore):
row["tag"]: json.loads(row["content"]) for row in rows
})
+
+class TagsStore(TagsWorkerStore):
@defer.inlineCallbacks
def add_tag_to_room(self, user_id, room_id, tag, content):
"""Add a tag to a room for a user.
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index c9bff408ef..dfdcbb3181 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -641,8 +641,12 @@ class UserDirectoryStore(SQLBaseStore):
"""
if self.hs.config.user_directory_search_all_users:
- join_clause = ""
- where_clause = "?<>''" # naughty hack to keep the same number of binds
+ # make s.user_id null to keep the ordering algorithm happy
+ join_clause = """
+ CROSS JOIN (SELECT NULL as user_id) AS s
+ """
+ join_args = ()
+ where_clause = "1=1"
else:
join_clause = """
LEFT JOIN users_in_public_rooms AS p USING (user_id)
@@ -651,6 +655,7 @@ class UserDirectoryStore(SQLBaseStore):
WHERE user_id = ? AND share_private
) AS s USING (user_id)
"""
+ join_args = (user_id,)
where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)"
if isinstance(self.database_engine, PostgresEngine):
@@ -692,7 +697,7 @@ class UserDirectoryStore(SQLBaseStore):
avatar_url IS NULL
LIMIT ?
""" % (join_clause, where_clause)
- args = (user_id, full_query, exact_query, prefix_query, limit + 1,)
+ args = join_args + (full_query, exact_query, prefix_query, limit + 1,)
elif isinstance(self.database_engine, Sqlite3Engine):
search_query = _parse_query_sqlite(search_term)
@@ -710,7 +715,7 @@ class UserDirectoryStore(SQLBaseStore):
avatar_url IS NULL
LIMIT ?
""" % (join_clause, where_clause)
- args = (user_id, search_query, limit + 1)
+ args = join_args + (search_query, limit + 1)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
|