diff --git a/changelog.d/9573.feature b/changelog.d/9573.feature
new file mode 100644
index 0000000000..5214b50d41
--- /dev/null
+++ b/changelog.d/9573.feature
@@ -0,0 +1 @@
+Add prometheus metrics for number of users successfully registering and logging in.
diff --git a/changelog.d/9576.misc b/changelog.d/9576.misc
new file mode 100644
index 0000000000..bc257d05b7
--- /dev/null
+++ b/changelog.d/9576.misc
@@ -0,0 +1 @@
+Improve efficiency of calculating the auth chain in large rooms.
diff --git a/changelog.d/9580.doc b/changelog.d/9580.doc
new file mode 100644
index 0000000000..f9c8b328b3
--- /dev/null
+++ b/changelog.d/9580.doc
@@ -0,0 +1 @@
+Clarify the spam checker modules documentation example to mention that `parse_config` is a required method.
diff --git a/changelog.d/9586.misc b/changelog.d/9586.misc
new file mode 100644
index 0000000000..2def9d5f55
--- /dev/null
+++ b/changelog.d/9586.misc
@@ -0,0 +1 @@
+Convert `synapse.types.Requester` to an `attrs` class.
diff --git a/docs/spam_checker.md b/docs/spam_checker.md
index e615ac9910..2020eb9006 100644
--- a/docs/spam_checker.md
+++ b/docs/spam_checker.md
@@ -14,6 +14,7 @@ The Python class is instantiated with two objects:
* An instance of `synapse.module_api.ModuleApi`.
It then implements methods which return a boolean to alter behavior in Synapse.
+All the methods must be defined.
There's a generic method for checking every event (`check_event_for_spam`), as
well as some specific methods:
@@ -24,6 +25,7 @@ well as some specific methods:
* `user_may_publish_room`
* `check_username_for_spam`
* `check_registration_for_spam`
+* `check_media_file_for_spam`
The details of each of these methods (as well as their inputs and outputs)
are documented in the `synapse.events.spamcheck.SpamChecker` class.
@@ -31,6 +33,10 @@ are documented in the `synapse.events.spamcheck.SpamChecker` class.
The `ModuleApi` class provides a way for the custom spam checker class to
call back into the homeserver internals.
+Additionally, a `parse_config` method is mandatory and receives the plugin config
+dictionary. After parsing, It must return an object which will be
+passed to `__init__` later.
+
### Example
```python
@@ -41,6 +47,10 @@ class ExampleSpamChecker:
self.config = config
self.api = api
+ @staticmethod
+ def parse_config(config):
+ return config
+
async def check_event_for_spam(self, foo):
return False # allow all events
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 18aff7af9b..5e8b86bc96 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -448,7 +448,7 @@ class FederationServer(FederationBase):
async def _on_state_ids_request_compute(self, room_id, event_id):
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
- auth_chain_ids = await self.store.get_auth_chain_ids(state_ids)
+ auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
async def _on_context_state_request_compute(
@@ -461,7 +461,9 @@ class FederationServer(FederationBase):
else:
pdus = (await self.state.get_current_state(room_id)).values()
- auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus])
+ auth_chain = await self.store.get_auth_chain(
+ room_id, [pdu.event_id for pdu in pdus]
+ )
return {
"pdus": [pdu.get_pdu_json() for pdu in pdus],
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index bec0c615d4..fb5f8118f0 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -337,7 +337,8 @@ class AuthHandler(BaseHandler):
user is too high to proceed
"""
-
+ if not requester.access_token_id:
+ raise ValueError("Cannot validate a user without an access token")
if self._ui_auth_session_timeout:
last_validated = await self.store.get_access_token_last_validated(
requester.access_token_id
@@ -1213,7 +1214,7 @@ class AuthHandler(BaseHandler):
async def delete_access_tokens_for_user(
self,
user_id: str,
- except_token_id: Optional[str] = None,
+ except_token_id: Optional[int] = None,
device_id: Optional[str] = None,
):
"""Invalidate access tokens belonging to a user
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 6cafb5c227..0f10cc3dc1 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1319,7 +1319,7 @@ class FederationHandler(BaseHandler):
async def on_event_auth(self, event_id: str) -> List[EventBase]:
event = await self.store.get_event(event_id)
auth = await self.store.get_auth_chain(
- list(event.auth_event_ids()), include_given=True
+ event.room_id, list(event.auth_event_ids()), include_given=True
)
return list(auth)
@@ -1653,7 +1653,7 @@ class FederationHandler(BaseHandler):
prev_state_ids = await context.get_prev_state_ids()
state_ids = list(prev_state_ids.values())
- auth_chain = await self.store.get_auth_chain(state_ids)
+ auth_chain = await self.store.get_auth_chain(event.room_id, state_ids)
state = await self.store.get_events(list(prev_state_ids.values()))
@@ -2413,7 +2413,7 @@ class FederationHandler(BaseHandler):
# Now get the current auth_chain for the event.
local_auth_chain = await self.store.get_auth_chain(
- list(event.auth_event_ids()), include_given=True
+ room_id, list(event.auth_event_ids()), include_given=True
)
# TODO: Check if we would now reject event_id. If so we need to tell
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 798c29748f..ac004ca7b9 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -16,7 +16,7 @@
"""Contains functions for registering clients."""
import logging
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
from prometheus_client import Counter
@@ -85,6 +85,7 @@ class RegistrationHandler(BaseHandler):
)
else:
self.device_handler = hs.get_device_handler()
+ self._register_device_client = self.register_device_inner
self.pusher_pool = hs.get_pusherpool()
self.session_lifetime = hs.config.session_lifetime
@@ -758,17 +759,35 @@ class RegistrationHandler(BaseHandler):
Returns:
Tuple of device ID and access token
"""
+ res = await self._register_device_client(
+ user_id=user_id,
+ device_id=device_id,
+ initial_display_name=initial_display_name,
+ is_guest=is_guest,
+ is_appservice_ghost=is_appservice_ghost,
+ )
- if self.hs.config.worker_app:
- r = await self._register_device_client(
- user_id=user_id,
- device_id=device_id,
- initial_display_name=initial_display_name,
- is_guest=is_guest,
- is_appservice_ghost=is_appservice_ghost,
- )
- return r["device_id"], r["access_token"]
+ login_counter.labels(
+ guest=is_guest,
+ auth_provider=(auth_provider_id or ""),
+ ).inc()
+
+ return res["device_id"], res["access_token"]
+
+ async def register_device_inner(
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ initial_display_name: Optional[str],
+ is_guest: bool = False,
+ is_appservice_ghost: bool = False,
+ ) -> Dict[str, str]:
+ """Helper for register_device
+ Does the bits that need doing on the main process. Not for use outside this
+ class and RegisterDeviceReplicationServlet.
+ """
+ assert not self.hs.config.worker_app
valid_until_ms = None
if self.session_lifetime is not None:
if is_guest:
@@ -793,12 +812,7 @@ class RegistrationHandler(BaseHandler):
is_appservice_ghost=is_appservice_ghost,
)
- login_counter.labels(
- guest=is_guest,
- auth_provider=(auth_provider_id or ""),
- ).inc()
-
- return (registered_device_id, access_token)
+ return {"device_id": registered_device_id, "access_token": access_token}
async def post_registration_actions(
self, user_id: str, auth_result: dict, access_token: Optional[str]
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 36071feb36..4ec1bfa6ea 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -61,7 +61,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
is_guest = content["is_guest"]
is_appservice_ghost = content["is_appservice_ghost"]
- device_id, access_token = await self.registration_handler.register_device(
+ res = await self.registration_handler.register_device_inner(
user_id,
device_id,
initial_display_name,
@@ -69,7 +69,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
is_appservice_ghost=is_appservice_ghost,
)
- return 200, {"device_id": device_id, "access_token": access_token}
+ return 200, res
def register_servlets(hs, http_server):
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 0641924f18..8b4841ed5d 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -35,6 +35,7 @@ from synapse.api.errors import (
from synapse.config._base import ConfigError
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import random_string
@@ -145,7 +146,7 @@ class MediaRepository:
upload_name: Optional[str],
content: IO,
content_length: int,
- auth_user: str,
+ auth_user: UserID,
) -> str:
"""Store uploaded content for a local user and return the mxc URL
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 18ddb92fcc..332193ad1c 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -54,11 +54,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
) # type: LruCache[str, List[Tuple[str, int]]]
async def get_auth_chain(
- self, event_ids: Collection[str], include_given: bool = False
+ self, room_id: str, event_ids: Collection[str], include_given: bool = False
) -> List[EventBase]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
+ room_id: The room the event is in.
event_ids: state events
include_given: include the given events in result
@@ -66,24 +67,44 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
list of events
"""
event_ids = await self.get_auth_chain_ids(
- event_ids, include_given=include_given
+ room_id, event_ids, include_given=include_given
)
return await self.get_events_as_list(event_ids)
async def get_auth_chain_ids(
self,
+ room_id: str,
event_ids: Collection[str],
include_given: bool = False,
) -> List[str]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
+ room_id: The room the event is in.
event_ids: state events
include_given: include the given events in result
Returns:
- An awaitable which resolve to a list of event_ids
+ list of event_ids
"""
+
+ # Check if we have indexed the room so we can use the chain cover
+ # algorithm.
+ room = await self.get_room(room_id)
+ if room["has_auth_chain_index"]:
+ try:
+ return await self.db_pool.runInteraction(
+ "get_auth_chain_ids_chains",
+ self._get_auth_chain_ids_using_cover_index_txn,
+ room_id,
+ event_ids,
+ include_given,
+ )
+ except _NoChainCoverIndex:
+ # For whatever reason we don't actually have a chain cover index
+ # for the events in question, so we fall back to the old method.
+ pass
+
return await self.db_pool.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
@@ -91,9 +112,130 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
include_given,
)
+ def _get_auth_chain_ids_using_cover_index_txn(
+ self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
+ ) -> List[str]:
+ """Calculates the auth chain IDs using the chain index."""
+
+ # First we look up the chain ID/sequence numbers for the given events.
+
+ initial_events = set(event_ids)
+
+ # All the events that we've found that are reachable from the events.
+ seen_events = set() # type: Set[str]
+
+ # A map from chain ID to max sequence number of the given events.
+ event_chains = {} # type: Dict[int, int]
+
+ sql = """
+ SELECT event_id, chain_id, sequence_number
+ FROM event_auth_chains
+ WHERE %s
+ """
+ for batch in batch_iter(initial_events, 1000):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "event_id", batch
+ )
+ txn.execute(sql % (clause,), args)
+
+ for event_id, chain_id, sequence_number in txn:
+ seen_events.add(event_id)
+ event_chains[chain_id] = max(
+ sequence_number, event_chains.get(chain_id, 0)
+ )
+
+ # Check that we actually have a chain ID for all the events.
+ events_missing_chain_info = initial_events.difference(seen_events)
+ if events_missing_chain_info:
+ # This can happen due to e.g. downgrade/upgrade of the server. We
+ # raise an exception and fall back to the previous algorithm.
+ logger.info(
+ "Unexpectedly found that events don't have chain IDs in room %s: %s",
+ room_id,
+ events_missing_chain_info,
+ )
+ raise _NoChainCoverIndex(room_id)
+
+ # Now we look up all links for the chains we have, adding chains that
+ # are reachable from any event.
+ sql = """
+ SELECT
+ origin_chain_id, origin_sequence_number,
+ target_chain_id, target_sequence_number
+ FROM event_auth_chain_links
+ WHERE %s
+ """
+
+ # A map from chain ID to max sequence number *reachable* from any event ID.
+ chains = {} # type: Dict[int, int]
+
+ # Add all linked chains reachable from initial set of chains.
+ for batch in batch_iter(event_chains, 1000):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "origin_chain_id", batch
+ )
+ txn.execute(sql % (clause,), args)
+
+ for (
+ origin_chain_id,
+ origin_sequence_number,
+ target_chain_id,
+ target_sequence_number,
+ ) in txn:
+ # chains are only reachable if the origin sequence number of
+ # the link is less than the max sequence number in the
+ # origin chain.
+ if origin_sequence_number <= event_chains.get(origin_chain_id, 0):
+ chains[target_chain_id] = max(
+ target_sequence_number,
+ chains.get(target_chain_id, 0),
+ )
+
+ # Add the initial set of chains, excluding the sequence corresponding to
+ # initial event.
+ for chain_id, seq_no in event_chains.items():
+ chains[chain_id] = max(seq_no - 1, chains.get(chain_id, 0))
+
+ # Now for each chain we figure out the maximum sequence number reachable
+ # from *any* event ID. Events with a sequence less than that are in the
+ # auth chain.
+ if include_given:
+ results = initial_events
+ else:
+ results = set()
+
+ if isinstance(self.database_engine, PostgresEngine):
+ # We can use `execute_values` to efficiently fetch the gaps when
+ # using postgres.
+ sql = """
+ SELECT event_id
+ FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
+ WHERE
+ c.chain_id = l.chain_id
+ AND sequence_number <= max_seq
+ """
+
+ rows = txn.execute_values(sql, chains.items())
+ results.update(r for r, in rows)
+ else:
+ # For SQLite we just fall back to doing a noddy for loop.
+ sql = """
+ SELECT event_id FROM event_auth_chains
+ WHERE chain_id = ? AND sequence_number <= ?
+ """
+ for chain_id, max_no in chains.items():
+ txn.execute(sql, (chain_id, max_no))
+ results.update(r for r, in txn)
+
+ return list(results)
+
def _get_auth_chain_ids_txn(
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
) -> List[str]:
+ """Calculates the auth chain IDs.
+
+ This is used when we don't have a cover index for the room.
+ """
if include_given:
results = set(event_ids)
else:
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 2cfa3dface..a7ac68cf7e 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -16,7 +16,7 @@
# limitations under the License.
import logging
import re
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import attr
@@ -1614,7 +1614,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
async def user_delete_access_tokens(
self,
user_id: str,
- except_token_id: Optional[str] = None,
+ except_token_id: Optional[int] = None,
device_id: Optional[str] = None,
) -> List[Tuple[str, int, Optional[str]]]:
"""
@@ -1637,7 +1637,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
items = keyvalues.items()
where_clause = " AND ".join(k + " = ?" for k, _ in items)
- values = [v for _, v in items]
+ values = [v for _, v in items] # type: List[Union[str, int]]
if except_token_id:
where_clause += " AND id != ?"
values.append(except_token_id)
diff --git a/synapse/types.py b/synapse/types.py
index 6a41a3665d..00655f6dd2 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -84,33 +84,32 @@ class ISynapseReactor(
"""The interfaces necessary for Synapse to function."""
-class Requester(
- namedtuple(
- "Requester",
- [
- "user",
- "access_token_id",
- "is_guest",
- "shadow_banned",
- "device_id",
- "app_service",
- "authenticated_entity",
- ],
- )
-):
+@attr.s(frozen=True, slots=True)
+class Requester:
"""
Represents the user making a request
Attributes:
- user (UserID): id of the user making the request
- access_token_id (int|None): *ID* of the access token used for this
+ user: id of the user making the request
+ access_token_id: *ID* of the access token used for this
request, or None if it came via the appservice API or similar
- is_guest (bool): True if the user making this request is a guest user
- shadow_banned (bool): True if the user making this request has been shadow-banned.
- device_id (str|None): device_id which was set at authentication time
- app_service (ApplicationService|None): the AS requesting on behalf of the user
+ is_guest: True if the user making this request is a guest user
+ shadow_banned: True if the user making this request has been shadow-banned.
+ device_id: device_id which was set at authentication time
+ app_service: the AS requesting on behalf of the user
+ authenticated_entity: The entity that authenticated when making the request.
+ This is different to the user_id when an admin user or the server is
+ "puppeting" the user.
"""
+ user = attr.ib(type="UserID")
+ access_token_id = attr.ib(type=Optional[int])
+ is_guest = attr.ib(type=bool)
+ shadow_banned = attr.ib(type=bool)
+ device_id = attr.ib(type=Optional[str])
+ app_service = attr.ib(type=Optional["ApplicationService"])
+ authenticated_entity = attr.ib(type=str)
+
def serialize(self):
"""Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize`
@@ -158,23 +157,23 @@ class Requester(
def create_requester(
user_id: Union[str, "UserID"],
access_token_id: Optional[int] = None,
- is_guest: Optional[bool] = False,
- shadow_banned: Optional[bool] = False,
+ is_guest: bool = False,
+ shadow_banned: bool = False,
device_id: Optional[str] = None,
app_service: Optional["ApplicationService"] = None,
authenticated_entity: Optional[str] = None,
-):
+) -> Requester:
"""
Create a new ``Requester`` object
Args:
- user_id (str|UserID): id of the user making the request
- access_token_id (int|None): *ID* of the access token used for this
+ user_id: id of the user making the request
+ access_token_id: *ID* of the access token used for this
request, or None if it came via the appservice API or similar
- is_guest (bool): True if the user making this request is a guest user
- shadow_banned (bool): True if the user making this request is shadow-banned.
- device_id (str|None): device_id which was set at authentication time
- app_service (ApplicationService|None): the AS requesting on behalf of the user
+ is_guest: True if the user making this request is a guest user
+ shadow_banned: True if the user making this request is shadow-banned.
+ device_id: device_id which was set at authentication time
+ app_service: the AS requesting on behalf of the user
authenticated_entity: The entity that authenticated when making the request.
This is different to the user_id when an admin user or the server is
"puppeting" the user.
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 06000f81a6..d597d712d6 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -118,8 +118,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
self.assertTrue(r == [room2] or r == [room3])
- @parameterized.expand([(True,), (False,)])
- def test_auth_difference(self, use_chain_cover_index: bool):
+ def _setup_auth_chain(self, use_chain_cover_index: bool) -> str:
room_id = "@ROOM:local"
# The silly auth graph we use to test the auth difference algorithm,
@@ -165,7 +164,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
"j": 1,
}
- # Mark the room as not having a cover index
+ # Mark the room as maybe having a cover index.
def store_room(txn):
self.store.db_pool.simple_insert_txn(
@@ -222,6 +221,77 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
)
+ return room_id
+
+ @parameterized.expand([(True,), (False,)])
+ def test_auth_chain_ids(self, use_chain_cover_index: bool):
+ room_id = self._setup_auth_chain(use_chain_cover_index)
+
+ # a and b have the same auth chain.
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["a"]))
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["b"]))
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["a", "b"])
+ )
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["c"]))
+ self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
+
+ # d and e have the same auth chain.
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["d"]))
+ self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["e"]))
+ self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["f"]))
+ self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["g"]))
+ self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"]))
+ self.assertEqual(auth_chain_ids, ["k"])
+
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"]))
+ self.assertEqual(auth_chain_ids, ["j"])
+
+ # j and k have no parents.
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"]))
+ self.assertEqual(auth_chain_ids, [])
+ auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"]))
+ self.assertEqual(auth_chain_ids, [])
+
+ # More complex input sequences.
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["b", "c", "d"])
+ )
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["h", "i"])
+ )
+ self.assertCountEqual(auth_chain_ids, ["k", "j"])
+
+ # e gets returned even though include_given is false, but it is in the
+ # auth chain of b.
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["b", "e"])
+ )
+ self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
+
+ # Test include_given.
+ auth_chain_ids = self.get_success(
+ self.store.get_auth_chain_ids(room_id, ["i"], include_given=True)
+ )
+ self.assertCountEqual(auth_chain_ids, ["i", "j"])
+
+ @parameterized.expand([(True,), (False,)])
+ def test_auth_difference(self, use_chain_cover_index: bool):
+ room_id = self._setup_auth_chain(use_chain_cover_index)
+
# Now actually test that various combinations give the right result:
difference = self.get_success(
|