From ddd25def01e3909a34c52954100763bb2a91f648 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 25 Jan 2016 13:36:02 +0000 Subject: Implement a _simple_select_many_batch --- synapse/storage/_base.py | 67 ++++++++++++++++++++++++++++++++++++++++++++ synapse/storage/presence.py | 35 ++++++++++++----------- synapse/storage/push_rule.py | 63 +++++++++++++---------------------------- 3 files changed, 105 insertions(+), 60 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 183a752387..897c5d8d73 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -629,6 +629,73 @@ class SQLBaseStore(object): return self.cursor_to_dict(txn) + @defer.inlineCallbacks + def _simple_select_many_batch(self, table, column, iterable, retcols, + keyvalues={}, desc="_simple_select_many_batch", + batch_size=100): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Filters rows by if value of `column` is in `iterable`. + + Args: + txn : Transaction object + table : string giving the table name + column : column name to test for inclusion against `iterable` + iterable : list + keyvalues : dict of column names and values to select the rows with + retcols : list of strings giving the names of the columns to return + """ + results = [] + + chunks = [iterable[i:i+batch_size] for i in xrange(0, len(iterable), batch_size)] + for chunk in chunks: + rows = yield self.runInteraction( + desc, + self._simple_select_many_txn, + table, column, chunk, keyvalues, retcols + ) + + results.extend(rows) + + defer.returnValue(results) + + def _simple_select_many_txn(self, txn, table, column, iterable, keyvalues, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Filters rows by if value of `column` is in `iterable`. + + Args: + txn : Transaction object + table : string giving the table name + column : column name to test for inclusion against `iterable` + iterable : list + keyvalues : dict of column names and values to select the rows with + retcols : list of strings giving the names of the columns to return + """ + sql = "SELECT %s FROM %s" % (", ".join(retcols), table) + + clauses = [] + values = [] + clauses.append( + "%s IN (%s)" % (column, ",".join("?" for _ in iterable)) + ) + values.extend(iterable) + + for key, value in keyvalues.items(): + clauses.append("%s = ?" % (key,)) + values.append(value) + + if clauses: + sql = "%s WHERE %s" % ( + sql, + " AND ".join(clauses), + ) + + txn.execute(sql, values) + return self.cursor_to_dict(txn) + def _simple_update_one(self, table, keyvalues, updatevalues, desc="_simple_update_one"): """Executes an UPDATE query on the named table, setting new values for diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 1095d52ace..9b3aecaf8c 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -48,24 +48,25 @@ class PresenceStore(SQLBaseStore): desc="get_presence_state", ) - @cachedList(get_presence_state.cache, list_name="user_localparts") + @cachedList(get_presence_state.cache, list_name="user_localparts", + inlineCallbacks=True) def get_presence_states(self, user_localparts): - def f(txn): - results = {} - for user_localpart in user_localparts: - res = self._simple_select_one_txn( - txn, - table="presence", - keyvalues={"user_id": user_localpart}, - retcols=["state", "status_msg", "mtime"], - allow_none=True, - ) - if res: - results[user_localpart] = res - - return results - - return self.runInteraction("get_presence_states", f) + rows = yield self._simple_select_many_batch( + table="presence", + column="user_id", + iterable=user_localparts, + retcols=("user_id", "state", "status_msg", "mtime",), + desc="get_presence_states", + ) + + defer.returnValue({ + row["user_id"]: { + "state": row["state"], + "status_msg": row["status_msg"], + "mtime": row["mtime"], + } + for row in rows + }) def set_presence_state(self, user_localpart, new_state): res = self._simple_update_one( diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 35ec7e8cef..1f51c90ee5 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -65,32 +65,20 @@ class PushRuleStore(SQLBaseStore): if not user_ids: defer.returnValue({}) - batch_size = 100 - - def f(txn, user_ids_to_fetch): - sql = ( - "SELECT pr.*" - " FROM push_rules AS pr" - " LEFT JOIN push_rules_enable AS pre" - " ON pr.user_name = pre.user_name AND pr.rule_id = pre.rule_id" - " WHERE pr.user_name" - " IN (" + ",".join("?" for _ in user_ids_to_fetch) + ")" - " AND (pre.enabled IS NULL OR pre.enabled = 1)" - " ORDER BY pr.user_name, pr.priority_class DESC, pr.priority DESC" - ) - txn.execute(sql, user_ids_to_fetch) - return self.cursor_to_dict(txn) - results = {} - chunks = [user_ids[i:i+batch_size] for i in xrange(0, len(user_ids), batch_size)] - for batch_user_ids in chunks: - rows = yield self.runInteraction( - "bulk_get_push_rules", f, batch_user_ids - ) + rows = yield self._simple_select_many_batch( + table="push_rules", + column="user_name", + iterable=user_ids, + retcols=("*",), + desc="bulk_get_push_rules", + ) + + rows.sort(key=lambda e: (-e["priority_class"], -e["priority"])) - for row in rows: - results.setdefault(row['user_name'], []).append(row) + for row in rows: + results.setdefault(row['user_name'], []).append(row) defer.returnValue(results) @defer.inlineCallbacks @@ -98,28 +86,17 @@ class PushRuleStore(SQLBaseStore): if not user_ids: defer.returnValue({}) - batch_size = 100 - - def f(txn, user_ids_to_fetch): - sql = ( - "SELECT user_name, rule_id, enabled" - " FROM push_rules_enable" - " WHERE user_name" - " IN (" + ",".join("?" for _ in user_ids_to_fetch) + ")" - ) - txn.execute(sql, user_ids_to_fetch) - return self.cursor_to_dict(txn) - results = {} - chunks = [user_ids[i:i+batch_size] for i in xrange(0, len(user_ids), batch_size)] - for batch_user_ids in chunks: - rows = yield self.runInteraction( - "bulk_get_push_rules_enabled", f, batch_user_ids - ) - - for row in rows: - results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled'] + rows = yield self._simple_select_many_batch( + table="push_rules_enable", + column="user_name", + iterable=user_ids, + retcols=("user_name", "rule_id", "enabled",), + desc="bulk_get_push_rules_enabled", + ) + for row in rows: + results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled'] defer.returnValue(results) @defer.inlineCallbacks -- cgit 1.4.1 From 53cb17366391a13e8d8297c9f4fb3f77fd6b85b7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 25 Jan 2016 13:53:05 +0000 Subject: Push: Use storage apis that are cached --- synapse/push/__init__.py | 30 +++++++++++++----------------- synapse/storage/roommember.py | 1 + 2 files changed, 14 insertions(+), 17 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index e6a28bd8c0..9bc0b356f4 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -17,7 +17,6 @@ from twisted.internet import defer from synapse.streams.config import PaginationConfig from synapse.types import StreamToken -from synapse.api.constants import Membership import synapse.util.async import push_rule_evaluator as push_rule_evaluator @@ -296,31 +295,28 @@ class Pusher(object): @defer.inlineCallbacks def _get_badge_count(self): - room_list = yield self.store.get_rooms_for_user_where_membership_is( - user_id=self.user_id, - membership_list=(Membership.INVITE, Membership.JOIN) - ) + invites, joins = yield defer.gatherResults([ + self.store.get_invites_for_user(self.user_id), + self.store.get_rooms_for_user(self.user_id), + ], consumeErrors=True) my_receipts_by_room = yield self.store.get_receipts_for_user( self.user_id, "m.read", ) - badge = 0 + badge = len(invites) - for r in room_list: - if r.membership == Membership.INVITE: - badge += 1 - else: - if r.room_id in my_receipts_by_room: - last_unread_event_id = my_receipts_by_room[r.room_id] + for r in joins: + if r.room_id in my_receipts_by_room: + last_unread_event_id = my_receipts_by_room[r.room_id] - notifs = yield ( - self.store.get_unread_event_push_actions_by_room_for_user( - r.room_id, self.user_id, last_unread_event_id - ) + notifs = yield ( + self.store.get_unread_event_push_actions_by_room_for_user( + r.room_id, self.user_id, last_unread_event_id ) - badge += len(notifs) + ) + badge += len(notifs) defer.returnValue(badge) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 68ac88905f..edfecced05 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -110,6 +110,7 @@ class RoomMemberStore(SQLBaseStore): membership=membership, ).addCallback(self._get_events) + @cached() def get_invites_for_user(self, user_id): """ Get all the invite events for a user Args: -- cgit 1.4.1 From 86896408b0a3c0d0c82356c8cc0bdaf3fe236b45 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 25 Jan 2016 15:30:32 +0000 Subject: Add index to event_push_actions --- synapse/storage/schema/delta/28/event_push_actions.sql | 1 + 1 file changed, 1 insertion(+) (limited to 'synapse/storage') diff --git a/synapse/storage/schema/delta/28/event_push_actions.sql b/synapse/storage/schema/delta/28/event_push_actions.sql index bdf6ae3f24..4d519849df 100644 --- a/synapse/storage/schema/delta/28/event_push_actions.sql +++ b/synapse/storage/schema/delta/28/event_push_actions.sql @@ -24,3 +24,4 @@ CREATE TABLE IF NOT EXISTS event_push_actions( CREATE INDEX event_push_actions_room_id_event_id_user_id_profile_tag on event_push_actions(room_id, event_id, user_id, profile_tag); +CREATE INDEX event_push_actions_room_id_user_id on event_push_actions(room_id, user_id); -- cgit 1.4.1 From 1ebf5e3d03a2a2ce9ff278b2eac07acc0f7cde66 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 25 Jan 2016 15:53:36 +0000 Subject: Correct docstring --- synapse/storage/_base.py | 1 - 1 file changed, 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 897c5d8d73..304ebdc825 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -639,7 +639,6 @@ class SQLBaseStore(object): Filters rows by if value of `column` is in `iterable`. Args: - txn : Transaction object table : string giving the table name column : column name to test for inclusion against `iterable` iterable : list -- cgit 1.4.1 From aea5da0ef6f4d3907ace2c1fdba743312118660c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 25 Jan 2016 15:59:29 +0000 Subject: Guard against empty iterables --- synapse/storage/_base.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'synapse/storage') diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 304ebdc825..90d7aee94a 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -647,6 +647,9 @@ class SQLBaseStore(object): """ results = [] + if not iterable: + defer.returnValue(results) + chunks = [iterable[i:i+batch_size] for i in xrange(0, len(iterable), batch_size)] for chunk in chunks: rows = yield self.runInteraction( @@ -673,6 +676,9 @@ class SQLBaseStore(object): keyvalues : dict of column names and values to select the rows with retcols : list of strings giving the names of the columns to return """ + if not iterable: + return [] + sql = "SELECT %s FROM %s" % (", ".join(retcols), table) clauses = [] -- cgit 1.4.1 From 87f9477b105b4e8216d1df186492ec6d9872967f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 26 Jan 2016 15:51:06 +0000 Subject: Add a Homeserver.setup method. This is for setting up dependencies that require work on startup. This is useful for the DataStore that wants to read a bunch from the database before initiliazing. --- synapse/app/homeserver.py | 33 ++++++++++++++--------- synapse/server.py | 32 ++++++++++++----------- synapse/storage/__init__.py | 45 +++++++++++++++++++++++++++++--- synapse/storage/_base.py | 49 ++++++++++++++++------------------- synapse/storage/events.py | 14 ++++------ synapse/storage/receipts.py | 8 +++--- synapse/storage/stream.py | 13 ---------- synapse/storage/tags.py | 7 ----- synapse/storage/util/id_generators.py | 36 +++++++------------------ 9 files changed, 121 insertions(+), 116 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 795c655ae3..fb76be58a2 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -254,6 +254,17 @@ class SynapseHomeServer(HomeServer): except IncorrectDatabaseSetup as e: quit_with_error(e.message) + def get_db_conn(self): + db_conn = self.database_engine.module.connect( + **{ + k: v for k, v in self.db_config.get("args", {}).items() + if not k.startswith("cp_") + } + ) + + self.database_engine.on_new_connection(db_conn) + return db_conn + def quit_with_error(error_string): message_lines = error_string.split("\n") @@ -390,13 +401,7 @@ def setup(config_options): logger.info("Preparing database: %s...", config.database_config['name']) try: - db_conn = database_engine.module.connect( - **{ - k: v for k, v in config.database_config.get("args", {}).items() - if not k.startswith("cp_") - } - ) - + db_conn = hs.get_db_conn() database_engine.prepare_database(db_conn) hs.run_startup_checks(db_conn, database_engine) @@ -411,13 +416,17 @@ def setup(config_options): logger.info("Database prepared in %s.", config.database_config['name']) + hs.setup() hs.start_listening() - hs.get_pusherpool().start() - hs.get_state_handler().start_caching() - hs.get_datastore().start_profiling() - hs.get_datastore().start_doing_background_updates() - hs.get_replication_layer().start_get_pdu_cache() + def start(): + hs.get_pusherpool().start() + hs.get_state_handler().start_caching() + hs.get_datastore().start_profiling() + hs.get_datastore().start_doing_background_updates() + hs.get_replication_layer().start_get_pdu_cache() + + reactor.callWhenRunning(start) return hs diff --git a/synapse/server.py b/synapse/server.py index a59e46ca2d..006e91b37c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -21,6 +21,7 @@ # Imports required for the default HomeServer() implementation from twisted.web.client import BrowserLikePolicyForHTTPS from twisted.enterprise import adbapi +from twisted.internet import defer from synapse.federation import initialize_http_replication from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory @@ -28,7 +29,7 @@ from synapse.notifier import Notifier from synapse.api.auth import Auth from synapse.handlers import Handlers from synapse.state import StateHandler -from synapse.storage import DataStore +from synapse.storage import get_datastore from synapse.util import Clock from synapse.util.distributor import Distributor from synapse.streams.events import EventSources @@ -40,6 +41,11 @@ from synapse.api.filtering import Filtering from synapse.http.matrixfederationclient import MatrixFederationHttpClient +import logging + + +logger = logging.getLogger(__name__) + class HomeServer(object): """A basic homeserver object without lazy component builders. @@ -102,10 +108,19 @@ class HomeServer(object): self.hostname = hostname self._building = {} + self.clock = Clock() + self.distributor = Distributor() + self.ratelimiter = Ratelimiter() + # Other kwargs are explicit dependencies for depname in kwargs: setattr(self, depname, kwargs[depname]) + def setup(self): + logger.info("Setting up.") + self.datastore = get_datastore(self) + logger.info("Finished setting up.") + def get_ip_from_request(self, request): # X-Forwarded-For is handled by our custom request type. return request.getClientIP() @@ -116,15 +131,9 @@ class HomeServer(object): def is_mine_id(self, string): return string.split(":", 1)[1] == self.hostname - def build_clock(self): - return Clock() - def build_replication_layer(self): return initialize_http_replication(self) - def build_datastore(self): - return DataStore(self) - def build_handlers(self): return Handlers(self) @@ -135,10 +144,9 @@ class HomeServer(object): return Auth(self) def build_http_client_context_factory(self): - config = self.get_config() return ( InsecureInterceptableContextFactory() - if config.use_insecure_ssl_client_just_for_testing_do_not_use + if self.config.use_insecure_ssl_client_just_for_testing_do_not_use else BrowserLikePolicyForHTTPS() ) @@ -157,15 +165,9 @@ class HomeServer(object): def build_state_handler(self): return StateHandler(self) - def build_distributor(self): - return Distributor() - def build_event_sources(self): return EventSources(self) - def build_ratelimiter(self): - return Ratelimiter() - def build_keyring(self): return Keyring(self) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 7a3f6c4662..c8cab45f77 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -46,6 +46,9 @@ from .tags import TagsStore from .account_data import AccountDataStore +from util.id_generators import IdGenerator, StreamIdGenerator + + import logging @@ -58,6 +61,22 @@ logger = logging.getLogger(__name__) LAST_SEEN_GRANULARITY = 120*1000 +def get_datastore(hs): + logger.info("getting called!") + + conn = hs.get_db_conn() + try: + cur = conn.cursor() + cur.execute("SELECT MIN(stream_ordering) FROM events",) + rows = cur.fetchall() + min_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1 + min_token = min(min_token, -1) + + return DataStore(conn, hs, min_token) + finally: + conn.close() + + class DataStore(RoomMemberStore, RoomStore, RegistrationStore, StreamStore, ProfileStore, PresenceStore, TransactionStore, @@ -79,18 +98,36 @@ class DataStore(RoomMemberStore, RoomStore, EventPushActionsStore ): - def __init__(self, hs): - super(DataStore, self).__init__(hs) + def __init__(self, db_conn, hs, min_stream_token): self.hs = hs - self.min_token_deferred = self._get_min_token() - self.min_token = None + self.min_stream_token = min_stream_token self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, ) + self._stream_id_gen = StreamIdGenerator( + db_conn, "events", "stream_ordering" + ) + self._receipts_id_gen = StreamIdGenerator( + db_conn, "receipts_linearized", "stream_id" + ) + self._account_data_id_gen = StreamIdGenerator( + db_conn, "account_data_max_stream_id", "stream_id" + ) + + self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) + self._state_groups_id_gen = IdGenerator("state_groups", "id", self) + self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) + self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self) + self._pushers_id_gen = IdGenerator("pushers", "id", self) + self._push_rule_id_gen = IdGenerator("push_rules", "id", self) + self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) + + super(DataStore, self).__init__(hs) + @defer.inlineCallbacks def insert_client_ip(self, user, access_token, ip, user_agent): now = int(self._clock.time_msec()) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 90d7aee94a..5e77320540 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -15,13 +15,11 @@ import logging from synapse.api.errors import StoreError -from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.descriptors import Cache import synapse.metrics -from util.id_generators import IdGenerator, StreamIdGenerator from twisted.internet import defer @@ -175,16 +173,6 @@ class SQLBaseStore(object): self.database_engine = hs.database_engine - self._stream_id_gen = StreamIdGenerator("events", "stream_ordering") - self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) - self._state_groups_id_gen = IdGenerator("state_groups", "id", self) - self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) - self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self) - self._pushers_id_gen = IdGenerator("pushers", "id", self) - self._push_rule_id_gen = IdGenerator("push_rules", "id", self) - self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) - self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id") - def start_profiling(self): self._previous_loop_ts = self._clock.time_msec() @@ -345,7 +333,8 @@ class SQLBaseStore(object): defer.returnValue(result) - def cursor_to_dict(self, cursor): + @staticmethod + def cursor_to_dict(cursor): """Converts a SQL cursor into an list of dicts. Args: @@ -402,8 +391,8 @@ class SQLBaseStore(object): if not or_ignore: raise - @log_function - def _simple_insert_txn(self, txn, table, values): + @staticmethod + def _simple_insert_txn(txn, table, values): keys, vals = zip(*values.items()) sql = "INSERT INTO %s (%s) VALUES(%s)" % ( @@ -414,7 +403,8 @@ class SQLBaseStore(object): txn.execute(sql, vals) - def _simple_insert_many_txn(self, txn, table, values): + @staticmethod + def _simple_insert_many_txn(txn, table, values): if not values: return @@ -537,9 +527,10 @@ class SQLBaseStore(object): table, keyvalues, retcol, allow_none=allow_none, ) - def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol, + @classmethod + def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol, allow_none=False): - ret = self._simple_select_onecol_txn( + ret = cls._simple_select_onecol_txn( txn, table=table, keyvalues=keyvalues, @@ -554,7 +545,8 @@ class SQLBaseStore(object): else: raise StoreError(404, "No row found") - def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol): + @staticmethod + def _simple_select_onecol_txn(txn, table, keyvalues, retcol): sql = ( "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" ) % { @@ -603,7 +595,8 @@ class SQLBaseStore(object): table, keyvalues, retcols ) - def _simple_select_list_txn(self, txn, table, keyvalues, retcols): + @classmethod + def _simple_select_list_txn(cls, txn, table, keyvalues, retcols): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -627,7 +620,7 @@ class SQLBaseStore(object): ) txn.execute(sql) - return self.cursor_to_dict(txn) + return cls.cursor_to_dict(txn) @defer.inlineCallbacks def _simple_select_many_batch(self, table, column, iterable, retcols, @@ -662,7 +655,8 @@ class SQLBaseStore(object): defer.returnValue(results) - def _simple_select_many_txn(self, txn, table, column, iterable, keyvalues, retcols): + @classmethod + def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -699,7 +693,7 @@ class SQLBaseStore(object): ) txn.execute(sql, values) - return self.cursor_to_dict(txn) + return cls.cursor_to_dict(txn) def _simple_update_one(self, table, keyvalues, updatevalues, desc="_simple_update_one"): @@ -726,7 +720,8 @@ class SQLBaseStore(object): table, keyvalues, updatevalues, ) - def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues): + @staticmethod + def _simple_update_one_txn(txn, table, keyvalues, updatevalues): update_sql = "UPDATE %s SET %s WHERE %s" % ( table, ", ".join("%s = ?" % (k,) for k in updatevalues), @@ -743,7 +738,8 @@ class SQLBaseStore(object): if txn.rowcount > 1: raise StoreError(500, "More than one row matched") - def _simple_select_one_txn(self, txn, table, keyvalues, retcols, + @staticmethod + def _simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False): select_sql = "SELECT %s FROM %s WHERE %s" % ( ", ".join(retcols), @@ -784,7 +780,8 @@ class SQLBaseStore(object): raise StoreError(500, "more than one row matched") return self.runInteraction(desc, func) - def _simple_delete_txn(self, txn, table, keyvalues): + @staticmethod + def _simple_delete_txn(txn, table, keyvalues): sql = "DELETE FROM %s WHERE %s" % ( table, " AND ".join("%s = ?" % (k, ) for k in keyvalues) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index ba368a3eca..298cb9bada 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -66,11 +66,9 @@ class EventsStore(SQLBaseStore): return if backfilled: - if not self.min_token_deferred.called: - yield self.min_token_deferred - start = self.min_token - 1 - self.min_token -= len(events_and_contexts) + 1 - stream_orderings = range(start, self.min_token, -1) + start = self.min_stream_token - 1 + self.min_stream_token -= len(events_and_contexts) + 1 + stream_orderings = range(start, self.min_stream_token, -1) @contextmanager def stream_ordering_manager(): @@ -107,10 +105,8 @@ class EventsStore(SQLBaseStore): is_new_state=True, current_state=None): stream_ordering = None if backfilled: - if not self.min_token_deferred.called: - yield self.min_token_deferred - self.min_token -= 1 - stream_ordering = self.min_token + self.min_stream_token -= 1 + stream_ordering = self.min_stream_token if stream_ordering is None: stream_ordering_manager = yield self._stream_id_gen.get_next(self) diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index c4232bdc65..c0593e23ee 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -31,7 +31,9 @@ class ReceiptsStore(SQLBaseStore): def __init__(self, hs): super(ReceiptsStore, self).__init__(hs) - self._receipts_stream_cache = _RoomStreamChangeCache() + self._receipts_stream_cache = _RoomStreamChangeCache( + self._receipts_id_gen.get_max_token(None) + ) @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): @@ -377,11 +379,11 @@ class _RoomStreamChangeCache(object): may have changed since that key. If the key is too old then the cache will simply return all rooms. """ - def __init__(self, size_of_cache=10000): + def __init__(self, current_key, size_of_cache=10000): self._size_of_cache = size_of_cache self._room_to_key = {} self._cache = sorteddict() - self._earliest_key = None + self._earliest_key = current_key self.name = "ReceiptsRoomChangeCache" caches_by_name[self.name] = self._cache diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 02b1913e26..e31bad258a 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -444,19 +444,6 @@ class StreamStore(SQLBaseStore): rows = txn.fetchall() return rows[0][0] if rows else 0 - @defer.inlineCallbacks - def _get_min_token(self): - row = yield self._execute( - "_get_min_token", None, "SELECT MIN(stream_ordering) FROM events" - ) - - self.min_token = row[0][0] if row and row[0] and row[0][0] else -1 - self.min_token = min(self.min_token, -1) - - logger.debug("min_token is: %s", self.min_token) - - defer.returnValue(self.min_token) - @staticmethod def _set_before_and_after(events, rows): for event, row in zip(events, rows): diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index ed9c91e5ea..4c39e07cbd 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -16,7 +16,6 @@ from ._base import SQLBaseStore from synapse.util.caches.descriptors import cached from twisted.internet import defer -from .util.id_generators import StreamIdGenerator import ujson as json import logging @@ -25,12 +24,6 @@ logger = logging.getLogger(__name__) class TagsStore(SQLBaseStore): - def __init__(self, hs): - super(TagsStore, self).__init__(hs) - - self._account_data_id_gen = StreamIdGenerator( - "account_data_max_stream_id", "stream_id" - ) def get_max_account_data_stream_id(self): """Get the current max stream id for the private user data stream diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index f58bf7fd2c..5c522f4ab9 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -72,28 +72,24 @@ class StreamIdGenerator(object): with stream_id_gen.get_next_txn(txn) as stream_id: # ... persist event ... """ - def __init__(self, table, column): + def __init__(self, db_conn, table, column): self.table = table self.column = column self._lock = threading.Lock() - self._current_max = None + cur = db_conn.cursor() + self._current_max = self._get_or_compute_current_max(cur) + cur.close() + self._unfinished_ids = deque() - @defer.inlineCallbacks def get_next(self, store): """ Usage: with yield stream_id_gen.get_next as stream_id: # ... persist event ... """ - if not self._current_max: - yield store.runInteraction( - "_compute_current_max", - self._get_or_compute_current_max, - ) - with self._lock: self._current_max += 1 next_id = self._current_max @@ -108,21 +104,14 @@ class StreamIdGenerator(object): with self._lock: self._unfinished_ids.remove(next_id) - defer.returnValue(manager()) + return manager() - @defer.inlineCallbacks def get_next_mult(self, store, n): """ Usage: with yield stream_id_gen.get_next(store, n) as stream_ids: # ... persist events ... """ - if not self._current_max: - yield store.runInteraction( - "_compute_current_max", - self._get_or_compute_current_max, - ) - with self._lock: next_ids = range(self._current_max + 1, self._current_max + n + 1) self._current_max += n @@ -139,24 +128,17 @@ class StreamIdGenerator(object): for next_id in next_ids: self._unfinished_ids.remove(next_id) - defer.returnValue(manager()) + return manager() - @defer.inlineCallbacks def get_max_token(self, store): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. """ - if not self._current_max: - yield store.runInteraction( - "_compute_current_max", - self._get_or_compute_current_max, - ) - with self._lock: if self._unfinished_ids: - defer.returnValue(self._unfinished_ids[0] - 1) + return self._unfinished_ids[0] - 1 - defer.returnValue(self._current_max) + return self._current_max def _get_or_compute_current_max(self, txn): with self._lock: -- cgit 1.4.1 From 8c94833b72f27d18a57e424a0f4f0c823c3a0aa1 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 27 Jan 2016 10:24:20 +0000 Subject: Fix adding push rules relative to other rules --- synapse/rest/client/v1/push_rule.py | 15 ++++++++++----- synapse/storage/push_rule.py | 3 ++- 2 files changed, 12 insertions(+), 6 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index cb3ec23872..96633a176c 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -66,11 +66,12 @@ class PushRuleRestServlet(ClientV1RestServlet): raise SynapseError(400, e.message) before = request.args.get("before", None) - if before and len(before): - before = before[0] + if before: + before = _namespaced_rule_id(spec, before[0]) + after = request.args.get("after", None) - if after and len(after): - after = after[0] + if after: + after = _namespaced_rule_id(spec, after[0]) try: yield self.hs.get_datastore().add_push_rule( @@ -452,11 +453,15 @@ def _strip_device_condition(rule): def _namespaced_rule_id_from_spec(spec): + return _namespaced_rule_id(spec, spec['rule_id']) + + +def _namespaced_rule_id(spec, rule_id): if spec['scope'] == 'global': scope = 'global' else: scope = 'device/%s' % (spec['profile_tag']) - return "%s/%s/%s" % (scope, spec['template'], spec['rule_id']) + return "%s/%s/%s" % (scope, spec['template'], rule_id) def _rule_id_from_namespaced(in_rule_id): diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 1f51c90ee5..f9a48171ba 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -130,7 +130,8 @@ class PushRuleStore(SQLBaseStore): def _add_push_rule_relative_txn(self, txn, user_id, **kwargs): after = kwargs.pop("after", None) - relative_to_rule = kwargs.pop("before", after) + before = kwargs.pop("before", None) + relative_to_rule = before or after res = self._simple_select_one_txn( txn, -- cgit 1.4.1 From b97f6626b6f9b91498d06a7ae113b9d20f1fc2ef Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 27 Jan 2016 09:54:30 +0000 Subject: Add cache to room stream --- synapse/handlers/sync.py | 42 +++++++--- synapse/storage/events.py | 2 + synapse/storage/receipts.py | 65 +-------------- synapse/storage/stream.py | 133 +++++++++++++++++++++++++++++++ synapse/util/caches/room_change_cache.py | 86 ++++++++++++++++++++ 5 files changed, 254 insertions(+), 74 deletions(-) create mode 100644 synapse/util/caches/room_change_cache.py (limited to 'synapse/storage') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 328c049b03..1fdf978313 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -514,13 +514,6 @@ class SyncHandler(BaseHandler): timeline_limit = sync_config.filter_collection.timeline_limit() - room_events, _ = yield self.store.get_room_events_stream( - sync_config.user.to_string(), - from_key=since_token.room_key, - to_key=now_token.room_key, - limit=timeline_limit + 1, - ) - tags_by_room = yield self.store.get_updated_tags( sync_config.user.to_string(), since_token.account_data_key, @@ -533,6 +526,32 @@ class SyncHandler(BaseHandler): ) ) + rooms_changed = yield self.store.get_room_changes_for_user( + sync_config.user.to_string(), since_token.room_key, now_token.room_key + ) + + room_to_events = yield self.store.get_room_events_stream_for_rooms( + room_ids=room_ids, + from_key=since_token.room_key, + to_key=now_token.room_key, + limit=timeline_limit + 1, + ) + + room_events = [ + event + for events, _ in room_to_events.values() + for event in events + ] + + room_events.extend(rooms_changed) + + # room_events, _ = yield self.store.get_room_events_stream( + # sync_config.user.to_string(), + # from_key=since_token.room_key, + # to_key=now_token.room_key, + # limit=timeline_limit + 1, + # ) + joined = [] archived = [] if len(room_events) <= timeline_limit: @@ -694,14 +713,12 @@ class SyncHandler(BaseHandler): end_key = room_key while limited and len(recents) < timeline_limit and max_repeat: - events, keys = yield self.store.get_recent_events_for_room( + events, end_key = yield self.store.get_recent_room_events_stream_for_room( room_id, limit=load_limit + 1, - from_token=since_token.room_key if since_token else None, - end_token=end_key, + from_key=since_token.room_key if since_token else None, + to_key=end_key, ) - room_key, _ = keys - end_key = "s" + room_key.split('-')[-1] loaded_recents = sync_config.filter_collection.filter_room_timeline(events) loaded_recents = yield self._filter_events_for_client( sync_config.user.to_string(), @@ -712,6 +729,7 @@ class SyncHandler(BaseHandler): recents = loaded_recents if len(events) <= load_limit: limited = False + break max_repeat -= 1 if len(recents) > timeline_limit: diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 298cb9bada..d96ea3a30e 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -128,6 +128,8 @@ class EventsStore(SQLBaseStore): is_new_state=is_new_state, current_state=current_state, ) + logger.info("Invalidating %r at %r", event.room_id, stream_ordering) + self._events_stream_cache.room_has_changed(None, event.room_id, stream_ordering) except _RollbackButIsFineException: pass diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index c0593e23ee..b7a4e77748 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -16,6 +16,7 @@ from ._base import SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached from synapse.util.caches import cache_counter, caches_by_name +from synapse.util.caches.room_change_cache import RoomStreamChangeCache from twisted.internet import defer @@ -31,8 +32,8 @@ class ReceiptsStore(SQLBaseStore): def __init__(self, hs): super(ReceiptsStore, self).__init__(hs) - self._receipts_stream_cache = _RoomStreamChangeCache( - self._receipts_id_gen.get_max_token(None) + self._receipts_stream_cache = RoomStreamChangeCache( + "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token(None) ) @cached(num_args=2) @@ -370,63 +371,3 @@ class ReceiptsStore(SQLBaseStore): "data": json.dumps(data), } ) - - -class _RoomStreamChangeCache(object): - """Keeps track of the stream_id of the latest change in rooms. - - Given a list of rooms and stream key, it will give a subset of rooms that - may have changed since that key. If the key is too old then the cache - will simply return all rooms. - """ - def __init__(self, current_key, size_of_cache=10000): - self._size_of_cache = size_of_cache - self._room_to_key = {} - self._cache = sorteddict() - self._earliest_key = current_key - self.name = "ReceiptsRoomChangeCache" - caches_by_name[self.name] = self._cache - - @defer.inlineCallbacks - def get_rooms_changed(self, store, room_ids, key): - """Returns subset of room ids that have had new receipts since the - given key. If the key is too old it will just return the given list. - """ - if key > (yield self._get_earliest_key(store)): - keys = self._cache.keys() - i = keys.bisect_right(key) - - result = set( - self._cache[k] for k in keys[i:] - ).intersection(room_ids) - - cache_counter.inc_hits(self.name) - else: - result = room_ids - cache_counter.inc_misses(self.name) - - defer.returnValue(result) - - @defer.inlineCallbacks - def room_has_changed(self, store, room_id, key): - """Informs the cache that the room has been changed at the given key. - """ - if key > (yield self._get_earliest_key(store)): - old_key = self._room_to_key.get(room_id, None) - if old_key: - key = max(key, old_key) - self._cache.pop(old_key, None) - self._cache[key] = room_id - - while len(self._cache) > self._size_of_cache: - k, r = self._cache.popitem() - self._earliest_key = max(k, self._earliest_key) - self._room_to_key.pop(r, None) - - @defer.inlineCallbacks - def _get_earliest_key(self, store): - if self._earliest_key is None: - self._earliest_key = yield store.get_max_receipt_stream_id() - self._earliest_key = int(self._earliest_key) - - defer.returnValue(self._earliest_key) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index e31bad258a..3a32a0019a 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -37,6 +37,7 @@ from twisted.internet import defer from ._base import SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.room_change_cache import RoomStreamChangeCache from synapse.api.constants import EventTypes from synapse.types import RoomStreamToken from synapse.util.logutils import log_function @@ -77,6 +78,12 @@ def upper_bound(token): class StreamStore(SQLBaseStore): + def __init__(self, hs): + super(StreamStore, self).__init__(hs) + + self._events_stream_cache = RoomStreamChangeCache( + "EventsRoomStreamChangeCache", self._stream_id_gen.get_max_token(None) + ) @defer.inlineCallbacks def get_appservice_room_stream(self, service, from_key, to_key, limit=0): @@ -157,6 +164,132 @@ class StreamStore(SQLBaseStore): results = yield self.runInteraction("get_appservice_room_stream", f) defer.returnValue(results) + @defer.inlineCallbacks + def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0): + from_id = RoomStreamToken.parse_stream_token(from_key).stream + + room_ids = yield self._events_stream_cache.get_rooms_changed( + self, room_ids, from_id + ) + + if not room_ids: + defer.returnValue({}) + + results = {} + room_ids = list(room_ids) + for rm_ids in (room_ids[i:i+20] for i in xrange(0, len(room_ids), 20)): + res = yield defer.gatherResults([ + self.get_recent_room_events_stream_for_room( + room_id, from_key, to_key, limit + ).addCallback(lambda r, rm: (rm, r), room_id) + for room_id in room_ids + ]) + results.update(dict(res)) + + defer.returnValue(results) + + @defer.inlineCallbacks + def get_recent_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0): + if from_key is not None: + from_id = RoomStreamToken.parse_stream_token(from_key).stream + else: + from_id = None + to_id = RoomStreamToken.parse_stream_token(to_key).stream + + if from_key == to_key: + defer.returnValue(([], from_key)) + + has_changed = yield self._events_stream_cache.get_room_has_changed( + room_id, from_id + ) + + if not has_changed: + defer.returnValue(([], from_key)) + + def f(txn): + if from_id is not None: + sql = ( + "SELECT event_id, stream_ordering FROM events WHERE" + " room_id = ?" + " AND not outlier" + " AND stream_ordering > ? AND stream_ordering <= ?" + " ORDER BY stream_ordering DESC LIMIT ?" + ) + txn.execute(sql, (room_id, from_id, to_id, limit)) + else: + sql = ( + "SELECT event_id, stream_ordering FROM events WHERE" + " room_id = ?" + " AND not outlier" + " AND stream_ordering <= ?" + " ORDER BY stream_ordering DESC LIMIT ?" + ) + txn.execute(sql, (room_id, to_id, limit)) + + rows = self.cursor_to_dict(txn) + + ret = self._get_events_txn( + txn, + [r["event_id"] for r in rows], + get_prev_content=True + ) + + ret.reverse() + + self._set_before_and_after(ret, rows) + + if rows: + key = "s%d" % min(r["stream_ordering"] for r in rows) + else: + # Assume we didn't get anything because there was nothing to + # get. + key = from_key + + return ret, key + res = yield self.runInteraction("get_recent_room_events_stream_for_room", f) + defer.returnValue(res) + + def get_room_changes_for_user(self, user_id, from_key, to_key): + if from_key is not None: + from_id = RoomStreamToken.parse_stream_token(from_key).stream + else: + from_id = None + to_id = RoomStreamToken.parse_stream_token(to_key).stream + + if from_key == to_key: + return defer.succeed([]) + + def f(txn): + if from_id is not None: + sql = ( + "SELECT m.event_id, stream_ordering FROM events AS e, room_memberships AS m" + " WHERE e.event_id = m.event_id" + " AND m.user_id = ?" + " AND e.stream_ordering > ? AND e.stream_ordering <= ?" + " ORDER BY e.stream_ordering ASC" + ) + txn.execute(sql, (user_id, from_id, to_id,)) + else: + sql = ( + "SELECT m.event_id, stream_ordering FROM events AS e, room_memberships AS m" + " WHERE e.event_id = m.event_id" + " AND m.user_id = ?" + " AND stream_ordering <= ?" + " ORDER BY stream_ordering ASC" + ) + txn.execute(sql, (user_id, to_id,)) + rows = self.cursor_to_dict(txn) + + ret = self._get_events_txn( + txn, + [r["event_id"] for r in rows], + get_prev_content=True + ) + + return ret + + return self.runInteraction("get_room_changes_for_user", f) + @log_function def get_room_events_stream( self, diff --git a/synapse/util/caches/room_change_cache.py b/synapse/util/caches/room_change_cache.py new file mode 100644 index 0000000000..3a873c9c30 --- /dev/null +++ b/synapse/util/caches/room_change_cache.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.util.caches import cache_counter, caches_by_name + + +from blist import sorteddict +import logging + + +logger = logging.getLogger(__name__) + + +class RoomStreamChangeCache(object): + """Keeps track of the stream_id of the latest change in rooms. + + Given a list of rooms and stream key, it will give a subset of rooms that + may have changed since that key. If the key is too old then the cache + will simply return all rooms. + """ + def __init__(self, name, current_key, size_of_cache=10000): + self._size_of_cache = size_of_cache + self._room_to_key = {} + self._cache = sorteddict() + self._earliest_known_key = current_key + self.name = name + caches_by_name[self.name] = self._cache + + def get_room_has_changed(self, room_id, key): + if key <= self._earliest_known_key: + return True + + room_key = self._room_to_key.get(room_id, None) + if room_key is None: + return True + + if key < room_key: + return True + + return False + + def get_rooms_changed(self, store, room_ids, key): + """Returns subset of room ids that have had new things since the + given key. If the key is too old it will just return the given list. + """ + if key > self._earliest_known_key: + keys = self._cache.keys() + i = keys.bisect_right(key) + + result = set( + self._cache[k] for k in keys[i:] + ).intersection(room_ids) + + cache_counter.inc_hits(self.name) + else: + result = room_ids + cache_counter.inc_misses(self.name) + + return result + + def room_has_changed(self, store, room_id, key): + """Informs the cache that the room has been changed at the given key. + """ + if key > self._earliest_known_key: + old_key = self._room_to_key.get(room_id, None) + if old_key: + key = max(key, old_key) + self._cache.pop(old_key, None) + self._cache[key] = room_id + + while len(self._cache) > self._size_of_cache: + k, r = self._cache.popitem() + self._earliest_key = max(k, self._earliest_key) + self._room_to_key.pop(r, None) -- cgit 1.4.1 From aca3193efb8c5f9f20049f61c96e5ff12f328b05 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 27 Jan 2016 17:06:52 +0000 Subject: Use the same path for incremental with gap or without gap --- synapse/handlers/sync.py | 352 +++++++++++++++++++--------------------------- synapse/storage/events.py | 1 - synapse/storage/stream.py | 6 +- 3 files changed, 147 insertions(+), 212 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 1fdf978313..f5e20d6a6e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -72,7 +72,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [ ) -class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [ +class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [ "room_id", # str "timeline", # TimelineBatch "state", # dict[(str, str), FrozenEvent] @@ -429,44 +429,20 @@ class SyncHandler(BaseHandler): defer.returnValue((now_token, ephemeral_by_room)) - @defer.inlineCallbacks def full_state_sync_for_archived_room(self, room_id, sync_config, leave_event_id, leave_token, timeline_since_token, tags_by_room, account_data_by_room): """Sync a room for a client which is starting without any state Returns: - A Deferred JoinedSyncResult. + A Deferred ArchivedSyncResult. """ - batch = yield self.load_filtered_recents( - room_id, sync_config, leave_token, since_token=timeline_since_token + return self.incremental_sync_for_archived_room( + sync_config, room_id, leave_event_id, timeline_since_token, tags_by_room, + account_data_by_room, full_state=True, leave_token=leave_token, ) - leave_state = yield self.store.get_state_for_event(leave_event_id) - - leave_state = { - (e.type, e.state_key): e - for e in sync_config.filter_collection.filter_room_state( - leave_state.values() - ) - } - - account_data = self.account_data_for_room( - room_id, tags_by_room, account_data_by_room - ) - - account_data = sync_config.filter_collection.filter_room_account_data( - account_data - ) - - defer.returnValue(ArchivedSyncResult( - room_id=room_id, - timeline=batch, - state=leave_state, - account_data=account_data, - )) - @defer.inlineCallbacks def incremental_sync_with_gap(self, sync_config, since_token): """ Get the incremental delta needed to bring the client up to @@ -512,173 +488,127 @@ class SyncHandler(BaseHandler): sync_config.user ) + user_id = sync_config.user.to_string() + timeline_limit = sync_config.filter_collection.timeline_limit() tags_by_room = yield self.store.get_updated_tags( - sync_config.user.to_string(), + user_id, since_token.account_data_key, ) account_data, account_data_by_room = ( yield self.store.get_updated_account_data_for_user( - sync_config.user.to_string(), + user_id, since_token.account_data_key, ) ) + # Get a list of membership change events that have happened. rooms_changed = yield self.store.get_room_changes_for_user( - sync_config.user.to_string(), since_token.room_key, now_token.room_key + user_id, since_token.room_key, now_token.room_key ) + mem_change_events_by_room_id = {} + for event in rooms_changed: + mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) + + newly_joined_rooms = [] + archived = [] + invited = [] + for room_id, events in mem_change_events_by_room_id.items(): + non_joins = [e for e in events if e.membership != Membership.JOIN] + has_join = len(non_joins) != len(events) + + # We want to figure out if we joined the room at some point since + # the last sync (even if we have since left). This is to make sure + # we do send down the room, and with full state, where necessary + if room_id in joined_room_ids or has_join: + old_state = yield self.get_state_at(room_id, since_token) + old_mem_ev = old_state.get((EventTypes.Member, user_id), None) + if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: + newly_joined_rooms.append(room_id) + + if room_id in joined_room_ids: + continue + + if not non_joins: + continue + + # Only bother if we're still currently invited + should_invite = non_joins[-1].membership == Membership.INVITE + if should_invite: + room_sync = InvitedSyncResult(room_id, invite=non_joins[-1]) + if room_sync: + invited.append(room_sync) + + # Always include leave/ban events. Just take the last one. + # TODO: How do we handle ban -> leave in same batch? + leave_events = [ + e for e in non_joins + if e.membership in (Membership.LEAVE, Membership.BAN) + ] + + if leave_events: + leave_event = leave_events[-1] + room_sync = yield self.incremental_sync_for_archived_room( + sync_config, room_id, leave_event.event_id, since_token, + tags_by_room, account_data_by_room, + full_state=room_id in newly_joined_rooms + ) + if room_sync: + archived.append(room_sync) + + # Get all events for rooms we're currently joined to. room_to_events = yield self.store.get_room_events_stream_for_rooms( - room_ids=room_ids, + room_ids=joined_room_ids, from_key=since_token.room_key, to_key=now_token.room_key, limit=timeline_limit + 1, ) - room_events = [ - event - for events, _ in room_to_events.values() - for event in events - ] - - room_events.extend(rooms_changed) - - # room_events, _ = yield self.store.get_room_events_stream( - # sync_config.user.to_string(), - # from_key=since_token.room_key, - # to_key=now_token.room_key, - # limit=timeline_limit + 1, - # ) - joined = [] - archived = [] - if len(room_events) <= timeline_limit: - # There is no gap in any of the rooms. Therefore we can just - # partition the new events by room and return them. - logger.debug("Got %i events for incremental sync - not limited", - len(room_events)) - - invite_events = [] - leave_events = [] - events_by_room_id = {} - for event in room_events: - events_by_room_id.setdefault(event.room_id, []).append(event) - if event.room_id not in joined_room_ids: - if (event.type == EventTypes.Member - and event.state_key == sync_config.user.to_string()): - if event.membership == Membership.INVITE: - invite_events.append(event) - elif event.membership in (Membership.LEAVE, Membership.BAN): - leave_events.append(event) - - for room_id in joined_room_ids: - recents = events_by_room_id.get(room_id, []) - logger.debug("Events for room %s: %r", room_id, recents) - state = { - (event.type, event.state_key): event - for event in recents if event.is_state()} - limited = False + # We loop through all room ids, even if there are no new events, in case + # there are non room events taht we need to notify about. + for room_id in joined_room_ids: + room_entry = room_to_events.get(room_id, None) - if recents: - prev_batch = now_token.copy_and_replace( - "room_key", recents[0].internal_metadata.before - ) - else: - prev_batch = now_token - - just_joined = yield self.check_joined_room(sync_config, state) - if just_joined: - logger.debug("User has just joined %s: needs full state", - room_id) - state = yield self.get_state_at(room_id, now_token) - # the timeline is inherently limited if we've just joined - limited = True - - recents = sync_config.filter_collection.filter_room_timeline(recents) - - state = { - (e.type, e.state_key): e - for e in sync_config.filter_collection.filter_room_state( - state.values() - ) - } - - acc_data = self.account_data_for_room( - room_id, tags_by_room, account_data_by_room - ) + if room_entry: + events, start_key = room_entry - acc_data = sync_config.filter_collection.filter_room_account_data( - acc_data - ) + prev_batch_token = now_token.copy_and_replace("room_key", start_key) - ephemeral = sync_config.filter_collection.filter_room_ephemeral( - ephemeral_by_room.get(room_id, []) - ) + newly_joined_room = room_id in newly_joined_rooms + full_state = newly_joined_room - room_sync = JoinedSyncResult( - room_id=room_id, - timeline=TimelineBatch( - events=recents, - prev_batch=prev_batch, - limited=limited, - ), - state=state, - ephemeral=ephemeral, - account_data=acc_data, - unread_notifications={}, + batch = yield self.load_filtered_recents( + room_id, sync_config, prev_batch_token, + since_token=since_token, + recents=events, + newly_joined_room=newly_joined_room, ) - logger.debug("Result for room %s: %r", room_id, room_sync) - - if room_sync: - notifs = yield self.unread_notifs_for_room_id( - room_id, sync_config, all_ephemeral_by_room - ) - - if notifs is not None: - notif_dict = room_sync.unread_notifications - notif_dict["notification_count"] = len(notifs) - notif_dict["highlight_count"] = len([ - 1 for notif in notifs - if _action_has_highlight(notif["actions"]) - ]) - - joined.append(room_sync) - - else: - logger.debug("Got %i events for incremental sync - hit limit", - len(room_events)) - - invite_events = yield self.store.get_invites_for_user( - sync_config.user.to_string() - ) - - leave_events = yield self.store.get_leave_and_ban_events_for_user( - sync_config.user.to_string() - ) - - for room_id in joined_room_ids: - room_sync = yield self.incremental_sync_with_gap_for_room( - room_id, sync_config, since_token, now_token, - ephemeral_by_room, tags_by_room, account_data_by_room, - all_ephemeral_by_room=all_ephemeral_by_room, + else: + batch = TimelineBatch( + events=[], + prev_batch=since_token, + limited=False, ) - if room_sync: - joined.append(room_sync) + full_state = False - for leave_event in leave_events: - room_sync = yield self.incremental_sync_for_archived_room( - sync_config, leave_event, since_token, tags_by_room, - account_data_by_room + room_sync = yield self.incremental_sync_with_gap_for_room( + room_id=room_id, + sync_config=sync_config, + since_token=since_token, + now_token=now_token, + ephemeral_by_room=ephemeral_by_room, + tags_by_room=tags_by_room, + account_data_by_room=account_data_by_room, + all_ephemeral_by_room=all_ephemeral_by_room, + batch=batch, + full_state=full_state, ) if room_sync: - archived.append(room_sync) - - invited = [ - InvitedSyncResult(room_id=event.room_id, invite=event) - for event in invite_events - ] + joined.append(room_sync) account_data_for_user = sync_config.filter_collection.filter_account_data( self.account_data_for_user(account_data) @@ -699,12 +629,10 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def load_filtered_recents(self, room_id, sync_config, now_token, - since_token=None): + since_token=None, recents=None, newly_joined_room=False): """ :returns a Deferred TimelineBatch """ - limited = True - recents = [] filtering_factor = 2 timeline_limit = sync_config.filter_collection.timeline_limit() load_limit = max(timeline_limit * filtering_factor, 100) @@ -712,11 +640,27 @@ class SyncHandler(BaseHandler): room_key = now_token.room_key end_key = room_key + limited = recents is None or newly_joined_room or timeline_limit < len(recents) + + if recents is not None: + recents = sync_config.filter_collection.filter_room_timeline(recents) + recents = yield self._filter_events_for_client( + sync_config.user.to_string(), + recents, + is_peeking=sync_config.is_guest, + ) + else: + recents = [] + + since_key = None + if since_token and not newly_joined_room: + since_key = since_token.room_key + while limited and len(recents) < timeline_limit and max_repeat: - events, end_key = yield self.store.get_recent_room_events_stream_for_room( + events, end_key = yield self.store.get_room_events_stream_for_room( room_id, limit=load_limit + 1, - from_key=since_token.room_key if since_token else None, + from_key=since_key, to_key=end_key, ) loaded_recents = sync_config.filter_collection.filter_room_timeline(events) @@ -727,6 +671,7 @@ class SyncHandler(BaseHandler): ) loaded_recents.extend(recents) recents = loaded_recents + if len(events) <= load_limit: limited = False break @@ -742,7 +687,9 @@ class SyncHandler(BaseHandler): ) defer.returnValue(TimelineBatch( - events=recents, prev_batch=prev_batch_token, limited=limited + events=recents, + prev_batch=prev_batch_token, + limited=limited or newly_joined_room )) @defer.inlineCallbacks @@ -750,24 +697,8 @@ class SyncHandler(BaseHandler): since_token, now_token, ephemeral_by_room, tags_by_room, account_data_by_room, - all_ephemeral_by_room): - """ Get the incremental delta needed to bring the client up to date for - the room. Gives the client the most recent events and the changes to - state. - Returns: - A Deferred JoinedSyncResult - """ - logger.debug("Doing incremental sync for room %s between %s and %s", - room_id, since_token, now_token) - - # TODO(mjark): Check for redactions we might have missed. - - batch = yield self.load_filtered_recents( - room_id, sync_config, now_token, since_token, - ) - - logger.debug("Recents %r", batch) - + all_ephemeral_by_room, + batch, full_state=False): if batch.limited: current_state = yield self.get_state_at(room_id, now_token) @@ -832,43 +763,48 @@ class SyncHandler(BaseHandler): defer.returnValue(room_sync) @defer.inlineCallbacks - def incremental_sync_for_archived_room(self, sync_config, leave_event, + def incremental_sync_for_archived_room(self, sync_config, room_id, leave_event_id, since_token, tags_by_room, - account_data_by_room): + account_data_by_room, full_state, + leave_token=None): """ Get the incremental delta needed to bring the client up to date for the archived room. Returns: A Deferred ArchivedSyncResult """ - stream_token = yield self.store.get_stream_token_for_event( - leave_event.event_id - ) + if not leave_token: + stream_token = yield self.store.get_stream_token_for_event( + leave_event_id + ) - leave_token = since_token.copy_and_replace("room_key", stream_token) + leave_token = since_token.copy_and_replace("room_key", stream_token) - if since_token.is_after(leave_token): + if since_token and since_token.is_after(leave_token): defer.returnValue(None) batch = yield self.load_filtered_recents( - leave_event.room_id, sync_config, leave_token, since_token, + room_id, sync_config, leave_token, since_token, ) logger.debug("Recents %r", batch) state_events_at_leave = yield self.store.get_state_for_event( - leave_event.event_id + leave_event_id ) - state_at_previous_sync = yield self.get_state_at( - leave_event.room_id, stream_position=since_token - ) + if not full_state: + state_at_previous_sync = yield self.get_state_at( + room_id, stream_position=since_token + ) - state_events_delta = yield self.compute_state_delta( - since_token=since_token, - previous_state=state_at_previous_sync, - current_state=state_events_at_leave, - ) + state_events_delta = yield self.compute_state_delta( + since_token=since_token, + previous_state=state_at_previous_sync, + current_state=state_events_at_leave, + ) + else: + state_events_delta = state_events_at_leave state_events_delta = { (e.type, e.state_key): e @@ -878,7 +814,7 @@ class SyncHandler(BaseHandler): } account_data = self.account_data_for_room( - leave_event.room_id, tags_by_room, account_data_by_room + room_id, tags_by_room, account_data_by_room ) account_data = sync_config.filter_collection.filter_room_account_data( @@ -886,7 +822,7 @@ class SyncHandler(BaseHandler): ) room_sync = ArchivedSyncResult( - room_id=leave_event.room_id, + room_id=room_id, timeline=batch, state=state_events_delta, account_data=account_data, diff --git a/synapse/storage/events.py b/synapse/storage/events.py index d96ea3a30e..0dd1daaa2e 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -128,7 +128,6 @@ class EventsStore(SQLBaseStore): is_new_state=is_new_state, current_state=current_state, ) - logger.info("Invalidating %r at %r", event.room_id, stream_ordering) self._events_stream_cache.room_has_changed(None, event.room_id, stream_ordering) except _RollbackButIsFineException: pass diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 3a32a0019a..563e289c4e 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -179,7 +179,7 @@ class StreamStore(SQLBaseStore): room_ids = list(room_ids) for rm_ids in (room_ids[i:i+20] for i in xrange(0, len(room_ids), 20)): res = yield defer.gatherResults([ - self.get_recent_room_events_stream_for_room( + self.get_room_events_stream_for_room( room_id, from_key, to_key, limit ).addCallback(lambda r, rm: (rm, r), room_id) for room_id in room_ids @@ -189,7 +189,7 @@ class StreamStore(SQLBaseStore): defer.returnValue(results) @defer.inlineCallbacks - def get_recent_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0): + def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0): if from_key is not None: from_id = RoomStreamToken.parse_stream_token(from_key).stream else: @@ -246,7 +246,7 @@ class StreamStore(SQLBaseStore): key = from_key return ret, key - res = yield self.runInteraction("get_recent_room_events_stream_for_room", f) + res = yield self.runInteraction("get_room_events_stream_for_room", f) defer.returnValue(res) def get_room_changes_for_user(self, user_id, from_key, to_key): -- cgit 1.4.1 From e7febf4fbb1f1beb11e7a03252f6844f84af7f30 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 27 Jan 2016 17:11:04 +0000 Subject: PEP8 --- synapse/storage/events.py | 4 +++- synapse/storage/receipts.py | 2 -- synapse/storage/stream.py | 9 ++++++--- 3 files changed, 9 insertions(+), 6 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 0dd1daaa2e..80187722ea 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -128,7 +128,9 @@ class EventsStore(SQLBaseStore): is_new_state=is_new_state, current_state=current_state, ) - self._events_stream_cache.room_has_changed(None, event.room_id, stream_ordering) + self._events_stream_cache.room_has_changed( + None, event.room_id, stream_ordering + ) except _RollbackButIsFineException: pass diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index b7a4e77748..7118368d97 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -15,12 +15,10 @@ from ._base import SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached -from synapse.util.caches import cache_counter, caches_by_name from synapse.util.caches.room_change_cache import RoomStreamChangeCache from twisted.internet import defer -from blist import sorteddict import logging import ujson as json diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 563e289c4e..0b22251790 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -262,7 +262,8 @@ class StreamStore(SQLBaseStore): def f(txn): if from_id is not None: sql = ( - "SELECT m.event_id, stream_ordering FROM events AS e, room_memberships AS m" + "SELECT m.event_id, stream_ordering FROM events AS e," + " room_memberships AS m" " WHERE e.event_id = m.event_id" " AND m.user_id = ?" " AND e.stream_ordering > ? AND e.stream_ordering <= ?" @@ -271,7 +272,8 @@ class StreamStore(SQLBaseStore): txn.execute(sql, (user_id, from_id, to_id,)) else: sql = ( - "SELECT m.event_id, stream_ordering FROM events AS e, room_memberships AS m" + "SELECT m.event_id, stream_ordering FROM events AS e," + " room_memberships AS m" " WHERE e.event_id = m.event_id" " AND m.user_id = ?" " AND stream_ordering <= ?" @@ -307,7 +309,8 @@ class StreamStore(SQLBaseStore): "SELECT c.room_id FROM history_visibility AS h" " INNER JOIN current_state_events AS c" " ON h.event_id = c.event_id" - " WHERE c.room_id IN (%s) AND h.history_visibility = 'world_readable'" % ( + " WHERE c.room_id IN (%s)" + " AND h.history_visibility = 'world_readable'" % ( ",".join(map(lambda _: "?", room_ids)) ) ) -- cgit 1.4.1 From c5e7c0e436d9073c07887676817bb5a45314aea5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 28 Jan 2016 09:58:45 +0000 Subject: Up get_rooms_for_user cache size --- synapse/storage/roommember.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index edfecced05..1d3e004c90 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -241,7 +241,7 @@ class RoomMemberStore(SQLBaseStore): return rows - @cached() + @cached(max_entries=5000) def get_rooms_for_user(self, user_id): return self.get_rooms_for_user_where_membership_is( user_id, membership_list=[Membership.JOIN], -- cgit 1.4.1 From ba8931829b0b601eb14049c92e0f21a10772576d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 28 Jan 2016 11:34:17 +0000 Subject: Return correct type of token --- synapse/storage/stream.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 0b22251790..28721e6994 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -236,7 +236,7 @@ class StreamStore(SQLBaseStore): ret.reverse() - self._set_before_and_after(ret, rows) + self._set_before_and_after(ret, rows, topo_order=False) if rows: key = "s%d" % min(r["stream_ordering"] for r in rows) @@ -581,10 +581,13 @@ class StreamStore(SQLBaseStore): return rows[0][0] if rows else 0 @staticmethod - def _set_before_and_after(events, rows): + def _set_before_and_after(events, rows, topo_order=True): for event, row in zip(events, rows): stream = row["stream_ordering"] - topo = event.depth + if topo_order: + topo = event.depth + else: + topo = None internal = event.internal_metadata internal.before = str(RoomStreamToken(topo, stream - 1)) internal.after = str(RoomStreamToken(topo, stream)) -- cgit 1.4.1 From 4e7948b47a3f197682de82fc0cda07ebb08a581d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 28 Jan 2016 11:52:34 +0000 Subject: Allow paginating backwards from stream token --- synapse/handlers/message.py | 15 +++++++++------ synapse/storage/stream.py | 16 ++++++++++++++-- tests/rest/client/v1/test_rooms.py | 9 +-------- 3 files changed, 24 insertions(+), 16 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index b73ad62147..82c8cb5f0c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -16,7 +16,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import SynapseError, AuthError, Codes +from synapse.api.errors import AuthError, Codes from synapse.streams.config import PaginationConfig from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator @@ -119,9 +119,12 @@ class MessageHandler(BaseHandler): if source_config.direction == 'b': # if we're going backwards, we might need to backfill. This # requires that we have a topo token. - if room_token.topological is None: - raise SynapseError(400, "Invalid token: cannot paginate " - "backwards from a stream token") + if room_token.topological: + max_topo = room_token.topological + else: + max_topo = yield self.store.get_max_topological_token_for_stream_and_room( + room_id, room_token.stream + ) if membership == Membership.LEAVE: # If they have left the room then clamp the token to be before @@ -131,11 +134,11 @@ class MessageHandler(BaseHandler): member_event_id ) leave_token = RoomStreamToken.parse(leave_token) - if leave_token.topological < room_token.topological: + if leave_token.topological < max_topo: source_config.from_key = str(leave_token) yield self.hs.get_handlers().federation_handler.maybe_backfill( - room_id, room_token.topological + room_id, max_topo ) events, next_key = yield data_source.get_pagination_rows( diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 28721e6994..5096b46864 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -234,10 +234,10 @@ class StreamStore(SQLBaseStore): get_prev_content=True ) - ret.reverse() - self._set_before_and_after(ret, rows, topo_order=False) + ret.reverse() + if rows: key = "s%d" % min(r["stream_ordering"] for r in rows) else: @@ -570,6 +570,18 @@ class StreamStore(SQLBaseStore): row["topological_ordering"], row["stream_ordering"],) ) + def get_max_topological_token_for_stream_and_room(self, room_id, stream_key): + sql = ( + "SELECT max(topological_ordering) FROM events" + " WHERE room_id = ? AND stream_ordering < ?" + ) + return self._execute( + "get_max_topological_token_for_stream_and_room", None, + sql, room_id, stream_key, + ).addCallback( + lambda r: r[0][0] if r else 0 + ) + def _get_max_topological_txn(self, txn): txn.execute( "SELECT MAX(topological_ordering) FROM events" diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 2fe6f695f5..ad5dd3bd6e 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1044,13 +1044,6 @@ class RoomMessageListTestCase(RestTestCase): self.assertTrue("chunk" in response) self.assertTrue("end" in response) - @defer.inlineCallbacks - def test_stream_token_is_rejected_for_back_pagination(self): - (code, response) = yield self.mock_resource.trigger_get( - "/rooms/%s/messages?access_token=x&from=s0_0_0_0_0&dir=b" % - self.room_id) - self.assertEquals(400, code) - @defer.inlineCallbacks def test_stream_token_is_accepted_for_fwd_pagianation(self): token = "s0_0_0_0_0" @@ -1061,4 +1054,4 @@ class RoomMessageListTestCase(RestTestCase): self.assertTrue("start" in response) self.assertEquals(token, response['start']) self.assertTrue("chunk" in response) - self.assertTrue("end" in response) \ No newline at end of file + self.assertTrue("end" in response) -- cgit 1.4.1 From 7ed2bbeb11d99ef97672497879e480f91db9b99b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 28 Jan 2016 14:32:05 +0000 Subject: Clean up a bit. Add comment --- synapse/app/homeserver.py | 13 +++++++------ synapse/server.py | 4 ++-- synapse/storage/__init__.py | 27 +++++++++------------------ 3 files changed, 18 insertions(+), 26 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 504557b2fc..65562222cf 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -255,12 +255,13 @@ class SynapseHomeServer(HomeServer): quit_with_error(e.message) def get_db_conn(self): - db_conn = self.database_engine.module.connect( - **{ - k: v for k, v in self.db_config.get("args", {}).items() - if not k.startswith("cp_") - } - ) + # Any param beginning with cp_ is a parameter for adbapi, and should + # not be passed to the database engine. + db_params = { + k: v for k, v in self.db_config.get("args", {}).items() + if not k.startswith("cp_") + } + db_conn = self.database_engine.module.connect(**db_params) self.database_engine.on_new_connection(db_conn) return db_conn diff --git a/synapse/server.py b/synapse/server.py index e013a349c9..5fee7fe130 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -28,7 +28,7 @@ from synapse.notifier import Notifier from synapse.api.auth import Auth from synapse.handlers import Handlers from synapse.state import StateHandler -from synapse.storage import get_datastore +from synapse.storage import DataStore from synapse.util import Clock from synapse.util.distributor import Distributor from synapse.streams.events import EventSources @@ -117,7 +117,7 @@ class HomeServer(object): def setup(self): logger.info("Setting up.") - self.datastore = get_datastore(self) + self.datastore = DataStore(self.get_db_conn(), self) logger.info("Finished setting up.") def get_ip_from_request(self, request): diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index c8cab45f77..eb88842308 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -61,22 +61,6 @@ logger = logging.getLogger(__name__) LAST_SEEN_GRANULARITY = 120*1000 -def get_datastore(hs): - logger.info("getting called!") - - conn = hs.get_db_conn() - try: - cur = conn.cursor() - cur.execute("SELECT MIN(stream_ordering) FROM events",) - rows = cur.fetchall() - min_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1 - min_token = min(min_token, -1) - - return DataStore(conn, hs, min_token) - finally: - conn.close() - - class DataStore(RoomMemberStore, RoomStore, RegistrationStore, StreamStore, ProfileStore, PresenceStore, TransactionStore, @@ -98,10 +82,17 @@ class DataStore(RoomMemberStore, RoomStore, EventPushActionsStore ): - def __init__(self, db_conn, hs, min_stream_token): + def __init__(self, db_conn, hs): self.hs = hs - self.min_stream_token = min_stream_token + cur = db_conn.cursor() + try: + cur.execute("SELECT MIN(stream_ordering) FROM events",) + rows = cur.fetchall() + self.min_stream_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1 + self.min_stream_token = min(self.min_stream_token, -1) + finally: + cur.close() self.client_ip_last_seen = Cache( name="client_ip_last_seen", -- cgit 1.4.1 From e1941442d442fe62570551071edfd936304697e7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 28 Jan 2016 15:02:37 +0000 Subject: Invalidate caches properly. Remove unused arg --- synapse/storage/events.py | 9 ++++++--- synapse/storage/receipts.py | 10 ++++++---- synapse/storage/stream.py | 2 +- synapse/util/caches/room_change_cache.py | 4 ++-- 4 files changed, 15 insertions(+), 10 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 80187722ea..2d2270b297 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -128,9 +128,6 @@ class EventsStore(SQLBaseStore): is_new_state=is_new_state, current_state=current_state, ) - self._events_stream_cache.room_has_changed( - None, event.room_id, stream_ordering - ) except _RollbackButIsFineException: pass @@ -213,6 +210,12 @@ class EventsStore(SQLBaseStore): for event, _ in events_and_contexts: txn.call_after(self._invalidate_get_event_cache, event.event_id) + if not backfilled: + txn.call_after( + self._events_stream_cache.room_has_changed, + event.room_id, event.internal_metadata.stream_ordering, + ) + depth_updates = {} for event, _ in events_and_contexts: if event.internal_metadata.is_outlier(): diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 7118368d97..5ffbfdec51 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -78,7 +78,7 @@ class ReceiptsStore(SQLBaseStore): if from_key: room_ids = yield self._receipts_stream_cache.get_rooms_changed( - self, room_ids, from_key + room_ids, from_key ) results = yield self._get_linearized_receipts_for_rooms( @@ -221,6 +221,11 @@ class ReceiptsStore(SQLBaseStore): # FIXME: This shouldn't invalidate the whole cache txn.call_after(self.get_linearized_receipts_for_room.invalidate_all) + txn.call_after( + self._receipts_stream_cache.room_has_changed, + room_id, stream_id + ) + # We don't want to clobber receipts for more recent events, so we # have to compare orderings of existing receipts sql = ( @@ -308,9 +313,6 @@ class ReceiptsStore(SQLBaseStore): stream_id_manager = yield self._receipts_id_gen.get_next(self) with stream_id_manager as stream_id: - yield self._receipts_stream_cache.room_has_changed( - self, room_id, stream_id - ) have_persisted = yield self.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 5096b46864..67e7e6a76f 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -169,7 +169,7 @@ class StreamStore(SQLBaseStore): from_id = RoomStreamToken.parse_stream_token(from_key).stream room_ids = yield self._events_stream_cache.get_rooms_changed( - self, room_ids, from_id + room_ids, from_id ) if not room_ids: diff --git a/synapse/util/caches/room_change_cache.py b/synapse/util/caches/room_change_cache.py index 3a873c9c30..eb2ab5f1e4 100644 --- a/synapse/util/caches/room_change_cache.py +++ b/synapse/util/caches/room_change_cache.py @@ -51,7 +51,7 @@ class RoomStreamChangeCache(object): return False - def get_rooms_changed(self, store, room_ids, key): + def get_rooms_changed(self, room_ids, key): """Returns subset of room ids that have had new things since the given key. If the key is too old it will just return the given list. """ @@ -70,7 +70,7 @@ class RoomStreamChangeCache(object): return result - def room_has_changed(self, store, room_id, key): + def room_has_changed(self, room_id, key): """Informs the cache that the room has been changed at the given key. """ if key > self._earliest_known_key: -- cgit 1.4.1 From c23a8c783382a0789c757e16e104cf08654e6cf8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 28 Jan 2016 15:55:26 +0000 Subject: Ensure keys to RoomStreamChangeCache are ints --- synapse/storage/stream.py | 11 ++++++----- synapse/util/caches/room_change_cache.py | 6 ++++++ 2 files changed, 12 insertions(+), 5 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 67e7e6a76f..6a724193e1 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -199,12 +199,13 @@ class StreamStore(SQLBaseStore): if from_key == to_key: defer.returnValue(([], from_key)) - has_changed = yield self._events_stream_cache.get_room_has_changed( - room_id, from_id - ) + if from_id: + has_changed = yield self._events_stream_cache.get_room_has_changed( + room_id, from_id + ) - if not has_changed: - defer.returnValue(([], from_key)) + if not has_changed: + defer.returnValue(([], from_key)) def f(txn): if from_id is not None: diff --git a/synapse/util/caches/room_change_cache.py b/synapse/util/caches/room_change_cache.py index eb2ab5f1e4..e8bfedd72f 100644 --- a/synapse/util/caches/room_change_cache.py +++ b/synapse/util/caches/room_change_cache.py @@ -39,6 +39,8 @@ class RoomStreamChangeCache(object): caches_by_name[self.name] = self._cache def get_room_has_changed(self, room_id, key): + assert type(key) is int + if key <= self._earliest_known_key: return True @@ -55,6 +57,8 @@ class RoomStreamChangeCache(object): """Returns subset of room ids that have had new things since the given key. If the key is too old it will just return the given list. """ + assert type(key) is int + if key > self._earliest_known_key: keys = self._cache.keys() i = keys.bisect_right(key) @@ -73,6 +77,8 @@ class RoomStreamChangeCache(object): def room_has_changed(self, room_id, key): """Informs the cache that the room has been changed at the given key. """ + assert type(key) is int + if key > self._earliest_known_key: old_key = self._room_to_key.get(room_id, None) if old_key: -- cgit 1.4.1 From 00cb3eb24b277bb37bd1b7d8449c08a37cb4b014 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 28 Jan 2016 16:37:41 +0000 Subject: Cache tags and account data --- synapse/storage/account_data.py | 20 ++++++- synapse/storage/events.py | 2 +- synapse/storage/receipts.py | 8 +-- synapse/storage/stream.py | 8 +-- synapse/storage/tags.py | 14 +++++ synapse/util/caches/room_change_cache.py | 92 ----------------------------- synapse/util/caches/stream_change_cache.py | 95 ++++++++++++++++++++++++++++++ 7 files changed, 137 insertions(+), 102 deletions(-) delete mode 100644 synapse/util/caches/room_change_cache.py create mode 100644 synapse/util/caches/stream_change_cache.py (limited to 'synapse/storage') diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index 9c6597e012..95294c3f6c 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -14,6 +14,7 @@ # limitations under the License. from ._base import SQLBaseStore +from synapse.util.caches.stream_change_cache import StreamChangeCache from twisted.internet import defer import ujson as json @@ -23,6 +24,13 @@ logger = logging.getLogger(__name__) class AccountDataStore(SQLBaseStore): + def __init__(self, hs): + super(AccountDataStore, self).__init__(hs) + + self._account_data_stream_cache = StreamChangeCache( + "AccountDataChangeCache", self._account_data_id_gen.get_max_token(None), + max_size=1000, + ) def get_account_data_for_user(self, user_id): """Get all the client account_data for a user. @@ -83,7 +91,7 @@ class AccountDataStore(SQLBaseStore): "get_account_data_for_room", get_account_data_for_room_txn ) - def get_updated_account_data_for_user(self, user_id, stream_id): + def get_updated_account_data_for_user(self, user_id, stream_id, room_ids=None): """Get all the client account_data for a that's changed. Args: @@ -120,6 +128,12 @@ class AccountDataStore(SQLBaseStore): return (global_account_data, account_data_by_room) + changed = self._account_data_stream_cache.get_entity_has_changed( + user_id, int(stream_id) + ) + if not changed: + defer.returnValue(({}, {})) + return self.runInteraction( "get_updated_account_data_for_user", get_updated_account_data_for_user_txn ) @@ -186,6 +200,10 @@ class AccountDataStore(SQLBaseStore): "content": content_json, } ) + txn.call_after( + self._account_data_stream_cache.entity_has_changed, + user_id, next_id, + ) self._update_max_stream_id(txn, next_id) with (yield self._account_data_id_gen.get_next(self)) as next_id: diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 2d2270b297..5e85552029 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -212,7 +212,7 @@ class EventsStore(SQLBaseStore): if not backfilled: txn.call_after( - self._events_stream_cache.room_has_changed, + self._events_stream_cache.entity_has_changed, event.room_id, event.internal_metadata.stream_ordering, ) diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 5ffbfdec51..8068c73740 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -15,7 +15,7 @@ from ._base import SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached -from synapse.util.caches.room_change_cache import RoomStreamChangeCache +from synapse.util.caches.stream_change_cache import StreamChangeCache from twisted.internet import defer @@ -30,7 +30,7 @@ class ReceiptsStore(SQLBaseStore): def __init__(self, hs): super(ReceiptsStore, self).__init__(hs) - self._receipts_stream_cache = RoomStreamChangeCache( + self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token(None) ) @@ -77,7 +77,7 @@ class ReceiptsStore(SQLBaseStore): room_ids = set(room_ids) if from_key: - room_ids = yield self._receipts_stream_cache.get_rooms_changed( + room_ids = yield self._receipts_stream_cache.get_entities_changed( room_ids, from_key ) @@ -222,7 +222,7 @@ class ReceiptsStore(SQLBaseStore): txn.call_after(self.get_linearized_receipts_for_room.invalidate_all) txn.call_after( - self._receipts_stream_cache.room_has_changed, + self._receipts_stream_cache.entity_has_changed, room_id, stream_id ) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 6a724193e1..c7d7893328 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -37,7 +37,7 @@ from twisted.internet import defer from ._base import SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks -from synapse.util.caches.room_change_cache import RoomStreamChangeCache +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.api.constants import EventTypes from synapse.types import RoomStreamToken from synapse.util.logutils import log_function @@ -81,7 +81,7 @@ class StreamStore(SQLBaseStore): def __init__(self, hs): super(StreamStore, self).__init__(hs) - self._events_stream_cache = RoomStreamChangeCache( + self._events_stream_cache = StreamChangeCache( "EventsRoomStreamChangeCache", self._stream_id_gen.get_max_token(None) ) @@ -168,7 +168,7 @@ class StreamStore(SQLBaseStore): def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0): from_id = RoomStreamToken.parse_stream_token(from_key).stream - room_ids = yield self._events_stream_cache.get_rooms_changed( + room_ids = yield self._events_stream_cache.get_entities_changed( room_ids, from_id ) @@ -200,7 +200,7 @@ class StreamStore(SQLBaseStore): defer.returnValue(([], from_key)) if from_id: - has_changed = yield self._events_stream_cache.get_room_has_changed( + has_changed = yield self._events_stream_cache.get_entity_has_changed( room_id, from_id ) diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index 4c39e07cbd..50af899192 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -15,6 +15,7 @@ from ._base import SQLBaseStore from synapse.util.caches.descriptors import cached +from synapse.util.caches.stream_change_cache import StreamChangeCache from twisted.internet import defer import ujson as json @@ -24,6 +25,13 @@ logger = logging.getLogger(__name__) class TagsStore(SQLBaseStore): + def __init__(self, hs): + super(TagsStore, self).__init__(hs) + + self._tags_stream_cache = StreamChangeCache( + "TagsChangeCache", self._account_data_id_gen.get_max_token(None), + max_size=1000, + ) def get_max_account_data_stream_id(self): """Get the current max stream id for the private user data stream @@ -80,6 +88,10 @@ class TagsStore(SQLBaseStore): room_ids = [row[0] for row in txn.fetchall()] return room_ids + changed = self._tags_stream_cache.get_entity_has_changed(user_id, int(stream_id)) + if not changed: + defer.returnValue({}) + room_ids = yield self.runInteraction( "get_updated_tags", get_updated_tags_txn ) @@ -177,6 +189,8 @@ class TagsStore(SQLBaseStore): next_id(int): The the revision to advance to. """ + txn.call_after(self._tags_stream_cache.entity_has_changed, user_id, next_id) + update_max_id_sql = ( "UPDATE account_data_max_stream_id" " SET stream_id = ?" diff --git a/synapse/util/caches/room_change_cache.py b/synapse/util/caches/room_change_cache.py deleted file mode 100644 index e8bfedd72f..0000000000 --- a/synapse/util/caches/room_change_cache.py +++ /dev/null @@ -1,92 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.util.caches import cache_counter, caches_by_name - - -from blist import sorteddict -import logging - - -logger = logging.getLogger(__name__) - - -class RoomStreamChangeCache(object): - """Keeps track of the stream_id of the latest change in rooms. - - Given a list of rooms and stream key, it will give a subset of rooms that - may have changed since that key. If the key is too old then the cache - will simply return all rooms. - """ - def __init__(self, name, current_key, size_of_cache=10000): - self._size_of_cache = size_of_cache - self._room_to_key = {} - self._cache = sorteddict() - self._earliest_known_key = current_key - self.name = name - caches_by_name[self.name] = self._cache - - def get_room_has_changed(self, room_id, key): - assert type(key) is int - - if key <= self._earliest_known_key: - return True - - room_key = self._room_to_key.get(room_id, None) - if room_key is None: - return True - - if key < room_key: - return True - - return False - - def get_rooms_changed(self, room_ids, key): - """Returns subset of room ids that have had new things since the - given key. If the key is too old it will just return the given list. - """ - assert type(key) is int - - if key > self._earliest_known_key: - keys = self._cache.keys() - i = keys.bisect_right(key) - - result = set( - self._cache[k] for k in keys[i:] - ).intersection(room_ids) - - cache_counter.inc_hits(self.name) - else: - result = room_ids - cache_counter.inc_misses(self.name) - - return result - - def room_has_changed(self, room_id, key): - """Informs the cache that the room has been changed at the given key. - """ - assert type(key) is int - - if key > self._earliest_known_key: - old_key = self._room_to_key.get(room_id, None) - if old_key: - key = max(key, old_key) - self._cache.pop(old_key, None) - self._cache[key] = room_id - - while len(self._cache) > self._size_of_cache: - k, r = self._cache.popitem() - self._earliest_key = max(k, self._earliest_key) - self._room_to_key.pop(r, None) diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py new file mode 100644 index 0000000000..33b37f7f29 --- /dev/null +++ b/synapse/util/caches/stream_change_cache.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.util.caches import cache_counter, caches_by_name + + +from blist import sorteddict +import logging + + +logger = logging.getLogger(__name__) + + +class StreamChangeCache(object): + """Keeps track of the stream positions of the latest change in a set of entities. + + Typically the entity will be a room or user id. + + Given a list of entities and a stream position, it will give a subset of + entities that may have changed since that position. If position key is too + old then the cache will simply return all given entities. + """ + def __init__(self, name, current_stream_pos, max_size=10000): + self._max_size = max_size + self._entity_to_key = {} + self._cache = sorteddict() + self._earliest_known_stream_pos = current_stream_pos + self.name = name + caches_by_name[self.name] = self._cache + + def get_entity_has_changed(self, entity, stream_pos): + assert type(stream_pos) is int + + if stream_pos <= self._earliest_known_stream_pos: + return True + + latest_entity_change_pos = self._entity_to_key.get(entity, None) + if latest_entity_change_pos is None: + return True + + if stream_pos < latest_entity_change_pos: + return True + + return False + + def get_entities_changed(self, entities, stream_pos): + """Returns subset of entities that have had new things since the + given position. If the position is too old it will just return the given list. + """ + assert type(stream_pos) is int + + if stream_pos > self._earliest_known_stream_pos: + keys = self._cache.keys() + i = keys.bisect_right(stream_pos) + + result = set( + self._cache[k] for k in keys[i:] + ).intersection(entities) + + cache_counter.inc_hits(self.name) + else: + result = entities + cache_counter.inc_misses(self.name) + + return result + + def entity_has_changed(self, entitiy, stream_pos): + """Informs the cache that the entitiy has been changed at the given + position. + """ + assert type(stream_pos) is int + + if stream_pos > self._earliest_known_stream_pos: + old_pos = self._entity_to_key.get(entitiy, None) + if old_pos: + stream_pos = max(stream_pos, old_pos) + self._cache.pop(old_pos, None) + self._cache[stream_pos] = entitiy + + while len(self._cache) > self._max_size: + k, r = self._cache.popitem() + self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) + self._entity_to_key.pop(r, None) -- cgit 1.4.1 From 45cf827c8fe7163a51f1d0d7c9e2531da9b58c8d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 28 Jan 2016 16:39:18 +0000 Subject: Change name and doc has_entity_changed --- synapse/storage/account_data.py | 2 +- synapse/storage/stream.py | 2 +- synapse/storage/tags.py | 2 +- synapse/util/caches/stream_change_cache.py | 4 +++- 4 files changed, 6 insertions(+), 4 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index 95294c3f6c..62e49e1c0e 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -128,7 +128,7 @@ class AccountDataStore(SQLBaseStore): return (global_account_data, account_data_by_room) - changed = self._account_data_stream_cache.get_entity_has_changed( + changed = self._account_data_stream_cache.has_entity_changed( user_id, int(stream_id) ) if not changed: diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index c7d7893328..6e81d46c60 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -200,7 +200,7 @@ class StreamStore(SQLBaseStore): defer.returnValue(([], from_key)) if from_id: - has_changed = yield self._events_stream_cache.get_entity_has_changed( + has_changed = yield self._events_stream_cache.has_entity_changed( room_id, from_id ) diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index 50af899192..75ce04092d 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -88,7 +88,7 @@ class TagsStore(SQLBaseStore): room_ids = [row[0] for row in txn.fetchall()] return room_ids - changed = self._tags_stream_cache.get_entity_has_changed(user_id, int(stream_id)) + changed = self._tags_stream_cache.has_entity_changed(user_id, int(stream_id)) if not changed: defer.returnValue({}) diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 33b37f7f29..3ca0e57780 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -40,7 +40,9 @@ class StreamChangeCache(object): self.name = name caches_by_name[self.name] = self._cache - def get_entity_has_changed(self, entity, stream_pos): + def has_entity_changed(self, entity, stream_pos): + """Returns True if the entity may have been updated since stream_pos + """ assert type(stream_pos) is int if stream_pos <= self._earliest_known_stream_pos: -- cgit 1.4.1 From fdca8ec4187e1b4ea93cbfe17ece6ca4cbadd519 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 28 Jan 2016 16:41:59 +0000 Subject: Add events index --- synapse/storage/schema/delta/28/events_room_stream.sql | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 synapse/storage/schema/delta/28/events_room_stream.sql (limited to 'synapse/storage') diff --git a/synapse/storage/schema/delta/28/events_room_stream.sql b/synapse/storage/schema/delta/28/events_room_stream.sql new file mode 100644 index 0000000000..200c35e6e2 --- /dev/null +++ b/synapse/storage/schema/delta/28/events_room_stream.sql @@ -0,0 +1,16 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +CREATE INDEX events_room_stream on events(room_id, stream_ordering); -- cgit 1.4.1 From 8fe8951a8d997a652d16758a7abed2b8afd117e2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 28 Jan 2016 17:09:09 +0000 Subject: Cache filters --- synapse/storage/filtering.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py index f8fc9bdddc..5248736816 100644 --- a/synapse/storage/filtering.py +++ b/synapse/storage/filtering.py @@ -16,12 +16,13 @@ from twisted.internet import defer from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cachedInlineCallbacks import simplejson as json class FilteringStore(SQLBaseStore): - @defer.inlineCallbacks + @cachedInlineCallbacks(num_args=2) def get_user_filter(self, user_localpart, filter_id): def_json = yield self._simple_select_one_onecol( table="user_filters", -- cgit 1.4.1 From 03b2c2577cbf51c80a42319689e1cb4903b8c4af Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 28 Jan 2016 17:29:24 +0000 Subject: Don't use defer.returnValue --- synapse/storage/account_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index 62e49e1c0e..88404059e8 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -132,7 +132,7 @@ class AccountDataStore(SQLBaseStore): user_id, int(stream_id) ) if not changed: - defer.returnValue(({}, {})) + return ({}, {}) return self.runInteraction( "get_updated_account_data_for_user", get_updated_account_data_for_user_txn -- cgit 1.4.1 From 467c27a1f90b873d6838ad1351399551cfa9cc07 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 28 Jan 2016 18:20:00 +0000 Subject: Amalgamate tags and account data stream caches --- synapse/storage/account_data.py | 3 ++- synapse/storage/tags.py | 18 +++++++----------- 2 files changed, 9 insertions(+), 12 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index 88404059e8..822c8bbe00 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -28,7 +28,8 @@ class AccountDataStore(SQLBaseStore): super(AccountDataStore, self).__init__(hs) self._account_data_stream_cache = StreamChangeCache( - "AccountDataChangeCache", self._account_data_id_gen.get_max_token(None), + "AccountDataAndTagsChangeCache", + self._account_data_id_gen.get_max_token(None), max_size=1000, ) diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index 75ce04092d..e1a9c0c261 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -15,7 +15,6 @@ from ._base import SQLBaseStore from synapse.util.caches.descriptors import cached -from synapse.util.caches.stream_change_cache import StreamChangeCache from twisted.internet import defer import ujson as json @@ -25,14 +24,6 @@ logger = logging.getLogger(__name__) class TagsStore(SQLBaseStore): - def __init__(self, hs): - super(TagsStore, self).__init__(hs) - - self._tags_stream_cache = StreamChangeCache( - "TagsChangeCache", self._account_data_id_gen.get_max_token(None), - max_size=1000, - ) - def get_max_account_data_stream_id(self): """Get the current max stream id for the private user data stream @@ -88,7 +79,9 @@ class TagsStore(SQLBaseStore): room_ids = [row[0] for row in txn.fetchall()] return room_ids - changed = self._tags_stream_cache.has_entity_changed(user_id, int(stream_id)) + changed = self._account_data_stream_cache.has_entity_changed( + user_id, int(stream_id) + ) if not changed: defer.returnValue({}) @@ -189,7 +182,10 @@ class TagsStore(SQLBaseStore): next_id(int): The the revision to advance to. """ - txn.call_after(self._tags_stream_cache.entity_has_changed, user_id, next_id) + txn.call_after( + self._account_data_stream_cache.entity_has_changed, + user_id, next_id + ) update_max_id_sql = ( "UPDATE account_data_max_stream_id" -- cgit 1.4.1 From ebc5f00efed5c7f72601933f55032947077c50a0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 29 Jan 2016 13:37:40 +0000 Subject: Bump AccountDataAndTagsChangeCache size --- synapse/storage/account_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index 822c8bbe00..ed6587429b 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -30,7 +30,7 @@ class AccountDataStore(SQLBaseStore): self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", self._account_data_id_gen.get_max_token(None), - max_size=1000, + max_size=10000, ) def get_account_data_for_user(self, user_id): -- cgit 1.4.1 From 18579534ea67f2d98c189e2ddeccc4bfecb491eb Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 29 Jan 2016 14:37:59 +0000 Subject: Prefill stream change caches --- synapse/storage/__init__.py | 49 +++++++++++++++++++++++++++++- synapse/storage/account_data.py | 9 ------ synapse/storage/stream.py | 8 ----- synapse/util/caches/stream_change_cache.py | 5 ++- 4 files changed, 52 insertions(+), 19 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index eb88842308..95ae97d507 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -45,9 +45,10 @@ from .search import SearchStore from .tags import TagsStore from .account_data import AccountDataStore - from util.id_generators import IdGenerator, StreamIdGenerator +from synapse.util.caches.stream_change_cache import StreamChangeCache + import logging @@ -117,8 +118,54 @@ class DataStore(RoomMemberStore, RoomStore, self._push_rule_id_gen = IdGenerator("push_rules", "id", self) self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) + events_max = self._stream_id_gen.get_max_token(None) + event_cache_prefill = self._get_cache_dict( + db_conn, "events", + entity_column="room_id", + stream_column="stream_ordering", + max_value=events_max, + ) + self._events_stream_cache = StreamChangeCache( + "EventsRoomStreamChangeCache", events_max, + prefilled_cache=event_cache_prefill, + ) + + account_max = self._account_data_id_gen.get_max_token(None) + account_cache_prefill = self._get_cache_dict( + db_conn, "account_data", + entity_column="user_id", + stream_column="stream_id", + max_value=account_max, + ) + self._account_data_stream_cache = StreamChangeCache( + "AccountDataAndTagsChangeCache", account_max, + prefilled_cache=account_cache_prefill, + ) + super(DataStore, self).__init__(hs) + def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value): + sql = ( + "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" + " WHERE %(stream)s > max(? - 100000, 0)" + " GROUP BY %(entity)s" + " ORDER BY MAX(%(stream)s) DESC" + " LIMIT 10000" + ) % { + "table": table, + "entity": entity_column, + "stream": stream_column, + } + + txn = db_conn.cursor() + txn.execute(sql, (int(max_value),)) + rows = txn.fetchall() + + return { + row[0]: row[1] + for row in rows + } + @defer.inlineCallbacks def insert_client_ip(self, user, access_token, ip, user_agent): now = int(self._clock.time_msec()) diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index ed6587429b..625d062eb1 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -14,7 +14,6 @@ # limitations under the License. from ._base import SQLBaseStore -from synapse.util.caches.stream_change_cache import StreamChangeCache from twisted.internet import defer import ujson as json @@ -24,14 +23,6 @@ logger = logging.getLogger(__name__) class AccountDataStore(SQLBaseStore): - def __init__(self, hs): - super(AccountDataStore, self).__init__(hs) - - self._account_data_stream_cache = StreamChangeCache( - "AccountDataAndTagsChangeCache", - self._account_data_id_gen.get_max_token(None), - max_size=10000, - ) def get_account_data_for_user(self, user_id): """Get all the client account_data for a user. diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 6e81d46c60..e245d2f914 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -37,7 +37,6 @@ from twisted.internet import defer from ._base import SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks -from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.api.constants import EventTypes from synapse.types import RoomStreamToken from synapse.util.logutils import log_function @@ -78,13 +77,6 @@ def upper_bound(token): class StreamStore(SQLBaseStore): - def __init__(self, hs): - super(StreamStore, self).__init__(hs) - - self._events_stream_cache = StreamChangeCache( - "EventsRoomStreamChangeCache", self._stream_id_gen.get_max_token(None) - ) - @defer.inlineCallbacks def get_appservice_room_stream(self, service, from_key, to_key, limit=0): # NB this lives here instead of appservice.py so we can reuse the diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index c673b1bdfc..891cb619fa 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -32,7 +32,7 @@ class StreamChangeCache(object): entities that may have changed since that position. If position key is too old then the cache will simply return all given entities. """ - def __init__(self, name, current_stream_pos, max_size=10000): + def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache={}): self._max_size = max_size self._entity_to_key = {} self._cache = sorteddict() @@ -40,6 +40,9 @@ class StreamChangeCache(object): self.name = name caches_by_name[self.name] = self._cache + for entity, stream_pos in prefilled_cache.items(): + self.entity_has_changed(entity, stream_pos) + def has_entity_changed(self, entity, stream_pos): """Returns True if the entity may have been updated since stream_pos """ -- cgit 1.4.1 From f67d60496a8a9b2c95fcacb6d4c539a1d4b6a105 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 29 Jan 2016 14:41:16 +0000 Subject: Convert param style --- synapse/storage/__init__.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 95ae97d507..2ed505cb1e 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -85,6 +85,7 @@ class DataStore(RoomMemberStore, RoomStore, def __init__(self, db_conn, hs): self.hs = hs + self.database_engine = hs.database_engine cur = db_conn.cursor() try: @@ -157,6 +158,8 @@ class DataStore(RoomMemberStore, RoomStore, "stream": stream_column, } + sql = self.database_engine.convert_param_style(sql) + txn = db_conn.cursor() txn.execute(sql, (int(max_value),)) rows = txn.fetchall() -- cgit 1.4.1 From 45488e0ffae5100c3a82568642736aff203e1602 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 29 Jan 2016 14:42:01 +0000 Subject: Max is not a function --- synapse/storage/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 2ed505cb1e..4d374a8b07 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -148,7 +148,7 @@ class DataStore(RoomMemberStore, RoomStore, def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value): sql = ( "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" - " WHERE %(stream)s > max(? - 100000, 0)" + " WHERE %(stream)s > ? - 100000" " GROUP BY %(entity)s" " ORDER BY MAX(%(stream)s) DESC" " LIMIT 10000" -- cgit 1.4.1 From 3d60686c0ceeb88c4f6269110e92dc0c7bf5a3b6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 29 Jan 2016 14:49:11 +0000 Subject: Actually use cache --- synapse/storage/__init__.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 4d374a8b07..957fff3c23 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -120,26 +120,26 @@ class DataStore(RoomMemberStore, RoomStore, self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) events_max = self._stream_id_gen.get_max_token(None) - event_cache_prefill = self._get_cache_dict( + event_cache_prefill, min_event_val = self._get_cache_dict( db_conn, "events", entity_column="room_id", stream_column="stream_ordering", max_value=events_max, ) self._events_stream_cache = StreamChangeCache( - "EventsRoomStreamChangeCache", events_max, + "EventsRoomStreamChangeCache", min_event_val, prefilled_cache=event_cache_prefill, ) account_max = self._account_data_id_gen.get_max_token(None) - account_cache_prefill = self._get_cache_dict( + account_cache_prefill, min_acc_val = self._get_cache_dict( db_conn, "account_data", entity_column="user_id", stream_column="stream_id", max_value=account_max, ) self._account_data_stream_cache = StreamChangeCache( - "AccountDataAndTagsChangeCache", account_max, + "AccountDataAndTagsChangeCache", min_acc_val, prefilled_cache=account_cache_prefill, ) @@ -151,7 +151,6 @@ class DataStore(RoomMemberStore, RoomStore, " WHERE %(stream)s > ? - 100000" " GROUP BY %(entity)s" " ORDER BY MAX(%(stream)s) DESC" - " LIMIT 10000" ) % { "table": table, "entity": entity_column, @@ -164,11 +163,18 @@ class DataStore(RoomMemberStore, RoomStore, txn.execute(sql, (int(max_value),)) rows = txn.fetchall() - return { - row[0]: row[1] + cache = { + row[0]: int(row[1]) for row in rows } + if cache: + min_val = min(cache.values()) + else: + min_val = max_value + + return cache, min_val + @defer.inlineCallbacks def insert_client_ip(self, user, access_token, ip, user_agent): now = int(self._clock.time_msec()) -- cgit 1.4.1 From b5dbced9389d072d4bd15002c7ddffba9e54340e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 29 Jan 2016 14:53:59 +0000 Subject: Don't prefill account data --- synapse/storage/__init__.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 957fff3c23..a6cb588563 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -132,15 +132,8 @@ class DataStore(RoomMemberStore, RoomStore, ) account_max = self._account_data_id_gen.get_max_token(None) - account_cache_prefill, min_acc_val = self._get_cache_dict( - db_conn, "account_data", - entity_column="user_id", - stream_column="stream_id", - max_value=account_max, - ) self._account_data_stream_cache = StreamChangeCache( - "AccountDataAndTagsChangeCache", min_acc_val, - prefilled_cache=account_cache_prefill, + "AccountDataAndTagsChangeCache", account_max, ) super(DataStore, self).__init__(hs) -- cgit 1.4.1 From 8da95b6f1bb1a37597f0b89c4da88b064401b0b8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 29 Jan 2016 15:39:17 +0000 Subject: Comment. Remove superfluous order by --- synapse/storage/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index a6cb588563..ee2153737d 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -139,11 +139,13 @@ class DataStore(RoomMemberStore, RoomStore, super(DataStore, self).__init__(hs) def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value): + # Fetch a mapping of room_id -> max stream position for "recent" rooms. + # It doesn't really matter how many we get, the StreamChangeCache will + # do the right thing to ensure it respects the max size of cache. sql = ( "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" " WHERE %(stream)s > ? - 100000" " GROUP BY %(entity)s" - " ORDER BY MAX(%(stream)s) DESC" ) % { "table": table, "entity": entity_column, -- cgit 1.4.1 From cc9c97e0dc0cf399d5d6013f12746063091b619e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 29 Jan 2016 16:41:51 +0000 Subject: Invalidate _account_data_stream_cache correctly --- synapse/storage/account_data.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'synapse/storage') diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index 625d062eb1..b8387fc500 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -157,6 +157,10 @@ class AccountDataStore(SQLBaseStore): "content": content_json, } ) + txn.call_after( + self._account_data_stream_cache.entity_has_changed, + user_id, next_id, + ) self._update_max_stream_id(txn, next_id) with (yield self._account_data_id_gen.get_next(self)) as next_id: -- cgit 1.4.1 From 25c311eaf603cef8cbf9e6501aad83d53c304ebb Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 29 Jan 2016 16:52:48 +0000 Subject: Cache get_room_changes_for_user --- synapse/storage/__init__.py | 4 ++++ synapse/storage/roommember.py | 4 ++++ synapse/storage/stream.py | 7 +++++++ 3 files changed, 15 insertions(+) (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index ee2153737d..c91c7a3729 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -131,6 +131,10 @@ class DataStore(RoomMemberStore, RoomStore, prefilled_cache=event_cache_prefill, ) + self._membership_stream_cache = StreamChangeCache( + "MembershipStreamChangeCache", events_max, + ) + account_max = self._account_data_id_gen.get_max_token(None) self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max, diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 1d3e004c90..3065b0c1a5 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -58,6 +58,10 @@ class RoomMemberStore(SQLBaseStore): txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,)) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) + txn.call_after( + self._membership_stream_cache.entity_has_changed, + event.state_key, event.internal_metadata.stream_ordering + ) def get_room_member(self, user_id, room_id): """Retrieve the current state of a room member. diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index e245d2f914..cc9e623608 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -252,6 +252,13 @@ class StreamStore(SQLBaseStore): if from_key == to_key: return defer.succeed([]) + if from_id: + has_changed = self._membership_stream_cache.has_entity_changed( + user_id, int(from_id) + ) + if not has_changed: + return defer.succeed([]) + def f(txn): if from_id is not None: sql = ( -- cgit 1.4.1 From ceb6b8680a8e419c3c132dcdb675517c5e9f69fd Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 1 Feb 2016 10:33:52 +0000 Subject: Only use room_ids if in get_room_events_stream if is_guest --- synapse/storage/stream.py | 5 ----- 1 file changed, 5 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index e245d2f914..a60e662f7d 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -316,11 +316,6 @@ class StreamStore(SQLBaseStore): " WHERE m.user_id = ? AND m.membership = 'join'" ) current_room_membership_args = [user_id] - if room_ids: - current_room_membership_sql += " AND m.room_id in (%s)" % ( - ",".join(map(lambda _: "?", room_ids)) - ) - current_room_membership_args = [user_id] + room_ids # We also want to get any membership events about that user, e.g. # invites or leave notifications. -- cgit 1.4.1 From 4bf448be254808c83aeb5ae28e601752664bc9e2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 1 Feb 2016 16:26:51 +0000 Subject: Switch over /events to use per room caches --- synapse/handlers/room.py | 25 ++++++++++++++++++++----- synapse/storage/stream.py | 4 ++-- 2 files changed, 22 insertions(+), 7 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 58e2d25f97..aca795e1c4 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1008,15 +1008,30 @@ class RoomEventSource(object): limit=limit, ) else: - events, end_key = yield self.store.get_room_events_stream( - user_id=user.to_string(), + room_events = yield self.store.get_room_changes_for_user( + user.to_string(), from_key, to_key + ) + + room_to_events = yield self.store.get_room_events_stream_for_rooms( + room_ids=room_ids, from_key=from_key, to_key=to_key, - limit=limit, - room_ids=room_ids, - is_guest=is_guest, + limit=limit or 10, ) + events = list(room_events) + events.extend(e for evs, _ in room_to_events.values() for e in evs) + + events.sort(key=lambda e: e.internal_metadata.after) + + if limit: + events[:] = events[:limit] + + if events: + end_key = events[-1].internal_metadata.after + else: + end_key = to_key + defer.returnValue((events, end_key)) def get_current_key(self, direction='f'): diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 8dc8f5c640..fd84aa8996 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -39,7 +39,6 @@ from ._base import SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.api.constants import EventTypes from synapse.types import RoomStreamToken -from synapse.util.logutils import log_function import logging @@ -288,11 +287,12 @@ class StreamStore(SQLBaseStore): get_prev_content=True ) + self._set_before_and_after(ret, rows, topo_order=False) + return ret return self.runInteraction("get_room_changes_for_user", f) - @log_function def get_room_events_stream( self, user_id, -- cgit 1.4.1 From 89b40b225cda4326081f6735b2a8a9bff5ce3446 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 1 Feb 2016 16:32:46 +0000 Subject: Order things correctly --- synapse/handlers/room.py | 2 +- synapse/storage/stream.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index aca795e1c4..a71cba8ef1 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1022,7 +1022,7 @@ class RoomEventSource(object): events = list(room_events) events.extend(e for evs, _ in room_to_events.values() for e in evs) - events.sort(key=lambda e: e.internal_metadata.after) + events.sort(key=lambda e: e.internal_metadata.order) if limit: events[:] = events[:limit] diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index fd84aa8996..a03458c2fc 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -598,6 +598,10 @@ class StreamStore(SQLBaseStore): internal = event.internal_metadata internal.before = str(RoomStreamToken(topo, stream - 1)) internal.after = str(RoomStreamToken(topo, stream)) + internal.order = ( + int(topo) if topo else 0, + int(stream), + ) @defer.inlineCallbacks def get_events_around(self, room_id, event_id, before_limit, after_limit): -- cgit 1.4.1 From 65e92eca4912848b03f71b7b7d29727015be31ce Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 2 Feb 2016 15:19:34 +0000 Subject: Change the way we do public room list fetching --- synapse/handlers/room.py | 86 ++++++++++++++++------ synapse/storage/room.py | 2 +- .../storage/schema/delta/28/public_roms_index.sql | 16 ++++ 3 files changed, 80 insertions(+), 24 deletions(-) create mode 100644 synapse/storage/schema/delta/28/public_roms_index.sql (limited to 'synapse/storage') diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index a71cba8ef1..1b3f624c67 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -876,31 +876,71 @@ class RoomListHandler(BaseHandler): @defer.inlineCallbacks def get_public_room_list(self): - chunk = yield self.store.get_rooms(is_public=True) - - room_members = yield defer.gatherResults( - [ - self.store.get_users_in_room(room["room_id"]) - for room in chunk - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) - - avatar_urls = yield defer.gatherResults( - [ - self.get_room_avatar_url(room["room_id"]) - for room in chunk - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) - - for i, room in enumerate(chunk): - room["num_joined_members"] = len(room_members[i]) - if avatar_urls[i]: - room["avatar_url"] = avatar_urls[i] + room_ids = yield self.store.get_public_room_ids() + + @defer.inlineCallbacks + def handle_room(room_id): + aliases = yield self.store.get_aliases_for_room(room_id) + if not aliases: + defer.returnValue(None) + + state = yield self.state_handler.get_current_state(room_id) + + result = {"aliases": aliases, "room_id": room_id} + + name_event = state.get((EventTypes.Name, ""), None) + if name_event: + name = name_event.content.get("name", None) + if name: + result["name"] = name + + topic_event = state.get((EventTypes.Topic, ""), None) + if topic_event: + topic = topic_event.content.get("topic", None) + if topic: + result["topic"] = topic + + canonical_event = state.get((EventTypes.CanonicalAlias, ""), None) + if canonical_event: + canonical_alias = canonical_event.content.get("alias", None) + if canonical_alias: + result["canonical_alias"] = canonical_alias + + visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None) + visibility = None + if visibility_event: + visibility = visibility_event.content.get("history_visibility", None) + result["world_readable"] = visibility == "world_readable" + + guest_event = state.get((EventTypes.GuestAccess, ""), None) + guest = None + if guest_event: + guest = guest_event.content.get("guest_access", None) + result["guest_can_join"] = guest == "can_join" + + avatar_event = state.get(("m.room.avatar", ""), None) + if avatar_event: + avatar_url = avatar_event.content.get("url", None) + if avatar_url: + result["avatar_url"] = avatar_url + + result["num_joined_members"] = sum( + 1 for (event_type, _), ev in state.items() + if event_type == EventTypes.Member and ev.membership == Membership.JOIN + ) + + defer.returnValue(result) + + result = [] + for chunk in (room_ids[i:i+10] for i in xrange(0, len(room_ids), 10)): + chunk_result = yield defer.gatherResults([ + handle_room(room_id) + for room_id in chunk + ], consumeErrors=True).addErrback(unwrapFirstError) + result.extend(v for v in chunk_result if v) # FIXME (erikj): START is no longer a valid value - defer.returnValue({"start": "START", "end": "END", "chunk": chunk}) + defer.returnValue({"start": "START", "end": "END", "chunk": result}) @defer.inlineCallbacks def get_room_avatar_url(self, room_id): diff --git a/synapse/storage/room.py b/synapse/storage/room.py index dc09a3aaba..1b6311f332 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import StoreError from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cachedInlineCallbacks, cached from .engines import PostgresEngine, Sqlite3Engine import collections diff --git a/synapse/storage/schema/delta/28/public_roms_index.sql b/synapse/storage/schema/delta/28/public_roms_index.sql new file mode 100644 index 0000000000..ba62a974a4 --- /dev/null +++ b/synapse/storage/schema/delta/28/public_roms_index.sql @@ -0,0 +1,16 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +CREATE INDEX public_room_index on rooms(is_public); -- cgit 1.4.1 From 477b1ed6cfd130e5a004cda0c0b84509da2aa006 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 2 Feb 2016 15:58:14 +0000 Subject: Fetch events in a separate transaction. This has a couple of benefits: - It reduces the time of transactions, allowing other database requests to run. - Fetching events is given a dedicated database thread, and so can't starve other database requests. --- synapse/storage/stream.py | 55 +++++++++++++++++++++++++---------------------- 1 file changed, 29 insertions(+), 26 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index a03458c2fc..bcae3d718e 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -220,27 +220,29 @@ class StreamStore(SQLBaseStore): rows = self.cursor_to_dict(txn) - ret = self._get_events_txn( - txn, - [r["event_id"] for r in rows], - get_prev_content=True - ) + return rows - self._set_before_and_after(ret, rows, topo_order=False) + rows = yield self.runInteraction("get_room_events_stream_for_room", f) - ret.reverse() + ret = yield self._get_events( + [r["event_id"] for r in rows], + get_prev_content=True + ) - if rows: - key = "s%d" % min(r["stream_ordering"] for r in rows) - else: - # Assume we didn't get anything because there was nothing to - # get. - key = from_key + self._set_before_and_after(ret, rows, topo_order=False) - return ret, key - res = yield self.runInteraction("get_room_events_stream_for_room", f) - defer.returnValue(res) + ret.reverse() + if rows: + key = "s%d" % min(r["stream_ordering"] for r in rows) + else: + # Assume we didn't get anything because there was nothing to + # get. + key = from_key + + defer.returnValue((ret, key)) + + @defer.inlineCallbacks def get_room_changes_for_user(self, user_id, from_key, to_key): if from_key is not None: from_id = RoomStreamToken.parse_stream_token(from_key).stream @@ -249,14 +251,14 @@ class StreamStore(SQLBaseStore): to_id = RoomStreamToken.parse_stream_token(to_key).stream if from_key == to_key: - return defer.succeed([]) + defer.returnValue([]) if from_id: has_changed = self._membership_stream_cache.has_entity_changed( user_id, int(from_id) ) if not has_changed: - return defer.succeed([]) + defer.returnValue([]) def f(txn): if from_id is not None: @@ -281,17 +283,18 @@ class StreamStore(SQLBaseStore): txn.execute(sql, (user_id, to_id,)) rows = self.cursor_to_dict(txn) - ret = self._get_events_txn( - txn, - [r["event_id"] for r in rows], - get_prev_content=True - ) + return rows + + rows = yield self.runInteraction("get_room_changes_for_user", f) - self._set_before_and_after(ret, rows, topo_order=False) + ret = yield self._get_events( + [r["event_id"] for r in rows], + get_prev_content=True + ) - return ret + self._set_before_and_after(ret, rows, topo_order=False) - return self.runInteraction("get_room_changes_for_user", f) + defer.returnValue(ret) def get_room_events_stream( self, -- cgit 1.4.1 From 8a391e33ae21f9a62c57cca8eea47435a14a6247 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 2 Feb 2016 16:12:10 +0000 Subject: s/get_room_changes_for_user/get_membership_changes_for_user/ --- synapse/handlers/room.py | 2 +- synapse/handlers/sync.py | 2 +- synapse/storage/stream.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 68e2c75a48..799221c198 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1013,7 +1013,7 @@ class RoomEventSource(object): limit=limit, ) else: - room_events = yield self.store.get_room_changes_for_user( + room_events = yield self.store.get_membership_changes_for_user( user.to_string(), from_key, to_key ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 8d8d10da33..dc686db541 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -479,7 +479,7 @@ class SyncHandler(BaseHandler): ) # Get a list of membership change events that have happened. - rooms_changed = yield self.store.get_room_changes_for_user( + rooms_changed = yield self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key ) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index bcae3d718e..338a9d40d5 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -243,7 +243,7 @@ class StreamStore(SQLBaseStore): defer.returnValue((ret, key)) @defer.inlineCallbacks - def get_room_changes_for_user(self, user_id, from_key, to_key): + def get_membership_changes_for_user(self, user_id, from_key, to_key): if from_key is not None: from_id = RoomStreamToken.parse_stream_token(from_key).stream else: @@ -285,7 +285,7 @@ class StreamStore(SQLBaseStore): return rows - rows = yield self.runInteraction("get_room_changes_for_user", f) + rows = yield self.runInteraction("get_membership_changes_for_user", f) ret = yield self._get_events( [r["event_id"] for r in rows], -- cgit 1.4.1 From d83d004ccdb7ace1dcb51b8acf7645bc176b10a5 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 2 Feb 2016 17:18:50 +0000 Subject: Fix flake8 warnings for new flake8 --- setup.cfg | 1 + synapse/api/auth.py | 2 +- synapse/app/__init__.py | 19 ++++++++++++++++ synapse/app/homeserver.py | 38 +++++++++----------------------- synapse/appservice/api.py | 2 +- synapse/federation/federation_client.py | 2 +- synapse/handlers/_base.py | 2 +- synapse/handlers/directory.py | 4 ++-- synapse/handlers/events.py | 2 +- synapse/handlers/presence.py | 2 +- synapse/handlers/register.py | 2 +- synapse/handlers/room.py | 2 +- synapse/http/matrixfederationclient.py | 2 +- synapse/notifier.py | 2 +- synapse/push/push_rule_evaluator.py | 2 +- synapse/rest/client/v1/login.py | 2 +- synapse/rest/client/v1/pusher.py | 4 ++-- synapse/rest/client/v1/register.py | 3 ++- synapse/rest/client/v2_alpha/register.py | 3 ++- synapse/rest/client/versions.py | 4 +--- synapse/server.py | 2 +- synapse/state.py | 2 +- synapse/storage/__init__.py | 2 +- synapse/storage/_base.py | 7 ++++-- synapse/storage/engines/sqlite3.py | 2 +- synapse/storage/event_federation.py | 2 +- synapse/storage/events.py | 6 ++--- synapse/storage/stream.py | 2 +- synapse/util/__init__.py | 2 +- synapse/util/caches/descriptors.py | 4 ++-- synapse/util/caches/expiringcache.py | 2 +- synapse/util/caches/treecache.py | 2 +- synapse/util/logutils.py | 2 +- synapse/util/ratelimitutils.py | 2 +- 34 files changed, 73 insertions(+), 66 deletions(-) (limited to 'synapse/storage') diff --git a/setup.cfg b/setup.cfg index ba027c7d13..e7fc5ffe78 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,3 +16,4 @@ ignore = [flake8] max-line-length = 90 +ignore = W503 diff --git a/synapse/api/auth.py b/synapse/api/auth.py index b5536e8565..c5a2865e26 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -574,7 +574,7 @@ class Auth(object): raise AuthError( 403, "Application service has not registered this user" - ) + ) defer.returnValue(user_id) @defer.inlineCallbacks diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index bfebb0f644..1bc4279807 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -12,3 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import sys +sys.dont_write_bytecode = True + +from synapse.python_dependencies import ( + check_requirements, MissingRequirementError +) # NOQA + +try: + check_requirements() +except MissingRequirementError as e: + message = "\n".join([ + "Missing Requirement: %s" % (e.message,), + "To install run:", + " pip install --upgrade --force \"%s\"" % (e.dependency,), + "", + ]) + sys.stderr.writelines(message) + sys.exit(1) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index e5066c48ef..c3066d6a0d 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -14,27 +14,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import synapse + +import contextlib +import logging +import os +import re +import resource +import subprocess import sys -from synapse.rest import ClientRestResource +import time -sys.dont_write_bytecode = True from synapse.python_dependencies import ( - check_requirements, DEPENDENCY_LINKS, MissingRequirementError + check_requirements, DEPENDENCY_LINKS ) -if __name__ == '__main__': - try: - check_requirements() - except MissingRequirementError as e: - message = "\n".join([ - "Missing Requirement: %s" % (e.message,), - "To install run:", - " pip install --upgrade --force \"%s\"" % (e.dependency,), - "", - ]) - sys.stderr.writelines(message) - sys.exit(1) - +from synapse.rest import ClientRestResource from synapse.storage.engines import create_engine, IncorrectDatabaseSetup from synapse.storage import are_all_users_on_domain from synapse.storage.prepare_database import UpgradeDatabaseException @@ -73,17 +68,6 @@ from synapse import events from daemonize import Daemonize -import synapse - -import contextlib -import logging -import os -import re -import resource -import subprocess -import time - - logger = logging.getLogger("synapse.app.homeserver") diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index e1c07028e8..bc90605324 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -29,7 +29,7 @@ class ApplicationServiceApi(SimpleHttpClient): pushing. """ - def __init__(self, hs): + def __init__(self, hs): super(ApplicationServiceApi, self).__init__(hs) self.clock = hs.get_clock() diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index c6259f9dc8..e30e2da58d 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -57,7 +57,7 @@ class FederationClient(FederationBase): cache_name="get_pdu_cache", clock=self._clock, max_len=1000, - expiry_ms=120*1000, + expiry_ms=120 * 1000, reset_expiry_on_get=False, ) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 744a9ee507..1423df6cf3 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -147,7 +147,7 @@ class BaseHandler(object): ) if not allowed: raise LimitExceededError( - retry_after_ms=int(1000*(time_allowed - time_now)), + retry_after_ms=int(1000 * (time_allowed - time_now)), ) @defer.inlineCallbacks diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 691564c651..4efecb1ffd 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -175,8 +175,8 @@ class DirectoryHandler(BaseHandler): # If this server is in the list of servers, return it first. if self.server_name in servers: servers = ( - [self.server_name] - + [s for s in servers if s != self.server_name] + [self.server_name] + + [s for s in servers if s != self.server_name] ) else: servers = list(servers) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 254b483da6..5ad8f3779a 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -130,7 +130,7 @@ class EventStreamHandler(BaseHandler): # Add some randomness to this value to try and mitigate against # thundering herds on restart. - timeout = random.randint(int(timeout*0.9), int(timeout*1.1)) + timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1)) events, tokens = yield self.notifier.get_events_for( auth_user, pagin_config, timeout, diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index d36eb3b8d7..d0c21ff5c9 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -34,7 +34,7 @@ metrics = synapse.metrics.get_metrics_for(__name__) # Don't bother bumping "last active" time if it differs by less than 60 seconds -LAST_ACTIVE_GRANULARITY = 60*1000 +LAST_ACTIVE_GRANULARITY = 60 * 1000 # Keep no more than this number of offline serial revisions MAX_OFFLINE_SERIALS = 1000 diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index abd1a16a41..b8fbcf9233 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -213,7 +213,7 @@ class RegistrationHandler(BaseHandler): 400, "User ID must only contain characters which do not" " require URL encoding." - ) + ) user = UserID(localpart, self.hs.hostname) user_id = user.to_string() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 799221c198..088b76d237 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -927,7 +927,7 @@ class RoomContextHandler(BaseHandler): Returns: dict, or None if the event isn't found """ - before_limit = math.floor(limit/2.) + before_limit = math.floor(limit / 2.) after_limit = limit - before_limit now_token = yield self.hs.get_event_sources().get_current_token() diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index da13e32e78..c3589534f8 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -152,7 +152,7 @@ class MatrixFederationHttpClient(object): return self.clock.time_bound_deferred( request_deferred, - time_out=timeout/1000. if timeout else 60, + time_out=timeout / 1000. if timeout else 60, ) response = yield preserve_context_over_fn( diff --git a/synapse/notifier.py b/synapse/notifier.py index 29965a9ab5..1a90bd55cd 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -308,7 +308,7 @@ class Notifier(object): def timed_out(): if listener: listener.deferred.cancel() - timer = self.clock.call_later(timeout/1000., timed_out) + timer = self.clock.call_later(timeout / 1000., timed_out) prev_token = from_token while not result: diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index dca018af95..2a2b4437dc 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -304,7 +304,7 @@ def _flatten_dict(d, prefix=[], result={}): if isinstance(value, basestring): result[".".join(prefix + [key])] = value.lower() elif hasattr(value, "items"): - _flatten_dict(value, prefix=(prefix+[key]), result=result) + _flatten_dict(value, prefix=(prefix + [key]), result=result) return result diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 07836709fb..7199113dac 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -89,7 +89,7 @@ class LoginRestServlet(ClientV1RestServlet): LoginRestServlet.SAML2_TYPE): relay_state = "" if "relay_state" in login_submission: - relay_state = "&RelayState="+urllib.quote( + relay_state = "&RelayState=" + urllib.quote( login_submission["relay_state"]) result = { "uri": "%s%s" % (self.idp_redirect_url, relay_state) diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index e218ed215c..5547f1b112 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -52,7 +52,7 @@ class PusherRestServlet(ClientV1RestServlet): if i not in content: missing.append(i) if len(missing): - raise SynapseError(400, "Missing parameters: "+','.join(missing), + raise SynapseError(400, "Missing parameters: " + ','.join(missing), errcode=Codes.MISSING_PARAM) logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind']) @@ -83,7 +83,7 @@ class PusherRestServlet(ClientV1RestServlet): data=content['data'] ) except PusherConfigException as pce: - raise SynapseError(400, "Config Error: "+pce.message, + raise SynapseError(400, "Config Error: " + pce.message, errcode=Codes.MISSING_PARAM) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 5378a9a938..2bfd4d96bf 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -38,7 +38,8 @@ logger = logging.getLogger(__name__) if hasattr(hmac, "compare_digest"): compare_digest = hmac.compare_digest else: - compare_digest = lambda a, b: a == b + def compare_digest(a, b): + return a == b class RegisterRestServlet(ClientV1RestServlet): diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 5d50dd9e3d..56a5bbec30 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -34,7 +34,8 @@ from synapse.util.async import run_on_reactor if hasattr(hmac, "compare_digest"): compare_digest = hmac.compare_digest else: - compare_digest = lambda a, b: a == b + def compare_digest(a, b): + return a == b logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 349ef6b396..ca5468c402 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -26,9 +26,7 @@ class VersionsRestServlet(RestServlet): def on_GET(self, request): return (200, { - "versions": [ - "r0.0.1", - ] + "versions": ["r0.0.1"] }) diff --git a/synapse/server.py b/synapse/server.py index 5fee7fe130..368d615576 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -23,7 +23,7 @@ from twisted.web.client import BrowserLikePolicyForHTTPS from twisted.enterprise import adbapi from synapse.federation import initialize_http_replication -from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory +from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.notifier import Notifier from synapse.api.auth import Auth from synapse.handlers import Handlers diff --git a/synapse/state.py b/synapse/state.py index 0acf309fe0..b9a1387520 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -63,7 +63,7 @@ class StateHandler(object): cache_name="state_cache", clock=self.clock, max_len=SIZE_OF_CACHE, - expiry_ms=EVICTION_TIMEOUT_SECONDS*1000, + expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, reset_expiry_on_get=True, ) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index c91c7a3729..5a9e7720d9 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -59,7 +59,7 @@ logger = logging.getLogger(__name__) # Number of msec of granularity to store the user IP 'last seen' time. Smaller # times give more inserts into the database even for readonly API hits # 120 seconds == 2 minutes -LAST_SEEN_GRANULARITY = 120*1000 +LAST_SEEN_GRANULARITY = 120 * 1000 class DataStore(RoomMemberStore, RoomStore, diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 5e77320540..cfb87d9328 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -185,7 +185,7 @@ class SQLBaseStore(object): time_then = self._previous_loop_ts self._previous_loop_ts = time_now - ratio = (curr - prev)/(time_now - time_then) + ratio = (curr - prev) / (time_now - time_then) top_three_counters = self._txn_perf_counters.interval( time_now - time_then, limit=3 @@ -643,7 +643,10 @@ class SQLBaseStore(object): if not iterable: defer.returnValue(results) - chunks = [iterable[i:i+batch_size] for i in xrange(0, len(iterable), batch_size)] + chunks = [ + iterable[i:i + batch_size] + for i in xrange(0, len(iterable), batch_size) + ] for chunk in chunks: rows = yield self.runInteraction( desc, diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index 400c10103c..91fac33b8b 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -54,7 +54,7 @@ class Sqlite3Engine(object): def _parse_match_info(buf): bufsize = len(buf) - return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)] + return [struct.unpack('@I', buf[i:i + 4])[0] for i in range(0, bufsize, 4)] def _rank(raw_match_info): diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 5f32eec6f8..ce2c794025 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -58,7 +58,7 @@ class EventFederationStore(SQLBaseStore): new_front = set() front_list = list(front) chunks = [ - front_list[x:x+100] + front_list[x:x + 100] for x in xrange(0, len(front), 100) ] for chunk in chunks: diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 5e85552029..4d7cdd00d0 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -84,7 +84,7 @@ class EventsStore(SQLBaseStore): event.internal_metadata.stream_ordering = stream chunks = [ - events_and_contexts[x:x+100] + events_and_contexts[x:x + 100] for x in xrange(0, len(events_and_contexts), 100) ] @@ -740,7 +740,7 @@ class EventsStore(SQLBaseStore): rows = [] N = 200 for i in range(1 + len(events) / N): - evs = events[i*N:(i + 1)*N] + evs = events[i * N:(i + 1) * N] if not evs: break @@ -755,7 +755,7 @@ class EventsStore(SQLBaseStore): " LEFT JOIN rejections as rej USING (event_id)" " LEFT JOIN redactions as r ON e.event_id = r.redacts" " WHERE e.event_id IN (%s)" - ) % (",".join(["?"]*len(evs)),) + ) % (",".join(["?"] * len(evs)),) txn.execute(sql, evs) rows.extend(self.cursor_to_dict(txn)) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 338a9d40d5..2c49a5e499 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -168,7 +168,7 @@ class StreamStore(SQLBaseStore): results = {} room_ids = list(room_ids) - for rm_ids in (room_ids[i:i+20] for i in xrange(0, len(room_ids), 20)): + for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)): res = yield defer.gatherResults([ self.get_room_events_stream_for_room( room_id, from_key, to_key, limit diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index f1fe963adf..7566d9eb33 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -46,7 +46,7 @@ class Clock(object): def looping_call(self, f, msec): l = task.LoopingCall(f) - l.start(msec/1000.0, now=False) + l.start(msec / 1000.0, now=False) return l def stop_looping_call(self, loop): diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 88e56e3302..e27917c63a 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -149,7 +149,7 @@ class CacheDescriptor(object): self.lru = lru self.tree = tree - self.arg_names = inspect.getargspec(orig).args[1:num_args+1] + self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] if len(self.arg_names) < self.num_args: raise Exception( @@ -250,7 +250,7 @@ class CacheListDescriptor(object): self.num_args = num_args self.list_name = list_name - self.arg_names = inspect.getargspec(orig).args[1:num_args+1] + self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] self.list_pos = self.arg_names.index(self.list_name) self.cache = cache diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 494226f5ea..62cae99649 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -55,7 +55,7 @@ class ExpiringCache(object): def f(): self._prune_cache() - self._clock.looping_call(f, self._expiry_ms/2) + self._clock.looping_call(f, self._expiry_ms / 2) def __setitem__(self, key, value): now = self._clock.time_msec() diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index 29d02f7e95..03bc1401b7 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -58,7 +58,7 @@ class TreeCache(object): if n: break - node_and_keys[i+1][0].pop(k) + node_and_keys[i + 1][0].pop(k) popped, cnt = _strip_and_count_entires(popped) self.size -= cnt diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py index d5b1a37eff..c37a157787 100644 --- a/synapse/util/logutils.py +++ b/synapse/util/logutils.py @@ -111,7 +111,7 @@ def time_function(f): _log_debug_as_f( f, "[FUNC END] {%s-%d} %f", - (func_name, id, end-start,), + (func_name, id, end - start,), ) return r diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index c37d6f12e3..ea321bc6a9 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -163,7 +163,7 @@ class _PerHostRatelimiter(object): "Ratelimit [%s]: sleeping req", id(request_id), ) - ret_defer = sleep(self.sleep_msec/1000.0) + ret_defer = sleep(self.sleep_msec / 1000.0) self.sleeping_requests.add(request_id) -- cgit 1.4.1 From b32121a5d1c833ab3e2c164e642ef982364069e2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 3 Feb 2016 10:30:56 +0000 Subject: Unused import --- synapse/storage/room.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 1b6311f332..dc09a3aaba 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import StoreError from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cachedInlineCallbacks, cached +from synapse.util.caches.descriptors import cachedInlineCallbacks from .engines import PostgresEngine, Sqlite3Engine import collections -- cgit 1.4.1 From 771528ab1323715271b9e968d2d337b88910fb2f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 3 Feb 2016 10:50:49 +0000 Subject: Change event_push_actions_rm_tokens schema --- synapse/handlers/sync.py | 6 +-- synapse/push/__init__.py | 2 +- synapse/storage/event_push_actions.py | 47 ++++++++++++++++-------- synapse/storage/prepare_database.py | 2 +- synapse/storage/schema/delta/29/push_actions.sql | 31 ++++++++++++++++ 5 files changed, 67 insertions(+), 21 deletions(-) create mode 100644 synapse/storage/schema/delta/29/push_actions.sql (limited to 'synapse/storage') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index dc686db541..0292e06733 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -706,10 +706,8 @@ class SyncHandler(BaseHandler): ) if notifs is not None: - unread_notifications["notification_count"] = len(notifs) - unread_notifications["highlight_count"] = len([ - 1 for notif in notifs if _action_has_highlight(notif["actions"]) - ]) + unread_notifications["notification_count"] = notifs["notify_count"] + unread_notifications["highlight_count"] = notifs["highlight_count"] logger.debug("Room sync: %r", room_sync) diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 9bc0b356f4..8b9d0f03e5 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -316,7 +316,7 @@ class Pusher(object): r.room_id, self.user_id, last_unread_event_id ) ) - badge += len(notifs) + badge += notifs["notify_count"] defer.returnValue(badge) diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index a05c4f84cf..aca3219206 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -37,7 +37,11 @@ class EventPushActionsStore(SQLBaseStore): 'event_id': event.event_id, 'user_id': uid, 'profile_tag': profile_tag, - 'actions': json.dumps(actions) + 'actions': json.dumps(actions), + 'stream_ordering': event.internal_metadata.stream_ordering, + 'topological_ordering': event.depth, + 'notif': 1, + 'highlight': 1 if _action_has_highlight(actions) else 0, }) def f(txn): @@ -74,26 +78,28 @@ class EventPushActionsStore(SQLBaseStore): topological_ordering = results[0][1] sql = ( - "SELECT ea.event_id, ea.actions" - " FROM event_push_actions ea, events e" - " WHERE ea.room_id = e.room_id" - " AND ea.event_id = e.event_id" - " AND ea.user_id = ?" - " AND ea.room_id = ?" + "SELECT sum(notif), sum(highlight)" + " FROM event_push_actions ea" + " WHERE" + " user_id = ?" + " AND room_id = ?" " AND (" - " e.topological_ordering > ?" - " OR (e.topological_ordering = ? AND e.stream_ordering > ?)" + " topological_ordering > ?" + " OR (topological_ordering = ? AND stream_ordering > ?)" ")" ) txn.execute(sql, ( user_id, room_id, topological_ordering, topological_ordering, stream_ordering - ) - ) - return [ - {"event_id": row[0], "actions": json.loads(row[1])} - for row in txn.fetchall() - ] + )) + row = txn.fetchone() + if row: + return { + "notify_count": row[0] or 0, + "highlight_count": row[1] or 0, + } + else: + return {"notify_count": 0, "highlight_count": 0} ret = yield self.runInteraction( "get_unread_event_push_actions_by_room", @@ -117,3 +123,14 @@ class EventPushActionsStore(SQLBaseStore): "remove_push_actions_for_event_id", f ) + + +def _action_has_highlight(actions): + for action in actions: + try: + if action.get("set_tweak", None) == "highlight": + return action.get("value", True) + except AttributeError: + pass + + return False diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index c1f5f99789..d782b8e25b 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 28 +SCHEMA_VERSION = 29 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/schema/delta/29/push_actions.sql b/synapse/storage/schema/delta/29/push_actions.sql new file mode 100644 index 0000000000..7e7b09820a --- /dev/null +++ b/synapse/storage/schema/delta/29/push_actions.sql @@ -0,0 +1,31 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE event_push_actions ADD COLUMN topological_ordering BIGINT; +ALTER TABLE event_push_actions ADD COLUMN stream_ordering BIGINT; +ALTER TABLE event_push_actions ADD COLUMN notif SMALLINT; +ALTER TABLE event_push_actions ADD COLUMN highlight SMALLINT; + +UPDATE event_push_actions SET stream_ordering = ( + SELECT stream_ordering FROM events WHERE event_id = event_push_actions.event_id +), topological_ordering = ( + SELECT topological_ordering FROM events WHERE event_id = event_push_actions.event_id +); + +UPDATE event_push_actions SET notif = 1, highlight = 0; + +CREATE INDEX event_push_actions_rm_tokens on event_push_actions( + user_id, room_id, topological_ordering, stream_ordering +); -- cgit 1.4.1 From f8aae79a72e462f4af65a22d0665192867522174 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 3 Feb 2016 13:23:32 +0000 Subject: Simplify get_rooms --- synapse/app/homeserver.py | 4 +-- synapse/storage/room.py | 84 ++++------------------------------------------ tests/storage/test_room.py | 26 -------------- 3 files changed, 9 insertions(+), 105 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index c3066d6a0d..0a6a19033d 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -674,8 +674,8 @@ def run(hs): stats["uptime_seconds"] = uptime stats["total_users"] = yield hs.get_datastore().count_all_users() - all_rooms = yield hs.get_datastore().get_rooms(False) - stats["total_room_count"] = len(all_rooms) + room_count = yield hs.get_datastore().get_room_count() + stats["total_room_count"] = room_count stats["daily_active_users"] = yield hs.get_datastore().count_daily_users() daily_messages = yield hs.get_datastore().count_daily_messages() diff --git a/synapse/storage/room.py b/synapse/storage/room.py index dc09a3aaba..46ab38a313 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -87,90 +87,20 @@ class RoomStore(SQLBaseStore): desc="get_public_room_ids", ) - @defer.inlineCallbacks - def get_rooms(self, is_public): - """Retrieve a list of all public rooms. - - Args: - is_public (bool): True if the rooms returned should be public. - Returns: - A list of room dicts containing at least a "room_id" key, a - "topic" key if one is set, and a "name" key if one is set + def get_room_count(self): + """Retrieve a list of all rooms """ def f(txn): - def subquery(table_name, column_name=None): - column_name = column_name or table_name - return ( - "SELECT %(table_name)s.event_id as event_id, " - "%(table_name)s.room_id as room_id, %(column_name)s " - "FROM %(table_name)s " - "INNER JOIN current_state_events as c " - "ON c.event_id = %(table_name)s.event_id " % { - "column_name": column_name, - "table_name": table_name, - } - ) - - sql = ( - "SELECT" - " r.room_id," - " max(n.name)," - " max(t.topic)," - " max(v.history_visibility)," - " max(g.guest_access)" - " FROM rooms AS r" - " LEFT JOIN (%(topic)s) AS t ON t.room_id = r.room_id" - " LEFT JOIN (%(name)s) AS n ON n.room_id = r.room_id" - " LEFT JOIN (%(history_visibility)s) AS v ON v.room_id = r.room_id" - " LEFT JOIN (%(guest_access)s) AS g ON g.room_id = r.room_id" - " WHERE r.is_public = ?" - " GROUP BY r.room_id" % { - "topic": subquery("topics", "topic"), - "name": subquery("room_names", "name"), - "history_visibility": subquery("history_visibility"), - "guest_access": subquery("guest_access"), - } - ) - - txn.execute(sql, (is_public,)) - - rows = txn.fetchall() - - for i, row in enumerate(rows): - room_id = row[0] - aliases = self._simple_select_onecol_txn( - txn, - table="room_aliases", - keyvalues={ - "room_id": room_id - }, - retcol="room_alias", - ) + sql = "SELECT count(*) FROM rooms" + txn.execute(sql) + row = txn.fetchone() + return row[0] or 0 - rows[i] = list(row) + [aliases] - - return rows - - rows = yield self.runInteraction( + return self.runInteraction( "get_rooms", f ) - ret = [ - { - "room_id": r[0], - "name": r[1], - "topic": r[2], - "world_readable": r[3] == "world_readable", - "guest_can_join": r[4] == "can_join", - "aliases": r[5], - } - for r in rows - if r[5] # We only return rooms that have at least one alias. - ] - - defer.returnValue(ret) - def _store_room_topic_txn(self, txn, event): if hasattr(event, "content") and "topic" in event.content: self._simple_insert_txn( diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 7fdbfc60f1..0baaf3df21 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -51,32 +51,6 @@ class RoomStoreTestCase(unittest.TestCase): (yield self.store.get_room(self.room.to_string())) ) - @defer.inlineCallbacks - def test_get_rooms(self): - # get_rooms does an INNER JOIN on the room_aliases table :( - - rooms = yield self.store.get_rooms(is_public=True) - # Should be empty before we add the alias - self.assertEquals([], rooms) - - yield self.store.create_room_alias_association( - room_alias=self.alias, - room_id=self.room.to_string(), - servers=["test"] - ) - - rooms = yield self.store.get_rooms(is_public=True) - - self.assertEquals(1, len(rooms)) - self.assertEquals({ - "name": None, - "room_id": self.room.to_string(), - "topic": None, - "aliases": [self.alias.to_string()], - "world_readable": False, - "guest_can_join": False, - }, rooms[0]) - class RoomEventsStoreTestCase(unittest.TestCase): -- cgit 1.4.1 From b84d59c5f01914fe53d2673c5c7e372f5c61d088 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 3 Feb 2016 16:22:35 +0000 Subject: Add descriptions --- synapse/storage/appservice.py | 3 ++- synapse/storage/keys.py | 1 + synapse/storage/registration.py | 1 + synapse/storage/stream.py | 1 + 4 files changed, 5 insertions(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index b5aa55c0a3..1100c67714 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -276,7 +276,8 @@ class ApplicationServiceTransactionStore(SQLBaseStore): "application_services_state", dict(as_id=service.id), ["state"], - allow_none=True + allow_none=True, + desc="get_appservice_state", ) if result: defer.returnValue(result.get("state")) diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 8022b8cfc6..fd05bfe54e 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -39,6 +39,7 @@ class KeyStore(SQLBaseStore): table="server_tls_certificates", keyvalues={"server_name": server_name}, retcols=("tls_certificate",), + desc="get_server_certificate", ) tls_certificate = OpenSSL.crypto.load_certificate( OpenSSL.crypto.FILETYPE_ASN1, tls_certificate_bytes, diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 70cde0d04d..bd35e19be6 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -134,6 +134,7 @@ class RegistrationStore(SQLBaseStore): }, retcols=["name", "password_hash", "is_guest"], allow_none=True, + desc="get_user_by_id", ) def get_users_by_id_case_insensitive(self, user_id): diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 2c49a5e499..50436cb2d2 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -564,6 +564,7 @@ class StreamStore(SQLBaseStore): table="events", keyvalues={"event_id": event_id}, retcols=("stream_ordering", "topological_ordering"), + desc="get_topological_token_for_event", ).addCallback(lambda row: "t%d-%d" % ( row["topological_ordering"], row["stream_ordering"],) ) -- cgit 1.4.1 From aa4af94c69b8b1c263dacfce0358aaef97d3e323 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 3 Feb 2016 16:29:32 +0000 Subject: We return dicts now. --- synapse/storage/event_push_actions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index aca3219206..2742e0c008 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -72,7 +72,7 @@ class EventPushActionsStore(SQLBaseStore): ) results = txn.fetchall() if len(results) == 0: - return [] + return {} stream_ordering = results[0][0] topological_ordering = results[0][1] -- cgit 1.4.1 From 4d36e732307ad35eb070af384058f227d7d85dd0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 3 Feb 2016 16:35:00 +0000 Subject: Actually return something sensible --- synapse/storage/event_push_actions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index 2742e0c008..d0a969f50b 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -72,7 +72,7 @@ class EventPushActionsStore(SQLBaseStore): ) results = txn.fetchall() if len(results) == 0: - return {} + return {"notify_count": 0, "highlight_count": 0} stream_ordering = results[0][0] topological_ordering = results[0][1] -- cgit 1.4.1 From 79a1c0574b33955d28bfb12697ccd5a7be779b36 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Fri, 5 Feb 2016 11:22:30 +0000 Subject: Allocate guest user IDs numericcally The current random IDs are ugly and confusing when presented in UIs. This makes them prettier and easier to read. Also, disable non-automated registration of numeric IDs so that we don't need to worry so much about people carving out our automated address space and us needing to keep retrying ID registration. --- synapse/handlers/register.py | 55 +++++++++++++++++++++++++++-------------- synapse/storage/registration.py | 36 +++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 19 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index b8fbcf9233..2660fd21a2 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -21,7 +21,6 @@ from synapse.api.errors import ( AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError ) from ._base import BaseHandler -import synapse.util.stringutils as stringutils from synapse.util.async import run_on_reactor from synapse.http.client import CaptchaServerHttpClient @@ -45,6 +44,8 @@ class RegistrationHandler(BaseHandler): self.distributor.declare("registered_user") self.captcha_client = CaptchaServerHttpClient(hs) + self._next_generated_user_id = None + @defer.inlineCallbacks def check_username(self, localpart, guest_access_token=None): yield run_on_reactor() @@ -91,7 +92,7 @@ class RegistrationHandler(BaseHandler): Args: localpart : The local part of the user ID to register. If None, - one will be randomly generated. + one will be generated. password (str) : The password to assign to this user so they can login again. This can be None which means they cannot login again via a password (e.g. the user is an application service user). @@ -108,6 +109,18 @@ class RegistrationHandler(BaseHandler): if localpart: yield self.check_username(localpart, guest_access_token=guest_access_token) + was_guest = guest_access_token is not None + + if not was_guest: + try: + int(localpart) + raise RegistrationError( + 400, + "Numeric user IDs are reserved for guest users." + ) + except ValueError: + pass + user = UserID(localpart, self.hs.hostname) user_id = user.to_string() @@ -118,40 +131,36 @@ class RegistrationHandler(BaseHandler): user_id=user_id, token=token, password_hash=password_hash, - was_guest=guest_access_token is not None, + was_guest=was_guest, make_guest=make_guest, ) yield registered_user(self.distributor, user) else: - # autogen a random user ID + # autogen a sequential user ID attempts = 0 - user_id = None token = None - while not user_id: + user = None + while not user: + localpart = yield self._generate_user_id(attempts > 0) + user = UserID(localpart, self.hs.hostname) + user_id = user.to_string() + yield self.check_user_id_is_valid(user_id) + if generate_token: + token = self.auth_handler().generate_access_token(user_id) try: - localpart = self._generate_user_id() - user = UserID(localpart, self.hs.hostname) - user_id = user.to_string() - yield self.check_user_id_is_valid(user_id) - if generate_token: - token = self.auth_handler().generate_access_token(user_id) yield self.store.register( user_id=user_id, token=token, password_hash=password_hash, make_guest=make_guest ) - - yield registered_user(self.distributor, user) except SynapseError: # if user id is taken, just generate another user_id = None token = None attempts += 1 - if attempts > 5: - raise RegistrationError( - 500, "Cannot generate user ID.") + yield registered_user(self.distributor, user) # We used to generate default identicons here, but nowadays # we want clients to generate their own as part of their branding @@ -283,8 +292,16 @@ class RegistrationHandler(BaseHandler): errcode=Codes.EXCLUSIVE ) - def _generate_user_id(self): - return "-" + stringutils.random_string(18) + @defer.inlineCallbacks + def _generate_user_id(self, reseed=False): + if reseed or self._next_generated_user_id is None: + self._next_generated_user_id = ( + yield self.store.find_next_generated_user_id_localpart() + ) + + id = self._next_generated_user_id + self._next_generated_user_id += 1 + defer.returnValue(str(id)) @defer.inlineCallbacks def _validate_captcha(self, ip_addr, private_key, challenge, response): diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index bd35e19be6..967c732bda 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re + from twisted.internet import defer from synapse.api.errors import StoreError, Codes @@ -351,3 +353,37 @@ class RegistrationStore(SQLBaseStore): ret = yield self.runInteraction("count_users", _count_users) defer.returnValue(ret) + + @defer.inlineCallbacks + def find_next_generated_user_id_localpart(self): + """ + Gets the localpart of the next generated user ID. + + Generated user IDs are integers, and we aim for them to be as small as + we can. Unfortunately, it's possible some of them are already taken by + existing users, and there may be gaps in the already taken range. This + function returns the start of the first allocatable gap. This is to + avoid the case of ID 10000000 being pre-allocated, so us wasting the + first (and shortest) many generated user IDs. + """ + def _find_next_generated_user_id(txn): + txn.execute("SELECT name FROM users") + rows = self.cursor_to_dict(txn) + + regex = re.compile("^@(\d+):") + + found = set() + + for r in rows: + user_id = r["name"] + match = regex.search(user_id) + if match: + found.add(int(match.group(1))) + for i in xrange(len(found) + 1): + if i not in found: + return i + + defer.returnValue((yield self.runInteraction( + "find_next_generated_user_id", + _find_next_generated_user_id + ))) -- cgit 1.4.1 From 2c1fbea5319db2c64fa486adb32b5e66680b6daf Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 4 Feb 2016 10:22:44 +0000 Subject: Fix up logcontexts --- synapse/api/auth.py | 4 +- synapse/app/homeserver.py | 2 + synapse/crypto/keyring.py | 83 ++++++++++++----------- synapse/federation/federation_server.py | 4 +- synapse/federation/transaction_queue.py | 3 - synapse/handlers/_base.py | 10 +-- synapse/handlers/events.py | 11 +++- synapse/handlers/federation.py | 50 ++------------ synapse/handlers/presence.py | 20 +++--- synapse/handlers/register.py | 2 +- synapse/handlers/room.py | 11 +++- synapse/handlers/sync.py | 40 ++++++------ synapse/http/server.py | 5 +- synapse/notifier.py | 58 ++++++++-------- synapse/push/__init__.py | 2 +- synapse/push/pusherpool.py | 9 +-- synapse/rest/client/v2_alpha/account_data.py | 4 +- synapse/rest/client/v2_alpha/tags.py | 4 +- synapse/storage/_base.py | 18 ++--- synapse/storage/events.py | 34 ++++++---- synapse/storage/presence.py | 5 +- synapse/storage/stream.py | 9 +-- synapse/util/__init__.py | 6 +- synapse/util/async.py | 11 +++- synapse/util/caches/descriptors.py | 16 +++-- synapse/util/caches/snapshot_cache.py | 3 +- synapse/util/distributor.py | 15 +++-- synapse/util/logcontext.py | 98 ++++++++++++++++++++++++++-- synapse/util/logutils.py | 35 ++++++++++ synapse/util/metrics.py | 10 +-- synapse/util/ratelimitutils.py | 3 +- 31 files changed, 356 insertions(+), 229 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 5bba9343f6..e2f84c4d57 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -24,6 +24,7 @@ from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError from synapse.types import Requester, RoomID, UserID, EventID from synapse.util.logutils import log_function +from synapse.util.logcontext import preserve_context_over_fn from unpaddedbase64 import decode_base64 import logging @@ -529,7 +530,8 @@ class Auth(object): default=[""] )[0] if user and access_token and ip_addr: - self.store.insert_client_ip( + preserve_context_over_fn( + self.store.insert_client_ip, user=user, access_token=access_token, ip=ip_addr, diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index e5c7e39cf9..2b4be7bdd0 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -709,6 +709,8 @@ def run(hs): phone_home_task.start(60 * 60 * 24, now=False) def in_thread(): + # Uncomment to enable tracing of log context changes. + # sys.settrace(logcontext_tracer) with LoggingContext("run"): change_resource_limit(hs.config.soft_file_limit) reactor.run() diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index cddec0b2bc..d08ee0aa91 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -18,6 +18,10 @@ from synapse.api.errors import SynapseError, Codes from synapse.util.retryutils import get_retry_limiter from synapse.util import unwrapFirstError from synapse.util.async import ObservableDeferred +from synapse.util.logcontext import ( + preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext, + preserve_fn +) from twisted.internet import defer @@ -142,40 +146,43 @@ class Keyring(object): for server_name, _ in server_and_json } - # We want to wait for any previous lookups to complete before - # proceeding. - wait_on_deferred = self.wait_for_previous_lookups( - [server_name for server_name, _ in server_and_json], - server_to_deferred, - ) + with PreserveLoggingContext(): - # Actually start fetching keys. - wait_on_deferred.addBoth( - lambda _: self.get_server_verify_keys(group_id_to_group, deferreds) - ) + # We want to wait for any previous lookups to complete before + # proceeding. + wait_on_deferred = self.wait_for_previous_lookups( + [server_name for server_name, _ in server_and_json], + server_to_deferred, + ) - # When we've finished fetching all the keys for a given server_name, - # resolve the deferred passed to `wait_for_previous_lookups` so that - # any lookups waiting will proceed. - server_to_gids = {} + # Actually start fetching keys. + wait_on_deferred.addBoth( + lambda _: self.get_server_verify_keys(group_id_to_group, deferreds) + ) + + # When we've finished fetching all the keys for a given server_name, + # resolve the deferred passed to `wait_for_previous_lookups` so that + # any lookups waiting will proceed. + server_to_gids = {} - def remove_deferreds(res, server_name, group_id): - server_to_gids[server_name].discard(group_id) - if not server_to_gids[server_name]: - d = server_to_deferred.pop(server_name, None) - if d: - d.callback(None) - return res + def remove_deferreds(res, server_name, group_id): + server_to_gids[server_name].discard(group_id) + if not server_to_gids[server_name]: + d = server_to_deferred.pop(server_name, None) + if d: + d.callback(None) + return res - for g_id, deferred in deferreds.items(): - server_name = group_id_to_group[g_id].server_name - server_to_gids.setdefault(server_name, set()).add(g_id) - deferred.addBoth(remove_deferreds, server_name, g_id) + for g_id, deferred in deferreds.items(): + server_name = group_id_to_group[g_id].server_name + server_to_gids.setdefault(server_name, set()).add(g_id) + deferred.addBoth(remove_deferreds, server_name, g_id) # Pass those keys to handle_key_deferred so that the json object # signatures can be verified return [ - handle_key_deferred( + preserve_context_over_fn( + handle_key_deferred, group_id_to_group[g_id], deferreds[g_id], ) @@ -198,12 +205,13 @@ class Keyring(object): if server_name in self.key_downloads ] if wait_on: - yield defer.DeferredList(wait_on) + with PreserveLoggingContext(): + yield defer.DeferredList(wait_on) else: break for server_name, deferred in server_to_deferred.items(): - d = ObservableDeferred(deferred) + d = ObservableDeferred(preserve_context_over_deferred(deferred)) self.key_downloads[server_name] = d def rm(r, server_name): @@ -244,12 +252,13 @@ class Keyring(object): for group in group_id_to_group.values(): for key_id in group.key_ids: if key_id in merged_results[group.server_name]: - group_id_to_deferred[group.group_id].callback(( - group.group_id, - group.server_name, - key_id, - merged_results[group.server_name][key_id], - )) + with PreserveLoggingContext(): + group_id_to_deferred[group.group_id].callback(( + group.group_id, + group.server_name, + key_id, + merged_results[group.server_name][key_id], + )) break else: missing_groups.setdefault( @@ -504,7 +513,7 @@ class Keyring(object): yield defer.gatherResults( [ - self.store_keys( + preserve_fn(self.store_keys)( server_name=key_server_name, from_server=server_name, verify_keys=verify_keys, @@ -573,7 +582,7 @@ class Keyring(object): yield defer.gatherResults( [ - self.store.store_server_keys_json( + preserve_fn(self.store.store_server_keys_json)( server_name=server_name, key_id=key_id, from_server=server_name, @@ -675,7 +684,7 @@ class Keyring(object): # TODO(markjh): Store whether the keys have expired. yield defer.gatherResults( [ - self.store.store_server_verify_key( + preserve_fn(self.store.store_server_verify_key)( server_name, server_name, key.time_added, key ) for key_id, key in verify_keys.items() diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index a97aa0c94a..90718192dd 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -126,10 +126,8 @@ class FederationServer(FederationBase): results = [] for pdu in pdu_list: - d = self._handle_new_pdu(transaction.origin, pdu) - try: - yield d + yield self._handle_new_pdu(transaction.origin, pdu) results.append({}) except FederationError as e: self.send_failure(e, transaction.origin) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 622adad3ae..1928da03b3 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -103,7 +103,6 @@ class TransactionQueue(object): else: return not destination.startswith("localhost") - @defer.inlineCallbacks def enqueue_pdu(self, pdu, destinations, order): # We loop through all destinations to see whether we already have # a transaction in progress. If we do, stick it in the pending_pdus @@ -141,8 +140,6 @@ class TransactionQueue(object): deferreds.append(deferred) - yield defer.DeferredList(deferreds, consumeErrors=True) - # NO inlineCallbacks def enqueue_edu(self, edu): destination = edu.destination diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 1423df6cf3..fa83d3e464 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -293,19 +293,11 @@ class BaseHandler(object): with PreserveLoggingContext(): # Don't block waiting on waking up all the listeners. - notify_d = self.notifier.on_new_room_event( + self.notifier.on_new_room_event( event, event_stream_id, max_stream_id, extra_users=extra_users ) - def log_failure(f): - logger.warn( - "Failed to notify about %s: %s", - event.event_id, f.value - ) - - notify_d.addErrback(log_failure) - # If invite, remove room_state from unsigned before sending. event.unsigned.pop("invite_room_state", None) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 5ad8f3779a..4933c31c19 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.util.logutils import log_function from synapse.types import UserID from synapse.events.utils import serialize_event +from synapse.util.logcontext import preserve_context_over_fn from ._base import BaseHandler @@ -29,11 +30,17 @@ logger = logging.getLogger(__name__) def started_user_eventstream(distributor, user): - return distributor.fire("started_user_eventstream", user) + return preserve_context_over_fn( + distributor.fire, + "started_user_eventstream", user + ) def stopped_user_eventstream(distributor, user): - return distributor.fire("stopped_user_eventstream", user) + return preserve_context_over_fn( + distributor.fire, + "stopped_user_eventstream", user + ) class EventStreamHandler(BaseHandler): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 2ce1e9d6c7..b78b0502d9 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -221,19 +221,11 @@ class FederationHandler(BaseHandler): extra_users.append(target_user) with PreserveLoggingContext(): - d = self.notifier.on_new_room_event( + self.notifier.on_new_room_event( event, event_stream_id, max_stream_id, extra_users=extra_users ) - def log_failure(f): - logger.warn( - "Failed to notify about %s: %s", - event.event_id, f.value - ) - - d.addErrback(log_failure) - if event.type == EventTypes.Member: if event.membership == Membership.JOIN: prev_state = context.current_state.get((event.type, event.state_key)) @@ -643,19 +635,11 @@ class FederationHandler(BaseHandler): ) with PreserveLoggingContext(): - d = self.notifier.on_new_room_event( + self.notifier.on_new_room_event( event, event_stream_id, max_stream_id, extra_users=[joinee] ) - def log_failure(f): - logger.warn( - "Failed to notify about %s: %s", - event.event_id, f.value - ) - - d.addErrback(log_failure) - logger.debug("Finished joining %s to %s", joinee, room_id) finally: room_queue = self.room_queues[room_id] @@ -730,18 +714,10 @@ class FederationHandler(BaseHandler): extra_users.append(target_user) with PreserveLoggingContext(): - d = self.notifier.on_new_room_event( + self.notifier.on_new_room_event( event, event_stream_id, max_stream_id, extra_users=extra_users ) - def log_failure(f): - logger.warn( - "Failed to notify about %s: %s", - event.event_id, f.value - ) - - d.addErrback(log_failure) - if event.type == EventTypes.Member: if event.content["membership"] == Membership.JOIN: user = UserID.from_string(event.state_key) @@ -811,19 +787,11 @@ class FederationHandler(BaseHandler): target_user = UserID.from_string(event.state_key) with PreserveLoggingContext(): - d = self.notifier.on_new_room_event( + self.notifier.on_new_room_event( event, event_stream_id, max_stream_id, extra_users=[target_user], ) - def log_failure(f): - logger.warn( - "Failed to notify about %s: %s", - event.event_id, f.value - ) - - d.addErrback(log_failure) - defer.returnValue(event) @defer.inlineCallbacks @@ -948,18 +916,10 @@ class FederationHandler(BaseHandler): extra_users.append(target_user) with PreserveLoggingContext(): - d = self.notifier.on_new_room_event( + self.notifier.on_new_room_event( event, event_stream_id, max_stream_id, extra_users=extra_users ) - def log_failure(f): - logger.warn( - "Failed to notify about %s: %s", - event.event_id, f.value - ) - - d.addErrback(log_failure) - new_pdu = event destinations = set() diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index d0c21ff5c9..b61394f2b5 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -378,9 +378,9 @@ class PresenceHandler(BaseHandler): was_polling = target_user in self._user_cachemap if now_online and not was_polling: - self.start_polling_presence(target_user, state=state) + yield self.start_polling_presence(target_user, state=state) elif not now_online and was_polling: - self.stop_polling_presence(target_user) + yield self.stop_polling_presence(target_user) # TODO(paul): perform a presence push as part of start/stop poll so # we don't have to do this all the time @@ -394,7 +394,8 @@ class PresenceHandler(BaseHandler): if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY: return - self.changed_presencelike_data(user, {"last_active": now}) + with PreserveLoggingContext(): + self.changed_presencelike_data(user, {"last_active": now}) def get_joined_rooms_for_user(self, user): """Get the list of rooms a user is joined to. @@ -466,11 +467,12 @@ class PresenceHandler(BaseHandler): local_user, room_ids=[room_id], add_to_cache=False ) - self.push_update_to_local_and_remote( - observed_user=local_user, - users_to_push=[user], - statuscache=statuscache, - ) + with PreserveLoggingContext(): + self.push_update_to_local_and_remote( + observed_user=local_user, + users_to_push=[user], + statuscache=statuscache, + ) @defer.inlineCallbacks def send_presence_invite(self, observer_user, observed_user): @@ -556,7 +558,7 @@ class PresenceHandler(BaseHandler): observer_user.localpart, observed_user.to_string() ) - self.start_polling_presence( + yield self.start_polling_presence( observer_user, target_user=observed_user ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 2660fd21a2..24c850ae9b 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -186,7 +186,7 @@ class RegistrationHandler(BaseHandler): token=token, password_hash="" ) - registered_user(self.distributor, user) + yield registered_user(self.distributor, user) defer.returnValue((user_id, token)) @defer.inlineCallbacks diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index bfd7e44e9f..a8e3a9029c 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -25,6 +25,7 @@ from synapse.api.constants import ( from synapse.api.errors import AuthError, StoreError, SynapseError, Codes from synapse.util import stringutils, unwrapFirstError from synapse.util.async import run_on_reactor +from synapse.util.logcontext import preserve_context_over_fn from signedjson.sign import verify_signed_json from signedjson.key import decode_verify_key_bytes @@ -46,11 +47,17 @@ def collect_presencelike_data(distributor, user, content): def user_left_room(distributor, user, room_id): - return distributor.fire("user_left_room", user=user, room_id=room_id) + return preserve_context_over_fn( + distributor.fire, + "user_left_room", user=user, room_id=room_id + ) def user_joined_room(distributor, user, room_id): - return distributor.fire("user_joined_room", user=user, room_id=room_id) + return preserve_context_over_fn( + distributor.fire, + "user_joined_room", user=user, room_id=room_id + ) class RoomCreationHandler(BaseHandler): diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 72271f2626..3f1cda5b0b 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -18,7 +18,7 @@ from ._base import BaseHandler from synapse.streams.config import PaginationConfig from synapse.api.constants import Membership, EventTypes from synapse.util import unwrapFirstError -from synapse.util.logcontext import LoggingContext +from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from twisted.internet import defer @@ -241,15 +241,16 @@ class SyncHandler(BaseHandler): deferreds = [] for event in room_list: if event.membership == Membership.JOIN: - room_sync_deferred = self.full_state_sync_for_joined_room( - room_id=event.room_id, - sync_config=sync_config, - now_token=now_token, - timeline_since_token=timeline_since_token, - ephemeral_by_room=ephemeral_by_room, - tags_by_room=tags_by_room, - account_data_by_room=account_data_by_room, - ) + with PreserveLoggingContext(LoggingContext.current_context()): + room_sync_deferred = self.full_state_sync_for_joined_room( + room_id=event.room_id, + sync_config=sync_config, + now_token=now_token, + timeline_since_token=timeline_since_token, + ephemeral_by_room=ephemeral_by_room, + tags_by_room=tags_by_room, + account_data_by_room=account_data_by_room, + ) room_sync_deferred.addCallback(joined.append) deferreds.append(room_sync_deferred) elif event.membership == Membership.INVITE: @@ -262,15 +263,16 @@ class SyncHandler(BaseHandler): leave_token = now_token.copy_and_replace( "room_key", "s%d" % (event.stream_ordering,) ) - room_sync_deferred = self.full_state_sync_for_archived_room( - sync_config=sync_config, - room_id=event.room_id, - leave_event_id=event.event_id, - leave_token=leave_token, - timeline_since_token=timeline_since_token, - tags_by_room=tags_by_room, - account_data_by_room=account_data_by_room, - ) + with PreserveLoggingContext(LoggingContext.current_context()): + room_sync_deferred = self.full_state_sync_for_archived_room( + sync_config=sync_config, + room_id=event.room_id, + leave_event_id=event.event_id, + leave_token=leave_token, + timeline_since_token=timeline_since_token, + tags_by_room=tags_by_room, + account_data_by_room=account_data_by_room, + ) room_sync_deferred.addCallback(archived.append) deferreds.append(room_sync_deferred) diff --git a/synapse/http/server.py b/synapse/http/server.py index 06935783ca..a90e2e1125 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -99,9 +99,8 @@ def request_handler(request_handler): request_context.request = request_id with request.processing(): try: - d = request_handler(self, request) - with PreserveLoggingContext(): - yield d + with PreserveLoggingContext(request_context): + yield request_handler(self, request) except CodeMessageException as e: code = e.code if isinstance(e, SynapseError): diff --git a/synapse/notifier.py b/synapse/notifier.py index 1a90bd55cd..560866b26e 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -18,7 +18,8 @@ from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.util.logutils import log_function -from synapse.util.async import run_on_reactor, ObservableDeferred +from synapse.util.async import ObservableDeferred +from synapse.util.logcontext import PreserveLoggingContext from synapse.types import StreamToken import synapse.metrics @@ -73,7 +74,8 @@ class _NotifierUserStream(object): self.current_token = current_token self.last_notified_ms = time_now_ms - self.notify_deferred = ObservableDeferred(defer.Deferred()) + with PreserveLoggingContext(): + self.notify_deferred = ObservableDeferred(defer.Deferred()) def notify(self, stream_key, stream_id, time_now_ms): """Notify any listeners for this user of a new event from an @@ -88,8 +90,10 @@ class _NotifierUserStream(object): ) self.last_notified_ms = time_now_ms noify_deferred = self.notify_deferred - self.notify_deferred = ObservableDeferred(defer.Deferred()) - noify_deferred.callback(self.current_token) + + with PreserveLoggingContext(): + self.notify_deferred = ObservableDeferred(defer.Deferred()) + noify_deferred.callback(self.current_token) def remove(self, notifier): """ Remove this listener from all the indexes in the Notifier @@ -184,8 +188,6 @@ class Notifier(object): lambda: count(bool, self.appservice_to_user_streams.values()), ) - @log_function - @defer.inlineCallbacks def on_new_room_event(self, event, room_stream_id, max_room_stream_id, extra_users=[]): """ Used by handlers to inform the notifier something has happened @@ -199,12 +201,11 @@ class Notifier(object): until all previous events have been persisted before notifying the client streams. """ - yield run_on_reactor() - - self.pending_new_room_events.append(( - room_stream_id, event, extra_users - )) - self._notify_pending_new_room_events(max_room_stream_id) + with PreserveLoggingContext(): + self.pending_new_room_events.append(( + room_stream_id, event, extra_users + )) + self._notify_pending_new_room_events(max_room_stream_id) def _notify_pending_new_room_events(self, max_room_stream_id): """Notify for the room events that were queued waiting for a previous @@ -251,31 +252,29 @@ class Notifier(object): extra_streams=app_streams, ) - @defer.inlineCallbacks - @log_function def on_new_event(self, stream_key, new_token, users=[], rooms=[], extra_streams=set()): """ Used to inform listeners that something has happend event wise. Will wake up all listeners for the given users and rooms. """ - yield run_on_reactor() - user_streams = set() + with PreserveLoggingContext(): + user_streams = set() - for user in users: - user_stream = self.user_to_user_stream.get(str(user)) - if user_stream is not None: - user_streams.add(user_stream) + for user in users: + user_stream = self.user_to_user_stream.get(str(user)) + if user_stream is not None: + user_streams.add(user_stream) - for room in rooms: - user_streams |= self.room_to_user_streams.get(room, set()) + for room in rooms: + user_streams |= self.room_to_user_streams.get(room, set()) - time_now_ms = self.clock.time_msec() - for user_stream in user_streams: - try: - user_stream.notify(stream_key, new_token, time_now_ms) - except: - logger.exception("Failed to notify listener") + time_now_ms = self.clock.time_msec() + for user_stream in user_streams: + try: + user_stream.notify(stream_key, new_token, time_now_ms) + except: + logger.exception("Failed to notify listener") @defer.inlineCallbacks def wait_for_events(self, user_id, timeout, callback, room_ids=None, @@ -325,7 +324,8 @@ class Notifier(object): # that we don't miss any current_token updates. prev_token = current_token listener = user_stream.new_listener(prev_token) - yield listener.deferred + with PreserveLoggingContext(): + yield listener.deferred except defer.CancelledError: break diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 64e581b8ba..8da2d8716c 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -111,7 +111,7 @@ class Pusher(object): self.user_id, config, timeout=0, affect_presence=False ) self.last_token = chunk['end'] - self.store.update_pusher_last_token( + yield self.store.update_pusher_last_token( self.app_id, self.pushkey, self.user_id, self.last_token ) logger.info("New pusher %s for user %s starting from token %s", diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index d1b7c0802f..d7dcb2de4b 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -18,6 +18,7 @@ from twisted.internet import defer from httppusher import HttpPusher from synapse.push import PusherConfigException +from synapse.util.logcontext import preserve_fn import logging @@ -76,7 +77,7 @@ class PusherPool: "Removing pusher for app id %s, pushkey %s, user %s", app_id, pushkey, p['user_name'] ) - self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) + yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) @defer.inlineCallbacks def remove_pushers_by_user(self, user_id): @@ -91,7 +92,7 @@ class PusherPool: "Removing pusher for app id %s, pushkey %s, user %s", p['app_id'], p['pushkey'], p['user_name'] ) - self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) + yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) @defer.inlineCallbacks def _add_pusher_to_store(self, user_id, access_token, profile_tag, kind, @@ -110,7 +111,7 @@ class PusherPool: lang=lang, data=data, ) - self._refresh_pusher(app_id, pushkey, user_id) + yield self._refresh_pusher(app_id, pushkey, user_id) def _create_pusher(self, pusherdict): if pusherdict['kind'] == 'http': @@ -166,7 +167,7 @@ class PusherPool: if fullid in self.pushers: self.pushers[fullid].stop() self.pushers[fullid] = p - p.start() + preserve_fn(p.start)() logger.info("Started pushers") diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index 985efe2a62..1456881c1a 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -57,7 +57,7 @@ class AccountDataServlet(RestServlet): user_id, account_data_type, body ) - yield self.notifier.on_new_event( + self.notifier.on_new_event( "account_data_key", max_id, users=[user_id] ) @@ -99,7 +99,7 @@ class RoomAccountDataServlet(RestServlet): user_id, room_id, account_data_type, body ) - yield self.notifier.on_new_event( + self.notifier.on_new_event( "account_data_key", max_id, users=[user_id] ) diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index 42f2203f3d..79c436a8cf 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -80,7 +80,7 @@ class TagServlet(RestServlet): max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body) - yield self.notifier.on_new_event( + self.notifier.on_new_event( "account_data_key", max_id, users=[user_id] ) @@ -94,7 +94,7 @@ class TagServlet(RestServlet): max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag) - yield self.notifier.on_new_event( + self.notifier.on_new_event( "account_data_key", max_id, users=[user_id] ) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index cfb87d9328..2e97ac84a8 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -15,7 +15,7 @@ import logging from synapse.api.errors import StoreError -from synapse.util.logcontext import preserve_context_over_fn, LoggingContext +from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.descriptors import Cache import synapse.metrics @@ -298,10 +298,10 @@ class SQLBaseStore(object): func, *args, **kwargs ) - result = yield preserve_context_over_fn( - self._db_pool.runWithConnection, - inner_func, *args, **kwargs - ) + with PreserveLoggingContext(): + result = yield self._db_pool.runWithConnection( + inner_func, *args, **kwargs + ) for after_callback, after_args in after_callbacks: after_callback(*after_args) @@ -326,10 +326,10 @@ class SQLBaseStore(object): return func(conn, *args, **kwargs) - result = yield preserve_context_over_fn( - self._db_pool.runWithConnection, - inner_func, *args, **kwargs - ) + with PreserveLoggingContext(): + result = yield self._db_pool.runWithConnection( + inner_func, *args, **kwargs + ) defer.returnValue(result) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 4d7cdd00d0..c6ed54721c 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -19,7 +19,7 @@ from twisted.internet import defer, reactor from synapse.events import FrozenEvent, USE_FROZEN_DICTS from synapse.events.utils import prune_event -from synapse.util.logcontext import preserve_context_over_deferred +from synapse.util.logcontext import preserve_fn, PreserveLoggingContext from synapse.util.logutils import log_function from synapse.api.constants import EventTypes @@ -664,14 +664,16 @@ class EventsStore(SQLBaseStore): for ids, d in lst: if not d.called: try: - d.callback([ - res[i] - for i in ids - if i in res - ]) + with PreserveLoggingContext(): + d.callback([ + res[i] + for i in ids + if i in res + ]) except: logger.exception("Failed to callback") - reactor.callFromThread(fire, event_list, row_dict) + with PreserveLoggingContext(): + reactor.callFromThread(fire, event_list, row_dict) except Exception as e: logger.exception("do_fetch") @@ -679,10 +681,12 @@ class EventsStore(SQLBaseStore): def fire(evs): for _, d in evs: if not d.called: - d.errback(e) + with PreserveLoggingContext(): + d.errback(e) if event_list: - reactor.callFromThread(fire, event_list) + with PreserveLoggingContext(): + reactor.callFromThread(fire, event_list) @defer.inlineCallbacks def _enqueue_events(self, events, check_redacted=True, @@ -709,18 +713,20 @@ class EventsStore(SQLBaseStore): should_start = False if should_start: - self.runWithConnection( - self._do_fetch - ) + with PreserveLoggingContext(): + self.runWithConnection( + self._do_fetch + ) - rows = yield preserve_context_over_deferred(events_d) + with PreserveLoggingContext(): + rows = yield events_d if not allow_rejected: rows[:] = [r for r in rows if not r["rejects"]] res = yield defer.gatherResults( [ - self._get_event_from_row( + preserve_fn(self._get_event_from_row)( row["internal_metadata"], row["json"], row["redacts"], check_redacted=check_redacted, get_prev_content=get_prev_content, diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 9b3aecaf8c..ef525f34c5 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -68,8 +68,9 @@ class PresenceStore(SQLBaseStore): for row in rows }) + @defer.inlineCallbacks def set_presence_state(self, user_localpart, new_state): - res = self._simple_update_one( + res = yield self._simple_update_one( table="presence", keyvalues={"user_id": user_localpart}, updatevalues={"state": new_state["state"], @@ -79,7 +80,7 @@ class PresenceStore(SQLBaseStore): ) self.get_presence_state.invalidate((user_localpart,)) - return res + defer.returnValue(res) def allow_presence_visible(self, observed_localpart, observer_userid): return self._simple_insert( diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 50436cb2d2..367ffc9543 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -39,6 +39,7 @@ from ._base import SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.api.constants import EventTypes from synapse.types import RoomStreamToken +from synapse.util.logcontext import preserve_fn import logging @@ -170,12 +171,12 @@ class StreamStore(SQLBaseStore): room_ids = list(room_ids) for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)): res = yield defer.gatherResults([ - self.get_room_events_stream_for_room( - room_id, from_key, to_key, limit - ).addCallback(lambda r, rm: (rm, r), room_id) + preserve_fn(self.get_room_events_stream_for_room)( + room_id, from_key, to_key, limit, + ) for room_id in room_ids ]) - results.update(dict(res)) + results.update(dict(zip(rm_ids, res))) defer.returnValue(results) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 7566d9eb33..133671e238 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext +from synapse.util.logcontext import PreserveLoggingContext from twisted.internet import defer, reactor, task @@ -61,10 +61,8 @@ class Clock(object): *args: Postional arguments to pass to function. **kwargs: Key arguments to pass to function. """ - current_context = LoggingContext.current_context() - def wrapped_callback(*args, **kwargs): - with PreserveLoggingContext(current_context): + with PreserveLoggingContext(): callback(*args, **kwargs) with PreserveLoggingContext(): diff --git a/synapse/util/async.py b/synapse/util/async.py index 200edd404c..640fae3890 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -16,13 +16,16 @@ from twisted.internet import defer, reactor -from .logcontext import preserve_context_over_deferred +from .logcontext import PreserveLoggingContext +@defer.inlineCallbacks def sleep(seconds): d = defer.Deferred() - reactor.callLater(seconds, d.callback, seconds) - return preserve_context_over_deferred(d) + with PreserveLoggingContext(): + reactor.callLater(seconds, d.callback, seconds) + res = yield d + defer.returnValue(res) def run_on_reactor(): @@ -54,6 +57,7 @@ class ObservableDeferred(object): object.__setattr__(self, "_result", (True, r)) while self._observers: try: + # TODO: Handle errors here. self._observers.pop().callback(r) except: pass @@ -63,6 +67,7 @@ class ObservableDeferred(object): object.__setattr__(self, "_result", (False, f)) while self._observers: try: + # TODO: Handle errors here. self._observers.pop().errback(f) except: pass diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index e27917c63a..277854ccbc 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -18,6 +18,9 @@ from synapse.util.async import ObservableDeferred from synapse.util import unwrapFirstError from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache +from synapse.util.logcontext import ( + PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn +) from . import caches_by_name, DEBUG_CACHES, cache_counter @@ -190,7 +193,7 @@ class CacheDescriptor(object): defer.returnValue(cached_result) observer.addCallback(check_result) - return observer + return preserve_context_over_deferred(observer) except KeyError: # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated @@ -198,6 +201,7 @@ class CacheDescriptor(object): sequence = self.cache.sequence ret = defer.maybeDeferred( + preserve_context_over_fn, self.function_to_call, obj, *args, **kwargs ) @@ -211,7 +215,7 @@ class CacheDescriptor(object): ret = ObservableDeferred(ret, consumeErrors=True) self.cache.update(sequence, cache_key, ret) - return ret.observe() + return preserve_context_over_deferred(ret.observe()) wrapped.invalidate = self.cache.invalidate wrapped.invalidate_all = self.cache.invalidate_all @@ -299,6 +303,7 @@ class CacheListDescriptor(object): args_to_call[self.list_name] = missing ret_d = defer.maybeDeferred( + preserve_context_over_fn, self.function_to_call, **args_to_call ) @@ -308,7 +313,8 @@ class CacheListDescriptor(object): # We need to create deferreds for each arg in the list so that # we can insert the new deferred into the cache. for arg in missing: - observer = ret_d.observe() + with PreserveLoggingContext(): + observer = ret_d.observe() observer.addCallback(lambda r, arg: r.get(arg, None), arg) observer = ObservableDeferred(observer) @@ -327,10 +333,10 @@ class CacheListDescriptor(object): cached[arg] = res - return defer.gatherResults( + return preserve_context_over_deferred(defer.gatherResults( cached.values(), consumeErrors=True, - ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)) + ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))) obj.__dict__[self.orig.__name__] = wrapped diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py index b1e40417fd..d03678b8c8 100644 --- a/synapse/util/caches/snapshot_cache.py +++ b/synapse/util/caches/snapshot_cache.py @@ -87,7 +87,8 @@ class SnapshotCache(object): # expire from the rotation of that cache. self.next_result_cache[key] = result self.pending_result_cache.pop(key, None) + return r - result.observe().addBoth(shuffle_along) + result.addBoth(shuffle_along) return result.observe() diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 4ebfebf701..8875813de4 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -15,9 +15,7 @@ from twisted.internet import defer -from synapse.util.logcontext import ( - PreserveLoggingContext, preserve_context_over_deferred, -) +from synapse.util.logcontext import PreserveLoggingContext from synapse.util import unwrapFirstError @@ -97,6 +95,7 @@ class Signal(object): Each observer callable may return a Deferred.""" self.observers.append(observer) + @defer.inlineCallbacks def fire(self, *args, **kwargs): """Invokes every callable in the observer list, passing in the args and kwargs. Exceptions thrown by observers are logged but ignored. It is @@ -116,6 +115,7 @@ class Signal(object): failure.getTracebackObject())) if not self.suppress_failures: return failure + return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) with PreserveLoggingContext(): @@ -124,8 +124,11 @@ class Signal(object): for observer in self.observers ] - d = defer.gatherResults(deferreds, consumeErrors=True) + res = yield defer.gatherResults( + deferreds, consumeErrors=True + ).addErrback(unwrapFirstError) - d.addErrback(unwrapFirstError) + defer.returnValue(res) - return preserve_context_over_deferred(d) + def __repr__(self): + return "" % (self.name,) diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index e701092cd8..9134e67908 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -48,7 +48,7 @@ class LoggingContext(object): __slots__ = [ "parent_context", "name", "usage_start", "usage_end", "main_thread", - "__dict__", "tag", + "__dict__", "tag", "alive", ] thread_local = threading.local() @@ -88,6 +88,7 @@ class LoggingContext(object): self.usage_start = None self.main_thread = threading.current_thread() self.tag = "" + self.alive = True def __str__(self): return "%s@%x" % (self.name, id(self)) @@ -106,6 +107,7 @@ class LoggingContext(object): The context that was previously active """ current = cls.current_context() + if current is not context: current.stop() cls.thread_local.current_context = context @@ -117,6 +119,7 @@ class LoggingContext(object): if self.parent_context is not None: raise Exception("Attempt to enter logging context multiple times") self.parent_context = self.set_current_context(self) + self.alive = True return self def __exit__(self, type, value, traceback): @@ -136,6 +139,7 @@ class LoggingContext(object): self ) self.parent_context = None + self.alive = False def __getattr__(self, name): """Delegate member lookup to parent context""" @@ -213,7 +217,7 @@ class PreserveLoggingContext(object): exited. Used to restore the context after a function using @defer.inlineCallbacks is resumed by a callback from the reactor.""" - __slots__ = ["current_context", "new_context"] + __slots__ = ["current_context", "new_context", "has_parent"] def __init__(self, new_context=LoggingContext.sentinel): self.new_context = new_context @@ -224,11 +228,26 @@ class PreserveLoggingContext(object): self.new_context ) + if self.current_context: + self.has_parent = self.current_context.parent_context is not None + if not self.current_context.alive: + logger.warn( + "Entering dead context: %s", + self.current_context, + ) + def __exit__(self, type, value, traceback): """Restores the current logging context""" - LoggingContext.set_current_context(self.current_context) + context = LoggingContext.set_current_context(self.current_context) + + if context != self.new_context: + logger.warn( + "Unexpected logging context: %s is not %s", + context, self.new_context, + ) + if self.current_context is not LoggingContext.sentinel: - if self.current_context.parent_context is None: + if not self.current_context.alive: logger.warn( "Restoring dead context: %s", self.current_context, @@ -289,3 +308,74 @@ def preserve_context_over_deferred(deferred): d = _PreservingContextDeferred(current_context) deferred.chainDeferred(d) return d + + +def preserve_fn(f): + """Ensures that function is called with correct context and that context is + restored after return. Useful for wrapping functions that return a deferred + which you don't yield on. + """ + current = LoggingContext.current_context() + + def g(*args, **kwargs): + with PreserveLoggingContext(current): + return f(*args, **kwargs) + + return g + + +# modules to ignore in `logcontext_tracer` +_to_ignore = [ + "synapse.util.logcontext", + "synapse.http.server", + "synapse.storage._base", + "synapse.util.async", +] + + +def logcontext_tracer(frame, event, arg): + """A tracer that logs whenever a logcontext "unexpectedly" changes within + a function. Probably inaccurate. + + Use by calling `sys.settrace(logcontext_tracer)` in the main thread. + """ + if event == 'call': + name = frame.f_globals["__name__"] + if name.startswith("synapse"): + if name == "synapse.util.logcontext": + if frame.f_code.co_name in ["__enter__", "__exit__"]: + tracer = frame.f_back.f_trace + if tracer: + tracer.just_changed = True + + tracer = frame.f_trace + if tracer: + return tracer + + if not any(name.startswith(ig) for ig in _to_ignore): + return LineTracer() + + +class LineTracer(object): + __slots__ = ["context", "just_changed"] + + def __init__(self): + self.context = LoggingContext.current_context() + self.just_changed = False + + def __call__(self, frame, event, arg): + if event in 'line': + if self.just_changed: + self.context = LoggingContext.current_context() + self.just_changed = False + else: + c = LoggingContext.current_context() + if c != self.context: + logger.info( + "Context changed! %s -> %s, %s, %s", + self.context, c, + frame.f_code.co_filename, frame.f_lineno + ) + self.context = c + + return self diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py index c37a157787..3a83828d25 100644 --- a/synapse/util/logutils.py +++ b/synapse/util/logutils.py @@ -168,3 +168,38 @@ def trace_function(f): wrapped.__name__ = func_name return wrapped + + +def get_previous_frames(): + s = inspect.currentframe().f_back.f_back + to_return = [] + while s: + if s.f_globals["__name__"].startswith("synapse"): + filename, lineno, function, _, _ = inspect.getframeinfo(s) + args_string = inspect.formatargvalues(*inspect.getargvalues(s)) + + to_return.append("{{ %s:%d %s - Args: %s }}" % ( + filename, lineno, function, args_string + )) + + s = s.f_back + + return ", ". join(to_return) + + +def get_previous_frame(ignore=[]): + s = inspect.currentframe().f_back.f_back + + while s: + if s.f_globals["__name__"].startswith("synapse"): + if not any(s.f_globals["__name__"].startswith(ig) for ig in ignore): + filename, lineno, function, _, _ = inspect.getframeinfo(s) + args_string = inspect.formatargvalues(*inspect.getargvalues(s)) + + return "{{ %s:%d %s - Args: %s }}" % ( + filename, lineno, function, args_string + ) + + s = s.f_back + + return None diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index daf6087fe0..ca48007218 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -68,16 +68,18 @@ class Measure(object): block_timer.inc_by(duration, self.name) context = LoggingContext.current_context() - if not context: - return if context != self.start_context: logger.warn( - "Context have unexpectedly changed %r, %r", - context, self.start_context + "Context have unexpectedly changed from '%s' to '%s'. (%r)", + context, self.start_context, self.name ) return + if not context: + logger.warn("Expected context. (%r)", self.name) + return + ru_utime, ru_stime = context.get_resource_usage() block_ru_utime.inc_by(ru_utime, self.name) diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index ea321bc6a9..4076eed269 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.api.errors import LimitExceededError from synapse.util.async import sleep +from synapse.util.logcontext import preserve_fn import collections import contextlib @@ -163,7 +164,7 @@ class _PerHostRatelimiter(object): "Ratelimit [%s]: sleeping req", id(request_id), ) - ret_defer = sleep(self.sleep_msec / 1000.0) + ret_defer = preserve_fn(sleep)(self.sleep_msec / 1000.0) self.sleeping_requests.add(request_id) -- cgit 1.4.1 From eff12e838ce10588ca8103c9131dcfe2f2e7950e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 9 Feb 2016 13:55:59 +0000 Subject: Don't load all ephemeral state for a room on every sync --- synapse/handlers/sync.py | 20 ++++++-------------- synapse/storage/receipts.py | 14 ++++++++++++++ 2 files changed, 20 insertions(+), 14 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 446f8bbe93..6a5868f87e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -319,7 +319,6 @@ class SyncHandler(BaseHandler): ephemeral_by_room=ephemeral_by_room, tags_by_room=tags_by_room, account_data_by_room=account_data_by_room, - all_ephemeral_by_room=ephemeral_by_room, batch=batch, full_state=True, ) @@ -453,13 +452,6 @@ class SyncHandler(BaseHandler): ) now_token = now_token.copy_and_replace("presence_key", presence_key) - # We now fetch all ephemeral events for this room in order to get - # this users current read receipt. This could almost certainly be - # optimised. - _, all_ephemeral_by_room = yield self.ephemeral_by_room( - sync_config, now_token - ) - now_token, ephemeral_by_room = yield self.ephemeral_by_room( sync_config, now_token, since_token ) @@ -591,7 +583,6 @@ class SyncHandler(BaseHandler): ephemeral_by_room=ephemeral_by_room, tags_by_room=tags_by_room, account_data_by_room=account_data_by_room, - all_ephemeral_by_room=all_ephemeral_by_room, batch=batch, full_state=full_state, ) @@ -691,7 +682,6 @@ class SyncHandler(BaseHandler): since_token, now_token, ephemeral_by_room, tags_by_room, account_data_by_room, - all_ephemeral_by_room, batch, full_state=False): state = yield self.compute_state_delta( room_id, batch, sync_config, since_token, now_token, @@ -722,7 +712,7 @@ class SyncHandler(BaseHandler): if room_sync: notifs = yield self.unread_notifs_for_room_id( - room_id, sync_config, all_ephemeral_by_room + room_id, sync_config ) if notifs is not None: @@ -906,10 +896,12 @@ class SyncHandler(BaseHandler): return False @defer.inlineCallbacks - def unread_notifs_for_room_id(self, room_id, sync_config, ephemeral_by_room): + def unread_notifs_for_room_id(self, room_id, sync_config): with Measure(self.clock, "unread_notifs_for_room_id"): - last_unread_event_id = self.last_read_event_id_for_room_and_user( - room_id, sync_config.user.to_string(), ephemeral_by_room + last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user( + user_id=sync_config.user.to_string(), + room_id=room_id, + receipt_type="m.read" ) notifs = [] diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 8068c73740..1aff9f070e 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -46,6 +46,20 @@ class ReceiptsStore(SQLBaseStore): desc="get_receipts_for_room", ) + @cached(num_args=3) + def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): + return self._simple_select_one_onecol( + table="receipts_linearized", + keyvalues={ + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id + }, + retcol="event_id", + desc="get_own_receipt_for_user", + allow_none=True, + ) + @cachedInlineCallbacks(num_args=2) def get_receipts_for_user(self, user_id, receipt_type): def f(txn): -- cgit 1.4.1 From 70a8608749e0c1ec7a993a9effc424303af24738 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 9 Feb 2016 14:27:29 +0000 Subject: Invalidate get_last_receipt_event_id_for_user cache --- synapse/storage/receipts.py | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'synapse/storage') diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 1aff9f070e..4202a6b3dc 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -240,6 +240,11 @@ class ReceiptsStore(SQLBaseStore): room_id, stream_id ) + txn.call_after( + self.get_last_receipt_event_id_for_user.invalidate, + (user_id, room_id, receipt_type) + ) + # We don't want to clobber receipts for more recent events, so we # have to compare orderings of existing receipts sql = ( -- cgit 1.4.1 From 78d6c1b5bec800671d3ff66acecb2f8bbdf41aa1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 9 Feb 2016 14:44:12 +0000 Subject: Change a log from debug to info --- synapse/storage/prepare_database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index d782b8e25b..850736c85e 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -211,7 +211,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, logger.debug("applied_delta_files: %s", applied_delta_files) for v in range(start_ver, SCHEMA_VERSION + 1): - logger.debug("Upgrading schema to v%d", v) + logger.info("Upgrading schema to v%d", v) delta_dir = os.path.join(dir_path, "schema", "delta", str(v)) -- cgit 1.4.1 From 7b0d846407a593ccd204f82aaa1090b8af8df84c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 9 Feb 2016 16:19:15 +0000 Subject: Atomically persit push actions when we persist the event --- synapse/events/snapshot.py | 1 + synapse/handlers/_base.py | 10 ++++---- synapse/handlers/federation.py | 12 +++++----- synapse/push/action_generator.py | 20 ++++------------ synapse/storage/event_push_actions.py | 45 +++++++++++++---------------------- synapse/storage/events.py | 26 ++++++++++++-------- 6 files changed, 49 insertions(+), 65 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index f51200d18e..8a475417a6 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -20,3 +20,4 @@ class EventContext(object): self.current_state = current_state self.state_group = None self.rejected = False + self.push_actions = [] diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index d3f722b22e..064e8723c8 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -264,13 +264,13 @@ class BaseHandler(object): "You don't have permission to redact events" ) - (event_stream_id, max_stream_id) = yield self.store.persist_event( - event, context=context - ) - action_generator = ActionGenerator(self.hs) yield action_generator.handle_push_actions_for_event( - event, self, context.current_state + event, context, self + ) + + (event_stream_id, max_stream_id) = yield self.store.persist_event( + event, context=context ) destinations = set() diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index b78b0502d9..da55d43541 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -236,12 +236,6 @@ class FederationHandler(BaseHandler): user = UserID.from_string(event.state_key) yield user_joined_room(self.distributor, user, event.room_id) - if not backfilled and not event.internal_metadata.is_outlier(): - action_generator = ActionGenerator(self.hs) - yield action_generator.handle_push_actions_for_event( - event, self - ) - @defer.inlineCallbacks def _filter_events_for_server(self, server_name, room_id, events): event_to_state = yield self.store.get_state_for_events( @@ -1073,6 +1067,12 @@ class FederationHandler(BaseHandler): auth_events=auth_events, ) + if not backfilled and not event.internal_metadata.is_outlier(): + action_generator = ActionGenerator(self.hs) + yield action_generator.handle_push_actions_for_event( + event, context, self + ) + event_stream_id, max_stream_id = yield self.store.persist_event( event, context=context, diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index d8f8256a1f..e0da0868ec 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -19,8 +19,6 @@ import bulk_push_rule_evaluator import logging -from synapse.api.constants import EventTypes - logger = logging.getLogger(__name__) @@ -36,23 +34,15 @@ class ActionGenerator: # tag (ie. we just need all the users). @defer.inlineCallbacks - def handle_push_actions_for_event(self, event, handler, current_state): - if event.type == EventTypes.Redaction and event.redacts is not None: - yield self.store.remove_push_actions_for_event_id( - event.room_id, event.redacts - ) - + def handle_push_actions_for_event(self, event, context, handler): bulk_evaluator = yield bulk_push_rule_evaluator.evaluator_for_room_id( event.room_id, self.hs, self.store ) actions_by_user = yield bulk_evaluator.action_for_event_by_user( - event, handler, current_state + event, handler, context.current_state ) - yield self.store.set_push_actions_for_event_and_users( - event, - [ - (uid, None, actions) for uid, actions in actions_by_user.items() - ] - ) + context.push_actions = [ + (uid, None, actions) for uid, actions in actions_by_user.items() + ] diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index d0a969f50b..466f07a1c4 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -24,8 +24,7 @@ logger = logging.getLogger(__name__) class EventPushActionsStore(SQLBaseStore): - @defer.inlineCallbacks - def set_push_actions_for_event_and_users(self, event, tuples): + def _set_push_actions_for_event_and_users(self, txn, event, tuples): """ :param event: the event set actions for :param tuples: list of tuples of (user_id, profile_tag, actions) @@ -44,18 +43,12 @@ class EventPushActionsStore(SQLBaseStore): 'highlight': 1 if _action_has_highlight(actions) else 0, }) - def f(txn): - for uid, _, __ in tuples: - txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (event.room_id, uid) - ) - return self._simple_insert_many_txn(txn, "event_push_actions", values) - - yield self.runInteraction( - "set_actions_for_event_and_users", - f, - ) + for uid, _, __ in tuples: + txn.call_after( + self.get_unread_event_push_actions_by_room_for_user.invalidate_many, + (event.room_id, uid) + ) + self._simple_insert_many_txn(txn, "event_push_actions", values) @cachedInlineCallbacks(num_args=3, lru=True, tree=True) def get_unread_event_push_actions_by_room_for_user( @@ -107,21 +100,15 @@ class EventPushActionsStore(SQLBaseStore): ) defer.returnValue(ret) - @defer.inlineCallbacks - def remove_push_actions_for_event_id(self, room_id, event_id): - def f(txn): - # Sad that we have to blow away the cache for the whole room here - txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (room_id,) - ) - txn.execute( - "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?", - (room_id, event_id) - ) - yield self.runInteraction( - "remove_push_actions_for_event_id", - f + def _remove_push_actions_for_event_id(self, txn, room_id, event_id): + # Sad that we have to blow away the cache for the whole room here + txn.call_after( + self.get_unread_event_push_actions_by_room_for_user.invalidate_many, + (room_id,) + ) + txn.execute( + "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?", + (room_id, event_id) ) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index c6ed54721c..7d4012c414 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -205,23 +205,29 @@ class EventsStore(SQLBaseStore): @log_function def _persist_events_txn(self, txn, events_and_contexts, backfilled, is_new_state=True): - - # Remove the any existing cache entries for the event_ids - for event, _ in events_and_contexts: + depth_updates = {} + for event, context in events_and_contexts: + # Remove the any existing cache entries for the event_ids txn.call_after(self._invalidate_get_event_cache, event.event_id) - if not backfilled: txn.call_after( self._events_stream_cache.entity_has_changed, event.room_id, event.internal_metadata.stream_ordering, ) - depth_updates = {} - for event, _ in events_and_contexts: - if event.internal_metadata.is_outlier(): - continue - depth_updates[event.room_id] = max( - event.depth, depth_updates.get(event.room_id, event.depth) + if not event.internal_metadata.is_outlier(): + depth_updates[event.room_id] = max( + event.depth, depth_updates.get(event.room_id, event.depth) + ) + + if context.push_actions: + self._set_push_actions_for_event_and_users( + txn, event, context.push_actions + ) + + if event.type == EventTypes.Redaction and event.redacts is not None: + self._remove_push_actions_for_event_id( + txn, event.room_id, event.redacts ) for room_id, depth in depth_updates.items(): -- cgit 1.4.1 From 02147452396c67e7874b201460f8b1cc8996a90a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 10 Feb 2016 11:09:56 +0000 Subject: Rename functions --- synapse/storage/event_push_actions.py | 4 ++-- synapse/storage/events.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index 466f07a1c4..d77a817682 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) class EventPushActionsStore(SQLBaseStore): - def _set_push_actions_for_event_and_users(self, txn, event, tuples): + def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): """ :param event: the event set actions for :param tuples: list of tuples of (user_id, profile_tag, actions) @@ -100,7 +100,7 @@ class EventPushActionsStore(SQLBaseStore): ) defer.returnValue(ret) - def _remove_push_actions_for_event_id(self, txn, room_id, event_id): + def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): # Sad that we have to blow away the cache for the whole room here txn.call_after( self.get_unread_event_push_actions_by_room_for_user.invalidate_many, diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 7d4012c414..3a5c6ee4b1 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -221,12 +221,12 @@ class EventsStore(SQLBaseStore): ) if context.push_actions: - self._set_push_actions_for_event_and_users( + self._set_push_actions_for_event_and_users_txn( txn, event, context.push_actions ) if event.type == EventTypes.Redaction and event.redacts is not None: - self._remove_push_actions_for_event_id( + self._remove_push_actions_for_event_id_txn( txn, event.room_id, event.redacts ) -- cgit 1.4.1 From 24f00a6c33900cf701330ff324b0479c1898d5ce Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 10 Feb 2016 12:57:50 +0000 Subject: Use _simple_select_many for _get_state_group_for_events --- synapse/handlers/sync.py | 2 +- synapse/storage/state.py | 26 ++++++++++---------------- 2 files changed, 11 insertions(+), 17 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 84f29e3867..1d0f0058a2 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -18,7 +18,7 @@ from ._base import BaseHandler from synapse.streams.config import PaginationConfig from synapse.api.constants import Membership, EventTypes from synapse.util import unwrapFirstError -from synapse.util.logcontext import LoggingContext, PreserveLoggingContext, preserve_fn +from synapse.util.logcontext import LoggingContext, preserve_fn from synapse.util.metrics import Measure from twisted.internet import defer diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 6c32e8f7b3..90ec50bb50 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -264,26 +264,20 @@ class StateStore(SQLBaseStore): ) @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", - num_args=1) + num_args=1, inlineCallbacks=True) def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ - def f(txn): - results = {} - for event_id in event_ids: - results[event_id] = self._simple_select_one_onecol_txn( - txn, - table="event_to_state_groups", - keyvalues={ - "event_id": event_id, - }, - retcol="state_group", - allow_none=True, - ) - - return results + rows = yield self._simple_select_many_batch( + table="event_to_state_groups", + column="event_id", + iterable=event_ids, + keyvalues={}, + retcols=("event_id", "state_group",), + desc="_get_state_group_for_events", + ) - return self.runInteraction("_get_state_group_for_events", f) + defer.returnValue({row["event_id"]: row["state_group"] for row in rows}) def _get_some_state_from_cache(self, group, types): """Checks if group is in cache. See `_get_state_for_groups` -- cgit 1.4.1 From 5189bfdef4c87a7b0527de603eae52ac27bd500c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 10 Feb 2016 13:24:42 +0000 Subject: Batch fetch _get_state_groups_from_groups --- synapse/storage/state.py | 66 +++++++++++++++++++++++++----------------------- 1 file changed, 34 insertions(+), 32 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 90ec50bb50..372b540002 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -171,41 +171,43 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) - def _get_state_groups_from_groups(self, groups_and_types): + def _get_state_groups_from_groups(self, groups, types): """Returns dictionary state_group -> state event ids - - Args: - groups_and_types (list): list of 2-tuple (`group`, `types`) """ - def f(txn): - results = {} - for group, types in groups_and_types: - if types is not None: - where_clause = "AND (%s)" % ( - " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), - ) - else: - where_clause = "" - - sql = ( - "SELECT event_id FROM state_groups_state WHERE" - " state_group = ? %s" - ) % (where_clause,) + def f(txn, groups): + if types is not None: + where_clause = "AND (%s)" % ( + " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), + ) + else: + where_clause = "" - args = [group] - if types is not None: - args.extend([i for typ in types for i in typ]) + sql = ( + "SELECT state_group, event_id FROM state_groups_state WHERE" + " state_group IN (%s) %s" % ( + ",".join("?" for _ in groups), + where_clause, + ) + ) - txn.execute(sql, args) + args = list(groups) + if types is not None: + args.extend([i for typ in types for i in typ]) - results[group] = [r[0] for r in txn.fetchall()] + txn.execute(sql, args) + rows = self.cursor_to_dict(txn) + results = {} + for row in rows: + results.setdefault(row["state_group"], []).append(row["event_id"]) return results - return self.runInteraction( - "_get_state_groups_from_groups", - f, - ) + chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)] + for chunk in chunks: + return self.runInteraction( + "_get_state_groups_from_groups", + f, chunk + ) @defer.inlineCallbacks def get_state_for_events(self, event_ids, types): @@ -349,7 +351,7 @@ class StateStore(SQLBaseStore): all events are returned. """ results = {} - missing_groups_and_types = [] + missing_groups = [] if types is not None: for group in set(groups): state_dict, missing_types, got_all = self._get_some_state_from_cache( @@ -358,7 +360,7 @@ class StateStore(SQLBaseStore): results[group] = state_dict if not got_all: - missing_groups_and_types.append((group, missing_types)) + missing_groups.append(group) else: for group in set(groups): state_dict, got_all = self._get_all_state_from_cache( @@ -367,9 +369,9 @@ class StateStore(SQLBaseStore): results[group] = state_dict if not got_all: - missing_groups_and_types.append((group, None)) + missing_groups.append(group) - if not missing_groups_and_types: + if not missing_groups: defer.returnValue({ group: { type_tuple: event @@ -383,7 +385,7 @@ class StateStore(SQLBaseStore): cache_seq_num = self._state_group_cache.sequence group_state_dict = yield self._get_state_groups_from_groups( - missing_groups_and_types + missing_groups, types ) state_events = yield self._get_events( -- cgit 1.4.1 From 0eff7405239c794d4c000ee93c0e38a87ea2b1bd Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 11 Feb 2016 10:07:27 +0000 Subject: Return events in correct order for /events --- synapse/handlers/room.py | 1 + synapse/storage/stream.py | 19 +++++++++++-------- 2 files changed, 12 insertions(+), 8 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index a8e3a9029c..b2de2cd0c0 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1061,6 +1061,7 @@ class RoomEventSource(object): from_key=from_key, to_key=to_key, limit=limit or 10, + order='ASC', ) events = list(room_events) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 367ffc9543..0d1034c6f1 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -157,7 +157,8 @@ class StreamStore(SQLBaseStore): defer.returnValue(results) @defer.inlineCallbacks - def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0): + def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0, + order='DESC'): from_id = RoomStreamToken.parse_stream_token(from_key).stream room_ids = yield self._events_stream_cache.get_entities_changed( @@ -172,7 +173,7 @@ class StreamStore(SQLBaseStore): for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)): res = yield defer.gatherResults([ preserve_fn(self.get_room_events_stream_for_room)( - room_id, from_key, to_key, limit, + room_id, from_key, to_key, limit, order=order, ) for room_id in room_ids ]) @@ -181,7 +182,8 @@ class StreamStore(SQLBaseStore): defer.returnValue(results) @defer.inlineCallbacks - def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0): + def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0, + order='DESC'): if from_key is not None: from_id = RoomStreamToken.parse_stream_token(from_key).stream else: @@ -206,8 +208,8 @@ class StreamStore(SQLBaseStore): " room_id = ?" " AND not outlier" " AND stream_ordering > ? AND stream_ordering <= ?" - " ORDER BY stream_ordering DESC LIMIT ?" - ) + " ORDER BY stream_ordering %s LIMIT ?" + ) % (order,) txn.execute(sql, (room_id, from_id, to_id, limit)) else: sql = ( @@ -215,8 +217,8 @@ class StreamStore(SQLBaseStore): " room_id = ?" " AND not outlier" " AND stream_ordering <= ?" - " ORDER BY stream_ordering DESC LIMIT ?" - ) + " ORDER BY stream_ordering %s LIMIT ?" + ) % (order,) txn.execute(sql, (room_id, to_id, limit)) rows = self.cursor_to_dict(txn) @@ -232,7 +234,8 @@ class StreamStore(SQLBaseStore): self._set_before_and_after(ret, rows, topo_order=False) - ret.reverse() + if order.lower() == "desc": + ret.reverse() if rows: key = "s%d" % min(r["stream_ordering"] for r in rows) -- cgit 1.4.1 From ce14c7a9954ebfa80831efc0901ca04d8cfe6ab5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 11 Feb 2016 15:02:56 +0000 Subject: Fix SYN-627, events are in incorrect room in /sync --- synapse/storage/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 0d1034c6f1..c236dafafb 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -175,7 +175,7 @@ class StreamStore(SQLBaseStore): preserve_fn(self.get_room_events_stream_for_room)( room_id, from_key, to_key, limit, order=order, ) - for room_id in room_ids + for room_id in rm_ids ]) results.update(dict(zip(rm_ids, res))) -- cgit 1.4.1 From 763360594dfb90433f693056d3d64ac82409fc87 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Thu, 11 Feb 2016 14:10:00 +0000 Subject: Mark AS users with their AS's ID --- scripts/synapse_port_db | 9 ++++- synapse/app/homeserver.py | 2 +- synapse/storage/appservice.py | 34 ++++++++++------- synapse/storage/engines/__init__.py | 5 ++- synapse/storage/engines/postgres.py | 5 ++- synapse/storage/engines/sqlite3.py | 5 ++- synapse/storage/prepare_database.py | 15 ++++---- synapse/storage/schema/delta/30/as_users.py | 59 +++++++++++++++++++++++++++++ tests/storage/test_base.py | 3 +- tests/utils.py | 20 +++++++--- 10 files changed, 121 insertions(+), 36 deletions(-) create mode 100644 synapse/storage/schema/delta/30/as_users.py (limited to 'synapse/storage') diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index fc92bbf2d8..a2a0f364cf 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -309,8 +309,8 @@ class Porter(object): **self.postgres_config["args"] ) - sqlite_engine = create_engine("sqlite3") - postgres_engine = create_engine("psycopg2") + sqlite_engine = create_engine(FakeConfig(sqlite_config)) + postgres_engine = create_engine(FakeConfig(postgres_config)) self.sqlite_store = Store(sqlite_db_pool, sqlite_engine) self.postgres_store = Store(postgres_db_pool, postgres_engine) @@ -792,3 +792,8 @@ if __name__ == "__main__": if end_error_exec_info: exc_type, exc_value, exc_traceback = end_error_exec_info traceback.print_exception(exc_type, exc_value, exc_traceback) + + +class FakeConfig: + def __init__(self, database_config): + self.database_config = database_config diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 2b4be7bdd0..d2e758c401 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -382,7 +382,7 @@ def setup(config_options): tls_server_context_factory = context_factory.ServerContextFactory(config) - database_engine = create_engine(config.database_config["name"]) + database_engine = create_engine(config) config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection hs = SynapseHomeServer( diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index 1100c67714..371600eebb 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -34,8 +34,8 @@ class ApplicationServiceStore(SQLBaseStore): def __init__(self, hs): super(ApplicationServiceStore, self).__init__(hs) self.hostname = hs.hostname - self.services_cache = [] - self._populate_appservice_cache( + self.services_cache = ApplicationServiceStore.load_appservices( + hs.hostname, hs.config.app_service_config_files ) @@ -144,21 +144,23 @@ class ApplicationServiceStore(SQLBaseStore): return rooms_for_user_matching_user_id - def _load_appservice(self, as_info): + @classmethod + def _load_appservice(cls, hostname, as_info, config_filename): required_string_fields = [ - # TODO: Add id here when it's stable to release - "url", "as_token", "hs_token", "sender_localpart" + "id", "url", "as_token", "hs_token", "sender_localpart" ] for field in required_string_fields: if not isinstance(as_info.get(field), basestring): - raise KeyError("Required string field: '%s'", field) + raise KeyError("Required string field: '%s' (%s)" % ( + field, config_filename, + )) localpart = as_info["sender_localpart"] if urllib.quote(localpart) != localpart: raise ValueError( "sender_localpart needs characters which are not URL encoded." ) - user = UserID(localpart, self.hostname) + user = UserID(localpart, hostname) user_id = user.to_string() # namespace checks @@ -188,25 +190,30 @@ class ApplicationServiceStore(SQLBaseStore): namespaces=as_info["namespaces"], hs_token=as_info["hs_token"], sender=user_id, - id=as_info["id"] if "id" in as_info else as_info["as_token"], + id=as_info["id"], ) - def _populate_appservice_cache(self, config_files): - """Populates a cache of Application Services from the config files.""" + @classmethod + def load_appservices(cls, hostname, config_files): + """Returns a list of Application Services from the config files.""" if not isinstance(config_files, list): logger.warning( "Expected %s to be a list of AS config files.", config_files ) - return + return [] # Dicts of value -> filename seen_as_tokens = {} seen_ids = {} + appservices = [] + for config_file in config_files: try: with open(config_file, 'r') as f: - appservice = self._load_appservice(yaml.load(f)) + appservice = ApplicationServiceStore._load_appservice( + hostname, yaml.load(f), config_file + ) if appservice.id in seen_ids: raise ConfigError( "Cannot reuse ID across application services: " @@ -226,11 +233,12 @@ class ApplicationServiceStore(SQLBaseStore): ) seen_as_tokens[appservice.token] = config_file logger.info("Loaded application service: %s", appservice) - self.services_cache.append(appservice) + appservices.append(appservice) except Exception as e: logger.error("Failed to load appservice from '%s'", config_file) logger.exception(e) raise + return appservices class ApplicationServiceTransactionStore(SQLBaseStore): diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index 4290aea83a..a48230b93f 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -26,12 +26,13 @@ SUPPORTED_MODULE = { } -def create_engine(name): +def create_engine(config): + name = config.database_config["name"] engine_class = SUPPORTED_MODULE.get(name, None) if engine_class: module = importlib.import_module(name) - return engine_class(module) + return engine_class(module, config=config) raise RuntimeError( "Unsupported database engine '%s'" % (name,) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 17b7a9c077..a09685b4df 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -21,9 +21,10 @@ from ._base import IncorrectDatabaseSetup class PostgresEngine(object): single_threaded = False - def __init__(self, database_module): + def __init__(self, database_module, config): self.module = database_module self.module.extensions.register_type(self.module.extensions.UNICODE) + self.config = config def check_database(self, txn): txn.execute("SHOW SERVER_ENCODING") @@ -44,7 +45,7 @@ class PostgresEngine(object): ) def prepare_database(self, db_conn): - prepare_database(db_conn, self) + prepare_database(db_conn, self, config=self.config) def is_deadlock(self, error): if isinstance(error, self.module.DatabaseError): diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index 91fac33b8b..522b905949 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -23,8 +23,9 @@ import struct class Sqlite3Engine(object): single_threaded = True - def __init__(self, database_module): + def __init__(self, database_module, config): self.module = database_module + self.config = config def check_database(self, txn): pass @@ -38,7 +39,7 @@ class Sqlite3Engine(object): def prepare_database(self, db_conn): prepare_sqlite3_database(db_conn) - prepare_database(db_conn, self) + prepare_database(db_conn, self, config=self.config) def is_deadlock(self, error): return False diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 850736c85e..3f29aad1e8 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 29 +SCHEMA_VERSION = 30 dir_path = os.path.abspath(os.path.dirname(__file__)) @@ -50,7 +50,7 @@ class UpgradeDatabaseException(PrepareDatabaseException): pass -def prepare_database(db_conn, database_engine): +def prepare_database(db_conn, database_engine, config): """Prepares a database for usage. Will either create all necessary tables or upgrade from an older schema version. """ @@ -61,10 +61,10 @@ def prepare_database(db_conn, database_engine): if version_info: user_version, delta_files, upgraded = version_info _upgrade_existing_database( - cur, user_version, delta_files, upgraded, database_engine + cur, user_version, delta_files, upgraded, database_engine, config ) else: - _setup_new_database(cur, database_engine) + _setup_new_database(cur, database_engine, config) # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) @@ -75,7 +75,7 @@ def prepare_database(db_conn, database_engine): raise -def _setup_new_database(cur, database_engine): +def _setup_new_database(cur, database_engine, config): """Sets up the database by finding a base set of "full schemas" and then applying any necessary deltas. @@ -148,11 +148,12 @@ def _setup_new_database(cur, database_engine): applied_delta_files=[], upgraded=False, database_engine=database_engine, + config=config, ) def _upgrade_existing_database(cur, current_version, applied_delta_files, - upgraded, database_engine): + upgraded, database_engine, config): """Upgrades an existing database. Delta files can either be SQL stored in *.sql files, or python modules @@ -245,7 +246,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, module_name, absolute_path, python_file ) logger.debug("Running script %s", relative_path) - module.run_upgrade(cur, database_engine) + module.run_upgrade(cur, database_engine, config=config) elif ext == ".pyc": # Sometimes .pyc files turn up anyway even though we've # disabled their generation; e.g. from distribution package diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py new file mode 100644 index 0000000000..4cf4dd0917 --- /dev/null +++ b/synapse/storage/schema/delta/30/as_users.py @@ -0,0 +1,59 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from synapse.storage.appservice import ApplicationServiceStore + + +logger = logging.getLogger(__name__) + + +def run_upgrade(cur, database_engine, config, *args, **kwargs): + # NULL indicates user was not registered by an appservice. + cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT") + + cur.execute("SELECT name FROM users") + rows = cur.fetchall() + + config_files = [] + try: + config_files = config.app_service_config_files + except AttributeError: + logger.warning("Could not get app_service_config_files from config") + pass + + appservices = ApplicationServiceStore.load_appservices( + config.server_name, config_files + ) + + owned = {} + + for row in rows: + user_id = row[0] + for appservice in appservices: + if appservice.is_exclusive_user(user_id): + if user_id in owned.keys(): + logger.error( + "user_id %s was owned by more than one application" + " service (IDs %s and %s); assigning arbitrarily to %s" % + (user_id, owned[user_id], appservice.id, owned[user_id]) + ) + owned[user_id] = appservice.id + + for user_id, as_id in owned.items(): + cur.execute( + database_engine.convert_param_style( + "UPDATE users SET appservice_id = ? WHERE name = ?" + ), + (as_id, user_id) + ) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 152d027663..0684fb6f70 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -48,11 +48,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): config = Mock() config.event_cache_size = 1 + config.database_config = {"name": "sqlite3"} hs = HomeServer( "test", db_pool=self.db_pool, config=config, - database_engine=create_engine("sqlite3"), + database_engine=create_engine(config), ) self.datastore = SQLBaseStore(hs) diff --git a/tests/utils.py b/tests/utils.py index 3b1eb50d8d..f3935648a0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -51,6 +51,8 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): config.server_name = "server.under.test" config.trusted_third_party_id_servers = [] + config.database_config = {"name": "sqlite3"} + if "clock" not in kargs: kargs["clock"] = MockClock() @@ -60,7 +62,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): hs = HomeServer( name, db_pool=db_pool, config=config, version_string="Synapse/tests", - database_engine=create_engine("sqlite3"), + database_engine=create_engine(config), get_db_conn=db_pool.get_db_conn, **kargs ) @@ -69,7 +71,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): hs = HomeServer( name, db_pool=None, datastore=datastore, config=config, version_string="Synapse/tests", - database_engine=create_engine("sqlite3"), + database_engine=create_engine(config), **kargs ) @@ -277,18 +279,24 @@ class SQLiteMemoryDbPool(ConnectionPool, object): cp_max=1, ) + self.config = Mock() + self.config.database_config = {"name": "sqlite3"} + def prepare(self): - engine = create_engine("sqlite3") + engine = self.create_engine() return self.runWithConnection( - lambda conn: prepare_database(conn, engine) + lambda conn: prepare_database(conn, engine, self.config) ) def get_db_conn(self): conn = self.connect() - engine = create_engine("sqlite3") - prepare_database(conn, engine) + engine = self.create_engine() + prepare_database(conn, engine, self.config) return conn + def create_engine(self): + return create_engine(self.config) + class MemoryDataStore(object): -- cgit 1.4.1 From a9c9868957277408c7ae3956d73ff87964692b73 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 16 Feb 2016 15:53:38 +0000 Subject: Make adding push rules idempotent Also remove the **kwargs from the add_push_rule method. Fixes https://matrix.org/jira/browse/SYN-391 --- synapse/storage/push_rule.py | 168 ++++++++++++++++++++++--------------------- 1 file changed, 86 insertions(+), 82 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index f9a48171ba..e19a81e41f 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -99,38 +99,36 @@ class PushRuleStore(SQLBaseStore): results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled'] defer.returnValue(results) - @defer.inlineCallbacks - def add_push_rule(self, before, after, **kwargs): - vals = kwargs - if 'conditions' in vals: - vals['conditions'] = json.dumps(vals['conditions']) - if 'actions' in vals: - vals['actions'] = json.dumps(vals['actions']) - - # we could check the rest of the keys are valid column names - # but sqlite will do that anyway so I think it's just pointless. - vals.pop("id", None) + def add_push_rule( + self, user_id, rule_id, priority_class, conditions, actions, + before=None, after=None + ): + conditions_json = json.dumps(conditions) + actions_json = json.dumps(actions) if before or after: - ret = yield self.runInteraction( + return self.runInteraction( "_add_push_rule_relative_txn", self._add_push_rule_relative_txn, - before=before, - after=after, - **vals + user_id, rule_id, priority_class, + conditions_json, actions_json, before, after, ) - defer.returnValue(ret) else: - ret = yield self.runInteraction( + return self.runInteraction( "_add_push_rule_highest_priority_txn", self._add_push_rule_highest_priority_txn, - **vals + user_id, rule_id, priority_class, + conditions_json, actions_json, ) - defer.returnValue(ret) - def _add_push_rule_relative_txn(self, txn, user_id, **kwargs): - after = kwargs.pop("after", None) - before = kwargs.pop("before", None) + def _add_push_rule_relative_txn( + self, txn, user_id, rule_id, priority_class, + conditions_json, actions_json, before, after + ): + # Lock the table since otherwise we'll have annoying races between the + # SELECT here and the UPSERT below. + self.database_engine.lock_table(txn, "push_rules") + relative_to_rule = before or after res = self._simple_select_one_txn( @@ -149,69 +147,45 @@ class PushRuleStore(SQLBaseStore): "before/after rule not found: %s" % (relative_to_rule,) ) - priority_class = res["priority_class"] + base_priority_class = res["priority_class"] base_rule_priority = res["priority"] - if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class: + if base_priority_class != priority_class: raise InconsistentRuleException( "Given priority class does not match class of relative rule" ) - new_rule = kwargs - new_rule.pop("before", None) - new_rule.pop("after", None) - new_rule['priority_class'] = priority_class - new_rule['user_name'] = user_id - new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn) - - # check if the priority before/after is free - new_rule_priority = base_rule_priority - if after: - new_rule_priority -= 1 + if before: + # Higher priority rules are executed first, So adding a rule before + # a rule means giving it a higher priority than that rule. + new_rule_priority = base_rule_priority + 1 else: - new_rule_priority += 1 - - new_rule['priority'] = new_rule_priority + # We increment the priority of the existing rules to make space for + # the new rule. Therefore if we want this rule to appear after + # an existing rule we give it the priority of the existing rule, + # and then increment the priority of the existing rule. + new_rule_priority = base_rule_priority sql = ( - "SELECT COUNT(*) FROM push_rules" - " WHERE user_name = ? AND priority_class = ? AND priority = ?" + "UPDATE push_rules SET priority = priority + 1" + " WHERE user_name = ? AND priority_class = ? AND priority >= ?" ) + txn.execute(sql, (user_id, priority_class, new_rule_priority)) - res = txn.fetchall() - num_conflicting = res[0][0] - - # if there are conflicting rules, bump everything - if num_conflicting: - sql = "UPDATE push_rules SET priority = priority " - if after: - sql += "-1" - else: - sql += "+1" - sql += " WHERE user_name = ? AND priority_class = ? AND priority " - if after: - sql += "<= ?" - else: - sql += ">= ?" - - txn.execute(sql, (user_id, priority_class, new_rule_priority)) - txn.call_after( - self.get_push_rules_for_user.invalidate, (user_id,) + self._upsert_push_rule_txn( + txn, user_id, rule_id, priority_class, new_rule_priority, + conditions_json, actions_json, ) - txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, (user_id,) - ) + def _add_push_rule_highest_priority_txn( + self, txn, user_id, rule_id, priority_class, + conditions_json, actions_json + ): + # Lock the table since otherwise we'll have annoying races between the + # SELECT here and the UPSERT below. + self.database_engine.lock_table(txn, "push_rules") - self._simple_insert_txn( - txn, - table="push_rules", - values=new_rule, - ) - - def _add_push_rule_highest_priority_txn(self, txn, user_id, - priority_class, **kwargs): # find the highest priority rule in that class sql = ( "SELECT COUNT(*), MAX(priority) FROM push_rules" @@ -225,12 +199,48 @@ class PushRuleStore(SQLBaseStore): if how_many > 0: new_prio = highest_prio + 1 - # and insert the new rule - new_rule = kwargs - new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn) - new_rule['user_name'] = user_id - new_rule['priority_class'] = priority_class - new_rule['priority'] = new_prio + self._upsert_push_rule_txn( + txn, + user_id, rule_id, priority_class, new_prio, + conditions_json, actions_json, + ) + + def _upsert_push_rule_txn( + self, txn, user_id, rule_id, priority_class, + priority, conditions_json, actions_json + ): + """Specialised version of _simple_upsert_txn that picks a push_rule_id + using the _push_rule_id_gen if it needs to insert the rule. It assumes + that the "push_rules" table is locked""" + + sql = ( + "UPDATE push_rules" + " SET priority_class = ?, priority = ?, conditions = ?, actions = ?" + " WHERE user_name = ? AND rule_id = ?" + ) + + txn.execute(sql, ( + priority_class, priority, conditions_json, actions_json, + user_id, rule_id, + )) + + if txn.rowcount == 0: + # We didn't update a row with the given rule_id so insert one + push_rule_id = self._push_rule_id_gen.get_next_txn(txn) + + self._simple_insert_txn( + txn, + table="push_rules", + values={ + "id": push_rule_id, + "user_name": user_id, + "rule_id": rule_id, + "priority_class": priority_class, + "priority": priority, + "conditions": conditions_json, + "actions": actions_json, + }, + ) txn.call_after( self.get_push_rules_for_user.invalidate, (user_id,) @@ -239,12 +249,6 @@ class PushRuleStore(SQLBaseStore): self.get_push_rules_enabled_for_user.invalidate, (user_id,) ) - self._simple_insert_txn( - txn, - table="push_rules", - values=new_rule, - ) - @defer.inlineCallbacks def delete_push_rule(self, user_id, rule_id): """ -- cgit 1.4.1 From a4e278bfe7972b367e7782102461881c65720c08 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Wed, 17 Feb 2016 15:25:12 +0000 Subject: Respond to federated invite with non-empty context Currently, we magically perform an extra database hit to find the inviter, and use this to guess where we should send the event. Instead, fill in a valid context, so that other callers relying on the context actually have one. --- synapse/handlers/_base.py | 51 ++++++++++++++++++++++++++++++++++-- synapse/handlers/room.py | 52 +++++++++---------------------------- synapse/storage/event_federation.py | 8 +++--- 3 files changed, 65 insertions(+), 46 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index cad37f50e7..41e153c934 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -147,7 +147,7 @@ class BaseHandler(object): @defer.inlineCallbacks def _create_new_client_event(self, builder): - latest_ret = yield self.store.get_latest_events_in_room( + latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room( builder.room_id, ) @@ -156,7 +156,10 @@ class BaseHandler(object): else: depth = 1 - prev_events = [(e, h) for e, h, _ in latest_ret] + prev_events = [ + (event_id, prev_hashes) + for event_id, prev_hashes, _ in latest_ret + ] builder.prev_events = prev_events builder.depth = depth @@ -165,6 +168,31 @@ class BaseHandler(object): context = yield state_handler.compute_event_context(builder) + # If we've received an invite over federation, there are no latest + # events in the room, because we don't know enough about the graph + # fragment we received to treat it like a graph, so the above returned + # no relevant events. It may have returned some events (if we have + # joined and left the room), but not useful ones, like the invite. So we + # forcibly set our context to the invite we received over federation. + if ( + not self.is_host_in_room(context.current_state) and + builder.type == EventTypes.Member + ): + prev_member_event = yield self.store.get_room_member( + builder.sender, builder.room_id + ) + if prev_member_event: + builder.prev_events = ( + prev_member_event.event_id, + prev_member_event.prev_events + ) + + context = yield state_handler.compute_event_context( + builder, + old_state=(prev_member_event,), + outlier=True + ) + if builder.is_state(): builder.prev_state = yield self.store.add_event_hashes( context.prev_state_events @@ -187,6 +215,25 @@ class BaseHandler(object): (event, context,) ) + def is_host_in_room(self, current_state): + room_members = [ + (state_key, event.membership) + for ((event_type, state_key), event) in current_state.items() + if event_type == EventTypes.Member + ] + if len(room_members) == 0: + # has the room been created so we can join it? + create_event = current_state.get(("m.room.create", "")) + if create_event: + return True + for (state_key, membership) in room_members: + if ( + UserID.from_string(state_key).domain == self.hs.hostname + and membership == Membership.JOIN + ): + return True + return False + @defer.inlineCallbacks def handle_new_client_event(self, event, context, ratelimit=True, extra_users=[]): # We now need to go and hit out to wherever we need to hit out to. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index d4bb21e69e..f85a5f2677 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -510,10 +510,9 @@ class RoomMemberHandler(BaseHandler): # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") - should_do_dance, room_hosts = yield self._should_do_dance( - room_id, + should_do_dance, room_hosts = self._should_do_dance( context, - (yield self.get_inviter(target_user.to_string(), room_id)), + (self.get_inviter(target_user.to_string(), context.current_state)), room_hosts, ) @@ -534,11 +533,11 @@ class RoomMemberHandler(BaseHandler): ) handled = True if event.membership == Membership.LEAVE: - is_host_in_room = yield self.is_host_in_room(room_id, context) + is_host_in_room = self.is_host_in_room(context.current_state) if not is_host_in_room: # Rejecting an invite, rather than leaving a joined room handler = self.hs.get_handlers().federation_handler - inviter = yield self.get_inviter(target_user.to_string(), room_id) + inviter = self.get_inviter(target_user.to_string(), context.current_state) if not inviter: # return the same error as join_room_alias does raise SynapseError(404, "No known servers") @@ -584,20 +583,18 @@ class RoomMemberHandler(BaseHandler): and guest_access.content["guest_access"] == "can_join" ) - @defer.inlineCallbacks - def _should_do_dance(self, room_id, context, inviter, room_hosts=None): + def _should_do_dance(self, context, inviter, room_hosts=None): # TODO: Shouldn't this be remote_room_host? room_hosts = room_hosts or [] - # TODO(danielwh): This shouldn't need to yield for this check, we have a context. - is_host_in_room = yield self.is_host_in_room(room_id, context) + is_host_in_room = self.is_host_in_room(context.current_state) if is_host_in_room: - defer.returnValue((False, room_hosts)) + return False, room_hosts if inviter and not self.hs.is_mine(inviter): room_hosts.append(inviter.domain) - defer.returnValue((True, room_hosts)) + return True, room_hosts @defer.inlineCallbacks def lookup_room_alias(self, room_alias): @@ -624,36 +621,11 @@ class RoomMemberHandler(BaseHandler): defer.returnValue((RoomID.from_string(room_id), hosts)) - # TODO(danielwh): This should use the context, rather than looking up the store. - @defer.inlineCallbacks - def get_inviter(self, user_id, room_id): - # TODO(markjh): get prev_state from snapshot - prev_state = yield self.store.get_room_member( - user_id, room_id - ) + def get_inviter(self, user_id, current_state): + prev_state = current_state.get((EventTypes.Member, user_id)) if prev_state and prev_state.membership == Membership.INVITE: - defer.returnValue(UserID.from_string(prev_state.user_id)) - - # TODO(danielwh): This looks insane. Please make it not insane. - @defer.inlineCallbacks - def is_host_in_room(self, room_id, context): - is_host_in_room = yield self.auth.check_host_in_room( - room_id, - self.hs.hostname - ) - if not is_host_in_room: - # is *anyone* in the room? - room_member_keys = [ - v for (k, v) in context.current_state.keys() if ( - k == "m.room.member" - ) - ] - if len(room_member_keys) == 0: - # has the room been created so we can join it? - create_event = context.current_state.get(("m.room.create", "")) - if create_event: - is_host_in_room = True - defer.returnValue(is_host_in_room) + return UserID.from_string(prev_state.user_id) + return None @defer.inlineCallbacks def get_joined_rooms_for_user(self, user): diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index ce2c794025..3489315e0d 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -114,10 +114,10 @@ class EventFederationStore(SQLBaseStore): retcol="event_id", ) - def get_latest_events_in_room(self, room_id): + def get_latest_event_ids_and_hashes_in_room(self, room_id): return self.runInteraction( - "get_latest_events_in_room", - self._get_latest_events_in_room, + "get_latest_event_ids_and_hashes_in_room", + self._get_latest_event_ids_and_hashes_in_room, room_id, ) @@ -132,7 +132,7 @@ class EventFederationStore(SQLBaseStore): desc="get_latest_event_ids_in_room", ) - def _get_latest_events_in_room(self, txn, room_id): + def _get_latest_event_ids_and_hashes_in_room(self, txn, room_id): sql = ( "SELECT e.event_id, e.depth FROM events as e " "INNER JOIN event_forward_extremities as f " -- cgit 1.4.1 From e5999bfb1a4aab56acecb59ed6d068442f5b11a0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 15 Feb 2016 17:10:40 +0000 Subject: Initial cut --- synapse/handlers/events.py | 43 +- synapse/handlers/message.py | 14 +- synapse/handlers/presence.py | 1662 ++++++++------------ synapse/handlers/profile.py | 3 + synapse/handlers/sync.py | 22 + synapse/rest/client/v1/presence.py | 26 +- synapse/rest/client/v1/room.py | 18 +- synapse/rest/client/v2_alpha/receipts.py | 3 + synapse/rest/client/v2_alpha/sync.py | 16 +- synapse/storage/__init__.py | 50 +- synapse/storage/prepare_database.py | 2 +- synapse/storage/presence.py | 170 +- .../storage/schema/delta/30/presence_stream.sql | 30 + synapse/storage/util/id_generators.py | 4 +- synapse/util/__init__.py | 2 +- tests/utils.py | 4 +- 16 files changed, 933 insertions(+), 1136 deletions(-) create mode 100644 synapse/storage/schema/delta/30/presence_stream.sql (limited to 'synapse/storage') diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 4933c31c19..72a31a9755 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -19,6 +19,8 @@ from synapse.util.logutils import log_function from synapse.types import UserID from synapse.events.utils import serialize_event from synapse.util.logcontext import preserve_context_over_fn +from synapse.api.constants import Membership, EventTypes +from synapse.events import EventBase from ._base import BaseHandler @@ -126,11 +128,12 @@ class EventStreamHandler(BaseHandler): If `only_keys` is not None, events from keys will be sent down. """ auth_user = UserID.from_string(auth_user_id) + presence_handler = self.hs.get_handlers().presence_handler - try: - if affect_presence: - yield self.started_stream(auth_user) - + context = yield presence_handler.user_syncing( + auth_user_id, affect_presence=affect_presence, + ) + with context: if timeout: # If they've set a timeout set a minimum limit. timeout = max(timeout, 500) @@ -145,6 +148,34 @@ class EventStreamHandler(BaseHandler): is_guest=is_guest, explicit_room_id=room_id ) + # When the user joins a new room, or another user joins a currently + # joined room, we need to send down presence for those users. + to_add = [] + for event in events: + if not isinstance(event, EventBase): + continue + if event.type == EventTypes.Member: + if event.membership != Membership.JOIN: + continue + # Send down presence. + if event.state_key == auth_user_id: + # Send down presence for everyone in the room. + users = yield self.store.get_users_in_room(event.room_id) + states = yield presence_handler.get_states( + users, + as_event=True, + ) + to_add.extend(states) + else: + + ev = yield presence_handler.get_state( + UserID.from_string(event.state_key), + as_event=True, + ) + to_add.append(ev) + + events.extend(to_add) + time_now = self.clock.time_msec() chunks = [ @@ -159,10 +190,6 @@ class EventStreamHandler(BaseHandler): defer.returnValue(chunk) - finally: - if affect_presence: - self.stopped_stream(auth_user) - class EventHandler(BaseHandler): diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 82c8cb5f0c..77894d9132 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -21,7 +21,6 @@ from synapse.streams.config import PaginationConfig from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.util import unwrapFirstError -from synapse.util.logcontext import PreserveLoggingContext from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.types import UserID, RoomStreamToken, StreamToken @@ -254,8 +253,7 @@ class MessageHandler(BaseHandler): if event.type == EventTypes.Message: presence = self.hs.get_handlers().presence_handler - with PreserveLoggingContext(): - presence.bump_presence_active_time(user) + yield presence.bump_presence_active_time(user) @defer.inlineCallbacks def create_and_send_event(self, event_dict, ratelimit=True, @@ -660,10 +658,6 @@ class MessageHandler(BaseHandler): room_id=room_id, ) - # TODO(paul): I wish I was called with user objects not user_id - # strings... - auth_user = UserID.from_string(user_id) - # TODO: These concurrently time_now = self.clock.time_msec() state = [ @@ -688,13 +682,11 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def get_presence(): states = yield presence_handler.get_states( - target_users=[UserID.from_string(m.user_id) for m in room_members], - auth_user=auth_user, + [m.user_id for m in room_members], as_event=True, - check_auth=False, ) - defer.returnValue(states.values()) + defer.returnValue(states) @defer.inlineCallbacks def get_receipts(): diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index b61394f2b5..26f2e669ce 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -13,13 +13,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +"""This module is responsible for keeping track of presence status of local +and remote users. -from synapse.api.errors import SynapseError, AuthError +The methods that define policy are: + - PresenceHandler._update_states + - PresenceHandler._handle_timeouts + - should_notify +""" + +from twisted.internet import defer, reactor +from contextlib import contextmanager + +from synapse.api.errors import SynapseError from synapse.api.constants import PresenceState +from synapse.storage.presence import UserPresenceState -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import preserve_fn from synapse.util.logutils import log_function +from synapse.util.wheel_timer import WheelTimer from synapse.types import UserID import synapse.metrics @@ -33,33 +45,24 @@ logger = logging.getLogger(__name__) metrics = synapse.metrics.get_metrics_for(__name__) -# Don't bother bumping "last active" time if it differs by less than 60 seconds +# If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them +# "currently_active" LAST_ACTIVE_GRANULARITY = 60 * 1000 -# Keep no more than this number of offline serial revisions -MAX_OFFLINE_SERIALS = 1000 - - -# TODO(paul): Maybe there's one of these I can steal from somewhere -def partition(l, func): - """Partition the list by the result of func applied to each element.""" - ret = {} +# How long to wait until a new /events or /sync request before assuming +# the client has gone. +SYNC_ONLINE_TIMEOUT = 30 * 1000 - for x in l: - key = func(x) - if key not in ret: - ret[key] = [] - ret[key].append(x) +# How long to wait before marking the user as idle. Compared against last active +IDLE_TIMER = 5 * 60 * 1000 - return ret +# How often we expect remote servers to resend us presence. +FEDERATION_TIMEOUT = 30 * 60 * 1000 +# How often to resend presence to remote servers +FEDERATION_PING_INTERVAL = 25 * 60 * 1000 -def partitionbool(l, func): - def boolfunc(x): - return bool(func(x)) - - ret = partition(l, boolfunc) - return ret.get(True, []), ret.get(False, []) +assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER def user_presence_changed(distributor, user, statuscache): @@ -72,45 +75,13 @@ def collect_presencelike_data(distributor, user, content): class PresenceHandler(BaseHandler): - STATE_LEVELS = { - PresenceState.OFFLINE: 0, - PresenceState.UNAVAILABLE: 1, - PresenceState.ONLINE: 2, - PresenceState.FREE_FOR_CHAT: 3, - } - def __init__(self, hs): super(PresenceHandler, self).__init__(hs) - - self.homeserver = hs - + self.hs = hs self.clock = hs.get_clock() - - distributor = hs.get_distributor() - distributor.observe("registered_user", self.registered_user) - - distributor.observe( - "started_user_eventstream", self.started_user_eventstream - ) - distributor.observe( - "stopped_user_eventstream", self.stopped_user_eventstream - ) - - distributor.observe("user_joined_room", self.user_joined_room) - - distributor.declare("collect_presencelike_data") - - distributor.declare("changed_presencelike_data") - distributor.observe( - "changed_presencelike_data", self.changed_presencelike_data - ) - - # outbound signal from the presence module to advertise when a user's - # presence has changed - distributor.declare("user_presence_changed") - - self.distributor = distributor - + self.store = hs.get_datastore() + self.wheel_timer = WheelTimer() + self.notifier = hs.get_notifier() self.federation = hs.get_replication_layer() self.federation.register_edu_handler( @@ -138,348 +109,574 @@ class PresenceHandler(BaseHandler): ) ) - # IN-MEMORY store, mapping local userparts to sets of local users to - # be informed of state changes. - self._local_pushmap = {} - # map local users to sets of remote /domain names/ who are interested - # in them - self._remote_sendmap = {} - # map remote users to sets of local users who're interested in them - self._remote_recvmap = {} - # list of (serial, set of(userids)) tuples, ordered by serial, latest - # first - self._remote_offline_serials = [] - - # map any user to a UserPresenceCache - self._user_cachemap = {} - self._user_cachemap_latest_serial = 0 - - # map room_ids to the latest presence serial for a member of that - # room - self._room_serials = {} - - metrics.register_callback( - "userCachemap:size", - lambda: len(self._user_cachemap), + distributor = hs.get_distributor() + distributor.observe("user_joined_room", self.user_joined_room) + + active_presence = self.store.take_presence_startup_info() + + # A dictionary of the current state of users. This is prefilled with + # non-offline presence from the DB. We should fetch from the DB if + # we can't find a users presence in here. + self.user_to_current_state = { + state.user_id: state + for state in active_presence + } + + now = self.clock.time_msec() + for state in active_presence: + self.wheel_timer.insert( + now=now, + obj=state.user_id, + then=state.last_active + IDLE_TIMER, + ) + self.wheel_timer.insert( + now=now, + obj=state.user_id, + then=state.last_user_sync + SYNC_ONLINE_TIMEOUT, + ) + if self.hs.is_mine_id(state.user_id): + self.wheel_timer.insert( + now=now, + obj=state.user_id, + then=state.last_federation_update + FEDERATION_PING_INTERVAL, + ) + else: + self.wheel_timer.insert( + now=now, + obj=state.user_id, + then=state.last_federation_update + FEDERATION_TIMEOUT, + ) + + # Set of users who have presence in the `user_to_current_state` that + # have not yet been persisted + self.unpersisted_users_changes = set() + + reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown) + + self.serial_to_user = {} + self._next_serial = 1 + + # Keeps track of the number of *ongoing* syncs. While this is non zero + # a user will never go offline. + self.user_to_num_current_syncs = {} + + # Start a LoopingCall in 30s that fires every 5s. + # The initial delay is to allow disconnected clients a chance to + # reconnect before we treat them as offline. + self.clock.call_later( + 0 * 1000, + self.clock.looping_call, + self._handle_timeouts, + 5000, ) - def _get_or_make_usercache(self, user): - """If the cache entry doesn't exist, initialise a new one.""" - if user not in self._user_cachemap: - self._user_cachemap[user] = UserPresenceCache() - return self._user_cachemap[user] - - def _get_or_offline_usercache(self, user): - """If the cache entry doesn't exist, return an OFFLINE one but do not - store it into the cache.""" - if user in self._user_cachemap: - return self._user_cachemap[user] - else: - return UserPresenceCache() + @defer.inlineCallbacks + def _on_shutdown(self): + """Gets called when shutting down. This lets us persist any updates that + we haven't yet persisted, e.g. updates that only changes some internal + timers. This allows changes to persist across startup without having to + persist every single change. + + If this does not run it simply means that some of the timers will fire + earlier than they should when synapse is restarted. This affect of this + is some spurious presence changes that will self-correct. + """ + logger.info( + "Performing _on_shutdown. Persiting %d unpersisted changes", + len(self.user_to_current_state) + ) - def registered_user(self, user): - return self.store.create_presence(user.localpart) + if self.unpersisted_users_changes: + yield self.store.update_presence([ + self.user_to_current_state[user_id] + for user_id in self.unpersisted_users_changes + ]) + logger.info("Finished _on_shutdown") @defer.inlineCallbacks - def is_presence_visible(self, observer_user, observed_user): - assert(self.hs.is_mine(observed_user)) + def _update_states(self, new_states): + """Updates presence of users. Sets the appropriate timeouts. Pokes + the notifier and federation if and only if the changed presence state + should be sent to clients/servers. + """ + now = self.clock.time_msec() - if observer_user == observed_user: - defer.returnValue(True) + # NOTE: We purposefully don't yield between now and when we've + # calculated what we want to do with the new states, to avoid races. - if (yield self.store.user_rooms_intersect( - [u.to_string() for u in observer_user, observed_user])): - defer.returnValue(True) + to_notify = {} # Changes we want to notify everyone about + to_federation_ping = {} # These need sending keep-alives + for new_state in new_states: + user_id = new_state.user_id + prev_state = self.user_to_current_state.get( + user_id, UserPresenceState.default(user_id) + ) - if (yield self.store.is_presence_visible( - observed_localpart=observed_user.localpart, - observer_userid=observer_user.to_string())): - defer.returnValue(True) + # If the users are ours then we want to set up a bunch of timers + # to time things out. + if self.hs.is_mine_id(user_id): + if new_state.state == PresenceState.ONLINE: + # Idle timer + self.wheel_timer.insert( + now=now, + obj=user_id, + then=new_state.last_active + IDLE_TIMER + ) - defer.returnValue(False) + if new_state.state != PresenceState.OFFLINE: + # User has stopped syncing + self.wheel_timer.insert( + now=now, + obj=user_id, + then=new_state.last_user_sync + SYNC_ONLINE_TIMEOUT + ) - @defer.inlineCallbacks - def get_state(self, target_user, auth_user, as_event=False, check_auth=True): - """Get the current presence state of the given user. + last_federate = new_state.last_federation_update + if now - last_federate > FEDERATION_PING_INTERVAL: + # Been a while since we've poked remote servers + new_state = new_state.copy_and_replace( + last_federation_update=now, + ) + to_federation_ping[user_id] = new_state - Args: - target_user (UserID): The user whose presence we want - auth_user (UserID): The user requesting the presence, used for - checking if said user is allowed to see the persence of the - `target_user` - as_event (bool): Format the return as an event or not? - check_auth (bool): Perform the auth checks or not? + else: + self.wheel_timer.insert( + now=now, + obj=user_id, + then=new_state.last_federation_update + FEDERATION_TIMEOUT + ) - Returns: - dict: The presence state of the `target_user`, whose format depends - on the `as_event` argument. - """ - if self.hs.is_mine(target_user): - if check_auth: - visible = yield self.is_presence_visible( - observer_user=auth_user, - observed_user=target_user + if new_state.state == PresenceState.ONLINE: + currently_active = now - new_state.last_active < LAST_ACTIVE_GRANULARITY + new_state = new_state.copy_and_replace( + currently_active=currently_active, ) - if not visible: - raise SynapseError(404, "Presence information not visible") + # Check whether the change was something worth notifying about + if should_notify(prev_state, new_state): + new_state.copy_and_replace( + last_federation_update=now, + ) + to_notify[user_id] = new_state - if target_user in self._user_cachemap: - state = self._user_cachemap[target_user].get_state() - else: - state = yield self.store.get_presence_state(target_user.localpart) - if "mtime" in state: - del state["mtime"] - state["presence"] = state.pop("state") - else: - # TODO(paul): Have remote server send us permissions set - state = self._get_or_offline_usercache(target_user).get_state() + self.user_to_current_state[user_id] = new_state + + # TODO: We should probably ensure there are no races hereafter - if "last_active" in state: - state["last_active_ago"] = int( - self.clock.time_msec() - state.pop("last_active") + if to_notify: + yield self._persist_and_notify(to_notify.values()) + + self.unpersisted_users_changes |= set(s.user_id for s in new_states) + self.unpersisted_users_changes -= set(to_notify.keys()) + + to_federation_ping = { + user_id: state for user_id, state in to_federation_ping.items() + if user_id not in to_notify + } + if to_federation_ping: + _, _, hosts_to_states = yield self._get_interested_parties( + to_federation_ping.values() ) - if as_event: - content = state + self._push_to_remotes(hosts_to_states) - content["user_id"] = target_user.to_string() + def _handle_timeouts(self): + """Checks the presence of users that have timed out and updates as + appropriate. + """ + now = self.clock.time_msec() - if "last_active" in content: - content["last_active_ago"] = int( - self._clock.time_msec() - content.pop("last_active") - ) + # Fetch the list of users that *may* have timed out. Things may have + # changed since the timeout was set, so we won't necessarily have to + # take any action. + users_to_check = self.wheel_timer.fetch(now) - defer.returnValue({"type": "m.presence", "content": content}) - else: - defer.returnValue(state) + changes = {} # Actual changes we need to notify people about - @defer.inlineCallbacks - def get_states(self, target_users, auth_user, as_event=False, check_auth=True): - """A batched version of the `get_state` method that accepts a list of - `target_users` + for user_id in set(users_to_check): + state = self.user_to_current_state.get(user_id, None) + if not state: + continue - Args: - target_users (list): The list of UserID's whose presence we want - auth_user (UserID): The user requesting the presence, used for - checking if said user is allowed to see the persence of the - `target_users` - as_event (bool): Format the return as an event or not? - check_auth (bool): Perform the auth checks or not? + if self.hs.is_mine_id(user_id): + if state.state == PresenceState.OFFLINE: + continue - Returns: - dict: A mapping from user -> presence_state + if state.state == PresenceState.ONLINE: + if now - state.last_active > IDLE_TIMER: + # Currently online, but last activity ages ago so auto + # idle + changes[user_id] = state.copy_and_replace( + state=PresenceState.UNAVAILABLE, + ) + elif now - state.last_active > LAST_ACTIVE_GRANULARITY: + # So that we send down a notification that we've + # stopped updating. + changes[user_id] = state + + if now - state.last_federation_update > FEDERATION_PING_INTERVAL: + # Need to send ping to other servers to ensure they don't + # timeout and set us to offline + changes[user_id] = state + + # If there are have been no sync for a while (and none ongoing), + # set presence to offline + if not self.user_to_num_current_syncs.get(user_id, 0): + if now - state.last_user_sync > SYNC_ONLINE_TIMEOUT: + changes[user_id] = state.copy_and_replace( + state=PresenceState.OFFLINE, + ) + else: + # We expect to be poked occaisonally by the other side. + # This is to protect against forgetful/buggy servers, so that + # no one gets stuck online forever. + if now - state.last_federation_update > FEDERATION_TIMEOUT: + if state.state != PresenceState.OFFLINE: + # The other side seems to have disappeared. + changes[user_id] = state.copy_and_replace( + state=PresenceState.OFFLINE, + ) + + preserve_fn(self._update_states)(changes.values()) + + @defer.inlineCallbacks + def bump_presence_active_time(self, user): + """We've seen the user do something that indicates they're interacting + with the app. """ - local_users, remote_users = partitionbool( - target_users, - lambda u: self.hs.is_mine(u) - ) + user_id = user.to_string() - if check_auth: - for user in local_users: - visible = yield self.is_presence_visible( - observer_user=auth_user, - observed_user=user - ) + prev_state = yield self.current_state_for_user(user_id) - if not visible: - raise SynapseError(404, "Presence information not visible") + yield self._update_states([prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active=self.clock.time_msec(), + )]) - results = {} - if local_users: - for user in local_users: - if user in self._user_cachemap: - results[user] = self._user_cachemap[user].get_state() + @defer.inlineCallbacks + def user_syncing(self, user_id, affect_presence=True): + """Returns a context manager that should surround any stream requests + from the user. - local_to_user = {u.localpart: u for u in local_users} + This allows us to keep track of who is currently streaming and who isn't + without having to have timers outside of this module to avoid flickering + when users disconnect/reconnect. - states = yield self.store.get_presence_states( - [u.localpart for u in local_users if u not in results] - ) + Args: + user_id (str) + affect_presence (bool): If false this function will be a no-op. + Useful for streams that are not associated with an actual + client that is being used by a user. + """ + if affect_presence: + curr_sync = self.user_to_num_current_syncs.get(user_id, 0) + self.user_to_num_current_syncs[user_id] = curr_sync + 1 + + prev_state = yield self.current_state_for_user(user_id) + if prev_state.state == PresenceState.OFFLINE: + # If they're currently offline then bring them online, otherwise + # just update the last sync times. + yield self._update_states([prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active=self.clock.time_msec(), + last_user_sync=self.clock.time_msec(), + )]) + else: + yield self._update_states([prev_state.copy_and_replace( + last_user_sync=self.clock.time_msec(), + )]) - for local_part, state in states.items(): - if state is None: - continue - res = {"presence": state["state"]} - if "status_msg" in state and state["status_msg"]: - res["status_msg"] = state["status_msg"] - results[local_to_user[local_part]] = res - - for user in remote_users: - # TODO(paul): Have remote server send us permissions set - results[user] = self._get_or_offline_usercache(user).get_state() - - for state in results.values(): - if "last_active" in state: - state["last_active_ago"] = int( - self.clock.time_msec() - state.pop("last_active") - ) + @defer.inlineCallbacks + def _end(): + if affect_presence: + self.user_to_num_current_syncs[user_id] -= 1 - if as_event: - for user, state in results.items(): - content = state - content["user_id"] = user.to_string() + prev_state = yield self.current_state_for_user(user_id) + yield self._update_states([prev_state.copy_and_replace( + last_user_sync=self.clock.time_msec(), + )]) - if "last_active" in content: - content["last_active_ago"] = int( - self._clock.time_msec() - content.pop("last_active") - ) + @contextmanager + def _user_syncing(): + try: + yield + finally: + preserve_fn(_end)() - results[user] = {"type": "m.presence", "content": content} + defer.returnValue(_user_syncing()) - defer.returnValue(results) + @defer.inlineCallbacks + def current_state_for_user(self, user_id): + """Get the current presence state for a user. + """ + res = yield self.current_state_for_users([user_id]) + defer.returnValue(res[user_id]) @defer.inlineCallbacks - @log_function - def set_state(self, target_user, auth_user, state): - # return - # TODO (erikj): Turn this back on. Why did we end up sending EDUs - # everywhere? + def current_state_for_users(self, user_ids): + """Get the current presence state for multiple users. - if not self.hs.is_mine(target_user): - raise SynapseError(400, "User is not hosted on this Home Server") + Returns: + dict: `user_id` -> `UserPresenceState` + """ + states = { + user_id: self.user_to_current_state.get(user_id, None) + for user_id in user_ids + } + + missing = [user_id for user_id, state in states.items() if not state] + if missing: + # There are things not in our in memory cache. Lets pull them out of + # the database. + res = yield self.store.get_presence_for_users(missing) + states.update({state.user_id: state for state in res}) + + missing = [user_id for user_id, state in states.items() if not state] + if missing: + states.update({ + user_id: UserPresenceState.default(user_id) + for user_id in missing + }) - if target_user != auth_user: - raise AuthError(400, "Cannot set another user's presence") + defer.returnValue(states) - if "status_msg" not in state: - state["status_msg"] = None + @defer.inlineCallbacks + def _get_interested_parties(self, states): + """Given a list of states return which entities (rooms, users, servers) + are interested in the given states. - for k in state.keys(): - if k not in ("presence", "status_msg"): - raise SynapseError( - 400, "Unexpected presence state key '%s'" % (k,) - ) + Returns: + 3-tuple: `(room_ids_to_states, users_to_states, hosts_to_states)`, + with each item being a dict of `entity_name` -> `[UserPresenceState]` + """ + room_ids_to_states = {} + users_to_states = {} + for state in states: + events = yield self.store.get_rooms_for_user(state.user_id) + for e in events: + room_ids_to_states.setdefault(e.room_id, []).append(state) + + plist = yield self.store.get_presence_list_observers_accepted(state.user_id) + for u in plist: + users_to_states.setdefault(u, []).append(state) + + # Always notify self + users_to_states.setdefault(state.user_id, []).append(state) + + hosts_to_states = {} + for room_id, states in room_ids_to_states.items(): + hosts = yield self.store.get_joined_hosts_for_room(room_id) + for host in hosts: + hosts_to_states.setdefault(host, []).extend(states) - if state["presence"] not in self.STATE_LEVELS: - raise SynapseError(400, "'%s' is not a valid presence state" % ( - state["presence"], - )) + for user_id, states in users_to_states.items(): + host = UserID.from_string(user_id).domain + hosts_to_states.setdefault(host, []).extend(states) - logger.debug("Updating presence state of %s to %s", - target_user.localpart, state["presence"]) + # TODO: de-dup hosts_to_states, as a single host might have multiple + # of same presence - state_to_store = dict(state) - state_to_store["state"] = state_to_store.pop("presence") + defer.returnValue((room_ids_to_states, users_to_states, hosts_to_states)) + + @defer.inlineCallbacks + def _persist_and_notify(self, states): + """Persist states in the database, poke the notifier and send to + interested remote servers + """ + stream_id, max_token = yield self.store.update_presence(states) - statuscache = self._get_or_offline_usercache(target_user) - was_level = self.STATE_LEVELS[statuscache.get_state()["presence"]] - now_level = self.STATE_LEVELS[state["presence"]] + parties = yield self._get_interested_parties(states) + room_ids_to_states, users_to_states, hosts_to_states = parties - yield self.store.set_presence_state( - target_user.localpart, state_to_store + self.notifier.on_new_event( + "presence_key", stream_id, rooms=room_ids_to_states.keys(), + users=[UserID.from_string(u) for u in users_to_states.keys()] ) - yield collect_presencelike_data(self.distributor, target_user, state) - if now_level > was_level: - state["last_active"] = self.clock.time_msec() + self._push_to_remotes(hosts_to_states) + + def _push_to_remotes(self, hosts_to_states): + """Sends state updates to remote servers. + + Args: + hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]` + """ + now = self.clock.time_msec() + for host, states in hosts_to_states.items(): + self.federation.send_edu( + destination=host, + edu_type="m.presence", + content={ + "push": [ + _format_user_presence_state(state, now) + for state in states + ] + } + ) + + @defer.inlineCallbacks + def incoming_presence(self, origin, content): + """Called when we receive a `m.presence` EDU from a remote server. + """ + now = self.clock.time_msec() + updates = [] + for push in content.get("push", []): + # A "push" contains a list of presence that we are probably interested + # in. + # TODO: Actually check if we're interested, rather than blindly + # accepting presence updates. + user_id = push.get("user_id", None) + if not user_id: + logger.info( + "Got presence update from %r with no 'user_id': %r", + origin, push, + ) + continue - now_online = state["presence"] != PresenceState.OFFLINE - was_polling = target_user in self._user_cachemap + presence_state = push.get("presence", None) + if not presence_state: + logger.info( + "Got presence update from %r with no 'presence_state': %r", + origin, push, + ) + continue - if now_online and not was_polling: - yield self.start_polling_presence(target_user, state=state) - elif not now_online and was_polling: - yield self.stop_polling_presence(target_user) + new_fields = { + "state": presence_state, + "last_federation_update": now, + } - # TODO(paul): perform a presence push as part of start/stop poll so - # we don't have to do this all the time - yield self.changed_presencelike_data(target_user, state) + last_active_ago = push.get("last_active_ago", None) + if last_active_ago is not None: + new_fields["last_active"] = now - last_active_ago - def bump_presence_active_time(self, user, now=None): - if now is None: - now = self.clock.time_msec() + new_fields["status_msg"] = push.get("status_msg", None) - prev_state = self._get_or_make_usercache(user) - if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY: - return + prev_state = yield self.current_state_for_user(user_id) + updates.append(prev_state.copy_and_replace(**new_fields)) - with PreserveLoggingContext(): - self.changed_presencelike_data(user, {"last_active": now}) + if updates: + yield self._update_states(updates) - def get_joined_rooms_for_user(self, user): - """Get the list of rooms a user is joined to. + @defer.inlineCallbacks + def get_state(self, target_user, as_event=False): + results = yield self.get_states( + [target_user.to_string()], + as_event=as_event, + ) + + defer.returnValue(results[0]) + + @defer.inlineCallbacks + def get_states(self, target_user_ids, as_event=False): + """Get the presence state for users. Args: - user(UserID): The user. + target_user_ids (list) + as_event (bool): Whether to format it as a client event or not. + Returns: - A Deferred of a list of room id strings. + list """ - rm_handler = self.homeserver.get_handlers().room_member_handler - return rm_handler.get_joined_rooms_for_user(user) - def get_joined_users_for_room_id(self, room_id): - rm_handler = self.homeserver.get_handlers().room_member_handler - return rm_handler.get_room_members(room_id) + updates = yield self.current_state_for_users(target_user_ids) + updates = updates.values() - @defer.inlineCallbacks - def changed_presencelike_data(self, user, state): - """Updates the presence state of a local user. + for user_id in set(target_user_ids) - set(u.user_id for u in updates): + updates.append(UserPresenceState.default(user_id)) - Args: - user(UserID): The user being updated. - state(dict): The new presence state for the user. - Returns: - A Deferred + now = self.clock.time_msec() + if as_event: + defer.returnValue([ + { + "type": "m.presence", + "content": _format_user_presence_state(state, now), + } + for state in updates + ]) + else: + defer.returnValue([ + _format_user_presence_state(state, now) for state in updates + ]) + + @defer.inlineCallbacks + def set_state(self, target_user, state): + """Set the presence state of the user. """ - self._user_cachemap_latest_serial += 1 - statuscache = yield self.update_presence_cache(user, state) - yield self.push_presence(user, statuscache=statuscache) + status_msg = state.get("status_msg", None) + presence = state["presence"] - @log_function - def started_user_eventstream(self, user): - # TODO(paul): Use "last online" state - return self.set_state(user, user, {"presence": PresenceState.ONLINE}) + user_id = target_user.to_string() - @log_function - def stopped_user_eventstream(self, user): - # TODO(paul): Save current state as "last online" state - return self.set_state(user, user, {"presence": PresenceState.OFFLINE}) + prev_state = yield self.current_state_for_user(user_id) + + new_fields = { + "state": presence, + "status_msg": status_msg + } + + if presence == PresenceState.ONLINE: + new_fields["last_active"] = self.clock.time_msec() + + yield self._update_states([prev_state.copy_and_replace(**new_fields)]) @defer.inlineCallbacks def user_joined_room(self, user, room_id): - """Called via the distributor whenever a user joins a room. - Notifies the new member of the presence of the current members. - Notifies the current members of the room of the new member's presence. - - Args: - user(UserID): The user who joined the room. - room_id(str): The room id the user joined. + """Called (via the distributor) when a user joins a room. This funciton + sends presence updates to servers, either: + 1. the joining user is a local user and we send their presence to + all servers in the room. + 2. the joining user is a remote user and so we send presence for all + local users in the room. """ + # We only need to send presence to servers that don't have it yet. We + # don't need to send to local clients here, as that is done as part + # of the event stream/sync. + # TODO: Only send to servers not already in the room. if self.hs.is_mine(user): - # No actual update but we need to bump the serial anyway for the - # event source - self._user_cachemap_latest_serial += 1 - statuscache = yield self.update_presence_cache( - user, room_ids=[room_id] - ) - self.push_update_to_local_and_remote( - observed_user=user, - room_ids=[room_id], - statuscache=statuscache, - ) + state = yield self.current_state_for_user(user.to_string()) - # We also want to tell them about current presence of people. - curr_users = yield self.get_joined_users_for_room_id(room_id) + hosts = yield self.store.get_joined_hosts_for_room(room_id) + self._push_to_remotes({host: (state,) for host in hosts}) + else: + user_ids = yield self.store.get_users_in_room(room_id) + user_ids = filter(self.hs.is_mine_id, user_ids) - for local_user in [c for c in curr_users if self.hs.is_mine(c)]: - statuscache = yield self.update_presence_cache( - local_user, room_ids=[room_id], add_to_cache=False - ) + states = yield self.current_state_for_users(user_ids) - with PreserveLoggingContext(): - self.push_update_to_local_and_remote( - observed_user=local_user, - users_to_push=[user], - statuscache=statuscache, - ) + self._push_to_remotes({user.domain: states.values()}) @defer.inlineCallbacks - def send_presence_invite(self, observer_user, observed_user): - """Request the presence of a local or remote user for a local user""" + def get_presence_list(self, observer_user, accepted=None): + """Returns the presence for all users in their presence list. + """ if not self.hs.is_mine(observer_user): raise SynapseError(400, "User is not hosted on this Home Server") + presence_list = yield self.store.get_presence_list( + observer_user.localpart, accepted=accepted + ) + + results = yield self.get_states( + target_user_ids=[row["observed_user_id"] for row in presence_list], + as_event=False, + ) + + is_accepted = { + row["observed_user_id"]: row["accepted"] for row in presence_list + } + + for result in results: + result.update({ + "accepted": is_accepted, + }) + + defer.returnValue(results) + + @defer.inlineCallbacks + def send_presence_invite(self, observer_user, observed_user): + """Sends a presence invite. + """ yield self.store.add_presence_list_pending( observer_user.localpart, observed_user.to_string() ) @@ -496,60 +693,41 @@ class PresenceHandler(BaseHandler): } ) - @defer.inlineCallbacks - def _should_accept_invite(self, observed_user, observer_user): - if not self.hs.is_mine(observed_user): - defer.returnValue(False) - - row = yield self.store.has_presence_state(observed_user.localpart) - if not row: - defer.returnValue(False) - - # TODO(paul): Eventually we'll ask the user's permission for this - # before accepting. For now just accept any invite request - defer.returnValue(True) - @defer.inlineCallbacks def invite_presence(self, observed_user, observer_user): - """Handles a m.presence_invite EDU. A remote or local user has - requested presence updates for a local user. If the invite is accepted - then allow the local or remote user to see the presence of the local - user. - - Args: - observed_user(UserID): The local user whose presence is requested. - observer_user(UserID): The remote or local user requesting presence. + """Handles new presence invites. """ - accept = yield self._should_accept_invite(observed_user, observer_user) - - if accept: - yield self.store.allow_presence_visible( - observed_user.localpart, observer_user.to_string() - ) + if not self.hs.is_mine(observed_user): + raise SynapseError(400, "User is not hosted on this Home Server") + # TODO: Don't auto accept if self.hs.is_mine(observer_user): - if accept: - yield self.accept_presence(observed_user, observer_user) - else: - yield self.deny_presence(observed_user, observer_user) + yield self.accept_presence(observed_user, observer_user) else: - edu_type = "m.presence_accept" if accept else "m.presence_deny" - - yield self.federation.send_edu( + self.federation.send_edu( destination=observer_user.domain, - edu_type=edu_type, + edu_type="m.presence_accept", content={ "observed_user": observed_user.to_string(), "observer_user": observer_user.to_string(), } ) + state_dict = yield self.get_state(observed_user, as_event=False) + + self.federation.send_edu( + destination=observer_user.domain, + edu_type="m.presence", + content={ + "push": [state_dict] + } + ) + @defer.inlineCallbacks def accept_presence(self, observed_user, observer_user): """Handles a m.presence_accept EDU. Mark a presence invite from a local or remote user as accepted in a local user's presence list. Starts polling for presence updates from the local or remote user. - Args: observed_user(UserID): The user to update in the presence list. observer_user(UserID): The owner of the presence list to update. @@ -558,15 +736,10 @@ class PresenceHandler(BaseHandler): observer_user.localpart, observed_user.to_string() ) - yield self.start_polling_presence( - observer_user, target_user=observed_user - ) - @defer.inlineCallbacks def deny_presence(self, observed_user, observer_user): """Handle a m.presence_deny EDU. Removes a local or remote user from a local user's presence list. - Args: observed_user(UserID): The local or remote user to remove from the list. @@ -584,7 +757,6 @@ class PresenceHandler(BaseHandler): def drop(self, observed_user, observer_user): """Remove a local or remote user from a local user's presence list and unsubscribe the local user from updates that user. - Args: observed_user(UserId): The local or remote user to remove from the list. @@ -599,710 +771,138 @@ class PresenceHandler(BaseHandler): observer_user.localpart, observed_user.to_string() ) - self.stop_polling_presence( - observer_user, target_user=observed_user - ) - - @defer.inlineCallbacks - def get_presence_list(self, observer_user, accepted=None): - """Get the presence list for a local user. The retured list includes - the current presence state for each user listed. - - Args: - observer_user(UserID): The local user whose presence list to fetch. - accepted(bool or None): If not none then only include users who - have or have not accepted the presence invite request. - Returns: - A Deferred list of presence state events. - """ - if not self.hs.is_mine(observer_user): - raise SynapseError(400, "User is not hosted on this Home Server") - - presence_list = yield self.store.get_presence_list( - observer_user.localpart, accepted=accepted - ) - - results = [] - for row in presence_list: - observed_user = UserID.from_string(row["observed_user_id"]) - result = { - "observed_user": observed_user, "accepted": row["accepted"] - } - result.update( - self._get_or_offline_usercache(observed_user).get_state() - ) - if "last_active" in result: - result["last_active_ago"] = int( - self.clock.time_msec() - result.pop("last_active") - ) - results.append(result) - - defer.returnValue(results) - - @defer.inlineCallbacks - @log_function - def start_polling_presence(self, user, target_user=None, state=None): - """Subscribe a local user to presence updates from a local or remote - user. If no target_user is supplied then subscribe to all users stored - in the presence list for the local user. - - Additonally this pushes the current presence state of this user to all - target_users. That state can be provided directly or will be read from - the stored state for the local user. - - Also this attempts to notify the local user of the current state of - any local target users. - - Args: - user(UserID): The local user that whishes for presence updates. - target_user(UserID): The local or remote user whose updates are - wanted. - state(dict): Optional presence state for the local user. - """ - logger.debug("Start polling for presence from %s", user) - - if target_user: - target_users = set([target_user]) - room_ids = [] - else: - presence = yield self.store.get_presence_list( - user.localpart, accepted=True - ) - target_users = set([ - UserID.from_string(x["observed_user_id"]) for x in presence - ]) - - # Also include people in all my rooms - - room_ids = yield self.get_joined_rooms_for_user(user) - - if state is None: - state = yield self.store.get_presence_state(user.localpart) - else: - # statuscache = self._get_or_make_usercache(user) - # self._user_cachemap_latest_serial += 1 - # statuscache.update(state, self._user_cachemap_latest_serial) - pass - - yield self.push_update_to_local_and_remote( - observed_user=user, - users_to_push=target_users, - room_ids=room_ids, - statuscache=self._get_or_make_usercache(user), - ) - - for target_user in target_users: - if self.hs.is_mine(target_user): - self._start_polling_local(user, target_user) - - # We want to tell the person that just came online - # presence state of people they are interested in? - self.push_update_to_clients( - users_to_push=[user], - ) - - deferreds = [] - remote_users = [u for u in target_users if not self.hs.is_mine(u)] - remoteusers_by_domain = partition(remote_users, lambda u: u.domain) - # Only poll for people in our get_presence_list - for domain in remoteusers_by_domain: - remoteusers = remoteusers_by_domain[domain] - - deferreds.append(self._start_polling_remote( - user, domain, remoteusers - )) - - yield defer.DeferredList(deferreds, consumeErrors=True) - - def _start_polling_local(self, user, target_user): - """Subscribe a local user to presence updates for a local user - - Args: - user(UserId): The local user that wishes for updates. - target_user(UserId): The local users whose updates are wanted. - """ - target_localpart = target_user.localpart - - if target_localpart not in self._local_pushmap: - self._local_pushmap[target_localpart] = set() - - self._local_pushmap[target_localpart].add(user) - - def _start_polling_remote(self, user, domain, remoteusers): - """Subscribe a local user to presence updates for remote users on a - given remote domain. - - Args: - user(UserID): The local user that wishes for updates. - domain(str): The remote server the local user wants updates from. - remoteusers(UserID): The remote users that local user wants to be - told about. - Returns: - A Deferred. - """ - to_poll = set() - - for u in remoteusers: - if u not in self._remote_recvmap: - self._remote_recvmap[u] = set() - to_poll.add(u) - - self._remote_recvmap[u].add(user) - - if not to_poll: - return defer.succeed(None) - - return self.federation.send_edu( - destination=domain, - edu_type="m.presence", - content={"poll": [u.to_string() for u in to_poll]} - ) - - @log_function - def stop_polling_presence(self, user, target_user=None): - """Unsubscribe a local user from presence updates from a local or - remote user. If no target user is supplied then unsubscribe the user - from all presence updates that the user had subscribed to. - - Args: - user(UserID): The local user that no longer wishes for updates. - target_user(UserID or None): The user whose updates are no longer - wanted. - Returns: - A Deferred. - """ - logger.debug("Stop polling for presence from %s", user) - - if not target_user or self.hs.is_mine(target_user): - self._stop_polling_local(user, target_user=target_user) - - deferreds = [] - - if target_user: - if target_user not in self._remote_recvmap: - return - target_users = set([target_user]) - else: - target_users = self._remote_recvmap.keys() - - remoteusers = [u for u in target_users - if user in self._remote_recvmap[u]] - remoteusers_by_domain = partition(remoteusers, lambda u: u.domain) - - for domain in remoteusers_by_domain: - remoteusers = remoteusers_by_domain[domain] - - deferreds.append( - self._stop_polling_remote(user, domain, remoteusers) - ) - - return defer.DeferredList(deferreds, consumeErrors=True) - - def _stop_polling_local(self, user, target_user): - """Unsubscribe a local user from presence updates from a local user on - this server. - - Args: - user(UserID): The local user that no longer wishes for updates. - target_user(UserID): The user whose updates are no longer wanted. - """ - for localpart in self._local_pushmap.keys(): - if target_user and localpart != target_user.localpart: - continue - - if user in self._local_pushmap[localpart]: - self._local_pushmap[localpart].remove(user) - - if not self._local_pushmap[localpart]: - del self._local_pushmap[localpart] - - @log_function - def _stop_polling_remote(self, user, domain, remoteusers): - """Unsubscribe a local user from presence updates from remote users on - a given domain. - - Args: - user(UserID): The local user that no longer wishes for updates. - domain(str): The remote server to unsubscribe from. - remoteusers([UserID]): The users on that remote server that the - local user no longer wishes to be updated about. - Returns: - A Deferred. - """ - to_unpoll = set() - - for u in remoteusers: - self._remote_recvmap[u].remove(user) - - if not self._remote_recvmap[u]: - del self._remote_recvmap[u] - to_unpoll.add(u) - - if not to_unpoll: - return defer.succeed(None) - - return self.federation.send_edu( - destination=domain, - edu_type="m.presence", - content={"unpoll": [u.to_string() for u in to_unpoll]} - ) - - @defer.inlineCallbacks - @log_function - def push_presence(self, user, statuscache): - """ - Notify local and remote users of a change in presence of a local user. - Pushes the update to local clients and remote domains that are directly - subscribed to the presence of the local user. - Also pushes that update to any local user or remote domain that shares - a room with the local user. - - Args: - user(UserID): The local user whose presence was updated. - statuscache(UserPresenceCache): Cache of the user's presence state - Returns: - A Deferred. - """ - assert(self.hs.is_mine(user)) - - logger.debug("Pushing presence update from %s", user) - - localusers = set(self._local_pushmap.get(user.localpart, set())) - remotedomains = set(self._remote_sendmap.get(user.localpart, set())) - - # Reflect users' status changes back to themselves, so UIs look nice - # and also user is informed of server-forced pushes - localusers.add(user) - - room_ids = yield self.get_joined_rooms_for_user(user) - - if not localusers and not room_ids: - defer.returnValue(None) - - yield self.push_update_to_local_and_remote( - observed_user=user, - users_to_push=localusers, - remote_domains=remotedomains, - room_ids=room_ids, - statuscache=statuscache, - ) - yield user_presence_changed(self.distributor, user, statuscache) - - @defer.inlineCallbacks - def incoming_presence(self, origin, content): - """Handle an incoming m.presence EDU. - For each presence update in the "push" list update our local cache and - notify the appropriate local clients. Only clients that share a room - or are directly subscribed to the presence for a user should be - notified of the update. - For each subscription request in the "poll" list start pushing presence - updates to the remote server. - For unsubscribe request in the "unpoll" list stop pushing presence - updates to the remote server. - - Args: - orgin(str): The source of this m.presence EDU. - content(dict): The content of this m.presence EDU. - Returns: - A Deferred. - """ - deferreds = [] - - for push in content.get("push", []): - user = UserID.from_string(push["user_id"]) - - logger.debug("Incoming presence update from %s", user) - - observers = set(self._remote_recvmap.get(user, set())) - if observers: - logger.debug( - " | %d interested local observers %r", len(observers), observers - ) - - room_ids = yield self.get_joined_rooms_for_user(user) - if room_ids: - logger.debug(" | %d interested room IDs %r", len(room_ids), room_ids) - - state = dict(push) - del state["user_id"] - - if "presence" not in state: - logger.warning( - "Received a presence 'push' EDU from %s without a" - " 'presence' key", origin - ) - continue - - if "last_active_ago" in state: - state["last_active"] = int( - self.clock.time_msec() - state.pop("last_active_ago") - ) - - self._user_cachemap_latest_serial += 1 - yield self.update_presence_cache(user, state, room_ids=room_ids) - - if not observers and not room_ids: - logger.debug(" | no interested observers or room IDs") - continue - - self.push_update_to_clients( - users_to_push=observers, room_ids=room_ids - ) - - user_id = user.to_string() - - if state["presence"] == PresenceState.OFFLINE: - self._remote_offline_serials.insert( - 0, - (self._user_cachemap_latest_serial, set([user_id])) - ) - while len(self._remote_offline_serials) > MAX_OFFLINE_SERIALS: - self._remote_offline_serials.pop() # remove the oldest - if user in self._user_cachemap: - del self._user_cachemap[user] - else: - # Remove the user from remote_offline_serials now that they're - # no longer offline - for idx, elem in enumerate(self._remote_offline_serials): - (_, user_ids) = elem - user_ids.discard(user_id) - if not user_ids: - self._remote_offline_serials.pop(idx) - - for poll in content.get("poll", []): - user = UserID.from_string(poll) - - if not self.hs.is_mine(user): - continue - - # TODO(paul) permissions checks - - if user not in self._remote_sendmap: - self._remote_sendmap[user] = set() - - self._remote_sendmap[user].add(origin) - - deferreds.append(self._push_presence_remote(user, origin)) - - for unpoll in content.get("unpoll", []): - user = UserID.from_string(unpoll) - - if not self.hs.is_mine(user): - continue - - if user in self._remote_sendmap: - self._remote_sendmap[user].remove(origin) - - if not self._remote_sendmap[user]: - del self._remote_sendmap[user] - - yield defer.DeferredList(deferreds, consumeErrors=True) - - @defer.inlineCallbacks - def update_presence_cache(self, user, state={}, room_ids=None, - add_to_cache=True): - """Update the presence cache for a user with a new state and bump the - serial to the latest value. - - Args: - user(UserID): The user being updated - state(dict): The presence state being updated - room_ids(None or list of str): A list of room_ids to update. If - room_ids is None then fetch the list of room_ids the user is - joined to. - add_to_cache: Whether to add an entry to the presence cache if the - user isn't already in the cache. - Returns: - A Deferred UserPresenceCache for the user being updated. - """ - if room_ids is None: - room_ids = yield self.get_joined_rooms_for_user(user) - - for room_id in room_ids: - self._room_serials[room_id] = self._user_cachemap_latest_serial - if add_to_cache: - statuscache = self._get_or_make_usercache(user) - else: - statuscache = self._get_or_offline_usercache(user) - statuscache.update(state, serial=self._user_cachemap_latest_serial) - defer.returnValue(statuscache) + # TODO: Inform the remote that we've dropped the presence list. @defer.inlineCallbacks - def push_update_to_local_and_remote(self, observed_user, statuscache, - users_to_push=[], room_ids=[], - remote_domains=[]): - """Notify local clients and remote servers of a change in the presence - of a user. - - Args: - observed_user(UserID): The user to push the presence state for. - statuscache(UserPresenceCache): The cache for the presence state to - push. - users_to_push([UserID]): A list of local and remote users to - notify. - room_ids([str]): Notify the local and remote occupants of these - rooms. - remote_domains([str]): A list of remote servers to notify in - addition to those implied by the users_to_push and the - room_ids. - Returns: - A Deferred. - """ + def is_visible(self, observed_user, observer_user): + observer_rooms = yield self.store.get_rooms_for_user(observer_user.to_string()) + observed_rooms = yield self.store.get_rooms_for_user(observed_user.to_string()) - localusers, remoteusers = partitionbool( - users_to_push, - lambda u: self.hs.is_mine(u) - ) + observer_room_ids = set(r.room_id for r in observer_rooms) + observed_room_ids = set(r.room_id for r in observed_rooms) - localusers = set(localusers) + if observer_room_ids & observed_room_ids: + defer.returnValue(True) - self.push_update_to_clients( - users_to_push=localusers, room_ids=room_ids + accepted_observers = yield self.store.get_presence_list_observers_accepted( + observed_user.to_string() ) - remote_domains = set(remote_domains) - remote_domains |= set([r.domain for r in remoteusers]) - for room_id in room_ids: - remote_domains.update( - (yield self.store.get_joined_hosts_for_room(room_id)) - ) + defer.returnValue(observer_user.to_string() in accepted_observers) - remote_domains.discard(self.hs.hostname) - - deferreds = [] - for domain in remote_domains: - logger.debug(" | push to remote domain %s", domain) - deferreds.append( - self._push_presence_remote( - observed_user, domain, state=statuscache.get_state() - ) - ) - yield defer.DeferredList(deferreds, consumeErrors=True) +def should_notify(old_state, new_state): + """Decides if a presence state change should be sent to interested parties. + """ + if old_state.status_msg != new_state.status_msg: + return True - defer.returnValue((localusers, remote_domains)) + if old_state.state == PresenceState.ONLINE: + if new_state.state != PresenceState.ONLINE: + # Always notify for online -> anything + return True - def push_update_to_clients(self, users_to_push=[], room_ids=[]): - """Notify clients of a new presence event. + if new_state.currently_active != old_state.currently_active: + return True - Args: - users_to_push([UserID]): List of users to notify. - room_ids([str]): List of room_ids to notify. - """ - with PreserveLoggingContext(): - self.notifier.on_new_event( - "presence_key", - self._user_cachemap_latest_serial, - users_to_push, - room_ids, - ) + if new_state.last_active - old_state.last_active > LAST_ACTIVE_GRANULARITY: + # Always notify for a transition where last active gets bumped. + return True - @defer.inlineCallbacks - def _push_presence_remote(self, user, destination, state=None): - """Push a user's presence to a remote server. If a presence state event - that event is sent. Otherwise a new state event is constructed from the - stored presence state. - The last_active is replaced with last_active_ago in case the wallclock - time on the remote server is different to the time on this server. - Sends an EDU to the remote server with the current presence state. + if old_state.state != new_state.state: + # Nothing to report. + return True - Args: - user(UserID): The user to push the presence state for. - destination(str): The remote server to send state to. - state(dict): The state to push, or None to use the current stored - state. - Returns: - A Deferred. - """ - if state is None: - state = yield self.store.get_presence_state(user.localpart) - del state["mtime"] - state["presence"] = state.pop("state") - - if user in self._user_cachemap: - state["last_active"] = ( - self._user_cachemap[user].get_state()["last_active"] - ) + return False - yield collect_presencelike_data(self.distributor, user, state) - if "last_active" in state: - state = dict(state) - state["last_active_ago"] = int( - self.clock.time_msec() - state.pop("last_active") - ) - - user_state = {"user_id": user.to_string(), } - user_state.update(state) +def _format_user_presence_state(state, now): + """Convert UserPresenceState to a format that can be sent down to clients + and to other servers. + """ + content = { + "presence": state.state, + "user_id": state.user_id, + } + if state.last_active: + content["last_active_ago"] = now - state.last_active + if state.status_msg and state.state != PresenceState.OFFLINE: + content["status_msg"] = state.status_msg + if state.state == PresenceState.ONLINE: + content["currently_active"] = state.currently_active - yield self.federation.send_edu( - destination=destination, - edu_type="m.presence", - content={"push": [user_state, ], } - ) + return content class PresenceEventSource(object): def __init__(self, hs): self.hs = hs self.clock = hs.get_clock() + self.store = hs.get_datastore() @defer.inlineCallbacks @log_function - def get_new_events(self, user, from_key, room_ids=None, **kwargs): - from_key = int(from_key) + def get_new_events(self, user, from_key, room_ids=None, include_offline=True, + **kwargs): + # The process for getting presence events are: + # 1. Get the rooms the user is in. + # 2. Get the list of user in the rooms. + # 3. Get the list of users that are in the user's presence list. + # 4. If there is a from_key set, cross reference the list of users + # with the `presence_stream_cache` to see which ones we actually + # need to check. + # 5. Load current state for the users. + # + # We don't try and limit the presence updates by the current token, as + # sending down the rare duplicate is not a concern. + + user_id = user.to_string() + if from_key is not None: + from_key = int(from_key) room_ids = room_ids or [] presence = self.hs.get_handlers().presence_handler - cachemap = presence._user_cachemap - - max_serial = presence._user_cachemap_latest_serial - - clock = self.clock - latest_serial = 0 - - user_ids_to_check = {user} - presence_list = yield presence.store.get_presence_list( - user.localpart, accepted=True - ) - if presence_list is not None: - user_ids_to_check |= set( - UserID.from_string(p["observed_user_id"]) for p in presence_list - ) - for room_id in set(room_ids) & set(presence._room_serials): - if presence._room_serials[room_id] > from_key: - joined = yield presence.get_joined_users_for_room_id(room_id) - user_ids_to_check |= set(joined) - updates = [] - for observed_user in user_ids_to_check & set(cachemap): - cached = cachemap[observed_user] - - if cached.serial <= from_key or cached.serial > max_serial: - continue - - latest_serial = max(cached.serial, latest_serial) - updates.append(cached.make_event(user=observed_user, clock=clock)) + if not room_ids: + rooms = yield self.store.get_rooms_for_user(user_id) + room_ids = set(e.room_id for e in rooms) - # TODO(paul): limit - - for serial, user_ids in presence._remote_offline_serials: - if serial <= from_key: - break - - if serial > max_serial: - continue - - latest_serial = max(latest_serial, serial) - for u in user_ids: - updates.append({ - "type": "m.presence", - "content": {"user_id": u, "presence": PresenceState.OFFLINE}, - }) - # TODO(paul): For the v2 API we want to tell the client their from_key - # is too old if we fell off the end of the _remote_offline_serials - # list, and get them to invalidate+resync. In v1 we have no such - # concept so this is a best-effort result. - - if updates: - defer.returnValue((updates, latest_serial)) - else: - defer.returnValue(([], presence._user_cachemap_latest_serial)) - - def get_current_key(self): - presence = self.hs.get_handlers().presence_handler - return presence._user_cachemap_latest_serial + user_ids_to_check = set() + for room_id in room_ids: + users = yield self.store.get_users_in_room(room_id) + user_ids_to_check.update(users) - @defer.inlineCallbacks - def get_pagination_rows(self, user, pagination_config, key): - # TODO (erikj): Does this make sense? Ordering? + plist = yield self.store.get_presence_list_accepted(user.localpart) + user_ids_to_check.update([row["observed_user_id"] for row in plist]) - from_key = int(pagination_config.from_key) + # Always include yourself. Only really matters for when the user is + # not in any rooms, but still. + user_ids_to_check.add(user_id) - if pagination_config.to_key: - to_key = int(pagination_config.to_key) - else: - to_key = -1 + max_token = self.store.get_current_presence_token() - presence = self.hs.get_handlers().presence_handler - cachemap = presence._user_cachemap - - user_ids_to_check = {user} - presence_list = yield presence.store.get_presence_list( - user.localpart, accepted=True - ) - if presence_list is not None: - user_ids_to_check |= set( - UserID.from_string(p["observed_user_id"]) for p in presence_list + if from_key: + user_ids_changed = self.store.presence_stream_cache.get_entities_changed( + user_ids_to_check, from_key, ) - room_ids = yield presence.get_joined_rooms_for_user(user) - for room_id in set(room_ids) & set(presence._room_serials): - if presence._room_serials[room_id] >= from_key: - joined = yield presence.get_joined_users_for_room_id(room_id) - user_ids_to_check |= set(joined) - - updates = [] - for observed_user in user_ids_to_check & set(cachemap): - if not (to_key < cachemap[observed_user].serial <= from_key): - continue - - updates.append((observed_user, cachemap[observed_user])) - - # TODO(paul): limit - - if updates: - clock = self.clock - - earliest_serial = max([x[1].serial for x in updates]) - data = [x[1].make_event(user=x[0], clock=clock) for x in updates] - - defer.returnValue((data, earliest_serial)) else: - defer.returnValue(([], 0)) - + user_ids_changed = user_ids_to_check -class UserPresenceCache(object): - """Store an observed user's state and status message. + updates = yield presence.current_state_for_users(user_ids_changed) - Includes the update timestamp. - """ - def __init__(self): - self.state = {"presence": PresenceState.OFFLINE} - self.serial = None - - def __repr__(self): - return "UserPresenceCache(state=%r, serial=%r)" % ( - self.state, self.serial - ) - - def update(self, state, serial): - assert("mtime_age" not in state) + now = self.clock.time_msec() - self.state.update(state) - # Delete keys that are now 'None' - for k in self.state.keys(): - if self.state[k] is None: - del self.state[k] - - self.serial = serial - - if "status_msg" in state: - self.status_msg = state["status_msg"] - else: - self.status_msg = None - - def get_state(self): - # clone it so caller can't break our cache - state = dict(self.state) - return state - - def make_event(self, user, clock): - content = self.get_state() - content["user_id"] = user.to_string() + defer.returnValue(([ + { + "type": "m.presence", + "content": _format_user_presence_state(s, now), + } + for s in updates.values() + if include_offline or s.state != PresenceState.OFFLINE + ], max_token)) - if "last_active" in content: - content["last_active_ago"] = int( - clock.time_msec() - content.pop("last_active") - ) + def get_current_key(self): + return self.store.get_current_presence_token() - return {"type": "m.presence", "content": content} + def get_pagination_rows(self, user, pagination_config, key): + return self.get_new_events(user, from_key=None, include_offline=False) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 629e6e3594..7084a7396f 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -49,6 +49,9 @@ class ProfileHandler(BaseHandler): distributor = hs.get_distributor() self.distributor = distributor + distributor.declare("collect_presencelike_data") + distributor.declare("changed_presencelike_data") + distributor.observe("registered_user", self.registered_user) distributor.observe( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 1d0f0058a2..c5c13e085b 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -582,6 +582,28 @@ class SyncHandler(BaseHandler): if room_sync: joined.append(room_sync) + # For each newly joined room, we want to send down presence of + # existing users. + presence_handler = self.hs.get_handlers().presence_handler + extra_presence_users = set() + for room_id in newly_joined_rooms: + users = yield self.store.get_users_in_room(event.room_id) + extra_presence_users.update(users) + + # For each new member, send down presence. + for joined_sync in joined: + it = itertools.chain(joined_sync.timeline.events, joined_sync.state.values()) + for event in it: + if event.type == EventTypes.Member: + if event.membership == Membership.JOIN: + extra_presence_users.add(event.state_key) + + states = yield presence_handler.get_states( + [u for u in extra_presence_users if u != user_id], + as_event=True, + ) + presence.extend(states) + account_data_for_user = sync_config.filter_collection.filter_account_data( self.account_data_for_user(account_data) ) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index a6f8754e32..27ea5f2a43 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -17,7 +17,7 @@ """ from twisted.internet import defer -from synapse.api.errors import SynapseError +from synapse.api.errors import SynapseError, AuthError from synapse.types import UserID from .base import ClientV1RestServlet, client_path_patterns @@ -35,8 +35,15 @@ class PresenceStatusRestServlet(ClientV1RestServlet): requester = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) - state = yield self.handlers.presence_handler.get_state( - target_user=user, auth_user=requester.user) + if requester.user != user: + allowed = yield self.handlers.presence_handler.is_visible( + observed_user=user, observer_user=requester.user, + ) + + if not allowed: + raise AuthError(403, "You are allowed to see their presence.") + + state = yield self.handlers.presence_handler.get_state(target_user=user) defer.returnValue((200, state)) @@ -45,6 +52,9 @@ class PresenceStatusRestServlet(ClientV1RestServlet): requester = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) + if requester.user != user: + raise AuthError(403, "Can only set your own presence state") + state = {} try: content = json.loads(request.content.read()) @@ -63,8 +73,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): except: raise SynapseError(400, "Unable to parse state") - yield self.handlers.presence_handler.set_state( - target_user=user, auth_user=requester.user, state=state) + yield self.handlers.presence_handler.set_state(user, state) defer.returnValue((200, {})) @@ -87,11 +96,8 @@ class PresenceListRestServlet(ClientV1RestServlet): raise SynapseError(400, "Cannot get another user's presence list") presence = yield self.handlers.presence_handler.get_presence_list( - observer_user=user, accepted=True) - - for p in presence: - observed_user = p.pop("observed_user") - p["user_id"] = observed_user.to_string() + observer_user=user, accepted=True + ) defer.returnValue((200, presence)) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 24706f9383..a8e89c7fe9 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -304,18 +304,6 @@ class RoomMemberListRestServlet(ClientV1RestServlet): if event["type"] != EventTypes.Member: continue chunk.append(event) - # FIXME: should probably be state_key here, not user_id - target_user = UserID.from_string(event["user_id"]) - # Presence is an optional cache; don't fail if we can't fetch it - try: - presence_handler = self.handlers.presence_handler - presence_state = yield presence_handler.get_state( - target_user=target_user, - auth_user=requester.user, - ) - event["content"].update(presence_state) - except: - pass defer.returnValue((200, { "chunk": chunk @@ -541,6 +529,10 @@ class RoomTypingRestServlet(ClientV1RestServlet): "/rooms/(?P[^/]*)/typing/(?P[^/]*)$" ) + def __init__(self, hs): + super(RoomTypingRestServlet, self).__init__(hs) + self.presence_handler = hs.get_handlers().presence_handler + @defer.inlineCallbacks def on_PUT(self, request, room_id, user_id): requester = yield self.auth.get_user_by_req(request) @@ -552,6 +544,8 @@ class RoomTypingRestServlet(ClientV1RestServlet): typing_handler = self.handlers.typing_notification_handler + yield self.presence_handler.bump_presence_active_time(requester.user) + if content["typing"]: yield typing_handler.started_typing( target_user=target_user, diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index eb4b369a3d..b831d8c95e 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -37,6 +37,7 @@ class ReceiptRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.receipts_handler = hs.get_handlers().receipts_handler + self.presence_handler = hs.get_handlers().presence_handler @defer.inlineCallbacks def on_POST(self, request, room_id, receipt_type, event_id): @@ -45,6 +46,8 @@ class ReceiptRestServlet(RestServlet): if receipt_type != "m.read": raise SynapseError(400, "Receipt type must be 'm.read'") + yield self.presence_handler.bump_presence_active_time(requester.user) + yield self.receipts_handler.received_client_receipt( room_id, receipt_type, diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index accbc6cfac..de4a020ad4 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -25,6 +25,7 @@ from synapse.events.utils import ( ) from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION from synapse.api.errors import SynapseError +from synapse.api.constants import PresenceState from ._base import client_v2_patterns import copy @@ -82,6 +83,7 @@ class SyncRestServlet(RestServlet): self.sync_handler = hs.get_handlers().sync_handler self.clock = hs.get_clock() self.filtering = hs.get_filtering() + self.presence_handler = hs.get_handlers().presence_handler @defer.inlineCallbacks def on_GET(self, request): @@ -139,17 +141,19 @@ class SyncRestServlet(RestServlet): else: since_token = None - if set_presence == "online": - yield self.event_stream_handler.started_stream(user) + affect_presence = set_presence != PresenceState.OFFLINE - try: + if affect_presence: + yield self.presence_handler.set_state(user, {"presence": set_presence}) + + context = yield self.presence_handler.user_syncing( + user.to_string(), affect_presence=affect_presence, + ) + with context: sync_result = yield self.sync_handler.wait_for_sync_for_user( sync_config, since_token=since_token, timeout=timeout, full_state=full_state ) - finally: - if set_presence == "online": - self.event_stream_handler.stopped_stream(user) time_now = self.clock.time_msec() diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 5a9e7720d9..8c3cf9e801 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -20,7 +20,7 @@ from .appservice import ( from ._base import Cache from .directory import DirectoryStore from .events import EventsStore -from .presence import PresenceStore +from .presence import PresenceStore, UserPresenceState from .profile import ProfileStore from .registration import RegistrationStore from .room import RoomStore @@ -47,6 +47,7 @@ from .account_data import AccountDataStore from util.id_generators import IdGenerator, StreamIdGenerator +from synapse.api.constants import PresenceState from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -110,6 +111,9 @@ class DataStore(RoomMemberStore, RoomStore, self._account_data_id_gen = StreamIdGenerator( db_conn, "account_data_max_stream_id", "stream_id" ) + self._presence_id_gen = StreamIdGenerator( + db_conn, "presence_stream", "stream_id" + ) self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) self._state_groups_id_gen = IdGenerator("state_groups", "id", self) @@ -119,7 +123,7 @@ class DataStore(RoomMemberStore, RoomStore, self._push_rule_id_gen = IdGenerator("push_rules", "id", self) self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) - events_max = self._stream_id_gen.get_max_token(None) + events_max = self._stream_id_gen.get_max_token() event_cache_prefill, min_event_val = self._get_cache_dict( db_conn, "events", entity_column="room_id", @@ -135,13 +139,31 @@ class DataStore(RoomMemberStore, RoomStore, "MembershipStreamChangeCache", events_max, ) - account_max = self._account_data_id_gen.get_max_token(None) + account_max = self._account_data_id_gen.get_max_token() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max, ) + self.__presence_on_startup = self._get_active_presence(db_conn) + + presence_cache_prefill, min_presence_val = self._get_cache_dict( + db_conn, "presence_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self._presence_id_gen.get_max_token(), + ) + self.presence_stream_cache = StreamChangeCache( + "PresenceStreamChangeCache", min_presence_val, + prefilled_cache=presence_cache_prefill + ) + super(DataStore, self).__init__(hs) + def take_presence_startup_info(self): + active_on_startup = self.__presence_on_startup + self.__presence_on_startup = None + return active_on_startup + def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value): # Fetch a mapping of room_id -> max stream position for "recent" rooms. # It doesn't really matter how many we get, the StreamChangeCache will @@ -161,6 +183,7 @@ class DataStore(RoomMemberStore, RoomStore, txn = db_conn.cursor() txn.execute(sql, (int(max_value),)) rows = txn.fetchall() + txn.close() cache = { row[0]: int(row[1]) @@ -174,6 +197,27 @@ class DataStore(RoomMemberStore, RoomStore, return cache, min_val + def _get_active_presence(self, db_conn): + """Fetch non-offline presence from the database so that we can register + the appropriate time outs. + """ + + sql = ( + "SELECT user_id, state, last_active, last_federation_update," + " last_user_sync, status_msg, currently_active FROM presence_stream" + " WHERE state != ?" + ) + sql = self.database_engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, (PresenceState.OFFLINE,)) + rows = self.cursor_to_dict(txn) + + for row in rows: + row["currently_active"] = bool(row["currently_active"]) + + return [UserPresenceState(**row) for row in rows] + @defer.inlineCallbacks def insert_client_ip(self, user, access_token, ip, user_agent): now = int(self._clock.time_msec()) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 850736c85e..0fd5d497ab 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 29 +SCHEMA_VERSION = 30 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index ef525f34c5..b133979102 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -14,73 +14,128 @@ # limitations under the License. from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedList +from synapse.api.constants import PresenceState +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from collections import namedtuple from twisted.internet import defer -class PresenceStore(SQLBaseStore): - def create_presence(self, user_localpart): - res = self._simple_insert( - table="presence", - values={"user_id": user_localpart}, - desc="create_presence", +class UserPresenceState(namedtuple("UserPresenceState", + ("user_id", "state", "last_active", "last_federation_update", + "last_user_sync", "status_msg", "currently_active"))): + """Represents the current presence state of the user. + + user_id (str) + last_active (int): Time in msec that the user last interacted with server. + last_federation_update (int): Time in msec since either a) we sent a presence + update to other servers or b) we received a presence update, depending + on if is a local user or not. + last_user_sync (int): Time in msec that the user last *completed* a sync + (or event stream). + status_msg (str): User set status message. + """ + + def copy_and_replace(self, **kwargs): + return self._replace(**kwargs) + + @classmethod + def default(cls, user_id): + """Returns a default presence state. + """ + return cls( + user_id=user_id, + state=PresenceState.OFFLINE, + last_active=0, + last_federation_update=0, + last_user_sync=0, + status_msg=None, + currently_active=False, ) - self.get_presence_state.invalidate((user_localpart,)) - return res - def has_presence_state(self, user_localpart): - return self._simple_select_one( - table="presence", - keyvalues={"user_id": user_localpart}, - retcols=["user_id"], - allow_none=True, - desc="has_presence_state", +class PresenceStore(SQLBaseStore): + @defer.inlineCallbacks + def update_presence(self, presence_states): + stream_id_manager = yield self._presence_id_gen.get_next(self) + with stream_id_manager as stream_id: + yield self.runInteraction( + "update_presence", + self._update_presence_txn, stream_id, presence_states, + ) + + defer.returnValue((stream_id, self._presence_id_gen.get_max_token())) + + def _update_presence_txn(self, txn, stream_id, presence_states): + for state in presence_states: + txn.call_after( + self.presence_stream_cache.entity_has_changed, + state.user_id, stream_id, + ) + + # Actually insert new rows + self._simple_insert_many_txn( + txn, + table="presence_stream", + values=[ + { + "stream_id": stream_id, + "user_id": state.user_id, + "state": state.state, + "last_active": state.last_active, + "last_federation_update": state.last_federation_update, + "last_user_sync": state.last_user_sync, + "status_msg": state.status_msg, + "currently_active": state.currently_active, + } + for state in presence_states + ], ) - @cached(max_entries=2000) - def get_presence_state(self, user_localpart): - return self._simple_select_one( - table="presence", - keyvalues={"user_id": user_localpart}, - retcols=["state", "status_msg", "mtime"], - desc="get_presence_state", + # Delete old rows to stop database from getting really big + sql = ( + "DELETE FROM presence_stream WHERE" + " stream_id < ?" + " AND user_id IN (%s)" ) - @cachedList(get_presence_state.cache, list_name="user_localparts", - inlineCallbacks=True) - def get_presence_states(self, user_localparts): + batches = ( + presence_states[i:i + 50] + for i in xrange(0, len(presence_states), 50) + ) + for states in batches: + args = [stream_id] + args.extend(s.user_id for s in states) + txn.execute( + sql % (",".join("?" for _ in states),), + args + ) + + @defer.inlineCallbacks + def get_presence_for_users(self, user_ids): rows = yield self._simple_select_many_batch( - table="presence", + table="presence_stream", column="user_id", - iterable=user_localparts, - retcols=("user_id", "state", "status_msg", "mtime",), - desc="get_presence_states", + iterable=user_ids, + keyvalues={}, + retcols=( + "user_id", + "state", + "last_active", + "last_federation_update", + "last_user_sync", + "status_msg", + "currently_active", + ), ) - defer.returnValue({ - row["user_id"]: { - "state": row["state"], - "status_msg": row["status_msg"], - "mtime": row["mtime"], - } - for row in rows - }) + for row in rows: + row["currently_active"] = bool(row["currently_active"]) - @defer.inlineCallbacks - def set_presence_state(self, user_localpart, new_state): - res = yield self._simple_update_one( - table="presence", - keyvalues={"user_id": user_localpart}, - updatevalues={"state": new_state["state"], - "status_msg": new_state["status_msg"], - "mtime": self._clock.time_msec()}, - desc="set_presence_state", - ) + defer.returnValue([UserPresenceState(**row) for row in rows]) - self.get_presence_state.invalidate((user_localpart,)) - defer.returnValue(res) + def get_current_presence_token(self): + return self._presence_id_gen.get_max_token() def allow_presence_visible(self, observed_localpart, observer_userid): return self._simple_insert( @@ -128,6 +183,7 @@ class PresenceStore(SQLBaseStore): desc="set_presence_list_accepted", ) self.get_presence_list_accepted.invalidate((observer_localpart,)) + self.get_presence_list_observers_accepted.invalidate((observed_userid,)) defer.returnValue(result) def get_presence_list(self, observer_localpart, accepted=None): @@ -154,6 +210,19 @@ class PresenceStore(SQLBaseStore): desc="get_presence_list_accepted", ) + @cachedInlineCallbacks() + def get_presence_list_observers_accepted(self, observed_userid): + user_localparts = yield self._simple_select_onecol( + table="presence_list", + keyvalues={"observed_user_id": observed_userid, "accepted": True}, + retcol="user_id", + desc="get_presence_list_accepted", + ) + + defer.returnValue([ + "@%s:%s" % (u, self.hs.hostname,) for u in user_localparts + ]) + @defer.inlineCallbacks def del_presence_list(self, observer_localpart, observed_userid): yield self._simple_delete_one( @@ -163,3 +232,4 @@ class PresenceStore(SQLBaseStore): desc="del_presence_list", ) self.get_presence_list_accepted.invalidate((observer_localpart,)) + self.get_presence_list_observers_accepted.invalidate((observed_userid,)) diff --git a/synapse/storage/schema/delta/30/presence_stream.sql b/synapse/storage/schema/delta/30/presence_stream.sql new file mode 100644 index 0000000000..14f5e3d30a --- /dev/null +++ b/synapse/storage/schema/delta/30/presence_stream.sql @@ -0,0 +1,30 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + CREATE TABLE presence_stream( + stream_id BIGINT, + user_id TEXT, + state TEXT, + last_active BIGINT, + last_federation_update BIGINT, + last_user_sync BIGINT, + status_msg TEXT, + currently_active BOOLEAN + ); + + CREATE INDEX presence_stream_id ON presence_stream(stream_id, user_id); + CREATE INDEX presence_stream_user_id ON presence_stream(user_id); + CREATE INDEX presence_stream_state ON presence_stream(state); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 5c522f4ab9..5ce54f76de 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -130,9 +130,11 @@ class StreamIdGenerator(object): return manager() - def get_max_token(self, store): + def get_max_token(self, *args): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. + + Used to take a DataStore param, which is no longer needed. """ with self._lock: if self._unfinished_ids: diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 133671e238..3b9da5b34a 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -42,7 +42,7 @@ class Clock(object): def time_msec(self): """Returns the current system time in miliseconds since epoch.""" - return self.time() * 1000 + return int(self.time() * 1000) def looping_call(self, f, msec): l = task.LoopingCall(f) diff --git a/tests/utils.py b/tests/utils.py index 3b1eb50d8d..f71125042b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -224,12 +224,12 @@ class MockClock(object): def time_msec(self): return self.time() * 1000 - def call_later(self, delay, callback): + def call_later(self, delay, callback, *args, **kwargs): current_context = LoggingContext.current_context() def wrapped_callback(): LoggingContext.thread_local.current_context = current_context - callback() + callback(*args, **kwargs) t = [self.now + delay, wrapped_callback, False] self.timers.append(t) -- cgit 1.4.1 From 112283e23005bdaa17b6184cb55fd786facff47d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 18 Feb 2016 10:11:43 +0000 Subject: Prefix TS fields with _ts --- synapse/handlers/presence.py | 54 +++++++++++----------- synapse/storage/__init__.py | 4 +- synapse/storage/presence.py | 23 ++++----- .../storage/schema/delta/30/presence_stream.sql | 6 +-- 4 files changed, 44 insertions(+), 43 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 3137c23509..d296953651 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -127,24 +127,24 @@ class PresenceHandler(BaseHandler): self.wheel_timer.insert( now=now, obj=state.user_id, - then=state.last_active + IDLE_TIMER, + then=state.last_active_ts + IDLE_TIMER, ) self.wheel_timer.insert( now=now, obj=state.user_id, - then=state.last_user_sync + SYNC_ONLINE_TIMEOUT, + then=state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT, ) if self.hs.is_mine_id(state.user_id): self.wheel_timer.insert( now=now, obj=state.user_id, - then=state.last_federation_update + FEDERATION_PING_INTERVAL, + then=state.last_federation_update_ts + FEDERATION_PING_INTERVAL, ) else: self.wheel_timer.insert( now=now, obj=state.user_id, - then=state.last_federation_update + FEDERATION_TIMEOUT, + then=state.last_federation_update_ts + FEDERATION_TIMEOUT, ) # Set of users who have presence in the `user_to_current_state` that @@ -225,7 +225,7 @@ class PresenceHandler(BaseHandler): self.wheel_timer.insert( now=now, obj=user_id, - then=new_state.last_active + IDLE_TIMER + then=new_state.last_active_ts + IDLE_TIMER ) if new_state.state != PresenceState.OFFLINE: @@ -233,14 +233,14 @@ class PresenceHandler(BaseHandler): self.wheel_timer.insert( now=now, obj=user_id, - then=new_state.last_user_sync + SYNC_ONLINE_TIMEOUT + then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT ) - last_federate = new_state.last_federation_update + last_federate = new_state.last_federation_update_ts if now - last_federate > FEDERATION_PING_INTERVAL: # Been a while since we've poked remote servers new_state = new_state.copy_and_replace( - last_federation_update=now, + last_federation_update_ts=now, ) to_federation_ping[user_id] = new_state @@ -248,11 +248,11 @@ class PresenceHandler(BaseHandler): self.wheel_timer.insert( now=now, obj=user_id, - then=new_state.last_federation_update + FEDERATION_TIMEOUT + then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT ) if new_state.state == PresenceState.ONLINE: - currently_active = now - new_state.last_active < LAST_ACTIVE_GRANULARITY + currently_active = now - new_state.last_active_ts < LAST_ACTIVE_GRANULARITY new_state = new_state.copy_and_replace( currently_active=currently_active, ) @@ -260,7 +260,7 @@ class PresenceHandler(BaseHandler): # Check whether the change was something worth notifying about if should_notify(prev_state, new_state): new_state.copy_and_replace( - last_federation_update=now, + last_federation_update_ts=now, ) to_notify[user_id] = new_state @@ -309,18 +309,18 @@ class PresenceHandler(BaseHandler): if self.hs.is_mine_id(user_id): if state.state == PresenceState.ONLINE: - if now - state.last_active > IDLE_TIMER: + if now - state.last_active_ts > IDLE_TIMER: # Currently online, but last activity ages ago so auto # idle changes[user_id] = state.copy_and_replace( state=PresenceState.UNAVAILABLE, ) - elif now - state.last_active > LAST_ACTIVE_GRANULARITY: + elif now - state.last_active_ts > LAST_ACTIVE_GRANULARITY: # So that we send down a notification that we've # stopped updating. changes[user_id] = state - if now - state.last_federation_update > FEDERATION_PING_INTERVAL: + if now - state.last_federation_update_ts > FEDERATION_PING_INTERVAL: # Need to send ping to other servers to ensure they don't # timeout and set us to offline changes[user_id] = state @@ -328,7 +328,7 @@ class PresenceHandler(BaseHandler): # If there are have been no sync for a while (and none ongoing), # set presence to offline if not self.user_to_num_current_syncs.get(user_id, 0): - if now - state.last_user_sync > SYNC_ONLINE_TIMEOUT: + if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT: changes[user_id] = state.copy_and_replace( state=PresenceState.OFFLINE, status_msg=None, @@ -337,7 +337,7 @@ class PresenceHandler(BaseHandler): # We expect to be poked occaisonally by the other side. # This is to protect against forgetful/buggy servers, so that # no one gets stuck online forever. - if now - state.last_federation_update > FEDERATION_TIMEOUT: + if now - state.last_federation_update_ts > FEDERATION_TIMEOUT: # The other side seems to have disappeared. changes[user_id] = state.copy_and_replace( state=PresenceState.OFFLINE, @@ -356,7 +356,7 @@ class PresenceHandler(BaseHandler): prev_state = yield self.current_state_for_user(user_id) new_fields = { - "last_active": self.clock.time_msec(), + "last_active_ts": self.clock.time_msec(), } if prev_state.state == PresenceState.UNAVAILABLE: new_fields["state"] = PresenceState.ONLINE @@ -388,12 +388,12 @@ class PresenceHandler(BaseHandler): # just update the last sync times. yield self._update_states([prev_state.copy_and_replace( state=PresenceState.ONLINE, - last_active=self.clock.time_msec(), - last_user_sync=self.clock.time_msec(), + last_active_ts=self.clock.time_msec(), + last_user_sync_ts=self.clock.time_msec(), )]) else: yield self._update_states([prev_state.copy_and_replace( - last_user_sync=self.clock.time_msec(), + last_user_sync_ts=self.clock.time_msec(), )]) @defer.inlineCallbacks @@ -403,7 +403,7 @@ class PresenceHandler(BaseHandler): prev_state = yield self.current_state_for_user(user_id) yield self._update_states([prev_state.copy_and_replace( - last_user_sync=self.clock.time_msec(), + last_user_sync_ts=self.clock.time_msec(), )]) @contextmanager @@ -553,12 +553,12 @@ class PresenceHandler(BaseHandler): new_fields = { "state": presence_state, - "last_federation_update": now, + "last_federation_update_ts": now, } last_active_ago = push.get("last_active_ago", None) if last_active_ago is not None: - new_fields["last_active"] = now - last_active_ago + new_fields["last_active_ts"] = now - last_active_ago new_fields["status_msg"] = push.get("status_msg", None) @@ -632,7 +632,7 @@ class PresenceHandler(BaseHandler): } if presence == PresenceState.ONLINE: - new_fields["last_active"] = self.clock.time_msec() + new_fields["last_active_ts"] = self.clock.time_msec() yield self._update_states([prev_state.copy_and_replace(**new_fields)]) @@ -823,7 +823,7 @@ def should_notify(old_state, new_state): if new_state.currently_active != old_state.currently_active: return True - if new_state.last_active - old_state.last_active > LAST_ACTIVE_GRANULARITY: + if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: # Always notify for a transition where last active gets bumped. return True @@ -841,8 +841,8 @@ def _format_user_presence_state(state, now): "presence": state.state, "user_id": state.user_id, } - if state.last_active: - content["last_active_ago"] = now - state.last_active + if state.last_active_ts: + content["last_active_ago"] = now - state.last_active_ts if state.status_msg and state.state != PresenceState.OFFLINE: content["status_msg"] = state.status_msg if state.state == PresenceState.ONLINE: diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 8c3cf9e801..fcb968e8f4 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -203,8 +203,8 @@ class DataStore(RoomMemberStore, RoomStore, """ sql = ( - "SELECT user_id, state, last_active, last_federation_update," - " last_user_sync, status_msg, currently_active FROM presence_stream" + "SELECT user_id, state, last_active_ts, last_federation_update_ts," + " last_user_sync_ts, status_msg, currently_active FROM presence_stream" " WHERE state != ?" ) sql = self.database_engine.convert_param_style(sql) diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index b133979102..70ece56548 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -22,8 +22,9 @@ from twisted.internet import defer class UserPresenceState(namedtuple("UserPresenceState", - ("user_id", "state", "last_active", "last_federation_update", - "last_user_sync", "status_msg", "currently_active"))): + ("user_id", "state", "last_active_ts", + "last_federation_update_ts", "last_user_sync_ts", + "status_msg", "currently_active"))): """Represents the current presence state of the user. user_id (str) @@ -46,9 +47,9 @@ class UserPresenceState(namedtuple("UserPresenceState", return cls( user_id=user_id, state=PresenceState.OFFLINE, - last_active=0, - last_federation_update=0, - last_user_sync=0, + last_active_ts=0, + last_federation_update_ts=0, + last_user_sync_ts=0, status_msg=None, currently_active=False, ) @@ -82,9 +83,9 @@ class PresenceStore(SQLBaseStore): "stream_id": stream_id, "user_id": state.user_id, "state": state.state, - "last_active": state.last_active, - "last_federation_update": state.last_federation_update, - "last_user_sync": state.last_user_sync, + "last_active_ts": state.last_active_ts, + "last_federation_update_ts": state.last_federation_update_ts, + "last_user_sync_ts": state.last_user_sync_ts, "status_msg": state.status_msg, "currently_active": state.currently_active, } @@ -121,9 +122,9 @@ class PresenceStore(SQLBaseStore): retcols=( "user_id", "state", - "last_active", - "last_federation_update", - "last_user_sync", + "last_active_ts", + "last_federation_update_ts", + "last_user_sync_ts", "status_msg", "currently_active", ), diff --git a/synapse/storage/schema/delta/30/presence_stream.sql b/synapse/storage/schema/delta/30/presence_stream.sql index 14f5e3d30a..606bbb037d 100644 --- a/synapse/storage/schema/delta/30/presence_stream.sql +++ b/synapse/storage/schema/delta/30/presence_stream.sql @@ -18,9 +18,9 @@ stream_id BIGINT, user_id TEXT, state TEXT, - last_active BIGINT, - last_federation_update BIGINT, - last_user_sync BIGINT, + last_active_ts BIGINT, + last_federation_update_ts BIGINT, + last_user_sync_ts BIGINT, status_msg TEXT, currently_active BOOLEAN ); -- cgit 1.4.1 From b9977ea667889f6cf89464c92fc57cbcae7cca28 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 18 Feb 2016 16:05:13 +0000 Subject: Remove dead code for setting device specific rules. It wasn't possible to hit the code from the API because of a typo in parsing the request path. Since no-one was using the feature we might as well remove the dead code. --- synapse/push/__init__.py | 7 ++- synapse/push/action_generator.py | 2 +- synapse/push/bulk_push_rule_evaluator.py | 2 +- synapse/push/httppusher.py | 3 +- synapse/push/push_rule_evaluator.py | 15 ++---- synapse/push/pusherpool.py | 48 +++++++---------- synapse/rest/client/v1/push_rule.py | 90 ++------------------------------ synapse/rest/client/v1/pusher.py | 6 +-- synapse/storage/event_push_actions.py | 7 ++- synapse/storage/pusher.py | 6 +-- 10 files changed, 45 insertions(+), 141 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 8da2d8716c..4c6c3b83a2 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -47,14 +47,13 @@ class Pusher(object): MAX_BACKOFF = 60 * 60 * 1000 GIVE_UP_AFTER = 24 * 60 * 60 * 1000 - def __init__(self, _hs, profile_tag, user_id, app_id, + def __init__(self, _hs, user_id, app_id, app_display_name, device_display_name, pushkey, pushkey_ts, data, last_token, last_success, failing_since): self.hs = _hs self.evStreamHandler = self.hs.get_handlers().event_stream_handler self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() - self.profile_tag = profile_tag self.user_id = user_id self.app_id = app_id self.app_display_name = app_display_name @@ -186,8 +185,8 @@ class Pusher(object): processed = False rule_evaluator = yield \ - push_rule_evaluator.evaluator_for_user_id_and_profile_tag( - self.user_id, self.profile_tag, single_event['room_id'], self.store + push_rule_evaluator.evaluator_for_user_id( + self.user_id, single_event['room_id'], self.store ) actions = yield rule_evaluator.actions_for_event(single_event) diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index e0da0868ec..c6c1dc769e 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -44,5 +44,5 @@ class ActionGenerator: ) context.push_actions = [ - (uid, None, actions) for uid, actions in actions_by_user.items() + (uid, actions) for uid, actions in actions_by_user.items() ] diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 8ac5ceb9ef..0a23b3f102 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -152,7 +152,7 @@ def _condition_checker(evaluator, conditions, uid, display_name, cache): elif res is True: continue - res = evaluator.matches(cond, uid, display_name, None) + res = evaluator.matches(cond, uid, display_name) if _id: cache[_id] = bool(res) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index cdc4494928..9be4869360 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -23,12 +23,11 @@ logger = logging.getLogger(__name__) class HttpPusher(Pusher): - def __init__(self, _hs, profile_tag, user_id, app_id, + def __init__(self, _hs, user_id, app_id, app_display_name, device_display_name, pushkey, pushkey_ts, data, last_token, last_success, failing_since): super(HttpPusher, self).__init__( _hs, - profile_tag, user_id, app_id, app_display_name, diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 2a2b4437dc..98e2a2015e 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -33,7 +33,7 @@ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") @defer.inlineCallbacks -def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store): +def evaluator_for_user_id(user_id, room_id, store): rawrules = yield store.get_push_rules_for_user(user_id) enabled_map = yield store.get_push_rules_enabled_for_user(user_id) our_member_event = yield store.get_current_state( @@ -43,7 +43,7 @@ def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store): ) defer.returnValue(PushRuleEvaluator( - user_id, profile_tag, rawrules, enabled_map, + user_id, rawrules, enabled_map, room_id, our_member_event, store )) @@ -77,10 +77,9 @@ def _room_member_count(ev, condition, room_member_count): class PushRuleEvaluator: DEFAULT_ACTIONS = [] - def __init__(self, user_id, profile_tag, raw_rules, enabled_map, room_id, + def __init__(self, user_id, raw_rules, enabled_map, room_id, our_member_event, store): self.user_id = user_id - self.profile_tag = profile_tag self.room_id = room_id self.our_member_event = our_member_event self.store = store @@ -152,7 +151,7 @@ class PushRuleEvaluator: matches = True for c in conditions: matches = evaluator.matches( - c, self.user_id, my_display_name, self.profile_tag + c, self.user_id, my_display_name ) if not matches: break @@ -189,13 +188,9 @@ class PushRuleEvaluatorForEvent(object): # Maps strings of e.g. 'content.body' -> event["content"]["body"] self._value_cache = _flatten_dict(event) - def matches(self, condition, user_id, display_name, profile_tag): + def matches(self, condition, user_id, display_name): if condition['kind'] == 'event_match': return self._event_match(condition, user_id) - elif condition['kind'] == 'device': - if 'profile_tag' not in condition: - return True - return condition['profile_tag'] == profile_tag elif condition['kind'] == 'contains_display_name': return self._contains_display_name(display_name) elif condition['kind'] == 'room_member_count': diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index d7dcb2de4b..a05aa5f661 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -29,6 +29,7 @@ class PusherPool: def __init__(self, _hs): self.hs = _hs self.store = self.hs.get_datastore() + self.clock = self.hs.get_clock() self.pushers = {} self.last_pusher_started = -1 @@ -38,8 +39,11 @@ class PusherPool: self._start_pushers(pushers) @defer.inlineCallbacks - def add_pusher(self, user_id, access_token, profile_tag, kind, app_id, - app_display_name, device_display_name, pushkey, lang, data): + def add_pusher(self, user_id, access_token, kind, app_id, + app_display_name, device_display_name, pushkey, lang, data, + profile_tag=""): + time_now_msec = self.clock.time_msec() + # we try to create the pusher just to validate the config: it # will then get pulled out of the database, # recreated, added and started: this means we have only one @@ -47,23 +51,31 @@ class PusherPool: self._create_pusher({ "user_name": user_id, "kind": kind, - "profile_tag": profile_tag, "app_id": app_id, "app_display_name": app_display_name, "device_display_name": device_display_name, "pushkey": pushkey, - "ts": self.hs.get_clock().time_msec(), + "ts": time_now_msec, "lang": lang, "data": data, "last_token": None, "last_success": None, "failing_since": None }) - yield self._add_pusher_to_store( - user_id, access_token, profile_tag, kind, app_id, - app_display_name, device_display_name, - pushkey, lang, data + yield self.store.add_pusher( + user_id=user_id, + access_token=access_token, + kind=kind, + app_id=app_id, + app_display_name=app_display_name, + device_display_name=device_display_name, + pushkey=pushkey, + pushkey_ts=time_now_msec, + lang=lang, + data=data, + profile_tag=profile_tag, ) + yield self._refresh_pusher(app_id, pushkey, user_id) @defer.inlineCallbacks def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey, @@ -94,30 +106,10 @@ class PusherPool: ) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) - @defer.inlineCallbacks - def _add_pusher_to_store(self, user_id, access_token, profile_tag, kind, - app_id, app_display_name, device_display_name, - pushkey, lang, data): - yield self.store.add_pusher( - user_id=user_id, - access_token=access_token, - profile_tag=profile_tag, - kind=kind, - app_id=app_id, - app_display_name=app_display_name, - device_display_name=device_display_name, - pushkey=pushkey, - pushkey_ts=self.hs.get_clock().time_msec(), - lang=lang, - data=data, - ) - yield self._refresh_pusher(app_id, pushkey, user_id) - def _create_pusher(self, pusherdict): if pusherdict['kind'] == 'http': return HttpPusher( self.hs, - profile_tag=pusherdict['profile_tag'], user_id=pusherdict['user_name'], app_id=pusherdict['app_id'], app_display_name=pusherdict['app_display_name'], diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 7766b8be1d..5db2805d68 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -60,7 +60,6 @@ class PushRuleRestServlet(ClientV1RestServlet): spec['template'], spec['rule_id'], content, - device=spec['device'] if 'device' in spec else None ) except InvalidRuleException as e: raise SynapseError(400, e.message) @@ -153,23 +152,7 @@ class PushRuleRestServlet(ClientV1RestServlet): elif pattern_type == "user_localpart": c["pattern"] = user.localpart - if r['priority_class'] > PRIORITY_CLASS_MAP['override']: - # per-device rule - profile_tag = _profile_tag_from_conditions(r["conditions"]) - r = _strip_device_condition(r) - if not profile_tag: - continue - if profile_tag not in rules['device']: - rules['device'][profile_tag] = {} - rules['device'][profile_tag] = ( - _add_empty_priority_class_arrays( - rules['device'][profile_tag] - ) - ) - - rulearray = rules['device'][profile_tag][template_name] - else: - rulearray = rules['global'][template_name] + rulearray = rules['global'][template_name] template_rule = _rule_to_template(r) if template_rule: @@ -195,24 +178,6 @@ class PushRuleRestServlet(ClientV1RestServlet): path = path[1:] result = _filter_ruleset_with_path(rules['global'], path) defer.returnValue((200, result)) - elif path[0] == 'device': - path = path[1:] - if path == []: - raise UnrecognizedRequestError( - PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR - ) - if path[0] == '': - defer.returnValue((200, rules['device'])) - - profile_tag = path[0] - path = path[1:] - if profile_tag not in rules['device']: - ret = {} - ret = _add_empty_priority_class_arrays(ret) - defer.returnValue((200, ret)) - ruleset = rules['device'][profile_tag] - result = _filter_ruleset_with_path(ruleset, path) - defer.returnValue((200, result)) else: raise UnrecognizedRequestError() @@ -252,16 +217,9 @@ def _rule_spec_from_path(path): scope = path[1] path = path[2:] - if scope not in ['global', 'device']: + if scope != 'global': raise UnrecognizedRequestError() - device = None - if scope == 'device': - if len(path) == 0: - raise UnrecognizedRequestError() - device = path[0] - path = path[1:] - if len(path) == 0: raise UnrecognizedRequestError() @@ -278,8 +236,6 @@ def _rule_spec_from_path(path): 'template': template, 'rule_id': rule_id } - if device: - spec['profile_tag'] = device path = path[1:] @@ -289,7 +245,7 @@ def _rule_spec_from_path(path): return spec -def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None): +def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): if rule_template in ['override', 'underride']: if 'conditions' not in req_obj: raise InvalidRuleException("Missing 'conditions'") @@ -322,12 +278,6 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None else: raise InvalidRuleException("Unknown rule template: %s" % (rule_template,)) - if device: - conditions.append({ - 'kind': 'device', - 'profile_tag': device - }) - if 'actions' not in req_obj: raise InvalidRuleException("No actions found") actions = req_obj['actions'] @@ -349,17 +299,6 @@ def _add_empty_priority_class_arrays(d): return d -def _profile_tag_from_conditions(conditions): - """ - Given a list of conditions, return the profile tag of the - device rule if there is one - """ - for c in conditions: - if c['kind'] == 'device': - return c['profile_tag'] - return None - - def _filter_ruleset_with_path(ruleset, path): if path == []: raise UnrecognizedRequestError( @@ -403,19 +342,11 @@ def _priority_class_from_spec(spec): raise InvalidRuleException("Unknown template: %s" % (spec['template'])) pc = PRIORITY_CLASS_MAP[spec['template']] - if spec['scope'] == 'device': - pc += len(PRIORITY_CLASS_MAP) - return pc def _priority_class_to_template_name(pc): - if pc > PRIORITY_CLASS_MAP['override']: - # per-device - prio_class_index = pc - len(PRIORITY_CLASS_MAP) - return PRIORITY_CLASS_INVERSE_MAP[prio_class_index] - else: - return PRIORITY_CLASS_INVERSE_MAP[pc] + return PRIORITY_CLASS_INVERSE_MAP[pc] def _rule_to_template(rule): @@ -445,23 +376,12 @@ def _rule_to_template(rule): return templaterule -def _strip_device_condition(rule): - for i, c in enumerate(rule['conditions']): - if c['kind'] == 'device': - del rule['conditions'][i] - return rule - - def _namespaced_rule_id_from_spec(spec): return _namespaced_rule_id(spec, spec['rule_id']) def _namespaced_rule_id(spec, rule_id): - if spec['scope'] == 'global': - scope = 'global' - else: - scope = 'device/%s' % (spec['profile_tag']) - return "%s/%s/%s" % (scope, spec['template'], rule_id) + return "global/%s/%s" % (spec['template'], rule_id) def _rule_id_from_namespaced(in_rule_id): diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 5547f1b112..4c662e6e3c 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -45,7 +45,7 @@ class PusherRestServlet(ClientV1RestServlet): ) defer.returnValue((200, {})) - reqd = ['profile_tag', 'kind', 'app_id', 'app_display_name', + reqd = ['kind', 'app_id', 'app_display_name', 'device_display_name', 'pushkey', 'lang', 'data'] missing = [] for i in reqd: @@ -73,14 +73,14 @@ class PusherRestServlet(ClientV1RestServlet): yield pusher_pool.add_pusher( user_id=user.to_string(), access_token=requester.access_token_id, - profile_tag=content['profile_tag'], kind=content['kind'], app_id=content['app_id'], app_display_name=content['app_display_name'], device_display_name=content['device_display_name'], pushkey=content['pushkey'], lang=content['lang'], - data=content['data'] + data=content['data'], + profile_tag=content.get('profile_tag', ""), ) except PusherConfigException as pce: raise SynapseError(400, "Config Error: " + pce.message, diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index d77a817682..5820539a92 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -27,15 +27,14 @@ class EventPushActionsStore(SQLBaseStore): def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): """ :param event: the event set actions for - :param tuples: list of tuples of (user_id, profile_tag, actions) + :param tuples: list of tuples of (user_id, actions) """ values = [] - for uid, profile_tag, actions in tuples: + for uid, actions in tuples: values.append({ 'room_id': event.room_id, 'event_id': event.event_id, 'user_id': uid, - 'profile_tag': profile_tag, 'actions': json.dumps(actions), 'stream_ordering': event.internal_metadata.stream_ordering, 'topological_ordering': event.depth, @@ -43,7 +42,7 @@ class EventPushActionsStore(SQLBaseStore): 'highlight': 1 if _action_has_highlight(actions) else 0, }) - for uid, _, __ in tuples: + for uid, __ in tuples: txn.call_after( self.get_unread_event_push_actions_by_room_for_user.invalidate_many, (event.room_id, uid) diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 8ec706178a..c23648cdbc 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -80,9 +80,9 @@ class PusherStore(SQLBaseStore): defer.returnValue(rows) @defer.inlineCallbacks - def add_pusher(self, user_id, access_token, profile_tag, kind, app_id, + def add_pusher(self, user_id, access_token, kind, app_id, app_display_name, device_display_name, - pushkey, pushkey_ts, lang, data): + pushkey, pushkey_ts, lang, data, profile_tag=""): try: next_id = yield self._pushers_id_gen.get_next() yield self._simple_upsert( @@ -95,12 +95,12 @@ class PusherStore(SQLBaseStore): dict( access_token=access_token, kind=kind, - profile_tag=profile_tag, app_display_name=app_display_name, device_display_name=device_display_name, ts=pushkey_ts, lang=lang, data=encode_canonical_json(data), + profile_tag=profile_tag, ), insertion_values=dict( id=next_id, -- cgit 1.4.1 From 42109a62a43ca0b8c1e0d3f797bbc70e0018ca5c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 18 Feb 2016 16:37:28 +0000 Subject: Remove unused param from get_max_token --- synapse/storage/account_data.py | 4 ++-- synapse/storage/events.py | 2 +- synapse/storage/receipts.py | 6 +++--- synapse/storage/stream.py | 2 +- synapse/storage/tags.py | 6 +++--- synapse/storage/util/id_generators.py | 4 +--- 6 files changed, 11 insertions(+), 13 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index b8387fc500..91cbf399b6 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -168,7 +168,7 @@ class AccountDataStore(SQLBaseStore): "add_room_account_data", add_account_data_txn, next_id ) - result = yield self._account_data_id_gen.get_max_token(self) + result = yield self._account_data_id_gen.get_max_token() defer.returnValue(result) @defer.inlineCallbacks @@ -207,7 +207,7 @@ class AccountDataStore(SQLBaseStore): "add_user_account_data", add_account_data_txn, next_id ) - result = yield self._account_data_id_gen.get_max_token(self) + result = yield self._account_data_id_gen.get_max_token() defer.returnValue(result) def _update_max_stream_id(self, txn, next_id): diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 3a5c6ee4b1..1dd3236829 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -131,7 +131,7 @@ class EventsStore(SQLBaseStore): except _RollbackButIsFineException: pass - max_persisted_id = yield self._stream_id_gen.get_max_token(self) + max_persisted_id = yield self._stream_id_gen.get_max_token() defer.returnValue((stream_ordering, max_persisted_id)) @defer.inlineCallbacks diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 4202a6b3dc..a7343c97f7 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -31,7 +31,7 @@ class ReceiptsStore(SQLBaseStore): super(ReceiptsStore, self).__init__(hs) self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token(None) + "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token() ) @cached(num_args=2) @@ -222,7 +222,7 @@ class ReceiptsStore(SQLBaseStore): defer.returnValue(results) def get_max_receipt_stream_id(self): - return self._receipts_id_gen.get_max_token(self) + return self._receipts_id_gen.get_max_token() def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, user_id, event_id, data, stream_id): @@ -347,7 +347,7 @@ class ReceiptsStore(SQLBaseStore): room_id, receipt_type, user_id, event_ids, data ) - max_persisted_id = yield self._stream_id_gen.get_max_token(self) + max_persisted_id = yield self._stream_id_gen.get_max_token() defer.returnValue((stream_id, max_persisted_id)) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index c236dafafb..8908d5b5da 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -531,7 +531,7 @@ class StreamStore(SQLBaseStore): @defer.inlineCallbacks def get_room_events_max_id(self, direction='f'): - token = yield self._stream_id_gen.get_max_token(self) + token = yield self._stream_id_gen.get_max_token() if direction != 'b': defer.returnValue("s%d" % (token,)) else: diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index e1a9c0c261..9551aa9739 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -30,7 +30,7 @@ class TagsStore(SQLBaseStore): Returns: A deferred int. """ - return self._account_data_id_gen.get_max_token(self) + return self._account_data_id_gen.get_max_token() @cached() def get_tags_for_user(self, user_id): @@ -147,7 +147,7 @@ class TagsStore(SQLBaseStore): self.get_tags_for_user.invalidate((user_id,)) - result = yield self._account_data_id_gen.get_max_token(self) + result = yield self._account_data_id_gen.get_max_token() defer.returnValue(result) @defer.inlineCallbacks @@ -169,7 +169,7 @@ class TagsStore(SQLBaseStore): self.get_tags_for_user.invalidate((user_id,)) - result = yield self._account_data_id_gen.get_max_token(self) + result = yield self._account_data_id_gen.get_max_token() defer.returnValue(result) def _update_revision_txn(self, txn, user_id, room_id, next_id): diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 5ce54f76de..ef5e4a4668 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -130,11 +130,9 @@ class StreamIdGenerator(object): return manager() - def get_max_token(self, *args): + def get_max_token(self): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. - - Used to take a DataStore param, which is no longer needed. """ with self._lock: if self._unfinished_ids: -- cgit 1.4.1 From e6c5e3f28a234be342cf5e4f15a925e65e82fd0d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 18 Feb 2016 16:39:28 +0000 Subject: Close cursor --- synapse/storage/__init__.py | 1 + 1 file changed, 1 insertion(+) (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index fcb968e8f4..9be1d12fac 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -212,6 +212,7 @@ class DataStore(RoomMemberStore, RoomStore, txn = db_conn.cursor() txn.execute(sql, (PresenceState.OFFLINE,)) rows = self.cursor_to_dict(txn) + txn.close() for row in rows: row["currently_active"] = bool(row["currently_active"]) -- cgit 1.4.1 From 6451fcd08580cc6142d15b4be5f8956ba496f178 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 23 Feb 2016 15:51:39 +0000 Subject: Create a new stream_id per presence update --- synapse/storage/presence.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 70ece56548..3ef91d34db 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -58,17 +58,20 @@ class UserPresenceState(namedtuple("UserPresenceState", class PresenceStore(SQLBaseStore): @defer.inlineCallbacks def update_presence(self, presence_states): - stream_id_manager = yield self._presence_id_gen.get_next(self) - with stream_id_manager as stream_id: + stream_ordering_manager = yield self._presence_id_gen.get_next_mult( + self, len(presence_states) + ) + + with stream_ordering_manager as stream_orderings: yield self.runInteraction( "update_presence", - self._update_presence_txn, stream_id, presence_states, + self._update_presence_txn, stream_orderings, presence_states, ) - defer.returnValue((stream_id, self._presence_id_gen.get_max_token())) + defer.returnValue((stream_orderings[-1], self._presence_id_gen.get_max_token())) - def _update_presence_txn(self, txn, stream_id, presence_states): - for state in presence_states: + def _update_presence_txn(self, txn, stream_orderings, presence_states): + for stream_id, state in zip(stream_orderings, presence_states): txn.call_after( self.presence_stream_cache.entity_has_changed, state.user_id, stream_id, -- cgit 1.4.1 From 33300673b7a6f79802f691ac121e720cb44c0dfc Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Wed, 24 Feb 2016 14:41:25 +0000 Subject: Generate guest access token on 3pid invites This means that following the same link across multiple sessions or devices can re-use the same guest account. Note that this is somewhat of an abuse vector; we can't throw up captchas on this flow, so this is a way of registering ephemeral accounts for spam, whose sign-up we don't rate limit. --- synapse/handlers/register.py | 15 ++++++++ synapse/handlers/room.py | 8 ++++ synapse/storage/registration.py | 44 ++++++++++++++++++++++ .../delta/30/threepid_guest_access_tokens.sql | 24 ++++++++++++ 4 files changed, 91 insertions(+) create mode 100644 synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql (limited to 'synapse/storage') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index f8959e5d82..6d155d57e7 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -349,3 +349,18 @@ class RegistrationHandler(BaseHandler): def auth_handler(self): return self.hs.get_handlers().auth_handler + + @defer.inlineCallbacks + def guest_access_token_for(self, medium, address, inviter_user_id): + access_token = yield self.store.get_3pid_guest_access_token(medium, address) + if access_token: + defer.returnValue(access_token) + + _, access_token = yield self.register( + generate_token=True, + make_guest=True + ) + access_token = yield self.store.save_or_get_3pid_guest_access_token( + medium, address, access_token, inviter_user_id + ) + defer.returnValue(access_token) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index eb9700a35b..d2de23a6cc 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -848,6 +848,13 @@ class RoomMemberHandler(BaseHandler): user. """ + registration_handler = self.hs.get_handlers().registration_handler + guest_access_token = yield registration_handler.guest_access_token_for( + medium=medium, + address=address, + inviter_user_id=inviter_user_id, + ) + is_url = "%s%s/_matrix/identity/api/v1/store-invite" % ( id_server_scheme, id_server, ) @@ -864,6 +871,7 @@ class RoomMemberHandler(BaseHandler): "sender": inviter_user_id, "sender_display_name": inviter_display_name, "sender_avatar_url": inviter_avatar_url, + "guest_access_token": guest_access_token, } ) # TODO: Check for success diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 967c732bda..03a9b66e4a 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -387,3 +387,47 @@ class RegistrationStore(SQLBaseStore): "find_next_generated_user_id", _find_next_generated_user_id ))) + + @defer.inlineCallbacks + def get_3pid_guest_access_token(self, medium, address): + ret = yield self._simple_select_one( + "threepid_guest_access_tokens", + { + "medium": medium, + "address": address + }, + ["guest_access_token"], True, 'get_3pid_guest_access_token' + ) + if ret: + defer.returnValue(ret["guest_access_token"]) + defer.returnValue(None) + + @defer.inlineCallbacks + def save_or_get_3pid_guest_access_token( + self, medium, address, access_token, inviter_user_id + ): + """ + Gets the 3pid's guest access token if exists, else saves access_token. + + :param medium (str): Medium of the 3pid. Must be "email". + :param address (str): 3pid address. + :param access_token (str): The access token to persist if none is + already persisted. + :param inviter_user_id (str): User ID of the inviter. + :return (deferred str): Whichever access token is persisted at the end + of this function call. + """ + def insert(txn): + txn.execute( + "INSERT INTO threepid_guest_access_tokens " + "(medium, address, guest_access_token, first_inviter) " + "VALUES (?, ?, ?, ?)", + (medium, address, access_token, inviter_user_id) + ) + + try: + yield self.runInteraction("save_3pid_guest_access_token", insert) + defer.returnValue(access_token) + except self.database_engine.module.IntegrityError: + ret = yield self.get_3pid_guest_access_token(medium, address) + defer.returnValue(ret) diff --git a/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql b/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql new file mode 100644 index 0000000000..0dd2f1360c --- /dev/null +++ b/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql @@ -0,0 +1,24 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Stores guest account access tokens generated for unbound 3pids. +CREATE TABLE threepid_guest_access_tokens( + medium TEXT, -- The medium of the 3pid. Must be "email". + address TEXT, -- The 3pid address. + guest_access_token TEXT, -- The access token for a guest user for this 3pid. + first_inviter TEXT -- User ID of the first user to invite this 3pid to a room. +); + +CREATE UNIQUE INDEX threepid_guest_access_tokens_index ON threepid_guest_access_tokens(medium, address); -- cgit 1.4.1 From de27f7fc79b785961181d13749468ae3e2019772 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 26 Feb 2016 14:28:19 +0000 Subject: Add support for changing the actions for default rules See matrix-org/matrix-doc#283 Works by adding dummy rules to the push rules table with a negative priority class and then using those rules to clobber the default rule actions when adding the default rules in ``list_with_base_rules`` --- synapse/push/baserules.py | 57 ++++++++++++++++++++++++++++++++----- synapse/rest/client/v1/push_rule.py | 31 +++++++++++++++++--- synapse/storage/push_rule.py | 25 ++++++++++++++++ 3 files changed, 102 insertions(+), 11 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 0832c77cb4..86a2998bcc 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -13,46 +13,67 @@ # limitations under the License. from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP +import copy def list_with_base_rules(rawrules): + """Combine the list of rules set by the user with the default push rules + + :param list rawrules: The rules the user has modified or set. + :returns: A new list with the rules set by the user combined with the + defaults. + """ ruleslist = [] + # Grab the base rules that the user has modified. + # The modified base rules have a priority_class of -1. + modified_base_rules = { + r['rule_id']: r for r in rawrules if r['priority_class'] < 0 + } + + # Remove the modified base rules from the list, They'll be added back + # in the default postions in the list. + rawrules = [r for r in rawrules if r['priority_class'] >= 0] + # shove the server default rules for each kind onto the end of each current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1] ruleslist.extend(make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules )) for r in rawrules: if r['priority_class'] < current_prio_class: while r['priority_class'] < current_prio_class: ruleslist.extend(make_base_append_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, )) current_prio_class -= 1 if current_prio_class > 0: ruleslist.extend(make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, )) ruleslist.append(r) while current_prio_class > 0: ruleslist.extend(make_base_append_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, )) current_prio_class -= 1 if current_prio_class > 0: ruleslist.extend(make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class] + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], + modified_base_rules, )) return ruleslist -def make_base_append_rules(kind): +def make_base_append_rules(kind, modified_base_rules): rules = [] if kind == 'override': @@ -62,15 +83,31 @@ def make_base_append_rules(kind): elif kind == 'content': rules = BASE_APPEND_CONTENT_RULES + # Copy the rules before modifying them + rules = copy.deepcopy(rules) + for r in rules: + # Only modify the actions, keep the conditions the same. + modified = modified_base_rules.get(r['rule_id']) + if modified: + r['actions'] = modified['actions'] + return rules -def make_base_prepend_rules(kind): +def make_base_prepend_rules(kind, modified_base_rules): rules = [] if kind == 'override': rules = BASE_PREPEND_OVERRIDE_RULES + # Copy the rules before modifying them + rules = copy.deepcopy(rules) + for r in rules: + # Only modify the actions, keep the conditions the same. + modified = modified_base_rules.get(r['rule_id']) + if modified: + r['actions'] = modified['actions'] + return rules @@ -263,18 +300,24 @@ BASE_APPEND_UNDERRIDE_RULES = [ ] +BASE_RULE_IDS = set() + for r in BASE_APPEND_CONTENT_RULES: r['priority_class'] = PRIORITY_CLASS_MAP['content'] r['default'] = True + BASE_RULE_IDS.add(r['rule_id']) for r in BASE_PREPEND_OVERRIDE_RULES: r['priority_class'] = PRIORITY_CLASS_MAP['override'] r['default'] = True + BASE_RULE_IDS.add(r['rule_id']) for r in BASE_APPEND_OVRRIDE_RULES: r['priority_class'] = PRIORITY_CLASS_MAP['override'] r['default'] = True + BASE_RULE_IDS.add(r['rule_id']) for r in BASE_APPEND_UNDERRIDE_RULES: r['priority_class'] = PRIORITY_CLASS_MAP['underride'] r['default'] = True + BASE_RULE_IDS.add(r['rule_id']) diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index d26e4cde3e..970a019223 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -22,7 +22,7 @@ from .base import ClientV1RestServlet, client_path_patterns from synapse.storage.push_rule import ( InconsistentRuleException, RuleNotFoundException ) -import synapse.push.baserules as baserules +from synapse.push.baserules import list_with_base_rules, BASE_RULE_IDS from synapse.push.rulekinds import ( PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP ) @@ -55,6 +55,10 @@ class PushRuleRestServlet(ClientV1RestServlet): yield self.set_rule_attr(requester.user.to_string(), spec, content) defer.returnValue((200, {})) + if spec['rule_id'].startswith('.'): + # Rule ids starting with '.' are reserved for server default rules. + raise SynapseError(400, "cannot add new rule_ids that start with '.'") + try: (conditions, actions) = _rule_tuple_from_request_object( spec['template'], @@ -128,7 +132,7 @@ class PushRuleRestServlet(ClientV1RestServlet): ruleslist.append(rule) # We're going to be mutating this a lot, so do a deep copy - ruleslist = copy.deepcopy(baserules.list_with_base_rules(ruleslist)) + ruleslist = copy.deepcopy(list_with_base_rules(ruleslist)) rules = {'global': {}, 'device': {}} @@ -197,6 +201,18 @@ class PushRuleRestServlet(ClientV1RestServlet): return self.hs.get_datastore().set_push_rule_enabled( user_id, namespaced_rule_id, val ) + elif spec['attr'] == 'actions': + actions = val.get('actions') + _check_actions(actions) + namespaced_rule_id = _namespaced_rule_id_from_spec(spec) + rule_id = spec['rule_id'] + is_default_rule = rule_id.startswith(".") + if is_default_rule: + if namespaced_rule_id not in BASE_RULE_IDS: + raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) + return self.hs.get_datastore().set_push_rule_actions( + user_id, namespaced_rule_id, actions, is_default_rule + ) else: raise UnrecognizedRequestError() @@ -274,6 +290,15 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): raise InvalidRuleException("No actions found") actions = req_obj['actions'] + _check_actions(actions) + + return conditions, actions + + +def _check_actions(actions): + if not isinstance(actions, list): + raise InvalidRuleException("No actions found") + for a in actions: if a in ['notify', 'dont_notify', 'coalesce']: pass @@ -282,8 +307,6 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): else: raise InvalidRuleException("Unrecognised action") - return conditions, actions - def _add_empty_priority_class_arrays(d): for pc in PRIORITY_CLASS_MAP.keys(): diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index e19a81e41f..bb5c14d912 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -294,6 +294,31 @@ class PushRuleStore(SQLBaseStore): self.get_push_rules_enabled_for_user.invalidate, (user_id,) ) + def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): + actions_json = json.dumps(actions) + + def set_push_rule_actions_txn(txn): + if is_default_rule: + # Add a dummy rule to the rules table with the user specified + # actions. + priority_class = -1 + priority = 1 + self._upsert_push_rule_txn( + txn, user_id, rule_id, priority_class, priority, + "[]", actions_json + ) + else: + self._simple_update_one_txn( + txn, + "push_rules", + {'user_name': user_id, 'rule_id': rule_id}, + {'actions': actions_json}, + ) + + return self.runInteraction( + "set_push_rule_actions", set_push_rule_actions_txn, + ) + class RuleNotFoundException(Exception): pass -- cgit 1.4.1 From 54172924c834a954fcbbc6224318140b9e95aa7d Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 1 Mar 2016 14:32:56 +0000 Subject: Load the current id in the IdGenerator constructor Rather than loading them lazily. This allows us to remove all the yield statements and spurious arguments for the get_next methods. It also allows us to replace all instances of get_next_txn with get_next since get_next no longer needs to access the db. --- synapse/storage/__init__.py | 14 +++---- synapse/storage/account_data.py | 8 ++-- synapse/storage/events.py | 6 +-- synapse/storage/presence.py | 4 +- synapse/storage/push_rule.py | 4 +- synapse/storage/pusher.py | 2 +- synapse/storage/receipts.py | 4 +- synapse/storage/registration.py | 6 +-- synapse/storage/state.py | 2 +- synapse/storage/tags.py | 8 ++-- synapse/storage/transactions.py | 2 +- synapse/storage/util/id_generators.py | 69 +++++++++++------------------------ 12 files changed, 52 insertions(+), 77 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 9be1d12fac..f257721ea3 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -115,13 +115,13 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "presence_stream", "stream_id" ) - self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) - self._state_groups_id_gen = IdGenerator("state_groups", "id", self) - self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) - self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self) - self._pushers_id_gen = IdGenerator("pushers", "id", self) - self._push_rule_id_gen = IdGenerator("push_rules", "id", self) - self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) + self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") + self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") + self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") + self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") + self._pushers_id_gen = IdGenerator(db_conn, "pushers", "id") + self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") + self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") events_max = self._stream_id_gen.get_max_token() event_cache_prefill, min_event_val = self._get_cache_dict( diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index 91cbf399b6..21a3240d9d 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -163,12 +163,12 @@ class AccountDataStore(SQLBaseStore): ) self._update_max_stream_id(txn, next_id) - with (yield self._account_data_id_gen.get_next(self)) as next_id: + with self._account_data_id_gen.get_next() as next_id: yield self.runInteraction( "add_room_account_data", add_account_data_txn, next_id ) - result = yield self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_max_token() defer.returnValue(result) @defer.inlineCallbacks @@ -202,12 +202,12 @@ class AccountDataStore(SQLBaseStore): ) self._update_max_stream_id(txn, next_id) - with (yield self._account_data_id_gen.get_next(self)) as next_id: + with self._account_data_id_gen.get_next() as next_id: yield self.runInteraction( "add_user_account_data", add_account_data_txn, next_id ) - result = yield self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_max_token() defer.returnValue(result) def _update_max_stream_id(self, txn, next_id): diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 1dd3236829..73a152bc07 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -75,8 +75,8 @@ class EventsStore(SQLBaseStore): yield stream_orderings stream_ordering_manager = stream_ordering_manager() else: - stream_ordering_manager = yield self._stream_id_gen.get_next_mult( - self, len(events_and_contexts) + stream_ordering_manager = self._stream_id_gen.get_next_mult( + len(events_and_contexts) ) with stream_ordering_manager as stream_orderings: @@ -109,7 +109,7 @@ class EventsStore(SQLBaseStore): stream_ordering = self.min_stream_token if stream_ordering is None: - stream_ordering_manager = yield self._stream_id_gen.get_next(self) + stream_ordering_manager = self._stream_id_gen.get_next() else: @contextmanager def stream_ordering_manager(): diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 3ef91d34db..eece7f8961 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -58,8 +58,8 @@ class UserPresenceState(namedtuple("UserPresenceState", class PresenceStore(SQLBaseStore): @defer.inlineCallbacks def update_presence(self, presence_states): - stream_ordering_manager = yield self._presence_id_gen.get_next_mult( - self, len(presence_states) + stream_ordering_manager = self._presence_id_gen.get_next_mult( + len(presence_states) ) with stream_ordering_manager as stream_orderings: diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index bb5c14d912..56e69495b1 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -226,7 +226,7 @@ class PushRuleStore(SQLBaseStore): if txn.rowcount == 0: # We didn't update a row with the given rule_id so insert one - push_rule_id = self._push_rule_id_gen.get_next_txn(txn) + push_rule_id = self._push_rule_id_gen.get_next() self._simple_insert_txn( txn, @@ -279,7 +279,7 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(ret) def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled): - new_id = self._push_rules_enable_id_gen.get_next_txn(txn) + new_id = self._push_rules_enable_id_gen.get_next() self._simple_upsert_txn( txn, "push_rules_enable", diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index c23648cdbc..7693ab9082 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -84,7 +84,7 @@ class PusherStore(SQLBaseStore): app_display_name, device_display_name, pushkey, pushkey_ts, lang, data, profile_tag=""): try: - next_id = yield self._pushers_id_gen.get_next() + next_id = self._pushers_id_gen.get_next() yield self._simple_upsert( "pushers", dict( diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index a7343c97f7..cd6dca4901 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -330,7 +330,7 @@ class ReceiptsStore(SQLBaseStore): "insert_receipt_conv", graph_to_linear ) - stream_id_manager = yield self._receipts_id_gen.get_next(self) + stream_id_manager = self._receipts_id_gen.get_next() with stream_id_manager as stream_id: have_persisted = yield self.runInteraction( "insert_linearized_receipt", @@ -347,7 +347,7 @@ class ReceiptsStore(SQLBaseStore): room_id, receipt_type, user_id, event_ids, data ) - max_persisted_id = yield self._stream_id_gen.get_max_token() + max_persisted_id = self._stream_id_gen.get_max_token() defer.returnValue((stream_id, max_persisted_id)) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 03a9b66e4a..ad1157f979 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -40,7 +40,7 @@ class RegistrationStore(SQLBaseStore): Raises: StoreError if there was a problem adding this. """ - next_id = yield self._access_tokens_id_gen.get_next() + next_id = self._access_tokens_id_gen.get_next() yield self._simple_insert( "access_tokens", @@ -62,7 +62,7 @@ class RegistrationStore(SQLBaseStore): Raises: StoreError if there was a problem adding this. """ - next_id = yield self._refresh_tokens_id_gen.get_next() + next_id = self._refresh_tokens_id_gen.get_next() yield self._simple_insert( "refresh_tokens", @@ -99,7 +99,7 @@ class RegistrationStore(SQLBaseStore): def _register(self, txn, user_id, token, password_hash, was_guest, make_guest): now = int(self.clock.time()) - next_id = self._access_tokens_id_gen.get_next_txn(txn) + next_id = self._access_tokens_id_gen.get_next() try: if was_guest: diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 372b540002..8ed8a21b0a 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -83,7 +83,7 @@ class StateStore(SQLBaseStore): if event.is_state(): state_events[(event.type, event.state_key)] = event - state_group = self._state_groups_id_gen.get_next_txn(txn) + state_group = self._state_groups_id_gen.get_next() self._simple_insert_txn( txn, table="state_groups", diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index 9551aa9739..1127b0bd7e 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -142,12 +142,12 @@ class TagsStore(SQLBaseStore): ) self._update_revision_txn(txn, user_id, room_id, next_id) - with (yield self._account_data_id_gen.get_next(self)) as next_id: + with self._account_data_id_gen.get_next() as next_id: yield self.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) - result = yield self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_max_token() defer.returnValue(result) @defer.inlineCallbacks @@ -164,12 +164,12 @@ class TagsStore(SQLBaseStore): txn.execute(sql, (user_id, room_id, tag)) self._update_revision_txn(txn, user_id, room_id, next_id) - with (yield self._account_data_id_gen.get_next(self)) as next_id: + with self._account_data_id_gen.get_next() as next_id: yield self.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) - result = yield self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_max_token() defer.returnValue(result) def _update_revision_txn(self, txn, user_id, room_id, next_id): diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 4475c451c1..d338dfcf0a 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -117,7 +117,7 @@ class TransactionStore(SQLBaseStore): def _prep_send_transaction(self, txn, transaction_id, destination, origin_server_ts): - next_id = self._transaction_id_gen.get_next_txn(txn) + next_id = self._transaction_id_gen.get_next() # First we find out what the prev_txns should be. # Since we know that we are only sending one transaction at a time, diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index ef5e4a4668..efe3f68e6e 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -13,51 +13,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from collections import deque import contextlib import threading class IdGenerator(object): - def __init__(self, table, column, store): + def __init__(self, db_conn, table, column): self.table = table self.column = column - self.store = store self._lock = threading.Lock() - self._next_id = None + cur = db_conn.cursor() + self._next_id = self._load_next_id(cur) + cur.close() - @defer.inlineCallbacks - def get_next(self): - if self._next_id is None: - yield self.store.runInteraction( - "IdGenerator_%s" % (self.table,), - self.get_next_txn, - ) + def _load_next_id(self, txn): + txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table,)) + val, = txn.fetchone() + return val + 1 if val else 1 + def get_next(self): with self._lock: i = self._next_id self._next_id += 1 - defer.returnValue(i) - - def get_next_txn(self, txn): - with self._lock: - if self._next_id: - i = self._next_id - self._next_id += 1 - return i - else: - txn.execute( - "SELECT MAX(%s) FROM %s" % (self.column, self.table,) - ) - - val, = txn.fetchone() - cur = val or 0 - cur += 1 - self._next_id = cur + 1 - - return cur + return i class StreamIdGenerator(object): @@ -69,7 +48,7 @@ class StreamIdGenerator(object): persistence of events can complete out of order. Usage: - with stream_id_gen.get_next_txn(txn) as stream_id: + with stream_id_gen.get_next() as stream_id: # ... persist event ... """ def __init__(self, db_conn, table, column): @@ -79,15 +58,21 @@ class StreamIdGenerator(object): self._lock = threading.Lock() cur = db_conn.cursor() - self._current_max = self._get_or_compute_current_max(cur) + self._current_max = self._load_current_max(cur) cur.close() self._unfinished_ids = deque() - def get_next(self, store): + def _load_current_max(self, txn): + txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table)) + rows = txn.fetchall() + val, = rows[0] + return int(val) if val else 1 + + def get_next(self): """ Usage: - with yield stream_id_gen.get_next as stream_id: + with stream_id_gen.get_next() as stream_id: # ... persist event ... """ with self._lock: @@ -106,10 +91,10 @@ class StreamIdGenerator(object): return manager() - def get_next_mult(self, store, n): + def get_next_mult(self, n): """ Usage: - with yield stream_id_gen.get_next(store, n) as stream_ids: + with stream_id_gen.get_next(n) as stream_ids: # ... persist events ... """ with self._lock: @@ -139,13 +124,3 @@ class StreamIdGenerator(object): return self._unfinished_ids[0] - 1 return self._current_max - - def _get_or_compute_current_max(self, txn): - with self._lock: - txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table)) - rows = txn.fetchall() - val, = rows[0] - - self._current_max = int(val) if val else 1 - - return self._current_max -- cgit 1.4.1 From f9af8962f8ea6201ed3910eb248b8668f1262fef Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 1 Mar 2016 14:46:31 +0000 Subject: Allow alias creators to delete aliases --- synapse/handlers/directory.py | 27 ++++++++++++++++++----- synapse/rest/client/v1/directory.py | 3 --- synapse/storage/directory.py | 15 ++++++++++++- synapse/storage/schema/delta/30/alias_creator.sql | 16 ++++++++++++++ 4 files changed, 51 insertions(+), 10 deletions(-) create mode 100644 synapse/storage/schema/delta/30/alias_creator.sql (limited to 'synapse/storage') diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index e0a778e7ff..cce6f76f0e 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -17,9 +17,9 @@ from twisted.internet import defer from ._base import BaseHandler -from synapse.api.errors import SynapseError, Codes, CodeMessageException +from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError from synapse.api.constants import EventTypes -from synapse.types import RoomAlias +from synapse.types import RoomAlias, UserID import logging import string @@ -38,7 +38,7 @@ class DirectoryHandler(BaseHandler): ) @defer.inlineCallbacks - def _create_association(self, room_alias, room_id, servers=None): + def _create_association(self, room_alias, room_id, servers=None, creator=None): # general association creation for both human users and app services for wchar in string.whitespace: @@ -60,7 +60,8 @@ class DirectoryHandler(BaseHandler): yield self.store.create_room_alias_association( room_alias, room_id, - servers + servers, + creator=creator, ) @defer.inlineCallbacks @@ -77,7 +78,7 @@ class DirectoryHandler(BaseHandler): 400, "This alias is reserved by an application service.", errcode=Codes.EXCLUSIVE ) - yield self._create_association(room_alias, room_id, servers) + yield self._create_association(room_alias, room_id, servers, creator=user_id) @defer.inlineCallbacks def create_appservice_association(self, service, room_alias, room_id, @@ -95,7 +96,11 @@ class DirectoryHandler(BaseHandler): def delete_association(self, user_id, room_alias): # association deletion for human users - # TODO Check if server admin + can_delete = yield self._user_can_delete_alias(room_alias, user_id) + if not can_delete: + raise AuthError( + 403, "You don't have permission to delete the alias.", + ) can_delete = yield self.can_modify_alias( room_alias, @@ -257,3 +262,13 @@ class DirectoryHandler(BaseHandler): return # either no interested services, or no service with an exclusive lock defer.returnValue(True) + + @defer.inlineCallbacks + def _user_can_delete_alias(self, alias, user_id): + creator = yield self.store.get_room_alias_creator(alias.to_string()) + + if creator and creator == user_id: + defer.returnValue(True) + + is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id)) + defer.returnValue(is_admin) diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 74ec1e50e0..55c22000fd 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -118,9 +118,6 @@ class ClientDirectoryServer(ClientV1RestServlet): requester = yield self.auth.get_user_by_req(request) user = requester.user - is_admin = yield self.auth.is_server_admin(user) - if not is_admin: - raise AuthError(403, "You need to be a server admin") room_alias = RoomAlias.from_string(room_alias) diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 1556619d5e..012a0b414a 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -70,13 +70,14 @@ class DirectoryStore(SQLBaseStore): ) @defer.inlineCallbacks - def create_room_alias_association(self, room_alias, room_id, servers): + def create_room_alias_association(self, room_alias, room_id, servers, creator=None): """ Creates an associatin between a room alias and room_id/servers Args: room_alias (RoomAlias) room_id (str) servers (list) + creator (str): Optional user_id of creator. Returns: Deferred @@ -87,6 +88,7 @@ class DirectoryStore(SQLBaseStore): { "room_alias": room_alias.to_string(), "room_id": room_id, + "creator": creator, }, desc="create_room_alias_association", ) @@ -107,6 +109,17 @@ class DirectoryStore(SQLBaseStore): ) self.get_aliases_for_room.invalidate((room_id,)) + def get_room_alias_creator(self, room_alias): + return self._simple_select_one_onecol( + table="room_aliases", + keyvalues={ + "room_alias": room_alias, + }, + retcol="creator", + desc="get_room_alias_creator", + allow_none=True + ) + @defer.inlineCallbacks def delete_room_alias(self, room_alias): room_id = yield self.runInteraction( diff --git a/synapse/storage/schema/delta/30/alias_creator.sql b/synapse/storage/schema/delta/30/alias_creator.sql new file mode 100644 index 0000000000..c9d0dde638 --- /dev/null +++ b/synapse/storage/schema/delta/30/alias_creator.sql @@ -0,0 +1,16 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE room_aliases ADD COLUMN creator TEXT; -- cgit 1.4.1 From 60a0f81c7a2da86bf959227a440e3f7a2b727bb5 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 1 Mar 2016 14:49:41 +0000 Subject: Add a /replication API for extracting the updates that happened on synapse This is necessary for replicating the data in synapse to be visible to a separate service because presence and typing notifications aren't stored in a database so won't be visible to another process. This API can be used to either get the raw data by requesting the tables themselves or to just receive notifications for updates by following the streams meta-stream. Returns updates for each table requested a JSON array of arrays with a row for each row in the table. Each table is prefixed by a header row with the: name of the table, current stream_id position for the table, number of rows, number of columns and the names of the columns. This is followed by the rows that have been added to the server since the requester last asked. The API has a timeout and is hooked up to the notifier so that a slave can long poll for updates. --- scripts-dev/tail-synapse.py | 67 ++++++++ synapse/app/homeserver.py | 4 + synapse/handlers/presence.py | 19 +++ synapse/handlers/typing.py | 14 ++ synapse/notifier.py | 48 ++++++ synapse/replication/__init__.py | 14 ++ synapse/replication/resource.py | 320 +++++++++++++++++++++++++++++++++++++ synapse/storage/account_data.py | 36 ++++- synapse/storage/events.py | 45 ++++++ synapse/storage/presence.py | 16 ++ synapse/storage/receipts.py | 16 ++ synapse/storage/tags.py | 53 ++++++ tests/replication/__init__.py | 14 ++ tests/replication/test_resource.py | 179 +++++++++++++++++++++ tests/utils.py | 5 +- 15 files changed, 846 insertions(+), 4 deletions(-) create mode 100644 scripts-dev/tail-synapse.py create mode 100644 synapse/replication/__init__.py create mode 100644 synapse/replication/resource.py create mode 100644 tests/replication/__init__.py create mode 100644 tests/replication/test_resource.py (limited to 'synapse/storage') diff --git a/scripts-dev/tail-synapse.py b/scripts-dev/tail-synapse.py new file mode 100644 index 0000000000..18be711e92 --- /dev/null +++ b/scripts-dev/tail-synapse.py @@ -0,0 +1,67 @@ +import requests +import collections +import sys +import time +import json + +Entry = collections.namedtuple("Entry", "name position rows") + +ROW_TYPES = {} + + +def row_type_for_columns(name, column_names): + column_names = tuple(column_names) + row_type = ROW_TYPES.get((name, column_names)) + if row_type is None: + row_type = collections.namedtuple(name, column_names) + ROW_TYPES[(name, column_names)] = row_type + return row_type + + +def parse_response(content): + streams = json.loads(content) + result = {} + for name, value in streams.items(): + row_type = row_type_for_columns(name, value["field_names"]) + position = value["position"] + rows = [row_type(*row) for row in value["rows"]] + result[name] = Entry(name, position, rows) + return result + + +def replicate(server, streams): + return parse_response(requests.get( + server + "/_synapse/replication", + verify=False, + params=streams + ).content) + + +def main(): + server = sys.argv[1] + + streams = None + while not streams: + try: + streams = { + row.name: row.position + for row in replicate(server, {"streams":"-1"})["streams"].rows + } + except requests.exceptions.ConnectionError as e: + time.sleep(0.1) + + while True: + try: + results = replicate(server, streams) + except: + sys.stdout.write("connection_lost("+ repr(streams) + ")\n") + break + for update in results.values(): + for row in update.rows: + sys.stdout.write(repr(row) + "\n") + streams[update.name] = update.position + + + +if __name__=='__main__': + main() diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 2b4be7bdd0..de5ee988f1 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -63,6 +63,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.util.logcontext import LoggingContext from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX from synapse.federation.transport.server import TransportLayerServer from synapse import events @@ -169,6 +170,9 @@ class SynapseHomeServer(HomeServer): if name == "metrics" and self.get_config().enable_metrics: resources[METRICS_PREFIX] = MetricsResource(self) + if name == "replication": + resources[REPLICATION_PREFIX] = ReplicationResource(self) + root_resource = create_resource_tree(resources) if tls: reactor.listenSSL( diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 08e38cdd25..d98e80086e 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -774,6 +774,25 @@ class PresenceHandler(BaseHandler): defer.returnValue(observer_user.to_string() in accepted_observers) + @defer.inlineCallbacks + def get_all_presence_updates(self, last_id, current_id): + """ + Gets a list of presence update rows from between the given stream ids. + Each row has: + - stream_id(str) + - user_id(str) + - state(str) + - last_active_ts(int) + - last_federation_update_ts(int) + - last_user_sync_ts(int) + - status_msg(int) + - currently_active(int) + """ + # TODO(markjh): replicate the unpersisted changes. + # This could use the in-memory stores for recent changes. + rows = yield self.store.get_all_presence_updates(last_id, current_id) + defer.returnValue(rows) + def should_notify(old_state, new_state): """Decides if a presence state change should be sent to interested parties. diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index b16d0017df..8ce27f49ec 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -25,6 +25,7 @@ from synapse.types import UserID import logging from collections import namedtuple +import ujson as json logger = logging.getLogger(__name__) @@ -219,6 +220,19 @@ class TypingNotificationHandler(BaseHandler): "typing_key", self._latest_room_serial, rooms=[room_id] ) + def get_all_typing_updates(self, last_id, current_id): + # TODO: Work out a way to do this without scanning the entire state. + rows = [] + for room_id, serial in self._room_serials.items(): + if last_id < serial and serial <= current_id: + typing = self._room_typing[room_id] + typing_bytes = json.dumps([ + u.to_string() for u in typing + ], ensure_ascii=False) + rows.append((serial, room_id, typing_bytes)) + rows.sort() + return rows + class TypingNotificationEventSource(object): def __init__(self, hs): diff --git a/synapse/notifier.py b/synapse/notifier.py index 560866b26e..3c36a20868 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -159,6 +159,8 @@ class Notifier(object): self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS ) + self.replication_deferred = ObservableDeferred(defer.Deferred()) + # This is not a very cheap test to perform, but it's only executed # when rendering the metrics page, which is likely once per minute at # most when scraping it. @@ -207,6 +209,8 @@ class Notifier(object): )) self._notify_pending_new_room_events(max_room_stream_id) + self.notify_replication() + def _notify_pending_new_room_events(self, max_room_stream_id): """Notify for the room events that were queued waiting for a previous event to be persisted. @@ -276,6 +280,8 @@ class Notifier(object): except: logger.exception("Failed to notify listener") + self.notify_replication() + @defer.inlineCallbacks def wait_for_events(self, user_id, timeout, callback, room_ids=None, from_token=StreamToken("s0", "0", "0", "0", "0")): @@ -479,3 +485,45 @@ class Notifier(object): room_streams = self.room_to_user_streams.setdefault(room_id, set()) room_streams.add(new_user_stream) new_user_stream.rooms.add(room_id) + + def notify_replication(self): + """Notify the any replication listeners that there's a new event""" + with PreserveLoggingContext(): + deferred = self.replication_deferred + self.replication_deferred = ObservableDeferred(defer.Deferred()) + deferred.callback(None) + + @defer.inlineCallbacks + def wait_for_replication(self, callback, timeout): + """Wait for an event to happen. + + :param callback: + Gets called whenever an event happens. If this returns a truthy + value then ``wait_for_replication`` returns, otherwise it waits + for another event. + :param int timeout: + How many milliseconds to wait for callback return a truthy value. + :returns: + A deferred that resolves with the value returned by the callback. + """ + listener = _NotificationListener(None) + + def timed_out(): + listener.deferred.cancel() + + timer = self.clock.call_later(timeout / 1000., timed_out) + while True: + listener.deferred = self.replication_deferred.observe() + result = yield callback() + if result: + break + + try: + with PreserveLoggingContext(): + yield listener.deferred + except defer.CancelledError: + break + + self.clock.cancel_call_later(timer, ignore_errs=True) + + defer.returnValue(result) diff --git a/synapse/replication/__init__.py b/synapse/replication/__init__.py new file mode 100644 index 0000000000..b7df13c9ee --- /dev/null +++ b/synapse/replication/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py new file mode 100644 index 0000000000..e0d039518d --- /dev/null +++ b/synapse/replication/resource.py @@ -0,0 +1,320 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.http.servlet import parse_integer, parse_string +from synapse.http.server import request_handler, finish_request + +from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET +from twisted.internet import defer + +import ujson as json + +import collections +import logging + +logger = logging.getLogger(__name__) + +REPLICATION_PREFIX = "/_synapse/replication" + +STREAM_NAMES = ( + ("events",), + ("presence",), + ("typing",), + ("receipts",), + ("user_account_data", "room_account_data", "tag_account_data",), + ("backfill",), +) + + +class ReplicationResource(Resource): + """ + HTTP endpoint for extracting data from synapse. + + The streams of data returned by the endpoint are controlled by the + parameters given to the API. To return a given stream pass a query + parameter with a position in the stream to return data from or the + special value "-1" to return data from the start of the stream. + + If there is no data for any of the supplied streams after the given + position then the request will block until there is data for one + of the streams. This allows clients to long-poll this API. + + The possible streams are: + + * "streams": A special stream returing the positions of other streams. + * "events": The new events seen on the server. + * "presence": Presence updates. + * "typing": Typing updates. + * "receipts": Receipt updates. + * "user_account_data": Top-level per user account data. + * "room_account_data: Per room per user account data. + * "tag_account_data": Per room per user tags. + * "backfill": Old events that have been backfilled from other servers. + + The API takes two additional query parameters: + + * "timeout": How long to wait before returning an empty response. + * "limit": The maximum number of rows to return for the selected streams. + + The response is a JSON object with keys for each stream with updates. Under + each key is a JSON object with: + + * "postion": The current position of the stream. + * "field_names": The names of the fields in each row. + * "rows": The updates as an array of arrays. + + There are a number of ways this API could be used: + + 1) To replicate the contents of the backing database to another database. + 2) To be notified when the contents of a shared backing database changes. + 3) To "tail" the activity happening on a server for debugging. + + In the first case the client would track all of the streams and store it's + own copy of the data. + + In the second case the client might theoretically just be able to follow + the "streams" stream to track where the other streams are. However in + practise it will probably need to get the contents of the streams in + order to expire the any in-memory caches. Whether it gets the contents + of the streams from this replication API or directly from the backing + store is a matter of taste. + + In the third case the client would use the "streams" stream to find what + streams are available and their current positions. Then it can start + long-polling this replication API for new data on those streams. + """ + + isLeaf = True + + def __init__(self, hs): + Resource.__init__(self) # Resource is old-style, so no super() + + self.version_string = hs.version_string + self.store = hs.get_datastore() + self.sources = hs.get_event_sources() + self.presence_handler = hs.get_handlers().presence_handler + self.typing_handler = hs.get_handlers().typing_notification_handler + self.notifier = hs.notifier + + def render_GET(self, request): + self._async_render_GET(request) + return NOT_DONE_YET + + @defer.inlineCallbacks + def current_replication_token(self): + stream_token = yield self.sources.get_current_token() + backfill_token = yield self.store.get_current_backfill_token() + + defer.returnValue(_ReplicationToken( + stream_token.room_stream_id, + int(stream_token.presence_key), + int(stream_token.typing_key), + int(stream_token.receipt_key), + int(stream_token.account_data_key), + backfill_token, + )) + + @request_handler + @defer.inlineCallbacks + def _async_render_GET(self, request): + limit = parse_integer(request, "limit", 100) + timeout = parse_integer(request, "timeout", 10 * 1000) + + request.setHeader(b"Content-Type", b"application/json") + writer = _Writer(request) + + @defer.inlineCallbacks + def replicate(): + current_token = yield self.current_replication_token() + logger.info("Replicating up to %r", current_token) + + yield self.account_data(writer, current_token, limit) + yield self.events(writer, current_token, limit) + yield self.presence(writer, current_token) # TODO: implement limit + yield self.typing(writer, current_token) # TODO: implement limit + yield self.receipts(writer, current_token, limit) + self.streams(writer, current_token) + + logger.info("Replicated %d rows", writer.total) + defer.returnValue(writer.total) + + yield self.notifier.wait_for_replication(replicate, timeout) + + writer.finish() + + def streams(self, writer, current_token): + request_token = parse_string(writer.request, "streams") + + streams = [] + + if request_token is not None: + if request_token == "-1": + for names, position in zip(STREAM_NAMES, current_token): + streams.extend((name, position) for name in names) + else: + items = zip( + STREAM_NAMES, + current_token, + _ReplicationToken(request_token) + ) + for names, current_id, last_id in items: + if last_id < current_id: + streams.extend((name, current_id) for name in names) + + if streams: + writer.write_header_and_rows( + "streams", streams, ("name", "position"), + position=str(current_token) + ) + + @defer.inlineCallbacks + def events(self, writer, current_token, limit): + request_events = parse_integer(writer.request, "events") + request_backfill = parse_integer(writer.request, "backfill") + + if request_events is not None or request_backfill is not None: + if request_events is None: + request_events = current_token.events + if request_backfill is None: + request_backfill = current_token.backfill + events_rows, backfill_rows = yield self.store.get_all_new_events( + request_backfill, request_events, + current_token.backfill, current_token.events, + limit + ) + writer.write_header_and_rows( + "events", events_rows, ("position", "internal", "json") + ) + writer.write_header_and_rows( + "backfill", backfill_rows, ("position", "internal", "json") + ) + + @defer.inlineCallbacks + def presence(self, writer, current_token): + current_position = current_token.presence + + request_presence = parse_integer(writer.request, "presence") + + if request_presence is not None: + presence_rows = yield self.presence_handler.get_all_presence_updates( + request_presence, current_position + ) + writer.write_header_and_rows("presence", presence_rows, ( + "position", "user_id", "state", "last_active_ts", + "last_federation_update_ts", "last_user_sync_ts", + "status_msg", "currently_active", + )) + + @defer.inlineCallbacks + def typing(self, writer, current_token): + current_position = current_token.presence + + request_typing = parse_integer(writer.request, "typing") + + if request_typing is not None: + typing_rows = yield self.typing_handler.get_all_typing_updates( + request_typing, current_position + ) + writer.write_header_and_rows("typing", typing_rows, ( + "position", "room_id", "typing" + )) + + @defer.inlineCallbacks + def receipts(self, writer, current_token, limit): + current_position = current_token.receipts + + request_receipts = parse_integer(writer.request, "receipts") + + if request_receipts is not None: + receipts_rows = yield self.store.get_all_updated_receipts( + request_receipts, current_position, limit + ) + writer.write_header_and_rows("receipts", receipts_rows, ( + "position", "room_id", "receipt_type", "user_id", "event_id", "data" + )) + + @defer.inlineCallbacks + def account_data(self, writer, current_token, limit): + current_position = current_token.account_data + + user_account_data = parse_integer(writer.request, "user_account_data") + room_account_data = parse_integer(writer.request, "room_account_data") + tag_account_data = parse_integer(writer.request, "tag_account_data") + + if user_account_data is not None or room_account_data is not None: + if user_account_data is None: + user_account_data = current_position + if room_account_data is None: + room_account_data = current_position + user_rows, room_rows = yield self.store.get_all_updated_account_data( + user_account_data, room_account_data, current_position, limit + ) + writer.write_header_and_rows("user_account_data", user_rows, ( + "position", "user_id", "type", "content" + )) + writer.write_header_and_rows("room_account_data", room_rows, ( + "position", "user_id", "room_id", "type", "content" + )) + + if tag_account_data is not None: + tag_rows = yield self.store.get_all_updated_tags( + tag_account_data, current_position, limit + ) + writer.write_header_and_rows("tag_account_data", tag_rows, ( + "position", "user_id", "room_id", "tags" + )) + + +class _Writer(object): + """Writes the streams as a JSON object as the response to the request""" + def __init__(self, request): + self.streams = {} + self.request = request + self.total = 0 + + def write_header_and_rows(self, name, rows, fields, position=None): + if not rows: + return + + if position is None: + position = rows[-1][0] + + self.streams[name] = { + "position": str(position), + "field_names": fields, + "rows": rows, + } + + self.total += len(rows) + + def finish(self): + self.request.write(json.dumps(self.streams, ensure_ascii=False)) + finish_request(self.request) + + +class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( + "events", "presence", "typing", "receipts", "account_data", "backfill", +))): + __slots__ = [] + + def __new__(cls, *args): + if len(args) == 1: + return cls(*(int(value) for value in args[0].split("_"))) + else: + return super(_ReplicationToken, cls).__new__(cls, *args) + + def __str__(self): + return "_".join(str(value) for value in self) diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index 91cbf399b6..05f98a9a29 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -83,8 +83,40 @@ class AccountDataStore(SQLBaseStore): "get_account_data_for_room", get_account_data_for_room_txn ) - def get_updated_account_data_for_user(self, user_id, stream_id, room_ids=None): - """Get all the client account_data for a that's changed. + def get_all_updated_account_data(self, last_global_id, last_room_id, + current_id, limit): + """Get all the client account_data that has changed on the server + Args: + last_global_id(int): The position to fetch from for top level data + last_room_id(int): The position to fetch from for per room data + current_id(int): The position to fetch up to. + Returns: + A deferred pair of lists of tuples of stream_id int, user_id string, + room_id string, type string, and content string. + """ + def get_updated_account_data_txn(txn): + sql = ( + "SELECT stream_id, user_id, account_data_type, content" + " FROM account_data WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_global_id, current_id, limit)) + global_results = txn.fetchall() + + sql = ( + "SELECT stream_id, user_id, room_id, account_data_type, content" + " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_room_id, current_id, limit)) + room_results = txn.fetchall() + return (global_results, room_results) + return self.runInteraction( + "get_all_updated_account_data_txn", get_updated_account_data_txn + ) + + def get_updated_account_data_for_user(self, user_id, stream_id): + """Get all the client account_data for a that's changed for a user Args: user_id(str): The user to get the account_data for. diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 1dd3236829..c0872dd7e2 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1064,3 +1064,48 @@ class EventsStore(SQLBaseStore): yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME) defer.returnValue(result) + + def get_current_backfill_token(self): + """The current minimum token that backfilled events have reached""" + + # TODO: Fix race with the persit_event txn by using one of the + # stream id managers + return -self.min_stream_token + + def get_all_new_events(self, last_backfill_id, last_forward_id, + current_backfill_id, current_forward_id, limit): + """Get all the new events that have arrived at the server either as + new events or as backfilled events""" + def get_all_new_events_txn(txn): + sql = ( + "SELECT e.stream_ordering, ej.internal_metadata, ej.json" + " FROM events as e" + " JOIN event_json as ej" + " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" + " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?" + " ORDER BY e.stream_ordering ASC" + " LIMIT ?" + ) + if last_forward_id != current_forward_id: + txn.execute(sql, (last_forward_id, current_forward_id, limit)) + new_forward_events = txn.fetchall() + else: + new_forward_events = [] + + sql = ( + "SELECT -e.stream_ordering, ej.internal_metadata, ej.json" + " FROM events as e" + " JOIN event_json as ej" + " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" + " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?" + " ORDER BY e.stream_ordering DESC" + " LIMIT ?" + ) + if last_backfill_id != current_backfill_id: + txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) + new_backfill_events = txn.fetchall() + else: + new_backfill_events = [] + + return (new_forward_events, new_backfill_events) + return self.runInteraction("get_all_new_events", get_all_new_events_txn) diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 3ef91d34db..de15741893 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -115,6 +115,22 @@ class PresenceStore(SQLBaseStore): args ) + def get_all_presence_updates(self, last_id, current_id): + def get_all_presence_updates_txn(txn): + sql = ( + "SELECT stream_id, user_id, state, last_active_ts," + " last_federation_update_ts, last_user_sync_ts, status_msg," + " currently_active" + " FROM presence_stream" + " WHERE ? < stream_id AND stream_id <= ?" + ) + txn.execute(sql, (last_id, current_id)) + return txn.fetchall() + + return self.runInteraction( + "get_all_presence_updates", get_all_presence_updates_txn + ) + @defer.inlineCallbacks def get_presence_for_users(self, user_ids): rows = yield self._simple_select_many_batch( diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index a7343c97f7..6567fa844f 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -390,3 +390,19 @@ class ReceiptsStore(SQLBaseStore): "data": json.dumps(data), } ) + + def get_all_updated_receipts(self, last_id, current_id, limit): + def get_all_updated_receipts_txn(txn): + sql = ( + "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" + " FROM receipts_linearized" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + " LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + + return txn.fetchall() + return self.runInteraction( + "get_all_updated_receipts", get_all_updated_receipts_txn + ) diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index 9551aa9739..b225f508a5 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -58,6 +58,59 @@ class TagsStore(SQLBaseStore): return deferred + @defer.inlineCallbacks + def get_all_updated_tags(self, last_id, current_id, limit): + """Get all the client tags that have changed on the server + Args: + last_id(int): The position to fetch from. + current_id(int): The position to fetch up to. + Returns: + A deferred list of tuples of stream_id int, user_id string, + room_id string, tag string and content string. + """ + def get_all_updated_tags_txn(txn): + sql = ( + "SELECT stream_id, user_id, room_id" + " FROM room_tags_revisions as r" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + + tag_ids = yield self.runInteraction( + "get_all_updated_tags", get_all_updated_tags_txn + ) + + def get_tag_content(txn, tag_ids): + sql = ( + "SELECT tag, content" + " FROM room_tags" + " WHERE user_id=? AND room_id=?" + ) + results = [] + for stream_id, user_id, room_id in tag_ids: + txn.execute(sql, (user_id, room_id)) + tags = [] + for tag, content in txn.fetchall(): + tags.append(json.dumps(tag) + ":" + content) + tag_json = "{" + ",".join(tags) + "}" + results.append((stream_id, user_id, room_id, tag_json)) + + return results + + batch_size = 50 + results = [] + for i in xrange(0, len(tag_ids), batch_size): + tags = yield self.runInteraction( + "get_all_updated_tag_content", + get_tag_content, + tag_ids[i:i + batch_size], + ) + results.extend(tags) + + defer.returnValue(results) + @defer.inlineCallbacks def get_updated_tags(self, user_id, stream_id): """Get all the tags for the rooms where the tags have changed since the diff --git a/tests/replication/__init__.py b/tests/replication/__init__.py new file mode 100644 index 0000000000..b7df13c9ee --- /dev/null +++ b/tests/replication/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py new file mode 100644 index 0000000000..38daaf87e2 --- /dev/null +++ b/tests/replication/test_resource.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.replication.resource import ReplicationResource +from synapse.types import Requester, UserID + +from twisted.internet import defer +from tests import unittest +from tests.utils import setup_test_homeserver +from mock import Mock, NonCallableMock +import json +import contextlib + + +class ReplicationResourceCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver( + "red", + http_client=None, + replication_layer=Mock(), + ratelimiter=NonCallableMock(spec_set=[ + "send_message", + ]), + ) + self.user = UserID.from_string("@seeing:red") + + self.hs.get_ratelimiter().send_message.return_value = (True, 0) + + self.resource = ReplicationResource(self.hs) + + @defer.inlineCallbacks + def test_streams(self): + # Passing "-1" returns the current stream positions + code, body = yield self.get(streams="-1") + self.assertEquals(code, 200) + self.assertEquals(body["streams"]["field_names"], ["name", "position"]) + position = body["streams"]["position"] + # Passing the current position returns an empty response after the + # timeout + get = self.get(streams=str(position), timeout="0") + self.hs.clock.advance_time_msec(1) + code, body = yield get + self.assertEquals(code, 200) + self.assertEquals(body, {}) + + @defer.inlineCallbacks + def test_events(self): + get = self.get(events="-1", timeout="0") + yield self.hs.get_handlers().room_creation_handler.create_room( + Requester(self.user, "", False), {} + ) + code, body = yield get + self.assertEquals(code, 200) + self.assertEquals(body["events"]["field_names"], [ + "position", "internal", "json" + ]) + + @defer.inlineCallbacks + def test_presence(self): + get = self.get(presence="-1") + yield self.hs.get_handlers().presence_handler.set_state( + self.user, {"presence": "online"} + ) + code, body = yield get + self.assertEquals(code, 200) + self.assertEquals(body["presence"]["field_names"], [ + "position", "user_id", "state", "last_active_ts", + "last_federation_update_ts", "last_user_sync_ts", + "status_msg", "currently_active", + ]) + + @defer.inlineCallbacks + def test_typing(self): + room_id = yield self.create_room() + get = self.get(typing="-1") + yield self.hs.get_handlers().typing_notification_handler.started_typing( + self.user, self.user, room_id, timeout=2 + ) + code, body = yield get + self.assertEquals(code, 200) + self.assertEquals(body["typing"]["field_names"], [ + "position", "room_id", "typing" + ]) + + @defer.inlineCallbacks + def test_receipts(self): + room_id = yield self.create_room() + event_id = yield self.send_text_message(room_id, "Hello, World") + get = self.get(receipts="-1") + yield self.hs.get_handlers().receipts_handler.received_client_receipt( + room_id, "m.read", self.user.to_string(), event_id + ) + code, body = yield get + self.assertEquals(code, 200) + self.assertEquals(body["receipts"]["field_names"], [ + "position", "room_id", "receipt_type", "user_id", "event_id", "data" + ]) + + def _test_timeout(stream): + """Check that a request for the given stream timesout""" + @defer.inlineCallbacks + def test_timeout(self): + get = self.get(**{stream: "-1", "timeout": "0"}) + self.hs.clock.advance_time_msec(1) + code, body = yield get + self.assertEquals(code, 200) + self.assertEquals(body, {}) + test_timeout.__name__ = "test_timeout_%s" % (stream) + return test_timeout + + test_timeout_events = _test_timeout("events") + test_timeout_presence = _test_timeout("presence") + test_timeout_typing = _test_timeout("typing") + test_timeout_receipts = _test_timeout("receipts") + test_timeout_user_account_data = _test_timeout("user_account_data") + test_timeout_room_account_data = _test_timeout("room_account_data") + test_timeout_tag_account_data = _test_timeout("tag_account_data") + test_timeout_backfill = _test_timeout("backfill") + + @defer.inlineCallbacks + def send_text_message(self, room_id, message): + handler = self.hs.get_handlers().message_handler + event = yield handler.create_and_send_nonmember_event({ + "type": "m.room.message", + "content": {"body": "message", "msgtype": "m.text"}, + "room_id": room_id, + "sender": self.user.to_string(), + }) + defer.returnValue(event.event_id) + + @defer.inlineCallbacks + def create_room(self): + result = yield self.hs.get_handlers().room_creation_handler.create_room( + Requester(self.user, "", False), {} + ) + defer.returnValue(result["room_id"]) + + @defer.inlineCallbacks + def get(self, **params): + request = NonCallableMock(spec_set=[ + "write", "finish", "setResponseCode", "setHeader", "args", + "method", "processing" + ]) + + request.method = "GET" + request.args = {k: [v] for k, v in params.items()} + + @contextlib.contextmanager + def processing(): + yield + request.processing = processing + + yield self.resource._async_render_GET(request) + self.assertTrue(request.finish.called) + + if request.setResponseCode.called: + response_code = request.setResponseCode.call_args[0][0] + else: + response_code = 200 + + response_json = "".join( + call[0][0] for call in request.write.call_args_list + ) + response_body = json.loads(response_json) + + defer.returnValue((response_code, response_body)) diff --git a/tests/utils.py b/tests/utils.py index bf7a31ff9e..dfbee5c23a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -239,9 +239,10 @@ class MockClock(object): def looping_call(self, function, interval): pass - def cancel_call_later(self, timer): + def cancel_call_later(self, timer, ignore_errs=False): if timer[2]: - raise Exception("Cannot cancel an expired timer") + if not ignore_errs: + raise Exception("Cannot cancel an expired timer") timer[2] = True self.timers = [t for t in self.timers if t != timer] -- cgit 1.4.1 From a1cf9e3bf343c3e5adb8dce7923726aa9b09115e Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 1 Mar 2016 13:35:37 +0000 Subject: Add a stream for push rule updates --- synapse/storage/__init__.py | 5 +- synapse/storage/_base.py | 25 ++- synapse/storage/push_rule.py | 173 ++++++++++++++++----- .../storage/schema/delta/30/push_rule_stream.sql | 38 +++++ synapse/storage/util/id_generators.py | 84 ++++++---- 5 files changed, 251 insertions(+), 74 deletions(-) create mode 100644 synapse/storage/schema/delta/30/push_rule_stream.sql (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index f257721ea3..e2d7b52569 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -45,7 +45,7 @@ from .search import SearchStore from .tags import TagsStore from .account_data import AccountDataStore -from util.id_generators import IdGenerator, StreamIdGenerator +from util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator from synapse.api.constants import PresenceState from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -122,6 +122,9 @@ class DataStore(RoomMemberStore, RoomStore, self._pushers_id_gen = IdGenerator(db_conn, "pushers", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") + self._push_rules_stream_id_gen = ChainedIdGenerator( + self._stream_id_gen, db_conn, "push_rules_stream", "stream_id" + ) events_max = self._stream_id_gen.get_max_token() event_cache_prefill, min_event_val = self._get_cache_dict( diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 2e97ac84a8..7dc67ecd57 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -766,6 +766,19 @@ class SQLBaseStore(object): """Executes a DELETE query on the named table, expecting to delete a single row. + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + """ + return self.runInteraction( + desc, self._simple_delete_one_txn, table, keyvalues + ) + + @staticmethod + def _simple_delete_one_txn(txn, table, keyvalues): + """Executes a DELETE query on the named table, expecting to delete a + single row. + Args: table : string giving the table name keyvalues : dict of column names and values to select the row with @@ -775,13 +788,11 @@ class SQLBaseStore(object): " AND ".join("%s = ?" % (k, ) for k in keyvalues) ) - def func(txn): - txn.execute(sql, keyvalues.values()) - if txn.rowcount == 0: - raise StoreError(404, "No row found") - if txn.rowcount > 1: - raise StoreError(500, "more than one row matched") - return self.runInteraction(desc, func) + txn.execute(sql, keyvalues.values()) + if txn.rowcount == 0: + raise StoreError(404, "No row found") + if txn.rowcount > 1: + raise StoreError(500, "more than one row matched") @staticmethod def _simple_delete_txn(txn, table, keyvalues): diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 56e69495b1..f3ebd49492 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -99,30 +99,31 @@ class PushRuleStore(SQLBaseStore): results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled'] defer.returnValue(results) + @defer.inlineCallbacks def add_push_rule( self, user_id, rule_id, priority_class, conditions, actions, before=None, after=None ): conditions_json = json.dumps(conditions) actions_json = json.dumps(actions) - - if before or after: - return self.runInteraction( - "_add_push_rule_relative_txn", - self._add_push_rule_relative_txn, - user_id, rule_id, priority_class, - conditions_json, actions_json, before, after, - ) - else: - return self.runInteraction( - "_add_push_rule_highest_priority_txn", - self._add_push_rule_highest_priority_txn, - user_id, rule_id, priority_class, - conditions_json, actions_json, - ) + with self._push_rules_stream_id_gen.get_next() as (stream_id, stream_ordering): + if before or after: + yield self.runInteraction( + "_add_push_rule_relative_txn", + self._add_push_rule_relative_txn, + stream_id, stream_ordering, user_id, rule_id, priority_class, + conditions_json, actions_json, before, after, + ) + else: + yield self.runInteraction( + "_add_push_rule_highest_priority_txn", + self._add_push_rule_highest_priority_txn, + stream_id, stream_ordering, user_id, rule_id, priority_class, + conditions_json, actions_json, + ) def _add_push_rule_relative_txn( - self, txn, user_id, rule_id, priority_class, + self, txn, stream_id, stream_ordering, user_id, rule_id, priority_class, conditions_json, actions_json, before, after ): # Lock the table since otherwise we'll have annoying races between the @@ -174,12 +175,12 @@ class PushRuleStore(SQLBaseStore): txn.execute(sql, (user_id, priority_class, new_rule_priority)) self._upsert_push_rule_txn( - txn, user_id, rule_id, priority_class, new_rule_priority, - conditions_json, actions_json, + txn, stream_id, stream_ordering, user_id, rule_id, priority_class, + new_rule_priority, conditions_json, actions_json, ) def _add_push_rule_highest_priority_txn( - self, txn, user_id, rule_id, priority_class, + self, txn, stream_id, stream_ordering, user_id, rule_id, priority_class, conditions_json, actions_json ): # Lock the table since otherwise we'll have annoying races between the @@ -201,13 +202,13 @@ class PushRuleStore(SQLBaseStore): self._upsert_push_rule_txn( txn, - user_id, rule_id, priority_class, new_prio, + stream_id, stream_ordering, user_id, rule_id, priority_class, new_prio, conditions_json, actions_json, ) def _upsert_push_rule_txn( - self, txn, user_id, rule_id, priority_class, - priority, conditions_json, actions_json + self, txn, stream_id, stream_ordering, user_id, rule_id, priority_class, + priority, conditions_json, actions_json, update_stream=True ): """Specialised version of _simple_upsert_txn that picks a push_rule_id using the _push_rule_id_gen if it needs to insert the rule. It assumes @@ -242,6 +243,23 @@ class PushRuleStore(SQLBaseStore): }, ) + if update_stream: + self._simple_insert_txn( + txn, + table="push_rules_stream", + values={ + "stream_id": stream_id, + "stream_ordering": stream_ordering, + "user_id": user_id, + "rule_id": rule_id, + "op": "ADD", + "priority_class": priority_class, + "priority": priority, + "conditions": conditions_json, + "actions": actions_json, + } + ) + txn.call_after( self.get_push_rules_for_user.invalidate, (user_id,) ) @@ -260,25 +278,47 @@ class PushRuleStore(SQLBaseStore): user_id (str): The matrix ID of the push rule owner rule_id (str): The rule_id of the rule to be deleted """ - yield self._simple_delete_one( - "push_rules", - {'user_name': user_id, 'rule_id': rule_id}, - desc="delete_push_rule", - ) + def delete_push_rule_txn(txn, stream_id, stream_ordering): + self._simple_delete_one_txn( + txn, + "push_rules", + {'user_name': user_id, 'rule_id': rule_id}, + ) + self._simple_insert_txn( + txn, + table="push_rules_stream", + values={ + "stream_id": stream_id, + "stream_ordering": stream_ordering, + "user_id": user_id, + "rule_id": rule_id, + "op": "DELETE", + } + ) + txn.call_after( + self.get_push_rules_for_user.invalidate, (user_id,) + ) + txn.call_after( + self.get_push_rules_enabled_for_user.invalidate, (user_id,) + ) - self.get_push_rules_for_user.invalidate((user_id,)) - self.get_push_rules_enabled_for_user.invalidate((user_id,)) + with self._push_rules_stream_id_gen.get_next() as (stream_id, stream_ordering): + yield self.runInteraction( + "delete_push_rule", delete_push_rule_txn, stream_id, stream_ordering + ) @defer.inlineCallbacks def set_push_rule_enabled(self, user_id, rule_id, enabled): - ret = yield self.runInteraction( - "_set_push_rule_enabled_txn", - self._set_push_rule_enabled_txn, - user_id, rule_id, enabled - ) - defer.returnValue(ret) + with self._push_rules_stream_id_gen.get_next() as (stream_id, stream_ordering): + yield self.runInteraction( + "_set_push_rule_enabled_txn", + self._set_push_rule_enabled_txn, + stream_id, stream_ordering, user_id, rule_id, enabled + ) - def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled): + def _set_push_rule_enabled_txn( + self, txn, stream_id, stream_ordering, user_id, rule_id, enabled + ): new_id = self._push_rules_enable_id_gen.get_next() self._simple_upsert_txn( txn, @@ -287,6 +327,19 @@ class PushRuleStore(SQLBaseStore): {'enabled': 1 if enabled else 0}, {'id': new_id}, ) + + self._simple_insert_txn( + txn, + "push_rules_stream", + values={ + "stream_id": stream_id, + "stream_ordering": stream_ordering, + "user_id": user_id, + "rule_id": rule_id, + "op": "ENABLE" if enabled else "DISABLE", + } + ) + txn.call_after( self.get_push_rules_for_user.invalidate, (user_id,) ) @@ -294,18 +347,20 @@ class PushRuleStore(SQLBaseStore): self.get_push_rules_enabled_for_user.invalidate, (user_id,) ) + @defer.inlineCallbacks def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): actions_json = json.dumps(actions) - def set_push_rule_actions_txn(txn): + def set_push_rule_actions_txn(txn, stream_id, stream_ordering): if is_default_rule: # Add a dummy rule to the rules table with the user specified # actions. priority_class = -1 priority = 1 self._upsert_push_rule_txn( - txn, user_id, rule_id, priority_class, priority, - "[]", actions_json + txn, stream_id, stream_ordering, user_id, rule_id, + priority_class, priority, "[]", actions_json, + update_stream=False ) else: self._simple_update_one_txn( @@ -315,8 +370,46 @@ class PushRuleStore(SQLBaseStore): {'actions': actions_json}, ) + self._simple_insert_txn( + txn, + "push_rules_stream", + values={ + "stream_id": stream_id, + "stream_ordering": stream_ordering, + "user_id": user_id, + "rule_id": rule_id, + "op": "ACTIONS", + "actions": actions_json, + } + ) + + txn.call_after( + self.get_push_rules_for_user.invalidate, (user_id,) + ) + txn.call_after( + self.get_push_rules_enabled_for_user.invalidate, (user_id,) + ) + + with self._push_rules_stream_id_gen.get_next() as (stream_id, stream_ordering): + yield self.runInteraction( + "set_push_rule_actions", set_push_rule_actions_txn, + stream_id, stream_ordering + ) + + def get_all_push_rule_updates(self, last_id, current_id, limit): + """Get all the push rules changes that have happend on the server""" + def get_all_push_rule_updates_txn(txn): + sql = ( + "SELECT stream_id, stream_ordering, user_id, rule_id," + " op, priority_class, priority, conditions, actions" + " FROM push_rules_stream" + " WHERE ? < stream_id and stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() return self.runInteraction( - "set_push_rule_actions", set_push_rule_actions_txn, + "get_all_push_rule_updates", get_all_push_rule_updates_txn ) diff --git a/synapse/storage/schema/delta/30/push_rule_stream.sql b/synapse/storage/schema/delta/30/push_rule_stream.sql new file mode 100644 index 0000000000..e8418bb35f --- /dev/null +++ b/synapse/storage/schema/delta/30/push_rule_stream.sql @@ -0,0 +1,38 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + +CREATE TABLE push_rules_stream( + stream_id BIGINT NOT NULL, + stream_ordering BIGINT NOT NULL, + user_id TEXT NOT NULL, + rule_id TEXT NOT NULL, + op TEXT NOT NULL, -- One of "ENABLE", "DISABLE", "ACTIONS", "ADD", "DELETE" + priority_class SMALLINT, + priority INTEGER, + conditions TEXT, + actions TEXT +); + +-- The extra data for each operation is: +-- * ENABLE, DISABLE, DELETE: [] +-- * ACTIONS: ["actions"] +-- * ADD: ["priority_class", "priority", "actions", "conditions"] + +-- Index for replication queries. +CREATE INDEX push_rules_stream_id ON push_rules_stream(stream_id); +-- Index for /sync queries. +CREATE INDEX push_rules_stream_user_stream_id on push_rules_stream(user_id, stream_id); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index efe3f68e6e..af425ba9a4 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -20,23 +20,21 @@ import threading class IdGenerator(object): def __init__(self, db_conn, table, column): - self.table = table - self.column = column self._lock = threading.Lock() - cur = db_conn.cursor() - self._next_id = self._load_next_id(cur) - cur.close() - - def _load_next_id(self, txn): - txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table,)) - val, = txn.fetchone() - return val + 1 if val else 1 + self._next_id = _load_max_id(db_conn, table, column) def get_next(self): with self._lock: - i = self._next_id self._next_id += 1 - return i + return self._next_id + + +def _load_max_id(db_conn, table, column): + cur = db_conn.cursor() + cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) + val, = cur.fetchone() + cur.close() + return val if val else 1 class StreamIdGenerator(object): @@ -52,23 +50,10 @@ class StreamIdGenerator(object): # ... persist event ... """ def __init__(self, db_conn, table, column): - self.table = table - self.column = column - self._lock = threading.Lock() - - cur = db_conn.cursor() - self._current_max = self._load_current_max(cur) - cur.close() - + self._current_max = _load_max_id(db_conn, table, column) self._unfinished_ids = deque() - def _load_current_max(self, txn): - txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table)) - rows = txn.fetchall() - val, = rows[0] - return int(val) if val else 1 - def get_next(self): """ Usage: @@ -124,3 +109,50 @@ class StreamIdGenerator(object): return self._unfinished_ids[0] - 1 return self._current_max + + +class ChainedIdGenerator(object): + """Used to generate new stream ids where the stream must be kept in sync + with another stream. It generates pairs of IDs, the first element is an + integer ID for this stream, the second element is the ID for the stream + that this stream needs to be kept in sync with.""" + + def __init__(self, chained_generator, db_conn, table, column): + self.chained_generator = chained_generator + self._lock = threading.Lock() + self._current_max = _load_max_id(db_conn, table, column) + self._unfinished_ids = deque() + + def get_next(self): + """ + Usage: + with stream_id_gen.get_next() as (stream_id, chained_id): + # ... persist event ... + """ + with self._lock: + self._current_max += 1 + next_id = self._current_max + chained_id = self.chained_generator.get_max_token() + + self._unfinished_ids.append((next_id, chained_id)) + + @contextlib.contextmanager + def manager(): + try: + yield (next_id, chained_id) + finally: + with self._lock: + self._unfinished_ids.remove((next_id, chained_id)) + + return manager() + + def get_max_token(self): + """Returns the maximum stream id such that all stream ids less than or + equal to it have been successfully persisted. + """ + with self._lock: + if self._unfinished_ids: + stream_id, chained_id = self._unfinished_ids[0] + return (stream_id - 1, chained_id) + + return (self._current_max, self.chained_generator.get_max_token()) -- cgit 1.4.1 From 2223204ebaf7624f4d640f2c56d3a4eb7ff6d98e Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 2 Mar 2016 17:26:20 +0000 Subject: Hook push rules up to the replication API --- synapse/replication/resource.py | 28 ++++++++++++++++++++++++++-- synapse/storage/push_rule.py | 6 ++++++ tests/replication/test_resource.py | 6 ++++-- 3 files changed, 36 insertions(+), 4 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index e0d039518d..15b7898a45 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -36,6 +36,7 @@ STREAM_NAMES = ( ("receipts",), ("user_account_data", "room_account_data", "tag_account_data",), ("backfill",), + ("push_rules",), ) @@ -63,6 +64,7 @@ class ReplicationResource(Resource): * "room_account_data: Per room per user account data. * "tag_account_data": Per room per user tags. * "backfill": Old events that have been backfilled from other servers. + * "push_rules": Per user changes to push rules. The API takes two additional query parameters: @@ -117,14 +119,16 @@ class ReplicationResource(Resource): def current_replication_token(self): stream_token = yield self.sources.get_current_token() backfill_token = yield self.store.get_current_backfill_token() + push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() defer.returnValue(_ReplicationToken( - stream_token.room_stream_id, + room_stream_token, int(stream_token.presence_key), int(stream_token.typing_key), int(stream_token.receipt_key), int(stream_token.account_data_key), backfill_token, + push_rules_token, )) @request_handler @@ -146,6 +150,7 @@ class ReplicationResource(Resource): yield self.presence(writer, current_token) # TODO: implement limit yield self.typing(writer, current_token) # TODO: implement limit yield self.receipts(writer, current_token, limit) + yield self.push_rules(writer, current_token, limit) self.streams(writer, current_token) logger.info("Replicated %d rows", writer.total) @@ -277,6 +282,21 @@ class ReplicationResource(Resource): "position", "user_id", "room_id", "tags" )) + @defer.inlineCallbacks + def push_rules(self, writer, current_token, limit): + current_position = current_token.push_rules + + push_rules = parse_integer(writer.request, "push_rules") + + if push_rules is not None: + rows = yield self.store.get_all_push_rule_updates( + push_rules, current_position, limit + ) + writer.write_header_and_rows("push_rules", rows, ( + "position", "stream_ordering", "user_id", "rule_id", "op", + "priority_class", "priority", "conditions", "actions" + )) + class _Writer(object): """Writes the streams as a JSON object as the response to the request""" @@ -307,12 +327,16 @@ class _Writer(object): class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( "events", "presence", "typing", "receipts", "account_data", "backfill", + "push_rules" ))): __slots__ = [] def __new__(cls, *args): if len(args) == 1: - return cls(*(int(value) for value in args[0].split("_"))) + streams = [int(value) for value in args[0].split("_")] + if len(streams) < len(cls._fields): + streams.extend([0] * (len(cls._fields) - len(streams))) + return cls(*streams) else: return super(_ReplicationToken, cls).__new__(cls, *args) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index f3ebd49492..e034024108 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -412,6 +412,12 @@ class PushRuleStore(SQLBaseStore): "get_all_push_rule_updates", get_all_push_rule_updates_txn ) + def get_push_rules_stream_token(self): + """Get the position of the push rules stream. + Returns a pair of a stream id for the push_rules stream and the + room stream ordering it corresponds to.""" + return self._push_rules_stream_id_gen.get_max_token() + class RuleNotFoundException(Exception): pass diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py index 38daaf87e2..a30d59a865 100644 --- a/tests/replication/test_resource.py +++ b/tests/replication/test_resource.py @@ -35,7 +35,8 @@ class ReplicationResourceCase(unittest.TestCase): "send_message", ]), ) - self.user = UserID.from_string("@seeing:red") + self.user_id = "@seeing:red" + self.user = UserID.from_string(self.user_id) self.hs.get_ratelimiter().send_message.return_value = (True, 0) @@ -101,7 +102,7 @@ class ReplicationResourceCase(unittest.TestCase): event_id = yield self.send_text_message(room_id, "Hello, World") get = self.get(receipts="-1") yield self.hs.get_handlers().receipts_handler.received_client_receipt( - room_id, "m.read", self.user.to_string(), event_id + room_id, "m.read", self.user_id, event_id ) code, body = yield get self.assertEquals(code, 200) @@ -129,6 +130,7 @@ class ReplicationResourceCase(unittest.TestCase): test_timeout_room_account_data = _test_timeout("room_account_data") test_timeout_tag_account_data = _test_timeout("tag_account_data") test_timeout_backfill = _test_timeout("backfill") + test_timeout_push_rules = _test_timeout("push_rules") @defer.inlineCallbacks def send_text_message(self, room_id, message): -- cgit 1.4.1 From 1b4f4a936fb416d81203fcd66be690f9a04b2b62 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 4 Mar 2016 14:44:01 +0000 Subject: Hook up the push rules stream to account_data in /sync --- synapse/handlers/sync.py | 22 +++++++ synapse/rest/client/v1/push_rule.py | 2 +- synapse/storage/__init__.py | 5 ++ synapse/storage/push_rule.py | 125 ++++++++++++++++-------------------- 4 files changed, 85 insertions(+), 69 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index fded6e4009..92eab20c7c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -20,6 +20,7 @@ from synapse.api.constants import Membership, EventTypes from synapse.util import unwrapFirstError from synapse.util.logcontext import LoggingContext, preserve_fn from synapse.util.metrics import Measure +from synapse.push.clientformat import format_push_rules_for_user from twisted.internet import defer @@ -224,6 +225,10 @@ class SyncHandler(BaseHandler): ) ) + account_data['m.push_rules'] = yield self.push_rules_for_user( + sync_config.user + ) + tags_by_room = yield self.store.get_tags_for_user( sync_config.user.to_string() ) @@ -322,6 +327,14 @@ class SyncHandler(BaseHandler): defer.returnValue(room_sync) + @defer.inlineCallbacks + def push_rules_for_user(self, user): + user_id = user.to_string() + rawrules = yield self.store.get_push_rules_for_user(user_id) + enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) + rules = format_push_rules_for_user(user, rawrules, enabled_map) + defer.returnValue(rules) + def account_data_for_user(self, account_data): account_data_events = [] @@ -481,6 +494,15 @@ class SyncHandler(BaseHandler): ) ) + push_rules_changed = yield self.store.have_push_rules_changed_for_user( + user_id, int(since_token.push_rules_key) + ) + + if push_rules_changed: + account_data["m.push_rules"] = yield self.push_rules_for_user( + sync_config.user + ) + # Get a list of membership change events that have happened. rooms_changed = yield self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index edfe28c79b..981d7708db 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -156,7 +156,7 @@ class PushRuleRestServlet(ClientV1RestServlet): return 200, {} def notify_user(self, user_id): - stream_id = self.store.get_push_rules_stream_token() + stream_id, _ = self.store.get_push_rules_stream_token() self.notifier.on_new_event( "push_rules_key", stream_id, users=[user_id] ) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index e2d7b52569..7b7b03d052 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -160,6 +160,11 @@ class DataStore(RoomMemberStore, RoomStore, prefilled_cache=presence_cache_prefill ) + self.push_rules_stream_cache = StreamChangeCache( + "PushRulesStreamChangeCache", + self._push_rules_stream_id_gen.get_max_token()[0], + ) + super(DataStore, self).__init__(hs) def take_presence_startup_info(self): diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index e034024108..792fcbdf5b 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -244,15 +244,10 @@ class PushRuleStore(SQLBaseStore): ) if update_stream: - self._simple_insert_txn( - txn, - table="push_rules_stream", - values={ - "stream_id": stream_id, - "stream_ordering": stream_ordering, - "user_id": user_id, - "rule_id": rule_id, - "op": "ADD", + self._insert_push_rules_update_txn( + txn, stream_id, stream_ordering, user_id, rule_id, + op="ADD", + data={ "priority_class": priority_class, "priority": priority, "conditions": conditions_json, @@ -260,13 +255,6 @@ class PushRuleStore(SQLBaseStore): } ) - txn.call_after( - self.get_push_rules_for_user.invalidate, (user_id,) - ) - txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, (user_id,) - ) - @defer.inlineCallbacks def delete_push_rule(self, user_id, rule_id): """ @@ -284,22 +272,10 @@ class PushRuleStore(SQLBaseStore): "push_rules", {'user_name': user_id, 'rule_id': rule_id}, ) - self._simple_insert_txn( - txn, - table="push_rules_stream", - values={ - "stream_id": stream_id, - "stream_ordering": stream_ordering, - "user_id": user_id, - "rule_id": rule_id, - "op": "DELETE", - } - ) - txn.call_after( - self.get_push_rules_for_user.invalidate, (user_id,) - ) - txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, (user_id,) + + self._insert_push_rules_update_txn( + txn, stream_id, stream_ordering, user_id, rule_id, + op="DELETE" ) with self._push_rules_stream_id_gen.get_next() as (stream_id, stream_ordering): @@ -328,23 +304,9 @@ class PushRuleStore(SQLBaseStore): {'id': new_id}, ) - self._simple_insert_txn( - txn, - "push_rules_stream", - values={ - "stream_id": stream_id, - "stream_ordering": stream_ordering, - "user_id": user_id, - "rule_id": rule_id, - "op": "ENABLE" if enabled else "DISABLE", - } - ) - - txn.call_after( - self.get_push_rules_for_user.invalidate, (user_id,) - ) - txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, (user_id,) + self._insert_push_rules_update_txn( + txn, stream_id, stream_ordering, user_id, rule_id, + op="ENABLE" if enabled else "DISABLE" ) @defer.inlineCallbacks @@ -370,24 +332,9 @@ class PushRuleStore(SQLBaseStore): {'actions': actions_json}, ) - self._simple_insert_txn( - txn, - "push_rules_stream", - values={ - "stream_id": stream_id, - "stream_ordering": stream_ordering, - "user_id": user_id, - "rule_id": rule_id, - "op": "ACTIONS", - "actions": actions_json, - } - ) - - txn.call_after( - self.get_push_rules_for_user.invalidate, (user_id,) - ) - txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, (user_id,) + self._insert_push_rules_update_txn( + txn, stream_id, stream_ordering, user_id, rule_id, + op="ACTIONS", data={"actions": actions_json} ) with self._push_rules_stream_id_gen.get_next() as (stream_id, stream_ordering): @@ -396,6 +343,31 @@ class PushRuleStore(SQLBaseStore): stream_id, stream_ordering ) + def _insert_push_rules_update_txn( + self, txn, stream_id, stream_ordering, user_id, rule_id, op, data=None + ): + values = { + "stream_id": stream_id, + "stream_ordering": stream_ordering, + "user_id": user_id, + "rule_id": rule_id, + "op": op, + } + if data is not None: + values.update(data) + + self._simple_insert_txn(txn, "push_rules_stream", values=values) + + txn.call_after( + self.get_push_rules_for_user.invalidate, (user_id,) + ) + txn.call_after( + self.get_push_rules_enabled_for_user.invalidate, (user_id,) + ) + txn.call_after( + self.push_rules_stream_cache.entity_has_changed, user_id, stream_id + ) + def get_all_push_rule_updates(self, last_id, current_id, limit): """Get all the push rules changes that have happend on the server""" def get_all_push_rule_updates_txn(txn): @@ -403,7 +375,7 @@ class PushRuleStore(SQLBaseStore): "SELECT stream_id, stream_ordering, user_id, rule_id," " op, priority_class, priority, conditions, actions" " FROM push_rules_stream" - " WHERE ? < stream_id and stream_id <= ?" + " WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_id, current_id, limit)) @@ -418,6 +390,23 @@ class PushRuleStore(SQLBaseStore): room stream ordering it corresponds to.""" return self._push_rules_stream_id_gen.get_max_token() + def have_push_rules_changed_for_user(self, user_id, last_id): + if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): + logger.error("FNARG") + return defer.succeed(False) + else: + def have_push_rules_changed_txn(txn): + sql = ( + "SELECT COUNT(stream_id) FROM push_rules_stream" + " WHERE user_id = ? AND ? < stream_id" + ) + txn.execute(sql, (user_id, last_id)) + count, = txn.fetchone() + return bool(count) + return self.runInteraction( + "have_push_rules_changed", have_push_rules_changed_txn + ) + class RuleNotFoundException(Exception): pass -- cgit 1.4.1 From 7e9fc9b6af2052441a54613627aebeb4999d1efe Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 4 Mar 2016 15:54:09 +0000 Subject: /FNARG/d --- synapse/storage/push_rule.py | 1 - 1 file changed, 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 792fcbdf5b..57e1ca5509 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -392,7 +392,6 @@ class PushRuleStore(SQLBaseStore): def have_push_rules_changed_for_user(self, user_id, last_id): if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): - logger.error("FNARG") return defer.succeed(False) else: def have_push_rules_changed_txn(txn): -- cgit 1.4.1 From ebcbb23226904f080e6a9c1e2f2901886c286445 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 4 Mar 2016 16:15:23 +0000 Subject: s/stream_ordering/event_stream_ordering/ in push --- synapse/replication/resource.py | 2 +- synapse/storage/push_rule.py | 54 ++++++++++++---------- .../storage/schema/delta/30/push_rule_stream.sql | 2 +- 3 files changed, 31 insertions(+), 27 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index 15b7898a45..adc1eb1d0b 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -293,7 +293,7 @@ class ReplicationResource(Resource): push_rules, current_position, limit ) writer.write_header_and_rows("push_rules", rows, ( - "position", "stream_ordering", "user_id", "rule_id", "op", + "position", "event_stream_ordering", "user_id", "rule_id", "op", "priority_class", "priority", "conditions", "actions" )) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 57e1ca5509..9dbad2fd5f 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -106,24 +106,25 @@ class PushRuleStore(SQLBaseStore): ): conditions_json = json.dumps(conditions) actions_json = json.dumps(actions) - with self._push_rules_stream_id_gen.get_next() as (stream_id, stream_ordering): + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids if before or after: yield self.runInteraction( "_add_push_rule_relative_txn", self._add_push_rule_relative_txn, - stream_id, stream_ordering, user_id, rule_id, priority_class, + stream_id, event_stream_ordering, user_id, rule_id, priority_class, conditions_json, actions_json, before, after, ) else: yield self.runInteraction( "_add_push_rule_highest_priority_txn", self._add_push_rule_highest_priority_txn, - stream_id, stream_ordering, user_id, rule_id, priority_class, + stream_id, event_stream_ordering, user_id, rule_id, priority_class, conditions_json, actions_json, ) def _add_push_rule_relative_txn( - self, txn, stream_id, stream_ordering, user_id, rule_id, priority_class, + self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class, conditions_json, actions_json, before, after ): # Lock the table since otherwise we'll have annoying races between the @@ -175,12 +176,12 @@ class PushRuleStore(SQLBaseStore): txn.execute(sql, (user_id, priority_class, new_rule_priority)) self._upsert_push_rule_txn( - txn, stream_id, stream_ordering, user_id, rule_id, priority_class, + txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class, new_rule_priority, conditions_json, actions_json, ) def _add_push_rule_highest_priority_txn( - self, txn, stream_id, stream_ordering, user_id, rule_id, priority_class, + self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class, conditions_json, actions_json ): # Lock the table since otherwise we'll have annoying races between the @@ -202,12 +203,12 @@ class PushRuleStore(SQLBaseStore): self._upsert_push_rule_txn( txn, - stream_id, stream_ordering, user_id, rule_id, priority_class, new_prio, + stream_id, event_stream_ordering, user_id, rule_id, priority_class, new_prio, conditions_json, actions_json, ) def _upsert_push_rule_txn( - self, txn, stream_id, stream_ordering, user_id, rule_id, priority_class, + self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class, priority, conditions_json, actions_json, update_stream=True ): """Specialised version of _simple_upsert_txn that picks a push_rule_id @@ -245,7 +246,7 @@ class PushRuleStore(SQLBaseStore): if update_stream: self._insert_push_rules_update_txn( - txn, stream_id, stream_ordering, user_id, rule_id, + txn, stream_id, event_stream_ordering, user_id, rule_id, op="ADD", data={ "priority_class": priority_class, @@ -266,7 +267,7 @@ class PushRuleStore(SQLBaseStore): user_id (str): The matrix ID of the push rule owner rule_id (str): The rule_id of the rule to be deleted """ - def delete_push_rule_txn(txn, stream_id, stream_ordering): + def delete_push_rule_txn(txn, stream_id, event_stream_ordering): self._simple_delete_one_txn( txn, "push_rules", @@ -274,26 +275,28 @@ class PushRuleStore(SQLBaseStore): ) self._insert_push_rules_update_txn( - txn, stream_id, stream_ordering, user_id, rule_id, + txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" ) - with self._push_rules_stream_id_gen.get_next() as (stream_id, stream_ordering): + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids yield self.runInteraction( - "delete_push_rule", delete_push_rule_txn, stream_id, stream_ordering + "delete_push_rule", delete_push_rule_txn, stream_id, event_stream_ordering ) @defer.inlineCallbacks def set_push_rule_enabled(self, user_id, rule_id, enabled): - with self._push_rules_stream_id_gen.get_next() as (stream_id, stream_ordering): + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids yield self.runInteraction( "_set_push_rule_enabled_txn", self._set_push_rule_enabled_txn, - stream_id, stream_ordering, user_id, rule_id, enabled + stream_id, event_stream_ordering, user_id, rule_id, enabled ) def _set_push_rule_enabled_txn( - self, txn, stream_id, stream_ordering, user_id, rule_id, enabled + self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled ): new_id = self._push_rules_enable_id_gen.get_next() self._simple_upsert_txn( @@ -305,7 +308,7 @@ class PushRuleStore(SQLBaseStore): ) self._insert_push_rules_update_txn( - txn, stream_id, stream_ordering, user_id, rule_id, + txn, stream_id, event_stream_ordering, user_id, rule_id, op="ENABLE" if enabled else "DISABLE" ) @@ -313,14 +316,14 @@ class PushRuleStore(SQLBaseStore): def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): actions_json = json.dumps(actions) - def set_push_rule_actions_txn(txn, stream_id, stream_ordering): + def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): if is_default_rule: # Add a dummy rule to the rules table with the user specified # actions. priority_class = -1 priority = 1 self._upsert_push_rule_txn( - txn, stream_id, stream_ordering, user_id, rule_id, + txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class, priority, "[]", actions_json, update_stream=False ) @@ -333,22 +336,23 @@ class PushRuleStore(SQLBaseStore): ) self._insert_push_rules_update_txn( - txn, stream_id, stream_ordering, user_id, rule_id, + txn, stream_id, event_stream_ordering, user_id, rule_id, op="ACTIONS", data={"actions": actions_json} ) - with self._push_rules_stream_id_gen.get_next() as (stream_id, stream_ordering): + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids yield self.runInteraction( "set_push_rule_actions", set_push_rule_actions_txn, - stream_id, stream_ordering + stream_id, event_stream_ordering ) def _insert_push_rules_update_txn( - self, txn, stream_id, stream_ordering, user_id, rule_id, op, data=None + self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None ): values = { "stream_id": stream_id, - "stream_ordering": stream_ordering, + "event_stream_ordering": event_stream_ordering, "user_id": user_id, "rule_id": rule_id, "op": op, @@ -372,7 +376,7 @@ class PushRuleStore(SQLBaseStore): """Get all the push rules changes that have happend on the server""" def get_all_push_rule_updates_txn(txn): sql = ( - "SELECT stream_id, stream_ordering, user_id, rule_id," + "SELECT stream_id, event_stream_ordering, user_id, rule_id," " op, priority_class, priority, conditions, actions" " FROM push_rules_stream" " WHERE ? < stream_id AND stream_id <= ?" diff --git a/synapse/storage/schema/delta/30/push_rule_stream.sql b/synapse/storage/schema/delta/30/push_rule_stream.sql index e8418bb35f..735aa8d5f6 100644 --- a/synapse/storage/schema/delta/30/push_rule_stream.sql +++ b/synapse/storage/schema/delta/30/push_rule_stream.sql @@ -17,7 +17,7 @@ CREATE TABLE push_rules_stream( stream_id BIGINT NOT NULL, - stream_ordering BIGINT NOT NULL, + event_stream_ordering BIGINT NOT NULL, user_id TEXT NOT NULL, rule_id TEXT NOT NULL, op TEXT NOT NULL, -- One of "ENABLE", "DISABLE", "ACTIONS", "ADD", "DELETE" -- cgit 1.4.1 From deda48068c24083750a9bfc21d114c12e8347969 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 4 Mar 2016 16:19:42 +0000 Subject: prefill the push rules stream change cache --- synapse/storage/__init__.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 7b7b03d052..ab2f115adf 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -160,9 +160,16 @@ class DataStore(RoomMemberStore, RoomStore, prefilled_cache=presence_cache_prefill ) + push_rules_prefill, push_rules_id = self._get_cache_dict( + db_conn, "presence_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self._push_rules_stream_id_gen.get_max_token()[0], + ) + self.push_rules_stream_cache = StreamChangeCache( - "PushRulesStreamChangeCache", - self._push_rules_stream_id_gen.get_max_token()[0], + "PushRulesStreamChangeCache", push_rules_id, + prefilled_cache=push_rules_prefill, ) super(DataStore, self).__init__(hs) -- cgit 1.4.1 From 9848b54cac2c7e077317eec85ee0de2cb567c561 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 4 Mar 2016 16:20:22 +0000 Subject: Prefill from the correct stream --- synapse/storage/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index ab2f115adf..6f37a85d09 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -161,7 +161,7 @@ class DataStore(RoomMemberStore, RoomStore, ) push_rules_prefill, push_rules_id = self._get_cache_dict( - db_conn, "presence_stream", + db_conn, "push_rules_stream", entity_column="user_id", stream_column="stream_id", max_value=self._push_rules_stream_id_gen.get_max_token()[0], -- cgit 1.4.1 From 7076082ae677b280c5b68df37b1fee2fc72752ff Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 8 Mar 2016 11:45:50 +0000 Subject: Fix relative imports so they work in both py3 and py27 --- synapse/push/__init__.py | 4 ++-- synapse/push/action_generator.py | 4 ++-- synapse/push/bulk_push_rule_evaluator.py | 6 +++--- synapse/push/push_rule_evaluator.py | 4 ++-- synapse/push/pusherpool.py | 2 +- synapse/rest/client/v1/admin.py | 2 +- synapse/rest/client/v1/initial_sync.py | 2 +- synapse/rest/client/v1/login.py | 2 +- synapse/rest/client/v1/register.py | 2 +- synapse/rest/client/v1/room.py | 2 +- synapse/rest/client/v1/voip.py | 2 +- synapse/storage/__init__.py | 2 +- synapse/storage/end_to_end_keys.py | 2 +- synapse/storage/events.py | 2 +- synapse/storage/keys.py | 2 +- synapse/storage/media_repository.py | 2 +- synapse/storage/signatures.py | 2 +- 17 files changed, 22 insertions(+), 22 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 4c6c3b83a2..65ef1b68a3 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -21,7 +21,7 @@ from synapse.util.logcontext import LoggingContext from synapse.util.metrics import Measure import synapse.util.async -import push_rule_evaluator as push_rule_evaluator +from .push_rule_evaluator import evaluator_for_user_id import logging import random @@ -185,7 +185,7 @@ class Pusher(object): processed = False rule_evaluator = yield \ - push_rule_evaluator.evaluator_for_user_id( + evaluator_for_user_id( self.user_id, single_event['room_id'], self.store ) diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index c6c1dc769e..84efcdd184 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -15,7 +15,7 @@ from twisted.internet import defer -import bulk_push_rule_evaluator +from .bulk_push_rule_evaluator import evaluator_for_room_id import logging @@ -35,7 +35,7 @@ class ActionGenerator: @defer.inlineCallbacks def handle_push_actions_for_event(self, event, context, handler): - bulk_evaluator = yield bulk_push_rule_evaluator.evaluator_for_room_id( + bulk_evaluator = yield evaluator_for_room_id( event.room_id, self.hs, self.store ) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 5d8be483e5..87d5061fb0 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -18,8 +18,8 @@ import ujson as json from twisted.internet import defer -import baserules -from push_rule_evaluator import PushRuleEvaluatorForEvent +from .baserules import list_with_base_rules +from .push_rule_evaluator import PushRuleEvaluatorForEvent from synapse.api.constants import EventTypes @@ -39,7 +39,7 @@ def _get_rules(room_id, user_ids, store): rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids) rules_by_user = { - uid: baserules.list_with_base_rules([ + uid: list_with_base_rules([ decode_rule_json(rule_list) for rule_list in rules_by_user.get(uid, []) ]) diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 98e2a2015e..51f73a5b78 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -15,7 +15,7 @@ from twisted.internet import defer -import baserules +from .baserules import list_with_base_rules import logging import simplejson as json @@ -91,7 +91,7 @@ class PushRuleEvaluator: rule['actions'] = json.loads(raw_rule['actions']) rules.append(rule) - self.rules = baserules.list_with_base_rules(rules) + self.rules = list_with_base_rules(rules) self.enabled_map = enabled_map diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index a05aa5f661..772a095f8b 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -16,7 +16,7 @@ from twisted.internet import defer -from httppusher import HttpPusher +from .httppusher import HttpPusher from synapse.push import PusherConfigException from synapse.util.logcontext import preserve_fn diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index e2f5eb7b29..aa05b3f023 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import AuthError, SynapseError from synapse.types import UserID -from base import ClientV1RestServlet, client_path_patterns +from .base import ClientV1RestServlet, client_path_patterns import logging diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index ad161bdbab..36c3520567 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -16,7 +16,7 @@ from twisted.internet import defer from synapse.streams.config import PaginationConfig -from base import ClientV1RestServlet, client_path_patterns +from .base import ClientV1RestServlet, client_path_patterns # TODO: Needs unit testing diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index c14e8af00e..f6902a60a8 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -19,7 +19,7 @@ from synapse.api.errors import SynapseError, LoginError, Codes from synapse.types import UserID from synapse.http.server import finish_request -from base import ClientV1RestServlet, client_path_patterns +from .base import ClientV1RestServlet, client_path_patterns import simplejson as json import urllib diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 6d6d03c34c..040a7a7ffa 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, Codes from synapse.api.constants import LoginType -from base import ClientV1RestServlet, client_path_patterns +from .base import ClientV1RestServlet, client_path_patterns import synapse.util.stringutils as stringutils from synapse.util.async import run_on_reactor diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index cbf3673eff..4b7d198c52 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -16,7 +16,7 @@ """ This module contains REST servlets to do with rooms: /rooms/ """ from twisted.internet import defer -from base import ClientV1RestServlet, client_path_patterns +from .base import ClientV1RestServlet, client_path_patterns from synapse.api.errors import SynapseError, Codes, AuthError from synapse.streams.config import PaginationConfig from synapse.api.constants import EventTypes, Membership diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index ec4cf8db79..c40442f958 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -15,7 +15,7 @@ from twisted.internet import defer -from base import ClientV1RestServlet, client_path_patterns +from .base import ClientV1RestServlet, client_path_patterns import hmac diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 6f37a85d09..168eb27b03 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -45,7 +45,7 @@ from .search import SearchStore from .tags import TagsStore from .account_data import AccountDataStore -from util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator +from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator from synapse.api.constants import PresenceState from synapse.util.caches.stream_change_cache import StreamChangeCache diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 5dd32b1413..2e89066515 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from _base import SQLBaseStore +from ._base import SQLBaseStore class EndToEndKeyStore(SQLBaseStore): diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 60936500d8..552e7ca35b 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from _base import SQLBaseStore, _RollbackButIsFineException +from ._base import SQLBaseStore, _RollbackButIsFineException from twisted.internet import defer, reactor diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index fd05bfe54e..a495a8a7d9 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from _base import SQLBaseStore +from ._base import SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks from twisted.internet import defer diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py index 0894384780..9d3ba32478 100644 --- a/synapse/storage/media_repository.py +++ b/synapse/storage/media_repository.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from _base import SQLBaseStore +from ._base import SQLBaseStore class MediaRepositoryStore(SQLBaseStore): diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index 70c6a06cd1..b10f2a5787 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -15,7 +15,7 @@ from twisted.internet import defer -from _base import SQLBaseStore +from ._base import SQLBaseStore from unpaddedbase64 import encode_base64 from synapse.crypto.event_signing import compute_event_reference_hash -- cgit 1.4.1 From edca2d989137680093c4acdcaf6ca4029f11a335 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 8 Mar 2016 17:32:29 +0000 Subject: Idempotent-ise schema update script If any ASes don't have an ID, the schema will fail, and then it will error when trying to add the column again. --- synapse/storage/schema/delta/30/as_users.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py index 4cf4dd0917..4da3c59de2 100644 --- a/synapse/storage/schema/delta/30/as_users.py +++ b/synapse/storage/schema/delta/30/as_users.py @@ -20,7 +20,11 @@ logger = logging.getLogger(__name__) def run_upgrade(cur, database_engine, config, *args, **kwargs): # NULL indicates user was not registered by an appservice. - cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT") + try: + cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT") + except: + # Maybe we already added the column? Hope so... + pass cur.execute("SELECT name FROM users") rows = cur.fetchall() -- cgit 1.4.1 From 158a322e8274b8ed031a20a00a3700d0798ae1c2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 9 Mar 2016 10:20:48 +0000 Subject: Ensure integer is an integer --- synapse/storage/util/id_generators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index af425ba9a4..610ddad423 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -34,7 +34,7 @@ def _load_max_id(db_conn, table, column): cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) val, = cur.fetchone() cur.close() - return val if val else 1 + return int(val) if val else 1 class StreamIdGenerator(object): -- cgit 1.4.1 From 3ecaabc7fd8a1a461c319248a94617cf24dd4070 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 9 Mar 2016 15:45:34 +0000 Subject: Use topological orders for initial sync timeline --- synapse/storage/stream.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 8908d5b5da..613b54cd1d 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -217,8 +217,8 @@ class StreamStore(SQLBaseStore): " room_id = ?" " AND not outlier" " AND stream_ordering <= ?" - " ORDER BY stream_ordering %s LIMIT ?" - ) % (order,) + " ORDER BY topological_ordering %s, stream_ordering %s LIMIT ?" + ) % (order, order,) txn.execute(sql, (room_id, to_id, limit)) rows = self.cursor_to_dict(txn) -- cgit 1.4.1 From af2fe6110c6e249ce0c679ca39bb67d2c16c59c3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 9 Mar 2016 16:11:53 +0000 Subject: Return the correct token form --- synapse/storage/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 613b54cd1d..c9f70a08a7 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -232,7 +232,7 @@ class StreamStore(SQLBaseStore): get_prev_content=True ) - self._set_before_and_after(ret, rows, topo_order=False) + self._set_before_and_after(ret, rows, topo_order=from_id is None) if order.lower() == "desc": ret.reverse() -- cgit 1.4.1 From 8a88684736efcb6792d17249e2976caefe00ee3d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 9 Mar 2016 16:51:22 +0000 Subject: Add comment --- synapse/storage/stream.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'synapse/storage') diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index c9f70a08a7..7f4a827528 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -184,6 +184,9 @@ class StreamStore(SQLBaseStore): @defer.inlineCallbacks def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0, order='DESC'): + # Note: If from_key is None then we return in topological order. This + # is because in that case we're using this as a "get the last few messages + # in a room" function, rather than "get new messages since last sync" if from_key is not None: from_id = RoomStreamToken.parse_stream_token(from_key).stream else: -- cgit 1.4.1 From 9669a99d1a76f346b2cfb9b4197636ac3142f9d2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 10 Mar 2016 15:12:19 +0000 Subject: Update users table in a batched manner --- synapse/storage/schema/delta/30/as_users.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py index 4da3c59de2..4f6e9dd540 100644 --- a/synapse/storage/schema/delta/30/as_users.py +++ b/synapse/storage/schema/delta/30/as_users.py @@ -52,12 +52,17 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): " service (IDs %s and %s); assigning arbitrarily to %s" % (user_id, owned[user_id], appservice.id, owned[user_id]) ) - owned[user_id] = appservice.id - - for user_id, as_id in owned.items(): - cur.execute( - database_engine.convert_param_style( - "UPDATE users SET appservice_id = ? WHERE name = ?" - ), - (as_id, user_id) - ) + owned.setdefault(appservice.id, []).append(user_id) + + for as_id, user_ids in owned.items(): + n = 100 + user_chunks = (user_ids[i:i + 100] for i in xrange(0, len(user_ids), n)) + for chunk in user_chunks: + cur.execute( + database_engine.convert_param_style( + "UPDATE users SET appservice_id = ? WHERE name IN (%s)" % ( + ",".join("?" for _ in chunk), + ) + ), + [as_id] + chunk + ) -- cgit 1.4.1 From 465605d616c991760ce021932f0453fc6bc477ef Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Thu, 10 Mar 2016 15:58:22 +0000 Subject: Store appservice ID on register --- synapse/handlers/register.py | 5 ++++- synapse/storage/registration.py | 40 ++++++++++++++++++++++++++++++++++------ 2 files changed, 38 insertions(+), 7 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index e2ace6a4e5..6ffb8c0da6 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -182,6 +182,8 @@ class RegistrationHandler(BaseHandler): errcode=Codes.EXCLUSIVE ) + service_id = service.id if service.is_exclusive_user(user_id) else None + yield self.check_user_id_not_appservice_exclusive( user_id, allowed_appservice=service ) @@ -190,7 +192,8 @@ class RegistrationHandler(BaseHandler): yield self.store.register( user_id=user_id, token=token, - password_hash="" + password_hash="", + appservice_id=service_id, ) yield registered_user(self.distributor, user) defer.returnValue((user_id, token)) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index ad1157f979..aa49f53458 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -76,7 +76,7 @@ class RegistrationStore(SQLBaseStore): @defer.inlineCallbacks def register(self, user_id, token, password_hash, - was_guest=False, make_guest=False): + was_guest=False, make_guest=False, appservice_id=None): """Attempts to register an account. Args: @@ -87,16 +87,32 @@ class RegistrationStore(SQLBaseStore): upgraded to a non-guest account. make_guest (boolean): True if the the new user should be guest, false to add a regular user account. + appservice_id (str): The ID of the appservice registering the user. Raises: StoreError if the user_id could not be registered. """ yield self.runInteraction( "register", - self._register, user_id, token, password_hash, was_guest, make_guest + self._register, + user_id, + token, + password_hash, + was_guest, + make_guest, + appservice_id ) self.is_guest.invalidate((user_id,)) - def _register(self, txn, user_id, token, password_hash, was_guest, make_guest): + def _register( + self, + txn, + user_id, + token, + password_hash, + was_guest, + make_guest, + appservice_id + ): now = int(self.clock.time()) next_id = self._access_tokens_id_gen.get_next() @@ -111,9 +127,21 @@ class RegistrationStore(SQLBaseStore): [password_hash, now, 1 if make_guest else 0, user_id]) else: txn.execute("INSERT INTO users " - "(name, password_hash, creation_ts, is_guest) " - "VALUES (?,?,?,?)", - [user_id, password_hash, now, 1 if make_guest else 0]) + "(" + " name," + " password_hash," + " creation_ts," + " is_guest," + " appservice_id" + ") " + "VALUES (?,?,?,?,?)", + [ + user_id, + password_hash, + now, + 1 if make_guest else 0, + appservice_id, + ]) except self.database_engine.module.IntegrityError: raise StoreError( 400, "User ID already taken.", errcode=Codes.USER_IN_USE -- cgit 1.4.1 From aa11db5f119b9fa88242b0df95cfddd00e196ca1 Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 11 Mar 2016 13:14:18 +0000 Subject: Fix cache invalidation so deleting access tokens (which we did when changing password) actually takes effect without HS restart. Reinstate the code to avoid logging out the session that changed the password, removed in 415c2f05491ce65a4fc34326519754cd1edd9c54 --- synapse/handlers/auth.py | 13 +++++++++---- synapse/push/pusherpool.py | 8 ++++---- synapse/rest/client/v2_alpha/account.py | 2 +- synapse/storage/registration.py | 28 ++++++++++++++++++++-------- 4 files changed, 34 insertions(+), 17 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 7a4afe446d..a740cc3da3 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -432,13 +432,18 @@ class AuthHandler(BaseHandler): ) @defer.inlineCallbacks - def set_password(self, user_id, newpassword): + def set_password(self, user_id, newpassword, requester=None): password_hash = self.hash(newpassword) + except_access_token_ids = [requester.access_token_id] if requester else [] + yield self.store.user_set_password_hash(user_id, password_hash) - yield self.store.user_delete_access_tokens(user_id) - yield self.hs.get_pusherpool().remove_pushers_by_user(user_id) - yield self.store.flush_user(user_id) + yield self.store.user_delete_access_tokens_except( + user_id, except_access_token_ids + ) + yield self.hs.get_pusherpool().remove_pushers_by_user_except_access_tokens( + user_id, except_access_token_ids + ) @defer.inlineCallbacks def add_threepid(self, user_id, medium, address, validated_at): diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 772a095f8b..28ec94d866 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -92,14 +92,14 @@ class PusherPool: yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) @defer.inlineCallbacks - def remove_pushers_by_user(self, user_id): + def remove_pushers_by_user_except_access_tokens(self, user_id, except_token_ids): all = yield self.store.get_all_pushers() logger.info( - "Removing all pushers for user %s", - user_id, + "Removing all pushers for user %s except access tokens ids %r", + user_id, except_token_ids ) for p in all: - if p['user_name'] == user_id: + if p['user_name'] == user_id and p['access_token'] not in except_token_ids: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", p['app_id'], p['pushkey'], p['user_name'] diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 688b051580..dd4ea45588 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -79,7 +79,7 @@ class PasswordRestServlet(RestServlet): new_password = params['new_password'] yield self.auth_handler.set_password( - user_id, new_password + user_id, new_password, requester ) defer.returnValue((200, {})) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index aa49f53458..5eef7ebcc7 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -208,14 +208,26 @@ class RegistrationStore(SQLBaseStore): ) @defer.inlineCallbacks - def flush_user(self, user_id): - rows = yield self._execute( - 'flush_user', None, - "SELECT token FROM access_tokens WHERE user_id = ?", - user_id - ) - for r in rows: - self.get_user_by_access_token.invalidate((r,)) + def user_delete_access_tokens_except(self, user_id, except_token_ids): + def f(txn): + txn.execute( + "SELECT id, token FROM access_tokens WHERE user_id = ? LIMIT 50", + (user_id,) + ) + rows = txn.fetchall() + for r in rows: + if r[0] in except_token_ids: + continue + + txn.call_after(self.get_user_by_access_token.invalidate, (r[1],)) + txn.execute( + "DELETE FROM access_tokens WHERE id in (%s)" % ",".join( + ["?" for _ in rows] + ), [r[0] for r in rows] + ) + return len(rows) == 50 + while (yield self.runInteraction("user_delete_access_tokens_except", f)): + pass @cached() def get_user_by_access_token(self, token): -- cgit 1.4.1 From 57c444b3ad9e69ae99fe694c0ef9a1961ec9366a Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 11 Mar 2016 14:25:05 +0000 Subject: Dear PyCharm, please indent sensibly for me. Thx. --- synapse/handlers/auth.py | 4 ++-- synapse/storage/registration.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index a740cc3da3..0f02493fa2 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -439,10 +439,10 @@ class AuthHandler(BaseHandler): yield self.store.user_set_password_hash(user_id, password_hash) yield self.store.user_delete_access_tokens_except( - user_id, except_access_token_ids + user_id, except_access_token_ids ) yield self.hs.get_pusherpool().remove_pushers_by_user_except_access_tokens( - user_id, except_access_token_ids + user_id, except_access_token_ids ) @defer.inlineCallbacks diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 5eef7ebcc7..d3d84c9b94 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -212,7 +212,7 @@ class RegistrationStore(SQLBaseStore): def f(txn): txn.execute( "SELECT id, token FROM access_tokens WHERE user_id = ? LIMIT 50", - (user_id,) + (user_id,) ) rows = txn.fetchall() for r in rows: -- cgit 1.4.1 From f523177850df7fbe480086b281f09815f3d5c656 Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 11 Mar 2016 14:29:01 +0000 Subject: Delete old, unused methods and rename new one to just be `user_delete_access_tokens` with an `except_token_ids` argument doing what it says on the tin. --- synapse/handlers/auth.py | 2 +- synapse/storage/registration.py | 17 ++--------------- 2 files changed, 3 insertions(+), 16 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 0f02493fa2..43f2cdc2c4 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -438,7 +438,7 @@ class AuthHandler(BaseHandler): except_access_token_ids = [requester.access_token_id] if requester else [] yield self.store.user_set_password_hash(user_id, password_hash) - yield self.store.user_delete_access_tokens_except( + yield self.store.user_delete_access_tokens( user_id, except_access_token_ids ) yield self.hs.get_pusherpool().remove_pushers_by_user_except_access_tokens( diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index d3d84c9b94..266e29f5bc 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -195,20 +195,7 @@ class RegistrationStore(SQLBaseStore): }) @defer.inlineCallbacks - def user_delete_access_tokens(self, user_id): - yield self.runInteraction( - "user_delete_access_tokens", - self._user_delete_access_tokens, user_id - ) - - def _user_delete_access_tokens(self, txn, user_id): - txn.execute( - "DELETE FROM access_tokens WHERE user_id = ?", - (user_id, ) - ) - - @defer.inlineCallbacks - def user_delete_access_tokens_except(self, user_id, except_token_ids): + def user_delete_access_tokens(self, user_id, except_token_ids): def f(txn): txn.execute( "SELECT id, token FROM access_tokens WHERE user_id = ? LIMIT 50", @@ -226,7 +213,7 @@ class RegistrationStore(SQLBaseStore): ), [r[0] for r in rows] ) return len(rows) == 50 - while (yield self.runInteraction("user_delete_access_tokens_except", f)): + while (yield self.runInteraction("user_delete_access_tokens", f)): pass @cached() -- cgit 1.4.1 From af59826a2fda31a6382f60659796a1c8db6f21ce Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 11 Mar 2016 14:34:09 +0000 Subject: Make select more sensible when dseleting access tokens, rename pusher deletion to match access token deletion and make exception arg optional. --- synapse/handlers/auth.py | 2 +- synapse/push/pusherpool.py | 2 +- synapse/storage/registration.py | 8 +++----- 3 files changed, 5 insertions(+), 7 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 43f2cdc2c4..5c0ea636bc 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -441,7 +441,7 @@ class AuthHandler(BaseHandler): yield self.store.user_delete_access_tokens( user_id, except_access_token_ids ) - yield self.hs.get_pusherpool().remove_pushers_by_user_except_access_tokens( + yield self.hs.get_pusherpool().remove_pushers_by_user( user_id, except_access_token_ids ) diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 28ec94d866..0b463c6fdb 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -92,7 +92,7 @@ class PusherPool: yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) @defer.inlineCallbacks - def remove_pushers_by_user_except_access_tokens(self, user_id, except_token_ids): + def remove_pushers_by_user(self, user_id, except_token_ids=[]): all = yield self.store.get_all_pushers() logger.info( "Removing all pushers for user %s except access tokens ids %r", diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 266e29f5bc..3a050621d9 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -198,14 +198,12 @@ class RegistrationStore(SQLBaseStore): def user_delete_access_tokens(self, user_id, except_token_ids): def f(txn): txn.execute( - "SELECT id, token FROM access_tokens WHERE user_id = ? LIMIT 50", - (user_id,) + "SELECT id, token FROM access_tokens " + "WHERE user_id = ? AND id not in LIMIT 50", + (user_id,except_token_ids) ) rows = txn.fetchall() for r in rows: - if r[0] in except_token_ids: - continue - txn.call_after(self.get_user_by_access_token.invalidate, (r[1],)) txn.execute( "DELETE FROM access_tokens WHERE id in (%s)" % ",".join( -- cgit 1.4.1 From 2dee03aee513eddaf50a4249747beac67445b3cd Mon Sep 17 00:00:00 2001 From: David Baker Date: Fri, 11 Mar 2016 14:38:23 +0000 Subject: more pep8 --- synapse/storage/registration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 3a050621d9..5d45f0c651 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -200,7 +200,7 @@ class RegistrationStore(SQLBaseStore): txn.execute( "SELECT id, token FROM access_tokens " "WHERE user_id = ? AND id not in LIMIT 50", - (user_id,except_token_ids) + (user_id, except_token_ids) ) rows = txn.fetchall() for r in rows: -- cgit 1.4.1 From c08122843982307dad8d099784bc1a4bc0202eae Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 11 Mar 2016 15:09:17 +0000 Subject: Fix SQL statement --- synapse/storage/registration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 5d45f0c651..5e7a4e371d 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -199,7 +199,7 @@ class RegistrationStore(SQLBaseStore): def f(txn): txn.execute( "SELECT id, token FROM access_tokens " - "WHERE user_id = ? AND id not in LIMIT 50", + "WHERE user_id = ? AND id NOT IN ? LIMIT 50", (user_id, except_token_ids) ) rows = txn.fetchall() -- cgit 1.4.1 From b13035cc91410634421820e5175d0596f5a67549 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 11 Mar 2016 16:27:50 +0000 Subject: Implement logout --- synapse/rest/__init__.py | 2 ++ synapse/rest/client/v1/logout.py | 72 ++++++++++++++++++++++++++++++++++++++++ synapse/storage/registration.py | 49 +++++++++++++++++++-------- 3 files changed, 109 insertions(+), 14 deletions(-) create mode 100644 synapse/rest/client/v1/logout.py (limited to 'synapse/storage') diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 433237c204..6688fa8fa0 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -30,6 +30,7 @@ from synapse.rest.client.v1 import ( push_rule, register as v1_register, login as v1_login, + logout, ) from synapse.rest.client.v2_alpha import ( @@ -72,6 +73,7 @@ class ClientRestResource(JsonResource): admin.register_servlets(hs, client_resource) pusher.register_servlets(hs, client_resource) push_rule.register_servlets(hs, client_resource) + logout.register_servlets(hs, client_resource) # "v2" sync.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py new file mode 100644 index 0000000000..9bff02ee4e --- /dev/null +++ b/synapse/rest/client/v1/logout.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import AuthError, Codes + +from .base import ClientV1RestServlet, client_path_patterns + +import logging + + +logger = logging.getLogger(__name__) + + +class LogoutRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/logout$") + + def __init__(self, hs): + super(LogoutRestServlet, self).__init__(hs) + self.store = hs.get_datastore() + + def on_OPTIONS(self, request): + return (200, {}) + + @defer.inlineCallbacks + def on_POST(self, request): + try: + access_token = request.args["access_token"][0] + except KeyError: + raise AuthError( + self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", + errcode=Codes.MISSING_TOKEN + ) + yield self.store.delete_access_token(access_token) + defer.returnValue((200, {})) + + +class LogoutAllRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/logout/all$") + + def __init__(self, hs): + super(LogoutAllRestServlet, self).__init__(hs) + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + def on_OPTIONS(self, request): + return (200, {}) + + @defer.inlineCallbacks + def on_POST(self, request): + requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + yield self.store.user_delete_access_tokens(user_id) + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + LogoutRestServlet(hs).register(http_server) + LogoutAllRestServlet(hs).register(http_server) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 5e7a4e371d..18898c44eb 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -195,24 +195,45 @@ class RegistrationStore(SQLBaseStore): }) @defer.inlineCallbacks - def user_delete_access_tokens(self, user_id, except_token_ids): + def user_delete_access_tokens(self, user_id, except_token_ids=[]): def f(txn): txn.execute( - "SELECT id, token FROM access_tokens " - "WHERE user_id = ? AND id NOT IN ? LIMIT 50", - (user_id, except_token_ids) + "SELECT token FROM access_tokens" + " WHERE user_id = ? AND id NOT IN (%s)" % ( + ",".join(["?" for _ in except_token_ids]), + ), + [user_id] + except_token_ids ) - rows = txn.fetchall() - for r in rows: - txn.call_after(self.get_user_by_access_token.invalidate, (r[1],)) - txn.execute( - "DELETE FROM access_tokens WHERE id in (%s)" % ",".join( - ["?" for _ in rows] - ), [r[0] for r in rows] + + while True: + rows = txn.fetchmany(100) + if not rows: + break + + for row in rows: + txn.call_after(self.get_user_by_access_token.invalidate, (row[0],)) + + txn.execute( + "DELETE FROM access_tokens WHERE token in (%s)" % ( + ",".join(["?" for _ in rows]), + ), [r[0] for r in rows] + ) + + yield self.runInteraction("user_delete_access_tokens", f) + + def delete_access_token(self, access_token): + def f(txn): + self._simple_delete_one_txn( + txn, + table="access_tokens", + keyvalues={ + "token": access_token + }, ) - return len(rows) == 50 - while (yield self.runInteraction("user_delete_access_tokens", f)): - pass + + txn.call_after(self.get_user_by_access_token.invalidate, (access_token,)) + + return self.runInteraction("delete_access_token", f) @cached() def get_user_by_access_token(self, token): -- cgit 1.4.1 From 15122da0e275ba18ec4633129715067a637f38af Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 11 Mar 2016 16:45:27 +0000 Subject: Thats not how transactions work. --- synapse/storage/registration.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 18898c44eb..bd4eb88a92 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -197,26 +197,29 @@ class RegistrationStore(SQLBaseStore): @defer.inlineCallbacks def user_delete_access_tokens(self, user_id, except_token_ids=[]): def f(txn): - txn.execute( - "SELECT token FROM access_tokens" - " WHERE user_id = ? AND id NOT IN (%s)" % ( + sql = "SELECT token FROM access_tokens WHERE user_id = ?" + clauses = [user_id] + + if except_token_ids: + sql += " AND id NOT IN (%s)" % ( ",".join(["?" for _ in except_token_ids]), - ), - [user_id] + except_token_ids - ) + ) + clauses += except_token_ids + + txn.execute(sql, clauses) - while True: - rows = txn.fetchmany(100) - if not rows: - break + rows = txn.fetchall() - for row in rows: + n = 100 + chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)] + for chunk in chunks: + for row in chunk: txn.call_after(self.get_user_by_access_token.invalidate, (row[0],)) txn.execute( "DELETE FROM access_tokens WHERE token in (%s)" % ( - ",".join(["?" for _ in rows]), - ), [r[0] for r in rows] + ",".join(["?" for _ in chunk]), + ), [r[0] for r in chunk] ) yield self.runInteraction("user_delete_access_tokens", f) -- cgit 1.4.1 From b6e8420aeed9921ba7d0fd4c8ebaf1b64d5f677c Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 15 Mar 2016 17:01:43 +0000 Subject: Add replication stream for pushers --- synapse/replication/resource.py | 25 ++++++++- synapse/storage/__init__.py | 5 +- synapse/storage/pusher.py | 63 ++++++++++++++++------ .../storage/schema/delta/30/deleted_pushers.sql | 24 +++++++++ synapse/storage/util/id_generators.py | 7 ++- tests/replication/test_resource.py | 1 + 6 files changed, 107 insertions(+), 18 deletions(-) create mode 100644 synapse/storage/schema/delta/30/deleted_pushers.sql (limited to 'synapse/storage') diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index adc1eb1d0b..8c1ae0fbc7 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -37,6 +37,7 @@ STREAM_NAMES = ( ("user_account_data", "room_account_data", "tag_account_data",), ("backfill",), ("push_rules",), + ("pushers",), ) @@ -65,6 +66,7 @@ class ReplicationResource(Resource): * "tag_account_data": Per room per user tags. * "backfill": Old events that have been backfilled from other servers. * "push_rules": Per user changes to push rules. + * "pushers": Per user changes to their pushers. The API takes two additional query parameters: @@ -120,6 +122,7 @@ class ReplicationResource(Resource): stream_token = yield self.sources.get_current_token() backfill_token = yield self.store.get_current_backfill_token() push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() + pushers_token = self.store.get_pushers_stream_token() defer.returnValue(_ReplicationToken( room_stream_token, @@ -129,6 +132,7 @@ class ReplicationResource(Resource): int(stream_token.account_data_key), backfill_token, push_rules_token, + pushers_token, )) @request_handler @@ -151,6 +155,7 @@ class ReplicationResource(Resource): yield self.typing(writer, current_token) # TODO: implement limit yield self.receipts(writer, current_token, limit) yield self.push_rules(writer, current_token, limit) + yield self.pushers(writer, current_token, limit) self.streams(writer, current_token) logger.info("Replicated %d rows", writer.total) @@ -297,6 +302,24 @@ class ReplicationResource(Resource): "priority_class", "priority", "conditions", "actions" )) + @defer.inlineCallbacks + def pushers(self, writer, current_token, limit): + current_position = current_token.pushers + + pushers = parse_integer(writer.request, "pushers") + if pushers is not None: + updated, deleted = yield self.store.get_all_updated_pushers( + pushers, current_position, limit + ) + writer.write_header_and_rows("pushers", updated, ( + "position", "user_id", "access_token", "profile_tag", "kind", + "app_id", "app_display_name", "device_display_name", "pushkey", + "ts", "lang", "data" + )) + writer.write_header_and_rows("deleted", deleted, ( + "position", "user_id", "app_id", "pushkey" + )) + class _Writer(object): """Writes the streams as a JSON object as the response to the request""" @@ -327,7 +350,7 @@ class _Writer(object): class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( "events", "presence", "typing", "receipts", "account_data", "backfill", - "push_rules" + "push_rules", "pushers" ))): __slots__ = [] diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 168eb27b03..250ba536ea 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -119,12 +119,15 @@ class DataStore(RoomMemberStore, RoomStore, self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") - self._pushers_id_gen = IdGenerator(db_conn, "pushers", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") self._push_rules_stream_id_gen = ChainedIdGenerator( self._stream_id_gen, db_conn, "push_rules_stream", "stream_id" ) + self._pushers_id_gen = StreamIdGenerator( + db_conn, "pushers", "id", + extra_tables=[("deleted_pushers", "stream_id")], + ) events_max = self._stream_id_gen.get_max_token() event_cache_prefill, min_event_val = self._get_cache_dict( diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 7693ab9082..29da3bbd13 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -16,8 +16,6 @@ from ._base import SQLBaseStore from twisted.internet import defer -from synapse.api.errors import StoreError - from canonicaljson import encode_canonical_json import logging @@ -79,12 +77,41 @@ class PusherStore(SQLBaseStore): rows = yield self.runInteraction("get_all_pushers", get_pushers) defer.returnValue(rows) + def get_pushers_stream_token(self): + return self._pushers_id_gen.get_max_token() + + def get_all_updated_pushers(self, last_id, current_id, limit): + def get_all_updated_pushers_txn(txn): + sql = ( + "SELECT id, user_name, access_token, profile_tag, kind," + " app_id, app_display_name, device_display_name, pushkey, ts," + " lang, data" + " FROM pushers" + " WHERE ? < id AND id <= ?" + " ORDER BY id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + updated = txn.fetchall() + + sql = ( + "SELECT stream_id, user_id, app_id, pushkey" + " FROM deleted_pushers" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + deleted = txn.fetchall() + + return (updated, deleted) + return self.runInteraction( + "get_all_updated_pushers", get_all_updated_pushers_txn + ) + @defer.inlineCallbacks def add_pusher(self, user_id, access_token, kind, app_id, app_display_name, device_display_name, pushkey, pushkey_ts, lang, data, profile_tag=""): - try: - next_id = self._pushers_id_gen.get_next() + with self._pushers_id_gen.get_next() as stream_id: yield self._simple_upsert( "pushers", dict( @@ -101,23 +128,29 @@ class PusherStore(SQLBaseStore): lang=lang, data=encode_canonical_json(data), profile_tag=profile_tag, - ), - insertion_values=dict( - id=next_id, + id=stream_id, ), desc="add_pusher", ) - except Exception as e: - logger.error("create_pusher with failed: %s", e) - raise StoreError(500, "Problem creating pusher.") @defer.inlineCallbacks def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id): - yield self._simple_delete_one( - "pushers", - {"app_id": app_id, "pushkey": pushkey, 'user_name': user_id}, - desc="delete_pusher_by_app_id_pushkey_user_id", - ) + def delete_pusher_txn(txn, stream_id): + self._simple_delete_one( + txn, + "pushers", + {"app_id": app_id, "pushkey": pushkey, "user_name": user_id} + ) + self._simple_upsert_txn( + txn, + "deleted_pushers", + {"app_id": app_id, "pushkey": pushkey, "user_id": user_id}, + {"stream_id", stream_id}, + ) + with self._pushers_id_gen.get_next() as stream_id: + yield self.runInteraction( + "delete_pusher", delete_pusher_txn, stream_id + ) @defer.inlineCallbacks def update_pusher_last_token(self, app_id, pushkey, user_id, last_token): diff --git a/synapse/storage/schema/delta/30/deleted_pushers.sql b/synapse/storage/schema/delta/30/deleted_pushers.sql new file mode 100644 index 0000000000..cdcf79ac81 --- /dev/null +++ b/synapse/storage/schema/delta/30/deleted_pushers.sql @@ -0,0 +1,24 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS deleted_pushers( + stream_id BIGINT NOT NULL, + app_id TEXT NOT NULL, + pushkey TEXT NOT NULL, + user_id TEXT NOT NULL, + UNIQUE (app_id, pushkey, user_id) +); + +CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 610ddad423..a02dfc7d58 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -49,9 +49,14 @@ class StreamIdGenerator(object): with stream_id_gen.get_next() as stream_id: # ... persist event ... """ - def __init__(self, db_conn, table, column): + def __init__(self, db_conn, table, column, extra_tables=[]): self._lock = threading.Lock() self._current_max = _load_max_id(db_conn, table, column) + for table, column in extra_tables: + self._current_max = max( + self._current_max, + _load_max_id(db_conn, table, column) + ) self._unfinished_ids = deque() def get_next(self): diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py index 4a42eb3365..f4b5fb3328 100644 --- a/tests/replication/test_resource.py +++ b/tests/replication/test_resource.py @@ -131,6 +131,7 @@ class ReplicationResourceCase(unittest.TestCase): test_timeout_tag_account_data = _test_timeout("tag_account_data") test_timeout_backfill = _test_timeout("backfill") test_timeout_push_rules = _test_timeout("push_rules") + test_timeout_pushers = _test_timeout("pushers") @defer.inlineCallbacks def send_text_message(self, room_id, message): -- cgit 1.4.1 From ee32d622cec56f2ab7b11577d15e4b805477d13f Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 15 Mar 2016 17:47:36 +0000 Subject: Fix a couple of errors when deleting pushers --- synapse/storage/pusher.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 29da3bbd13..87b2ac5773 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -136,7 +136,7 @@ class PusherStore(SQLBaseStore): @defer.inlineCallbacks def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id): def delete_pusher_txn(txn, stream_id): - self._simple_delete_one( + self._simple_delete_one_txn( txn, "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id} @@ -145,7 +145,7 @@ class PusherStore(SQLBaseStore): txn, "deleted_pushers", {"app_id": app_id, "pushkey": pushkey, "user_id": user_id}, - {"stream_id", stream_id}, + {"stream_id": stream_id}, ) with self._pushers_id_gen.get_next() as stream_id: yield self.runInteraction( -- cgit 1.4.1 From ba660ecde20544ac1cfc163a5586f4f202627afa Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Wed, 16 Mar 2016 10:35:00 +0000 Subject: Add a comment to offer a hint to an explanation for why we have a unique constraint on (app_id, pushkey, user_id) --- synapse/storage/schema/delta/30/deleted_pushers.sql | 1 + 1 file changed, 1 insertion(+) (limited to 'synapse/storage') diff --git a/synapse/storage/schema/delta/30/deleted_pushers.sql b/synapse/storage/schema/delta/30/deleted_pushers.sql index cdcf79ac81..712c454aa1 100644 --- a/synapse/storage/schema/delta/30/deleted_pushers.sql +++ b/synapse/storage/schema/delta/30/deleted_pushers.sql @@ -18,6 +18,7 @@ CREATE TABLE IF NOT EXISTS deleted_pushers( app_id TEXT NOT NULL, pushkey TEXT NOT NULL, user_id TEXT NOT NULL, + /* We only track the most recent delete for each app_id, pushkey and user_id. */ UNIQUE (app_id, pushkey, user_id) ); -- cgit 1.4.1 From 673c96ce97052126f5bfd11c7dcc19880614ec25 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Thu, 17 Mar 2016 11:01:28 +0000 Subject: Remove dead code left over from presence changes --- synapse/handlers/events.py | 70 ---------------------------------------- synapse/handlers/presence.py | 4 --- synapse/storage/roommember.py | 24 -------------- tests/storage/test_roommember.py | 10 ------ 4 files changed, 108 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 72a31a9755..f25a252523 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -18,7 +18,6 @@ from twisted.internet import defer from synapse.util.logutils import log_function from synapse.types import UserID from synapse.events.utils import serialize_event -from synapse.util.logcontext import preserve_context_over_fn from synapse.api.constants import Membership, EventTypes from synapse.events import EventBase @@ -31,20 +30,6 @@ import random logger = logging.getLogger(__name__) -def started_user_eventstream(distributor, user): - return preserve_context_over_fn( - distributor.fire, - "started_user_eventstream", user - ) - - -def stopped_user_eventstream(distributor, user): - return preserve_context_over_fn( - distributor.fire, - "stopped_user_eventstream", user - ) - - class EventStreamHandler(BaseHandler): def __init__(self, hs): @@ -63,61 +48,6 @@ class EventStreamHandler(BaseHandler): self.notifier = hs.get_notifier() - @defer.inlineCallbacks - def started_stream(self, user): - """Tells the presence handler that we have started an eventstream for - the user: - - Args: - user (User): The user who started a stream. - Returns: - A deferred that completes once their presence has been updated. - """ - if user not in self._streams_per_user: - # Make sure we set the streams per user to 1 here rather than - # setting it to zero and incrementing the value below. - # Otherwise this may race with stopped_stream causing the - # user to be erased from the map before we have a chance - # to increment it. - self._streams_per_user[user] = 1 - if user in self._stop_timer_per_user: - try: - self.clock.cancel_call_later( - self._stop_timer_per_user.pop(user) - ) - except: - logger.exception("Failed to cancel event timer") - else: - yield started_user_eventstream(self.distributor, user) - else: - self._streams_per_user[user] += 1 - - def stopped_stream(self, user): - """If there are no streams for a user this starts a timer that will - notify the presence handler that we haven't got an event stream for - the user unless the user starts a new stream in 30 seconds. - - Args: - user (User): The user who stopped a stream. - """ - self._streams_per_user[user] -= 1 - if not self._streams_per_user[user]: - del self._streams_per_user[user] - - # 30 seconds of grace to allow the client to reconnect again - # before we think they're gone - def _later(): - logger.debug("_later stopped_user_eventstream %s", user) - - self._stop_timer_per_user.pop(user, None) - - return stopped_user_eventstream(self.distributor, user) - - logger.debug("Scheduling _later: for %s", user) - self._stop_timer_per_user[user] = ( - self.clock.call_later(30, _later) - ) - @defer.inlineCallbacks @log_function def get_stream(self, auth_user_id, pagin_config, timeout=0, diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index f6cf343174..cfbcf2d32c 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -73,10 +73,6 @@ FEDERATION_PING_INTERVAL = 25 * 60 * 1000 assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER -def user_presence_changed(distributor, user, statuscache): - return distributor.fire("user_presence_changed", user, statuscache) - - def collect_presencelike_data(distributor, user, content): return distributor.fire("collect_presencelike_data", user, content) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 3065b0c1a5..0cd89260f2 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -251,30 +251,6 @@ class RoomMemberStore(SQLBaseStore): user_id, membership_list=[Membership.JOIN], ) - @defer.inlineCallbacks - def user_rooms_intersect(self, user_id_list): - """ Checks whether all the users whose IDs are given in a list share a - room. - - This is a "hot path" function that's called a lot, e.g. by presence for - generating the event stream. As such, it is implemented locally by - wrapping logic around heavily-cached database queries. - """ - if len(user_id_list) < 2: - defer.returnValue(True) - - deferreds = [self.get_rooms_for_user(u) for u in user_id_list] - - results = yield defer.DeferredList(deferreds, consumeErrors=True) - - # A list of sets of strings giving room IDs for each user - room_id_lists = [set([r.room_id for r in result[1]]) for result in results] - - # There isn't a setintersection(*list_of_sets) - ret = len(room_id_lists.pop(0).intersection(*room_id_lists)) > 0 - - defer.returnValue(ret) - @defer.inlineCallbacks def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 677d11f68d..b029ff0584 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -91,11 +91,6 @@ class RoomMemberStoreTestCase(unittest.TestCase): ) )] ) - self.assertFalse( - (yield self.store.user_rooms_intersect( - [self.u_alice.to_string(), self.u_bob.to_string()] - )) - ) @defer.inlineCallbacks def test_two_members(self): @@ -108,11 +103,6 @@ class RoomMemberStoreTestCase(unittest.TestCase): yield self.store.get_room_members(self.room.to_string()) )} ) - self.assertTrue(( - yield self.store.user_rooms_intersect([ - self.u_alice.to_string(), self.u_bob.to_string() - ]) - )) @defer.inlineCallbacks def test_room_hosts(self): -- cgit 1.4.1 From 67ed8065dba960055c2e3d1740af12229b7d19a4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 18 Mar 2016 14:31:31 +0000 Subject: Dedupe requested event list in _get_events --- synapse/storage/events.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 552e7ca35b..285c586cfe 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -526,6 +526,9 @@ class EventsStore(SQLBaseStore): if not event_ids: defer.returnValue([]) + event_id_list = event_ids + event_ids = set(event_ids) + event_map = self._get_events_from_cache( event_ids, check_redacted=check_redacted, @@ -535,23 +538,18 @@ class EventsStore(SQLBaseStore): missing_events_ids = [e for e in event_ids if e not in event_map] - if not missing_events_ids: - defer.returnValue([ - event_map[e_id] for e_id in event_ids - if e_id in event_map and event_map[e_id] - ]) - - missing_events = yield self._enqueue_events( - missing_events_ids, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - ) + if missing_events_ids: + missing_events = yield self._enqueue_events( + missing_events_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) - event_map.update(missing_events) + event_map.update(missing_events) defer.returnValue([ - event_map[e_id] for e_id in event_ids + event_map[e_id] for e_id in event_id_list if e_id in event_map and event_map[e_id] ]) -- cgit 1.4.1