diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/__init__.py | 14 | ||||
-rw-r--r-- | synapse/storage/account_data.py | 44 | ||||
-rw-r--r-- | synapse/storage/appservice.py | 34 | ||||
-rw-r--r-- | synapse/storage/engines/__init__.py | 5 | ||||
-rw-r--r-- | synapse/storage/engines/postgres.py | 5 | ||||
-rw-r--r-- | synapse/storage/engines/sqlite3.py | 5 | ||||
-rw-r--r-- | synapse/storage/events.py | 51 | ||||
-rw-r--r-- | synapse/storage/prepare_database.py | 13 | ||||
-rw-r--r-- | synapse/storage/presence.py | 20 | ||||
-rw-r--r-- | synapse/storage/push_rule.py | 4 | ||||
-rw-r--r-- | synapse/storage/pusher.py | 2 | ||||
-rw-r--r-- | synapse/storage/receipts.py | 20 | ||||
-rw-r--r-- | synapse/storage/registration.py | 6 | ||||
-rw-r--r-- | synapse/storage/schema/delta/30/as_users.py | 59 | ||||
-rw-r--r-- | synapse/storage/state.py | 2 | ||||
-rw-r--r-- | synapse/storage/tags.py | 61 | ||||
-rw-r--r-- | synapse/storage/transactions.py | 2 | ||||
-rw-r--r-- | synapse/storage/util/id_generators.py | 69 |
18 files changed, 312 insertions, 104 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 9be1d12fac..f257721ea3 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -115,13 +115,13 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "presence_stream", "stream_id" ) - self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) - self._state_groups_id_gen = IdGenerator("state_groups", "id", self) - self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) - self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self) - self._pushers_id_gen = IdGenerator("pushers", "id", self) - self._push_rule_id_gen = IdGenerator("push_rules", "id", self) - self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) + self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") + self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") + self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") + self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") + self._pushers_id_gen = IdGenerator(db_conn, "pushers", "id") + self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") + self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") events_max = self._stream_id_gen.get_max_token() event_cache_prefill, min_event_val = self._get_cache_dict( diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index 91cbf399b6..faddefe219 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -83,8 +83,40 @@ class AccountDataStore(SQLBaseStore): "get_account_data_for_room", get_account_data_for_room_txn ) - def get_updated_account_data_for_user(self, user_id, stream_id, room_ids=None): - """Get all the client account_data for a that's changed. + def get_all_updated_account_data(self, last_global_id, last_room_id, + current_id, limit): + """Get all the client account_data that has changed on the server + Args: + last_global_id(int): The position to fetch from for top level data + last_room_id(int): The position to fetch from for per room data + current_id(int): The position to fetch up to. + Returns: + A deferred pair of lists of tuples of stream_id int, user_id string, + room_id string, type string, and content string. + """ + def get_updated_account_data_txn(txn): + sql = ( + "SELECT stream_id, user_id, account_data_type, content" + " FROM account_data WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_global_id, current_id, limit)) + global_results = txn.fetchall() + + sql = ( + "SELECT stream_id, user_id, room_id, account_data_type, content" + " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_room_id, current_id, limit)) + room_results = txn.fetchall() + return (global_results, room_results) + return self.runInteraction( + "get_all_updated_account_data_txn", get_updated_account_data_txn + ) + + def get_updated_account_data_for_user(self, user_id, stream_id): + """Get all the client account_data for a that's changed for a user Args: user_id(str): The user to get the account_data for. @@ -163,12 +195,12 @@ class AccountDataStore(SQLBaseStore): ) self._update_max_stream_id(txn, next_id) - with (yield self._account_data_id_gen.get_next(self)) as next_id: + with self._account_data_id_gen.get_next() as next_id: yield self.runInteraction( "add_room_account_data", add_account_data_txn, next_id ) - result = yield self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_max_token() defer.returnValue(result) @defer.inlineCallbacks @@ -202,12 +234,12 @@ class AccountDataStore(SQLBaseStore): ) self._update_max_stream_id(txn, next_id) - with (yield self._account_data_id_gen.get_next(self)) as next_id: + with self._account_data_id_gen.get_next() as next_id: yield self.runInteraction( "add_user_account_data", add_account_data_txn, next_id ) - result = yield self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_max_token() defer.returnValue(result) def _update_max_stream_id(self, txn, next_id): diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index 1100c67714..371600eebb 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -34,8 +34,8 @@ class ApplicationServiceStore(SQLBaseStore): def __init__(self, hs): super(ApplicationServiceStore, self).__init__(hs) self.hostname = hs.hostname - self.services_cache = [] - self._populate_appservice_cache( + self.services_cache = ApplicationServiceStore.load_appservices( + hs.hostname, hs.config.app_service_config_files ) @@ -144,21 +144,23 @@ class ApplicationServiceStore(SQLBaseStore): return rooms_for_user_matching_user_id - def _load_appservice(self, as_info): + @classmethod + def _load_appservice(cls, hostname, as_info, config_filename): required_string_fields = [ - # TODO: Add id here when it's stable to release - "url", "as_token", "hs_token", "sender_localpart" + "id", "url", "as_token", "hs_token", "sender_localpart" ] for field in required_string_fields: if not isinstance(as_info.get(field), basestring): - raise KeyError("Required string field: '%s'", field) + raise KeyError("Required string field: '%s' (%s)" % ( + field, config_filename, + )) localpart = as_info["sender_localpart"] if urllib.quote(localpart) != localpart: raise ValueError( "sender_localpart needs characters which are not URL encoded." ) - user = UserID(localpart, self.hostname) + user = UserID(localpart, hostname) user_id = user.to_string() # namespace checks @@ -188,25 +190,30 @@ class ApplicationServiceStore(SQLBaseStore): namespaces=as_info["namespaces"], hs_token=as_info["hs_token"], sender=user_id, - id=as_info["id"] if "id" in as_info else as_info["as_token"], + id=as_info["id"], ) - def _populate_appservice_cache(self, config_files): - """Populates a cache of Application Services from the config files.""" + @classmethod + def load_appservices(cls, hostname, config_files): + """Returns a list of Application Services from the config files.""" if not isinstance(config_files, list): logger.warning( "Expected %s to be a list of AS config files.", config_files ) - return + return [] # Dicts of value -> filename seen_as_tokens = {} seen_ids = {} + appservices = [] + for config_file in config_files: try: with open(config_file, 'r') as f: - appservice = self._load_appservice(yaml.load(f)) + appservice = ApplicationServiceStore._load_appservice( + hostname, yaml.load(f), config_file + ) if appservice.id in seen_ids: raise ConfigError( "Cannot reuse ID across application services: " @@ -226,11 +233,12 @@ class ApplicationServiceStore(SQLBaseStore): ) seen_as_tokens[appservice.token] = config_file logger.info("Loaded application service: %s", appservice) - self.services_cache.append(appservice) + appservices.append(appservice) except Exception as e: logger.error("Failed to load appservice from '%s'", config_file) logger.exception(e) raise + return appservices class ApplicationServiceTransactionStore(SQLBaseStore): diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index 4290aea83a..a48230b93f 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -26,12 +26,13 @@ SUPPORTED_MODULE = { } -def create_engine(name): +def create_engine(config): + name = config.database_config["name"] engine_class = SUPPORTED_MODULE.get(name, None) if engine_class: module = importlib.import_module(name) - return engine_class(module) + return engine_class(module, config=config) raise RuntimeError( "Unsupported database engine '%s'" % (name,) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 17b7a9c077..a09685b4df 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -21,9 +21,10 @@ from ._base import IncorrectDatabaseSetup class PostgresEngine(object): single_threaded = False - def __init__(self, database_module): + def __init__(self, database_module, config): self.module = database_module self.module.extensions.register_type(self.module.extensions.UNICODE) + self.config = config def check_database(self, txn): txn.execute("SHOW SERVER_ENCODING") @@ -44,7 +45,7 @@ class PostgresEngine(object): ) def prepare_database(self, db_conn): - prepare_database(db_conn, self) + prepare_database(db_conn, self, config=self.config) def is_deadlock(self, error): if isinstance(error, self.module.DatabaseError): diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index 91fac33b8b..522b905949 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -23,8 +23,9 @@ import struct class Sqlite3Engine(object): single_threaded = True - def __init__(self, database_module): + def __init__(self, database_module, config): self.module = database_module + self.config = config def check_database(self, txn): pass @@ -38,7 +39,7 @@ class Sqlite3Engine(object): def prepare_database(self, db_conn): prepare_sqlite3_database(db_conn) - prepare_database(db_conn, self) + prepare_database(db_conn, self, config=self.config) def is_deadlock(self, error): return False diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 1dd3236829..60936500d8 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -75,8 +75,8 @@ class EventsStore(SQLBaseStore): yield stream_orderings stream_ordering_manager = stream_ordering_manager() else: - stream_ordering_manager = yield self._stream_id_gen.get_next_mult( - self, len(events_and_contexts) + stream_ordering_manager = self._stream_id_gen.get_next_mult( + len(events_and_contexts) ) with stream_ordering_manager as stream_orderings: @@ -109,7 +109,7 @@ class EventsStore(SQLBaseStore): stream_ordering = self.min_stream_token if stream_ordering is None: - stream_ordering_manager = yield self._stream_id_gen.get_next(self) + stream_ordering_manager = self._stream_id_gen.get_next() else: @contextmanager def stream_ordering_manager(): @@ -1064,3 +1064,48 @@ class EventsStore(SQLBaseStore): yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME) defer.returnValue(result) + + def get_current_backfill_token(self): + """The current minimum token that backfilled events have reached""" + + # TODO: Fix race with the persit_event txn by using one of the + # stream id managers + return -self.min_stream_token + + def get_all_new_events(self, last_backfill_id, last_forward_id, + current_backfill_id, current_forward_id, limit): + """Get all the new events that have arrived at the server either as + new events or as backfilled events""" + def get_all_new_events_txn(txn): + sql = ( + "SELECT e.stream_ordering, ej.internal_metadata, ej.json" + " FROM events as e" + " JOIN event_json as ej" + " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" + " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?" + " ORDER BY e.stream_ordering ASC" + " LIMIT ?" + ) + if last_forward_id != current_forward_id: + txn.execute(sql, (last_forward_id, current_forward_id, limit)) + new_forward_events = txn.fetchall() + else: + new_forward_events = [] + + sql = ( + "SELECT -e.stream_ordering, ej.internal_metadata, ej.json" + " FROM events as e" + " JOIN event_json as ej" + " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" + " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?" + " ORDER BY e.stream_ordering DESC" + " LIMIT ?" + ) + if last_backfill_id != current_backfill_id: + txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) + new_backfill_events = txn.fetchall() + else: + new_backfill_events = [] + + return (new_forward_events, new_backfill_events) + return self.runInteraction("get_all_new_events", get_all_new_events_txn) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 0fd5d497ab..3f29aad1e8 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -50,7 +50,7 @@ class UpgradeDatabaseException(PrepareDatabaseException): pass -def prepare_database(db_conn, database_engine): +def prepare_database(db_conn, database_engine, config): """Prepares a database for usage. Will either create all necessary tables or upgrade from an older schema version. """ @@ -61,10 +61,10 @@ def prepare_database(db_conn, database_engine): if version_info: user_version, delta_files, upgraded = version_info _upgrade_existing_database( - cur, user_version, delta_files, upgraded, database_engine + cur, user_version, delta_files, upgraded, database_engine, config ) else: - _setup_new_database(cur, database_engine) + _setup_new_database(cur, database_engine, config) # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) @@ -75,7 +75,7 @@ def prepare_database(db_conn, database_engine): raise -def _setup_new_database(cur, database_engine): +def _setup_new_database(cur, database_engine, config): """Sets up the database by finding a base set of "full schemas" and then applying any necessary deltas. @@ -148,11 +148,12 @@ def _setup_new_database(cur, database_engine): applied_delta_files=[], upgraded=False, database_engine=database_engine, + config=config, ) def _upgrade_existing_database(cur, current_version, applied_delta_files, - upgraded, database_engine): + upgraded, database_engine, config): """Upgrades an existing database. Delta files can either be SQL stored in *.sql files, or python modules @@ -245,7 +246,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, module_name, absolute_path, python_file ) logger.debug("Running script %s", relative_path) - module.run_upgrade(cur, database_engine) + module.run_upgrade(cur, database_engine, config=config) elif ext == ".pyc": # Sometimes .pyc files turn up anyway even though we've # disabled their generation; e.g. from distribution package diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 3ef91d34db..4cec31e316 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -58,8 +58,8 @@ class UserPresenceState(namedtuple("UserPresenceState", class PresenceStore(SQLBaseStore): @defer.inlineCallbacks def update_presence(self, presence_states): - stream_ordering_manager = yield self._presence_id_gen.get_next_mult( - self, len(presence_states) + stream_ordering_manager = self._presence_id_gen.get_next_mult( + len(presence_states) ) with stream_ordering_manager as stream_orderings: @@ -115,6 +115,22 @@ class PresenceStore(SQLBaseStore): args ) + def get_all_presence_updates(self, last_id, current_id): + def get_all_presence_updates_txn(txn): + sql = ( + "SELECT stream_id, user_id, state, last_active_ts," + " last_federation_update_ts, last_user_sync_ts, status_msg," + " currently_active" + " FROM presence_stream" + " WHERE ? < stream_id AND stream_id <= ?" + ) + txn.execute(sql, (last_id, current_id)) + return txn.fetchall() + + return self.runInteraction( + "get_all_presence_updates", get_all_presence_updates_txn + ) + @defer.inlineCallbacks def get_presence_for_users(self, user_ids): rows = yield self._simple_select_many_batch( diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index bb5c14d912..56e69495b1 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -226,7 +226,7 @@ class PushRuleStore(SQLBaseStore): if txn.rowcount == 0: # We didn't update a row with the given rule_id so insert one - push_rule_id = self._push_rule_id_gen.get_next_txn(txn) + push_rule_id = self._push_rule_id_gen.get_next() self._simple_insert_txn( txn, @@ -279,7 +279,7 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(ret) def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled): - new_id = self._push_rules_enable_id_gen.get_next_txn(txn) + new_id = self._push_rules_enable_id_gen.get_next() self._simple_upsert_txn( txn, "push_rules_enable", diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index c23648cdbc..7693ab9082 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -84,7 +84,7 @@ class PusherStore(SQLBaseStore): app_display_name, device_display_name, pushkey, pushkey_ts, lang, data, profile_tag=""): try: - next_id = yield self._pushers_id_gen.get_next() + next_id = self._pushers_id_gen.get_next() yield self._simple_upsert( "pushers", dict( diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index a7343c97f7..dbc074d6b5 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -330,7 +330,7 @@ class ReceiptsStore(SQLBaseStore): "insert_receipt_conv", graph_to_linear ) - stream_id_manager = yield self._receipts_id_gen.get_next(self) + stream_id_manager = self._receipts_id_gen.get_next() with stream_id_manager as stream_id: have_persisted = yield self.runInteraction( "insert_linearized_receipt", @@ -347,7 +347,7 @@ class ReceiptsStore(SQLBaseStore): room_id, receipt_type, user_id, event_ids, data ) - max_persisted_id = yield self._stream_id_gen.get_max_token() + max_persisted_id = self._stream_id_gen.get_max_token() defer.returnValue((stream_id, max_persisted_id)) @@ -390,3 +390,19 @@ class ReceiptsStore(SQLBaseStore): "data": json.dumps(data), } ) + + def get_all_updated_receipts(self, last_id, current_id, limit): + def get_all_updated_receipts_txn(txn): + sql = ( + "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" + " FROM receipts_linearized" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + " LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + + return txn.fetchall() + return self.runInteraction( + "get_all_updated_receipts", get_all_updated_receipts_txn + ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 03a9b66e4a..ad1157f979 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -40,7 +40,7 @@ class RegistrationStore(SQLBaseStore): Raises: StoreError if there was a problem adding this. """ - next_id = yield self._access_tokens_id_gen.get_next() + next_id = self._access_tokens_id_gen.get_next() yield self._simple_insert( "access_tokens", @@ -62,7 +62,7 @@ class RegistrationStore(SQLBaseStore): Raises: StoreError if there was a problem adding this. """ - next_id = yield self._refresh_tokens_id_gen.get_next() + next_id = self._refresh_tokens_id_gen.get_next() yield self._simple_insert( "refresh_tokens", @@ -99,7 +99,7 @@ class RegistrationStore(SQLBaseStore): def _register(self, txn, user_id, token, password_hash, was_guest, make_guest): now = int(self.clock.time()) - next_id = self._access_tokens_id_gen.get_next_txn(txn) + next_id = self._access_tokens_id_gen.get_next() try: if was_guest: diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py new file mode 100644 index 0000000000..4cf4dd0917 --- /dev/null +++ b/synapse/storage/schema/delta/30/as_users.py @@ -0,0 +1,59 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from synapse.storage.appservice import ApplicationServiceStore + + +logger = logging.getLogger(__name__) + + +def run_upgrade(cur, database_engine, config, *args, **kwargs): + # NULL indicates user was not registered by an appservice. + cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT") + + cur.execute("SELECT name FROM users") + rows = cur.fetchall() + + config_files = [] + try: + config_files = config.app_service_config_files + except AttributeError: + logger.warning("Could not get app_service_config_files from config") + pass + + appservices = ApplicationServiceStore.load_appservices( + config.server_name, config_files + ) + + owned = {} + + for row in rows: + user_id = row[0] + for appservice in appservices: + if appservice.is_exclusive_user(user_id): + if user_id in owned.keys(): + logger.error( + "user_id %s was owned by more than one application" + " service (IDs %s and %s); assigning arbitrarily to %s" % + (user_id, owned[user_id], appservice.id, owned[user_id]) + ) + owned[user_id] = appservice.id + + for user_id, as_id in owned.items(): + cur.execute( + database_engine.convert_param_style( + "UPDATE users SET appservice_id = ? WHERE name = ?" + ), + (as_id, user_id) + ) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 372b540002..8ed8a21b0a 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -83,7 +83,7 @@ class StateStore(SQLBaseStore): if event.is_state(): state_events[(event.type, event.state_key)] = event - state_group = self._state_groups_id_gen.get_next_txn(txn) + state_group = self._state_groups_id_gen.get_next() self._simple_insert_txn( txn, table="state_groups", diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index 9551aa9739..a0e6b42b30 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -59,6 +59,59 @@ class TagsStore(SQLBaseStore): return deferred @defer.inlineCallbacks + def get_all_updated_tags(self, last_id, current_id, limit): + """Get all the client tags that have changed on the server + Args: + last_id(int): The position to fetch from. + current_id(int): The position to fetch up to. + Returns: + A deferred list of tuples of stream_id int, user_id string, + room_id string, tag string and content string. + """ + def get_all_updated_tags_txn(txn): + sql = ( + "SELECT stream_id, user_id, room_id" + " FROM room_tags_revisions as r" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + + tag_ids = yield self.runInteraction( + "get_all_updated_tags", get_all_updated_tags_txn + ) + + def get_tag_content(txn, tag_ids): + sql = ( + "SELECT tag, content" + " FROM room_tags" + " WHERE user_id=? AND room_id=?" + ) + results = [] + for stream_id, user_id, room_id in tag_ids: + txn.execute(sql, (user_id, room_id)) + tags = [] + for tag, content in txn.fetchall(): + tags.append(json.dumps(tag) + ":" + content) + tag_json = "{" + ",".join(tags) + "}" + results.append((stream_id, user_id, room_id, tag_json)) + + return results + + batch_size = 50 + results = [] + for i in xrange(0, len(tag_ids), batch_size): + tags = yield self.runInteraction( + "get_all_updated_tag_content", + get_tag_content, + tag_ids[i:i + batch_size], + ) + results.extend(tags) + + defer.returnValue(results) + + @defer.inlineCallbacks def get_updated_tags(self, user_id, stream_id): """Get all the tags for the rooms where the tags have changed since the given version @@ -142,12 +195,12 @@ class TagsStore(SQLBaseStore): ) self._update_revision_txn(txn, user_id, room_id, next_id) - with (yield self._account_data_id_gen.get_next(self)) as next_id: + with self._account_data_id_gen.get_next() as next_id: yield self.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) - result = yield self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_max_token() defer.returnValue(result) @defer.inlineCallbacks @@ -164,12 +217,12 @@ class TagsStore(SQLBaseStore): txn.execute(sql, (user_id, room_id, tag)) self._update_revision_txn(txn, user_id, room_id, next_id) - with (yield self._account_data_id_gen.get_next(self)) as next_id: + with self._account_data_id_gen.get_next() as next_id: yield self.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) - result = yield self._account_data_id_gen.get_max_token() + result = self._account_data_id_gen.get_max_token() defer.returnValue(result) def _update_revision_txn(self, txn, user_id, room_id, next_id): diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 4475c451c1..d338dfcf0a 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -117,7 +117,7 @@ class TransactionStore(SQLBaseStore): def _prep_send_transaction(self, txn, transaction_id, destination, origin_server_ts): - next_id = self._transaction_id_gen.get_next_txn(txn) + next_id = self._transaction_id_gen.get_next() # First we find out what the prev_txns should be. # Since we know that we are only sending one transaction at a time, diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index ef5e4a4668..efe3f68e6e 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -13,51 +13,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from collections import deque import contextlib import threading class IdGenerator(object): - def __init__(self, table, column, store): + def __init__(self, db_conn, table, column): self.table = table self.column = column - self.store = store self._lock = threading.Lock() - self._next_id = None + cur = db_conn.cursor() + self._next_id = self._load_next_id(cur) + cur.close() - @defer.inlineCallbacks - def get_next(self): - if self._next_id is None: - yield self.store.runInteraction( - "IdGenerator_%s" % (self.table,), - self.get_next_txn, - ) + def _load_next_id(self, txn): + txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table,)) + val, = txn.fetchone() + return val + 1 if val else 1 + def get_next(self): with self._lock: i = self._next_id self._next_id += 1 - defer.returnValue(i) - - def get_next_txn(self, txn): - with self._lock: - if self._next_id: - i = self._next_id - self._next_id += 1 - return i - else: - txn.execute( - "SELECT MAX(%s) FROM %s" % (self.column, self.table,) - ) - - val, = txn.fetchone() - cur = val or 0 - cur += 1 - self._next_id = cur + 1 - - return cur + return i class StreamIdGenerator(object): @@ -69,7 +48,7 @@ class StreamIdGenerator(object): persistence of events can complete out of order. Usage: - with stream_id_gen.get_next_txn(txn) as stream_id: + with stream_id_gen.get_next() as stream_id: # ... persist event ... """ def __init__(self, db_conn, table, column): @@ -79,15 +58,21 @@ class StreamIdGenerator(object): self._lock = threading.Lock() cur = db_conn.cursor() - self._current_max = self._get_or_compute_current_max(cur) + self._current_max = self._load_current_max(cur) cur.close() self._unfinished_ids = deque() - def get_next(self, store): + def _load_current_max(self, txn): + txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table)) + rows = txn.fetchall() + val, = rows[0] + return int(val) if val else 1 + + def get_next(self): """ Usage: - with yield stream_id_gen.get_next as stream_id: + with stream_id_gen.get_next() as stream_id: # ... persist event ... """ with self._lock: @@ -106,10 +91,10 @@ class StreamIdGenerator(object): return manager() - def get_next_mult(self, store, n): + def get_next_mult(self, n): """ Usage: - with yield stream_id_gen.get_next(store, n) as stream_ids: + with stream_id_gen.get_next(n) as stream_ids: # ... persist events ... """ with self._lock: @@ -139,13 +124,3 @@ class StreamIdGenerator(object): return self._unfinished_ids[0] - 1 return self._current_max - - def _get_or_compute_current_max(self, txn): - with self._lock: - txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table)) - rows = txn.fetchall() - val, = rows[0] - - self._current_max = int(val) if val else 1 - - return self._current_max |