summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-07-01 14:25:37 -0400
committerGitHub <noreply@github.com>2021-07-01 14:25:37 -0400
commit8d609435c0053fc4decbc3f9c3603e728912749c (patch)
tree71ac54e8aaf9a2c810dd374ef5f3e5ecbbd20d73
parentfix ordering of bg update (#10291) (diff)
downloadsynapse-8d609435c0053fc4decbc3f9c3603e728912749c.tar.xz
Move methods involving event authentication to EventAuthHandler. (#10268)
Instead of mixing them with user authentication methods.
Diffstat (limited to '')
-rw-r--r--changelog.d/10268.misc1
-rw-r--r--synapse/api/auth.py75
-rw-r--r--synapse/events/builder.py12
-rw-r--r--synapse/federation/federation_server.py6
-rw-r--r--synapse/handlers/event_auth.py62
-rw-r--r--synapse/handlers/federation.py36
-rw-r--r--synapse/handlers/message.py9
-rw-r--r--synapse/handlers/room.py3
-rw-r--r--synapse/handlers/space_summary.py6
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py4
-rw-r--r--tests/handlers/test_presence.py4
11 files changed, 112 insertions, 106 deletions
diff --git a/changelog.d/10268.misc b/changelog.d/10268.misc
new file mode 100644
index 0000000000..9e3f60c72f
--- /dev/null
+++ b/changelog.d/10268.misc
@@ -0,0 +1 @@
+Move event authentication methods from `Auth` to `EventAuthHandler`.
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index f8b068e563..307f5f9a94 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Optional, Tuple
 
 import pymacaroons
 from netaddr import IPAddress
@@ -28,10 +28,8 @@ from synapse.api.errors import (
     InvalidClientTokenError,
     MissingClientTokenError,
 )
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.appservice import ApplicationService
 from synapse.events import EventBase
-from synapse.events.builder import EventBuilder
 from synapse.http import get_request_user_agent
 from synapse.http.site import SynapseRequest
 from synapse.logging import opentracing as opentracing
@@ -39,7 +37,6 @@ from synapse.storage.databases.main.registration import TokenLookupResult
 from synapse.types import Requester, StateMap, UserID, create_requester
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
-from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -47,15 +44,6 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-AuthEventTypes = (
-    EventTypes.Create,
-    EventTypes.Member,
-    EventTypes.PowerLevels,
-    EventTypes.JoinRules,
-    EventTypes.RoomHistoryVisibility,
-    EventTypes.ThirdPartyInvite,
-)
-
 # guests always get this device id.
 GUEST_DEVICE_ID = "guest_device"
 
@@ -66,9 +54,7 @@ class _InvalidMacaroonException(Exception):
 
 class Auth:
     """
-    FIXME: This class contains a mix of functions for authenticating users
-    of our client-server API and authenticating events added to room graphs.
-    The latter should be moved to synapse.handlers.event_auth.EventAuthHandler.
+    This class contains functions for authenticating users of our client-server API.
     """
 
     def __init__(self, hs: "HomeServer"):
@@ -90,18 +76,6 @@ class Auth:
         self._macaroon_secret_key = hs.config.macaroon_secret_key
         self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
 
-    async def check_from_context(
-        self, room_version: str, event, context, do_sig_check=True
-    ) -> None:
-        auth_event_ids = event.auth_event_ids()
-        auth_events_by_id = await self.store.get_events(auth_event_ids)
-        auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}
-
-        room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
-        event_auth.check(
-            room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check
-        )
-
     async def check_user_in_room(
         self,
         room_id: str,
@@ -152,13 +126,6 @@ class Auth:
 
         raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
 
-    async def check_host_in_room(self, room_id: str, host: str) -> bool:
-        with Measure(self.clock, "check_host_in_room"):
-            return await self.store.is_host_joined(room_id, host)
-
-    def get_public_keys(self, invite_event: EventBase) -> List[Dict[str, Any]]:
-        return event_auth.get_public_keys(invite_event)
-
     async def get_user_by_req(
         self,
         request: SynapseRequest,
@@ -489,44 +456,6 @@ class Auth:
         """
         return await self.store.is_server_admin(user)
 
-    def compute_auth_events(
-        self,
-        event: Union[EventBase, EventBuilder],
-        current_state_ids: StateMap[str],
-        for_verification: bool = False,
-    ) -> List[str]:
-        """Given an event and current state return the list of event IDs used
-        to auth an event.
-
-        If `for_verification` is False then only return auth events that
-        should be added to the event's `auth_events`.
-
-        Returns:
-            List of event IDs.
-        """
-
-        if event.type == EventTypes.Create:
-            return []
-
-        # Currently we ignore the `for_verification` flag even though there are
-        # some situations where we can drop particular auth events when adding
-        # to the event's `auth_events` (e.g. joins pointing to previous joins
-        # when room is publicly joinable). Dropping event IDs has the
-        # advantage that the auth chain for the room grows slower, but we use
-        # the auth chain in state resolution v2 to order events, which means
-        # care must be taken if dropping events to ensure that it doesn't
-        # introduce undesirable "state reset" behaviour.
-        #
-        # All of which sounds a bit tricky so we don't bother for now.
-
-        auth_ids = []
-        for etype, state_key in event_auth.auth_types_for_event(event):
-            auth_ev_id = current_state_ids.get((etype, state_key))
-            if auth_ev_id:
-                auth_ids.append(auth_ev_id)
-
-        return auth_ids
-
     async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool:
         """Determine whether the user is allowed to edit the room's entry in the
         published room list.
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index fb48ec8541..26e3950859 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -34,7 +34,7 @@ from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
 if TYPE_CHECKING:
-    from synapse.api.auth import Auth
+    from synapse.handlers.event_auth import EventAuthHandler
     from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
@@ -66,7 +66,7 @@ class EventBuilder:
     """
 
     _state: StateHandler
-    _auth: "Auth"
+    _event_auth_handler: "EventAuthHandler"
     _store: DataStore
     _clock: Clock
     _hostname: str
@@ -125,7 +125,9 @@ class EventBuilder:
             state_ids = await self._state.get_current_state_ids(
                 self.room_id, prev_event_ids
             )
-            auth_event_ids = self._auth.compute_auth_events(self, state_ids)
+            auth_event_ids = self._event_auth_handler.compute_auth_events(
+                self, state_ids
+            )
 
         format_version = self.room_version.event_format
         if format_version == EventFormatVersions.V1:
@@ -193,7 +195,7 @@ class EventBuilderFactory:
 
         self.store = hs.get_datastore()
         self.state = hs.get_state_handler()
-        self.auth = hs.get_auth()
+        self._event_auth_handler = hs.get_event_auth_handler()
 
     def new(self, room_version: str, key_values: dict) -> EventBuilder:
         """Generate an event builder appropriate for the given room version
@@ -229,7 +231,7 @@ class EventBuilderFactory:
         return EventBuilder(
             store=self.store,
             state=self.state,
-            auth=self.auth,
+            event_auth_handler=self._event_auth_handler,
             clock=self.clock,
             hostname=self.hostname,
             signing_key=self.signing_key,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index e93b7577fe..b312d0b809 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -108,9 +108,9 @@ class FederationServer(FederationBase):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
-        self.auth = hs.get_auth()
         self.handler = hs.get_federation_handler()
         self.state = hs.get_state_handler()
+        self._event_auth_handler = hs.get_event_auth_handler()
 
         self.device_handler = hs.get_device_handler()
 
@@ -420,7 +420,7 @@ class FederationServer(FederationBase):
         origin_host, _ = parse_server_name(origin)
         await self.check_server_matches_acl(origin_host, room_id)
 
-        in_room = await self.auth.check_host_in_room(room_id, origin)
+        in_room = await self._event_auth_handler.check_host_in_room(room_id, origin)
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
@@ -453,7 +453,7 @@ class FederationServer(FederationBase):
         origin_host, _ = parse_server_name(origin)
         await self.check_server_matches_acl(origin_host, room_id)
 
-        in_room = await self.auth.check_host_in_room(room_id, origin)
+        in_room = await self._event_auth_handler.check_host_in_room(room_id, origin)
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index 989996b628..41dbdfd0a1 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -11,8 +11,9 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import TYPE_CHECKING, Collection, Optional
+from typing import TYPE_CHECKING, Collection, List, Optional, Union
 
+from synapse import event_auth
 from synapse.api.constants import (
     EventTypes,
     JoinRules,
@@ -20,9 +21,11 @@ from synapse.api.constants import (
     RestrictedJoinRuleTypes,
 )
 from synapse.api.errors import AuthError
-from synapse.api.room_versions import RoomVersion
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.events import EventBase
+from synapse.events.builder import EventBuilder
 from synapse.types import StateMap
+from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -34,8 +37,63 @@ class EventAuthHandler:
     """
 
     def __init__(self, hs: "HomeServer"):
+        self._clock = hs.get_clock()
         self._store = hs.get_datastore()
 
+    async def check_from_context(
+        self, room_version: str, event, context, do_sig_check=True
+    ) -> None:
+        auth_event_ids = event.auth_event_ids()
+        auth_events_by_id = await self._store.get_events(auth_event_ids)
+        auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}
+
+        room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+        event_auth.check(
+            room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check
+        )
+
+    def compute_auth_events(
+        self,
+        event: Union[EventBase, EventBuilder],
+        current_state_ids: StateMap[str],
+        for_verification: bool = False,
+    ) -> List[str]:
+        """Given an event and current state return the list of event IDs used
+        to auth an event.
+
+        If `for_verification` is False then only return auth events that
+        should be added to the event's `auth_events`.
+
+        Returns:
+            List of event IDs.
+        """
+
+        if event.type == EventTypes.Create:
+            return []
+
+        # Currently we ignore the `for_verification` flag even though there are
+        # some situations where we can drop particular auth events when adding
+        # to the event's `auth_events` (e.g. joins pointing to previous joins
+        # when room is publicly joinable). Dropping event IDs has the
+        # advantage that the auth chain for the room grows slower, but we use
+        # the auth chain in state resolution v2 to order events, which means
+        # care must be taken if dropping events to ensure that it doesn't
+        # introduce undesirable "state reset" behaviour.
+        #
+        # All of which sounds a bit tricky so we don't bother for now.
+
+        auth_ids = []
+        for etype, state_key in event_auth.auth_types_for_event(event):
+            auth_ev_id = current_state_ids.get((etype, state_key))
+            if auth_ev_id:
+                auth_ids.append(auth_ev_id)
+
+        return auth_ids
+
+    async def check_host_in_room(self, room_id: str, host: str) -> bool:
+        with Measure(self._clock, "check_host_in_room"):
+            return await self._store.is_host_joined(room_id, host)
+
     async def check_restricted_join_rules(
         self,
         state_ids: StateMap[str],
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index d929c65131..991ec9919a 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -250,7 +250,9 @@ class FederationHandler(BaseHandler):
         #
         # Note that if we were never in the room then we would have already
         # dropped the event, since we wouldn't know the room version.
-        is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
+        is_in_room = await self._event_auth_handler.check_host_in_room(
+            room_id, self.server_name
+        )
         if not is_in_room:
             logger.info(
                 "Ignoring PDU from %s as we're not in the room",
@@ -1674,7 +1676,9 @@ class FederationHandler(BaseHandler):
         room_version = await self.store.get_room_version_id(room_id)
 
         # now check that we are *still* in the room
-        is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
+        is_in_room = await self._event_auth_handler.check_host_in_room(
+            room_id, self.server_name
+        )
         if not is_in_room:
             logger.info(
                 "Got /make_join request for room %s we are no longer in",
@@ -1705,7 +1709,7 @@ class FederationHandler(BaseHandler):
 
         # The remote hasn't signed it yet, obviously. We'll do the full checks
         # when we get the event back in `on_send_join_request`
-        await self.auth.check_from_context(
+        await self._event_auth_handler.check_from_context(
             room_version, event, context, do_sig_check=False
         )
 
@@ -1877,7 +1881,7 @@ class FederationHandler(BaseHandler):
         try:
             # The remote hasn't signed it yet, obviously. We'll do the full checks
             # when we get the event back in `on_send_leave_request`
-            await self.auth.check_from_context(
+            await self._event_auth_handler.check_from_context(
                 room_version, event, context, do_sig_check=False
             )
         except AuthError as e:
@@ -1939,7 +1943,7 @@ class FederationHandler(BaseHandler):
         try:
             # The remote hasn't signed it yet, obviously. We'll do the full checks
             # when we get the event back in `on_send_knock_request`
-            await self.auth.check_from_context(
+            await self._event_auth_handler.check_from_context(
                 room_version, event, context, do_sig_check=False
             )
         except AuthError as e:
@@ -2111,7 +2115,7 @@ class FederationHandler(BaseHandler):
     async def on_backfill_request(
         self, origin: str, room_id: str, pdu_list: List[str], limit: int
     ) -> List[EventBase]:
-        in_room = await self.auth.check_host_in_room(room_id, origin)
+        in_room = await self._event_auth_handler.check_host_in_room(room_id, origin)
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
@@ -2146,7 +2150,9 @@ class FederationHandler(BaseHandler):
         )
 
         if event:
-            in_room = await self.auth.check_host_in_room(event.room_id, origin)
+            in_room = await self._event_auth_handler.check_host_in_room(
+                event.room_id, origin
+            )
             if not in_room:
                 raise AuthError(403, "Host not in room.")
 
@@ -2499,7 +2505,7 @@ class FederationHandler(BaseHandler):
         latest_events: List[str],
         limit: int,
     ) -> List[EventBase]:
-        in_room = await self.auth.check_host_in_room(room_id, origin)
+        in_room = await self._event_auth_handler.check_host_in_room(room_id, origin)
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
@@ -2562,7 +2568,7 @@ class FederationHandler(BaseHandler):
 
         if not auth_events:
             prev_state_ids = await context.get_prev_state_ids()
-            auth_events_ids = self.auth.compute_auth_events(
+            auth_events_ids = self._event_auth_handler.compute_auth_events(
                 event, prev_state_ids, for_verification=True
             )
             auth_events_x = await self.store.get_events(auth_events_ids)
@@ -2991,7 +2997,7 @@ class FederationHandler(BaseHandler):
             "state_key": target_user_id,
         }
 
-        if await self.auth.check_host_in_room(room_id, self.hs.hostname):
+        if await self._event_auth_handler.check_host_in_room(room_id, self.hs.hostname):
             room_version = await self.store.get_room_version_id(room_id)
             builder = self.event_builder_factory.new(room_version, event_dict)
 
@@ -3011,7 +3017,9 @@ class FederationHandler(BaseHandler):
             event.internal_metadata.send_on_behalf_of = self.hs.hostname
 
             try:
-                await self.auth.check_from_context(room_version, event, context)
+                await self._event_auth_handler.check_from_context(
+                    room_version, event, context
+                )
             except AuthError as e:
                 logger.warning("Denying new third party invite %r because %s", event, e)
                 raise e
@@ -3054,7 +3062,9 @@ class FederationHandler(BaseHandler):
         )
 
         try:
-            await self.auth.check_from_context(room_version, event, context)
+            await self._event_auth_handler.check_from_context(
+                room_version, event, context
+            )
         except AuthError as e:
             logger.warning("Denying third party invite %r because %s", event, e)
             raise e
@@ -3142,7 +3152,7 @@ class FederationHandler(BaseHandler):
         last_exception = None  # type: Optional[Exception]
 
         # for each public key in the 3pid invite event
-        for public_key_object in self.hs.get_auth().get_public_keys(invite_event):
+        for public_key_object in event_auth.get_public_keys(invite_event):
             try:
                 # for each sig on the third_party_invite block of the actual invite
                 for server, signature_block in signed["signatures"].items():
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 364c5cd2d3..66e40a915d 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -385,6 +385,7 @@ class EventCreationHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
+        self._event_auth_handler = hs.get_event_auth_handler()
         self.store = hs.get_datastore()
         self.storage = hs.get_storage()
         self.state = hs.get_state_handler()
@@ -597,7 +598,7 @@ class EventCreationHandler:
                 (e.type, e.state_key): e.event_id for e in auth_events
             }
             # Actually strip down and use the necessary auth events
-            auth_event_ids = self.auth.compute_auth_events(
+            auth_event_ids = self._event_auth_handler.compute_auth_events(
                 event=temp_event,
                 current_state_ids=auth_event_state_map,
                 for_verification=False,
@@ -1056,7 +1057,9 @@ class EventCreationHandler:
             assert event.content["membership"] == Membership.LEAVE
         else:
             try:
-                await self.auth.check_from_context(room_version, event, context)
+                await self._event_auth_handler.check_from_context(
+                    room_version, event, context
+                )
             except AuthError as err:
                 logger.warning("Denying new event %r because %s", event, err)
                 raise err
@@ -1381,7 +1384,7 @@ class EventCreationHandler:
                     raise AuthError(403, "Redacting server ACL events is not permitted")
 
             prev_state_ids = await context.get_prev_state_ids()
-            auth_events_ids = self.auth.compute_auth_events(
+            auth_events_ids = self._event_auth_handler.compute_auth_events(
                 event, prev_state_ids, for_verification=True
             )
             auth_events_map = await self.store.get_events(auth_events_ids)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 835d874cee..579b1b93c5 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -83,6 +83,7 @@ class RoomCreationHandler(BaseHandler):
         self.spam_checker = hs.get_spam_checker()
         self.event_creation_handler = hs.get_event_creation_handler()
         self.room_member_handler = hs.get_room_member_handler()
+        self._event_auth_handler = hs.get_event_auth_handler()
         self.config = hs.config
 
         # Room state based off defined presets
@@ -226,7 +227,7 @@ class RoomCreationHandler(BaseHandler):
             },
         )
         old_room_version = await self.store.get_room_version_id(old_room_id)
-        await self.auth.check_from_context(
+        await self._event_auth_handler.check_from_context(
             old_room_version, tombstone_event, tombstone_context
         )
 
diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py
index 266f369883..b585057ec3 100644
--- a/synapse/handlers/space_summary.py
+++ b/synapse/handlers/space_summary.py
@@ -472,7 +472,7 @@ class SpaceSummaryHandler:
         # If this is a request over federation, check if the host is in the room or
         # is in one of the spaces specified via the join rules.
         elif origin:
-            if await self._auth.check_host_in_room(room_id, origin):
+            if await self._event_auth_handler.check_host_in_room(room_id, origin):
                 return True
 
             # Alternately, if the host has a user in any of the spaces specified
@@ -485,7 +485,9 @@ class SpaceSummaryHandler:
                     await self._event_auth_handler.get_rooms_that_allow_join(state_ids)
                 )
                 for space_id in allowed_rooms:
-                    if await self._auth.check_host_in_room(space_id, origin):
+                    if await self._event_auth_handler.check_host_in_room(
+                        space_id, origin
+                    ):
                         return True
 
         # otherwise, check if the room is peekable
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 350646f458..669ea462e2 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -104,7 +104,7 @@ class BulkPushRuleEvaluator:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
-        self.auth = hs.get_auth()
+        self._event_auth_handler = hs.get_event_auth_handler()
 
         # Used by `RulesForRoom` to ensure only one thing mutates the cache at a
         # time. Keyed off room_id.
@@ -172,7 +172,7 @@ class BulkPushRuleEvaluator:
             # not having a power level event is an extreme edge case
             auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
         else:
-            auth_events_ids = self.auth.compute_auth_events(
+            auth_events_ids = self._event_auth_handler.compute_auth_events(
                 event, prev_state_ids, for_verification=False
             )
             auth_events_dict = await self.store.get_events(auth_events_ids)
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index dfb9b3a0fa..18e92e90d7 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -734,7 +734,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
 
         self.store = hs.get_datastore()
         self.state = hs.get_state_handler()
-        self.auth = hs.get_auth()
+        self._event_auth_handler = hs.get_event_auth_handler()
 
         # We don't actually check signatures in tests, so lets just create a
         # random key to use.
@@ -846,7 +846,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
 
         builder = EventBuilder(
             state=self.state,
-            auth=self.auth,
+            event_auth_handler=self._event_auth_handler,
             store=self.store,
             clock=self.clock,
             hostname=hostname,