diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 7a3f6c4662..250ba536ea 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -20,7 +20,7 @@ from .appservice import (
from ._base import Cache
from .directory import DirectoryStore
from .events import EventsStore
-from .presence import PresenceStore
+from .presence import PresenceStore, UserPresenceState
from .profile import ProfileStore
from .registration import RegistrationStore
from .room import RoomStore
@@ -45,6 +45,11 @@ from .search import SearchStore
from .tags import TagsStore
from .account_data import AccountDataStore
+from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
+
+from synapse.api.constants import PresenceState
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
import logging
@@ -55,7 +60,7 @@ logger = logging.getLogger(__name__)
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
# times give more inserts into the database even for readonly API hits
# 120 seconds == 2 minutes
-LAST_SEEN_GRANULARITY = 120*1000
+LAST_SEEN_GRANULARITY = 120 * 1000
class DataStore(RoomMemberStore, RoomStore,
@@ -79,18 +84,159 @@ class DataStore(RoomMemberStore, RoomStore,
EventPushActionsStore
):
- def __init__(self, hs):
- super(DataStore, self).__init__(hs)
+ def __init__(self, db_conn, hs):
self.hs = hs
+ self.database_engine = hs.database_engine
- self.min_token_deferred = self._get_min_token()
- self.min_token = None
+ cur = db_conn.cursor()
+ try:
+ cur.execute("SELECT MIN(stream_ordering) FROM events",)
+ rows = cur.fetchall()
+ self.min_stream_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1
+ self.min_stream_token = min(self.min_stream_token, -1)
+ finally:
+ cur.close()
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
keylen=4,
)
+ self._stream_id_gen = StreamIdGenerator(
+ db_conn, "events", "stream_ordering"
+ )
+ self._receipts_id_gen = StreamIdGenerator(
+ db_conn, "receipts_linearized", "stream_id"
+ )
+ self._account_data_id_gen = StreamIdGenerator(
+ db_conn, "account_data_max_stream_id", "stream_id"
+ )
+ self._presence_id_gen = StreamIdGenerator(
+ db_conn, "presence_stream", "stream_id"
+ )
+
+ self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
+ self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
+ self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
+ self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
+ self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
+ self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
+ self._push_rules_stream_id_gen = ChainedIdGenerator(
+ self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
+ )
+ self._pushers_id_gen = StreamIdGenerator(
+ db_conn, "pushers", "id",
+ extra_tables=[("deleted_pushers", "stream_id")],
+ )
+
+ events_max = self._stream_id_gen.get_max_token()
+ event_cache_prefill, min_event_val = self._get_cache_dict(
+ db_conn, "events",
+ entity_column="room_id",
+ stream_column="stream_ordering",
+ max_value=events_max,
+ )
+ self._events_stream_cache = StreamChangeCache(
+ "EventsRoomStreamChangeCache", min_event_val,
+ prefilled_cache=event_cache_prefill,
+ )
+
+ self._membership_stream_cache = StreamChangeCache(
+ "MembershipStreamChangeCache", events_max,
+ )
+
+ account_max = self._account_data_id_gen.get_max_token()
+ self._account_data_stream_cache = StreamChangeCache(
+ "AccountDataAndTagsChangeCache", account_max,
+ )
+
+ self.__presence_on_startup = self._get_active_presence(db_conn)
+
+ presence_cache_prefill, min_presence_val = self._get_cache_dict(
+ db_conn, "presence_stream",
+ entity_column="user_id",
+ stream_column="stream_id",
+ max_value=self._presence_id_gen.get_max_token(),
+ )
+ self.presence_stream_cache = StreamChangeCache(
+ "PresenceStreamChangeCache", min_presence_val,
+ prefilled_cache=presence_cache_prefill
+ )
+
+ push_rules_prefill, push_rules_id = self._get_cache_dict(
+ db_conn, "push_rules_stream",
+ entity_column="user_id",
+ stream_column="stream_id",
+ max_value=self._push_rules_stream_id_gen.get_max_token()[0],
+ )
+
+ self.push_rules_stream_cache = StreamChangeCache(
+ "PushRulesStreamChangeCache", push_rules_id,
+ prefilled_cache=push_rules_prefill,
+ )
+
+ super(DataStore, self).__init__(hs)
+
+ def take_presence_startup_info(self):
+ active_on_startup = self.__presence_on_startup
+ self.__presence_on_startup = None
+ return active_on_startup
+
+ def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value):
+ # Fetch a mapping of room_id -> max stream position for "recent" rooms.
+ # It doesn't really matter how many we get, the StreamChangeCache will
+ # do the right thing to ensure it respects the max size of cache.
+ sql = (
+ "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
+ " WHERE %(stream)s > ? - 100000"
+ " GROUP BY %(entity)s"
+ ) % {
+ "table": table,
+ "entity": entity_column,
+ "stream": stream_column,
+ }
+
+ sql = self.database_engine.convert_param_style(sql)
+
+ txn = db_conn.cursor()
+ txn.execute(sql, (int(max_value),))
+ rows = txn.fetchall()
+ txn.close()
+
+ cache = {
+ row[0]: int(row[1])
+ for row in rows
+ }
+
+ if cache:
+ min_val = min(cache.values())
+ else:
+ min_val = max_value
+
+ return cache, min_val
+
+ def _get_active_presence(self, db_conn):
+ """Fetch non-offline presence from the database so that we can register
+ the appropriate time outs.
+ """
+
+ sql = (
+ "SELECT user_id, state, last_active_ts, last_federation_update_ts,"
+ " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
+ " WHERE state != ?"
+ )
+ sql = self.database_engine.convert_param_style(sql)
+
+ txn = db_conn.cursor()
+ txn.execute(sql, (PresenceState.OFFLINE,))
+ rows = self.cursor_to_dict(txn)
+ txn.close()
+
+ for row in rows:
+ row["currently_active"] = bool(row["currently_active"])
+
+ return [UserPresenceState(**row) for row in rows]
+
@defer.inlineCallbacks
def insert_client_ip(self, user, access_token, ip, user_agent):
now = int(self._clock.time_msec())
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 183a752387..7dc67ecd57 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,13 +15,11 @@
import logging
from synapse.api.errors import StoreError
-from synapse.util.logutils import log_function
-from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache
import synapse.metrics
-from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer
@@ -175,16 +173,6 @@ class SQLBaseStore(object):
self.database_engine = hs.database_engine
- self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
- self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
- self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
- self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
- self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
- self._pushers_id_gen = IdGenerator("pushers", "id", self)
- self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
- self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
- self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
-
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
@@ -197,7 +185,7 @@ class SQLBaseStore(object):
time_then = self._previous_loop_ts
self._previous_loop_ts = time_now
- ratio = (curr - prev)/(time_now - time_then)
+ ratio = (curr - prev) / (time_now - time_then)
top_three_counters = self._txn_perf_counters.interval(
time_now - time_then, limit=3
@@ -310,10 +298,10 @@ class SQLBaseStore(object):
func, *args, **kwargs
)
- result = yield preserve_context_over_fn(
- self._db_pool.runWithConnection,
- inner_func, *args, **kwargs
- )
+ with PreserveLoggingContext():
+ result = yield self._db_pool.runWithConnection(
+ inner_func, *args, **kwargs
+ )
for after_callback, after_args in after_callbacks:
after_callback(*after_args)
@@ -338,14 +326,15 @@ class SQLBaseStore(object):
return func(conn, *args, **kwargs)
- result = yield preserve_context_over_fn(
- self._db_pool.runWithConnection,
- inner_func, *args, **kwargs
- )
+ with PreserveLoggingContext():
+ result = yield self._db_pool.runWithConnection(
+ inner_func, *args, **kwargs
+ )
defer.returnValue(result)
- def cursor_to_dict(self, cursor):
+ @staticmethod
+ def cursor_to_dict(cursor):
"""Converts a SQL cursor into an list of dicts.
Args:
@@ -402,8 +391,8 @@ class SQLBaseStore(object):
if not or_ignore:
raise
- @log_function
- def _simple_insert_txn(self, txn, table, values):
+ @staticmethod
+ def _simple_insert_txn(txn, table, values):
keys, vals = zip(*values.items())
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
@@ -414,7 +403,8 @@ class SQLBaseStore(object):
txn.execute(sql, vals)
- def _simple_insert_many_txn(self, txn, table, values):
+ @staticmethod
+ def _simple_insert_many_txn(txn, table, values):
if not values:
return
@@ -537,9 +527,10 @@ class SQLBaseStore(object):
table, keyvalues, retcol, allow_none=allow_none,
)
- def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol,
+ @classmethod
+ def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol,
allow_none=False):
- ret = self._simple_select_onecol_txn(
+ ret = cls._simple_select_onecol_txn(
txn,
table=table,
keyvalues=keyvalues,
@@ -554,7 +545,8 @@ class SQLBaseStore(object):
else:
raise StoreError(404, "No row found")
- def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
+ @staticmethod
+ def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
sql = (
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
) % {
@@ -603,7 +595,8 @@ class SQLBaseStore(object):
table, keyvalues, retcols
)
- def _simple_select_list_txn(self, txn, table, keyvalues, retcols):
+ @classmethod
+ def _simple_select_list_txn(cls, txn, table, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -627,7 +620,83 @@ class SQLBaseStore(object):
)
txn.execute(sql)
- return self.cursor_to_dict(txn)
+ return cls.cursor_to_dict(txn)
+
+ @defer.inlineCallbacks
+ def _simple_select_many_batch(self, table, column, iterable, retcols,
+ keyvalues={}, desc="_simple_select_many_batch",
+ batch_size=100):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Filters rows by if value of `column` is in `iterable`.
+
+ Args:
+ table : string giving the table name
+ column : column name to test for inclusion against `iterable`
+ iterable : list
+ keyvalues : dict of column names and values to select the rows with
+ retcols : list of strings giving the names of the columns to return
+ """
+ results = []
+
+ if not iterable:
+ defer.returnValue(results)
+
+ chunks = [
+ iterable[i:i + batch_size]
+ for i in xrange(0, len(iterable), batch_size)
+ ]
+ for chunk in chunks:
+ rows = yield self.runInteraction(
+ desc,
+ self._simple_select_many_txn,
+ table, column, chunk, keyvalues, retcols
+ )
+
+ results.extend(rows)
+
+ defer.returnValue(results)
+
+ @classmethod
+ def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Filters rows by if value of `column` is in `iterable`.
+
+ Args:
+ txn : Transaction object
+ table : string giving the table name
+ column : column name to test for inclusion against `iterable`
+ iterable : list
+ keyvalues : dict of column names and values to select the rows with
+ retcols : list of strings giving the names of the columns to return
+ """
+ if not iterable:
+ return []
+
+ sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
+
+ clauses = []
+ values = []
+ clauses.append(
+ "%s IN (%s)" % (column, ",".join("?" for _ in iterable))
+ )
+ values.extend(iterable)
+
+ for key, value in keyvalues.items():
+ clauses.append("%s = ?" % (key,))
+ values.append(value)
+
+ if clauses:
+ sql = "%s WHERE %s" % (
+ sql,
+ " AND ".join(clauses),
+ )
+
+ txn.execute(sql, values)
+ return cls.cursor_to_dict(txn)
def _simple_update_one(self, table, keyvalues, updatevalues,
desc="_simple_update_one"):
@@ -654,7 +723,8 @@ class SQLBaseStore(object):
table, keyvalues, updatevalues,
)
- def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues):
+ @staticmethod
+ def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
update_sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in updatevalues),
@@ -671,7 +741,8 @@ class SQLBaseStore(object):
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched")
- def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
+ @staticmethod
+ def _simple_select_one_txn(txn, table, keyvalues, retcols,
allow_none=False):
select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
@@ -699,20 +770,32 @@ class SQLBaseStore(object):
table : string giving the table name
keyvalues : dict of column names and values to select the row with
"""
+ return self.runInteraction(
+ desc, self._simple_delete_one_txn, table, keyvalues
+ )
+
+ @staticmethod
+ def _simple_delete_one_txn(txn, table, keyvalues):
+ """Executes a DELETE query on the named table, expecting to delete a
+ single row.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ """
sql = "DELETE FROM %s WHERE %s" % (
table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
- def func(txn):
- txn.execute(sql, keyvalues.values())
- if txn.rowcount == 0:
- raise StoreError(404, "No row found")
- if txn.rowcount > 1:
- raise StoreError(500, "more than one row matched")
- return self.runInteraction(desc, func)
+ txn.execute(sql, keyvalues.values())
+ if txn.rowcount == 0:
+ raise StoreError(404, "No row found")
+ if txn.rowcount > 1:
+ raise StoreError(500, "more than one row matched")
- def _simple_delete_txn(self, txn, table, keyvalues):
+ @staticmethod
+ def _simple_delete_txn(txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % (
table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index 9c6597e012..faddefe219 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -83,8 +83,40 @@ class AccountDataStore(SQLBaseStore):
"get_account_data_for_room", get_account_data_for_room_txn
)
+ def get_all_updated_account_data(self, last_global_id, last_room_id,
+ current_id, limit):
+ """Get all the client account_data that has changed on the server
+ Args:
+ last_global_id(int): The position to fetch from for top level data
+ last_room_id(int): The position to fetch from for per room data
+ current_id(int): The position to fetch up to.
+ Returns:
+ A deferred pair of lists of tuples of stream_id int, user_id string,
+ room_id string, type string, and content string.
+ """
+ def get_updated_account_data_txn(txn):
+ sql = (
+ "SELECT stream_id, user_id, account_data_type, content"
+ " FROM account_data WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_global_id, current_id, limit))
+ global_results = txn.fetchall()
+
+ sql = (
+ "SELECT stream_id, user_id, room_id, account_data_type, content"
+ " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_room_id, current_id, limit))
+ room_results = txn.fetchall()
+ return (global_results, room_results)
+ return self.runInteraction(
+ "get_all_updated_account_data_txn", get_updated_account_data_txn
+ )
+
def get_updated_account_data_for_user(self, user_id, stream_id):
- """Get all the client account_data for a that's changed.
+ """Get all the client account_data for a that's changed for a user
Args:
user_id(str): The user to get the account_data for.
@@ -120,6 +152,12 @@ class AccountDataStore(SQLBaseStore):
return (global_account_data, account_data_by_room)
+ changed = self._account_data_stream_cache.has_entity_changed(
+ user_id, int(stream_id)
+ )
+ if not changed:
+ return ({}, {})
+
return self.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
@@ -151,14 +189,18 @@ class AccountDataStore(SQLBaseStore):
"content": content_json,
}
)
+ txn.call_after(
+ self._account_data_stream_cache.entity_has_changed,
+ user_id, next_id,
+ )
self._update_max_stream_id(txn, next_id)
- with (yield self._account_data_id_gen.get_next(self)) as next_id:
+ with self._account_data_id_gen.get_next() as next_id:
yield self.runInteraction(
"add_room_account_data", add_account_data_txn, next_id
)
- result = yield self._account_data_id_gen.get_max_token(self)
+ result = self._account_data_id_gen.get_max_token()
defer.returnValue(result)
@defer.inlineCallbacks
@@ -186,14 +228,18 @@ class AccountDataStore(SQLBaseStore):
"content": content_json,
}
)
+ txn.call_after(
+ self._account_data_stream_cache.entity_has_changed,
+ user_id, next_id,
+ )
self._update_max_stream_id(txn, next_id)
- with (yield self._account_data_id_gen.get_next(self)) as next_id:
+ with self._account_data_id_gen.get_next() as next_id:
yield self.runInteraction(
"add_user_account_data", add_account_data_txn, next_id
)
- result = yield self._account_data_id_gen.get_max_token(self)
+ result = self._account_data_id_gen.get_max_token()
defer.returnValue(result)
def _update_max_stream_id(self, txn, next_id):
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index b5aa55c0a3..371600eebb 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -34,8 +34,8 @@ class ApplicationServiceStore(SQLBaseStore):
def __init__(self, hs):
super(ApplicationServiceStore, self).__init__(hs)
self.hostname = hs.hostname
- self.services_cache = []
- self._populate_appservice_cache(
+ self.services_cache = ApplicationServiceStore.load_appservices(
+ hs.hostname,
hs.config.app_service_config_files
)
@@ -144,21 +144,23 @@ class ApplicationServiceStore(SQLBaseStore):
return rooms_for_user_matching_user_id
- def _load_appservice(self, as_info):
+ @classmethod
+ def _load_appservice(cls, hostname, as_info, config_filename):
required_string_fields = [
- # TODO: Add id here when it's stable to release
- "url", "as_token", "hs_token", "sender_localpart"
+ "id", "url", "as_token", "hs_token", "sender_localpart"
]
for field in required_string_fields:
if not isinstance(as_info.get(field), basestring):
- raise KeyError("Required string field: '%s'", field)
+ raise KeyError("Required string field: '%s' (%s)" % (
+ field, config_filename,
+ ))
localpart = as_info["sender_localpart"]
if urllib.quote(localpart) != localpart:
raise ValueError(
"sender_localpart needs characters which are not URL encoded."
)
- user = UserID(localpart, self.hostname)
+ user = UserID(localpart, hostname)
user_id = user.to_string()
# namespace checks
@@ -188,25 +190,30 @@ class ApplicationServiceStore(SQLBaseStore):
namespaces=as_info["namespaces"],
hs_token=as_info["hs_token"],
sender=user_id,
- id=as_info["id"] if "id" in as_info else as_info["as_token"],
+ id=as_info["id"],
)
- def _populate_appservice_cache(self, config_files):
- """Populates a cache of Application Services from the config files."""
+ @classmethod
+ def load_appservices(cls, hostname, config_files):
+ """Returns a list of Application Services from the config files."""
if not isinstance(config_files, list):
logger.warning(
"Expected %s to be a list of AS config files.", config_files
)
- return
+ return []
# Dicts of value -> filename
seen_as_tokens = {}
seen_ids = {}
+ appservices = []
+
for config_file in config_files:
try:
with open(config_file, 'r') as f:
- appservice = self._load_appservice(yaml.load(f))
+ appservice = ApplicationServiceStore._load_appservice(
+ hostname, yaml.load(f), config_file
+ )
if appservice.id in seen_ids:
raise ConfigError(
"Cannot reuse ID across application services: "
@@ -226,11 +233,12 @@ class ApplicationServiceStore(SQLBaseStore):
)
seen_as_tokens[appservice.token] = config_file
logger.info("Loaded application service: %s", appservice)
- self.services_cache.append(appservice)
+ appservices.append(appservice)
except Exception as e:
logger.error("Failed to load appservice from '%s'", config_file)
logger.exception(e)
raise
+ return appservices
class ApplicationServiceTransactionStore(SQLBaseStore):
@@ -276,7 +284,8 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
"application_services_state",
dict(as_id=service.id),
["state"],
- allow_none=True
+ allow_none=True,
+ desc="get_appservice_state",
)
if result:
defer.returnValue(result.get("state"))
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 1556619d5e..012a0b414a 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -70,13 +70,14 @@ class DirectoryStore(SQLBaseStore):
)
@defer.inlineCallbacks
- def create_room_alias_association(self, room_alias, room_id, servers):
+ def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
""" Creates an associatin between a room alias and room_id/servers
Args:
room_alias (RoomAlias)
room_id (str)
servers (list)
+ creator (str): Optional user_id of creator.
Returns:
Deferred
@@ -87,6 +88,7 @@ class DirectoryStore(SQLBaseStore):
{
"room_alias": room_alias.to_string(),
"room_id": room_id,
+ "creator": creator,
},
desc="create_room_alias_association",
)
@@ -107,6 +109,17 @@ class DirectoryStore(SQLBaseStore):
)
self.get_aliases_for_room.invalidate((room_id,))
+ def get_room_alias_creator(self, room_alias):
+ return self._simple_select_one_onecol(
+ table="room_aliases",
+ keyvalues={
+ "room_alias": room_alias,
+ },
+ retcol="creator",
+ desc="get_room_alias_creator",
+ allow_none=True
+ )
+
@defer.inlineCallbacks
def delete_room_alias(self, room_alias):
room_id = yield self.runInteraction(
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 5dd32b1413..2e89066515 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from _base import SQLBaseStore
+from ._base import SQLBaseStore
class EndToEndKeyStore(SQLBaseStore):
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 4290aea83a..a48230b93f 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -26,12 +26,13 @@ SUPPORTED_MODULE = {
}
-def create_engine(name):
+def create_engine(config):
+ name = config.database_config["name"]
engine_class = SUPPORTED_MODULE.get(name, None)
if engine_class:
module = importlib.import_module(name)
- return engine_class(module)
+ return engine_class(module, config=config)
raise RuntimeError(
"Unsupported database engine '%s'" % (name,)
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 17b7a9c077..a09685b4df 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -21,9 +21,10 @@ from ._base import IncorrectDatabaseSetup
class PostgresEngine(object):
single_threaded = False
- def __init__(self, database_module):
+ def __init__(self, database_module, config):
self.module = database_module
self.module.extensions.register_type(self.module.extensions.UNICODE)
+ self.config = config
def check_database(self, txn):
txn.execute("SHOW SERVER_ENCODING")
@@ -44,7 +45,7 @@ class PostgresEngine(object):
)
def prepare_database(self, db_conn):
- prepare_database(db_conn, self)
+ prepare_database(db_conn, self, config=self.config)
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py
index 400c10103c..522b905949 100644
--- a/synapse/storage/engines/sqlite3.py
+++ b/synapse/storage/engines/sqlite3.py
@@ -23,8 +23,9 @@ import struct
class Sqlite3Engine(object):
single_threaded = True
- def __init__(self, database_module):
+ def __init__(self, database_module, config):
self.module = database_module
+ self.config = config
def check_database(self, txn):
pass
@@ -38,7 +39,7 @@ class Sqlite3Engine(object):
def prepare_database(self, db_conn):
prepare_sqlite3_database(db_conn)
- prepare_database(db_conn, self)
+ prepare_database(db_conn, self, config=self.config)
def is_deadlock(self, error):
return False
@@ -54,7 +55,7 @@ class Sqlite3Engine(object):
def _parse_match_info(buf):
bufsize = len(buf)
- return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)]
+ return [struct.unpack('@I', buf[i:i + 4])[0] for i in range(0, bufsize, 4)]
def _rank(raw_match_info):
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 5f32eec6f8..3489315e0d 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -58,7 +58,7 @@ class EventFederationStore(SQLBaseStore):
new_front = set()
front_list = list(front)
chunks = [
- front_list[x:x+100]
+ front_list[x:x + 100]
for x in xrange(0, len(front), 100)
]
for chunk in chunks:
@@ -114,10 +114,10 @@ class EventFederationStore(SQLBaseStore):
retcol="event_id",
)
- def get_latest_events_in_room(self, room_id):
+ def get_latest_event_ids_and_hashes_in_room(self, room_id):
return self.runInteraction(
- "get_latest_events_in_room",
- self._get_latest_events_in_room,
+ "get_latest_event_ids_and_hashes_in_room",
+ self._get_latest_event_ids_and_hashes_in_room,
room_id,
)
@@ -132,7 +132,7 @@ class EventFederationStore(SQLBaseStore):
desc="get_latest_event_ids_in_room",
)
- def _get_latest_events_in_room(self, txn, room_id):
+ def _get_latest_event_ids_and_hashes_in_room(self, txn, room_id):
sql = (
"SELECT e.event_id, e.depth FROM events as e "
"INNER JOIN event_forward_extremities as f "
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index a05c4f84cf..5820539a92 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -24,34 +24,30 @@ logger = logging.getLogger(__name__)
class EventPushActionsStore(SQLBaseStore):
- @defer.inlineCallbacks
- def set_push_actions_for_event_and_users(self, event, tuples):
+ def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
"""
:param event: the event set actions for
- :param tuples: list of tuples of (user_id, profile_tag, actions)
+ :param tuples: list of tuples of (user_id, actions)
"""
values = []
- for uid, profile_tag, actions in tuples:
+ for uid, actions in tuples:
values.append({
'room_id': event.room_id,
'event_id': event.event_id,
'user_id': uid,
- 'profile_tag': profile_tag,
- 'actions': json.dumps(actions)
+ 'actions': json.dumps(actions),
+ 'stream_ordering': event.internal_metadata.stream_ordering,
+ 'topological_ordering': event.depth,
+ 'notif': 1,
+ 'highlight': 1 if _action_has_highlight(actions) else 0,
})
- def f(txn):
- for uid, _, __ in tuples:
- txn.call_after(
- self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
- (event.room_id, uid)
- )
- return self._simple_insert_many_txn(txn, "event_push_actions", values)
-
- yield self.runInteraction(
- "set_actions_for_event_and_users",
- f,
- )
+ for uid, __ in tuples:
+ txn.call_after(
+ self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+ (event.room_id, uid)
+ )
+ self._simple_insert_many_txn(txn, "event_push_actions", values)
@cachedInlineCallbacks(num_args=3, lru=True, tree=True)
def get_unread_event_push_actions_by_room_for_user(
@@ -68,32 +64,34 @@ class EventPushActionsStore(SQLBaseStore):
)
results = txn.fetchall()
if len(results) == 0:
- return []
+ return {"notify_count": 0, "highlight_count": 0}
stream_ordering = results[0][0]
topological_ordering = results[0][1]
sql = (
- "SELECT ea.event_id, ea.actions"
- " FROM event_push_actions ea, events e"
- " WHERE ea.room_id = e.room_id"
- " AND ea.event_id = e.event_id"
- " AND ea.user_id = ?"
- " AND ea.room_id = ?"
+ "SELECT sum(notif), sum(highlight)"
+ " FROM event_push_actions ea"
+ " WHERE"
+ " user_id = ?"
+ " AND room_id = ?"
" AND ("
- " e.topological_ordering > ?"
- " OR (e.topological_ordering = ? AND e.stream_ordering > ?)"
+ " topological_ordering > ?"
+ " OR (topological_ordering = ? AND stream_ordering > ?)"
")"
)
txn.execute(sql, (
user_id, room_id,
topological_ordering, topological_ordering, stream_ordering
- )
- )
- return [
- {"event_id": row[0], "actions": json.loads(row[1])}
- for row in txn.fetchall()
- ]
+ ))
+ row = txn.fetchone()
+ if row:
+ return {
+ "notify_count": row[0] or 0,
+ "highlight_count": row[1] or 0,
+ }
+ else:
+ return {"notify_count": 0, "highlight_count": 0}
ret = yield self.runInteraction(
"get_unread_event_push_actions_by_room",
@@ -101,19 +99,24 @@ class EventPushActionsStore(SQLBaseStore):
)
defer.returnValue(ret)
- @defer.inlineCallbacks
- def remove_push_actions_for_event_id(self, room_id, event_id):
- def f(txn):
- # Sad that we have to blow away the cache for the whole room here
- txn.call_after(
- self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
- (room_id,)
- )
- txn.execute(
- "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
- (room_id, event_id)
- )
- yield self.runInteraction(
- "remove_push_actions_for_event_id",
- f
+ def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
+ # Sad that we have to blow away the cache for the whole room here
+ txn.call_after(
+ self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+ (room_id,)
)
+ txn.execute(
+ "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
+ (room_id, event_id)
+ )
+
+
+def _action_has_highlight(actions):
+ for action in actions:
+ try:
+ if action.get("set_tweak", None) == "highlight":
+ return action.get("value", True)
+ except AttributeError:
+ pass
+
+ return False
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index ba368a3eca..285c586cfe 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -12,14 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from _base import SQLBaseStore, _RollbackButIsFineException
+from ._base import SQLBaseStore, _RollbackButIsFineException
from twisted.internet import defer, reactor
from synapse.events import FrozenEvent, USE_FROZEN_DICTS
from synapse.events.utils import prune_event
-from synapse.util.logcontext import preserve_context_over_deferred
+from synapse.util.logcontext import preserve_fn, PreserveLoggingContext
from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes
@@ -66,19 +66,17 @@ class EventsStore(SQLBaseStore):
return
if backfilled:
- if not self.min_token_deferred.called:
- yield self.min_token_deferred
- start = self.min_token - 1
- self.min_token -= len(events_and_contexts) + 1
- stream_orderings = range(start, self.min_token, -1)
+ start = self.min_stream_token - 1
+ self.min_stream_token -= len(events_and_contexts) + 1
+ stream_orderings = range(start, self.min_stream_token, -1)
@contextmanager
def stream_ordering_manager():
yield stream_orderings
stream_ordering_manager = stream_ordering_manager()
else:
- stream_ordering_manager = yield self._stream_id_gen.get_next_mult(
- self, len(events_and_contexts)
+ stream_ordering_manager = self._stream_id_gen.get_next_mult(
+ len(events_and_contexts)
)
with stream_ordering_manager as stream_orderings:
@@ -86,7 +84,7 @@ class EventsStore(SQLBaseStore):
event.internal_metadata.stream_ordering = stream
chunks = [
- events_and_contexts[x:x+100]
+ events_and_contexts[x:x + 100]
for x in xrange(0, len(events_and_contexts), 100)
]
@@ -107,13 +105,11 @@ class EventsStore(SQLBaseStore):
is_new_state=True, current_state=None):
stream_ordering = None
if backfilled:
- if not self.min_token_deferred.called:
- yield self.min_token_deferred
- self.min_token -= 1
- stream_ordering = self.min_token
+ self.min_stream_token -= 1
+ stream_ordering = self.min_stream_token
if stream_ordering is None:
- stream_ordering_manager = yield self._stream_id_gen.get_next(self)
+ stream_ordering_manager = self._stream_id_gen.get_next()
else:
@contextmanager
def stream_ordering_manager():
@@ -135,7 +131,7 @@ class EventsStore(SQLBaseStore):
except _RollbackButIsFineException:
pass
- max_persisted_id = yield self._stream_id_gen.get_max_token(self)
+ max_persisted_id = yield self._stream_id_gen.get_max_token()
defer.returnValue((stream_ordering, max_persisted_id))
@defer.inlineCallbacks
@@ -209,17 +205,29 @@ class EventsStore(SQLBaseStore):
@log_function
def _persist_events_txn(self, txn, events_and_contexts, backfilled,
is_new_state=True):
-
- # Remove the any existing cache entries for the event_ids
- for event, _ in events_and_contexts:
+ depth_updates = {}
+ for event, context in events_and_contexts:
+ # Remove the any existing cache entries for the event_ids
txn.call_after(self._invalidate_get_event_cache, event.event_id)
+ if not backfilled:
+ txn.call_after(
+ self._events_stream_cache.entity_has_changed,
+ event.room_id, event.internal_metadata.stream_ordering,
+ )
- depth_updates = {}
- for event, _ in events_and_contexts:
- if event.internal_metadata.is_outlier():
- continue
- depth_updates[event.room_id] = max(
- event.depth, depth_updates.get(event.room_id, event.depth)
+ if not event.internal_metadata.is_outlier():
+ depth_updates[event.room_id] = max(
+ event.depth, depth_updates.get(event.room_id, event.depth)
+ )
+
+ if context.push_actions:
+ self._set_push_actions_for_event_and_users_txn(
+ txn, event, context.push_actions
+ )
+
+ if event.type == EventTypes.Redaction and event.redacts is not None:
+ self._remove_push_actions_for_event_id_txn(
+ txn, event.room_id, event.redacts
)
for room_id, depth in depth_updates.items():
@@ -518,6 +526,9 @@ class EventsStore(SQLBaseStore):
if not event_ids:
defer.returnValue([])
+ event_id_list = event_ids
+ event_ids = set(event_ids)
+
event_map = self._get_events_from_cache(
event_ids,
check_redacted=check_redacted,
@@ -527,23 +538,18 @@ class EventsStore(SQLBaseStore):
missing_events_ids = [e for e in event_ids if e not in event_map]
- if not missing_events_ids:
- defer.returnValue([
- event_map[e_id] for e_id in event_ids
- if e_id in event_map and event_map[e_id]
- ])
-
- missing_events = yield self._enqueue_events(
- missing_events_ids,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- allow_rejected=allow_rejected,
- )
+ if missing_events_ids:
+ missing_events = yield self._enqueue_events(
+ missing_events_ids,
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
- event_map.update(missing_events)
+ event_map.update(missing_events)
defer.returnValue([
- event_map[e_id] for e_id in event_ids
+ event_map[e_id] for e_id in event_id_list
if e_id in event_map and event_map[e_id]
])
@@ -662,14 +668,16 @@ class EventsStore(SQLBaseStore):
for ids, d in lst:
if not d.called:
try:
- d.callback([
- res[i]
- for i in ids
- if i in res
- ])
+ with PreserveLoggingContext():
+ d.callback([
+ res[i]
+ for i in ids
+ if i in res
+ ])
except:
logger.exception("Failed to callback")
- reactor.callFromThread(fire, event_list, row_dict)
+ with PreserveLoggingContext():
+ reactor.callFromThread(fire, event_list, row_dict)
except Exception as e:
logger.exception("do_fetch")
@@ -677,10 +685,12 @@ class EventsStore(SQLBaseStore):
def fire(evs):
for _, d in evs:
if not d.called:
- d.errback(e)
+ with PreserveLoggingContext():
+ d.errback(e)
if event_list:
- reactor.callFromThread(fire, event_list)
+ with PreserveLoggingContext():
+ reactor.callFromThread(fire, event_list)
@defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True,
@@ -707,18 +717,20 @@ class EventsStore(SQLBaseStore):
should_start = False
if should_start:
- self.runWithConnection(
- self._do_fetch
- )
+ with PreserveLoggingContext():
+ self.runWithConnection(
+ self._do_fetch
+ )
- rows = yield preserve_context_over_deferred(events_d)
+ with PreserveLoggingContext():
+ rows = yield events_d
if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]]
res = yield defer.gatherResults(
[
- self._get_event_from_row(
+ preserve_fn(self._get_event_from_row)(
row["internal_metadata"], row["json"], row["redacts"],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
@@ -738,7 +750,7 @@ class EventsStore(SQLBaseStore):
rows = []
N = 200
for i in range(1 + len(events) / N):
- evs = events[i*N:(i + 1)*N]
+ evs = events[i * N:(i + 1) * N]
if not evs:
break
@@ -753,7 +765,7 @@ class EventsStore(SQLBaseStore):
" LEFT JOIN rejections as rej USING (event_id)"
" LEFT JOIN redactions as r ON e.event_id = r.redacts"
" WHERE e.event_id IN (%s)"
- ) % (",".join(["?"]*len(evs)),)
+ ) % (",".join(["?"] * len(evs)),)
txn.execute(sql, evs)
rows.extend(self.cursor_to_dict(txn))
@@ -1050,3 +1062,48 @@ class EventsStore(SQLBaseStore):
yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME)
defer.returnValue(result)
+
+ def get_current_backfill_token(self):
+ """The current minimum token that backfilled events have reached"""
+
+ # TODO: Fix race with the persit_event txn by using one of the
+ # stream id managers
+ return -self.min_stream_token
+
+ def get_all_new_events(self, last_backfill_id, last_forward_id,
+ current_backfill_id, current_forward_id, limit):
+ """Get all the new events that have arrived at the server either as
+ new events or as backfilled events"""
+ def get_all_new_events_txn(txn):
+ sql = (
+ "SELECT e.stream_ordering, ej.internal_metadata, ej.json"
+ " FROM events as e"
+ " JOIN event_json as ej"
+ " ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
+ " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?"
+ " ORDER BY e.stream_ordering ASC"
+ " LIMIT ?"
+ )
+ if last_forward_id != current_forward_id:
+ txn.execute(sql, (last_forward_id, current_forward_id, limit))
+ new_forward_events = txn.fetchall()
+ else:
+ new_forward_events = []
+
+ sql = (
+ "SELECT -e.stream_ordering, ej.internal_metadata, ej.json"
+ " FROM events as e"
+ " JOIN event_json as ej"
+ " ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
+ " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?"
+ " ORDER BY e.stream_ordering DESC"
+ " LIMIT ?"
+ )
+ if last_backfill_id != current_backfill_id:
+ txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
+ new_backfill_events = txn.fetchall()
+ else:
+ new_backfill_events = []
+
+ return (new_forward_events, new_backfill_events)
+ return self.runInteraction("get_all_new_events", get_all_new_events_txn)
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index f8fc9bdddc..5248736816 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -16,12 +16,13 @@
from twisted.internet import defer
from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cachedInlineCallbacks
import simplejson as json
class FilteringStore(SQLBaseStore):
- @defer.inlineCallbacks
+ @cachedInlineCallbacks(num_args=2)
def get_user_filter(self, user_localpart, filter_id):
def_json = yield self._simple_select_one_onecol(
table="user_filters",
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 8022b8cfc6..a495a8a7d9 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from _base import SQLBaseStore
+from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
from twisted.internet import defer
@@ -39,6 +39,7 @@ class KeyStore(SQLBaseStore):
table="server_tls_certificates",
keyvalues={"server_name": server_name},
retcols=("tls_certificate",),
+ desc="get_server_certificate",
)
tls_certificate = OpenSSL.crypto.load_certificate(
OpenSSL.crypto.FILETYPE_ASN1, tls_certificate_bytes,
diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py
index 0894384780..9d3ba32478 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/media_repository.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from _base import SQLBaseStore
+from ._base import SQLBaseStore
class MediaRepositoryStore(SQLBaseStore):
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index c1f5f99789..3f29aad1e8 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 28
+SCHEMA_VERSION = 30
dir_path = os.path.abspath(os.path.dirname(__file__))
@@ -50,7 +50,7 @@ class UpgradeDatabaseException(PrepareDatabaseException):
pass
-def prepare_database(db_conn, database_engine):
+def prepare_database(db_conn, database_engine, config):
"""Prepares a database for usage. Will either create all necessary tables
or upgrade from an older schema version.
"""
@@ -61,10 +61,10 @@ def prepare_database(db_conn, database_engine):
if version_info:
user_version, delta_files, upgraded = version_info
_upgrade_existing_database(
- cur, user_version, delta_files, upgraded, database_engine
+ cur, user_version, delta_files, upgraded, database_engine, config
)
else:
- _setup_new_database(cur, database_engine)
+ _setup_new_database(cur, database_engine, config)
# cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,))
@@ -75,7 +75,7 @@ def prepare_database(db_conn, database_engine):
raise
-def _setup_new_database(cur, database_engine):
+def _setup_new_database(cur, database_engine, config):
"""Sets up the database by finding a base set of "full schemas" and then
applying any necessary deltas.
@@ -148,11 +148,12 @@ def _setup_new_database(cur, database_engine):
applied_delta_files=[],
upgraded=False,
database_engine=database_engine,
+ config=config,
)
def _upgrade_existing_database(cur, current_version, applied_delta_files,
- upgraded, database_engine):
+ upgraded, database_engine, config):
"""Upgrades an existing database.
Delta files can either be SQL stored in *.sql files, or python modules
@@ -211,7 +212,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
logger.debug("applied_delta_files: %s", applied_delta_files)
for v in range(start_ver, SCHEMA_VERSION + 1):
- logger.debug("Upgrading schema to v%d", v)
+ logger.info("Upgrading schema to v%d", v)
delta_dir = os.path.join(dir_path, "schema", "delta", str(v))
@@ -245,7 +246,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
module_name, absolute_path, python_file
)
logger.debug("Running script %s", relative_path)
- module.run_upgrade(cur, database_engine)
+ module.run_upgrade(cur, database_engine, config=config)
elif ext == ".pyc":
# Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 1095d52ace..4cec31e316 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -14,71 +14,148 @@
# limitations under the License.
from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached, cachedList
+from synapse.api.constants import PresenceState
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from collections import namedtuple
from twisted.internet import defer
+class UserPresenceState(namedtuple("UserPresenceState",
+ ("user_id", "state", "last_active_ts",
+ "last_federation_update_ts", "last_user_sync_ts",
+ "status_msg", "currently_active"))):
+ """Represents the current presence state of the user.
+
+ user_id (str)
+ last_active (int): Time in msec that the user last interacted with server.
+ last_federation_update (int): Time in msec since either a) we sent a presence
+ update to other servers or b) we received a presence update, depending
+ on if is a local user or not.
+ last_user_sync (int): Time in msec that the user last *completed* a sync
+ (or event stream).
+ status_msg (str): User set status message.
+ """
+
+ def copy_and_replace(self, **kwargs):
+ return self._replace(**kwargs)
+
+ @classmethod
+ def default(cls, user_id):
+ """Returns a default presence state.
+ """
+ return cls(
+ user_id=user_id,
+ state=PresenceState.OFFLINE,
+ last_active_ts=0,
+ last_federation_update_ts=0,
+ last_user_sync_ts=0,
+ status_msg=None,
+ currently_active=False,
+ )
+
+
class PresenceStore(SQLBaseStore):
- def create_presence(self, user_localpart):
- res = self._simple_insert(
- table="presence",
- values={"user_id": user_localpart},
- desc="create_presence",
+ @defer.inlineCallbacks
+ def update_presence(self, presence_states):
+ stream_ordering_manager = self._presence_id_gen.get_next_mult(
+ len(presence_states)
)
- self.get_presence_state.invalidate((user_localpart,))
- return res
+ with stream_ordering_manager as stream_orderings:
+ yield self.runInteraction(
+ "update_presence",
+ self._update_presence_txn, stream_orderings, presence_states,
+ )
- def has_presence_state(self, user_localpart):
- return self._simple_select_one(
- table="presence",
- keyvalues={"user_id": user_localpart},
- retcols=["user_id"],
- allow_none=True,
- desc="has_presence_state",
+ defer.returnValue((stream_orderings[-1], self._presence_id_gen.get_max_token()))
+
+ def _update_presence_txn(self, txn, stream_orderings, presence_states):
+ for stream_id, state in zip(stream_orderings, presence_states):
+ txn.call_after(
+ self.presence_stream_cache.entity_has_changed,
+ state.user_id, stream_id,
+ )
+
+ # Actually insert new rows
+ self._simple_insert_many_txn(
+ txn,
+ table="presence_stream",
+ values=[
+ {
+ "stream_id": stream_id,
+ "user_id": state.user_id,
+ "state": state.state,
+ "last_active_ts": state.last_active_ts,
+ "last_federation_update_ts": state.last_federation_update_ts,
+ "last_user_sync_ts": state.last_user_sync_ts,
+ "status_msg": state.status_msg,
+ "currently_active": state.currently_active,
+ }
+ for state in presence_states
+ ],
)
- @cached(max_entries=2000)
- def get_presence_state(self, user_localpart):
- return self._simple_select_one(
- table="presence",
- keyvalues={"user_id": user_localpart},
- retcols=["state", "status_msg", "mtime"],
- desc="get_presence_state",
+ # Delete old rows to stop database from getting really big
+ sql = (
+ "DELETE FROM presence_stream WHERE"
+ " stream_id < ?"
+ " AND user_id IN (%s)"
+ )
+
+ batches = (
+ presence_states[i:i + 50]
+ for i in xrange(0, len(presence_states), 50)
)
+ for states in batches:
+ args = [stream_id]
+ args.extend(s.user_id for s in states)
+ txn.execute(
+ sql % (",".join("?" for _ in states),),
+ args
+ )
+
+ def get_all_presence_updates(self, last_id, current_id):
+ def get_all_presence_updates_txn(txn):
+ sql = (
+ "SELECT stream_id, user_id, state, last_active_ts,"
+ " last_federation_update_ts, last_user_sync_ts, status_msg,"
+ " currently_active"
+ " FROM presence_stream"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ )
+ txn.execute(sql, (last_id, current_id))
+ return txn.fetchall()
- @cachedList(get_presence_state.cache, list_name="user_localparts")
- def get_presence_states(self, user_localparts):
- def f(txn):
- results = {}
- for user_localpart in user_localparts:
- res = self._simple_select_one_txn(
- txn,
- table="presence",
- keyvalues={"user_id": user_localpart},
- retcols=["state", "status_msg", "mtime"],
- allow_none=True,
- )
- if res:
- results[user_localpart] = res
-
- return results
-
- return self.runInteraction("get_presence_states", f)
-
- def set_presence_state(self, user_localpart, new_state):
- res = self._simple_update_one(
- table="presence",
- keyvalues={"user_id": user_localpart},
- updatevalues={"state": new_state["state"],
- "status_msg": new_state["status_msg"],
- "mtime": self._clock.time_msec()},
- desc="set_presence_state",
+ return self.runInteraction(
+ "get_all_presence_updates", get_all_presence_updates_txn
)
- self.get_presence_state.invalidate((user_localpart,))
- return res
+ @defer.inlineCallbacks
+ def get_presence_for_users(self, user_ids):
+ rows = yield self._simple_select_many_batch(
+ table="presence_stream",
+ column="user_id",
+ iterable=user_ids,
+ keyvalues={},
+ retcols=(
+ "user_id",
+ "state",
+ "last_active_ts",
+ "last_federation_update_ts",
+ "last_user_sync_ts",
+ "status_msg",
+ "currently_active",
+ ),
+ )
+
+ for row in rows:
+ row["currently_active"] = bool(row["currently_active"])
+
+ defer.returnValue([UserPresenceState(**row) for row in rows])
+
+ def get_current_presence_token(self):
+ return self._presence_id_gen.get_max_token()
def allow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_insert(
@@ -126,6 +203,7 @@ class PresenceStore(SQLBaseStore):
desc="set_presence_list_accepted",
)
self.get_presence_list_accepted.invalidate((observer_localpart,))
+ self.get_presence_list_observers_accepted.invalidate((observed_userid,))
defer.returnValue(result)
def get_presence_list(self, observer_localpart, accepted=None):
@@ -152,6 +230,19 @@ class PresenceStore(SQLBaseStore):
desc="get_presence_list_accepted",
)
+ @cachedInlineCallbacks()
+ def get_presence_list_observers_accepted(self, observed_userid):
+ user_localparts = yield self._simple_select_onecol(
+ table="presence_list",
+ keyvalues={"observed_user_id": observed_userid, "accepted": True},
+ retcol="user_id",
+ desc="get_presence_list_accepted",
+ )
+
+ defer.returnValue([
+ "@%s:%s" % (u, self.hs.hostname,) for u in user_localparts
+ ])
+
@defer.inlineCallbacks
def del_presence_list(self, observer_localpart, observed_userid):
yield self._simple_delete_one(
@@ -161,3 +252,4 @@ class PresenceStore(SQLBaseStore):
desc="del_presence_list",
)
self.get_presence_list_accepted.invalidate((observer_localpart,))
+ self.get_presence_list_observers_accepted.invalidate((observed_userid,))
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 35ec7e8cef..9dbad2fd5f 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -65,32 +65,20 @@ class PushRuleStore(SQLBaseStore):
if not user_ids:
defer.returnValue({})
- batch_size = 100
-
- def f(txn, user_ids_to_fetch):
- sql = (
- "SELECT pr.*"
- " FROM push_rules AS pr"
- " LEFT JOIN push_rules_enable AS pre"
- " ON pr.user_name = pre.user_name AND pr.rule_id = pre.rule_id"
- " WHERE pr.user_name"
- " IN (" + ",".join("?" for _ in user_ids_to_fetch) + ")"
- " AND (pre.enabled IS NULL OR pre.enabled = 1)"
- " ORDER BY pr.user_name, pr.priority_class DESC, pr.priority DESC"
- )
- txn.execute(sql, user_ids_to_fetch)
- return self.cursor_to_dict(txn)
-
results = {}
- chunks = [user_ids[i:i+batch_size] for i in xrange(0, len(user_ids), batch_size)]
- for batch_user_ids in chunks:
- rows = yield self.runInteraction(
- "bulk_get_push_rules", f, batch_user_ids
- )
+ rows = yield self._simple_select_many_batch(
+ table="push_rules",
+ column="user_name",
+ iterable=user_ids,
+ retcols=("*",),
+ desc="bulk_get_push_rules",
+ )
+
+ rows.sort(key=lambda e: (-e["priority_class"], -e["priority"]))
- for row in rows:
- results.setdefault(row['user_name'], []).append(row)
+ for row in rows:
+ results.setdefault(row['user_name'], []).append(row)
defer.returnValue(results)
@defer.inlineCallbacks
@@ -98,62 +86,52 @@ class PushRuleStore(SQLBaseStore):
if not user_ids:
defer.returnValue({})
- batch_size = 100
-
- def f(txn, user_ids_to_fetch):
- sql = (
- "SELECT user_name, rule_id, enabled"
- " FROM push_rules_enable"
- " WHERE user_name"
- " IN (" + ",".join("?" for _ in user_ids_to_fetch) + ")"
- )
- txn.execute(sql, user_ids_to_fetch)
- return self.cursor_to_dict(txn)
-
results = {}
- chunks = [user_ids[i:i+batch_size] for i in xrange(0, len(user_ids), batch_size)]
- for batch_user_ids in chunks:
- rows = yield self.runInteraction(
- "bulk_get_push_rules_enabled", f, batch_user_ids
- )
-
- for row in rows:
- results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled']
+ rows = yield self._simple_select_many_batch(
+ table="push_rules_enable",
+ column="user_name",
+ iterable=user_ids,
+ retcols=("user_name", "rule_id", "enabled",),
+ desc="bulk_get_push_rules_enabled",
+ )
+ for row in rows:
+ results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled']
defer.returnValue(results)
@defer.inlineCallbacks
- def add_push_rule(self, before, after, **kwargs):
- vals = kwargs
- if 'conditions' in vals:
- vals['conditions'] = json.dumps(vals['conditions'])
- if 'actions' in vals:
- vals['actions'] = json.dumps(vals['actions'])
-
- # we could check the rest of the keys are valid column names
- # but sqlite will do that anyway so I think it's just pointless.
- vals.pop("id", None)
-
- if before or after:
- ret = yield self.runInteraction(
- "_add_push_rule_relative_txn",
- self._add_push_rule_relative_txn,
- before=before,
- after=after,
- **vals
- )
- defer.returnValue(ret)
- else:
- ret = yield self.runInteraction(
- "_add_push_rule_highest_priority_txn",
- self._add_push_rule_highest_priority_txn,
- **vals
- )
- defer.returnValue(ret)
-
- def _add_push_rule_relative_txn(self, txn, user_id, **kwargs):
- after = kwargs.pop("after", None)
- relative_to_rule = kwargs.pop("before", after)
+ def add_push_rule(
+ self, user_id, rule_id, priority_class, conditions, actions,
+ before=None, after=None
+ ):
+ conditions_json = json.dumps(conditions)
+ actions_json = json.dumps(actions)
+ with self._push_rules_stream_id_gen.get_next() as ids:
+ stream_id, event_stream_ordering = ids
+ if before or after:
+ yield self.runInteraction(
+ "_add_push_rule_relative_txn",
+ self._add_push_rule_relative_txn,
+ stream_id, event_stream_ordering, user_id, rule_id, priority_class,
+ conditions_json, actions_json, before, after,
+ )
+ else:
+ yield self.runInteraction(
+ "_add_push_rule_highest_priority_txn",
+ self._add_push_rule_highest_priority_txn,
+ stream_id, event_stream_ordering, user_id, rule_id, priority_class,
+ conditions_json, actions_json,
+ )
+
+ def _add_push_rule_relative_txn(
+ self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
+ conditions_json, actions_json, before, after
+ ):
+ # Lock the table since otherwise we'll have annoying races between the
+ # SELECT here and the UPSERT below.
+ self.database_engine.lock_table(txn, "push_rules")
+
+ relative_to_rule = before or after
res = self._simple_select_one_txn(
txn,
@@ -171,69 +149,45 @@ class PushRuleStore(SQLBaseStore):
"before/after rule not found: %s" % (relative_to_rule,)
)
- priority_class = res["priority_class"]
+ base_priority_class = res["priority_class"]
base_rule_priority = res["priority"]
- if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class:
+ if base_priority_class != priority_class:
raise InconsistentRuleException(
"Given priority class does not match class of relative rule"
)
- new_rule = kwargs
- new_rule.pop("before", None)
- new_rule.pop("after", None)
- new_rule['priority_class'] = priority_class
- new_rule['user_name'] = user_id
- new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
-
- # check if the priority before/after is free
- new_rule_priority = base_rule_priority
- if after:
- new_rule_priority -= 1
+ if before:
+ # Higher priority rules are executed first, So adding a rule before
+ # a rule means giving it a higher priority than that rule.
+ new_rule_priority = base_rule_priority + 1
else:
- new_rule_priority += 1
-
- new_rule['priority'] = new_rule_priority
+ # We increment the priority of the existing rules to make space for
+ # the new rule. Therefore if we want this rule to appear after
+ # an existing rule we give it the priority of the existing rule,
+ # and then increment the priority of the existing rule.
+ new_rule_priority = base_rule_priority
sql = (
- "SELECT COUNT(*) FROM push_rules"
- " WHERE user_name = ? AND priority_class = ? AND priority = ?"
+ "UPDATE push_rules SET priority = priority + 1"
+ " WHERE user_name = ? AND priority_class = ? AND priority >= ?"
)
- txn.execute(sql, (user_id, priority_class, new_rule_priority))
- res = txn.fetchall()
- num_conflicting = res[0][0]
-
- # if there are conflicting rules, bump everything
- if num_conflicting:
- sql = "UPDATE push_rules SET priority = priority "
- if after:
- sql += "-1"
- else:
- sql += "+1"
- sql += " WHERE user_name = ? AND priority_class = ? AND priority "
- if after:
- sql += "<= ?"
- else:
- sql += ">= ?"
-
- txn.execute(sql, (user_id, priority_class, new_rule_priority))
- txn.call_after(
- self.get_push_rules_for_user.invalidate, (user_id,)
- )
+ txn.execute(sql, (user_id, priority_class, new_rule_priority))
- txn.call_after(
- self.get_push_rules_enabled_for_user.invalidate, (user_id,)
+ self._upsert_push_rule_txn(
+ txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
+ new_rule_priority, conditions_json, actions_json,
)
- self._simple_insert_txn(
- txn,
- table="push_rules",
- values=new_rule,
- )
+ def _add_push_rule_highest_priority_txn(
+ self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
+ conditions_json, actions_json
+ ):
+ # Lock the table since otherwise we'll have annoying races between the
+ # SELECT here and the UPSERT below.
+ self.database_engine.lock_table(txn, "push_rules")
- def _add_push_rule_highest_priority_txn(self, txn, user_id,
- priority_class, **kwargs):
# find the highest priority rule in that class
sql = (
"SELECT COUNT(*), MAX(priority) FROM push_rules"
@@ -247,26 +201,61 @@ class PushRuleStore(SQLBaseStore):
if how_many > 0:
new_prio = highest_prio + 1
- # and insert the new rule
- new_rule = kwargs
- new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
- new_rule['user_name'] = user_id
- new_rule['priority_class'] = priority_class
- new_rule['priority'] = new_prio
-
- txn.call_after(
- self.get_push_rules_for_user.invalidate, (user_id,)
- )
- txn.call_after(
- self.get_push_rules_enabled_for_user.invalidate, (user_id,)
+ self._upsert_push_rule_txn(
+ txn,
+ stream_id, event_stream_ordering, user_id, rule_id, priority_class, new_prio,
+ conditions_json, actions_json,
)
- self._simple_insert_txn(
- txn,
- table="push_rules",
- values=new_rule,
+ def _upsert_push_rule_txn(
+ self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
+ priority, conditions_json, actions_json, update_stream=True
+ ):
+ """Specialised version of _simple_upsert_txn that picks a push_rule_id
+ using the _push_rule_id_gen if it needs to insert the rule. It assumes
+ that the "push_rules" table is locked"""
+
+ sql = (
+ "UPDATE push_rules"
+ " SET priority_class = ?, priority = ?, conditions = ?, actions = ?"
+ " WHERE user_name = ? AND rule_id = ?"
)
+ txn.execute(sql, (
+ priority_class, priority, conditions_json, actions_json,
+ user_id, rule_id,
+ ))
+
+ if txn.rowcount == 0:
+ # We didn't update a row with the given rule_id so insert one
+ push_rule_id = self._push_rule_id_gen.get_next()
+
+ self._simple_insert_txn(
+ txn,
+ table="push_rules",
+ values={
+ "id": push_rule_id,
+ "user_name": user_id,
+ "rule_id": rule_id,
+ "priority_class": priority_class,
+ "priority": priority,
+ "conditions": conditions_json,
+ "actions": actions_json,
+ },
+ )
+
+ if update_stream:
+ self._insert_push_rules_update_txn(
+ txn, stream_id, event_stream_ordering, user_id, rule_id,
+ op="ADD",
+ data={
+ "priority_class": priority_class,
+ "priority": priority,
+ "conditions": conditions_json,
+ "actions": actions_json,
+ }
+ )
+
@defer.inlineCallbacks
def delete_push_rule(self, user_id, rule_id):
"""
@@ -278,26 +267,38 @@ class PushRuleStore(SQLBaseStore):
user_id (str): The matrix ID of the push rule owner
rule_id (str): The rule_id of the rule to be deleted
"""
- yield self._simple_delete_one(
- "push_rules",
- {'user_name': user_id, 'rule_id': rule_id},
- desc="delete_push_rule",
- )
+ def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
+ self._simple_delete_one_txn(
+ txn,
+ "push_rules",
+ {'user_name': user_id, 'rule_id': rule_id},
+ )
- self.get_push_rules_for_user.invalidate((user_id,))
- self.get_push_rules_enabled_for_user.invalidate((user_id,))
+ self._insert_push_rules_update_txn(
+ txn, stream_id, event_stream_ordering, user_id, rule_id,
+ op="DELETE"
+ )
+
+ with self._push_rules_stream_id_gen.get_next() as ids:
+ stream_id, event_stream_ordering = ids
+ yield self.runInteraction(
+ "delete_push_rule", delete_push_rule_txn, stream_id, event_stream_ordering
+ )
@defer.inlineCallbacks
def set_push_rule_enabled(self, user_id, rule_id, enabled):
- ret = yield self.runInteraction(
- "_set_push_rule_enabled_txn",
- self._set_push_rule_enabled_txn,
- user_id, rule_id, enabled
- )
- defer.returnValue(ret)
+ with self._push_rules_stream_id_gen.get_next() as ids:
+ stream_id, event_stream_ordering = ids
+ yield self.runInteraction(
+ "_set_push_rule_enabled_txn",
+ self._set_push_rule_enabled_txn,
+ stream_id, event_stream_ordering, user_id, rule_id, enabled
+ )
- def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled):
- new_id = self._push_rules_enable_id_gen.get_next_txn(txn)
+ def _set_push_rule_enabled_txn(
+ self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
+ ):
+ new_id = self._push_rules_enable_id_gen.get_next()
self._simple_upsert_txn(
txn,
"push_rules_enable",
@@ -305,12 +306,109 @@ class PushRuleStore(SQLBaseStore):
{'enabled': 1 if enabled else 0},
{'id': new_id},
)
+
+ self._insert_push_rules_update_txn(
+ txn, stream_id, event_stream_ordering, user_id, rule_id,
+ op="ENABLE" if enabled else "DISABLE"
+ )
+
+ @defer.inlineCallbacks
+ def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
+ actions_json = json.dumps(actions)
+
+ def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
+ if is_default_rule:
+ # Add a dummy rule to the rules table with the user specified
+ # actions.
+ priority_class = -1
+ priority = 1
+ self._upsert_push_rule_txn(
+ txn, stream_id, event_stream_ordering, user_id, rule_id,
+ priority_class, priority, "[]", actions_json,
+ update_stream=False
+ )
+ else:
+ self._simple_update_one_txn(
+ txn,
+ "push_rules",
+ {'user_name': user_id, 'rule_id': rule_id},
+ {'actions': actions_json},
+ )
+
+ self._insert_push_rules_update_txn(
+ txn, stream_id, event_stream_ordering, user_id, rule_id,
+ op="ACTIONS", data={"actions": actions_json}
+ )
+
+ with self._push_rules_stream_id_gen.get_next() as ids:
+ stream_id, event_stream_ordering = ids
+ yield self.runInteraction(
+ "set_push_rule_actions", set_push_rule_actions_txn,
+ stream_id, event_stream_ordering
+ )
+
+ def _insert_push_rules_update_txn(
+ self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None
+ ):
+ values = {
+ "stream_id": stream_id,
+ "event_stream_ordering": event_stream_ordering,
+ "user_id": user_id,
+ "rule_id": rule_id,
+ "op": op,
+ }
+ if data is not None:
+ values.update(data)
+
+ self._simple_insert_txn(txn, "push_rules_stream", values=values)
+
txn.call_after(
self.get_push_rules_for_user.invalidate, (user_id,)
)
txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, (user_id,)
)
+ txn.call_after(
+ self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
+ )
+
+ def get_all_push_rule_updates(self, last_id, current_id, limit):
+ """Get all the push rules changes that have happend on the server"""
+ def get_all_push_rule_updates_txn(txn):
+ sql = (
+ "SELECT stream_id, event_stream_ordering, user_id, rule_id,"
+ " op, priority_class, priority, conditions, actions"
+ " FROM push_rules_stream"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_id, current_id, limit))
+ return txn.fetchall()
+ return self.runInteraction(
+ "get_all_push_rule_updates", get_all_push_rule_updates_txn
+ )
+
+ def get_push_rules_stream_token(self):
+ """Get the position of the push rules stream.
+ Returns a pair of a stream id for the push_rules stream and the
+ room stream ordering it corresponds to."""
+ return self._push_rules_stream_id_gen.get_max_token()
+
+ def have_push_rules_changed_for_user(self, user_id, last_id):
+ if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
+ return defer.succeed(False)
+ else:
+ def have_push_rules_changed_txn(txn):
+ sql = (
+ "SELECT COUNT(stream_id) FROM push_rules_stream"
+ " WHERE user_id = ? AND ? < stream_id"
+ )
+ txn.execute(sql, (user_id, last_id))
+ count, = txn.fetchone()
+ return bool(count)
+ return self.runInteraction(
+ "have_push_rules_changed", have_push_rules_changed_txn
+ )
class RuleNotFoundException(Exception):
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 8ec706178a..87b2ac5773 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -16,8 +16,6 @@
from ._base import SQLBaseStore
from twisted.internet import defer
-from synapse.api.errors import StoreError
-
from canonicaljson import encode_canonical_json
import logging
@@ -79,12 +77,41 @@ class PusherStore(SQLBaseStore):
rows = yield self.runInteraction("get_all_pushers", get_pushers)
defer.returnValue(rows)
+ def get_pushers_stream_token(self):
+ return self._pushers_id_gen.get_max_token()
+
+ def get_all_updated_pushers(self, last_id, current_id, limit):
+ def get_all_updated_pushers_txn(txn):
+ sql = (
+ "SELECT id, user_name, access_token, profile_tag, kind,"
+ " app_id, app_display_name, device_display_name, pushkey, ts,"
+ " lang, data"
+ " FROM pushers"
+ " WHERE ? < id AND id <= ?"
+ " ORDER BY id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_id, current_id, limit))
+ updated = txn.fetchall()
+
+ sql = (
+ "SELECT stream_id, user_id, app_id, pushkey"
+ " FROM deleted_pushers"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_id, current_id, limit))
+ deleted = txn.fetchall()
+
+ return (updated, deleted)
+ return self.runInteraction(
+ "get_all_updated_pushers", get_all_updated_pushers_txn
+ )
+
@defer.inlineCallbacks
- def add_pusher(self, user_id, access_token, profile_tag, kind, app_id,
+ def add_pusher(self, user_id, access_token, kind, app_id,
app_display_name, device_display_name,
- pushkey, pushkey_ts, lang, data):
- try:
- next_id = yield self._pushers_id_gen.get_next()
+ pushkey, pushkey_ts, lang, data, profile_tag=""):
+ with self._pushers_id_gen.get_next() as stream_id:
yield self._simple_upsert(
"pushers",
dict(
@@ -95,29 +122,35 @@ class PusherStore(SQLBaseStore):
dict(
access_token=access_token,
kind=kind,
- profile_tag=profile_tag,
app_display_name=app_display_name,
device_display_name=device_display_name,
ts=pushkey_ts,
lang=lang,
data=encode_canonical_json(data),
- ),
- insertion_values=dict(
- id=next_id,
+ profile_tag=profile_tag,
+ id=stream_id,
),
desc="add_pusher",
)
- except Exception as e:
- logger.error("create_pusher with failed: %s", e)
- raise StoreError(500, "Problem creating pusher.")
@defer.inlineCallbacks
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
- yield self._simple_delete_one(
- "pushers",
- {"app_id": app_id, "pushkey": pushkey, 'user_name': user_id},
- desc="delete_pusher_by_app_id_pushkey_user_id",
- )
+ def delete_pusher_txn(txn, stream_id):
+ self._simple_delete_one_txn(
+ txn,
+ "pushers",
+ {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}
+ )
+ self._simple_upsert_txn(
+ txn,
+ "deleted_pushers",
+ {"app_id": app_id, "pushkey": pushkey, "user_id": user_id},
+ {"stream_id": stream_id},
+ )
+ with self._pushers_id_gen.get_next() as stream_id:
+ yield self.runInteraction(
+ "delete_pusher", delete_pusher_txn, stream_id
+ )
@defer.inlineCallbacks
def update_pusher_last_token(self, app_id, pushkey, user_id, last_token):
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index c4232bdc65..dbc074d6b5 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -15,11 +15,10 @@
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
-from synapse.util.caches import cache_counter, caches_by_name
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from twisted.internet import defer
-from blist import sorteddict
import logging
import ujson as json
@@ -31,7 +30,9 @@ class ReceiptsStore(SQLBaseStore):
def __init__(self, hs):
super(ReceiptsStore, self).__init__(hs)
- self._receipts_stream_cache = _RoomStreamChangeCache()
+ self._receipts_stream_cache = StreamChangeCache(
+ "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token()
+ )
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
@@ -45,6 +46,20 @@ class ReceiptsStore(SQLBaseStore):
desc="get_receipts_for_room",
)
+ @cached(num_args=3)
+ def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
+ return self._simple_select_one_onecol(
+ table="receipts_linearized",
+ keyvalues={
+ "room_id": room_id,
+ "receipt_type": receipt_type,
+ "user_id": user_id
+ },
+ retcol="event_id",
+ desc="get_own_receipt_for_user",
+ allow_none=True,
+ )
+
@cachedInlineCallbacks(num_args=2)
def get_receipts_for_user(self, user_id, receipt_type):
def f(txn):
@@ -76,8 +91,8 @@ class ReceiptsStore(SQLBaseStore):
room_ids = set(room_ids)
if from_key:
- room_ids = yield self._receipts_stream_cache.get_rooms_changed(
- self, room_ids, from_key
+ room_ids = yield self._receipts_stream_cache.get_entities_changed(
+ room_ids, from_key
)
results = yield self._get_linearized_receipts_for_rooms(
@@ -207,7 +222,7 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue(results)
def get_max_receipt_stream_id(self):
- return self._receipts_id_gen.get_max_token(self)
+ return self._receipts_id_gen.get_max_token()
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id):
@@ -220,6 +235,16 @@ class ReceiptsStore(SQLBaseStore):
# FIXME: This shouldn't invalidate the whole cache
txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
+ txn.call_after(
+ self._receipts_stream_cache.entity_has_changed,
+ room_id, stream_id
+ )
+
+ txn.call_after(
+ self.get_last_receipt_event_id_for_user.invalidate,
+ (user_id, room_id, receipt_type)
+ )
+
# We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts
sql = (
@@ -305,11 +330,8 @@ class ReceiptsStore(SQLBaseStore):
"insert_receipt_conv", graph_to_linear
)
- stream_id_manager = yield self._receipts_id_gen.get_next(self)
+ stream_id_manager = self._receipts_id_gen.get_next()
with stream_id_manager as stream_id:
- yield self._receipts_stream_cache.room_has_changed(
- self, room_id, stream_id
- )
have_persisted = yield self.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
@@ -325,7 +347,7 @@ class ReceiptsStore(SQLBaseStore):
room_id, receipt_type, user_id, event_ids, data
)
- max_persisted_id = yield self._stream_id_gen.get_max_token(self)
+ max_persisted_id = self._stream_id_gen.get_max_token()
defer.returnValue((stream_id, max_persisted_id))
@@ -369,62 +391,18 @@ class ReceiptsStore(SQLBaseStore):
}
)
+ def get_all_updated_receipts(self, last_id, current_id, limit):
+ def get_all_updated_receipts_txn(txn):
+ sql = (
+ "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
+ " FROM receipts_linearized"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC"
+ " LIMIT ?"
+ )
+ txn.execute(sql, (last_id, current_id, limit))
-class _RoomStreamChangeCache(object):
- """Keeps track of the stream_id of the latest change in rooms.
-
- Given a list of rooms and stream key, it will give a subset of rooms that
- may have changed since that key. If the key is too old then the cache
- will simply return all rooms.
- """
- def __init__(self, size_of_cache=10000):
- self._size_of_cache = size_of_cache
- self._room_to_key = {}
- self._cache = sorteddict()
- self._earliest_key = None
- self.name = "ReceiptsRoomChangeCache"
- caches_by_name[self.name] = self._cache
-
- @defer.inlineCallbacks
- def get_rooms_changed(self, store, room_ids, key):
- """Returns subset of room ids that have had new receipts since the
- given key. If the key is too old it will just return the given list.
- """
- if key > (yield self._get_earliest_key(store)):
- keys = self._cache.keys()
- i = keys.bisect_right(key)
-
- result = set(
- self._cache[k] for k in keys[i:]
- ).intersection(room_ids)
-
- cache_counter.inc_hits(self.name)
- else:
- result = room_ids
- cache_counter.inc_misses(self.name)
-
- defer.returnValue(result)
-
- @defer.inlineCallbacks
- def room_has_changed(self, store, room_id, key):
- """Informs the cache that the room has been changed at the given key.
- """
- if key > (yield self._get_earliest_key(store)):
- old_key = self._room_to_key.get(room_id, None)
- if old_key:
- key = max(key, old_key)
- self._cache.pop(old_key, None)
- self._cache[key] = room_id
-
- while len(self._cache) > self._size_of_cache:
- k, r = self._cache.popitem()
- self._earliest_key = max(k, self._earliest_key)
- self._room_to_key.pop(r, None)
-
- @defer.inlineCallbacks
- def _get_earliest_key(self, store):
- if self._earliest_key is None:
- self._earliest_key = yield store.get_max_receipt_stream_id()
- self._earliest_key = int(self._earliest_key)
-
- defer.returnValue(self._earliest_key)
+ return txn.fetchall()
+ return self.runInteraction(
+ "get_all_updated_receipts", get_all_updated_receipts_txn
+ )
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 70cde0d04d..bd4eb88a92 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import re
+
from twisted.internet import defer
from synapse.api.errors import StoreError, Codes
@@ -38,7 +40,7 @@ class RegistrationStore(SQLBaseStore):
Raises:
StoreError if there was a problem adding this.
"""
- next_id = yield self._access_tokens_id_gen.get_next()
+ next_id = self._access_tokens_id_gen.get_next()
yield self._simple_insert(
"access_tokens",
@@ -60,7 +62,7 @@ class RegistrationStore(SQLBaseStore):
Raises:
StoreError if there was a problem adding this.
"""
- next_id = yield self._refresh_tokens_id_gen.get_next()
+ next_id = self._refresh_tokens_id_gen.get_next()
yield self._simple_insert(
"refresh_tokens",
@@ -74,7 +76,7 @@ class RegistrationStore(SQLBaseStore):
@defer.inlineCallbacks
def register(self, user_id, token, password_hash,
- was_guest=False, make_guest=False):
+ was_guest=False, make_guest=False, appservice_id=None):
"""Attempts to register an account.
Args:
@@ -85,19 +87,35 @@ class RegistrationStore(SQLBaseStore):
upgraded to a non-guest account.
make_guest (boolean): True if the the new user should be guest,
false to add a regular user account.
+ appservice_id (str): The ID of the appservice registering the user.
Raises:
StoreError if the user_id could not be registered.
"""
yield self.runInteraction(
"register",
- self._register, user_id, token, password_hash, was_guest, make_guest
+ self._register,
+ user_id,
+ token,
+ password_hash,
+ was_guest,
+ make_guest,
+ appservice_id
)
self.is_guest.invalidate((user_id,))
- def _register(self, txn, user_id, token, password_hash, was_guest, make_guest):
+ def _register(
+ self,
+ txn,
+ user_id,
+ token,
+ password_hash,
+ was_guest,
+ make_guest,
+ appservice_id
+ ):
now = int(self.clock.time())
- next_id = self._access_tokens_id_gen.get_next_txn(txn)
+ next_id = self._access_tokens_id_gen.get_next()
try:
if was_guest:
@@ -109,9 +127,21 @@ class RegistrationStore(SQLBaseStore):
[password_hash, now, 1 if make_guest else 0, user_id])
else:
txn.execute("INSERT INTO users "
- "(name, password_hash, creation_ts, is_guest) "
- "VALUES (?,?,?,?)",
- [user_id, password_hash, now, 1 if make_guest else 0])
+ "("
+ " name,"
+ " password_hash,"
+ " creation_ts,"
+ " is_guest,"
+ " appservice_id"
+ ") "
+ "VALUES (?,?,?,?,?)",
+ [
+ user_id,
+ password_hash,
+ now,
+ 1 if make_guest else 0,
+ appservice_id,
+ ])
except self.database_engine.module.IntegrityError:
raise StoreError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE
@@ -134,6 +164,7 @@ class RegistrationStore(SQLBaseStore):
},
retcols=["name", "password_hash", "is_guest"],
allow_none=True,
+ desc="get_user_by_id",
)
def get_users_by_id_case_insensitive(self, user_id):
@@ -164,27 +195,48 @@ class RegistrationStore(SQLBaseStore):
})
@defer.inlineCallbacks
- def user_delete_access_tokens(self, user_id):
- yield self.runInteraction(
- "user_delete_access_tokens",
- self._user_delete_access_tokens, user_id
- )
+ def user_delete_access_tokens(self, user_id, except_token_ids=[]):
+ def f(txn):
+ sql = "SELECT token FROM access_tokens WHERE user_id = ?"
+ clauses = [user_id]
- def _user_delete_access_tokens(self, txn, user_id):
- txn.execute(
- "DELETE FROM access_tokens WHERE user_id = ?",
- (user_id, )
- )
+ if except_token_ids:
+ sql += " AND id NOT IN (%s)" % (
+ ",".join(["?" for _ in except_token_ids]),
+ )
+ clauses += except_token_ids
- @defer.inlineCallbacks
- def flush_user(self, user_id):
- rows = yield self._execute(
- 'flush_user', None,
- "SELECT token FROM access_tokens WHERE user_id = ?",
- user_id
- )
- for r in rows:
- self.get_user_by_access_token.invalidate((r,))
+ txn.execute(sql, clauses)
+
+ rows = txn.fetchall()
+
+ n = 100
+ chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
+ for chunk in chunks:
+ for row in chunk:
+ txn.call_after(self.get_user_by_access_token.invalidate, (row[0],))
+
+ txn.execute(
+ "DELETE FROM access_tokens WHERE token in (%s)" % (
+ ",".join(["?" for _ in chunk]),
+ ), [r[0] for r in chunk]
+ )
+
+ yield self.runInteraction("user_delete_access_tokens", f)
+
+ def delete_access_token(self, access_token):
+ def f(txn):
+ self._simple_delete_one_txn(
+ txn,
+ table="access_tokens",
+ keyvalues={
+ "token": access_token
+ },
+ )
+
+ txn.call_after(self.get_user_by_access_token.invalidate, (access_token,))
+
+ return self.runInteraction("delete_access_token", f)
@cached()
def get_user_by_access_token(self, token):
@@ -350,3 +402,81 @@ class RegistrationStore(SQLBaseStore):
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
+
+ @defer.inlineCallbacks
+ def find_next_generated_user_id_localpart(self):
+ """
+ Gets the localpart of the next generated user ID.
+
+ Generated user IDs are integers, and we aim for them to be as small as
+ we can. Unfortunately, it's possible some of them are already taken by
+ existing users, and there may be gaps in the already taken range. This
+ function returns the start of the first allocatable gap. This is to
+ avoid the case of ID 10000000 being pre-allocated, so us wasting the
+ first (and shortest) many generated user IDs.
+ """
+ def _find_next_generated_user_id(txn):
+ txn.execute("SELECT name FROM users")
+ rows = self.cursor_to_dict(txn)
+
+ regex = re.compile("^@(\d+):")
+
+ found = set()
+
+ for r in rows:
+ user_id = r["name"]
+ match = regex.search(user_id)
+ if match:
+ found.add(int(match.group(1)))
+ for i in xrange(len(found) + 1):
+ if i not in found:
+ return i
+
+ defer.returnValue((yield self.runInteraction(
+ "find_next_generated_user_id",
+ _find_next_generated_user_id
+ )))
+
+ @defer.inlineCallbacks
+ def get_3pid_guest_access_token(self, medium, address):
+ ret = yield self._simple_select_one(
+ "threepid_guest_access_tokens",
+ {
+ "medium": medium,
+ "address": address
+ },
+ ["guest_access_token"], True, 'get_3pid_guest_access_token'
+ )
+ if ret:
+ defer.returnValue(ret["guest_access_token"])
+ defer.returnValue(None)
+
+ @defer.inlineCallbacks
+ def save_or_get_3pid_guest_access_token(
+ self, medium, address, access_token, inviter_user_id
+ ):
+ """
+ Gets the 3pid's guest access token if exists, else saves access_token.
+
+ :param medium (str): Medium of the 3pid. Must be "email".
+ :param address (str): 3pid address.
+ :param access_token (str): The access token to persist if none is
+ already persisted.
+ :param inviter_user_id (str): User ID of the inviter.
+ :return (deferred str): Whichever access token is persisted at the end
+ of this function call.
+ """
+ def insert(txn):
+ txn.execute(
+ "INSERT INTO threepid_guest_access_tokens "
+ "(medium, address, guest_access_token, first_inviter) "
+ "VALUES (?, ?, ?, ?)",
+ (medium, address, access_token, inviter_user_id)
+ )
+
+ try:
+ yield self.runInteraction("save_3pid_guest_access_token", insert)
+ defer.returnValue(access_token)
+ except self.database_engine.module.IntegrityError:
+ ret = yield self.get_3pid_guest_access_token(medium, address)
+ defer.returnValue(ret)
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index dc09a3aaba..46ab38a313 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -87,90 +87,20 @@ class RoomStore(SQLBaseStore):
desc="get_public_room_ids",
)
- @defer.inlineCallbacks
- def get_rooms(self, is_public):
- """Retrieve a list of all public rooms.
-
- Args:
- is_public (bool): True if the rooms returned should be public.
- Returns:
- A list of room dicts containing at least a "room_id" key, a
- "topic" key if one is set, and a "name" key if one is set
+ def get_room_count(self):
+ """Retrieve a list of all rooms
"""
def f(txn):
- def subquery(table_name, column_name=None):
- column_name = column_name or table_name
- return (
- "SELECT %(table_name)s.event_id as event_id, "
- "%(table_name)s.room_id as room_id, %(column_name)s "
- "FROM %(table_name)s "
- "INNER JOIN current_state_events as c "
- "ON c.event_id = %(table_name)s.event_id " % {
- "column_name": column_name,
- "table_name": table_name,
- }
- )
-
- sql = (
- "SELECT"
- " r.room_id,"
- " max(n.name),"
- " max(t.topic),"
- " max(v.history_visibility),"
- " max(g.guest_access)"
- " FROM rooms AS r"
- " LEFT JOIN (%(topic)s) AS t ON t.room_id = r.room_id"
- " LEFT JOIN (%(name)s) AS n ON n.room_id = r.room_id"
- " LEFT JOIN (%(history_visibility)s) AS v ON v.room_id = r.room_id"
- " LEFT JOIN (%(guest_access)s) AS g ON g.room_id = r.room_id"
- " WHERE r.is_public = ?"
- " GROUP BY r.room_id" % {
- "topic": subquery("topics", "topic"),
- "name": subquery("room_names", "name"),
- "history_visibility": subquery("history_visibility"),
- "guest_access": subquery("guest_access"),
- }
- )
-
- txn.execute(sql, (is_public,))
-
- rows = txn.fetchall()
-
- for i, row in enumerate(rows):
- room_id = row[0]
- aliases = self._simple_select_onecol_txn(
- txn,
- table="room_aliases",
- keyvalues={
- "room_id": room_id
- },
- retcol="room_alias",
- )
+ sql = "SELECT count(*) FROM rooms"
+ txn.execute(sql)
+ row = txn.fetchone()
+ return row[0] or 0
- rows[i] = list(row) + [aliases]
-
- return rows
-
- rows = yield self.runInteraction(
+ return self.runInteraction(
"get_rooms", f
)
- ret = [
- {
- "room_id": r[0],
- "name": r[1],
- "topic": r[2],
- "world_readable": r[3] == "world_readable",
- "guest_can_join": r[4] == "can_join",
- "aliases": r[5],
- }
- for r in rows
- if r[5] # We only return rooms that have at least one alias.
- ]
-
- defer.returnValue(ret)
-
def _store_room_topic_txn(self, txn, event):
if hasattr(event, "content") and "topic" in event.content:
self._simple_insert_txn(
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 68ac88905f..0cd89260f2 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -58,6 +58,10 @@ class RoomMemberStore(SQLBaseStore):
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
+ txn.call_after(
+ self._membership_stream_cache.entity_has_changed,
+ event.state_key, event.internal_metadata.stream_ordering
+ )
def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member.
@@ -110,6 +114,7 @@ class RoomMemberStore(SQLBaseStore):
membership=membership,
).addCallback(self._get_events)
+ @cached()
def get_invites_for_user(self, user_id):
""" Get all the invite events for a user
Args:
@@ -240,37 +245,13 @@ class RoomMemberStore(SQLBaseStore):
return rows
- @cached()
+ @cached(max_entries=5000)
def get_rooms_for_user(self, user_id):
return self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN],
)
@defer.inlineCallbacks
- def user_rooms_intersect(self, user_id_list):
- """ Checks whether all the users whose IDs are given in a list share a
- room.
-
- This is a "hot path" function that's called a lot, e.g. by presence for
- generating the event stream. As such, it is implemented locally by
- wrapping logic around heavily-cached database queries.
- """
- if len(user_id_list) < 2:
- defer.returnValue(True)
-
- deferreds = [self.get_rooms_for_user(u) for u in user_id_list]
-
- results = yield defer.DeferredList(deferreds, consumeErrors=True)
-
- # A list of sets of strings giving room IDs for each user
- room_id_lists = [set([r.room_id for r in result[1]]) for result in results]
-
- # There isn't a setintersection(*list_of_sets)
- ret = len(room_id_lists.pop(0).intersection(*room_id_lists)) > 0
-
- defer.returnValue(ret)
-
- @defer.inlineCallbacks
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
diff --git a/synapse/storage/schema/delta/28/event_push_actions.sql b/synapse/storage/schema/delta/28/event_push_actions.sql
index bdf6ae3f24..4d519849df 100644
--- a/synapse/storage/schema/delta/28/event_push_actions.sql
+++ b/synapse/storage/schema/delta/28/event_push_actions.sql
@@ -24,3 +24,4 @@ CREATE TABLE IF NOT EXISTS event_push_actions(
CREATE INDEX event_push_actions_room_id_event_id_user_id_profile_tag on event_push_actions(room_id, event_id, user_id, profile_tag);
+CREATE INDEX event_push_actions_room_id_user_id on event_push_actions(room_id, user_id);
diff --git a/synapse/storage/schema/delta/28/events_room_stream.sql b/synapse/storage/schema/delta/28/events_room_stream.sql
new file mode 100644
index 0000000000..200c35e6e2
--- /dev/null
+++ b/synapse/storage/schema/delta/28/events_room_stream.sql
@@ -0,0 +1,16 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+*/
+
+CREATE INDEX events_room_stream on events(room_id, stream_ordering);
diff --git a/synapse/storage/schema/delta/28/public_roms_index.sql b/synapse/storage/schema/delta/28/public_roms_index.sql
new file mode 100644
index 0000000000..ba62a974a4
--- /dev/null
+++ b/synapse/storage/schema/delta/28/public_roms_index.sql
@@ -0,0 +1,16 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+*/
+
+CREATE INDEX public_room_index on rooms(is_public);
diff --git a/synapse/storage/schema/delta/29/push_actions.sql b/synapse/storage/schema/delta/29/push_actions.sql
new file mode 100644
index 0000000000..7e7b09820a
--- /dev/null
+++ b/synapse/storage/schema/delta/29/push_actions.sql
@@ -0,0 +1,31 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE event_push_actions ADD COLUMN topological_ordering BIGINT;
+ALTER TABLE event_push_actions ADD COLUMN stream_ordering BIGINT;
+ALTER TABLE event_push_actions ADD COLUMN notif SMALLINT;
+ALTER TABLE event_push_actions ADD COLUMN highlight SMALLINT;
+
+UPDATE event_push_actions SET stream_ordering = (
+ SELECT stream_ordering FROM events WHERE event_id = event_push_actions.event_id
+), topological_ordering = (
+ SELECT topological_ordering FROM events WHERE event_id = event_push_actions.event_id
+);
+
+UPDATE event_push_actions SET notif = 1, highlight = 0;
+
+CREATE INDEX event_push_actions_rm_tokens on event_push_actions(
+ user_id, room_id, topological_ordering, stream_ordering
+);
diff --git a/synapse/storage/schema/delta/30/alias_creator.sql b/synapse/storage/schema/delta/30/alias_creator.sql
new file mode 100644
index 0000000000..c9d0dde638
--- /dev/null
+++ b/synapse/storage/schema/delta/30/alias_creator.sql
@@ -0,0 +1,16 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE room_aliases ADD COLUMN creator TEXT;
diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py
new file mode 100644
index 0000000000..4f6e9dd540
--- /dev/null
+++ b/synapse/storage/schema/delta/30/as_users.py
@@ -0,0 +1,68 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from synapse.storage.appservice import ApplicationServiceStore
+
+
+logger = logging.getLogger(__name__)
+
+
+def run_upgrade(cur, database_engine, config, *args, **kwargs):
+ # NULL indicates user was not registered by an appservice.
+ try:
+ cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT")
+ except:
+ # Maybe we already added the column? Hope so...
+ pass
+
+ cur.execute("SELECT name FROM users")
+ rows = cur.fetchall()
+
+ config_files = []
+ try:
+ config_files = config.app_service_config_files
+ except AttributeError:
+ logger.warning("Could not get app_service_config_files from config")
+ pass
+
+ appservices = ApplicationServiceStore.load_appservices(
+ config.server_name, config_files
+ )
+
+ owned = {}
+
+ for row in rows:
+ user_id = row[0]
+ for appservice in appservices:
+ if appservice.is_exclusive_user(user_id):
+ if user_id in owned.keys():
+ logger.error(
+ "user_id %s was owned by more than one application"
+ " service (IDs %s and %s); assigning arbitrarily to %s" %
+ (user_id, owned[user_id], appservice.id, owned[user_id])
+ )
+ owned.setdefault(appservice.id, []).append(user_id)
+
+ for as_id, user_ids in owned.items():
+ n = 100
+ user_chunks = (user_ids[i:i + 100] for i in xrange(0, len(user_ids), n))
+ for chunk in user_chunks:
+ cur.execute(
+ database_engine.convert_param_style(
+ "UPDATE users SET appservice_id = ? WHERE name IN (%s)" % (
+ ",".join("?" for _ in chunk),
+ )
+ ),
+ [as_id] + chunk
+ )
diff --git a/synapse/storage/schema/delta/30/deleted_pushers.sql b/synapse/storage/schema/delta/30/deleted_pushers.sql
new file mode 100644
index 0000000000..712c454aa1
--- /dev/null
+++ b/synapse/storage/schema/delta/30/deleted_pushers.sql
@@ -0,0 +1,25 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS deleted_pushers(
+ stream_id BIGINT NOT NULL,
+ app_id TEXT NOT NULL,
+ pushkey TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ /* We only track the most recent delete for each app_id, pushkey and user_id. */
+ UNIQUE (app_id, pushkey, user_id)
+);
+
+CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id);
diff --git a/synapse/storage/schema/delta/30/presence_stream.sql b/synapse/storage/schema/delta/30/presence_stream.sql
new file mode 100644
index 0000000000..606bbb037d
--- /dev/null
+++ b/synapse/storage/schema/delta/30/presence_stream.sql
@@ -0,0 +1,30 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+ CREATE TABLE presence_stream(
+ stream_id BIGINT,
+ user_id TEXT,
+ state TEXT,
+ last_active_ts BIGINT,
+ last_federation_update_ts BIGINT,
+ last_user_sync_ts BIGINT,
+ status_msg TEXT,
+ currently_active BOOLEAN
+ );
+
+ CREATE INDEX presence_stream_id ON presence_stream(stream_id, user_id);
+ CREATE INDEX presence_stream_user_id ON presence_stream(user_id);
+ CREATE INDEX presence_stream_state ON presence_stream(state);
diff --git a/synapse/storage/schema/delta/30/push_rule_stream.sql b/synapse/storage/schema/delta/30/push_rule_stream.sql
new file mode 100644
index 0000000000..735aa8d5f6
--- /dev/null
+++ b/synapse/storage/schema/delta/30/push_rule_stream.sql
@@ -0,0 +1,38 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+
+CREATE TABLE push_rules_stream(
+ stream_id BIGINT NOT NULL,
+ event_stream_ordering BIGINT NOT NULL,
+ user_id TEXT NOT NULL,
+ rule_id TEXT NOT NULL,
+ op TEXT NOT NULL, -- One of "ENABLE", "DISABLE", "ACTIONS", "ADD", "DELETE"
+ priority_class SMALLINT,
+ priority INTEGER,
+ conditions TEXT,
+ actions TEXT
+);
+
+-- The extra data for each operation is:
+-- * ENABLE, DISABLE, DELETE: []
+-- * ACTIONS: ["actions"]
+-- * ADD: ["priority_class", "priority", "actions", "conditions"]
+
+-- Index for replication queries.
+CREATE INDEX push_rules_stream_id ON push_rules_stream(stream_id);
+-- Index for /sync queries.
+CREATE INDEX push_rules_stream_user_stream_id on push_rules_stream(user_id, stream_id);
diff --git a/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql b/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql
new file mode 100644
index 0000000000..0dd2f1360c
--- /dev/null
+++ b/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql
@@ -0,0 +1,24 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Stores guest account access tokens generated for unbound 3pids.
+CREATE TABLE threepid_guest_access_tokens(
+ medium TEXT, -- The medium of the 3pid. Must be "email".
+ address TEXT, -- The 3pid address.
+ guest_access_token TEXT, -- The access token for a guest user for this 3pid.
+ first_inviter TEXT -- User ID of the first user to invite this 3pid to a room.
+);
+
+CREATE UNIQUE INDEX threepid_guest_access_tokens_index ON threepid_guest_access_tokens(medium, address);
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index 70c6a06cd1..b10f2a5787 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -15,7 +15,7 @@
from twisted.internet import defer
-from _base import SQLBaseStore
+from ._base import SQLBaseStore
from unpaddedbase64 import encode_base64
from synapse.crypto.event_signing import compute_event_reference_hash
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 6c32e8f7b3..8ed8a21b0a 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -83,7 +83,7 @@ class StateStore(SQLBaseStore):
if event.is_state():
state_events[(event.type, event.state_key)] = event
- state_group = self._state_groups_id_gen.get_next_txn(txn)
+ state_group = self._state_groups_id_gen.get_next()
self._simple_insert_txn(
txn,
table="state_groups",
@@ -171,41 +171,43 @@ class StateStore(SQLBaseStore):
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)
- def _get_state_groups_from_groups(self, groups_and_types):
+ def _get_state_groups_from_groups(self, groups, types):
"""Returns dictionary state_group -> state event ids
-
- Args:
- groups_and_types (list): list of 2-tuple (`group`, `types`)
"""
- def f(txn):
- results = {}
- for group, types in groups_and_types:
- if types is not None:
- where_clause = "AND (%s)" % (
- " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
- )
- else:
- where_clause = ""
-
- sql = (
- "SELECT event_id FROM state_groups_state WHERE"
- " state_group = ? %s"
- ) % (where_clause,)
+ def f(txn, groups):
+ if types is not None:
+ where_clause = "AND (%s)" % (
+ " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
+ )
+ else:
+ where_clause = ""
- args = [group]
- if types is not None:
- args.extend([i for typ in types for i in typ])
+ sql = (
+ "SELECT state_group, event_id FROM state_groups_state WHERE"
+ " state_group IN (%s) %s" % (
+ ",".join("?" for _ in groups),
+ where_clause,
+ )
+ )
- txn.execute(sql, args)
+ args = list(groups)
+ if types is not None:
+ args.extend([i for typ in types for i in typ])
- results[group] = [r[0] for r in txn.fetchall()]
+ txn.execute(sql, args)
+ rows = self.cursor_to_dict(txn)
+ results = {}
+ for row in rows:
+ results.setdefault(row["state_group"], []).append(row["event_id"])
return results
- return self.runInteraction(
- "_get_state_groups_from_groups",
- f,
- )
+ chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)]
+ for chunk in chunks:
+ return self.runInteraction(
+ "_get_state_groups_from_groups",
+ f, chunk
+ )
@defer.inlineCallbacks
def get_state_for_events(self, event_ids, types):
@@ -264,26 +266,20 @@ class StateStore(SQLBaseStore):
)
@cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids",
- num_args=1)
+ num_args=1, inlineCallbacks=True)
def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
- def f(txn):
- results = {}
- for event_id in event_ids:
- results[event_id] = self._simple_select_one_onecol_txn(
- txn,
- table="event_to_state_groups",
- keyvalues={
- "event_id": event_id,
- },
- retcol="state_group",
- allow_none=True,
- )
-
- return results
+ rows = yield self._simple_select_many_batch(
+ table="event_to_state_groups",
+ column="event_id",
+ iterable=event_ids,
+ keyvalues={},
+ retcols=("event_id", "state_group",),
+ desc="_get_state_group_for_events",
+ )
- return self.runInteraction("_get_state_group_for_events", f)
+ defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
def _get_some_state_from_cache(self, group, types):
"""Checks if group is in cache. See `_get_state_for_groups`
@@ -355,7 +351,7 @@ class StateStore(SQLBaseStore):
all events are returned.
"""
results = {}
- missing_groups_and_types = []
+ missing_groups = []
if types is not None:
for group in set(groups):
state_dict, missing_types, got_all = self._get_some_state_from_cache(
@@ -364,7 +360,7 @@ class StateStore(SQLBaseStore):
results[group] = state_dict
if not got_all:
- missing_groups_and_types.append((group, missing_types))
+ missing_groups.append(group)
else:
for group in set(groups):
state_dict, got_all = self._get_all_state_from_cache(
@@ -373,9 +369,9 @@ class StateStore(SQLBaseStore):
results[group] = state_dict
if not got_all:
- missing_groups_and_types.append((group, None))
+ missing_groups.append(group)
- if not missing_groups_and_types:
+ if not missing_groups:
defer.returnValue({
group: {
type_tuple: event
@@ -389,7 +385,7 @@ class StateStore(SQLBaseStore):
cache_seq_num = self._state_group_cache.sequence
group_state_dict = yield self._get_state_groups_from_groups(
- missing_groups_and_types
+ missing_groups, types
)
state_events = yield self._get_events(
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 02b1913e26..7f4a827528 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -39,7 +39,7 @@ from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken
-from synapse.util.logutils import log_function
+from synapse.util.logcontext import preserve_fn
import logging
@@ -77,7 +77,6 @@ def upper_bound(token):
class StreamStore(SQLBaseStore):
-
@defer.inlineCallbacks
def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
# NB this lives here instead of appservice.py so we can reuse the
@@ -157,7 +156,153 @@ class StreamStore(SQLBaseStore):
results = yield self.runInteraction("get_appservice_room_stream", f)
defer.returnValue(results)
- @log_function
+ @defer.inlineCallbacks
+ def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0,
+ order='DESC'):
+ from_id = RoomStreamToken.parse_stream_token(from_key).stream
+
+ room_ids = yield self._events_stream_cache.get_entities_changed(
+ room_ids, from_id
+ )
+
+ if not room_ids:
+ defer.returnValue({})
+
+ results = {}
+ room_ids = list(room_ids)
+ for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
+ res = yield defer.gatherResults([
+ preserve_fn(self.get_room_events_stream_for_room)(
+ room_id, from_key, to_key, limit, order=order,
+ )
+ for room_id in rm_ids
+ ])
+ results.update(dict(zip(rm_ids, res)))
+
+ defer.returnValue(results)
+
+ @defer.inlineCallbacks
+ def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0,
+ order='DESC'):
+ # Note: If from_key is None then we return in topological order. This
+ # is because in that case we're using this as a "get the last few messages
+ # in a room" function, rather than "get new messages since last sync"
+ if from_key is not None:
+ from_id = RoomStreamToken.parse_stream_token(from_key).stream
+ else:
+ from_id = None
+ to_id = RoomStreamToken.parse_stream_token(to_key).stream
+
+ if from_key == to_key:
+ defer.returnValue(([], from_key))
+
+ if from_id:
+ has_changed = yield self._events_stream_cache.has_entity_changed(
+ room_id, from_id
+ )
+
+ if not has_changed:
+ defer.returnValue(([], from_key))
+
+ def f(txn):
+ if from_id is not None:
+ sql = (
+ "SELECT event_id, stream_ordering FROM events WHERE"
+ " room_id = ?"
+ " AND not outlier"
+ " AND stream_ordering > ? AND stream_ordering <= ?"
+ " ORDER BY stream_ordering %s LIMIT ?"
+ ) % (order,)
+ txn.execute(sql, (room_id, from_id, to_id, limit))
+ else:
+ sql = (
+ "SELECT event_id, stream_ordering FROM events WHERE"
+ " room_id = ?"
+ " AND not outlier"
+ " AND stream_ordering <= ?"
+ " ORDER BY topological_ordering %s, stream_ordering %s LIMIT ?"
+ ) % (order, order,)
+ txn.execute(sql, (room_id, to_id, limit))
+
+ rows = self.cursor_to_dict(txn)
+
+ return rows
+
+ rows = yield self.runInteraction("get_room_events_stream_for_room", f)
+
+ ret = yield self._get_events(
+ [r["event_id"] for r in rows],
+ get_prev_content=True
+ )
+
+ self._set_before_and_after(ret, rows, topo_order=from_id is None)
+
+ if order.lower() == "desc":
+ ret.reverse()
+
+ if rows:
+ key = "s%d" % min(r["stream_ordering"] for r in rows)
+ else:
+ # Assume we didn't get anything because there was nothing to
+ # get.
+ key = from_key
+
+ defer.returnValue((ret, key))
+
+ @defer.inlineCallbacks
+ def get_membership_changes_for_user(self, user_id, from_key, to_key):
+ if from_key is not None:
+ from_id = RoomStreamToken.parse_stream_token(from_key).stream
+ else:
+ from_id = None
+ to_id = RoomStreamToken.parse_stream_token(to_key).stream
+
+ if from_key == to_key:
+ defer.returnValue([])
+
+ if from_id:
+ has_changed = self._membership_stream_cache.has_entity_changed(
+ user_id, int(from_id)
+ )
+ if not has_changed:
+ defer.returnValue([])
+
+ def f(txn):
+ if from_id is not None:
+ sql = (
+ "SELECT m.event_id, stream_ordering FROM events AS e,"
+ " room_memberships AS m"
+ " WHERE e.event_id = m.event_id"
+ " AND m.user_id = ?"
+ " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
+ " ORDER BY e.stream_ordering ASC"
+ )
+ txn.execute(sql, (user_id, from_id, to_id,))
+ else:
+ sql = (
+ "SELECT m.event_id, stream_ordering FROM events AS e,"
+ " room_memberships AS m"
+ " WHERE e.event_id = m.event_id"
+ " AND m.user_id = ?"
+ " AND stream_ordering <= ?"
+ " ORDER BY stream_ordering ASC"
+ )
+ txn.execute(sql, (user_id, to_id,))
+ rows = self.cursor_to_dict(txn)
+
+ return rows
+
+ rows = yield self.runInteraction("get_membership_changes_for_user", f)
+
+ ret = yield self._get_events(
+ [r["event_id"] for r in rows],
+ get_prev_content=True
+ )
+
+ self._set_before_and_after(ret, rows, topo_order=False)
+
+ defer.returnValue(ret)
+
def get_room_events_stream(
self,
user_id,
@@ -174,7 +319,8 @@ class StreamStore(SQLBaseStore):
"SELECT c.room_id FROM history_visibility AS h"
" INNER JOIN current_state_events AS c"
" ON h.event_id = c.event_id"
- " WHERE c.room_id IN (%s) AND h.history_visibility = 'world_readable'" % (
+ " WHERE c.room_id IN (%s)"
+ " AND h.history_visibility = 'world_readable'" % (
",".join(map(lambda _: "?", room_ids))
)
)
@@ -187,11 +333,6 @@ class StreamStore(SQLBaseStore):
" WHERE m.user_id = ? AND m.membership = 'join'"
)
current_room_membership_args = [user_id]
- if room_ids:
- current_room_membership_sql += " AND m.room_id in (%s)" % (
- ",".join(map(lambda _: "?", room_ids))
- )
- current_room_membership_args = [user_id] + room_ids
# We also want to get any membership events about that user, e.g.
# invites or leave notifications.
@@ -393,7 +534,7 @@ class StreamStore(SQLBaseStore):
@defer.inlineCallbacks
def get_room_events_max_id(self, direction='f'):
- token = yield self._stream_id_gen.get_max_token(self)
+ token = yield self._stream_id_gen.get_max_token()
if direction != 'b':
defer.returnValue("s%d" % (token,))
else:
@@ -430,10 +571,23 @@ class StreamStore(SQLBaseStore):
table="events",
keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"),
+ desc="get_topological_token_for_event",
).addCallback(lambda row: "t%d-%d" % (
row["topological_ordering"], row["stream_ordering"],)
)
+ def get_max_topological_token_for_stream_and_room(self, room_id, stream_key):
+ sql = (
+ "SELECT max(topological_ordering) FROM events"
+ " WHERE room_id = ? AND stream_ordering < ?"
+ )
+ return self._execute(
+ "get_max_topological_token_for_stream_and_room", None,
+ sql, room_id, stream_key,
+ ).addCallback(
+ lambda r: r[0][0] if r else 0
+ )
+
def _get_max_topological_txn(self, txn):
txn.execute(
"SELECT MAX(topological_ordering) FROM events"
@@ -444,27 +598,21 @@ class StreamStore(SQLBaseStore):
rows = txn.fetchall()
return rows[0][0] if rows else 0
- @defer.inlineCallbacks
- def _get_min_token(self):
- row = yield self._execute(
- "_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
- )
-
- self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
- self.min_token = min(self.min_token, -1)
-
- logger.debug("min_token is: %s", self.min_token)
-
- defer.returnValue(self.min_token)
-
@staticmethod
- def _set_before_and_after(events, rows):
+ def _set_before_and_after(events, rows, topo_order=True):
for event, row in zip(events, rows):
stream = row["stream_ordering"]
- topo = event.depth
+ if topo_order:
+ topo = event.depth
+ else:
+ topo = None
internal = event.internal_metadata
internal.before = str(RoomStreamToken(topo, stream - 1))
internal.after = str(RoomStreamToken(topo, stream))
+ internal.order = (
+ int(topo) if topo else 0,
+ int(stream),
+ )
@defer.inlineCallbacks
def get_events_around(self, room_id, event_id, before_limit, after_limit):
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index ed9c91e5ea..a0e6b42b30 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -16,7 +16,6 @@
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from twisted.internet import defer
-from .util.id_generators import StreamIdGenerator
import ujson as json
import logging
@@ -25,20 +24,13 @@ logger = logging.getLogger(__name__)
class TagsStore(SQLBaseStore):
- def __init__(self, hs):
- super(TagsStore, self).__init__(hs)
-
- self._account_data_id_gen = StreamIdGenerator(
- "account_data_max_stream_id", "stream_id"
- )
-
def get_max_account_data_stream_id(self):
"""Get the current max stream id for the private user data stream
Returns:
A deferred int.
"""
- return self._account_data_id_gen.get_max_token(self)
+ return self._account_data_id_gen.get_max_token()
@cached()
def get_tags_for_user(self, user_id):
@@ -67,6 +59,59 @@ class TagsStore(SQLBaseStore):
return deferred
@defer.inlineCallbacks
+ def get_all_updated_tags(self, last_id, current_id, limit):
+ """Get all the client tags that have changed on the server
+ Args:
+ last_id(int): The position to fetch from.
+ current_id(int): The position to fetch up to.
+ Returns:
+ A deferred list of tuples of stream_id int, user_id string,
+ room_id string, tag string and content string.
+ """
+ def get_all_updated_tags_txn(txn):
+ sql = (
+ "SELECT stream_id, user_id, room_id"
+ " FROM room_tags_revisions as r"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_id, current_id, limit))
+ return txn.fetchall()
+
+ tag_ids = yield self.runInteraction(
+ "get_all_updated_tags", get_all_updated_tags_txn
+ )
+
+ def get_tag_content(txn, tag_ids):
+ sql = (
+ "SELECT tag, content"
+ " FROM room_tags"
+ " WHERE user_id=? AND room_id=?"
+ )
+ results = []
+ for stream_id, user_id, room_id in tag_ids:
+ txn.execute(sql, (user_id, room_id))
+ tags = []
+ for tag, content in txn.fetchall():
+ tags.append(json.dumps(tag) + ":" + content)
+ tag_json = "{" + ",".join(tags) + "}"
+ results.append((stream_id, user_id, room_id, tag_json))
+
+ return results
+
+ batch_size = 50
+ results = []
+ for i in xrange(0, len(tag_ids), batch_size):
+ tags = yield self.runInteraction(
+ "get_all_updated_tag_content",
+ get_tag_content,
+ tag_ids[i:i + batch_size],
+ )
+ results.extend(tags)
+
+ defer.returnValue(results)
+
+ @defer.inlineCallbacks
def get_updated_tags(self, user_id, stream_id):
"""Get all the tags for the rooms where the tags have changed since the
given version
@@ -87,6 +132,12 @@ class TagsStore(SQLBaseStore):
room_ids = [row[0] for row in txn.fetchall()]
return room_ids
+ changed = self._account_data_stream_cache.has_entity_changed(
+ user_id, int(stream_id)
+ )
+ if not changed:
+ defer.returnValue({})
+
room_ids = yield self.runInteraction(
"get_updated_tags", get_updated_tags_txn
)
@@ -144,12 +195,12 @@ class TagsStore(SQLBaseStore):
)
self._update_revision_txn(txn, user_id, room_id, next_id)
- with (yield self._account_data_id_gen.get_next(self)) as next_id:
+ with self._account_data_id_gen.get_next() as next_id:
yield self.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
- result = yield self._account_data_id_gen.get_max_token(self)
+ result = self._account_data_id_gen.get_max_token()
defer.returnValue(result)
@defer.inlineCallbacks
@@ -166,12 +217,12 @@ class TagsStore(SQLBaseStore):
txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id)
- with (yield self._account_data_id_gen.get_next(self)) as next_id:
+ with self._account_data_id_gen.get_next() as next_id:
yield self.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
- result = yield self._account_data_id_gen.get_max_token(self)
+ result = self._account_data_id_gen.get_max_token()
defer.returnValue(result)
def _update_revision_txn(self, txn, user_id, room_id, next_id):
@@ -184,6 +235,11 @@ class TagsStore(SQLBaseStore):
next_id(int): The the revision to advance to.
"""
+ txn.call_after(
+ self._account_data_stream_cache.entity_has_changed,
+ user_id, next_id
+ )
+
update_max_id_sql = (
"UPDATE account_data_max_stream_id"
" SET stream_id = ?"
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 4475c451c1..d338dfcf0a 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -117,7 +117,7 @@ class TransactionStore(SQLBaseStore):
def _prep_send_transaction(self, txn, transaction_id, destination,
origin_server_ts):
- next_id = self._transaction_id_gen.get_next_txn(txn)
+ next_id = self._transaction_id_gen.get_next()
# First we find out what the prev_txns should be.
# Since we know that we are only sending one transaction at a time,
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index f58bf7fd2c..a02dfc7d58 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -13,51 +13,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
from collections import deque
import contextlib
import threading
class IdGenerator(object):
- def __init__(self, table, column, store):
- self.table = table
- self.column = column
- self.store = store
+ def __init__(self, db_conn, table, column):
self._lock = threading.Lock()
- self._next_id = None
+ self._next_id = _load_max_id(db_conn, table, column)
- @defer.inlineCallbacks
def get_next(self):
- if self._next_id is None:
- yield self.store.runInteraction(
- "IdGenerator_%s" % (self.table,),
- self.get_next_txn,
- )
-
with self._lock:
- i = self._next_id
self._next_id += 1
- defer.returnValue(i)
+ return self._next_id
- def get_next_txn(self, txn):
- with self._lock:
- if self._next_id:
- i = self._next_id
- self._next_id += 1
- return i
- else:
- txn.execute(
- "SELECT MAX(%s) FROM %s" % (self.column, self.table,)
- )
-
- val, = txn.fetchone()
- cur = val or 0
- cur += 1
- self._next_id = cur + 1
- return cur
+def _load_max_id(db_conn, table, column):
+ cur = db_conn.cursor()
+ cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
+ val, = cur.fetchone()
+ cur.close()
+ return int(val) if val else 1
class StreamIdGenerator(object):
@@ -69,31 +46,25 @@ class StreamIdGenerator(object):
persistence of events can complete out of order.
Usage:
- with stream_id_gen.get_next_txn(txn) as stream_id:
+ with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
- def __init__(self, table, column):
- self.table = table
- self.column = column
-
+ def __init__(self, db_conn, table, column, extra_tables=[]):
self._lock = threading.Lock()
-
- self._current_max = None
+ self._current_max = _load_max_id(db_conn, table, column)
+ for table, column in extra_tables:
+ self._current_max = max(
+ self._current_max,
+ _load_max_id(db_conn, table, column)
+ )
self._unfinished_ids = deque()
- @defer.inlineCallbacks
- def get_next(self, store):
+ def get_next(self):
"""
Usage:
- with yield stream_id_gen.get_next as stream_id:
+ with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
- if not self._current_max:
- yield store.runInteraction(
- "_compute_current_max",
- self._get_or_compute_current_max,
- )
-
with self._lock:
self._current_max += 1
next_id = self._current_max
@@ -108,21 +79,14 @@ class StreamIdGenerator(object):
with self._lock:
self._unfinished_ids.remove(next_id)
- defer.returnValue(manager())
+ return manager()
- @defer.inlineCallbacks
- def get_next_mult(self, store, n):
+ def get_next_mult(self, n):
"""
Usage:
- with yield stream_id_gen.get_next(store, n) as stream_ids:
+ with stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
- if not self._current_max:
- yield store.runInteraction(
- "_compute_current_max",
- self._get_or_compute_current_max,
- )
-
with self._lock:
next_ids = range(self._current_max + 1, self._current_max + n + 1)
self._current_max += n
@@ -139,31 +103,61 @@ class StreamIdGenerator(object):
for next_id in next_ids:
self._unfinished_ids.remove(next_id)
- defer.returnValue(manager())
+ return manager()
- @defer.inlineCallbacks
- def get_max_token(self, store):
+ def get_max_token(self):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
- if not self._current_max:
- yield store.runInteraction(
- "_compute_current_max",
- self._get_or_compute_current_max,
- )
-
with self._lock:
if self._unfinished_ids:
- defer.returnValue(self._unfinished_ids[0] - 1)
+ return self._unfinished_ids[0] - 1
+
+ return self._current_max
+
+
+class ChainedIdGenerator(object):
+ """Used to generate new stream ids where the stream must be kept in sync
+ with another stream. It generates pairs of IDs, the first element is an
+ integer ID for this stream, the second element is the ID for the stream
+ that this stream needs to be kept in sync with."""
- defer.returnValue(self._current_max)
+ def __init__(self, chained_generator, db_conn, table, column):
+ self.chained_generator = chained_generator
+ self._lock = threading.Lock()
+ self._current_max = _load_max_id(db_conn, table, column)
+ self._unfinished_ids = deque()
- def _get_or_compute_current_max(self, txn):
+ def get_next(self):
+ """
+ Usage:
+ with stream_id_gen.get_next() as (stream_id, chained_id):
+ # ... persist event ...
+ """
with self._lock:
- txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
- rows = txn.fetchall()
- val, = rows[0]
+ self._current_max += 1
+ next_id = self._current_max
+ chained_id = self.chained_generator.get_max_token()
- self._current_max = int(val) if val else 1
+ self._unfinished_ids.append((next_id, chained_id))
- return self._current_max
+ @contextlib.contextmanager
+ def manager():
+ try:
+ yield (next_id, chained_id)
+ finally:
+ with self._lock:
+ self._unfinished_ids.remove((next_id, chained_id))
+
+ return manager()
+
+ def get_max_token(self):
+ """Returns the maximum stream id such that all stream ids less than or
+ equal to it have been successfully persisted.
+ """
+ with self._lock:
+ if self._unfinished_ids:
+ stream_id, chained_id = self._unfinished_ids[0]
+ return (stream_id - 1, chained_id)
+
+ return (self._current_max, self.chained_generator.get_max_token())
|