diff options
Diffstat (limited to 'synapse/storage')
35 files changed, 1314 insertions, 499 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 5a9e7720d9..250ba536ea 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -20,7 +20,7 @@ from .appservice import ( from ._base import Cache from .directory import DirectoryStore from .events import EventsStore -from .presence import PresenceStore +from .presence import PresenceStore, UserPresenceState from .profile import ProfileStore from .registration import RegistrationStore from .room import RoomStore @@ -45,8 +45,9 @@ from .search import SearchStore from .tags import TagsStore from .account_data import AccountDataStore -from util.id_generators import IdGenerator, StreamIdGenerator +from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator +from synapse.api.constants import PresenceState from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -110,16 +111,25 @@ class DataStore(RoomMemberStore, RoomStore, self._account_data_id_gen = StreamIdGenerator( db_conn, "account_data_max_stream_id", "stream_id" ) + self._presence_id_gen = StreamIdGenerator( + db_conn, "presence_stream", "stream_id" + ) - self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) - self._state_groups_id_gen = IdGenerator("state_groups", "id", self) - self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) - self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self) - self._pushers_id_gen = IdGenerator("pushers", "id", self) - self._push_rule_id_gen = IdGenerator("push_rules", "id", self) - self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) + self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") + self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") + self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") + self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") + self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") + self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") + self._push_rules_stream_id_gen = ChainedIdGenerator( + self._stream_id_gen, db_conn, "push_rules_stream", "stream_id" + ) + self._pushers_id_gen = StreamIdGenerator( + db_conn, "pushers", "id", + extra_tables=[("deleted_pushers", "stream_id")], + ) - events_max = self._stream_id_gen.get_max_token(None) + events_max = self._stream_id_gen.get_max_token() event_cache_prefill, min_event_val = self._get_cache_dict( db_conn, "events", entity_column="room_id", @@ -135,13 +145,43 @@ class DataStore(RoomMemberStore, RoomStore, "MembershipStreamChangeCache", events_max, ) - account_max = self._account_data_id_gen.get_max_token(None) + account_max = self._account_data_id_gen.get_max_token() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max, ) + self.__presence_on_startup = self._get_active_presence(db_conn) + + presence_cache_prefill, min_presence_val = self._get_cache_dict( + db_conn, "presence_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self._presence_id_gen.get_max_token(), + ) + self.presence_stream_cache = StreamChangeCache( + "PresenceStreamChangeCache", min_presence_val, + prefilled_cache=presence_cache_prefill + ) + + push_rules_prefill, push_rules_id = self._get_cache_dict( + db_conn, "push_rules_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self._push_rules_stream_id_gen.get_max_token()[0], + ) + + self.push_rules_stream_cache = StreamChangeCache( + "PushRulesStreamChangeCache", push_rules_id, + prefilled_cache=push_rules_prefill, + ) + super(DataStore, self).__init__(hs) + def take_presence_startup_info(self): + active_on_startup = self.__presence_on_startup + self.__presence_on_startup = None + return active_on_startup + def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value): # Fetch a mapping of room_id -> max stream position for "recent" rooms. # It doesn't really matter how many we get, the StreamChangeCache will @@ -161,6 +201,7 @@ class DataStore(RoomMemberStore, RoomStore, txn = db_conn.cursor() txn.execute(sql, (int(max_value),)) rows = txn.fetchall() + txn.close() cache = { row[0]: int(row[1]) @@ -174,6 +215,28 @@ class DataStore(RoomMemberStore, RoomStore, return cache, min_val + def _get_active_presence(self, db_conn): + """Fetch non-offline presence from the database so that we can register + the appropriate time outs. + """ + + sql = ( + "SELECT user_id, state, last_active_ts, last_federation_update_ts," + " last_user_sync_ts, status_msg, currently_active FROM presence_stream" + " WHERE state != ?" + ) + sql = self.database_engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, (PresenceState.OFFLINE,)) + rows = self.cursor_to_dict(txn) + txn.close() + + for row in rows: + row["currently_active"] = bool(row["currently_active"]) + + return [UserPresenceState(**row) for row in rows] + @defer.inlineCallbacks def insert_client_ip(self, user, access_token, ip, user_agent): now = int(self._clock.time_msec()) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 2e97ac84a8..b75b79df36 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -18,6 +18,7 @@ from synapse.api.errors import StoreError from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.descriptors import Cache +from synapse.util.caches import intern_dict import synapse.metrics @@ -26,6 +27,10 @@ from twisted.internet import defer import sys import time import threading +import os + + +CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) logger = logging.getLogger(__name__) @@ -163,7 +168,9 @@ class SQLBaseStore(object): self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, max_entries=hs.config.event_cache_size) - self._state_group_cache = DictionaryCache("*stateGroupCache*", 2000) + self._state_group_cache = DictionaryCache( + "*stateGroupCache*", 2000 * CACHE_SIZE_FACTOR + ) self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] @@ -344,7 +351,7 @@ class SQLBaseStore(object): """ col_headers = list(column[0] for column in cursor.description) results = list( - dict(zip(col_headers, row)) for row in cursor.fetchall() + intern_dict(dict(zip(col_headers, row))) for row in cursor.fetchall() ) return results @@ -770,18 +777,29 @@ class SQLBaseStore(object): table : string giving the table name keyvalues : dict of column names and values to select the row with """ + return self.runInteraction( + desc, self._simple_delete_one_txn, table, keyvalues + ) + + @staticmethod + def _simple_delete_one_txn(txn, table, keyvalues): + """Executes a DELETE query on the named table, expecting to delete a + single row. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + """ sql = "DELETE FROM %s WHERE %s" % ( table, " AND ".join("%s = ?" % (k, ) for k in keyvalues) ) - def func(txn): - txn.execute(sql, keyvalues.values()) - if txn.rowcount == 0: - raise StoreError(404, "No row found") - if txn.rowcount > 1: - raise StoreError(500, "more than one row matched") - return self.runInteraction(desc, func) + txn.execute(sql, keyvalues.values()) + if txn.rowcount == 0: + raise StoreError(404, "No row found") + if txn.rowcount > 1: + raise StoreError(500, "more than one row matched") @staticmethod def _simple_delete_txn(txn, table, keyvalues): diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index b8387fc500..faddefe219 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -83,8 +83,40 @@ class AccountDataStore(SQLBaseStore): "get_account_data_for_room", get_account_data_for_room_txn ) - def get_updated_account_data_for_user(self, user_id, stream_id, room_ids=None): - """Get all the client account_data for a that's changed. + def get_all_updated_account_data(self, last_global_id, last_room_id, + current_id, limit): + """Get all the client account_data that has changed on the server + Args: + last_global_id(int): The position to fetch from for top level data + last_room_id(int): The position to fetch from for per room data + current_id(int): The position to fetch up to. + Returns: + A deferred pair of lists of tuples of stream_id int, user_id string, + room_id string, type string, and content string. + """ + def get_updated_account_data_txn(txn): + sql = ( + "SELECT stream_id, user_id, account_data_type, content" + " FROM account_data WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_global_id, current_id, limit)) + global_results = txn.fetchall() + + sql = ( + "SELECT stream_id, user_id, room_id, account_data_type, content" + " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_room_id, current_id, limit)) + room_results = txn.fetchall() + return (global_results, room_results) + return self.runInteraction( + "get_all_updated_account_data_txn", get_updated_account_data_txn + ) + + def get_updated_account_data_for_user(self, user_id, stream_id): + """Get all the client account_data for a that's changed for a user Args: user_id(str): The user to get the account_data for. @@ -163,12 +195,12 @@ class AccountDataStore(SQLBaseStore): ) self._update_max_stream_id(txn, next_id) - with (yield self._account_data_id_gen.get_next(self)) as next_id: + with self._account_data_id_gen.get_next() as next_id: yield self.runInteraction( "add_room_account_data", add_account_data_txn, next_id ) - result = yield self._account_data_id_gen.get_max_token(self) + result = self._account_data_id_gen.get_max_token() defer.returnValue(result) @defer.inlineCallbacks @@ -202,12 +234,12 @@ class AccountDataStore(SQLBaseStore): ) self._update_max_stream_id(txn, next_id) - with (yield self._account_data_id_gen.get_next(self)) as next_id: + with self._account_data_id_gen.get_next() as next_id: yield self.runInteraction( "add_user_account_data", add_account_data_txn, next_id ) - result = yield self._account_data_id_gen.get_max_token(self) + result = self._account_data_id_gen.get_max_token() defer.returnValue(result) def _update_max_stream_id(self, txn, next_id): diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index 1100c67714..371600eebb 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -34,8 +34,8 @@ class ApplicationServiceStore(SQLBaseStore): def __init__(self, hs): super(ApplicationServiceStore, self).__init__(hs) self.hostname = hs.hostname - self.services_cache = [] - self._populate_appservice_cache( + self.services_cache = ApplicationServiceStore.load_appservices( + hs.hostname, hs.config.app_service_config_files ) @@ -144,21 +144,23 @@ class ApplicationServiceStore(SQLBaseStore): return rooms_for_user_matching_user_id - def _load_appservice(self, as_info): + @classmethod + def _load_appservice(cls, hostname, as_info, config_filename): required_string_fields = [ - # TODO: Add id here when it's stable to release - "url", "as_token", "hs_token", "sender_localpart" + "id", "url", "as_token", "hs_token", "sender_localpart" ] for field in required_string_fields: if not isinstance(as_info.get(field), basestring): - raise KeyError("Required string field: '%s'", field) + raise KeyError("Required string field: '%s' (%s)" % ( + field, config_filename, + )) localpart = as_info["sender_localpart"] if urllib.quote(localpart) != localpart: raise ValueError( "sender_localpart needs characters which are not URL encoded." ) - user = UserID(localpart, self.hostname) + user = UserID(localpart, hostname) user_id = user.to_string() # namespace checks @@ -188,25 +190,30 @@ class ApplicationServiceStore(SQLBaseStore): namespaces=as_info["namespaces"], hs_token=as_info["hs_token"], sender=user_id, - id=as_info["id"] if "id" in as_info else as_info["as_token"], + id=as_info["id"], ) - def _populate_appservice_cache(self, config_files): - """Populates a cache of Application Services from the config files.""" + @classmethod + def load_appservices(cls, hostname, config_files): + """Returns a list of Application Services from the config files.""" if not isinstance(config_files, list): logger.warning( "Expected %s to be a list of AS config files.", config_files ) - return + return [] # Dicts of value -> filename seen_as_tokens = {} seen_ids = {} + appservices = [] + for config_file in config_files: try: with open(config_file, 'r') as f: - appservice = self._load_appservice(yaml.load(f)) + appservice = ApplicationServiceStore._load_appservice( + hostname, yaml.load(f), config_file + ) if appservice.id in seen_ids: raise ConfigError( "Cannot reuse ID across application services: " @@ -226,11 +233,12 @@ class ApplicationServiceStore(SQLBaseStore): ) seen_as_tokens[appservice.token] = config_file logger.info("Loaded application service: %s", appservice) - self.services_cache.append(appservice) + appservices.append(appservice) except Exception as e: logger.error("Failed to load appservice from '%s'", config_file) logger.exception(e) raise + return appservices class ApplicationServiceTransactionStore(SQLBaseStore): diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 1556619d5e..ef231a04dc 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -70,13 +70,14 @@ class DirectoryStore(SQLBaseStore): ) @defer.inlineCallbacks - def create_room_alias_association(self, room_alias, room_id, servers): + def create_room_alias_association(self, room_alias, room_id, servers, creator=None): """ Creates an associatin between a room alias and room_id/servers Args: room_alias (RoomAlias) room_id (str) servers (list) + creator (str): Optional user_id of creator. Returns: Deferred @@ -87,6 +88,7 @@ class DirectoryStore(SQLBaseStore): { "room_alias": room_alias.to_string(), "room_id": room_id, + "creator": creator, }, desc="create_room_alias_association", ) @@ -107,6 +109,17 @@ class DirectoryStore(SQLBaseStore): ) self.get_aliases_for_room.invalidate((room_id,)) + def get_room_alias_creator(self, room_alias): + return self._simple_select_one_onecol( + table="room_aliases", + keyvalues={ + "room_alias": room_alias, + }, + retcol="creator", + desc="get_room_alias_creator", + allow_none=True + ) + @defer.inlineCallbacks def delete_room_alias(self, room_alias): room_id = yield self.runInteraction( @@ -142,7 +155,7 @@ class DirectoryStore(SQLBaseStore): return room_id - @cached() + @cached(max_entries=5000) def get_aliases_for_room(self, room_id): return self._simple_select_onecol( "room_aliases", diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 5dd32b1413..2e89066515 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from _base import SQLBaseStore +from ._base import SQLBaseStore class EndToEndKeyStore(SQLBaseStore): diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index 4290aea83a..a48230b93f 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -26,12 +26,13 @@ SUPPORTED_MODULE = { } -def create_engine(name): +def create_engine(config): + name = config.database_config["name"] engine_class = SUPPORTED_MODULE.get(name, None) if engine_class: module = importlib.import_module(name) - return engine_class(module) + return engine_class(module, config=config) raise RuntimeError( "Unsupported database engine '%s'" % (name,) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 17b7a9c077..a09685b4df 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -21,9 +21,10 @@ from ._base import IncorrectDatabaseSetup class PostgresEngine(object): single_threaded = False - def __init__(self, database_module): + def __init__(self, database_module, config): self.module = database_module self.module.extensions.register_type(self.module.extensions.UNICODE) + self.config = config def check_database(self, txn): txn.execute("SHOW SERVER_ENCODING") @@ -44,7 +45,7 @@ class PostgresEngine(object): ) def prepare_database(self, db_conn): - prepare_database(db_conn, self) + prepare_database(db_conn, self, config=self.config) def is_deadlock(self, error): if isinstance(error, self.module.DatabaseError): diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index 91fac33b8b..522b905949 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -23,8 +23,9 @@ import struct class Sqlite3Engine(object): single_threaded = True - def __init__(self, database_module): + def __init__(self, database_module, config): self.module = database_module + self.config = config def check_database(self, txn): pass @@ -38,7 +39,7 @@ class Sqlite3Engine(object): def prepare_database(self, db_conn): prepare_sqlite3_database(db_conn) - prepare_database(db_conn, self) + prepare_database(db_conn, self, config=self.config) def is_deadlock(self, error): return False diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index ce2c794025..3489315e0d 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -114,10 +114,10 @@ class EventFederationStore(SQLBaseStore): retcol="event_id", ) - def get_latest_events_in_room(self, room_id): + def get_latest_event_ids_and_hashes_in_room(self, room_id): return self.runInteraction( - "get_latest_events_in_room", - self._get_latest_events_in_room, + "get_latest_event_ids_and_hashes_in_room", + self._get_latest_event_ids_and_hashes_in_room, room_id, ) @@ -132,7 +132,7 @@ class EventFederationStore(SQLBaseStore): desc="get_latest_event_ids_in_room", ) - def _get_latest_events_in_room(self, txn, room_id): + def _get_latest_event_ids_and_hashes_in_room(self, txn, room_id): sql = ( "SELECT e.event_id, e.depth FROM events as e " "INNER JOIN event_forward_extremities as f " diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index d77a817682..dc5830450a 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -27,15 +27,14 @@ class EventPushActionsStore(SQLBaseStore): def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): """ :param event: the event set actions for - :param tuples: list of tuples of (user_id, profile_tag, actions) + :param tuples: list of tuples of (user_id, actions) """ values = [] - for uid, profile_tag, actions in tuples: + for uid, actions in tuples: values.append({ 'room_id': event.room_id, 'event_id': event.event_id, 'user_id': uid, - 'profile_tag': profile_tag, 'actions': json.dumps(actions), 'stream_ordering': event.internal_metadata.stream_ordering, 'topological_ordering': event.depth, @@ -43,14 +42,14 @@ class EventPushActionsStore(SQLBaseStore): 'highlight': 1 if _action_has_highlight(actions) else 0, }) - for uid, _, __ in tuples: + for uid, __ in tuples: txn.call_after( self.get_unread_event_push_actions_by_room_for_user.invalidate_many, (event.room_id, uid) ) self._simple_insert_many_txn(txn, "event_push_actions", values) - @cachedInlineCallbacks(num_args=3, lru=True, tree=True) + @cachedInlineCallbacks(num_args=3, lru=True, tree=True, max_entries=5000) def get_unread_event_push_actions_by_room_for_user( self, room_id, user_id, last_read_event_id ): diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 3a5c6ee4b1..5233430028 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from _base import SQLBaseStore, _RollbackButIsFineException +from ._base import SQLBaseStore, _RollbackButIsFineException from twisted.internet import defer, reactor @@ -75,8 +75,8 @@ class EventsStore(SQLBaseStore): yield stream_orderings stream_ordering_manager = stream_ordering_manager() else: - stream_ordering_manager = yield self._stream_id_gen.get_next_mult( - self, len(events_and_contexts) + stream_ordering_manager = self._stream_id_gen.get_next_mult( + len(events_and_contexts) ) with stream_ordering_manager as stream_orderings: @@ -101,37 +101,23 @@ class EventsStore(SQLBaseStore): @defer.inlineCallbacks @log_function - def persist_event(self, event, context, backfilled=False, + def persist_event(self, event, context, is_new_state=True, current_state=None): - stream_ordering = None - if backfilled: - self.min_stream_token -= 1 - stream_ordering = self.min_stream_token - - if stream_ordering is None: - stream_ordering_manager = yield self._stream_id_gen.get_next(self) - else: - @contextmanager - def stream_ordering_manager(): - yield stream_ordering - stream_ordering_manager = stream_ordering_manager() - try: - with stream_ordering_manager as stream_ordering: + with self._stream_id_gen.get_next() as stream_ordering: event.internal_metadata.stream_ordering = stream_ordering yield self.runInteraction( "persist_event", self._persist_event_txn, event=event, context=context, - backfilled=backfilled, is_new_state=is_new_state, current_state=current_state, ) except _RollbackButIsFineException: pass - max_persisted_id = yield self._stream_id_gen.get_max_token(self) + max_persisted_id = yield self._stream_id_gen.get_max_token() defer.returnValue((stream_ordering, max_persisted_id)) @defer.inlineCallbacks @@ -165,13 +151,38 @@ class EventsStore(SQLBaseStore): defer.returnValue(events[0] if events else None) + @defer.inlineCallbacks + def get_events(self, event_ids, check_redacted=True, + get_prev_content=False, allow_rejected=False): + """Get events from the database + + Args: + event_ids (list): The event_ids of the events to fetch + check_redacted (bool): If True, check if event has been redacted + and redact it. + get_prev_content (bool): If True and event is a state event, + include the previous states content in the unsigned field. + allow_rejected (bool): If True return rejected events. + + Returns: + Deferred : Dict from event_id to event. + """ + events = yield self._get_events( + event_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + defer.returnValue({e.event_id: e for e in events}) + @log_function - def _persist_event_txn(self, txn, event, context, backfilled, + def _persist_event_txn(self, txn, event, context, is_new_state=True, current_state=None): # We purposefully do this first since if we include a `current_state` # key, we *want* to update the `current_state_events` table if current_state: - txn.call_after(self.get_current_state_for_key.invalidate_all) + txn.call_after(self._get_current_state_for_key.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) @@ -198,7 +209,7 @@ class EventsStore(SQLBaseStore): return self._persist_events_txn( txn, [(event, context)], - backfilled=backfilled, + backfilled=False, is_new_state=is_new_state, ) @@ -455,7 +466,7 @@ class EventsStore(SQLBaseStore): for event, _ in state_events_and_contexts: if not context.rejected: txn.call_after( - self.get_current_state_for_key.invalidate, + self._get_current_state_for_key.invalidate, (event.room_id, event.type, event.state_key,) ) @@ -526,6 +537,9 @@ class EventsStore(SQLBaseStore): if not event_ids: defer.returnValue([]) + event_id_list = event_ids + event_ids = set(event_ids) + event_map = self._get_events_from_cache( event_ids, check_redacted=check_redacted, @@ -535,23 +549,18 @@ class EventsStore(SQLBaseStore): missing_events_ids = [e for e in event_ids if e not in event_map] - if not missing_events_ids: - defer.returnValue([ - event_map[e_id] for e_id in event_ids - if e_id in event_map and event_map[e_id] - ]) - - missing_events = yield self._enqueue_events( - missing_events_ids, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - ) + if missing_events_ids: + missing_events = yield self._enqueue_events( + missing_events_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) - event_map.update(missing_events) + event_map.update(missing_events) defer.returnValue([ - event_map[e_id] for e_id in event_ids + event_map[e_id] for e_id in event_id_list if e_id in event_map and event_map[e_id] ]) @@ -1064,3 +1073,48 @@ class EventsStore(SQLBaseStore): yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME) defer.returnValue(result) + + def get_current_backfill_token(self): + """The current minimum token that backfilled events have reached""" + + # TODO: Fix race with the persit_event txn by using one of the + # stream id managers + return -self.min_stream_token + + def get_all_new_events(self, last_backfill_id, last_forward_id, + current_backfill_id, current_forward_id, limit): + """Get all the new events that have arrived at the server either as + new events or as backfilled events""" + def get_all_new_events_txn(txn): + sql = ( + "SELECT e.stream_ordering, ej.internal_metadata, ej.json" + " FROM events as e" + " JOIN event_json as ej" + " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" + " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?" + " ORDER BY e.stream_ordering ASC" + " LIMIT ?" + ) + if last_forward_id != current_forward_id: + txn.execute(sql, (last_forward_id, current_forward_id, limit)) + new_forward_events = txn.fetchall() + else: + new_forward_events = [] + + sql = ( + "SELECT -e.stream_ordering, ej.internal_metadata, ej.json" + " FROM events as e" + " JOIN event_json as ej" + " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" + " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?" + " ORDER BY e.stream_ordering DESC" + " LIMIT ?" + ) + if last_backfill_id != current_backfill_id: + txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) + new_backfill_events = txn.fetchall() + else: + new_backfill_events = [] + + return (new_forward_events, new_backfill_events) + return self.runInteraction("get_all_new_events", get_all_new_events_txn) diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index fd05bfe54e..a495a8a7d9 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from _base import SQLBaseStore +from ._base import SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks from twisted.internet import defer diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py index 0894384780..9d3ba32478 100644 --- a/synapse/storage/media_repository.py +++ b/synapse/storage/media_repository.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from _base import SQLBaseStore +from ._base import SQLBaseStore class MediaRepositoryStore(SQLBaseStore): diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 850736c85e..3f29aad1e8 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 29 +SCHEMA_VERSION = 30 dir_path = os.path.abspath(os.path.dirname(__file__)) @@ -50,7 +50,7 @@ class UpgradeDatabaseException(PrepareDatabaseException): pass -def prepare_database(db_conn, database_engine): +def prepare_database(db_conn, database_engine, config): """Prepares a database for usage. Will either create all necessary tables or upgrade from an older schema version. """ @@ -61,10 +61,10 @@ def prepare_database(db_conn, database_engine): if version_info: user_version, delta_files, upgraded = version_info _upgrade_existing_database( - cur, user_version, delta_files, upgraded, database_engine + cur, user_version, delta_files, upgraded, database_engine, config ) else: - _setup_new_database(cur, database_engine) + _setup_new_database(cur, database_engine, config) # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) @@ -75,7 +75,7 @@ def prepare_database(db_conn, database_engine): raise -def _setup_new_database(cur, database_engine): +def _setup_new_database(cur, database_engine, config): """Sets up the database by finding a base set of "full schemas" and then applying any necessary deltas. @@ -148,11 +148,12 @@ def _setup_new_database(cur, database_engine): applied_delta_files=[], upgraded=False, database_engine=database_engine, + config=config, ) def _upgrade_existing_database(cur, current_version, applied_delta_files, - upgraded, database_engine): + upgraded, database_engine, config): """Upgrades an existing database. Delta files can either be SQL stored in *.sql files, or python modules @@ -245,7 +246,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, module_name, absolute_path, python_file ) logger.debug("Running script %s", relative_path) - module.run_upgrade(cur, database_engine) + module.run_upgrade(cur, database_engine, config=config) elif ext == ".pyc": # Sometimes .pyc files turn up anyway even though we've # disabled their generation; e.g. from distribution package diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index ef525f34c5..4cec31e316 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -14,73 +14,148 @@ # limitations under the License. from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedList +from synapse.api.constants import PresenceState +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from collections import namedtuple from twisted.internet import defer +class UserPresenceState(namedtuple("UserPresenceState", + ("user_id", "state", "last_active_ts", + "last_federation_update_ts", "last_user_sync_ts", + "status_msg", "currently_active"))): + """Represents the current presence state of the user. + + user_id (str) + last_active (int): Time in msec that the user last interacted with server. + last_federation_update (int): Time in msec since either a) we sent a presence + update to other servers or b) we received a presence update, depending + on if is a local user or not. + last_user_sync (int): Time in msec that the user last *completed* a sync + (or event stream). + status_msg (str): User set status message. + """ + + def copy_and_replace(self, **kwargs): + return self._replace(**kwargs) + + @classmethod + def default(cls, user_id): + """Returns a default presence state. + """ + return cls( + user_id=user_id, + state=PresenceState.OFFLINE, + last_active_ts=0, + last_federation_update_ts=0, + last_user_sync_ts=0, + status_msg=None, + currently_active=False, + ) + + class PresenceStore(SQLBaseStore): - def create_presence(self, user_localpart): - res = self._simple_insert( - table="presence", - values={"user_id": user_localpart}, - desc="create_presence", + @defer.inlineCallbacks + def update_presence(self, presence_states): + stream_ordering_manager = self._presence_id_gen.get_next_mult( + len(presence_states) ) - self.get_presence_state.invalidate((user_localpart,)) - return res + with stream_ordering_manager as stream_orderings: + yield self.runInteraction( + "update_presence", + self._update_presence_txn, stream_orderings, presence_states, + ) - def has_presence_state(self, user_localpart): - return self._simple_select_one( - table="presence", - keyvalues={"user_id": user_localpart}, - retcols=["user_id"], - allow_none=True, - desc="has_presence_state", + defer.returnValue((stream_orderings[-1], self._presence_id_gen.get_max_token())) + + def _update_presence_txn(self, txn, stream_orderings, presence_states): + for stream_id, state in zip(stream_orderings, presence_states): + txn.call_after( + self.presence_stream_cache.entity_has_changed, + state.user_id, stream_id, + ) + + # Actually insert new rows + self._simple_insert_many_txn( + txn, + table="presence_stream", + values=[ + { + "stream_id": stream_id, + "user_id": state.user_id, + "state": state.state, + "last_active_ts": state.last_active_ts, + "last_federation_update_ts": state.last_federation_update_ts, + "last_user_sync_ts": state.last_user_sync_ts, + "status_msg": state.status_msg, + "currently_active": state.currently_active, + } + for state in presence_states + ], ) - @cached(max_entries=2000) - def get_presence_state(self, user_localpart): - return self._simple_select_one( - table="presence", - keyvalues={"user_id": user_localpart}, - retcols=["state", "status_msg", "mtime"], - desc="get_presence_state", + # Delete old rows to stop database from getting really big + sql = ( + "DELETE FROM presence_stream WHERE" + " stream_id < ?" + " AND user_id IN (%s)" ) - @cachedList(get_presence_state.cache, list_name="user_localparts", - inlineCallbacks=True) - def get_presence_states(self, user_localparts): - rows = yield self._simple_select_many_batch( - table="presence", - column="user_id", - iterable=user_localparts, - retcols=("user_id", "state", "status_msg", "mtime",), - desc="get_presence_states", + batches = ( + presence_states[i:i + 50] + for i in xrange(0, len(presence_states), 50) ) + for states in batches: + args = [stream_id] + args.extend(s.user_id for s in states) + txn.execute( + sql % (",".join("?" for _ in states),), + args + ) + + def get_all_presence_updates(self, last_id, current_id): + def get_all_presence_updates_txn(txn): + sql = ( + "SELECT stream_id, user_id, state, last_active_ts," + " last_federation_update_ts, last_user_sync_ts, status_msg," + " currently_active" + " FROM presence_stream" + " WHERE ? < stream_id AND stream_id <= ?" + ) + txn.execute(sql, (last_id, current_id)) + return txn.fetchall() - defer.returnValue({ - row["user_id"]: { - "state": row["state"], - "status_msg": row["status_msg"], - "mtime": row["mtime"], - } - for row in rows - }) + return self.runInteraction( + "get_all_presence_updates", get_all_presence_updates_txn + ) @defer.inlineCallbacks - def set_presence_state(self, user_localpart, new_state): - res = yield self._simple_update_one( - table="presence", - keyvalues={"user_id": user_localpart}, - updatevalues={"state": new_state["state"], - "status_msg": new_state["status_msg"], - "mtime": self._clock.time_msec()}, - desc="set_presence_state", + def get_presence_for_users(self, user_ids): + rows = yield self._simple_select_many_batch( + table="presence_stream", + column="user_id", + iterable=user_ids, + keyvalues={}, + retcols=( + "user_id", + "state", + "last_active_ts", + "last_federation_update_ts", + "last_user_sync_ts", + "status_msg", + "currently_active", + ), ) - self.get_presence_state.invalidate((user_localpart,)) - defer.returnValue(res) + for row in rows: + row["currently_active"] = bool(row["currently_active"]) + + defer.returnValue([UserPresenceState(**row) for row in rows]) + + def get_current_presence_token(self): + return self._presence_id_gen.get_max_token() def allow_presence_visible(self, observed_localpart, observer_userid): return self._simple_insert( @@ -128,6 +203,7 @@ class PresenceStore(SQLBaseStore): desc="set_presence_list_accepted", ) self.get_presence_list_accepted.invalidate((observer_localpart,)) + self.get_presence_list_observers_accepted.invalidate((observed_userid,)) defer.returnValue(result) def get_presence_list(self, observer_localpart, accepted=None): @@ -154,6 +230,19 @@ class PresenceStore(SQLBaseStore): desc="get_presence_list_accepted", ) + @cachedInlineCallbacks() + def get_presence_list_observers_accepted(self, observed_userid): + user_localparts = yield self._simple_select_onecol( + table="presence_list", + keyvalues={"observed_user_id": observed_userid, "accepted": True}, + retcol="user_id", + desc="get_presence_list_accepted", + ) + + defer.returnValue([ + "@%s:%s" % (u, self.hs.hostname,) for u in user_localparts + ]) + @defer.inlineCallbacks def del_presence_list(self, observer_localpart, observed_userid): yield self._simple_delete_one( @@ -163,3 +252,4 @@ class PresenceStore(SQLBaseStore): desc="del_presence_list", ) self.get_presence_list_accepted.invalidate((observer_localpart,)) + self.get_presence_list_observers_accepted.invalidate((observed_userid,)) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index f9a48171ba..9dbad2fd5f 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -100,37 +100,37 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(results) @defer.inlineCallbacks - def add_push_rule(self, before, after, **kwargs): - vals = kwargs - if 'conditions' in vals: - vals['conditions'] = json.dumps(vals['conditions']) - if 'actions' in vals: - vals['actions'] = json.dumps(vals['actions']) - - # we could check the rest of the keys are valid column names - # but sqlite will do that anyway so I think it's just pointless. - vals.pop("id", None) - - if before or after: - ret = yield self.runInteraction( - "_add_push_rule_relative_txn", - self._add_push_rule_relative_txn, - before=before, - after=after, - **vals - ) - defer.returnValue(ret) - else: - ret = yield self.runInteraction( - "_add_push_rule_highest_priority_txn", - self._add_push_rule_highest_priority_txn, - **vals - ) - defer.returnValue(ret) + def add_push_rule( + self, user_id, rule_id, priority_class, conditions, actions, + before=None, after=None + ): + conditions_json = json.dumps(conditions) + actions_json = json.dumps(actions) + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids + if before or after: + yield self.runInteraction( + "_add_push_rule_relative_txn", + self._add_push_rule_relative_txn, + stream_id, event_stream_ordering, user_id, rule_id, priority_class, + conditions_json, actions_json, before, after, + ) + else: + yield self.runInteraction( + "_add_push_rule_highest_priority_txn", + self._add_push_rule_highest_priority_txn, + stream_id, event_stream_ordering, user_id, rule_id, priority_class, + conditions_json, actions_json, + ) + + def _add_push_rule_relative_txn( + self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class, + conditions_json, actions_json, before, after + ): + # Lock the table since otherwise we'll have annoying races between the + # SELECT here and the UPSERT below. + self.database_engine.lock_table(txn, "push_rules") - def _add_push_rule_relative_txn(self, txn, user_id, **kwargs): - after = kwargs.pop("after", None) - before = kwargs.pop("before", None) relative_to_rule = before or after res = self._simple_select_one_txn( @@ -149,69 +149,45 @@ class PushRuleStore(SQLBaseStore): "before/after rule not found: %s" % (relative_to_rule,) ) - priority_class = res["priority_class"] + base_priority_class = res["priority_class"] base_rule_priority = res["priority"] - if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class: + if base_priority_class != priority_class: raise InconsistentRuleException( "Given priority class does not match class of relative rule" ) - new_rule = kwargs - new_rule.pop("before", None) - new_rule.pop("after", None) - new_rule['priority_class'] = priority_class - new_rule['user_name'] = user_id - new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn) - - # check if the priority before/after is free - new_rule_priority = base_rule_priority - if after: - new_rule_priority -= 1 + if before: + # Higher priority rules are executed first, So adding a rule before + # a rule means giving it a higher priority than that rule. + new_rule_priority = base_rule_priority + 1 else: - new_rule_priority += 1 - - new_rule['priority'] = new_rule_priority + # We increment the priority of the existing rules to make space for + # the new rule. Therefore if we want this rule to appear after + # an existing rule we give it the priority of the existing rule, + # and then increment the priority of the existing rule. + new_rule_priority = base_rule_priority sql = ( - "SELECT COUNT(*) FROM push_rules" - " WHERE user_name = ? AND priority_class = ? AND priority = ?" + "UPDATE push_rules SET priority = priority + 1" + " WHERE user_name = ? AND priority_class = ? AND priority >= ?" ) - txn.execute(sql, (user_id, priority_class, new_rule_priority)) - res = txn.fetchall() - num_conflicting = res[0][0] - - # if there are conflicting rules, bump everything - if num_conflicting: - sql = "UPDATE push_rules SET priority = priority " - if after: - sql += "-1" - else: - sql += "+1" - sql += " WHERE user_name = ? AND priority_class = ? AND priority " - if after: - sql += "<= ?" - else: - sql += ">= ?" - txn.execute(sql, (user_id, priority_class, new_rule_priority)) + txn.execute(sql, (user_id, priority_class, new_rule_priority)) - txn.call_after( - self.get_push_rules_for_user.invalidate, (user_id,) + self._upsert_push_rule_txn( + txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class, + new_rule_priority, conditions_json, actions_json, ) - txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, (user_id,) - ) + def _add_push_rule_highest_priority_txn( + self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class, + conditions_json, actions_json + ): + # Lock the table since otherwise we'll have annoying races between the + # SELECT here and the UPSERT below. + self.database_engine.lock_table(txn, "push_rules") - self._simple_insert_txn( - txn, - table="push_rules", - values=new_rule, - ) - - def _add_push_rule_highest_priority_txn(self, txn, user_id, - priority_class, **kwargs): # find the highest priority rule in that class sql = ( "SELECT COUNT(*), MAX(priority) FROM push_rules" @@ -225,26 +201,61 @@ class PushRuleStore(SQLBaseStore): if how_many > 0: new_prio = highest_prio + 1 - # and insert the new rule - new_rule = kwargs - new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn) - new_rule['user_name'] = user_id - new_rule['priority_class'] = priority_class - new_rule['priority'] = new_prio - - txn.call_after( - self.get_push_rules_for_user.invalidate, (user_id,) - ) - txn.call_after( - self.get_push_rules_enabled_for_user.invalidate, (user_id,) + self._upsert_push_rule_txn( + txn, + stream_id, event_stream_ordering, user_id, rule_id, priority_class, new_prio, + conditions_json, actions_json, ) - self._simple_insert_txn( - txn, - table="push_rules", - values=new_rule, + def _upsert_push_rule_txn( + self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class, + priority, conditions_json, actions_json, update_stream=True + ): + """Specialised version of _simple_upsert_txn that picks a push_rule_id + using the _push_rule_id_gen if it needs to insert the rule. It assumes + that the "push_rules" table is locked""" + + sql = ( + "UPDATE push_rules" + " SET priority_class = ?, priority = ?, conditions = ?, actions = ?" + " WHERE user_name = ? AND rule_id = ?" ) + txn.execute(sql, ( + priority_class, priority, conditions_json, actions_json, + user_id, rule_id, + )) + + if txn.rowcount == 0: + # We didn't update a row with the given rule_id so insert one + push_rule_id = self._push_rule_id_gen.get_next() + + self._simple_insert_txn( + txn, + table="push_rules", + values={ + "id": push_rule_id, + "user_name": user_id, + "rule_id": rule_id, + "priority_class": priority_class, + "priority": priority, + "conditions": conditions_json, + "actions": actions_json, + }, + ) + + if update_stream: + self._insert_push_rules_update_txn( + txn, stream_id, event_stream_ordering, user_id, rule_id, + op="ADD", + data={ + "priority_class": priority_class, + "priority": priority, + "conditions": conditions_json, + "actions": actions_json, + } + ) + @defer.inlineCallbacks def delete_push_rule(self, user_id, rule_id): """ @@ -256,26 +267,38 @@ class PushRuleStore(SQLBaseStore): user_id (str): The matrix ID of the push rule owner rule_id (str): The rule_id of the rule to be deleted """ - yield self._simple_delete_one( - "push_rules", - {'user_name': user_id, 'rule_id': rule_id}, - desc="delete_push_rule", - ) + def delete_push_rule_txn(txn, stream_id, event_stream_ordering): + self._simple_delete_one_txn( + txn, + "push_rules", + {'user_name': user_id, 'rule_id': rule_id}, + ) + + self._insert_push_rules_update_txn( + txn, stream_id, event_stream_ordering, user_id, rule_id, + op="DELETE" + ) - self.get_push_rules_for_user.invalidate((user_id,)) - self.get_push_rules_enabled_for_user.invalidate((user_id,)) + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids + yield self.runInteraction( + "delete_push_rule", delete_push_rule_txn, stream_id, event_stream_ordering + ) @defer.inlineCallbacks def set_push_rule_enabled(self, user_id, rule_id, enabled): - ret = yield self.runInteraction( - "_set_push_rule_enabled_txn", - self._set_push_rule_enabled_txn, - user_id, rule_id, enabled - ) - defer.returnValue(ret) + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids + yield self.runInteraction( + "_set_push_rule_enabled_txn", + self._set_push_rule_enabled_txn, + stream_id, event_stream_ordering, user_id, rule_id, enabled + ) - def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled): - new_id = self._push_rules_enable_id_gen.get_next_txn(txn) + def _set_push_rule_enabled_txn( + self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled + ): + new_id = self._push_rules_enable_id_gen.get_next() self._simple_upsert_txn( txn, "push_rules_enable", @@ -283,12 +306,109 @@ class PushRuleStore(SQLBaseStore): {'enabled': 1 if enabled else 0}, {'id': new_id}, ) + + self._insert_push_rules_update_txn( + txn, stream_id, event_stream_ordering, user_id, rule_id, + op="ENABLE" if enabled else "DISABLE" + ) + + @defer.inlineCallbacks + def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): + actions_json = json.dumps(actions) + + def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): + if is_default_rule: + # Add a dummy rule to the rules table with the user specified + # actions. + priority_class = -1 + priority = 1 + self._upsert_push_rule_txn( + txn, stream_id, event_stream_ordering, user_id, rule_id, + priority_class, priority, "[]", actions_json, + update_stream=False + ) + else: + self._simple_update_one_txn( + txn, + "push_rules", + {'user_name': user_id, 'rule_id': rule_id}, + {'actions': actions_json}, + ) + + self._insert_push_rules_update_txn( + txn, stream_id, event_stream_ordering, user_id, rule_id, + op="ACTIONS", data={"actions": actions_json} + ) + + with self._push_rules_stream_id_gen.get_next() as ids: + stream_id, event_stream_ordering = ids + yield self.runInteraction( + "set_push_rule_actions", set_push_rule_actions_txn, + stream_id, event_stream_ordering + ) + + def _insert_push_rules_update_txn( + self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None + ): + values = { + "stream_id": stream_id, + "event_stream_ordering": event_stream_ordering, + "user_id": user_id, + "rule_id": rule_id, + "op": op, + } + if data is not None: + values.update(data) + + self._simple_insert_txn(txn, "push_rules_stream", values=values) + txn.call_after( self.get_push_rules_for_user.invalidate, (user_id,) ) txn.call_after( self.get_push_rules_enabled_for_user.invalidate, (user_id,) ) + txn.call_after( + self.push_rules_stream_cache.entity_has_changed, user_id, stream_id + ) + + def get_all_push_rule_updates(self, last_id, current_id, limit): + """Get all the push rules changes that have happend on the server""" + def get_all_push_rule_updates_txn(txn): + sql = ( + "SELECT stream_id, event_stream_ordering, user_id, rule_id," + " op, priority_class, priority, conditions, actions" + " FROM push_rules_stream" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + return self.runInteraction( + "get_all_push_rule_updates", get_all_push_rule_updates_txn + ) + + def get_push_rules_stream_token(self): + """Get the position of the push rules stream. + Returns a pair of a stream id for the push_rules stream and the + room stream ordering it corresponds to.""" + return self._push_rules_stream_id_gen.get_max_token() + + def have_push_rules_changed_for_user(self, user_id, last_id): + if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): + return defer.succeed(False) + else: + def have_push_rules_changed_txn(txn): + sql = ( + "SELECT COUNT(stream_id) FROM push_rules_stream" + " WHERE user_id = ? AND ? < stream_id" + ) + txn.execute(sql, (user_id, last_id)) + count, = txn.fetchone() + return bool(count) + return self.runInteraction( + "have_push_rules_changed", have_push_rules_changed_txn + ) class RuleNotFoundException(Exception): diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 8ec706178a..87b2ac5773 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -16,8 +16,6 @@ from ._base import SQLBaseStore from twisted.internet import defer -from synapse.api.errors import StoreError - from canonicaljson import encode_canonical_json import logging @@ -79,12 +77,41 @@ class PusherStore(SQLBaseStore): rows = yield self.runInteraction("get_all_pushers", get_pushers) defer.returnValue(rows) + def get_pushers_stream_token(self): + return self._pushers_id_gen.get_max_token() + + def get_all_updated_pushers(self, last_id, current_id, limit): + def get_all_updated_pushers_txn(txn): + sql = ( + "SELECT id, user_name, access_token, profile_tag, kind," + " app_id, app_display_name, device_display_name, pushkey, ts," + " lang, data" + " FROM pushers" + " WHERE ? < id AND id <= ?" + " ORDER BY id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + updated = txn.fetchall() + + sql = ( + "SELECT stream_id, user_id, app_id, pushkey" + " FROM deleted_pushers" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + deleted = txn.fetchall() + + return (updated, deleted) + return self.runInteraction( + "get_all_updated_pushers", get_all_updated_pushers_txn + ) + @defer.inlineCallbacks - def add_pusher(self, user_id, access_token, profile_tag, kind, app_id, + def add_pusher(self, user_id, access_token, kind, app_id, app_display_name, device_display_name, - pushkey, pushkey_ts, lang, data): - try: - next_id = yield self._pushers_id_gen.get_next() + pushkey, pushkey_ts, lang, data, profile_tag=""): + with self._pushers_id_gen.get_next() as stream_id: yield self._simple_upsert( "pushers", dict( @@ -95,29 +122,35 @@ class PusherStore(SQLBaseStore): dict( access_token=access_token, kind=kind, - profile_tag=profile_tag, app_display_name=app_display_name, device_display_name=device_display_name, ts=pushkey_ts, lang=lang, data=encode_canonical_json(data), - ), - insertion_values=dict( - id=next_id, + profile_tag=profile_tag, + id=stream_id, ), desc="add_pusher", ) - except Exception as e: - logger.error("create_pusher with failed: %s", e) - raise StoreError(500, "Problem creating pusher.") @defer.inlineCallbacks def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id): - yield self._simple_delete_one( - "pushers", - {"app_id": app_id, "pushkey": pushkey, 'user_name': user_id}, - desc="delete_pusher_by_app_id_pushkey_user_id", - ) + def delete_pusher_txn(txn, stream_id): + self._simple_delete_one_txn( + txn, + "pushers", + {"app_id": app_id, "pushkey": pushkey, "user_name": user_id} + ) + self._simple_upsert_txn( + txn, + "deleted_pushers", + {"app_id": app_id, "pushkey": pushkey, "user_id": user_id}, + {"stream_id": stream_id}, + ) + with self._pushers_id_gen.get_next() as stream_id: + yield self.runInteraction( + "delete_pusher", delete_pusher_txn, stream_id + ) @defer.inlineCallbacks def update_pusher_last_token(self, app_id, pushkey, user_id, last_token): diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index 4202a6b3dc..6b9d848eaa 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -31,7 +31,7 @@ class ReceiptsStore(SQLBaseStore): super(ReceiptsStore, self).__init__(hs) self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token(None) + "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token() ) @cached(num_args=2) @@ -62,18 +62,17 @@ class ReceiptsStore(SQLBaseStore): @cachedInlineCallbacks(num_args=2) def get_receipts_for_user(self, user_id, receipt_type): - def f(txn): - sql = ( - "SELECT room_id,event_id " - "FROM receipts_linearized " - "WHERE user_id = ? AND receipt_type = ? " - ) - txn.execute(sql, (user_id, receipt_type)) - return txn.fetchall() + rows = yield self._simple_select_list( + table="receipts_linearized", + keyvalues={ + "user_id": user_id, + "receipt_type": receipt_type, + }, + retcols=("room_id", "event_id"), + desc="get_receipts_for_user", + ) - defer.returnValue(dict( - (yield self.runInteraction("get_receipts_for_user", f)) - )) + defer.returnValue({row["room_id"]: row["event_id"] for row in rows}) @defer.inlineCallbacks def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): @@ -222,7 +221,7 @@ class ReceiptsStore(SQLBaseStore): defer.returnValue(results) def get_max_receipt_stream_id(self): - return self._receipts_id_gen.get_max_token(self) + return self._receipts_id_gen.get_max_token() def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, user_id, event_id, data, stream_id): @@ -330,7 +329,7 @@ class ReceiptsStore(SQLBaseStore): "insert_receipt_conv", graph_to_linear ) - stream_id_manager = yield self._receipts_id_gen.get_next(self) + stream_id_manager = self._receipts_id_gen.get_next() with stream_id_manager as stream_id: have_persisted = yield self.runInteraction( "insert_linearized_receipt", @@ -347,7 +346,7 @@ class ReceiptsStore(SQLBaseStore): room_id, receipt_type, user_id, event_ids, data ) - max_persisted_id = yield self._stream_id_gen.get_max_token(self) + max_persisted_id = self._stream_id_gen.get_max_token() defer.returnValue((stream_id, max_persisted_id)) @@ -390,3 +389,19 @@ class ReceiptsStore(SQLBaseStore): "data": json.dumps(data), } ) + + def get_all_updated_receipts(self, last_id, current_id, limit): + def get_all_updated_receipts_txn(txn): + sql = ( + "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" + " FROM receipts_linearized" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + " LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + + return txn.fetchall() + return self.runInteraction( + "get_all_updated_receipts", get_all_updated_receipts_txn + ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 967c732bda..bd4eb88a92 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -40,7 +40,7 @@ class RegistrationStore(SQLBaseStore): Raises: StoreError if there was a problem adding this. """ - next_id = yield self._access_tokens_id_gen.get_next() + next_id = self._access_tokens_id_gen.get_next() yield self._simple_insert( "access_tokens", @@ -62,7 +62,7 @@ class RegistrationStore(SQLBaseStore): Raises: StoreError if there was a problem adding this. """ - next_id = yield self._refresh_tokens_id_gen.get_next() + next_id = self._refresh_tokens_id_gen.get_next() yield self._simple_insert( "refresh_tokens", @@ -76,7 +76,7 @@ class RegistrationStore(SQLBaseStore): @defer.inlineCallbacks def register(self, user_id, token, password_hash, - was_guest=False, make_guest=False): + was_guest=False, make_guest=False, appservice_id=None): """Attempts to register an account. Args: @@ -87,19 +87,35 @@ class RegistrationStore(SQLBaseStore): upgraded to a non-guest account. make_guest (boolean): True if the the new user should be guest, false to add a regular user account. + appservice_id (str): The ID of the appservice registering the user. Raises: StoreError if the user_id could not be registered. """ yield self.runInteraction( "register", - self._register, user_id, token, password_hash, was_guest, make_guest + self._register, + user_id, + token, + password_hash, + was_guest, + make_guest, + appservice_id ) self.is_guest.invalidate((user_id,)) - def _register(self, txn, user_id, token, password_hash, was_guest, make_guest): + def _register( + self, + txn, + user_id, + token, + password_hash, + was_guest, + make_guest, + appservice_id + ): now = int(self.clock.time()) - next_id = self._access_tokens_id_gen.get_next_txn(txn) + next_id = self._access_tokens_id_gen.get_next() try: if was_guest: @@ -111,9 +127,21 @@ class RegistrationStore(SQLBaseStore): [password_hash, now, 1 if make_guest else 0, user_id]) else: txn.execute("INSERT INTO users " - "(name, password_hash, creation_ts, is_guest) " - "VALUES (?,?,?,?)", - [user_id, password_hash, now, 1 if make_guest else 0]) + "(" + " name," + " password_hash," + " creation_ts," + " is_guest," + " appservice_id" + ") " + "VALUES (?,?,?,?,?)", + [ + user_id, + password_hash, + now, + 1 if make_guest else 0, + appservice_id, + ]) except self.database_engine.module.IntegrityError: raise StoreError( 400, "User ID already taken.", errcode=Codes.USER_IN_USE @@ -167,27 +195,48 @@ class RegistrationStore(SQLBaseStore): }) @defer.inlineCallbacks - def user_delete_access_tokens(self, user_id): - yield self.runInteraction( - "user_delete_access_tokens", - self._user_delete_access_tokens, user_id - ) + def user_delete_access_tokens(self, user_id, except_token_ids=[]): + def f(txn): + sql = "SELECT token FROM access_tokens WHERE user_id = ?" + clauses = [user_id] - def _user_delete_access_tokens(self, txn, user_id): - txn.execute( - "DELETE FROM access_tokens WHERE user_id = ?", - (user_id, ) - ) + if except_token_ids: + sql += " AND id NOT IN (%s)" % ( + ",".join(["?" for _ in except_token_ids]), + ) + clauses += except_token_ids - @defer.inlineCallbacks - def flush_user(self, user_id): - rows = yield self._execute( - 'flush_user', None, - "SELECT token FROM access_tokens WHERE user_id = ?", - user_id - ) - for r in rows: - self.get_user_by_access_token.invalidate((r,)) + txn.execute(sql, clauses) + + rows = txn.fetchall() + + n = 100 + chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)] + for chunk in chunks: + for row in chunk: + txn.call_after(self.get_user_by_access_token.invalidate, (row[0],)) + + txn.execute( + "DELETE FROM access_tokens WHERE token in (%s)" % ( + ",".join(["?" for _ in chunk]), + ), [r[0] for r in chunk] + ) + + yield self.runInteraction("user_delete_access_tokens", f) + + def delete_access_token(self, access_token): + def f(txn): + self._simple_delete_one_txn( + txn, + table="access_tokens", + keyvalues={ + "token": access_token + }, + ) + + txn.call_after(self.get_user_by_access_token.invalidate, (access_token,)) + + return self.runInteraction("delete_access_token", f) @cached() def get_user_by_access_token(self, token): @@ -387,3 +436,47 @@ class RegistrationStore(SQLBaseStore): "find_next_generated_user_id", _find_next_generated_user_id ))) + + @defer.inlineCallbacks + def get_3pid_guest_access_token(self, medium, address): + ret = yield self._simple_select_one( + "threepid_guest_access_tokens", + { + "medium": medium, + "address": address + }, + ["guest_access_token"], True, 'get_3pid_guest_access_token' + ) + if ret: + defer.returnValue(ret["guest_access_token"]) + defer.returnValue(None) + + @defer.inlineCallbacks + def save_or_get_3pid_guest_access_token( + self, medium, address, access_token, inviter_user_id + ): + """ + Gets the 3pid's guest access token if exists, else saves access_token. + + :param medium (str): Medium of the 3pid. Must be "email". + :param address (str): 3pid address. + :param access_token (str): The access token to persist if none is + already persisted. + :param inviter_user_id (str): User ID of the inviter. + :return (deferred str): Whichever access token is persisted at the end + of this function call. + """ + def insert(txn): + txn.execute( + "INSERT INTO threepid_guest_access_tokens " + "(medium, address, guest_access_token, first_inviter) " + "VALUES (?, ?, ?, ?)", + (medium, address, access_token, inviter_user_id) + ) + + try: + yield self.runInteraction("save_3pid_guest_access_token", insert) + defer.returnValue(access_token) + except self.database_engine.module.IntegrityError: + ret = yield self.get_3pid_guest_access_token(medium, address) + defer.returnValue(ret) diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 46ab38a313..9be977f387 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -77,6 +77,14 @@ class RoomStore(SQLBaseStore): allow_none=True, ) + def set_room_is_public(self, room_id, is_public): + return self._simple_update_one( + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"is_public": is_public}, + desc="set_room_is_public", + ) + def get_public_room_ids(self): return self._simple_select_onecol( table="rooms", diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 3065b0c1a5..430b49c12e 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -115,19 +115,17 @@ class RoomMemberStore(SQLBaseStore): ).addCallback(self._get_events) @cached() - def get_invites_for_user(self, user_id): - """ Get all the invite events for a user + def get_invited_rooms_for_user(self, user_id): + """ Get all the rooms the user is invited to Args: user_id (str): The user ID. Returns: - A deferred list of event objects. + A deferred list of RoomsForUser. """ return self.get_rooms_for_user_where_membership_is( user_id, [Membership.INVITE] - ).addCallback(lambda invites: self._get_events([ - invite.event_id for invite in invites - ])) + ) def get_leave_and_ban_events_for_user(self, user_id): """ Get all the leave events for a user @@ -252,30 +250,6 @@ class RoomMemberStore(SQLBaseStore): ) @defer.inlineCallbacks - def user_rooms_intersect(self, user_id_list): - """ Checks whether all the users whose IDs are given in a list share a - room. - - This is a "hot path" function that's called a lot, e.g. by presence for - generating the event stream. As such, it is implemented locally by - wrapping logic around heavily-cached database queries. - """ - if len(user_id_list) < 2: - defer.returnValue(True) - - deferreds = [self.get_rooms_for_user(u) for u in user_id_list] - - results = yield defer.DeferredList(deferreds, consumeErrors=True) - - # A list of sets of strings giving room IDs for each user - room_id_lists = [set([r.room_id for r in result[1]]) for result in results] - - # There isn't a setintersection(*list_of_sets) - ret = len(room_id_lists.pop(0).intersection(*room_id_lists)) > 0 - - defer.returnValue(ret) - - @defer.inlineCallbacks def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" def f(txn): diff --git a/synapse/storage/schema/delta/30/alias_creator.sql b/synapse/storage/schema/delta/30/alias_creator.sql new file mode 100644 index 0000000000..c9d0dde638 --- /dev/null +++ b/synapse/storage/schema/delta/30/alias_creator.sql @@ -0,0 +1,16 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE room_aliases ADD COLUMN creator TEXT; diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py new file mode 100644 index 0000000000..4f6e9dd540 --- /dev/null +++ b/synapse/storage/schema/delta/30/as_users.py @@ -0,0 +1,68 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from synapse.storage.appservice import ApplicationServiceStore + + +logger = logging.getLogger(__name__) + + +def run_upgrade(cur, database_engine, config, *args, **kwargs): + # NULL indicates user was not registered by an appservice. + try: + cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT") + except: + # Maybe we already added the column? Hope so... + pass + + cur.execute("SELECT name FROM users") + rows = cur.fetchall() + + config_files = [] + try: + config_files = config.app_service_config_files + except AttributeError: + logger.warning("Could not get app_service_config_files from config") + pass + + appservices = ApplicationServiceStore.load_appservices( + config.server_name, config_files + ) + + owned = {} + + for row in rows: + user_id = row[0] + for appservice in appservices: + if appservice.is_exclusive_user(user_id): + if user_id in owned.keys(): + logger.error( + "user_id %s was owned by more than one application" + " service (IDs %s and %s); assigning arbitrarily to %s" % + (user_id, owned[user_id], appservice.id, owned[user_id]) + ) + owned.setdefault(appservice.id, []).append(user_id) + + for as_id, user_ids in owned.items(): + n = 100 + user_chunks = (user_ids[i:i + 100] for i in xrange(0, len(user_ids), n)) + for chunk in user_chunks: + cur.execute( + database_engine.convert_param_style( + "UPDATE users SET appservice_id = ? WHERE name IN (%s)" % ( + ",".join("?" for _ in chunk), + ) + ), + [as_id] + chunk + ) diff --git a/synapse/storage/schema/delta/30/deleted_pushers.sql b/synapse/storage/schema/delta/30/deleted_pushers.sql new file mode 100644 index 0000000000..712c454aa1 --- /dev/null +++ b/synapse/storage/schema/delta/30/deleted_pushers.sql @@ -0,0 +1,25 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS deleted_pushers( + stream_id BIGINT NOT NULL, + app_id TEXT NOT NULL, + pushkey TEXT NOT NULL, + user_id TEXT NOT NULL, + /* We only track the most recent delete for each app_id, pushkey and user_id. */ + UNIQUE (app_id, pushkey, user_id) +); + +CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id); diff --git a/synapse/storage/schema/delta/30/presence_stream.sql b/synapse/storage/schema/delta/30/presence_stream.sql new file mode 100644 index 0000000000..606bbb037d --- /dev/null +++ b/synapse/storage/schema/delta/30/presence_stream.sql @@ -0,0 +1,30 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + CREATE TABLE presence_stream( + stream_id BIGINT, + user_id TEXT, + state TEXT, + last_active_ts BIGINT, + last_federation_update_ts BIGINT, + last_user_sync_ts BIGINT, + status_msg TEXT, + currently_active BOOLEAN + ); + + CREATE INDEX presence_stream_id ON presence_stream(stream_id, user_id); + CREATE INDEX presence_stream_user_id ON presence_stream(user_id); + CREATE INDEX presence_stream_state ON presence_stream(state); diff --git a/synapse/storage/schema/delta/30/public_rooms.sql b/synapse/storage/schema/delta/30/public_rooms.sql new file mode 100644 index 0000000000..f09db4faa6 --- /dev/null +++ b/synapse/storage/schema/delta/30/public_rooms.sql @@ -0,0 +1,23 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +/* This release removes the restriction that published rooms must have an alias, + * so we go back and ensure the only 'public' rooms are ones with an alias. + * We use (1 = 0) and (1 = 1) so that it works in both postgres and sqlite + */ +UPDATE rooms SET is_public = (1 = 0) WHERE is_public = (1 = 1) AND room_id not in ( + SELECT room_id FROM room_aliases +); diff --git a/synapse/storage/schema/delta/30/push_rule_stream.sql b/synapse/storage/schema/delta/30/push_rule_stream.sql new file mode 100644 index 0000000000..735aa8d5f6 --- /dev/null +++ b/synapse/storage/schema/delta/30/push_rule_stream.sql @@ -0,0 +1,38 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + + +CREATE TABLE push_rules_stream( + stream_id BIGINT NOT NULL, + event_stream_ordering BIGINT NOT NULL, + user_id TEXT NOT NULL, + rule_id TEXT NOT NULL, + op TEXT NOT NULL, -- One of "ENABLE", "DISABLE", "ACTIONS", "ADD", "DELETE" + priority_class SMALLINT, + priority INTEGER, + conditions TEXT, + actions TEXT +); + +-- The extra data for each operation is: +-- * ENABLE, DISABLE, DELETE: [] +-- * ACTIONS: ["actions"] +-- * ADD: ["priority_class", "priority", "actions", "conditions"] + +-- Index for replication queries. +CREATE INDEX push_rules_stream_id ON push_rules_stream(stream_id); +-- Index for /sync queries. +CREATE INDEX push_rules_stream_user_stream_id on push_rules_stream(user_id, stream_id); diff --git a/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql b/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql new file mode 100644 index 0000000000..0dd2f1360c --- /dev/null +++ b/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql @@ -0,0 +1,24 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Stores guest account access tokens generated for unbound 3pids. +CREATE TABLE threepid_guest_access_tokens( + medium TEXT, -- The medium of the 3pid. Must be "email". + address TEXT, -- The 3pid address. + guest_access_token TEXT, -- The access token for a guest user for this 3pid. + first_inviter TEXT -- User ID of the first user to invite this 3pid to a room. +); + +CREATE UNIQUE INDEX threepid_guest_access_tokens_index ON threepid_guest_access_tokens(medium, address); diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index 70c6a06cd1..b10f2a5787 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -15,7 +15,7 @@ from twisted.internet import defer -from _base import SQLBaseStore +from ._base import SQLBaseStore from unpaddedbase64 import encode_base64 from synapse.crypto.event_signing import compute_event_reference_hash diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 372b540002..02cefdff26 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -14,9 +14,8 @@ # limitations under the License. from ._base import SQLBaseStore -from synapse.util.caches.descriptors import ( - cached, cachedInlineCallbacks, cachedList -) +from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.caches import intern_string from twisted.internet import defer @@ -83,7 +82,7 @@ class StateStore(SQLBaseStore): if event.is_state(): state_events[(event.type, event.state_key)] = event - state_group = self._state_groups_id_gen.get_next_txn(txn) + state_group = self._state_groups_id_gen.get_next() self._simple_insert_txn( txn, table="state_groups", @@ -155,8 +154,14 @@ class StateStore(SQLBaseStore): events = yield self._get_events(event_ids, get_prev_content=False) defer.returnValue(events) - @cachedInlineCallbacks(num_args=3) + @defer.inlineCallbacks def get_current_state_for_key(self, room_id, event_type, state_key): + event_ids = yield self._get_current_state_for_key(room_id, event_type, state_key) + events = yield self._get_events(event_ids, get_prev_content=False) + defer.returnValue(events) + + @cached(num_args=3) + def _get_current_state_for_key(self, room_id, event_type, state_key): def f(txn): sql = ( "SELECT event_id FROM current_state_events" @@ -167,12 +172,10 @@ class StateStore(SQLBaseStore): txn.execute(sql, args) results = txn.fetchall() return [r[0] for r in results] - event_ids = yield self.runInteraction("get_current_state_for_key", f) - events = yield self._get_events(event_ids, get_prev_content=False) - defer.returnValue(events) + return self.runInteraction("get_current_state_for_key", f) def _get_state_groups_from_groups(self, groups, types): - """Returns dictionary state_group -> state event ids + """Returns dictionary state_group -> (dict of (type, state_key) -> event id) """ def f(txn, groups): if types is not None: @@ -183,7 +186,8 @@ class StateStore(SQLBaseStore): where_clause = "" sql = ( - "SELECT state_group, event_id FROM state_groups_state WHERE" + "SELECT state_group, event_id, type, state_key" + " FROM state_groups_state WHERE" " state_group IN (%s) %s" % ( ",".join("?" for _ in groups), where_clause, @@ -199,7 +203,8 @@ class StateStore(SQLBaseStore): results = {} for row in rows: - results.setdefault(row["state_group"], []).append(row["event_id"]) + key = (row["type"], row["state_key"]) + results.setdefault(row["state_group"], {})[key] = row["event_id"] return results chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)] @@ -296,7 +301,7 @@ class StateStore(SQLBaseStore): where a `state_key` of `None` matches all state_keys for the `type`. """ - is_all, state_dict = self._state_group_cache.get(group) + is_all, state_dict_ids = self._state_group_cache.get(group) type_to_key = {} missing_types = set() @@ -308,7 +313,7 @@ class StateStore(SQLBaseStore): if type_to_key.get(typ, object()) is not None: type_to_key.setdefault(typ, set()).add(state_key) - if (typ, state_key) not in state_dict: + if (typ, state_key) not in state_dict_ids: missing_types.add((typ, state_key)) sentinel = object() @@ -326,7 +331,7 @@ class StateStore(SQLBaseStore): got_all = not (missing_types or types is None) return { - k: v for k, v in state_dict.items() + k: v for k, v in state_dict_ids.items() if include(k[0], k[1]) }, missing_types, got_all @@ -340,8 +345,9 @@ class StateStore(SQLBaseStore): Args: group: The state group to lookup """ - is_all, state_dict = self._state_group_cache.get(group) - return state_dict, is_all + is_all, state_dict_ids = self._state_group_cache.get(group) + + return state_dict_ids, is_all @defer.inlineCallbacks def _get_state_for_groups(self, groups, types=None): @@ -354,84 +360,72 @@ class StateStore(SQLBaseStore): missing_groups = [] if types is not None: for group in set(groups): - state_dict, missing_types, got_all = self._get_some_state_from_cache( + state_dict_ids, missing_types, got_all = self._get_some_state_from_cache( group, types ) - results[group] = state_dict + results[group] = state_dict_ids if not got_all: missing_groups.append(group) else: for group in set(groups): - state_dict, got_all = self._get_all_state_from_cache( + state_dict_ids, got_all = self._get_all_state_from_cache( group ) - results[group] = state_dict + + results[group] = state_dict_ids if not got_all: missing_groups.append(group) - if not missing_groups: - defer.returnValue({ - group: { - type_tuple: event - for type_tuple, event in state.items() - if event - } - for group, state in results.items() - }) + if missing_groups: + # Okay, so we have some missing_types, lets fetch them. + cache_seq_num = self._state_group_cache.sequence - # Okay, so we have some missing_types, lets fetch them. - cache_seq_num = self._state_group_cache.sequence + group_to_state_dict = yield self._get_state_groups_from_groups( + missing_groups, types + ) - group_state_dict = yield self._get_state_groups_from_groups( - missing_groups, types - ) + # Now we want to update the cache with all the things we fetched + # from the database. + for group, group_state_dict in group_to_state_dict.items(): + if types: + # We delibrately put key -> None mappings into the cache to + # cache absence of the key, on the assumption that if we've + # explicitly asked for some types then we will probably ask + # for them again. + state_dict = { + (intern_string(etype), intern_string(state_key)): None + for (etype, state_key) in types + } + state_dict.update(results[group]) + results[group] = state_dict + else: + state_dict = results[group] + + state_dict.update(group_state_dict) + + self._state_group_cache.update( + cache_seq_num, + key=group, + value=state_dict, + full=(types is None), + ) state_events = yield self._get_events( - [e_id for l in group_state_dict.values() for e_id in l], + [ev_id for sd in results.values() for ev_id in sd.values()], get_prev_content=False ) state_events = {e.event_id: e for e in state_events} - # Now we want to update the cache with all the things we fetched - # from the database. - for group, state_ids in group_state_dict.items(): - if types: - # We delibrately put key -> None mappings into the cache to - # cache absence of the key, on the assumption that if we've - # explicitly asked for some types then we will probably ask - # for them again. - state_dict = {key: None for key in types} - state_dict.update(results[group]) - results[group] = state_dict - else: - state_dict = results[group] - - for event_id in state_ids: - try: - state_event = state_events[event_id] - state_dict[(state_event.type, state_event.state_key)] = state_event - except KeyError: - # Hmm. So we do don't have that state event? Interesting. - logger.warn( - "Can't find state event %r for state group %r", - event_id, group, - ) - - self._state_group_cache.update( - cache_seq_num, - key=group, - value=state_dict, - full=(types is None), - ) - # Remove all the entries with None values. The None values were just # used for bookkeeping in the cache. for group, state_dict in results.items(): results[group] = { - key: event for key, event in state_dict.items() if event + key: state_events[event_id] + for key, event_id in state_dict.items() + if event_id and event_id in state_events } defer.returnValue(results) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index c236dafafb..cf84938be5 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -36,7 +36,7 @@ what sort order was used: from twisted.internet import defer from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cached from synapse.api.constants import EventTypes from synapse.types import RoomStreamToken from synapse.util.logcontext import preserve_fn @@ -184,6 +184,9 @@ class StreamStore(SQLBaseStore): @defer.inlineCallbacks def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0, order='DESC'): + # Note: If from_key is None then we return in topological order. This + # is because in that case we're using this as a "get the last few messages + # in a room" function, rather than "get new messages since last sync" if from_key is not None: from_id = RoomStreamToken.parse_stream_token(from_key).stream else: @@ -217,8 +220,8 @@ class StreamStore(SQLBaseStore): " room_id = ?" " AND not outlier" " AND stream_ordering <= ?" - " ORDER BY stream_ordering %s LIMIT ?" - ) % (order,) + " ORDER BY topological_ordering %s, stream_ordering %s LIMIT ?" + ) % (order, order,) txn.execute(sql, (room_id, to_id, limit)) rows = self.cursor_to_dict(txn) @@ -232,7 +235,7 @@ class StreamStore(SQLBaseStore): get_prev_content=True ) - self._set_before_and_after(ret, rows, topo_order=False) + self._set_before_and_after(ret, rows, topo_order=from_id is None) if order.lower() == "desc": ret.reverse() @@ -462,9 +465,25 @@ class StreamStore(SQLBaseStore): defer.returnValue((events, token)) - @cachedInlineCallbacks(num_args=4) + @defer.inlineCallbacks def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None): + rows, token = yield self.get_recent_event_ids_for_room( + room_id, limit, end_token, from_token + ) + + logger.debug("stream before") + events = yield self._get_events( + [r["event_id"] for r in rows], + get_prev_content=True + ) + logger.debug("stream after") + + self._set_before_and_after(events, rows) + defer.returnValue((events, token)) + + @cached(num_args=4) + def get_recent_event_ids_for_room(self, room_id, limit, end_token, from_token=None): end_token = RoomStreamToken.parse_stream_token(end_token) if from_token is None: @@ -514,24 +533,13 @@ class StreamStore(SQLBaseStore): return rows, token - rows, token = yield self.runInteraction( + return self.runInteraction( "get_recent_events_for_room", get_recent_events_for_room_txn ) - logger.debug("stream before") - events = yield self._get_events( - [r["event_id"] for r in rows], - get_prev_content=True - ) - logger.debug("stream after") - - self._set_before_and_after(events, rows) - - defer.returnValue((events, token)) - @defer.inlineCallbacks def get_room_events_max_id(self, direction='f'): - token = yield self._stream_id_gen.get_max_token(self) + token = yield self._stream_id_gen.get_max_token() if direction != 'b': defer.returnValue("s%d" % (token,)) else: diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index e1a9c0c261..a0e6b42b30 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -30,7 +30,7 @@ class TagsStore(SQLBaseStore): Returns: A deferred int. """ - return self._account_data_id_gen.get_max_token(self) + return self._account_data_id_gen.get_max_token() @cached() def get_tags_for_user(self, user_id): @@ -59,6 +59,59 @@ class TagsStore(SQLBaseStore): return deferred @defer.inlineCallbacks + def get_all_updated_tags(self, last_id, current_id, limit): + """Get all the client tags that have changed on the server + Args: + last_id(int): The position to fetch from. + current_id(int): The position to fetch up to. + Returns: + A deferred list of tuples of stream_id int, user_id string, + room_id string, tag string and content string. + """ + def get_all_updated_tags_txn(txn): + sql = ( + "SELECT stream_id, user_id, room_id" + " FROM room_tags_revisions as r" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + + tag_ids = yield self.runInteraction( + "get_all_updated_tags", get_all_updated_tags_txn + ) + + def get_tag_content(txn, tag_ids): + sql = ( + "SELECT tag, content" + " FROM room_tags" + " WHERE user_id=? AND room_id=?" + ) + results = [] + for stream_id, user_id, room_id in tag_ids: + txn.execute(sql, (user_id, room_id)) + tags = [] + for tag, content in txn.fetchall(): + tags.append(json.dumps(tag) + ":" + content) + tag_json = "{" + ",".join(tags) + "}" + results.append((stream_id, user_id, room_id, tag_json)) + + return results + + batch_size = 50 + results = [] + for i in xrange(0, len(tag_ids), batch_size): + tags = yield self.runInteraction( + "get_all_updated_tag_content", + get_tag_content, + tag_ids[i:i + batch_size], + ) + results.extend(tags) + + defer.returnValue(results) + + @defer.inlineCallbacks def get_updated_tags(self, user_id, stream_id): """Get all the tags for the rooms where the tags have changed since the given version @@ -142,12 +195,12 @@ class TagsStore(SQLBaseStore): ) self._update_revision_txn(txn, user_id, room_id, next_id) - with (yield self._account_data_id_gen.get_next(self)) as next_id: + with self._account_data_id_gen.get_next() as next_id: yield self.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) - result = yield self._account_data_id_gen.get_max_token(self) + result = self._account_data_id_gen.get_max_token() defer.returnValue(result) @defer.inlineCallbacks @@ -164,12 +217,12 @@ class TagsStore(SQLBaseStore): txn.execute(sql, (user_id, room_id, tag)) self._update_revision_txn(txn, user_id, room_id, next_id) - with (yield self._account_data_id_gen.get_next(self)) as next_id: + with self._account_data_id_gen.get_next() as next_id: yield self.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) - result = yield self._account_data_id_gen.get_max_token(self) + result = self._account_data_id_gen.get_max_token() defer.returnValue(result) def _update_revision_txn(self, txn, user_id, room_id, next_id): diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 4475c451c1..d338dfcf0a 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -117,7 +117,7 @@ class TransactionStore(SQLBaseStore): def _prep_send_transaction(self, txn, transaction_id, destination, origin_server_ts): - next_id = self._transaction_id_gen.get_next_txn(txn) + next_id = self._transaction_id_gen.get_next() # First we find out what the prev_txns should be. # Since we know that we are only sending one transaction at a time, diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 5c522f4ab9..a02dfc7d58 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -13,51 +13,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from collections import deque import contextlib import threading class IdGenerator(object): - def __init__(self, table, column, store): - self.table = table - self.column = column - self.store = store + def __init__(self, db_conn, table, column): self._lock = threading.Lock() - self._next_id = None + self._next_id = _load_max_id(db_conn, table, column) - @defer.inlineCallbacks def get_next(self): - if self._next_id is None: - yield self.store.runInteraction( - "IdGenerator_%s" % (self.table,), - self.get_next_txn, - ) - with self._lock: - i = self._next_id self._next_id += 1 - defer.returnValue(i) - - def get_next_txn(self, txn): - with self._lock: - if self._next_id: - i = self._next_id - self._next_id += 1 - return i - else: - txn.execute( - "SELECT MAX(%s) FROM %s" % (self.column, self.table,) - ) + return self._next_id - val, = txn.fetchone() - cur = val or 0 - cur += 1 - self._next_id = cur + 1 - return cur +def _load_max_id(db_conn, table, column): + cur = db_conn.cursor() + cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) + val, = cur.fetchone() + cur.close() + return int(val) if val else 1 class StreamIdGenerator(object): @@ -69,25 +46,23 @@ class StreamIdGenerator(object): persistence of events can complete out of order. Usage: - with stream_id_gen.get_next_txn(txn) as stream_id: + with stream_id_gen.get_next() as stream_id: # ... persist event ... """ - def __init__(self, db_conn, table, column): - self.table = table - self.column = column - + def __init__(self, db_conn, table, column, extra_tables=[]): self._lock = threading.Lock() - - cur = db_conn.cursor() - self._current_max = self._get_or_compute_current_max(cur) - cur.close() - + self._current_max = _load_max_id(db_conn, table, column) + for table, column in extra_tables: + self._current_max = max( + self._current_max, + _load_max_id(db_conn, table, column) + ) self._unfinished_ids = deque() - def get_next(self, store): + def get_next(self): """ Usage: - with yield stream_id_gen.get_next as stream_id: + with stream_id_gen.get_next() as stream_id: # ... persist event ... """ with self._lock: @@ -106,10 +81,10 @@ class StreamIdGenerator(object): return manager() - def get_next_mult(self, store, n): + def get_next_mult(self, n): """ Usage: - with yield stream_id_gen.get_next(store, n) as stream_ids: + with stream_id_gen.get_next(n) as stream_ids: # ... persist events ... """ with self._lock: @@ -130,7 +105,7 @@ class StreamIdGenerator(object): return manager() - def get_max_token(self, store): + def get_max_token(self): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. """ @@ -140,12 +115,49 @@ class StreamIdGenerator(object): return self._current_max - def _get_or_compute_current_max(self, txn): + +class ChainedIdGenerator(object): + """Used to generate new stream ids where the stream must be kept in sync + with another stream. It generates pairs of IDs, the first element is an + integer ID for this stream, the second element is the ID for the stream + that this stream needs to be kept in sync with.""" + + def __init__(self, chained_generator, db_conn, table, column): + self.chained_generator = chained_generator + self._lock = threading.Lock() + self._current_max = _load_max_id(db_conn, table, column) + self._unfinished_ids = deque() + + def get_next(self): + """ + Usage: + with stream_id_gen.get_next() as (stream_id, chained_id): + # ... persist event ... + """ with self._lock: - txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table)) - rows = txn.fetchall() - val, = rows[0] + self._current_max += 1 + next_id = self._current_max + chained_id = self.chained_generator.get_max_token() - self._current_max = int(val) if val else 1 + self._unfinished_ids.append((next_id, chained_id)) - return self._current_max + @contextlib.contextmanager + def manager(): + try: + yield (next_id, chained_id) + finally: + with self._lock: + self._unfinished_ids.remove((next_id, chained_id)) + + return manager() + + def get_max_token(self): + """Returns the maximum stream id such that all stream ids less than or + equal to it have been successfully persisted. + """ + with self._lock: + if self._unfinished_ids: + stream_id, chained_id = self._unfinished_ids[0] + return (stream_id - 1, chained_id) + + return (self._current_max, self.chained_generator.get_max_token()) |