diff options
23 files changed, 643 insertions, 215 deletions
diff --git a/changelog.d/8868.misc b/changelog.d/8868.misc index 1a11e30944..346741d982 100644 --- a/changelog.d/8868.misc +++ b/changelog.d/8868.misc @@ -1 +1 @@ -Improve efficiency of large state resolutions for new rooms. +Improve efficiency of large state resolutions. diff --git a/changelog.d/9029.misc b/changelog.d/9029.misc new file mode 100644 index 0000000000..346741d982 --- /dev/null +++ b/changelog.d/9029.misc @@ -0,0 +1 @@ +Improve efficiency of large state resolutions. diff --git a/changelog.d/9098.misc b/changelog.d/9098.misc new file mode 100644 index 0000000000..907020d428 --- /dev/null +++ b/changelog.d/9098.misc @@ -0,0 +1 @@ +Fix the wrong arguments being passed to `BlacklistingAgentWrapper` from `MatrixFederationAgent`. Contributed by Timothy Leung. diff --git a/changelog.d/9107.feature b/changelog.d/9107.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9107.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/changelog.d/9114.bugfix b/changelog.d/9114.bugfix new file mode 100644 index 0000000000..211f26589d --- /dev/null +++ b/changelog.d/9114.bugfix @@ -0,0 +1 @@ +Fix bugs in federation catchup logic that caused outbound federation to be delayed for large servers after start up. Introduced in v1.8.0 and v1.21.0. diff --git a/changelog.d/9115.misc b/changelog.d/9115.misc new file mode 100644 index 0000000000..346741d982 --- /dev/null +++ b/changelog.d/9115.misc @@ -0,0 +1 @@ +Improve efficiency of large state resolutions. diff --git a/changelog.d/9116.bugfix b/changelog.d/9116.bugfix new file mode 100644 index 0000000000..211f26589d --- /dev/null +++ b/changelog.d/9116.bugfix @@ -0,0 +1 @@ +Fix bugs in federation catchup logic that caused outbound federation to be delayed for large servers after start up. Introduced in v1.8.0 and v1.21.0. diff --git a/changelog.d/9118.misc b/changelog.d/9118.misc new file mode 100644 index 0000000000..346741d982 --- /dev/null +++ b/changelog.d/9118.misc @@ -0,0 +1 @@ +Improve efficiency of large state resolutions. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 22dd169bfb..69bf9110a6 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -70,7 +70,7 @@ logger = logging.getLogger("synapse_port_db") BOOLEAN_COLUMNS = { "events": ["processed", "outlier", "contains_url"], - "rooms": ["is_public"], + "rooms": ["is_public", "has_auth_chain_index"], "event_edges": ["is_state"], "presence_list": ["accepted"], "presence_stream": ["currently_active"], diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index cbecf23be6..57a2f5237c 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -429,7 +429,6 @@ def setup(config_options): oidc = hs.get_oidc_handler() # Loading the provider metadata also ensures the provider config is valid. await oidc.load_metadata() - await oidc.load_jwks() await _base.start(hs, config.listeners) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 84754e5c9c..d6347bb1b8 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -35,6 +35,7 @@ from typing_extensions import TypedDict from twisted.web.client import readBody from synapse.config import ConfigError +from synapse.config.oidc_config import OidcProviderConfig from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable @@ -70,6 +71,131 @@ JWK = Dict[str, str] JWKS = TypedDict("JWKS", {"keys": List[JWK]}) +class OidcHandler: + """Handles requests related to the OpenID Connect login flow. + """ + + def __init__(self, hs: "HomeServer"): + self._sso_handler = hs.get_sso_handler() + + provider_conf = hs.config.oidc.oidc_provider + # we should not have been instantiated if there is no configured provider. + assert provider_conf is not None + + self._token_generator = OidcSessionTokenGenerator(hs) + + self._provider = OidcProvider(hs, self._token_generator, provider_conf) + + async def load_metadata(self) -> None: + """Validate the config and load the metadata from the remote endpoint. + + Called at startup to ensure we have everything we need. + """ + await self._provider.load_metadata() + await self._provider.load_jwks() + + async def handle_oidc_callback(self, request: SynapseRequest) -> None: + """Handle an incoming request to /_synapse/oidc/callback + + Since we might want to display OIDC-related errors in a user-friendly + way, we don't raise SynapseError from here. Instead, we call + ``self._sso_handler.render_error`` which displays an HTML page for the error. + + Most of the OpenID Connect logic happens here: + + - first, we check if there was any error returned by the provider and + display it + - then we fetch the session cookie, decode and verify it + - the ``state`` query parameter should match with the one stored in the + session cookie + + Once we know the session is legit, we then delegate to the OIDC Provider + implementation, which will exchange the code with the provider and complete the + login/authentication. + + Args: + request: the incoming request from the browser. + """ + + # The provider might redirect with an error. + # In that case, just display it as-is. + if b"error" in request.args: + # error response from the auth server. see: + # https://tools.ietf.org/html/rfc6749#section-4.1.2.1 + # https://openid.net/specs/openid-connect-core-1_0.html#AuthError + error = request.args[b"error"][0].decode() + description = request.args.get(b"error_description", [b""])[0].decode() + + # Most of the errors returned by the provider could be due by + # either the provider misbehaving or Synapse being misconfigured. + # The only exception of that is "access_denied", where the user + # probably cancelled the login flow. In other cases, log those errors. + if error != "access_denied": + logger.error("Error from the OIDC provider: %s %s", error, description) + + self._sso_handler.render_error(request, error, description) + return + + # otherwise, it is presumably a successful response. see: + # https://tools.ietf.org/html/rfc6749#section-4.1.2 + + # Fetch the session cookie + session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes] + if session is None: + logger.info("No session cookie found") + self._sso_handler.render_error( + request, "missing_session", "No session cookie found" + ) + return + + # Remove the cookie. There is a good chance that if the callback failed + # once, it will fail next time and the code will already be exchanged. + # Removing it early avoids spamming the provider with token requests. + request.addCookie( + SESSION_COOKIE_NAME, + b"", + path="/_synapse/oidc", + expires="Thu, Jan 01 1970 00:00:00 UTC", + httpOnly=True, + sameSite="lax", + ) + + # Check for the state query parameter + if b"state" not in request.args: + logger.info("State parameter is missing") + self._sso_handler.render_error( + request, "invalid_request", "State parameter is missing" + ) + return + + state = request.args[b"state"][0].decode() + + # Deserialize the session token and verify it. + try: + session_data = self._token_generator.verify_oidc_session_token( + session, state + ) + except MacaroonDeserializationException as e: + logger.exception("Invalid session") + self._sso_handler.render_error(request, "invalid_session", str(e)) + return + except MacaroonInvalidSignatureException as e: + logger.exception("Could not verify session") + self._sso_handler.render_error(request, "mismatching_session", str(e)) + return + + if b"code" not in request.args: + logger.info("Code parameter is missing") + self._sso_handler.render_error( + request, "invalid_request", "Code parameter is missing" + ) + return + + code = request.args[b"code"][0].decode() + + await self._provider.handle_oidc_callback(request, session_data, code) + + class OidcError(Exception): """Used to catch errors when calling the token_endpoint """ @@ -84,21 +210,25 @@ class OidcError(Exception): return self.error -class OidcHandler: - """Handles requests related to the OpenID Connect login flow. +class OidcProvider: + """Wraps the config for a single OIDC IdentityProvider + + Provides methods for handling redirect requests and callbacks via that particular + IdP. """ - def __init__(self, hs: "HomeServer"): + def __init__( + self, + hs: "HomeServer", + token_generator: "OidcSessionTokenGenerator", + provider: OidcProviderConfig, + ): self._store = hs.get_datastore() - self._token_generator = OidcSessionTokenGenerator(hs) + self._token_generator = token_generator self._callback_url = hs.config.oidc_callback_url # type: str - provider = hs.config.oidc.oidc_provider - # we should not have been instantiated if there is no configured provider. - assert provider is not None - self._scopes = provider.scopes self._user_profile_method = provider.user_profile_method self._client_auth = ClientAuth( @@ -552,22 +682,16 @@ class OidcHandler: nonce=nonce, ) - async def handle_oidc_callback(self, request: SynapseRequest) -> None: + async def handle_oidc_callback( + self, request: SynapseRequest, session_data: "OidcSessionData", code: str + ) -> None: """Handle an incoming request to /_synapse/oidc/callback - Since we might want to display OIDC-related errors in a user-friendly - way, we don't raise SynapseError from here. Instead, we call - ``self._sso_handler.render_error`` which displays an HTML page for the error. + By this time we have already validated the session on the synapse side, and + now need to do the provider-specific operations. This includes: - Most of the OpenID Connect logic happens here: - - - first, we check if there was any error returned by the provider and - display it - - then we fetch the session cookie, decode and verify it - - the ``state`` query parameter should match with the one stored in the - session cookie - - once we known this session is legit, exchange the code with the - provider using the ``token_endpoint`` (see ``_exchange_code``) + - exchange the code with the provider using the ``token_endpoint`` (see + ``_exchange_code``) - once we have the token, use it to either extract the UserInfo from the ``id_token`` (``_parse_id_token``), or use the ``access_token`` to fetch UserInfo from the ``userinfo_endpoint`` @@ -577,86 +701,12 @@ class OidcHandler: Args: request: the incoming request from the browser. + session_data: the session data, extracted from our cookie + code: The authorization code we got from the callback. """ - - # The provider might redirect with an error. - # In that case, just display it as-is. - if b"error" in request.args: - # error response from the auth server. see: - # https://tools.ietf.org/html/rfc6749#section-4.1.2.1 - # https://openid.net/specs/openid-connect-core-1_0.html#AuthError - error = request.args[b"error"][0].decode() - description = request.args.get(b"error_description", [b""])[0].decode() - - # Most of the errors returned by the provider could be due by - # either the provider misbehaving or Synapse being misconfigured. - # The only exception of that is "access_denied", where the user - # probably cancelled the login flow. In other cases, log those errors. - if error != "access_denied": - logger.error("Error from the OIDC provider: %s %s", error, description) - - self._sso_handler.render_error(request, error, description) - return - - # otherwise, it is presumably a successful response. see: - # https://tools.ietf.org/html/rfc6749#section-4.1.2 - - # Fetch the session cookie - session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes] - if session is None: - logger.info("No session cookie found") - self._sso_handler.render_error( - request, "missing_session", "No session cookie found" - ) - return - - # Remove the cookie. There is a good chance that if the callback failed - # once, it will fail next time and the code will already be exchanged. - # Removing it early avoids spamming the provider with token requests. - request.addCookie( - SESSION_COOKIE_NAME, - b"", - path="/_synapse/oidc", - expires="Thu, Jan 01 1970 00:00:00 UTC", - httpOnly=True, - sameSite="lax", - ) - - # Check for the state query parameter - if b"state" not in request.args: - logger.info("State parameter is missing") - self._sso_handler.render_error( - request, "invalid_request", "State parameter is missing" - ) - return - - state = request.args[b"state"][0].decode() - - # Deserialize the session token and verify it. - try: - session_data = self._token_generator.verify_oidc_session_token( - session, state - ) - except MacaroonDeserializationException as e: - logger.exception("Invalid session") - self._sso_handler.render_error(request, "invalid_session", str(e)) - return - except MacaroonInvalidSignatureException as e: - logger.exception("Could not verify session") - self._sso_handler.render_error(request, "mismatching_session", str(e)) - return - # Exchange the code with the provider - if b"code" not in request.args: - logger.info("Code parameter is missing") - self._sso_handler.render_error( - request, "invalid_request", "Code parameter is missing" - ) - return - - logger.debug("Exchanging code") - code = request.args[b"code"][0].decode() try: + logger.debug("Exchanging code") token = await self._exchange_code(code) except OidcError as e: logger.exception("Could not exchange code") diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 3b756a7dc2..4c06a117d3 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -102,7 +102,6 @@ class MatrixFederationAgent: pool=self._pool, contextFactory=tls_client_options_factory, ), - self._reactor, ip_blacklist=ip_blacklist, ), user_agent=self.user_agent, diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 6cfadc2b4e..a19d65ad23 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -49,6 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor +from synapse.storage.util.sequence import build_sequence_generator from synapse.types import Collection # python 3 does not have a maximum int value @@ -412,6 +413,16 @@ class DatabasePool: self._check_safe_to_upsert, ) + # We define this sequence here so that it can be referenced from both + # the DataStore and PersistEventStore. + def get_chain_id_txn(txn): + txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains") + return txn.fetchone()[0] + + self.event_chain_id_gen = build_sequence_generator( + engine, get_chain_id_txn, "event_auth_chain_id" + ) + def is_running(self) -> bool: """Is the database pool currently running """ diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 1b6ccd51c8..c128889bf9 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -25,6 +25,7 @@ from twisted.enterprise.adbapi import Connection from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool, make_in_list_sql_clause +from synapse.storage.engines import PostgresEngine from synapse.storage.types import Cursor from synapse.types import JsonDict from synapse.util import json_encoder @@ -513,21 +514,35 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): for user_chunk in batch_iter(user_ids, 100): clause, params = make_in_list_sql_clause( - txn.database_engine, "k.user_id", user_chunk - ) - sql = ( - """ - SELECT k.user_id, k.keytype, k.keydata, k.stream_id - FROM e2e_cross_signing_keys k - INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id - FROM e2e_cross_signing_keys - GROUP BY user_id, keytype) s - USING (user_id, stream_id, keytype) - WHERE - """ - + clause + txn.database_engine, "user_id", user_chunk ) + # Fetch the latest key for each type per user. + if isinstance(self.database_engine, PostgresEngine): + # The `DISTINCT ON` clause will pick the *first* row it + # encounters, so ordering by stream ID desc will ensure we get + # the latest key. + sql = """ + SELECT DISTINCT ON (user_id, keytype) user_id, keytype, keydata, stream_id + FROM e2e_cross_signing_keys + WHERE %(clause)s + ORDER BY user_id, keytype, stream_id DESC + """ % { + "clause": clause + } + else: + # SQLite has special handling for bare columns when using + # MIN/MAX with a `GROUP BY` clause where it picks the value from + # a row that matches the MIN/MAX. + sql = """ + SELECT user_id, keytype, keydata, MAX(stream_id) + FROM e2e_cross_signing_keys + WHERE %(clause)s + GROUP BY user_id, keytype + """ % { + "clause": clause + } + txn.execute(sql, params) rows = self.db_pool.cursor_to_dict(txn) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 186f064036..3216b3f3c8 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -43,7 +43,6 @@ from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchEntry from synapse.storage.util.id_generators import MultiWriterIdGenerator -from synapse.storage.util.sequence import build_sequence_generator from synapse.types import StateMap, get_domain_from_id from synapse.util import json_encoder from synapse.util.iterutils import batch_iter, sorted_topologically @@ -100,14 +99,6 @@ class PersistEventsStore: self._clock = hs.get_clock() self._instance_name = hs.get_instance_name() - def get_chain_id_txn(txn): - txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains") - return txn.fetchone()[0] - - self._event_chain_id_gen = build_sequence_generator( - db.engine, get_chain_id_txn, "event_auth_chain_id" - ) - self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id @@ -466,9 +457,6 @@ class PersistEventsStore: if not state_events: return - # Map from event ID to chain ID/sequence number. - chain_map = {} # type: Dict[str, Tuple[int, int]] - # We need to know the type/state_key and auth events of the events we're # calculating chain IDs for. We don't rely on having the full Event # instances as we'll potentially be pulling more events from the DB and @@ -479,19 +467,44 @@ class PersistEventsStore: event_to_auth_chain = { e.event_id: e.auth_event_ids() for e in state_events.values() } + event_to_room_id = {e.event_id: e.room_id for e in state_events.values()} + + self._add_chain_cover_index( + txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain, + ) + + @staticmethod + def _add_chain_cover_index( + txn, + db_pool: DatabasePool, + event_to_room_id: Dict[str, str], + event_to_types: Dict[str, Tuple[str, str]], + event_to_auth_chain: Dict[str, List[str]], + ) -> None: + """Calculate the chain cover index for the given events. + + Args: + event_to_room_id: Event ID to the room ID of the event + event_to_types: Event ID to type and state_key of the event + event_to_auth_chain: Event ID to list of auth event IDs of the + event (events with no auth events can be excluded). + """ + + # Map from event ID to chain ID/sequence number. + chain_map = {} # type: Dict[str, Tuple[int, int]] # Set of event IDs to calculate chain ID/seq numbers for. - events_to_calc_chain_id_for = set(state_events) + events_to_calc_chain_id_for = set(event_to_room_id) # We check if there are any events that need to be handled in the rooms # we're looking at. These should just be out of band memberships, where # we didn't have the auth chain when we first persisted. - rows = self.db_pool.simple_select_many_txn( + rows = db_pool.simple_select_many_txn( txn, table="event_auth_chain_to_calculate", keyvalues={}, column="room_id", - iterable={e.room_id for e in state_events.values()}, + iterable=set(event_to_room_id.values()), retcols=("event_id", "type", "state_key"), ) for row in rows: @@ -502,7 +515,7 @@ class PersistEventsStore: # (We could pull out the auth events for all rows at once using # simple_select_many, but this case happens rarely and almost always # with a single row.) - auth_events = self.db_pool.simple_select_onecol_txn( + auth_events = db_pool.simple_select_onecol_txn( txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id", ) @@ -551,9 +564,7 @@ class PersistEventsStore: events_to_calc_chain_id_for.add(auth_id) - event_to_auth_chain[ - auth_id - ] = self.db_pool.simple_select_onecol_txn( + event_to_auth_chain[auth_id] = db_pool.simple_select_onecol_txn( txn, "event_auth", keyvalues={"event_id": auth_id}, @@ -582,16 +593,17 @@ class PersistEventsStore: # the list of events to calculate chain IDs for next time # around. (Otherwise we will have already added it to the # table). - event = state_events.get(event_id) - if event: - self.db_pool.simple_insert_txn( + room_id = event_to_room_id.get(event_id) + if room_id: + e_type, state_key = event_to_types[event_id] + db_pool.simple_insert_txn( txn, table="event_auth_chain_to_calculate", values={ - "event_id": event.event_id, - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, + "event_id": event_id, + "room_id": room_id, + "type": e_type, + "state_key": state_key, }, ) @@ -617,7 +629,7 @@ class PersistEventsStore: events_to_calc_chain_id_for, event_to_auth_chain ): existing_chain_id = None - for auth_id in event_to_auth_chain[event_id]: + for auth_id in event_to_auth_chain.get(event_id, []): if event_to_types.get(event_id) == event_to_types.get(auth_id): existing_chain_id = chain_map[auth_id] break @@ -629,7 +641,7 @@ class PersistEventsStore: proposed_new_id = existing_chain_id[0] proposed_new_seq = existing_chain_id[1] + 1 if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated: - already_allocated = self.db_pool.simple_select_one_onecol_txn( + already_allocated = db_pool.simple_select_one_onecol_txn( txn, table="event_auth_chains", keyvalues={ @@ -650,14 +662,14 @@ class PersistEventsStore: ) if not new_chain_tuple: - new_chain_tuple = (self._event_chain_id_gen.get_next_id_txn(txn), 1) + new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1) chains_tuples_allocated.add(new_chain_tuple) chain_map[event_id] = new_chain_tuple new_chain_tuples[event_id] = new_chain_tuple - self.db_pool.simple_insert_many_txn( + db_pool.simple_insert_many_txn( txn, table="event_auth_chains", values=[ @@ -666,7 +678,7 @@ class PersistEventsStore: ], ) - self.db_pool.simple_delete_many_txn( + db_pool.simple_delete_many_txn( txn, table="event_auth_chain_to_calculate", keyvalues={}, @@ -699,7 +711,7 @@ class PersistEventsStore: # Step 1, fetch all existing links from all the chains we've seen # referenced. chain_links = _LinkMap() - rows = self.db_pool.simple_select_many_txn( + rows = db_pool.simple_select_many_txn( txn, table="event_auth_chain_links", column="origin_chain_id", @@ -730,11 +742,11 @@ class PersistEventsStore: # auth events (A, B) to check if B is reachable from A. reduction = { a_id - for a_id in event_to_auth_chain[event_id] + for a_id in event_to_auth_chain.get(event_id, []) if chain_map[a_id][0] != chain_id } for start_auth_id, end_auth_id in itertools.permutations( - event_to_auth_chain[event_id], r=2, + event_to_auth_chain.get(event_id, []), r=2, ): if chain_links.exists_path_from( chain_map[start_auth_id], chain_map[end_auth_id] @@ -763,7 +775,7 @@ class PersistEventsStore: (chain_id, sequence_number), (target_id, target_seq) ) - self.db_pool.simple_insert_many_txn( + db_pool.simple_insert_many_txn( txn, table="event_auth_chain_links", values=[ diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 7e4b175d08..7128dc1742 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -14,13 +14,14 @@ # limitations under the License. import logging -from typing import List, Tuple +from typing import Dict, List, Optional, Tuple from synapse.api.constants import EventContentFields from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import make_event_from_dict from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, make_tuple_comparison_clause +from synapse.storage.databases.main.events import PersistEventsStore from synapse.storage.types import Cursor from synapse.types import JsonDict @@ -108,6 +109,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): "rejected_events_metadata", self._rejected_events_metadata, ) + self.db_pool.updates.register_background_update_handler( + "chain_cover", self._chain_cover_index, + ) + async def _background_reindex_fields_sender(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] @@ -706,3 +711,191 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ) return len(results) + + async def _chain_cover_index(self, progress: dict, batch_size: int) -> int: + """A background updates that iterates over all rooms and generates the + chain cover index for them. + """ + + current_room_id = progress.get("current_room_id", "") + + # Have we finished processing the current room. + finished = progress.get("finished", True) + + # Where we've processed up to in the room, defaults to the start of the + # room. + last_depth = progress.get("last_depth", -1) + last_stream = progress.get("last_stream", -1) + + # Have we set the `has_auth_chain_index` for the room yet. + has_set_room_has_chain_index = progress.get( + "has_set_room_has_chain_index", False + ) + + if finished: + # If we've finished with the previous room (or its our first + # iteration) we move on to the next room. + + def _get_next_room(txn: Cursor) -> Optional[str]: + sql = """ + SELECT room_id FROM rooms + WHERE room_id > ? + AND ( + NOT has_auth_chain_index + OR has_auth_chain_index IS NULL + ) + ORDER BY room_id + LIMIT 1 + """ + txn.execute(sql, (current_room_id,)) + row = txn.fetchone() + if row: + return row[0] + + return None + + current_room_id = await self.db_pool.runInteraction( + "_chain_cover_index", _get_next_room + ) + if not current_room_id: + await self.db_pool.updates._end_background_update("chain_cover") + return 0 + + logger.debug("Adding chain cover to %s", current_room_id) + + def _calculate_auth_chain( + txn: Cursor, last_depth: int, last_stream: int + ) -> Tuple[int, int, int]: + # Get the next set of events in the room (that we haven't already + # computed chain cover for). We do this in topological order. + + # We want to do a `(topological_ordering, stream_ordering) > (?,?)` + # comparison, but that is not supported on older SQLite versions + tuple_clause, tuple_args = make_tuple_comparison_clause( + self.database_engine, + [ + ("topological_ordering", last_depth), + ("stream_ordering", last_stream), + ], + ) + + sql = """ + SELECT + event_id, state_events.type, state_events.state_key, + topological_ordering, stream_ordering + FROM events + INNER JOIN state_events USING (event_id) + LEFT JOIN event_auth_chains USING (event_id) + LEFT JOIN event_auth_chain_to_calculate USING (event_id) + WHERE events.room_id = ? + AND event_auth_chains.event_id IS NULL + AND event_auth_chain_to_calculate.event_id IS NULL + AND %(tuple_cmp)s + ORDER BY topological_ordering, stream_ordering + LIMIT ? + """ % { + "tuple_cmp": tuple_clause, + } + + args = [current_room_id] + args.extend(tuple_args) + args.append(batch_size) + + txn.execute(sql, args) + rows = txn.fetchall() + + # Put the results in the necessary format for + # `_add_chain_cover_index` + event_to_room_id = {row[0]: current_room_id for row in rows} + event_to_types = {row[0]: (row[1], row[2]) for row in rows} + + new_last_depth = rows[-1][3] if rows else last_depth # type: int + new_last_stream = rows[-1][4] if rows else last_stream # type: int + + count = len(rows) + + # We also need to fetch the auth events for them. + auth_events = self.db_pool.simple_select_many_txn( + txn, + table="event_auth", + column="event_id", + iterable=event_to_room_id, + keyvalues={}, + retcols=("event_id", "auth_id"), + ) + + event_to_auth_chain = {} # type: Dict[str, List[str]] + for row in auth_events: + event_to_auth_chain.setdefault(row["event_id"], []).append( + row["auth_id"] + ) + + # Calculate and persist the chain cover index for this set of events. + # + # Annoyingly we need to gut wrench into the persit event store so that + # we can reuse the function to calculate the chain cover for rooms. + PersistEventsStore._add_chain_cover_index( + txn, + self.db_pool, + event_to_room_id, + event_to_types, + event_to_auth_chain, + ) + + return new_last_depth, new_last_stream, count + + last_depth, last_stream, count = await self.db_pool.runInteraction( + "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream + ) + + total_rows_processed = count + + if count < batch_size and not has_set_room_has_chain_index: + # If we've done all the events in the room we flip the + # `has_auth_chain_index` in the DB. Note that its possible for + # further events to be persisted between the above and setting the + # flag without having the chain cover calculated for them. This is + # fine as a) the code gracefully handles these cases and b) we'll + # calculate them below. + + await self.db_pool.simple_update( + table="rooms", + keyvalues={"room_id": current_room_id}, + updatevalues={"has_auth_chain_index": True}, + desc="_chain_cover_index", + ) + has_set_room_has_chain_index = True + + # Handle any events that might have raced with us flipping the + # bit above. + last_depth, last_stream, count = await self.db_pool.runInteraction( + "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream + ) + + total_rows_processed += count + + # Note that at this point its technically possible that more events + # than our `batch_size` have been persisted without their chain + # cover, so we need to continue processing this room if the last + # count returned was equal to the `batch_size`. + + if count < batch_size: + # We've finished calculating the index for this room, move on to the + # next room. + await self.db_pool.updates._background_update_progress( + "chain_cover", {"current_room_id": current_room_id, "finished": True}, + ) + else: + # We still have outstanding events to calculate the index for. + await self.db_pool.updates._background_update_progress( + "chain_cover", + { + "current_room_id": current_room_id, + "last_depth": last_depth, + "last_stream": last_stream, + "has_auth_chain_index": has_set_room_has_chain_index, + "finished": False, + }, + ) + + return total_rows_processed diff --git a/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql b/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql new file mode 100644 index 0000000000..fe3dca71dd --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (5906, 'chain_cover', '{}', 'rejected_events_metadata'); diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 59207cadd4..cea595ff19 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -464,19 +464,17 @@ class TransactionStore(TransactionWorkerStore): txn: LoggingTransaction, now_time_ms: int, after_destination: Optional[str] ) -> List[str]: q = """ - SELECT destination FROM destinations - WHERE destination IN ( - SELECT destination FROM destination_rooms - WHERE destination_rooms.stream_ordering > - destinations.last_successful_stream_ordering - ) - AND destination > ? - AND ( - retry_last_ts IS NULL OR - retry_last_ts + retry_interval < ? - ) - ORDER BY destination - LIMIT 25 + SELECT DISTINCT destination FROM destinations + INNER JOIN destination_rooms USING (destination) + WHERE + stream_ordering > last_successful_stream_ordering + AND destination > ? + AND ( + retry_last_ts IS NULL OR + retry_last_ts + retry_interval < ? + ) + ORDER BY destination + LIMIT 25 """ txn.execute( q, diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index 4386b6101e..412df6b8ef 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -15,9 +15,8 @@ import abc import logging import threading -from typing import Callable, List, Optional +from typing import TYPE_CHECKING, Callable, List, Optional -from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.engines import ( BaseDatabaseEngine, IncorrectDatabaseSetup, @@ -25,6 +24,9 @@ from synapse.storage.engines import ( ) from synapse.storage.types import Connection, Cursor +if TYPE_CHECKING: + from synapse.storage.database import LoggingDatabaseConnection + logger = logging.getLogger(__name__) @@ -55,7 +57,7 @@ class SequenceGenerator(metaclass=abc.ABCMeta): @abc.abstractmethod def check_consistency( self, - db_conn: LoggingDatabaseConnection, + db_conn: "LoggingDatabaseConnection", table: str, id_column: str, positive: bool = True, @@ -88,7 +90,7 @@ class PostgresSequenceGenerator(SequenceGenerator): def check_consistency( self, - db_conn: LoggingDatabaseConnection, + db_conn: "LoggingDatabaseConnection", table: str, id_column: str, positive: bool = True, diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py index f7b4857a84..6ef2b008a4 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py @@ -92,7 +92,7 @@ def sorted_topologically( node = heapq.heappop(zero_degree) yield node - for edge in reverse_graph[node]: + for edge in reverse_graph.get(node, []): if edge in degree_map: degree_map[edge] -= 1 if degree_map[edge] == 0: diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 2abd7a83b5..5d338bea87 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -151,6 +151,7 @@ class OidcHandlerTestCase(HomeserverTestCase): hs = self.setup_test_homeserver(proxied_http_client=self.http_client) self.handler = hs.get_oidc_handler() + self.provider = self.handler._provider sso_handler = hs.get_sso_handler() # Mock the render error method. self.render_error = Mock(return_value=None) @@ -162,9 +163,10 @@ class OidcHandlerTestCase(HomeserverTestCase): return hs def metadata_edit(self, values): - return patch.dict(self.handler._provider_metadata, values) + return patch.dict(self.provider._provider_metadata, values) def assertRenderedError(self, error, error_description=None): + self.render_error.assert_called_once() args = self.render_error.call_args[0] self.assertEqual(args[1], error) if error_description is not None: @@ -175,15 +177,15 @@ class OidcHandlerTestCase(HomeserverTestCase): def test_config(self): """Basic config correctly sets up the callback URL and client auth correctly.""" - self.assertEqual(self.handler._callback_url, CALLBACK_URL) - self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID) - self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET) + self.assertEqual(self.provider._callback_url, CALLBACK_URL) + self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID) + self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET) @override_config({"oidc_config": {"discover": True}}) def test_discovery(self): """The handler should discover the endpoints from OIDC discovery document.""" # This would throw if some metadata were invalid - metadata = self.get_success(self.handler.load_metadata()) + metadata = self.get_success(self.provider.load_metadata()) self.http_client.get_json.assert_called_once_with(WELL_KNOWN) self.assertEqual(metadata.issuer, ISSUER) @@ -195,47 +197,47 @@ class OidcHandlerTestCase(HomeserverTestCase): # subsequent calls should be cached self.http_client.reset_mock() - self.get_success(self.handler.load_metadata()) + self.get_success(self.provider.load_metadata()) self.http_client.get_json.assert_not_called() @override_config({"oidc_config": COMMON_CONFIG}) def test_no_discovery(self): """When discovery is disabled, it should not try to load from discovery document.""" - self.get_success(self.handler.load_metadata()) + self.get_success(self.provider.load_metadata()) self.http_client.get_json.assert_not_called() @override_config({"oidc_config": COMMON_CONFIG}) def test_load_jwks(self): """JWKS loading is done once (then cached) if used.""" - jwks = self.get_success(self.handler.load_jwks()) + jwks = self.get_success(self.provider.load_jwks()) self.http_client.get_json.assert_called_once_with(JWKS_URI) self.assertEqual(jwks, {"keys": []}) # subsequent calls should be cached… self.http_client.reset_mock() - self.get_success(self.handler.load_jwks()) + self.get_success(self.provider.load_jwks()) self.http_client.get_json.assert_not_called() # …unless forced self.http_client.reset_mock() - self.get_success(self.handler.load_jwks(force=True)) + self.get_success(self.provider.load_jwks(force=True)) self.http_client.get_json.assert_called_once_with(JWKS_URI) # Throw if the JWKS uri is missing with self.metadata_edit({"jwks_uri": None}): - self.get_failure(self.handler.load_jwks(force=True), RuntimeError) + self.get_failure(self.provider.load_jwks(force=True), RuntimeError) # Return empty key set if JWKS are not used - self.handler._scopes = [] # not asking the openid scope + self.provider._scopes = [] # not asking the openid scope self.http_client.get_json.reset_mock() - jwks = self.get_success(self.handler.load_jwks(force=True)) + jwks = self.get_success(self.provider.load_jwks(force=True)) self.http_client.get_json.assert_not_called() self.assertEqual(jwks, {"keys": []}) @override_config({"oidc_config": COMMON_CONFIG}) def test_validate_config(self): """Provider metadatas are extensively validated.""" - h = self.handler + h = self.provider # Default test config does not throw h._validate_metadata() @@ -314,13 +316,13 @@ class OidcHandlerTestCase(HomeserverTestCase): """Provider metadata validation can be disabled by config.""" with self.metadata_edit({"issuer": "http://insecure"}): # This should not throw - self.handler._validate_metadata() + self.provider._validate_metadata() def test_redirect_request(self): """The redirect request has the right arguments & generates a valid session cookie.""" req = Mock(spec=["addCookie"]) url = self.get_success( - self.handler.handle_redirect_request(req, b"http://client/redirect") + self.provider.handle_redirect_request(req, b"http://client/redirect") ) url = urlparse(url) auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT) @@ -388,7 +390,7 @@ class OidcHandlerTestCase(HomeserverTestCase): # ensure that we are correctly testing the fallback when "get_extra_attributes" # is not implemented. - mapping_provider = self.handler._user_mapping_provider + mapping_provider = self.provider._user_mapping_provider with self.assertRaises(AttributeError): _ = mapping_provider.get_extra_attributes @@ -403,9 +405,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": username, } expected_user_id = "@%s:%s" % (username, self.hs.hostname) - self.handler._exchange_code = simple_async_mock(return_value=token) - self.handler._parse_id_token = simple_async_mock(return_value=userinfo) - self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) + self.provider._exchange_code = simple_async_mock(return_value=token) + self.provider._parse_id_token = simple_async_mock(return_value=userinfo) + self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() @@ -425,14 +427,14 @@ class OidcHandlerTestCase(HomeserverTestCase): auth_handler.complete_sso_login.assert_called_once_with( expected_user_id, request, client_redirect_url, None, ) - self.handler._exchange_code.assert_called_once_with(code) - self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.handler._fetch_userinfo.assert_not_called() + self.provider._exchange_code.assert_called_once_with(code) + self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) + self.provider._fetch_userinfo.assert_not_called() self.render_error.assert_not_called() # Handle mapping errors with patch.object( - self.handler, + self.provider, "_remote_id_from_userinfo", new=Mock(side_effect=MappingException()), ): @@ -440,36 +442,36 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertRenderedError("mapping_error") # Handle ID token errors - self.handler._parse_id_token = simple_async_mock(raises=Exception()) + self.provider._parse_id_token = simple_async_mock(raises=Exception()) self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_token") auth_handler.complete_sso_login.reset_mock() - self.handler._exchange_code.reset_mock() - self.handler._parse_id_token.reset_mock() - self.handler._fetch_userinfo.reset_mock() + self.provider._exchange_code.reset_mock() + self.provider._parse_id_token.reset_mock() + self.provider._fetch_userinfo.reset_mock() # With userinfo fetching - self.handler._scopes = [] # do not ask the "openid" scope + self.provider._scopes = [] # do not ask the "openid" scope self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( expected_user_id, request, client_redirect_url, None, ) - self.handler._exchange_code.assert_called_once_with(code) - self.handler._parse_id_token.assert_not_called() - self.handler._fetch_userinfo.assert_called_once_with(token) + self.provider._exchange_code.assert_called_once_with(code) + self.provider._parse_id_token.assert_not_called() + self.provider._fetch_userinfo.assert_called_once_with(token) self.render_error.assert_not_called() # Handle userinfo fetching error - self.handler._fetch_userinfo = simple_async_mock(raises=Exception()) + self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("fetch_error") # Handle code exchange failure from synapse.handlers.oidc_handler import OidcError - self.handler._exchange_code = simple_async_mock( + self.provider._exchange_code = simple_async_mock( raises=OidcError("invalid_request") ) self.get_success(self.handler.handle_oidc_callback(request)) @@ -524,7 +526,7 @@ class OidcHandlerTestCase(HomeserverTestCase): return_value=FakeResponse(code=200, phrase=b"OK", body=token_json) ) code = "code" - ret = self.get_success(self.handler._exchange_code(code)) + ret = self.get_success(self.provider._exchange_code(code)) kwargs = self.http_client.request.call_args[1] self.assertEqual(ret, token) @@ -548,7 +550,7 @@ class OidcHandlerTestCase(HomeserverTestCase): ) from synapse.handlers.oidc_handler import OidcError - exc = self.get_failure(self.handler._exchange_code(code), OidcError) + exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "foo") self.assertEqual(exc.value.error_description, "bar") @@ -558,7 +560,7 @@ class OidcHandlerTestCase(HomeserverTestCase): code=500, phrase=b"Internal Server Error", body=b"Not JSON", ) ) - exc = self.get_failure(self.handler._exchange_code(code), OidcError) + exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # Internal server error with JSON body @@ -570,14 +572,14 @@ class OidcHandlerTestCase(HomeserverTestCase): ) ) - exc = self.get_failure(self.handler._exchange_code(code), OidcError) + exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "internal_server_error") # 4xx error without "error" field self.http_client.request = simple_async_mock( return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",) ) - exc = self.get_failure(self.handler._exchange_code(code), OidcError) + exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # 2xx error with "error" field @@ -586,7 +588,7 @@ class OidcHandlerTestCase(HomeserverTestCase): code=200, phrase=b"OK", body=b'{"error": "some_error"}', ) ) - exc = self.get_failure(self.handler._exchange_code(code), OidcError) + exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "some_error") @override_config( @@ -612,8 +614,8 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "foo", "phone": "1234567", } - self.handler._exchange_code = simple_async_mock(return_value=token) - self.handler._parse_id_token = simple_async_mock(return_value=userinfo) + self.provider._exchange_code = simple_async_mock(return_value=token) + self.provider._parse_id_token = simple_async_mock(return_value=userinfo) auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() @@ -979,9 +981,10 @@ async def _make_callback_with_userinfo( from synapse.handlers.oidc_handler import OidcSessionData handler = hs.get_oidc_handler() - handler._exchange_code = simple_async_mock(return_value={}) - handler._parse_id_token = simple_async_mock(return_value=userinfo) - handler._fetch_userinfo = simple_async_mock(return_value=userinfo) + provider = handler._provider + provider._exchange_code = simple_async_mock(return_value={}) + provider._parse_id_token = simple_async_mock(return_value=userinfo) + provider._fetch_userinfo = simple_async_mock(return_value=userinfo) state = "state" session = handler._token_generator.generate_oidc_session_token( diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index 83c377824b..ff67a73749 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -20,7 +20,10 @@ from twisted.trial import unittest from synapse.api.constants import EventTypes from synapse.api.room_versions import RoomVersions from synapse.events import EventBase +from synapse.rest import admin +from synapse.rest.client.v1 import login, room from synapse.storage.databases.main.events import _LinkMap +from synapse.types import create_requester from tests.unittest import HomeserverTestCase @@ -470,3 +473,114 @@ class LinkMapTestCase(unittest.TestCase): self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)]) self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)]) + + +class EventChainBackgroundUpdateTestCase(HomeserverTestCase): + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def test_background_update(self): + """Test that the background update to calculate auth chains for historic + rooms works correctly. + """ + + # Create a room + user_id = self.register_user("foo", "pass") + token = self.login("foo", "pass") + room_id = self.helper.create_room_as(user_id, tok=token) + requester = create_requester(user_id) + + store = self.hs.get_datastore() + + # Mark the room as not having a chain cover index + self.get_success( + store.db_pool.simple_update( + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"has_auth_chain_index": False}, + desc="test", + ) + ) + + # Create a fork in the DAG with different events. + event_handler = self.hs.get_event_creation_handler() + latest_event_ids = self.get_success(store.get_prev_events_for_room(room_id)) + event, context = self.get_success( + event_handler.create_event( + requester, + { + "type": "some_state_type", + "state_key": "", + "content": {}, + "room_id": room_id, + "sender": user_id, + }, + prev_event_ids=latest_event_ids, + ) + ) + self.get_success( + event_handler.handle_new_client_event(requester, event, context) + ) + state1 = list(self.get_success(context.get_current_state_ids()).values()) + + event, context = self.get_success( + event_handler.create_event( + requester, + { + "type": "some_state_type", + "state_key": "", + "content": {}, + "room_id": room_id, + "sender": user_id, + }, + prev_event_ids=latest_event_ids, + ) + ) + self.get_success( + event_handler.handle_new_client_event(requester, event, context) + ) + state2 = list(self.get_success(context.get_current_state_ids()).values()) + + # Delete the chain cover info. + + def _delete_tables(txn): + txn.execute("DELETE FROM event_auth_chains") + txn.execute("DELETE FROM event_auth_chain_links") + + self.get_success(store.db_pool.runInteraction("test", _delete_tables)) + + # Insert and run the background update. + self.get_success( + store.db_pool.simple_insert( + "background_updates", + {"update_name": "chain_cover", "progress_json": "{}"}, + ) + ) + + # Ugh, have to reset this flag + store.db_pool.updates._all_done = False + + while not self.get_success( + store.db_pool.updates.has_completed_background_updates() + ): + self.get_success( + store.db_pool.updates.do_next_background_update(100), by=0.1 + ) + + # Test that the `has_auth_chain_index` has been set + self.assertTrue(self.get_success(store.has_auth_chain_index(room_id))) + + # Test that calculating the auth chain difference using the newly + # calculated chain cover works. + self.get_success( + store.db_pool.runInteraction( + "test", + store._get_auth_chain_difference_using_cover_index_txn, + room_id, + [state1, state2], + ) + ) diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py index 1184cea5a3..522c8061f9 100644 --- a/tests/util/test_itertools.py +++ b/tests/util/test_itertools.py @@ -56,6 +56,14 @@ class SortTopologically(TestCase): graph = {} # type: Dict[int, List[int]] self.assertEqual(list(sorted_topologically([], graph)), []) + def test_handle_empty_graph(self): + "Test that a graph where a node doesn't have an entry is treated as empty" + + graph = {} # type: Dict[int, List[int]] + + # For disconnected nodes the output is simply sorted. + self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) + def test_disconnected(self): "Test that a graph with no edges work" |