From d59378d86b04b396ff493b45bd2b9b68513038e0 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 13 Oct 2020 15:17:13 +0100 Subject: Remove redundant calls to third_party_rules in `on_send_{join,leave}` There's not much point in calling these *after* we have decided to accept them into the DAG. --- synapse/handlers/federation.py | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 455acd7669..c38cb7a5c8 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1567,15 +1567,6 @@ class FederationHandler(BaseHandler): context = await self._handle_new_event(origin, event) - event_allowed = await self.third_party_event_rules.check_event_allowed( - event, context - ) - if not event_allowed: - logger.info("Sending of join %s forbidden by third-party rules", event) - raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN - ) - logger.debug( "on_send_join_request: After _handle_new_event: %s, sigs: %s", event.event_id, @@ -1789,16 +1780,7 @@ class FederationHandler(BaseHandler): event.internal_metadata.outlier = False - context = await self._handle_new_event(origin, event) - - event_allowed = await self.third_party_event_rules.check_event_allowed( - event, context - ) - if not event_allowed: - logger.info("Sending of leave %s forbidden by third-party rules", event) - raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN - ) + await self._handle_new_event(origin, event) logger.debug( "on_send_leave_request: After _handle_new_event: %s, sigs: %s", -- cgit 1.5.1 From 123711ed198bd5cf9984818f8bac1926ed1af5fa Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 13 Oct 2020 15:44:54 +0100 Subject: Move third_party_rules check to event creation time Rather than waiting until we handle the event, call the ThirdPartyRules check when we fist create the event. --- synapse/handlers/federation.py | 46 ++---------------------------------------- synapse/handlers/message.py | 19 +++++++++-------- 2 files changed, 13 insertions(+), 52 deletions(-) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index c38cb7a5c8..fde8f00531 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1507,18 +1507,9 @@ class FederationHandler(BaseHandler): event, context = await self.event_creation_handler.create_new_client_event( builder=builder ) - except AuthError as e: + except SynapseError as e: logger.warning("Failed to create join to %s because %s", room_id, e) - raise e - - event_allowed = await self.third_party_event_rules.check_event_allowed( - event, context - ) - if not event_allowed: - logger.info("Creation of join %s forbidden by third-party rules", event) - raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN - ) + raise # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` @@ -1739,15 +1730,6 @@ class FederationHandler(BaseHandler): builder=builder ) - event_allowed = await self.third_party_event_rules.check_event_allowed( - event, context - ) - if not event_allowed: - logger.warning("Creation of leave %s forbidden by third-party rules", event) - raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN - ) - try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_leave_request` @@ -2676,18 +2658,6 @@ class FederationHandler(BaseHandler): builder=builder ) - event_allowed = await self.third_party_event_rules.check_event_allowed( - event, context - ) - if not event_allowed: - logger.info( - "Creation of threepid invite %s forbidden by third-party rules", - event, - ) - raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN - ) - event, context = await self.add_display_name_to_third_party_invite( room_version, event_dict, event, context ) @@ -2738,18 +2708,6 @@ class FederationHandler(BaseHandler): event, context = await self.event_creation_handler.create_new_client_event( builder=builder ) - - event_allowed = await self.third_party_event_rules.check_event_allowed( - event, context - ) - if not event_allowed: - logger.warning( - "Exchange of threepid invite %s forbidden by third-party rules", event - ) - raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN - ) - event, context = await self.add_display_name_to_third_party_invite( room_version, event_dict, event, context ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index c52e6824d3..987c759791 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -795,6 +795,17 @@ class EventCreationHandler: if requester: context.app_service = requester.app_service + event_allowed = await self.third_party_event_rules.check_event_allowed( + event, context + ) + if not event_allowed: + logger.info( + "Event %s forbidden by third-party rules", event, + ) + raise SynapseError( + 403, "This event is not allowed in this context", Codes.FORBIDDEN + ) + self.validator.validate_new(event, self.config) # If this event is an annotation then we check that that the sender @@ -881,14 +892,6 @@ class EventCreationHandler: else: room_version = await self.store.get_room_version_id(event.room_id) - event_allowed = await self.third_party_event_rules.check_event_allowed( - event, context - ) - if not event_allowed: - raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN - ) - if event.internal_metadata.is_out_of_band_membership(): # the only sort of out-of-band-membership events we expect to see here # are invite rejections we have generated ourselves. -- cgit 1.5.1 From d9d86c29965ceaf92927f061375a0d7a888aea67 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 13 Oct 2020 23:06:36 +0100 Subject: Remove redundant `token_id` parameter to create_event this is always the same as requester.access_token_id. --- synapse/handlers/message.py | 8 +++----- synapse/handlers/room.py | 1 - synapse/handlers/room_member.py | 1 - tests/handlers/test_message.py | 1 - 4 files changed, 3 insertions(+), 8 deletions(-) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index c52e6824d3..b8f541425b 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -437,7 +437,6 @@ class EventCreationHandler: self, requester: Requester, event_dict: dict, - token_id: Optional[str] = None, txn_id: Optional[str] = None, prev_event_ids: Optional[List[str]] = None, require_consent: bool = True, @@ -453,7 +452,6 @@ class EventCreationHandler: Args: requester event_dict: An entire event - token_id txn_id prev_event_ids: the forward extremities to use as the prev_events for the @@ -511,8 +509,8 @@ class EventCreationHandler: if require_consent and not is_exempt: await self.assert_accepted_privacy_policy(requester) - if token_id is not None: - builder.internal_metadata.token_id = token_id + if requester.access_token_id is not None: + builder.internal_metadata.token_id = requester.access_token_id if txn_id is not None: builder.internal_metadata.txn_id = txn_id @@ -726,7 +724,7 @@ class EventCreationHandler: return event, event.internal_metadata.stream_ordering event, context = await self.create_event( - requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id + requester, event_dict, txn_id=txn_id ) assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 93ed51063a..ec300d8877 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -214,7 +214,6 @@ class RoomCreationHandler(BaseHandler): "replacement_room": new_room_id, }, }, - token_id=requester.access_token_id, ) old_room_version = await self.store.get_room_version_id(old_room_id) await self.auth.check_from_context( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 0080eeaf8d..d9b98d8acd 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -193,7 +193,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # For backwards compatibility: "membership": membership, }, - token_id=requester.access_token_id, txn_id=txn_id, prev_event_ids=prev_event_ids, require_consent=require_consent, diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 64e28bc639..9f6f21a6e2 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -66,7 +66,6 @@ class EventCreationTestCase(unittest.HomeserverTestCase): "sender": self.requester.user.to_string(), "content": {"msgtype": "m.text", "body": random_string(5)}, }, - token_id=self.token_id, txn_id=txn_id, ) ) -- cgit 1.5.1 From 617e8a46538af56bd5abbb2c9e4df8025841c338 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 13 Oct 2020 18:53:56 +0100 Subject: Allow ThirdPartyRules modules to replace event content Support returning a new event dict from `check_event_allowed`. --- synapse/events/third_party_rules.py | 12 ++++-- synapse/handlers/message.py | 64 ++++++++++++++++++++++++++++- tests/rest/client/test_third_party_rules.py | 8 ++-- 3 files changed, 75 insertions(+), 9 deletions(-) diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 1535cc5339..a9aabe00df 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -12,7 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable + +from typing import Callable, Union from synapse.events import EventBase from synapse.events.snapshot import EventContext @@ -44,15 +45,20 @@ class ThirdPartyEventRules: async def check_event_allowed( self, event: EventBase, context: EventContext - ) -> bool: + ) -> Union[bool, dict]: """Check if a provided event should be allowed in the given context. + The module can return: + * True: the event is allowed. + * False: the event is not allowed, and should be rejected with M_FORBIDDEN. + * a dict: replacement event data. + Args: event: The event to be checked. context: The context of the event. Returns: - True if the event should be allowed, False if not. + The result from the ThirdPartyRules module, as above """ if self.third_party_rules is None: return True diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 987c759791..0c6aec347e 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -795,16 +795,22 @@ class EventCreationHandler: if requester: context.app_service = requester.app_service - event_allowed = await self.third_party_event_rules.check_event_allowed( + third_party_result = await self.third_party_event_rules.check_event_allowed( event, context ) - if not event_allowed: + if not third_party_result: logger.info( "Event %s forbidden by third-party rules", event, ) raise SynapseError( 403, "This event is not allowed in this context", Codes.FORBIDDEN ) + elif isinstance(third_party_result, dict): + # the third-party rules want to replace the event. We'll need to build a new + # event. + event, context = await self._rebuild_event_after_third_party_rules( + third_party_result, event + ) self.validator.validate_new(event, self.config) @@ -1294,3 +1300,57 @@ class EventCreationHandler: room_id, ) del self._rooms_to_exclude_from_dummy_event_insertion[room_id] + + async def _rebuild_event_after_third_party_rules( + self, third_party_result: dict, original_event: EventBase + ) -> Tuple[EventBase, EventContext]: + # the third_party_event_rules want to replace the event. + # we do some basic checks, and then return the replacement event and context. + + # Construct a new EventBuilder and validate it, which helps with the + # rest of these checks. + try: + builder = self.event_builder_factory.for_room_version( + original_event.room_version, third_party_result + ) + self.validator.validate_builder(builder) + except SynapseError as e: + raise Exception( + "Third party rules module created an invalid event: " + e.msg, + ) + + immutable_fields = [ + # changing the room is going to break things: we've already checked that the + # room exists, and are holding a concurrency limiter token for that room. + # Also, we might need to use a different room version. + "room_id", + # changing the type or state key might work, but we'd need to check that the + # calling functions aren't making assumptions about them. + "type", + "state_key", + ] + + for k in immutable_fields: + if getattr(builder, k, None) != original_event.get(k): + raise Exception( + "Third party rules module created an invalid event: " + "cannot change field " + k + ) + + # check that the new sender belongs to this HS + if not self.hs.is_mine_id(builder.sender): + raise Exception( + "Third party rules module created an invalid event: " + "invalid sender " + builder.sender + ) + + # copy over the original internal metadata + for k, v in original_event.internal_metadata.get_dict().items(): + setattr(builder.internal_metadata, k, v) + + event = await builder.build(prev_event_ids=original_event.prev_event_ids()) + + # we rebuild the event context, to be on the safe side. If nothing else, + # delta_ids might need an update. + context = await self.state.compute_event_context(event) + return event, context diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index b737625e33..d404550800 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -115,12 +115,12 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"403", channel.result) def test_modify_event(self): - """Tests that the module can successfully tweak an event before it is persisted. - """ + """The module can return a modified version of the event""" # first patch the event checker so that it will modify the event async def check(ev: EventBase, state): - ev.content = {"x": "y"} - return True + d = ev.get_dict() + d["content"] = {"x": "y"} + return d current_rules_module().check_event_allowed = check -- cgit 1.5.1 From 898196f1cca419c0d2b60529c86ddff3cea83072 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 13 Oct 2020 22:02:41 +0100 Subject: guard against accidental modification --- synapse/events/__init__.py | 6 ++++++ synapse/events/third_party_rules.py | 7 ++++--- tests/rest/client/test_third_party_rules.py | 20 ++++++++++++++++++++ 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 7a51d0a22f..65df62107f 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -312,6 +312,12 @@ class EventBase(metaclass=abc.ABCMeta): """ return [e for e, _ in self.auth_events] + def freeze(self): + """'Freeze' the event dict, so it cannot be modified by accident""" + + # this will be a no-op if the event dict is already frozen. + self._dict = freeze(self._dict) + class FrozenEvent(EventBase): format_version = EventFormatVersions.V1 # All events of this type are V1 diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index a9aabe00df..77fbd3f68a 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -69,9 +69,10 @@ class ThirdPartyEventRules: events = await self.store.get_events(prev_state_ids.values()) state_events = {(ev.type, ev.state_key): ev for ev in events.values()} - # The module can modify the event slightly if it wants, but caution should be - # exercised, and it's likely to go very wrong if applied to events received over - # federation. + # Ensure that the event is frozen, to make sure that the module is not tempted + # to try to modify it. Any attempt to modify it at this point will invalidate + # the hashes and signatures. + event.freeze() return await self.third_party_rules.check_event_allowed(event, state_events) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index d404550800..0048bea54a 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -114,6 +114,26 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) + def test_cannot_modify_event(self): + """cannot accidentally modify an event before it is persisted""" + + # first patch the event checker so that it will try to modify the event + async def check(ev: EventBase, state): + ev.content = {"x": "y"} + return True + + current_rules_module().check_event_allowed = check + + # now send the event + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id, + {"x": "x"}, + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.result["code"], b"500", channel.result) + def test_modify_event(self): """The module can return a modified version of the event""" # first patch the event checker so that it will modify the event -- cgit 1.5.1 From 091e9482af6171c19f4b967d3a973ec6c2088ab7 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 13 Oct 2020 18:57:37 +0100 Subject: changelog --- changelog.d/8535.feature | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/8535.feature diff --git a/changelog.d/8535.feature b/changelog.d/8535.feature new file mode 100644 index 0000000000..45342e66ad --- /dev/null +++ b/changelog.d/8535.feature @@ -0,0 +1 @@ +Support modifying event content in `ThirdPartyRules` modules. -- cgit 1.5.1 From a34b17e492129adba260915de401ab1cbabf6215 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 13 Oct 2020 23:14:35 +0100 Subject: Simplify `_locally_reject_invite` Update `EventCreationHandler.create_event` to accept an auth_events param, and use it in `_locally_reject_invite` instead of reinventing the wheel. --- synapse/events/builder.py | 21 ++++---- synapse/handlers/message.py | 22 ++++++++- synapse/handlers/room_member.py | 58 ++++++----------------- tests/handlers/test_presence.py | 2 +- tests/replication/test_federation_sender_shard.py | 2 +- tests/storage/test_redaction.py | 4 +- 6 files changed, 52 insertions(+), 57 deletions(-) diff --git a/synapse/events/builder.py b/synapse/events/builder.py index b6c47be646..df4f950fec 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -97,32 +97,37 @@ class EventBuilder: def is_state(self): return self._state_key is not None - async def build(self, prev_event_ids: List[str]) -> EventBase: + async def build( + self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]] + ) -> EventBase: """Transform into a fully signed and hashed event Args: prev_event_ids: The event IDs to use as the prev events + auth_event_ids: The event IDs to use as the auth events. + Should normally be set to None, which will cause them to be calculated + based on the room state at the prev_events. Returns: The signed and hashed event. """ - - state_ids = await self._state.get_current_state_ids( - self.room_id, prev_event_ids - ) - auth_ids = self._auth.compute_auth_events(self, state_ids) + if auth_event_ids is None: + state_ids = await self._state.get_current_state_ids( + self.room_id, prev_event_ids + ) + auth_event_ids = self._auth.compute_auth_events(self, state_ids) format_version = self.room_version.event_format if format_version == EventFormatVersions.V1: # The types of auth/prev events changes between event versions. auth_events = await self._store.add_event_hashes( - auth_ids + auth_event_ids ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] prev_events = await self._store.add_event_hashes( prev_event_ids ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] else: - auth_events = auth_ids + auth_events = auth_event_ids prev_events = prev_event_ids old_depth = await self._store.get_max_depth_of(prev_event_ids) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index b8f541425b..f18f882596 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -439,6 +439,7 @@ class EventCreationHandler: event_dict: dict, txn_id: Optional[str] = None, prev_event_ids: Optional[List[str]] = None, + auth_event_ids: Optional[List[str]] = None, require_consent: bool = True, ) -> Tuple[EventBase, EventContext]: """ @@ -458,6 +459,12 @@ class EventCreationHandler: new event. If None, they will be requested from the database. + + auth_event_ids: + The event ids to use as the auth_events for the new event. + Should normally be left as None, which will cause them to be calculated + based on the room state at the prev_events. + require_consent: Whether to check if the requester has consented to the privacy policy. Raises: @@ -516,7 +523,10 @@ class EventCreationHandler: builder.internal_metadata.txn_id = txn_id event, context = await self.create_new_client_event( - builder=builder, requester=requester, prev_event_ids=prev_event_ids, + builder=builder, + requester=requester, + prev_event_ids=prev_event_ids, + auth_event_ids=auth_event_ids, ) # In an ideal world we wouldn't need the second part of this condition. However, @@ -755,6 +765,7 @@ class EventCreationHandler: builder: EventBuilder, requester: Optional[Requester] = None, prev_event_ids: Optional[List[str]] = None, + auth_event_ids: Optional[List[str]] = None, ) -> Tuple[EventBase, EventContext]: """Create a new event for a local client @@ -767,6 +778,11 @@ class EventCreationHandler: If None, they will be requested from the database. + auth_event_ids: + The event ids to use as the auth_events for the new event. + Should normally be left as None, which will cause them to be calculated + based on the room state at the prev_events. + Returns: Tuple of created event, context """ @@ -788,7 +804,9 @@ class EventCreationHandler: builder.type == EventTypes.Create or len(prev_event_ids) > 0 ), "Attempting to create an event with no prev_events" - event = await builder.build(prev_event_ids=prev_event_ids) + event = await builder.build( + prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids + ) context = await self.state.compute_event_context(event) if requester: context.app_service = requester.app_service diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index d9b98d8acd..ec784030e9 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -17,12 +17,10 @@ import abc import logging import random from http import HTTPStatus -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union - -from unpaddedbase64 import encode_base64 +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from synapse import types -from synapse.api.constants import MAX_DEPTH, AccountDataTypes, EventTypes, Membership +from synapse.api.constants import AccountDataTypes, EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, @@ -31,12 +29,8 @@ from synapse.api.errors import ( SynapseError, ) from synapse.api.ratelimiting import Ratelimiter -from synapse.api.room_versions import EventFormatVersions -from synapse.crypto.event_signing import compute_event_reference_hash from synapse.events import EventBase -from synapse.events.builder import create_local_event_from_event_dict from synapse.events.snapshot import EventContext -from synapse.events.validator import EventValidator from synapse.storage.roommember import RoomsForUser from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID from synapse.util.async_helpers import Linearizer @@ -1132,31 +1126,10 @@ class RoomMemberMasterHandler(RoomMemberHandler): room_id = invite_event.room_id target_user = invite_event.state_key - room_version = await self.store.get_room_version(room_id) content["membership"] = Membership.LEAVE - # the auth events for the new event are the same as that of the invite, plus - # the invite itself. - # - # the prev_events are just the invite. - invite_hash = invite_event.event_id # type: Union[str, Tuple] - if room_version.event_format == EventFormatVersions.V1: - alg, h = compute_event_reference_hash(invite_event) - invite_hash = (invite_event.event_id, {alg: encode_base64(h)}) - - auth_events = tuple(invite_event.auth_events) + (invite_hash,) - prev_events = (invite_hash,) - - # we cap depth of generated events, to ensure that they are not - # rejected by other servers (and so that they can be persisted in - # the db) - depth = min(invite_event.depth + 1, MAX_DEPTH) - event_dict = { - "depth": depth, - "auth_events": auth_events, - "prev_events": prev_events, "type": EventTypes.Member, "room_id": room_id, "sender": target_user, @@ -1164,24 +1137,23 @@ class RoomMemberMasterHandler(RoomMemberHandler): "state_key": target_user, } - event = create_local_event_from_event_dict( - clock=self.clock, - hostname=self.hs.hostname, - signing_key=self.hs.signing_key, - room_version=room_version, - event_dict=event_dict, + # the auth events for the new event are the same as that of the invite, plus + # the invite itself. + # + # the prev_events are just the invite. + prev_event_ids = [invite_event.event_id] + auth_event_ids = invite_event.auth_event_ids() + prev_event_ids + + event, context = await self.event_creation_handler.create_event( + requester, + event_dict, + txn_id=txn_id, + prev_event_ids=prev_event_ids, + auth_event_ids=auth_event_ids, ) event.internal_metadata.outlier = True event.internal_metadata.out_of_band_membership = True - if txn_id is not None: - event.internal_metadata.txn_id = txn_id - if requester.access_token_id is not None: - event.internal_metadata.token_id = requester.access_token_id - - EventValidator().validate_new(event, self.config) - context = await self.state_handler.compute_event_context(event) - context.app_service = requester.app_service result_event = await self.event_creation_handler.handle_new_client_event( requester, event, context, extra_users=[UserID.from_string(target_user)], ) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 914c82e7a8..8ed67640f8 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -615,7 +615,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): self.store.get_latest_event_ids_in_room(room_id) ) - event = self.get_success(builder.build(prev_event_ids)) + event = self.get_success(builder.build(prev_event_ids, None)) self.get_success(self.federation_handler.on_receive_pdu(hostname, event)) diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 9c4a9c3563..779745ae9d 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -226,7 +226,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): } builder = factory.for_room_version(room_version, event_dict) - join_event = self.get_success(builder.build(prev_event_ids)) + join_event = self.get_success(builder.build(prev_event_ids, None)) self.get_success(federation.on_send_join_request(remote_server, join_event)) self.replicate() diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 1ea35d60c1..d4f9e809db 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -236,9 +236,9 @@ class RedactionTestCase(unittest.HomeserverTestCase): self._event_id = event_id @defer.inlineCallbacks - def build(self, prev_event_ids): + def build(self, prev_event_ids, auth_event_ids): built_event = yield defer.ensureDeferred( - self._base_builder.build(prev_event_ids) + self._base_builder.build(prev_event_ids, auth_event_ids) ) built_event._event_id = self._event_id -- cgit 1.5.1 From d9dc6185d3596e2243dcbf257909f8cdb722e622 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 13 Oct 2020 23:19:49 +0100 Subject: changelog --- changelog.d/8537.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/8537.misc diff --git a/changelog.d/8537.misc b/changelog.d/8537.misc new file mode 100644 index 0000000000..26309b5b93 --- /dev/null +++ b/changelog.d/8537.misc @@ -0,0 +1 @@ +Factor out common code between `RoomMemberHandler._locally_reject_invite` and `EventCreationHandler.create_event`. -- cgit 1.5.1 From 1264c8ac89e5bfa757acee8d26437f14621206a4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 14 Oct 2020 13:53:20 +0100 Subject: Add basic tests for sync/pagination with vector clock tokens. (#8488) These are tests for #8439 --- changelog.d/8488.misc | 1 + tests/replication/test_sharded_event_persister.py | 217 ++++++++++++++++++++++ tests/unittest.py | 32 +++- 3 files changed, 249 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8488.misc diff --git a/changelog.d/8488.misc b/changelog.d/8488.misc new file mode 100644 index 0000000000..237cb3b311 --- /dev/null +++ b/changelog.d/8488.misc @@ -0,0 +1 @@ +Allow events to be sent to clients sooner when using sharded event persisters. diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index 6068d14905..82cf033d4e 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -14,8 +14,12 @@ # limitations under the License. import logging +from mock import patch + +from synapse.api.room_versions import RoomVersion from synapse.rest import admin from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import sync from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.utils import USE_POSTGRES_FOR_TESTS @@ -36,6 +40,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, + sync.register_servlets, ] def prepare(self, reactor, clock, hs): @@ -43,6 +48,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): self.other_user_id = self.register_user("otheruser", "pass") self.other_access_token = self.login("otheruser", "pass") + self.room_creator = self.hs.get_room_creation_handler() + self.store = hs.get_datastore() + def default_config(self): conf = super().default_config() conf["redis"] = {"enabled": "true"} @@ -53,6 +61,29 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): } return conf + def _create_room(self, room_id: str, user_id: str, tok: str): + """Create a room with given room_id + """ + + # We control the room ID generation by patching out the + # `_generate_room_id` method + async def generate_room( + creator_id: str, is_public: bool, room_version: RoomVersion + ): + await self.store.store_room( + room_id=room_id, + room_creator_user_id=creator_id, + is_public=is_public, + room_version=room_version, + ) + return room_id + + with patch( + "synapse.handlers.room.RoomCreationHandler._generate_room_id" + ) as mock: + mock.side_effect = generate_room + self.helper.create_room_as(user_id, tok=tok) + def test_basic(self): """Simple test to ensure that multiple rooms can be created and joined, and that different rooms get handled by different instances. @@ -100,3 +131,189 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): self.assertTrue(persisted_on_1) self.assertTrue(persisted_on_2) + + def test_vector_clock_token(self): + """Tests that using a stream token with a vector clock component works + correctly with basic /sync and /messages usage. + """ + + self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "worker1"}, + ) + + worker_hs2 = self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "worker2"}, + ) + + sync_hs = self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "sync"}, + ) + + # Specially selected room IDs that get persisted on different workers. + room_id1 = "!foo:test" + room_id2 = "!baz:test" + + self.assertEqual( + self.hs.config.worker.events_shard_config.get_instance(room_id1), "worker1" + ) + self.assertEqual( + self.hs.config.worker.events_shard_config.get_instance(room_id2), "worker2" + ) + + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + store = self.hs.get_datastore() + + # Create two room on the different workers. + self._create_room(room_id1, user_id, access_token) + self._create_room(room_id2, user_id, access_token) + + # The other user joins + self.helper.join( + room=room_id1, user=self.other_user_id, tok=self.other_access_token + ) + self.helper.join( + room=room_id2, user=self.other_user_id, tok=self.other_access_token + ) + + # Do an initial sync so that we're up to date. + request, channel = self.make_request("GET", "/sync", access_token=access_token) + self.render_on_worker(sync_hs, request) + next_batch = channel.json_body["next_batch"] + + # We now gut wrench into the events stream MultiWriterIdGenerator on + # worker2 to mimic it getting stuck persisting an event. This ensures + # that when we send an event on worker1 we end up in a state where + # worker2 events stream position lags that on worker1, resulting in a + # RoomStreamToken with a non-empty instance map component. + # + # Worker2's event stream position will not advance until we call + # __aexit__ again. + actx = worker_hs2.get_datastore()._stream_id_gen.get_next() + self.get_success(actx.__aenter__()) + + response = self.helper.send(room_id1, body="Hi!", tok=self.other_access_token) + first_event_in_room1 = response["event_id"] + + # Assert that the current stream token has an instance map component, as + # we are trying to test vector clock tokens. + room_stream_token = store.get_room_max_token() + self.assertNotEqual(len(room_stream_token.instance_map), 0) + + # Check that syncing still gets the new event, despite the gap in the + # stream IDs. + request, channel = self.make_request( + "GET", "/sync?since={}".format(next_batch), access_token=access_token + ) + self.render_on_worker(sync_hs, request) + + # We should only see the new event and nothing else + self.assertIn(room_id1, channel.json_body["rooms"]["join"]) + self.assertNotIn(room_id2, channel.json_body["rooms"]["join"]) + + events = channel.json_body["rooms"]["join"][room_id1]["timeline"]["events"] + self.assertListEqual( + [first_event_in_room1], [event["event_id"] for event in events] + ) + + # Get the next batch and makes sure its a vector clock style token. + vector_clock_token = channel.json_body["next_batch"] + self.assertTrue(vector_clock_token.startswith("m")) + + # Now that we've got a vector clock token we finish the fake persisting + # an event we started above. + self.get_success(actx.__aexit__(None, None, None)) + + # Now try and send an event to the other rooom so that we can test that + # the vector clock style token works as a `since` token. + response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token) + first_event_in_room2 = response["event_id"] + + request, channel = self.make_request( + "GET", + "/sync?since={}".format(vector_clock_token), + access_token=access_token, + ) + self.render_on_worker(sync_hs, request) + + self.assertNotIn(room_id1, channel.json_body["rooms"]["join"]) + self.assertIn(room_id2, channel.json_body["rooms"]["join"]) + + events = channel.json_body["rooms"]["join"][room_id2]["timeline"]["events"] + self.assertListEqual( + [first_event_in_room2], [event["event_id"] for event in events] + ) + + next_batch = channel.json_body["next_batch"] + + # We also want to test that the vector clock style token works with + # pagination. We do this by sending a couple of new events into the room + # and syncing again to get a prev_batch token for each room, then + # paginating from there back to the vector clock token. + self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token) + self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token) + + request, channel = self.make_request( + "GET", "/sync?since={}".format(next_batch), access_token=access_token + ) + self.render_on_worker(sync_hs, request) + + prev_batch1 = channel.json_body["rooms"]["join"][room_id1]["timeline"][ + "prev_batch" + ] + prev_batch2 = channel.json_body["rooms"]["join"][room_id2]["timeline"][ + "prev_batch" + ] + + # Paginating back in the first room should not produce any results, as + # no events have happened in it. This tests that we are correctly + # filtering results based on the vector clock portion. + request, channel = self.make_request( + "GET", + "/rooms/{}/messages?from={}&to={}&dir=b".format( + room_id1, prev_batch1, vector_clock_token + ), + access_token=access_token, + ) + self.render_on_worker(sync_hs, request) + self.assertListEqual([], channel.json_body["chunk"]) + + # Paginating back on the second room should produce the first event + # again. This tests that pagination isn't completely broken. + request, channel = self.make_request( + "GET", + "/rooms/{}/messages?from={}&to={}&dir=b".format( + room_id2, prev_batch2, vector_clock_token + ), + access_token=access_token, + ) + self.render_on_worker(sync_hs, request) + self.assertEqual(len(channel.json_body["chunk"]), 1) + self.assertEqual( + channel.json_body["chunk"][0]["event_id"], first_event_in_room2 + ) + + # Paginating forwards should give the same results + request, channel = self.make_request( + "GET", + "/rooms/{}/messages?from={}&to={}&dir=f".format( + room_id1, vector_clock_token, prev_batch1 + ), + access_token=access_token, + ) + self.render_on_worker(sync_hs, request) + self.assertListEqual([], channel.json_body["chunk"]) + + request, channel = self.make_request( + "GET", + "/rooms/{}/messages?from={}&to={}&dir=f".format( + room_id2, vector_clock_token, prev_batch2, + ), + access_token=access_token, + ) + self.render_on_worker(sync_hs, request) + self.assertEqual(len(channel.json_body["chunk"]), 1) + self.assertEqual( + channel.json_body["chunk"][0]["event_id"], first_event_in_room2 + ) diff --git a/tests/unittest.py b/tests/unittest.py index 6c1661c92c..040b126a27 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -20,7 +20,7 @@ import hmac import inspect import logging import time -from typing import Optional, Tuple, Type, TypeVar, Union +from typing import Optional, Tuple, Type, TypeVar, Union, overload from mock import Mock, patch @@ -364,6 +364,36 @@ class HomeserverTestCase(TestCase): Function to optionally be overridden in subclasses. """ + # Annoyingly mypy doesn't seem to pick up the fact that T is SynapseRequest + # when the `request` arg isn't given, so we define an explicit override to + # cover that case. + @overload + def make_request( + self, + method: Union[bytes, str], + path: Union[bytes, str], + content: Union[bytes, dict] = b"", + access_token: Optional[str] = None, + shorthand: bool = True, + federation_auth_origin: str = None, + content_is_form: bool = False, + ) -> Tuple[SynapseRequest, FakeChannel]: + ... + + @overload + def make_request( + self, + method: Union[bytes, str], + path: Union[bytes, str], + content: Union[bytes, dict] = b"", + access_token: Optional[str] = None, + request: Type[T] = SynapseRequest, + shorthand: bool = True, + federation_auth_origin: str = None, + content_is_form: bool = False, + ) -> Tuple[T, FakeChannel]: + ... + def make_request( self, method: Union[bytes, str], -- cgit 1.5.1 From 9e66f3761cdfe8adcd45be37466c86bdbfc57a35 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 14 Oct 2020 15:00:49 +0100 Subject: Update documentation on retention policies limits (#8529) * Update documentation on retention policies limits Document the changes from https://github.com/matrix-org/synapse/pull/8104 --- changelog.d/8529.doc | 1 + docs/message_retention_policies.md | 34 ++++++++++++++++++++++------------ 2 files changed, 23 insertions(+), 12 deletions(-) create mode 100644 changelog.d/8529.doc diff --git a/changelog.d/8529.doc b/changelog.d/8529.doc new file mode 100644 index 0000000000..6e710e6527 --- /dev/null +++ b/changelog.d/8529.doc @@ -0,0 +1 @@ +Document the new behaviour of the `allowed_lifetime_min` and `allowed_lifetime_max` settings in the room retention configuration. diff --git a/docs/message_retention_policies.md b/docs/message_retention_policies.md index 1dd60bdad9..75d2028e17 100644 --- a/docs/message_retention_policies.md +++ b/docs/message_retention_policies.md @@ -136,24 +136,34 @@ the server's database. ### Lifetime limits -**Note: this feature is mainly useful within a closed federation or on -servers that don't federate, because there currently is no way to -enforce these limits in an open federation.** - -Server admins can restrict the values their local users are allowed to -use for both `min_lifetime` and `max_lifetime`. These limits can be -defined as such in the `retention` section of the configuration file: +Server admins can set limits on the values of `max_lifetime` to use when +purging old events in a room. These limits can be defined as such in the +`retention` section of the configuration file: ```yaml allowed_lifetime_min: 1d allowed_lifetime_max: 1y ``` -Here, `allowed_lifetime_min` is the lowest value a local user can set -for both `min_lifetime` and `max_lifetime`, and `allowed_lifetime_max` -is the highest value. Both parameters are optional (e.g. setting -`allowed_lifetime_min` but not `allowed_lifetime_max` only enforces a -minimum and no maximum). +The limits are considered when running purge jobs. If necessary, the +effective value of `max_lifetime` will be brought between +`allowed_lifetime_min` and `allowed_lifetime_max` (inclusive). +This means that, if the value of `max_lifetime` defined in the room's state +is lower than `allowed_lifetime_min`, the value of `allowed_lifetime_min` +will be used instead. Likewise, if the value of `max_lifetime` is higher +than `allowed_lifetime_max`, the value of `allowed_lifetime_max` will be +used instead. + +In the example above, we ensure Synapse never deletes events that are less +than one day old, and that it always deletes events that are over a year +old. + +If a default policy is set, and its `max_lifetime` value is lower than +`allowed_lifetime_min` or higher than `allowed_lifetime_max`, the same +process applies. + +Both parameters are optional; if one is omitted Synapse won't use it to +adjust the effective value of `max_lifetime`. Like other settings in this section, these parameters can be expressed either as a duration or as a number of milliseconds. -- cgit 1.5.1 From 1cf4a68108a77607c8aff1ee8f6216df251c4e7e Mon Sep 17 00:00:00 2001 From: Christopher May-Townsend Date: Wed, 14 Oct 2020 15:28:59 +0100 Subject: Add note to manhole.md about bind_address when using with docker (#8526) Signed-off-by: Christopher May-Townsend --- changelog.d/8526.doc | 1 + docs/manhole.md | 46 +++++++++++++++++++++++++++++++++++++++------- 2 files changed, 40 insertions(+), 7 deletions(-) create mode 100644 changelog.d/8526.doc diff --git a/changelog.d/8526.doc b/changelog.d/8526.doc new file mode 100644 index 0000000000..cbf48680c1 --- /dev/null +++ b/changelog.d/8526.doc @@ -0,0 +1 @@ +Added note about docker in manhole.md regarding which ip address to bind to. Contributed by @Maquis196. diff --git a/docs/manhole.md b/docs/manhole.md index 75b6ae40e0..37d1d7823c 100644 --- a/docs/manhole.md +++ b/docs/manhole.md @@ -5,22 +5,54 @@ The "manhole" allows server administrators to access a Python shell on a running Synapse installation. This is a very powerful mechanism for administration and debugging. +**_Security Warning_** + +Note that this will give administrative access to synapse to **all users** with +shell access to the server. It should therefore **not** be enabled in +environments where untrusted users have shell access. + +*** + To enable it, first uncomment the `manhole` listener configuration in -`homeserver.yaml`: +`homeserver.yaml`. The configuration is slightly different if you're using docker. + +#### Docker config + +If you are using Docker, set `bind_addresses` to `['0.0.0.0']` as shown: ```yaml listeners: - port: 9000 - bind_addresses: ['::1', '127.0.0.1'] + bind_addresses: ['0.0.0.0'] type: manhole ``` -(`bind_addresses` in the above is important: it ensures that access to the -manhole is only possible for local users). +When using `docker run` to start the server, you will then need to change the command to the following to include the +`manhole` port forwarding. The `-p 127.0.0.1:9000:9000` below is important: it +ensures that access to the `manhole` is only possible for local users. -Note that this will give administrative access to synapse to **all users** with -shell access to the server. It should therefore **not** be enabled in -environments where untrusted users have shell access. +```bash +docker run -d --name synapse \ + --mount type=volume,src=synapse-data,dst=/data \ + -p 8008:8008 \ + -p 127.0.0.1:9000:9000 \ + matrixdotorg/synapse:latest +``` + +#### Native config + +If you are not using docker, set `bind_addresses` to `['::1', '127.0.0.1']` as shown. +The `bind_addresses` in the example below is important: it ensures that access to the +`manhole` is only possible for local users). + +```yaml +listeners: + - port: 9000 + bind_addresses: ['::1', '127.0.0.1'] + type: manhole +``` + +#### Accessing synapse manhole Then restart synapse, and point an ssh client at port 9000 on localhost, using the username `matrix`: -- cgit 1.5.1 From 618d405a322590d022d839b6d72ba51e992a71c3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 14 Oct 2020 15:40:06 +0100 Subject: Remove racey assertion in MultiWriterIDGenerator (#8530) We asserted that the IDs returned by postgres sequence was greater than any we had seen, however this is technically racey as we may update the current positions out of order. We now assert that the sequences are correct on startup, so the assertion is no longer really required, so we remove them. --- changelog.d/8530.bugfix | 1 + synapse/storage/util/id_generators.py | 7 ------- 2 files changed, 1 insertion(+), 7 deletions(-) create mode 100644 changelog.d/8530.bugfix diff --git a/changelog.d/8530.bugfix b/changelog.d/8530.bugfix new file mode 100644 index 0000000000..443d88424e --- /dev/null +++ b/changelog.d/8530.bugfix @@ -0,0 +1 @@ +Fix rare bug where sending an event would fail due to a racey assertion. diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 3d8da48f2d..02d71302ea 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -618,14 +618,7 @@ class _MultiWriterCtxManager: db_autocommit=True, ) - # Assert the fetched ID is actually greater than any ID we've already - # seen. If not, then the sequence and table have got out of sync - # somehow. with self.id_gen._lock: - assert max(self.id_gen._current_positions.values(), default=0) < min( - self.stream_ids - ) - self.id_gen._unfinished_ids.update(self.stream_ids) if self.multiple_ids is None: -- cgit 1.5.1 From 19b15d63e80dd1d925ff78158f7a191427d6403f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 14 Oct 2020 15:50:59 +0100 Subject: Use autocommit mode for single statement DB functions. (#8542) Autocommit means that we don't wrap the functions in transactions, and instead get executed directly. Introduced in #8456. This will help: 1. reduce the number of `could not serialize access due to concurrent delete` errors that we see (though there are a few functions that often cause serialization errors that we don't fix here); 2. improve the DB performance, as it no longer needs to deal with the overhead of `REPEATABLE READ` isolation levels; and 3. improve wall clock speed of these functions, as we no longer need to send `BEGIN` and `COMMIT` to the DB. Some notes about the differences between autocommit mode and our default `REPEATABLE READ` transactions: 1. Currently `autocommit` only applies when using PostgreSQL, and is ignored when using SQLite (due to silliness with [Twisted DB classes](https://twistedmatrix.com/trac/ticket/9998)). 2. Autocommit functions may get retried on error, which means they can get applied *twice* (or more) to the DB (since they are not in a transaction the previous call would not get rolled back). This means that the functions need to be idempotent (or otherwise not care about being called multiple times). Read queries, simple deletes, and updates/upserts that replace rows (rather than generating new values from existing rows) are all idempotent. 3. Autocommit functions no longer get executed in [`REPEATABLE READ`](https://www.postgresql.org/docs/current/transaction-iso.html) isolation level, and so data can change queries, which is fine for single statement queries. --- changelog.d/8542.misc | 1 + synapse/storage/database.py | 99 ++++++++++++++++++++++-- synapse/storage/databases/main/keys.py | 5 +- synapse/storage/databases/main/transactions.py | 76 ++++++++++-------- synapse/storage/databases/main/user_directory.py | 45 ++++------- 5 files changed, 156 insertions(+), 70 deletions(-) create mode 100644 changelog.d/8542.misc diff --git a/changelog.d/8542.misc b/changelog.d/8542.misc new file mode 100644 index 0000000000..63149fd9b9 --- /dev/null +++ b/changelog.d/8542.misc @@ -0,0 +1 @@ +Improve database performance by executing more queries without starting transactions. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 0ba3a025cf..763722d6bc 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -893,6 +893,12 @@ class DatabasePool: attempts = 0 while True: try: + # We can autocommit if we are going to use native upserts + autocommit = ( + self.engine.can_native_upsert + and table not in self._unsafe_to_upsert_tables + ) + return await self.runInteraction( desc, self.simple_upsert_txn, @@ -901,6 +907,7 @@ class DatabasePool: values, insertion_values, lock=lock, + db_autocommit=autocommit, ) except self.engine.module.IntegrityError as e: attempts += 1 @@ -1063,6 +1070,43 @@ class DatabasePool: ) txn.execute(sql, list(allvalues.values())) + async def simple_upsert_many( + self, + table: str, + key_names: Collection[str], + key_values: Collection[Iterable[Any]], + value_names: Collection[str], + value_values: Iterable[Iterable[Any]], + desc: str, + ) -> None: + """ + Upsert, many times. + + Args: + table: The table to upsert into + key_names: The key column names. + key_values: A list of each row's key column values. + value_names: The value column names + value_values: A list of each row's value column values. + Ignored if value_names is empty. + """ + + # We can autocommit if we are going to use native upserts + autocommit = ( + self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables + ) + + return await self.runInteraction( + desc, + self.simple_upsert_many_txn, + table, + key_names, + key_values, + value_names, + value_values, + db_autocommit=autocommit, + ) + def simple_upsert_many_txn( self, txn: LoggingTransaction, @@ -1214,7 +1258,13 @@ class DatabasePool: desc: description of the transaction, for logging and metrics """ return await self.runInteraction( - desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none + desc, + self.simple_select_one_txn, + table, + keyvalues, + retcols, + allow_none, + db_autocommit=True, ) @overload @@ -1265,6 +1315,7 @@ class DatabasePool: keyvalues, retcol, allow_none=allow_none, + db_autocommit=True, ) @overload @@ -1346,7 +1397,12 @@ class DatabasePool: Results in a list """ return await self.runInteraction( - desc, self.simple_select_onecol_txn, table, keyvalues, retcol + desc, + self.simple_select_onecol_txn, + table, + keyvalues, + retcol, + db_autocommit=True, ) async def simple_select_list( @@ -1371,7 +1427,12 @@ class DatabasePool: A list of dictionaries. """ return await self.runInteraction( - desc, self.simple_select_list_txn, table, keyvalues, retcols + desc, + self.simple_select_list_txn, + table, + keyvalues, + retcols, + db_autocommit=True, ) @classmethod @@ -1450,6 +1511,7 @@ class DatabasePool: chunk, keyvalues, retcols, + db_autocommit=True, ) results.extend(rows) @@ -1548,7 +1610,12 @@ class DatabasePool: desc: description of the transaction, for logging and metrics """ await self.runInteraction( - desc, self.simple_update_one_txn, table, keyvalues, updatevalues + desc, + self.simple_update_one_txn, + table, + keyvalues, + updatevalues, + db_autocommit=True, ) @classmethod @@ -1607,7 +1674,9 @@ class DatabasePool: keyvalues: dict of column names and values to select the row with desc: description of the transaction, for logging and metrics """ - await self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) + await self.runInteraction( + desc, self.simple_delete_one_txn, table, keyvalues, db_autocommit=True, + ) @staticmethod def simple_delete_one_txn( @@ -1646,7 +1715,9 @@ class DatabasePool: Returns: The number of deleted rows. """ - return await self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) + return await self.runInteraction( + desc, self.simple_delete_txn, table, keyvalues, db_autocommit=True + ) @staticmethod def simple_delete_txn( @@ -1694,7 +1765,13 @@ class DatabasePool: Number rows deleted """ return await self.runInteraction( - desc, self.simple_delete_many_txn, table, column, iterable, keyvalues + desc, + self.simple_delete_many_txn, + table, + column, + iterable, + keyvalues, + db_autocommit=True, ) @staticmethod @@ -1860,7 +1937,13 @@ class DatabasePool: """ return await self.runInteraction( - desc, self.simple_search_list_txn, table, term, col, retcols + desc, + self.simple_search_list_txn, + table, + term, + col, + retcols, + db_autocommit=True, ) @classmethod diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index ad43bb05ab..f8f4bb9b3f 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -122,9 +122,7 @@ class KeyStore(SQLBaseStore): # param, which is itself the 2-tuple (server_name, key_id). invalidations.append((server_name, key_id)) - await self.db_pool.runInteraction( - "store_server_verify_keys", - self.db_pool.simple_upsert_many_txn, + await self.db_pool.simple_upsert_many( table="server_signature_keys", key_names=("server_name", "key_id"), key_values=key_values, @@ -135,6 +133,7 @@ class KeyStore(SQLBaseStore): "verify_key", ), value_values=value_values, + desc="store_server_verify_keys", ) invalidate = self._get_server_verify_key.invalidate diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 7d46090267..59207cadd4 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -208,42 +208,56 @@ class TransactionStore(TransactionWorkerStore): """ self._destination_retry_cache.pop(destination, None) - return await self.db_pool.runInteraction( - "set_destination_retry_timings", - self._set_destination_retry_timings, - destination, - failure_ts, - retry_last_ts, - retry_interval, - ) + if self.database_engine.can_native_upsert: + return await self.db_pool.runInteraction( + "set_destination_retry_timings", + self._set_destination_retry_timings_native, + destination, + failure_ts, + retry_last_ts, + retry_interval, + db_autocommit=True, # Safe as its a single upsert + ) + else: + return await self.db_pool.runInteraction( + "set_destination_retry_timings", + self._set_destination_retry_timings_emulated, + destination, + failure_ts, + retry_last_ts, + retry_interval, + ) - def _set_destination_retry_timings( + def _set_destination_retry_timings_native( self, txn, destination, failure_ts, retry_last_ts, retry_interval ): + assert self.database_engine.can_native_upsert + + # Upsert retry time interval if retry_interval is zero (i.e. we're + # resetting it) or greater than the existing retry interval. + # + # WARNING: This is executed in autocommit, so we shouldn't add any more + # SQL calls in here (without being very careful). + sql = """ + INSERT INTO destinations ( + destination, failure_ts, retry_last_ts, retry_interval + ) + VALUES (?, ?, ?, ?) + ON CONFLICT (destination) DO UPDATE SET + failure_ts = EXCLUDED.failure_ts, + retry_last_ts = EXCLUDED.retry_last_ts, + retry_interval = EXCLUDED.retry_interval + WHERE + EXCLUDED.retry_interval = 0 + OR destinations.retry_interval IS NULL + OR destinations.retry_interval < EXCLUDED.retry_interval + """ - if self.database_engine.can_native_upsert: - # Upsert retry time interval if retry_interval is zero (i.e. we're - # resetting it) or greater than the existing retry interval. - - sql = """ - INSERT INTO destinations ( - destination, failure_ts, retry_last_ts, retry_interval - ) - VALUES (?, ?, ?, ?) - ON CONFLICT (destination) DO UPDATE SET - failure_ts = EXCLUDED.failure_ts, - retry_last_ts = EXCLUDED.retry_last_ts, - retry_interval = EXCLUDED.retry_interval - WHERE - EXCLUDED.retry_interval = 0 - OR destinations.retry_interval IS NULL - OR destinations.retry_interval < EXCLUDED.retry_interval - """ - - txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval)) - - return + txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval)) + def _set_destination_retry_timings_emulated( + self, txn, destination, failure_ts, retry_last_ts, retry_interval + ): self.database_engine.lock_table(txn, "destinations") # We need to be careful here as the data may have changed from under us diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 5a390ff2f6..d87ceec6da 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -480,21 +480,16 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): user_id_tuples: iterable of 2-tuple of user IDs. """ - def _add_users_who_share_room_txn(txn): - self.db_pool.simple_upsert_many_txn( - txn, - table="users_who_share_private_rooms", - key_names=["user_id", "other_user_id", "room_id"], - key_values=[ - (user_id, other_user_id, room_id) - for user_id, other_user_id in user_id_tuples - ], - value_names=(), - value_values=None, - ) - - await self.db_pool.runInteraction( - "add_users_who_share_room", _add_users_who_share_room_txn + await self.db_pool.simple_upsert_many( + table="users_who_share_private_rooms", + key_names=["user_id", "other_user_id", "room_id"], + key_values=[ + (user_id, other_user_id, room_id) + for user_id, other_user_id in user_id_tuples + ], + value_names=(), + value_values=None, + desc="add_users_who_share_room", ) async def add_users_in_public_rooms( @@ -508,19 +503,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): user_ids """ - def _add_users_in_public_rooms_txn(txn): - - self.db_pool.simple_upsert_many_txn( - txn, - table="users_in_public_rooms", - key_names=["user_id", "room_id"], - key_values=[(user_id, room_id) for user_id in user_ids], - value_names=(), - value_values=None, - ) - - await self.db_pool.runInteraction( - "add_users_in_public_rooms", _add_users_in_public_rooms_txn + await self.db_pool.simple_upsert_many( + table="users_in_public_rooms", + key_names=["user_id", "room_id"], + key_values=[(user_id, room_id) for user_id in user_ids], + value_names=(), + value_values=None, + desc="add_users_in_public_rooms", ) async def delete_all_from_user_dir(self) -> None: -- cgit 1.5.1 From 7eff59ec91e59140c375b43a6dac05b833ab0051 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 14 Oct 2020 19:40:53 +0100 Subject: Add some more type annotations to Cache --- synapse/replication/slave/storage/client_ips.py | 2 +- synapse/util/caches/descriptors.py | 81 ++++++++++++++++++------- synapse/util/caches/lrucache.py | 3 +- 3 files changed, 62 insertions(+), 24 deletions(-) diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 1f8dafe7ea..273d627fad 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -26,7 +26,7 @@ class SlavedClientIpStore(BaseSlavedStore): self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 - ) + ) # type: Cache[tuple, int] async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): now = int(self._clock.time_msec()) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 98b34f2223..14458bc20f 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -13,12 +13,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import enum import functools import inspect import logging import threading -from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast +from typing import ( + Any, + Callable, + Generic, + Iterable, + MutableMapping, + Optional, + Tuple, + TypeVar, + Union, + cast, +) from weakref import WeakValueDictionary from prometheus_client import Gauge @@ -38,6 +49,8 @@ logger = logging.getLogger(__name__) CacheKey = Union[Tuple, Any] F = TypeVar("F", bound=Callable[..., Any]) +KT = TypeVar("KT") +VT = TypeVar("VT") class _CachedFunction(Generic[F]): @@ -61,13 +74,19 @@ cache_pending_metric = Gauge( ["name"], ) -_CacheSentinel = object() + +class _Sentinel(enum.Enum): + # defining a sentinel in this way allows mypy to correctly handle the + # type of a dictionary lookup. + sentinel = object() class CacheEntry: __slots__ = ["deferred", "callbacks", "invalidated"] - def __init__(self, deferred, callbacks): + def __init__( + self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]] + ): self.deferred = deferred self.callbacks = set(callbacks) self.invalidated = False @@ -80,7 +99,13 @@ class CacheEntry: self.callbacks.clear() -class Cache: +class Cache(Generic[KT, VT]): + """Wraps an LruCache, adding support for Deferred results. + + It expects that each entry added with set() will be a Deferred; likewise get() + may return an ObservableDeferred. + """ + __slots__ = ( "cache", "name", @@ -103,19 +128,23 @@ class Cache: Args: name: The name of the cache max_entries: Maximum amount of entries that the cache will hold - keylen: The length of the tuple used as the cache key + keylen: The length of the tuple used as the cache key. Ignored unless + `tree` is True. tree: Use a TreeCache instead of a dict as the underlying cache type iterable: If True, count each item in the cached object as an entry, rather than each cached object apply_cache_factor_from_config: Whether cache factors specified in the config file affect `max_entries` - - Returns: - Cache """ cache_type = TreeCache if tree else dict - self._pending_deferred_cache = cache_type() + # _pending_deferred_cache maps from the key value to a `CacheEntry` object. + self._pending_deferred_cache = ( + cache_type() + ) # type: MutableMapping[KT, CacheEntry] + + # cache is used for completed results and maps to the result itself, rather than + # a Deferred. self.cache = LruCache( max_size=max_entries, keylen=keylen, @@ -155,7 +184,13 @@ class Cache: "Cache objects can only be accessed from the main thread" ) - def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True): + def get( + self, + key: KT, + default=_Sentinel.sentinel, + callback: Optional[Callable[[], None]] = None, + update_metrics: bool = True, + ): """Looks the key up in the caches. Args: @@ -166,30 +201,32 @@ class Cache: update_metrics (bool): whether to update the cache hit rate metrics Returns: - Either an ObservableDeferred or the raw result + Either an ObservableDeferred or the result itself """ callbacks = [callback] if callback else [] - val = self._pending_deferred_cache.get(key, _CacheSentinel) - if val is not _CacheSentinel: + val = self._pending_deferred_cache.get(key, _Sentinel.sentinel) + if val is not _Sentinel.sentinel: val.callbacks.update(callbacks) if update_metrics: self.metrics.inc_hits() return val.deferred - val = self.cache.get(key, _CacheSentinel, callbacks=callbacks) - if val is not _CacheSentinel: + val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks) + if val is not _Sentinel.sentinel: self.metrics.inc_hits() return val if update_metrics: self.metrics.inc_misses() - if default is _CacheSentinel: + if default is _Sentinel.sentinel: raise KeyError() else: return default - def set(self, key, value, callback=None): + def set( + self, key: KT, value: defer.Deferred, callback: Optional[Callable[[], None]] = None + ) -> ObservableDeferred: if not isinstance(value, defer.Deferred): raise TypeError("not a Deferred") @@ -248,7 +285,7 @@ class Cache: observer.addCallbacks(cb, eb) return observable - def prefill(self, key, value, callback=None): + def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None): callbacks = [callback] if callback else [] self.cache.set(key, value, callbacks=callbacks) @@ -267,7 +304,7 @@ class Cache: if entry: entry.invalidate() - def invalidate_many(self, key): + def invalidate_many(self, key: KT): self.check_thread() if not isinstance(key, tuple): raise TypeError("The cache key must be a tuple not %r" % (type(key),)) @@ -275,7 +312,7 @@ class Cache: # if we have a pending lookup for this key, remove it from the # _pending_deferred_cache, as above - entry_dict = self._pending_deferred_cache.pop(key, None) + entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None) if entry_dict is not None: for entry in iterate_tree_cache_entry(entry_dict): entry.invalidate() @@ -396,7 +433,7 @@ class CacheDescriptor(_CacheDescriptorBase): keylen=self.num_args, tree=self.tree, iterable=self.iterable, - ) + ) # type: Cache[Tuple, Any] def get_cache_key_gen(args, kwargs): """Given some args/kwargs return a generator that resolves into diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 4bc1a67b58..33eae2b7c4 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -64,7 +64,8 @@ class LruCache: Args: max_size: The maximum amount of entries the cache can hold - keylen: The length of the tuple used as the cache key + keylen: The length of the tuple used as the cache key. Ignored unless + cache_type is `TreeCache`. cache_type (type): type of underlying cache to be used. Typically one of dict -- cgit 1.5.1 From 9f87da0a84f93096e228f01f1139c9b5db8ea3d4 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 14 Oct 2020 19:43:37 +0100 Subject: Rename Cache->DeferredCache --- synapse/replication/slave/storage/client_ips.py | 6 +++--- synapse/storage/databases/main/client_ips.py | 4 ++-- synapse/storage/databases/main/devices.py | 4 ++-- synapse/storage/databases/main/events_worker.py | 4 ++-- synapse/util/caches/descriptors.py | 19 ++++++++++++------- tests/storage/test__base.py | 10 +++++----- tests/test_metrics.py | 4 ++-- tests/util/caches/test_descriptors.py | 4 ++-- 8 files changed, 30 insertions(+), 25 deletions(-) diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 273d627fad..40ea78a353 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -15,7 +15,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY -from synapse.util.caches.descriptors import Cache +from synapse.util.caches.descriptors import DeferredCache from ._base import BaseSlavedStore @@ -24,9 +24,9 @@ class SlavedClientIpStore(BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self.client_ip_last_seen = Cache( + self.client_ip_last_seen = DeferredCache( name="client_ip_last_seen", keylen=4, max_entries=50000 - ) # type: Cache[tuple, int] + ) # type: DeferredCache[tuple, int] async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): now = int(self._clock.time_msec()) diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index a25a888443..ad32701d36 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -19,7 +19,7 @@ from typing import Dict, Optional, Tuple from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_tuple_comparison_clause -from synapse.util.caches.descriptors import Cache +from synapse.util.caches.descriptors import DeferredCache logger = logging.getLogger(__name__) @@ -410,7 +410,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): class ClientIpStore(ClientIpWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): - self.client_ip_last_seen = Cache( + self.client_ip_last_seen = DeferredCache( name="client_ip_last_seen", keylen=4, max_entries=50000 ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 88fd97e1df..d903155e89 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -34,7 +34,7 @@ from synapse.storage.database import ( ) from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder -from synapse.util.caches.descriptors import Cache, cached, cachedList +from synapse.util.caches.descriptors import DeferredCache, cached, cachedList from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -1004,7 +1004,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. - self.device_id_exists_cache = Cache( + self.device_id_exists_cache = DeferredCache( name="device_id_exists", keylen=2, max_entries=10000 ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 3ec4d1d9c2..be7f60f2e8 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -42,7 +42,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import Collection, get_domain_from_id -from synapse.util.caches.descriptors import Cache, cached +from synapse.util.caches.descriptors import DeferredCache, cached from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -145,7 +145,7 @@ class EventsWorkerStore(SQLBaseStore): self._cleanup_old_transaction_ids, ) - self._get_event_cache = Cache( + self._get_event_cache = DeferredCache( "*getEvent*", keylen=3, max_entries=hs.config.caches.event_cache_size, diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 14458bc20f..7c9fe199bf 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -99,7 +99,7 @@ class CacheEntry: self.callbacks.clear() -class Cache(Generic[KT, VT]): +class DeferredCache(Generic[KT, VT]): """Wraps an LruCache, adding support for Deferred results. It expects that each entry added with set() will be a Deferred; likewise get() @@ -225,7 +225,10 @@ class Cache(Generic[KT, VT]): return default def set( - self, key: KT, value: defer.Deferred, callback: Optional[Callable[[], None]] = None + self, + key: KT, + value: defer.Deferred, + callback: Optional[Callable[[], None]] = None, ) -> ObservableDeferred: if not isinstance(value, defer.Deferred): raise TypeError("not a Deferred") @@ -427,13 +430,13 @@ class CacheDescriptor(_CacheDescriptorBase): self.iterable = iterable def __get__(self, obj, owner): - cache = Cache( + cache = DeferredCache( name=self.orig.__name__, max_entries=self.max_entries, keylen=self.num_args, tree=self.tree, iterable=self.iterable, - ) # type: Cache[Tuple, Any] + ) # type: DeferredCache[Tuple, Any] def get_cache_key_gen(args, kwargs): """Given some args/kwargs return a generator that resolves into @@ -677,9 +680,9 @@ class _CacheContext: _cache_context_objects = ( WeakValueDictionary() - ) # type: WeakValueDictionary[Tuple[Cache, CacheKey], _CacheContext] + ) # type: WeakValueDictionary[Tuple[DeferredCache, CacheKey], _CacheContext] - def __init__(self, cache, cache_key): # type: (Cache, CacheKey) -> None + def __init__(self, cache, cache_key): # type: (DeferredCache, CacheKey) -> None self._cache = cache self._cache_key = cache_key @@ -688,7 +691,9 @@ class _CacheContext: self._cache.invalidate(self._cache_key) @classmethod - def get_instance(cls, cache, cache_key): # type: (Cache, CacheKey) -> _CacheContext + def get_instance( + cls, cache, cache_key + ): # type: (DeferredCache, CacheKey) -> _CacheContext """Returns an instance constructed with the given arguments. A new instance is only created if none already exists. diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index f5afed017c..00adcab7b9 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -20,14 +20,14 @@ from mock import Mock from twisted.internet import defer from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches.descriptors import Cache, cached +from synapse.util.caches.descriptors import DeferredCache, cached from tests import unittest -class CacheTestCase(unittest.HomeserverTestCase): +class DeferredCacheTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.cache = Cache("test") + self.cache = DeferredCache("test") def test_empty(self): failed = False @@ -56,7 +56,7 @@ class CacheTestCase(unittest.HomeserverTestCase): self.assertTrue(failed) def test_eviction(self): - cache = Cache("test", max_entries=2) + cache = DeferredCache("test", max_entries=2) cache.prefill(1, "one") cache.prefill(2, "two") @@ -74,7 +74,7 @@ class CacheTestCase(unittest.HomeserverTestCase): cache.get(3) def test_eviction_lru(self): - cache = Cache("test", max_entries=2) + cache = DeferredCache("test", max_entries=2) cache.prefill(1, "one") cache.prefill(2, "two") diff --git a/tests/test_metrics.py b/tests/test_metrics.py index f5f63d8ed6..1c03a52f7c 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -15,7 +15,7 @@ # limitations under the License. from synapse.metrics import REGISTRY, InFlightGauge, generate_latest -from synapse.util.caches.descriptors import Cache +from synapse.util.caches.descriptors import DeferredCache from tests import unittest @@ -138,7 +138,7 @@ class CacheMetricsTests(unittest.HomeserverTestCase): Caches produce metrics reflecting their state when scraped. """ CACHE_NAME = "cache_metrics_test_fgjkbdfg" - cache = Cache(CACHE_NAME, max_entries=777) + cache = DeferredCache(CACHE_NAME, max_entries=777) items = { x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 677e925477..bd870b4a33 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -42,9 +42,9 @@ def run_on_reactor(): return make_deferred_yieldable(d) -class CacheTestCase(unittest.TestCase): +class DeferredCacheTestCase(unittest.TestCase): def test_invalidate_all(self): - cache = descriptors.Cache("testcache") + cache = descriptors.DeferredCache("testcache") callback_record = [False, False] -- cgit 1.5.1 From 4182bb812f21d7231ff0efc5e93e5f2e88f6605e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 14 Oct 2020 23:25:23 +0100 Subject: move DeferredCache into its own module --- synapse/replication/slave/storage/client_ips.py | 2 +- synapse/storage/databases/main/client_ips.py | 2 +- synapse/storage/databases/main/devices.py | 3 +- synapse/storage/databases/main/events_worker.py | 3 +- synapse/util/caches/deferred_cache.py | 292 ++++++++++++++++++++++++ synapse/util/caches/descriptors.py | 284 +---------------------- tests/storage/test__base.py | 3 +- tests/test_metrics.py | 2 +- tests/util/caches/test_deferred_cache.py | 64 ++++++ tests/util/caches/test_descriptors.py | 44 ---- 10 files changed, 367 insertions(+), 332 deletions(-) create mode 100644 synapse/util/caches/deferred_cache.py create mode 100644 tests/util/caches/test_deferred_cache.py diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 40ea78a353..4b0ea0cc01 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -15,7 +15,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY -from synapse.util.caches.descriptors import DeferredCache +from synapse.util.caches.deferred_cache import DeferredCache from ._base import BaseSlavedStore diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index ad32701d36..9e66e6648a 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -19,7 +19,7 @@ from typing import Dict, Optional, Tuple from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_tuple_comparison_clause -from synapse.util.caches.descriptors import DeferredCache +from synapse.util.caches.deferred_cache import DeferredCache logger = logging.getLogger(__name__) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index d903155e89..e662a20d24 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -34,7 +34,8 @@ from synapse.storage.database import ( ) from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder -from synapse.util.caches.descriptors import DeferredCache, cached, cachedList +from synapse.util.caches.deferred_cache import DeferredCache +from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index be7f60f2e8..ff150f0be7 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -42,7 +42,8 @@ from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import Collection, get_domain_from_id -from synapse.util.caches.descriptors import DeferredCache, cached +from synapse.util.caches.deferred_cache import DeferredCache +from synapse.util.caches.descriptors import cached from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py new file mode 100644 index 0000000000..f728cd2cf2 --- /dev/null +++ b/synapse/util/caches/deferred_cache.py @@ -0,0 +1,292 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +import threading +from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, cast + +from prometheus_client import Gauge + +from twisted.internet import defer + +from synapse.util.async_helpers import ObservableDeferred +from synapse.util.caches import register_cache +from synapse.util.caches.lrucache import LruCache +from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry + +cache_pending_metric = Gauge( + "synapse_util_caches_cache_pending", + "Number of lookups currently pending for this cache", + ["name"], +) + + +KT = TypeVar("KT") +VT = TypeVar("VT") + + +class _Sentinel(enum.Enum): + # defining a sentinel in this way allows mypy to correctly handle the + # type of a dictionary lookup. + sentinel = object() + + +class DeferredCache(Generic[KT, VT]): + """Wraps an LruCache, adding support for Deferred results. + + It expects that each entry added with set() will be a Deferred; likewise get() + may return an ObservableDeferred. + """ + + __slots__ = ( + "cache", + "name", + "keylen", + "thread", + "metrics", + "_pending_deferred_cache", + ) + + def __init__( + self, + name: str, + max_entries: int = 1000, + keylen: int = 1, + tree: bool = False, + iterable: bool = False, + apply_cache_factor_from_config: bool = True, + ): + """ + Args: + name: The name of the cache + max_entries: Maximum amount of entries that the cache will hold + keylen: The length of the tuple used as the cache key. Ignored unless + `tree` is True. + tree: Use a TreeCache instead of a dict as the underlying cache type + iterable: If True, count each item in the cached object as an entry, + rather than each cached object + apply_cache_factor_from_config: Whether cache factors specified in the + config file affect `max_entries` + """ + cache_type = TreeCache if tree else dict + + # _pending_deferred_cache maps from the key value to a `CacheEntry` object. + self._pending_deferred_cache = ( + cache_type() + ) # type: MutableMapping[KT, CacheEntry] + + # cache is used for completed results and maps to the result itself, rather than + # a Deferred. + self.cache = LruCache( + max_size=max_entries, + keylen=keylen, + cache_type=cache_type, + size_callback=(lambda d: len(d)) if iterable else None, + evicted_callback=self._on_evicted, + apply_cache_factor_from_config=apply_cache_factor_from_config, + ) + + self.name = name + self.keylen = keylen + self.thread = None # type: Optional[threading.Thread] + self.metrics = register_cache( + "cache", + name, + self.cache, + collect_callback=self._metrics_collection_callback, + ) + + @property + def max_entries(self): + return self.cache.max_size + + def _on_evicted(self, evicted_count): + self.metrics.inc_evictions(evicted_count) + + def _metrics_collection_callback(self): + cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache)) + + def check_thread(self): + expected_thread = self.thread + if expected_thread is None: + self.thread = threading.current_thread() + else: + if expected_thread is not threading.current_thread(): + raise ValueError( + "Cache objects can only be accessed from the main thread" + ) + + def get( + self, + key: KT, + default=_Sentinel.sentinel, + callback: Optional[Callable[[], None]] = None, + update_metrics: bool = True, + ): + """Looks the key up in the caches. + + Args: + key(tuple) + default: What is returned if key is not in the caches. If not + specified then function throws KeyError instead + callback(fn): Gets called when the entry in the cache is invalidated + update_metrics (bool): whether to update the cache hit rate metrics + + Returns: + Either an ObservableDeferred or the result itself + """ + callbacks = [callback] if callback else [] + val = self._pending_deferred_cache.get(key, _Sentinel.sentinel) + if val is not _Sentinel.sentinel: + val.callbacks.update(callbacks) + if update_metrics: + self.metrics.inc_hits() + return val.deferred + + val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks) + if val is not _Sentinel.sentinel: + self.metrics.inc_hits() + return val + + if update_metrics: + self.metrics.inc_misses() + + if default is _Sentinel.sentinel: + raise KeyError() + else: + return default + + def set( + self, + key: KT, + value: defer.Deferred, + callback: Optional[Callable[[], None]] = None, + ) -> ObservableDeferred: + if not isinstance(value, defer.Deferred): + raise TypeError("not a Deferred") + + callbacks = [callback] if callback else [] + self.check_thread() + observable = ObservableDeferred(value, consumeErrors=True) + observer = observable.observe() + entry = CacheEntry(deferred=observable, callbacks=callbacks) + + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry: + existing_entry.invalidate() + + self._pending_deferred_cache[key] = entry + + def compare_and_pop(): + """Check if our entry is still the one in _pending_deferred_cache, and + if so, pop it. + + Returns true if the entries matched. + """ + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry is entry: + return True + + # oops, the _pending_deferred_cache has been updated since + # we started our query, so we are out of date. + # + # Better put back whatever we took out. (We do it this way + # round, rather than peeking into the _pending_deferred_cache + # and then removing on a match, to make the common case faster) + if existing_entry is not None: + self._pending_deferred_cache[key] = existing_entry + + return False + + def cb(result): + if compare_and_pop(): + self.cache.set(key, result, entry.callbacks) + else: + # we're not going to put this entry into the cache, so need + # to make sure that the invalidation callbacks are called. + # That was probably done when _pending_deferred_cache was + # updated, but it's possible that `set` was called without + # `invalidate` being previously called, in which case it may + # not have been. Either way, let's double-check now. + entry.invalidate() + + def eb(_fail): + compare_and_pop() + entry.invalidate() + + # once the deferred completes, we can move the entry from the + # _pending_deferred_cache to the real cache. + # + observer.addCallbacks(cb, eb) + return observable + + def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None): + callbacks = [callback] if callback else [] + self.cache.set(key, value, callbacks=callbacks) + + def invalidate(self, key): + self.check_thread() + self.cache.pop(key, None) + + # if we have a pending lookup for this key, remove it from the + # _pending_deferred_cache, which will (a) stop it being returned + # for future queries and (b) stop it being persisted as a proper entry + # in self.cache. + entry = self._pending_deferred_cache.pop(key, None) + + # run the invalidation callbacks now, rather than waiting for the + # deferred to resolve. + if entry: + entry.invalidate() + + def invalidate_many(self, key: KT): + self.check_thread() + if not isinstance(key, tuple): + raise TypeError("The cache key must be a tuple not %r" % (type(key),)) + self.cache.del_multi(key) + + # if we have a pending lookup for this key, remove it from the + # _pending_deferred_cache, as above + entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None) + if entry_dict is not None: + for entry in iterate_tree_cache_entry(entry_dict): + entry.invalidate() + + def invalidate_all(self): + self.check_thread() + self.cache.clear() + for entry in self._pending_deferred_cache.values(): + entry.invalidate() + self._pending_deferred_cache.clear() + + +class CacheEntry: + __slots__ = ["deferred", "callbacks", "invalidated"] + + def __init__( + self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]] + ): + self.deferred = deferred + self.callbacks = set(callbacks) + self.invalidated = False + + def invalidate(self): + if not self.invalidated: + self.invalidated = True + for callback in self.callbacks: + callback() + self.callbacks.clear() diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 7c9fe199bf..1f43886804 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -13,44 +13,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import enum import functools import inspect import logging -import threading -from typing import ( - Any, - Callable, - Generic, - Iterable, - MutableMapping, - Optional, - Tuple, - TypeVar, - Union, - cast, -) +from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast from weakref import WeakValueDictionary -from prometheus_client import Gauge - from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches.lrucache import LruCache -from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry - -from . import register_cache +from synapse.util.caches.deferred_cache import DeferredCache logger = logging.getLogger(__name__) CacheKey = Union[Tuple, Any] F = TypeVar("F", bound=Callable[..., Any]) -KT = TypeVar("KT") -VT = TypeVar("VT") class _CachedFunction(Generic[F]): @@ -68,266 +48,6 @@ class _CachedFunction(Generic[F]): __call__ = None # type: F -cache_pending_metric = Gauge( - "synapse_util_caches_cache_pending", - "Number of lookups currently pending for this cache", - ["name"], -) - - -class _Sentinel(enum.Enum): - # defining a sentinel in this way allows mypy to correctly handle the - # type of a dictionary lookup. - sentinel = object() - - -class CacheEntry: - __slots__ = ["deferred", "callbacks", "invalidated"] - - def __init__( - self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]] - ): - self.deferred = deferred - self.callbacks = set(callbacks) - self.invalidated = False - - def invalidate(self): - if not self.invalidated: - self.invalidated = True - for callback in self.callbacks: - callback() - self.callbacks.clear() - - -class DeferredCache(Generic[KT, VT]): - """Wraps an LruCache, adding support for Deferred results. - - It expects that each entry added with set() will be a Deferred; likewise get() - may return an ObservableDeferred. - """ - - __slots__ = ( - "cache", - "name", - "keylen", - "thread", - "metrics", - "_pending_deferred_cache", - ) - - def __init__( - self, - name: str, - max_entries: int = 1000, - keylen: int = 1, - tree: bool = False, - iterable: bool = False, - apply_cache_factor_from_config: bool = True, - ): - """ - Args: - name: The name of the cache - max_entries: Maximum amount of entries that the cache will hold - keylen: The length of the tuple used as the cache key. Ignored unless - `tree` is True. - tree: Use a TreeCache instead of a dict as the underlying cache type - iterable: If True, count each item in the cached object as an entry, - rather than each cached object - apply_cache_factor_from_config: Whether cache factors specified in the - config file affect `max_entries` - """ - cache_type = TreeCache if tree else dict - - # _pending_deferred_cache maps from the key value to a `CacheEntry` object. - self._pending_deferred_cache = ( - cache_type() - ) # type: MutableMapping[KT, CacheEntry] - - # cache is used for completed results and maps to the result itself, rather than - # a Deferred. - self.cache = LruCache( - max_size=max_entries, - keylen=keylen, - cache_type=cache_type, - size_callback=(lambda d: len(d)) if iterable else None, - evicted_callback=self._on_evicted, - apply_cache_factor_from_config=apply_cache_factor_from_config, - ) - - self.name = name - self.keylen = keylen - self.thread = None # type: Optional[threading.Thread] - self.metrics = register_cache( - "cache", - name, - self.cache, - collect_callback=self._metrics_collection_callback, - ) - - @property - def max_entries(self): - return self.cache.max_size - - def _on_evicted(self, evicted_count): - self.metrics.inc_evictions(evicted_count) - - def _metrics_collection_callback(self): - cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache)) - - def check_thread(self): - expected_thread = self.thread - if expected_thread is None: - self.thread = threading.current_thread() - else: - if expected_thread is not threading.current_thread(): - raise ValueError( - "Cache objects can only be accessed from the main thread" - ) - - def get( - self, - key: KT, - default=_Sentinel.sentinel, - callback: Optional[Callable[[], None]] = None, - update_metrics: bool = True, - ): - """Looks the key up in the caches. - - Args: - key(tuple) - default: What is returned if key is not in the caches. If not - specified then function throws KeyError instead - callback(fn): Gets called when the entry in the cache is invalidated - update_metrics (bool): whether to update the cache hit rate metrics - - Returns: - Either an ObservableDeferred or the result itself - """ - callbacks = [callback] if callback else [] - val = self._pending_deferred_cache.get(key, _Sentinel.sentinel) - if val is not _Sentinel.sentinel: - val.callbacks.update(callbacks) - if update_metrics: - self.metrics.inc_hits() - return val.deferred - - val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks) - if val is not _Sentinel.sentinel: - self.metrics.inc_hits() - return val - - if update_metrics: - self.metrics.inc_misses() - - if default is _Sentinel.sentinel: - raise KeyError() - else: - return default - - def set( - self, - key: KT, - value: defer.Deferred, - callback: Optional[Callable[[], None]] = None, - ) -> ObservableDeferred: - if not isinstance(value, defer.Deferred): - raise TypeError("not a Deferred") - - callbacks = [callback] if callback else [] - self.check_thread() - observable = ObservableDeferred(value, consumeErrors=True) - observer = observable.observe() - entry = CacheEntry(deferred=observable, callbacks=callbacks) - - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry: - existing_entry.invalidate() - - self._pending_deferred_cache[key] = entry - - def compare_and_pop(): - """Check if our entry is still the one in _pending_deferred_cache, and - if so, pop it. - - Returns true if the entries matched. - """ - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry is entry: - return True - - # oops, the _pending_deferred_cache has been updated since - # we started our query, so we are out of date. - # - # Better put back whatever we took out. (We do it this way - # round, rather than peeking into the _pending_deferred_cache - # and then removing on a match, to make the common case faster) - if existing_entry is not None: - self._pending_deferred_cache[key] = existing_entry - - return False - - def cb(result): - if compare_and_pop(): - self.cache.set(key, result, entry.callbacks) - else: - # we're not going to put this entry into the cache, so need - # to make sure that the invalidation callbacks are called. - # That was probably done when _pending_deferred_cache was - # updated, but it's possible that `set` was called without - # `invalidate` being previously called, in which case it may - # not have been. Either way, let's double-check now. - entry.invalidate() - - def eb(_fail): - compare_and_pop() - entry.invalidate() - - # once the deferred completes, we can move the entry from the - # _pending_deferred_cache to the real cache. - # - observer.addCallbacks(cb, eb) - return observable - - def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None): - callbacks = [callback] if callback else [] - self.cache.set(key, value, callbacks=callbacks) - - def invalidate(self, key): - self.check_thread() - self.cache.pop(key, None) - - # if we have a pending lookup for this key, remove it from the - # _pending_deferred_cache, which will (a) stop it being returned - # for future queries and (b) stop it being persisted as a proper entry - # in self.cache. - entry = self._pending_deferred_cache.pop(key, None) - - # run the invalidation callbacks now, rather than waiting for the - # deferred to resolve. - if entry: - entry.invalidate() - - def invalidate_many(self, key: KT): - self.check_thread() - if not isinstance(key, tuple): - raise TypeError("The cache key must be a tuple not %r" % (type(key),)) - self.cache.del_multi(key) - - # if we have a pending lookup for this key, remove it from the - # _pending_deferred_cache, as above - entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None) - if entry_dict is not None: - for entry in iterate_tree_cache_entry(entry_dict): - entry.invalidate() - - def invalidate_all(self): - self.check_thread() - self.cache.clear() - for entry in self._pending_deferred_cache.values(): - entry.invalidate() - self._pending_deferred_cache.clear() - - class _CacheDescriptorBase: def __init__(self, orig: _CachedFunction, num_args, cache_context=False): self.orig = orig diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 00adcab7b9..2598dbe0a7 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -20,7 +20,8 @@ from mock import Mock from twisted.internet import defer from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches.descriptors import DeferredCache, cached +from synapse.util.caches.deferred_cache import DeferredCache +from synapse.util.caches.descriptors import cached from tests import unittest diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 1c03a52f7c..759e4cd048 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -15,7 +15,7 @@ # limitations under the License. from synapse.metrics import REGISTRY, InFlightGauge, generate_latest -from synapse.util.caches.descriptors import DeferredCache +from synapse.util.caches.deferred_cache import DeferredCache from tests import unittest diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py new file mode 100644 index 0000000000..9b6acdfc43 --- /dev/null +++ b/tests/util/caches/test_deferred_cache.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial + +from twisted.internet import defer + +import synapse.util.caches.deferred_cache + + +class DeferredCacheTestCase(unittest.TestCase): + def test_invalidate_all(self): + cache = synapse.util.caches.deferred_cache.DeferredCache("testcache") + + callback_record = [False, False] + + def record_callback(idx): + callback_record[idx] = True + + # add a couple of pending entries + d1 = defer.Deferred() + cache.set("key1", d1, partial(record_callback, 0)) + + d2 = defer.Deferred() + cache.set("key2", d2, partial(record_callback, 1)) + + # lookup should return observable deferreds + self.assertFalse(cache.get("key1").has_called()) + self.assertFalse(cache.get("key2").has_called()) + + # let one of the lookups complete + d2.callback("result2") + + # for now at least, the cache will return real results rather than an + # observabledeferred + self.assertEqual(cache.get("key2"), "result2") + + # now do the invalidation + cache.invalidate_all() + + # lookup should return none + self.assertIsNone(cache.get("key1", None)) + self.assertIsNone(cache.get("key2", None)) + + # both callbacks should have been callbacked + self.assertTrue(callback_record[0], "Invalidation callback for key1 not called") + self.assertTrue(callback_record[1], "Invalidation callback for key2 not called") + + # letting the other lookup complete should do nothing + d1.callback("result1") + self.assertIsNone(cache.get("key1", None)) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index bd870b4a33..3d1f960869 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from functools import partial import mock @@ -42,49 +41,6 @@ def run_on_reactor(): return make_deferred_yieldable(d) -class DeferredCacheTestCase(unittest.TestCase): - def test_invalidate_all(self): - cache = descriptors.DeferredCache("testcache") - - callback_record = [False, False] - - def record_callback(idx): - callback_record[idx] = True - - # add a couple of pending entries - d1 = defer.Deferred() - cache.set("key1", d1, partial(record_callback, 0)) - - d2 = defer.Deferred() - cache.set("key2", d2, partial(record_callback, 1)) - - # lookup should return observable deferreds - self.assertFalse(cache.get("key1").has_called()) - self.assertFalse(cache.get("key2").has_called()) - - # let one of the lookups complete - d2.callback("result2") - - # for now at least, the cache will return real results rather than an - # observabledeferred - self.assertEqual(cache.get("key2"), "result2") - - # now do the invalidation - cache.invalidate_all() - - # lookup should return none - self.assertIsNone(cache.get("key1", None)) - self.assertIsNone(cache.get("key2", None)) - - # both callbacks should have been callbacked - self.assertTrue(callback_record[0], "Invalidation callback for key1 not called") - self.assertTrue(callback_record[1], "Invalidation callback for key2 not called") - - # letting the other lookup complete should do nothing - d1.callback("result1") - self.assertIsNone(cache.get("key1", None)) - - class DescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks def test_cache(self): -- cgit 1.5.1 From 470dedd2662536c309407d05085d04a7d61c5de8 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 14 Oct 2020 23:37:23 +0100 Subject: Combine the two sets of DeferredCache tests --- tests/storage/test__base.py | 72 ----------------------------- tests/util/caches/test_deferred_cache.py | 77 +++++++++++++++++++++++++++++++- 2 files changed, 75 insertions(+), 74 deletions(-) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 2598dbe0a7..8e69b1e9cc 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -20,83 +20,11 @@ from mock import Mock from twisted.internet import defer from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches.deferred_cache import DeferredCache from synapse.util.caches.descriptors import cached from tests import unittest -class DeferredCacheTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): - self.cache = DeferredCache("test") - - def test_empty(self): - failed = False - try: - self.cache.get("foo") - except KeyError: - failed = True - - self.assertTrue(failed) - - def test_hit(self): - self.cache.prefill("foo", 123) - - self.assertEquals(self.cache.get("foo"), 123) - - def test_invalidate(self): - self.cache.prefill(("foo",), 123) - self.cache.invalidate(("foo",)) - - failed = False - try: - self.cache.get(("foo",)) - except KeyError: - failed = True - - self.assertTrue(failed) - - def test_eviction(self): - cache = DeferredCache("test", max_entries=2) - - cache.prefill(1, "one") - cache.prefill(2, "two") - cache.prefill(3, "three") # 1 will be evicted - - failed = False - try: - cache.get(1) - except KeyError: - failed = True - - self.assertTrue(failed) - - cache.get(2) - cache.get(3) - - def test_eviction_lru(self): - cache = DeferredCache("test", max_entries=2) - - cache.prefill(1, "one") - cache.prefill(2, "two") - - # Now access 1 again, thus causing 2 to be least-recently used - cache.get(1) - - cache.prefill(3, "three") - - failed = False - try: - cache.get(2) - except KeyError: - failed = True - - self.assertTrue(failed) - - cache.get(1) - cache.get(3) - - class CacheDecoratorTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def test_passthrough(self): diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py index 9b6acdfc43..9717be56b6 100644 --- a/tests/util/caches/test_deferred_cache.py +++ b/tests/util/caches/test_deferred_cache.py @@ -18,12 +18,41 @@ from functools import partial from twisted.internet import defer -import synapse.util.caches.deferred_cache +from synapse.util.caches.deferred_cache import DeferredCache class DeferredCacheTestCase(unittest.TestCase): + def test_empty(self): + cache = DeferredCache("test") + failed = False + try: + cache.get("foo") + except KeyError: + failed = True + + self.assertTrue(failed) + + def test_hit(self): + cache = DeferredCache("test") + cache.prefill("foo", 123) + + self.assertEquals(cache.get("foo"), 123) + + def test_invalidate(self): + cache = DeferredCache("test") + cache.prefill(("foo",), 123) + cache.invalidate(("foo",)) + + failed = False + try: + cache.get(("foo",)) + except KeyError: + failed = True + + self.assertTrue(failed) + def test_invalidate_all(self): - cache = synapse.util.caches.deferred_cache.DeferredCache("testcache") + cache = DeferredCache("testcache") callback_record = [False, False] @@ -62,3 +91,47 @@ class DeferredCacheTestCase(unittest.TestCase): # letting the other lookup complete should do nothing d1.callback("result1") self.assertIsNone(cache.get("key1", None)) + + def test_eviction(self): + cache = DeferredCache( + "test", max_entries=2, apply_cache_factor_from_config=False + ) + + cache.prefill(1, "one") + cache.prefill(2, "two") + cache.prefill(3, "three") # 1 will be evicted + + failed = False + try: + cache.get(1) + except KeyError: + failed = True + + self.assertTrue(failed) + + cache.get(2) + cache.get(3) + + def test_eviction_lru(self): + cache = DeferredCache( + "test", max_entries=2, apply_cache_factor_from_config=False + ) + + cache.prefill(1, "one") + cache.prefill(2, "two") + + # Now access 1 again, thus causing 2 to be least-recently used + cache.get(1) + + cache.prefill(3, "three") + + failed = False + try: + cache.get(2) + except KeyError: + failed = True + + self.assertTrue(failed) + + cache.get(1) + cache.get(3) -- cgit 1.5.1 From 27cfd712b38bd4e925e17f01f19b746d0d060cff Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 14 Oct 2020 23:38:47 +0100 Subject: changelog --- changelog.d/8548.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/8548.misc diff --git a/changelog.d/8548.misc b/changelog.d/8548.misc new file mode 100644 index 0000000000..fba10bd731 --- /dev/null +++ b/changelog.d/8548.misc @@ -0,0 +1 @@ +Rename `Cache` to `DeferredCache`, to better reflect its purpose. -- cgit 1.5.1 From 1f3915507160a0eb64ed50931f80a94155e1b491 Mon Sep 17 00:00:00 2001 From: Neil Johnson Date: Thu, 15 Oct 2020 10:36:40 +0100 Subject: Include user agent in user daily visits table (#8503) Include user agent in user daily visits table. --- changelog.d/8503.misc | 1 + synapse/storage/databases/main/metrics.py | 11 ++++++++--- .../main/schema/delta/58/20user_daily_visits.sql | 18 ++++++++++++++++++ 3 files changed, 27 insertions(+), 3 deletions(-) create mode 100644 changelog.d/8503.misc create mode 100644 synapse/storage/databases/main/schema/delta/58/20user_daily_visits.sql diff --git a/changelog.d/8503.misc b/changelog.d/8503.misc new file mode 100644 index 0000000000..edb1be8aa8 --- /dev/null +++ b/changelog.d/8503.misc @@ -0,0 +1 @@ +Add user agent to user_daily_visits table. diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index 0acf0617ca..79b01d16f9 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -281,9 +281,14 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): a_day_in_milliseconds = 24 * 60 * 60 * 1000 now = self._clock.time_msec() + # A note on user_agent. Technically a given device can have multiple + # user agents, so we need to decide which one to pick. We could have handled this + # in number of ways, but given that we don't _that_ much have gone for MAX() + # For more details of the other options considered see + # https://github.com/matrix-org/synapse/pull/8503#discussion_r502306111 sql = """ - INSERT INTO user_daily_visits (user_id, device_id, timestamp) - SELECT u.user_id, u.device_id, ? + INSERT INTO user_daily_visits (user_id, device_id, timestamp, user_agent) + SELECT u.user_id, u.device_id, ?, MAX(u.user_agent) FROM user_ips AS u LEFT JOIN ( SELECT user_id, device_id, timestamp FROM user_daily_visits @@ -294,7 +299,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): WHERE last_seen > ? AND last_seen <= ? AND udv.timestamp IS NULL AND users.is_guest=0 AND users.appservice_id IS NULL - GROUP BY u.user_id, u.device_id + GROUP BY u.user_id, u.device_id, u.user_agent """ # This means that the day has rolled over but there could still diff --git a/synapse/storage/databases/main/schema/delta/58/20user_daily_visits.sql b/synapse/storage/databases/main/schema/delta/58/20user_daily_visits.sql new file mode 100644 index 0000000000..b0b5dcddce --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/20user_daily_visits.sql @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + -- Add new column to user_daily_visits to track user agent +ALTER TABLE user_daily_visits + ADD COLUMN user_agent TEXT; -- cgit 1.5.1 From 8075504a600b47ac93faf9b605ade691ae0fbcd3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 15 Oct 2020 11:44:39 +0100 Subject: Enable mypy for synapse.util.caches (#8547) This seemed to entail dragging in a type stub for SortedList. --- changelog.d/8547.misc | 1 + mypy.ini | 4 +- stubs/sortedcontainers/__init__.pyi | 11 +-- stubs/sortedcontainers/sortedlist.pyi | 177 ++++++++++++++++++++++++++++++++++ synapse/util/caches/ttlcache.py | 2 +- 5 files changed, 185 insertions(+), 10 deletions(-) create mode 100644 changelog.d/8547.misc create mode 100644 stubs/sortedcontainers/sortedlist.pyi diff --git a/changelog.d/8547.misc b/changelog.d/8547.misc new file mode 100644 index 0000000000..fafb1c8347 --- /dev/null +++ b/changelog.d/8547.misc @@ -0,0 +1 @@ +Enable mypy type checking for `synapse.util.caches`. diff --git a/mypy.ini b/mypy.ini index f08fe992a4..9748f6258c 100644 --- a/mypy.ini +++ b/mypy.ini @@ -64,9 +64,7 @@ files = synapse/streams, synapse/types.py, synapse/util/async_helpers.py, - synapse/util/caches/descriptors.py, - synapse/util/caches/response_cache.py, - synapse/util/caches/stream_change_cache.py, + synapse/util/caches, synapse/util/metrics.py, tests/replication, tests/test_utils, diff --git a/stubs/sortedcontainers/__init__.pyi b/stubs/sortedcontainers/__init__.pyi index 073b806d3c..fa307483fe 100644 --- a/stubs/sortedcontainers/__init__.pyi +++ b/stubs/sortedcontainers/__init__.pyi @@ -1,13 +1,12 @@ -from .sorteddict import ( - SortedDict, - SortedKeysView, - SortedItemsView, - SortedValuesView, -) +from .sorteddict import SortedDict, SortedItemsView, SortedKeysView, SortedValuesView +from .sortedlist import SortedKeyList, SortedList, SortedListWithKey __all__ = [ "SortedDict", "SortedKeysView", "SortedItemsView", "SortedValuesView", + "SortedKeyList", + "SortedList", + "SortedListWithKey", ] diff --git a/stubs/sortedcontainers/sortedlist.pyi b/stubs/sortedcontainers/sortedlist.pyi new file mode 100644 index 0000000000..8f6086b3ff --- /dev/null +++ b/stubs/sortedcontainers/sortedlist.pyi @@ -0,0 +1,177 @@ +# stub for SortedList. This is an exact copy of +# https://github.com/grantjenks/python-sortedcontainers/blob/a419ffbd2b1c935b09f11f0971696e537fd0c510/sortedcontainers/sortedlist.pyi +# (from https://github.com/grantjenks/python-sortedcontainers/pull/107) + +from typing import ( + Any, + Callable, + Generic, + Iterable, + Iterator, + List, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + overload, +) + +_T = TypeVar("_T") +_SL = TypeVar("_SL", bound=SortedList) +_SKL = TypeVar("_SKL", bound=SortedKeyList) +_Key = Callable[[_T], Any] +_Repr = Callable[[], str] + +def recursive_repr(fillvalue: str = ...) -> Callable[[_Repr], _Repr]: ... + +class SortedList(MutableSequence[_T]): + + DEFAULT_LOAD_FACTOR: int = ... + def __init__( + self, iterable: Optional[Iterable[_T]] = ..., key: Optional[_Key[_T]] = ..., + ): ... + # NB: currently mypy does not honour return type, see mypy #3307 + @overload + def __new__(cls: Type[_SL], iterable: None, key: None) -> _SL: ... + @overload + def __new__(cls: Type[_SL], iterable: None, key: _Key[_T]) -> SortedKeyList[_T]: ... + @overload + def __new__(cls: Type[_SL], iterable: Iterable[_T], key: None) -> _SL: ... + @overload + def __new__(cls, iterable: Iterable[_T], key: _Key[_T]) -> SortedKeyList[_T]: ... + @property + def key(self) -> Optional[Callable[[_T], Any]]: ... + def _reset(self, load: int) -> None: ... + def clear(self) -> None: ... + def _clear(self) -> None: ... + def add(self, value: _T) -> None: ... + def _expand(self, pos: int) -> None: ... + def update(self, iterable: Iterable[_T]) -> None: ... + def _update(self, iterable: Iterable[_T]) -> None: ... + def discard(self, value: _T) -> None: ... + def remove(self, value: _T) -> None: ... + def _delete(self, pos: int, idx: int) -> None: ... + def _loc(self, pos: int, idx: int) -> int: ... + def _pos(self, idx: int) -> int: ... + def _build_index(self) -> None: ... + def __contains__(self, value: Any) -> bool: ... + def __delitem__(self, index: Union[int, slice]) -> None: ... + @overload + def __getitem__(self, index: int) -> _T: ... + @overload + def __getitem__(self, index: slice) -> List[_T]: ... + @overload + def _getitem(self, index: int) -> _T: ... + @overload + def _getitem(self, index: slice) -> List[_T]: ... + @overload + def __setitem__(self, index: int, value: _T) -> None: ... + @overload + def __setitem__(self, index: slice, value: Iterable[_T]) -> None: ... + def __iter__(self) -> Iterator[_T]: ... + def __reversed__(self) -> Iterator[_T]: ... + def __len__(self) -> int: ... + def reverse(self) -> None: ... + def islice( + self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool, + ) -> Iterator[_T]: ... + def _islice( + self, min_pos: int, min_idx: int, max_pos: int, max_idx: int, reverse: bool, + ) -> Iterator[_T]: ... + def irange( + self, + minimum: Optional[int] = ..., + maximum: Optional[int] = ..., + inclusive: Tuple[bool, bool] = ..., + reverse: bool = ..., + ) -> Iterator[_T]: ... + def bisect_left(self, value: _T) -> int: ... + def bisect_right(self, value: _T) -> int: ... + def bisect(self, value: _T) -> int: ... + def _bisect_right(self, value: _T) -> int: ... + def count(self, value: _T) -> int: ... + def copy(self: _SL) -> _SL: ... + def __copy__(self: _SL) -> _SL: ... + def append(self, value: _T) -> None: ... + def extend(self, values: Iterable[_T]) -> None: ... + def insert(self, index: int, value: _T) -> None: ... + def pop(self, index: int = ...) -> _T: ... + def index( + self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ... + ) -> int: ... + def __add__(self: _SL, other: Iterable[_T]) -> _SL: ... + def __radd__(self: _SL, other: Iterable[_T]) -> _SL: ... + def __iadd__(self: _SL, other: Iterable[_T]) -> _SL: ... + def __mul__(self: _SL, num: int) -> _SL: ... + def __rmul__(self: _SL, num: int) -> _SL: ... + def __imul__(self: _SL, num: int) -> _SL: ... + def __eq__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... + def __lt__(self, other: Sequence[_T]) -> bool: ... + def __gt__(self, other: Sequence[_T]) -> bool: ... + def __le__(self, other: Sequence[_T]) -> bool: ... + def __ge__(self, other: Sequence[_T]) -> bool: ... + def __repr__(self) -> str: ... + def _check(self) -> None: ... + +class SortedKeyList(SortedList[_T]): + def __init__( + self, iterable: Optional[Iterable[_T]] = ..., key: _Key[_T] = ... + ) -> None: ... + def __new__( + cls, iterable: Optional[Iterable[_T]] = ..., key: _Key[_T] = ... + ) -> SortedKeyList[_T]: ... + @property + def key(self) -> Callable[[_T], Any]: ... + def clear(self) -> None: ... + def _clear(self) -> None: ... + def add(self, value: _T) -> None: ... + def _expand(self, pos: int) -> None: ... + def update(self, iterable: Iterable[_T]) -> None: ... + def _update(self, iterable: Iterable[_T]) -> None: ... + # NB: Must be T to be safely passed to self.func, yet base class imposes Any + def __contains__(self, value: _T) -> bool: ... # type: ignore + def discard(self, value: _T) -> None: ... + def remove(self, value: _T) -> None: ... + def _delete(self, pos: int, idx: int) -> None: ... + def irange( + self, + minimum: Optional[int] = ..., + maximum: Optional[int] = ..., + inclusive: Tuple[bool, bool] = ..., + reverse: bool = ..., + ): ... + def irange_key( + self, + min_key: Optional[Any] = ..., + max_key: Optional[Any] = ..., + inclusive: Tuple[bool, bool] = ..., + reserve: bool = ..., + ): ... + def bisect_left(self, value: _T) -> int: ... + def bisect_right(self, value: _T) -> int: ... + def bisect(self, value: _T) -> int: ... + def bisect_key_left(self, key: Any) -> int: ... + def _bisect_key_left(self, key: Any) -> int: ... + def bisect_key_right(self, key: Any) -> int: ... + def _bisect_key_right(self, key: Any) -> int: ... + def bisect_key(self, key: Any) -> int: ... + def count(self, value: _T) -> int: ... + def copy(self: _SKL) -> _SKL: ... + def __copy__(self: _SKL) -> _SKL: ... + def index( + self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ... + ) -> int: ... + def __add__(self: _SKL, other: Iterable[_T]) -> _SKL: ... + def __radd__(self: _SKL, other: Iterable[_T]) -> _SKL: ... + def __iadd__(self: _SKL, other: Iterable[_T]) -> _SKL: ... + def __mul__(self: _SKL, num: int) -> _SKL: ... + def __rmul__(self: _SKL, num: int) -> _SKL: ... + def __imul__(self: _SKL, num: int) -> _SKL: ... + def __repr__(self) -> str: ... + def _check(self) -> None: ... + +SortedListWithKey = SortedKeyList diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py index 3e180cafd3..6ce2a3d12b 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py @@ -34,7 +34,7 @@ class TTLCache: self._data = {} # the _CacheEntries, sorted by expiry time - self._expiry_list = SortedList() + self._expiry_list = SortedList() # type: SortedList[_CacheEntry] self._timer = timer -- cgit 1.5.1 From 20fa83f3744b25e513fdc904261c87c324bbc87e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 14 Oct 2020 15:40:06 +0100 Subject: Remove racey assertion in MultiWriterIDGenerator (#8530) We asserted that the IDs returned by postgres sequence was greater than any we had seen, however this is technically racey as we may update the current positions out of order. We now assert that the sequences are correct on startup, so the assertion is no longer really required, so we remove them. --- changelog.d/8530.bugfix | 1 + synapse/storage/util/id_generators.py | 7 ------- 2 files changed, 1 insertion(+), 7 deletions(-) create mode 100644 changelog.d/8530.bugfix diff --git a/changelog.d/8530.bugfix b/changelog.d/8530.bugfix new file mode 100644 index 0000000000..443d88424e --- /dev/null +++ b/changelog.d/8530.bugfix @@ -0,0 +1 @@ +Fix rare bug where sending an event would fail due to a racey assertion. diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index ad017207aa..eccd2d5b7b 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -612,14 +612,7 @@ class _MultiWriterCtxManager: db_autocommit=True, ) - # Assert the fetched ID is actually greater than any ID we've already - # seen. If not, then the sequence and table have got out of sync - # somehow. with self.id_gen._lock: - assert max(self.id_gen._current_positions.values(), default=0) < min( - self.stream_ids - ) - self.id_gen._unfinished_ids.update(self.stream_ids) if self.multiple_ids is None: -- cgit 1.5.1 From 9991aaa49c7c044c16c37e4a75ee2a9b8c2376b9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 15 Oct 2020 09:24:10 -0400 Subject: 1.21.2 --- CHANGES.md | 9 +++++++++ changelog.d/8530.bugfix | 1 - debian/changelog | 7 +++++++ synapse/__init__.py | 2 +- 4 files changed, 17 insertions(+), 2 deletions(-) delete mode 100644 changelog.d/8530.bugfix diff --git a/CHANGES.md b/CHANGES.md index 75dc5fa893..6ef499bd9e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,12 @@ +Synapse 1.21.2 (2020-10-15) +=========================== + +Bugfixes +-------- + +- Fix rare bug where sending an event would fail due to a racey assertion. ([\#8530](https://github.com/matrix-org/synapse/issues/8530)) + + Synapse 1.21.1 (2020-10-13) =========================== diff --git a/changelog.d/8530.bugfix b/changelog.d/8530.bugfix deleted file mode 100644 index 443d88424e..0000000000 --- a/changelog.d/8530.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix rare bug where sending an event would fail due to a racey assertion. diff --git a/debian/changelog b/debian/changelog index eeafd4f50a..8d873a4845 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,10 @@ +matrix-synapse-py3 (1.21.2) stable; urgency=medium + + [ Synapse Packaging team ] + * New synapse release 1.21.2. + + -- Synapse Packaging team Thu, 15 Oct 2020 09:23:27 -0400 + matrix-synapse-py3 (1.21.1) stable; urgency=medium [ Synapse Packaging team ] diff --git a/synapse/__init__.py b/synapse/__init__.py index 722b53a67d..83b8e4897f 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -48,7 +48,7 @@ try: except ImportError: pass -__version__ = "1.21.1" +__version__ = "1.21.2" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when -- cgit 1.5.1 From f49708dee3c46be87a23a934ecba17e7e58d4b16 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 15 Oct 2020 10:18:02 -0400 Subject: Add additional release notes. --- CHANGES.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/CHANGES.md b/CHANGES.md index 6ef499bd9e..af5a9bafb8 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,10 +1,23 @@ Synapse 1.21.2 (2020-10-15) =========================== +Security advisory +----------------- + +* HTML pages served via Synapse were vulernable to cross-site scripting (XSS) + attacks. All server administrators are encouraged to upgrade. + ([34ff8da8](https://github.com/matrix-org/synapse/commit/34ff8da83b54024289f515c6d73e6b486574d699)) + ([CVE-2020-26891](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26891)) + + This fix was originally included in v1.21.0 but was missing a security advisory. + + This was reported by [Denis Kasak](https://github.com/dkasak). + Bugfixes -------- - Fix rare bug where sending an event would fail due to a racey assertion. ([\#8530](https://github.com/matrix-org/synapse/issues/8530)) +- Fix issues introduced in the packaging of v1.21.1 when using OpenID Connect with the Docker or Debian packages by including an updated version of the authlib dependency. Synapse 1.21.1 (2020-10-13) -- cgit 1.5.1 From f30f12a839a231eb9e57582b4a48a6ae38979de4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 15 Oct 2020 10:28:27 -0400 Subject: Fix typo. --- CHANGES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index af5a9bafb8..696f6bc6cc 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -4,7 +4,7 @@ Synapse 1.21.2 (2020-10-15) Security advisory ----------------- -* HTML pages served via Synapse were vulernable to cross-site scripting (XSS) +* HTML pages served via Synapse were vulnerable to cross-site scripting (XSS) attacks. All server administrators are encouraged to upgrade. ([34ff8da8](https://github.com/matrix-org/synapse/commit/34ff8da83b54024289f515c6d73e6b486574d699)) ([CVE-2020-26891](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26891)) -- cgit 1.5.1 From a7d4985a6b0a3c9e22c4f376a62e3d8664e779b8 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 15 Oct 2020 10:28:53 -0400 Subject: Clarify authlib changes. --- CHANGES.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 696f6bc6cc..e9ff374e4d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,6 +1,8 @@ Synapse 1.21.2 (2020-10-15) =========================== +Debian packages and Docker images are rebuilt using the latest versions of dependency libraries, including authlib 0.15.1. Please see bugfixes below. + Security advisory ----------------- @@ -17,7 +19,7 @@ Bugfixes -------- - Fix rare bug where sending an event would fail due to a racey assertion. ([\#8530](https://github.com/matrix-org/synapse/issues/8530)) -- Fix issues introduced in the packaging of v1.21.1 when using OpenID Connect with the Docker or Debian packages by including an updated version of the authlib dependency. +- An updated version of the authlib dependency is included in the Docker and Debian release to fix an issue using OpenID Connect. Synapse 1.21.1 (2020-10-13) -- cgit 1.5.1 From 9b8a53c7b9e1a3ca5f46e417b9fa705f8bacb494 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 15 Oct 2020 10:33:43 -0400 Subject: Additional tweaks. --- CHANGES.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index e9ff374e4d..38a0814bbf 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,14 +1,14 @@ Synapse 1.21.2 (2020-10-15) =========================== -Debian packages and Docker images are rebuilt using the latest versions of dependency libraries, including authlib 0.15.1. Please see bugfixes below. +Debian packages and Docker images have been rebuilt using the latest versions of dependency libraries, including authlib 0.15.1. Please see bugfixes below. Security advisory ----------------- * HTML pages served via Synapse were vulnerable to cross-site scripting (XSS) attacks. All server administrators are encouraged to upgrade. - ([34ff8da8](https://github.com/matrix-org/synapse/commit/34ff8da83b54024289f515c6d73e6b486574d699)) + ([\#8444](https://github.com/matrix-org/synapse/pull/8444)) ([CVE-2020-26891](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26891)) This fix was originally included in v1.21.0 but was missing a security advisory. @@ -19,7 +19,7 @@ Bugfixes -------- - Fix rare bug where sending an event would fail due to a racey assertion. ([\#8530](https://github.com/matrix-org/synapse/issues/8530)) -- An updated version of the authlib dependency is included in the Docker and Debian release to fix an issue using OpenID Connect. +- An updated version of the authlib dependency is included in the Docker and Debian images to fix an issue using OpenID Connect. See [\#8534](https://github.com/matrix-org/synapse/issues/8534) for details. Synapse 1.21.1 (2020-10-13) -- cgit 1.5.1 From 654e239b25e5ed49dd0340132a74e4617aa185c8 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 15 Oct 2020 15:45:13 +0100 Subject: Add option to scripts-dev/lint.sh to only lint files changed since the last git commit (#8472) This PR makes several changes to the `./scripts-dev/lint.sh` script, which lints the codebase with a number of tools: * Adds usage information, with `-h` flag to show it. Otherwise it will show when providing an unknown flag. * Adds option `-d` which will check both staged and unstaged files that have changed since the last commit and add them to the list of files to lint. - Note that only files without an extension, or with a `.py` extension will be allowed. This prevents editing bash scripts causing the linters to break on non-python files. * Improves the print-out of which files/directories are being linted. --- CONTRIBUTING.md | 4 +++ changelog.d/8472.misc | 1 + scripts-dev/lint.sh | 93 ++++++++++++++++++++++++++++++++++++++++++++++----- setup.py | 4 +-- 4 files changed, 90 insertions(+), 12 deletions(-) create mode 100644 changelog.d/8472.misc diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 524f82433d..c17e3b2399 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -63,6 +63,10 @@ run-time: ./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder ``` +You can also provided the `-d` option, which will lint the files that have been +changed since the last git commit. This will often be significantly faster than +linting the whole codebase. + Before pushing new changes, ensure they don't produce linting errors. Commit any files that were corrected. diff --git a/changelog.d/8472.misc b/changelog.d/8472.misc new file mode 100644 index 0000000000..880f3f5e14 --- /dev/null +++ b/changelog.d/8472.misc @@ -0,0 +1 @@ +Add `-d` option to `./scripts-dev/lint.sh` to lint files that have changed since the last git commit. \ No newline at end of file diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index 0647993658..f2b65a2105 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -1,4 +1,4 @@ -#!/bin/sh +#!/bin/bash # # Runs linting scripts over the local Synapse checkout # isort - sorts import statements @@ -7,15 +7,90 @@ set -e -if [ $# -ge 1 ] -then - files=$* +usage() { + echo + echo "Usage: $0 [-h] [-d] [paths...]" + echo + echo "-d" + echo " Lint files that have changed since the last git commit." + echo + echo " If paths are provided and this option is set, both provided paths and those" + echo " that have changed since the last commit will be linted." + echo + echo " If no paths are provided and this option is not set, all files will be linted." + echo + echo " Note that paths with a file extension that is not '.py' will be excluded." + echo "-h" + echo " Display this help text." +} + +USING_DIFF=0 +files=() + +while getopts ":dh" opt; do + case $opt in + d) + USING_DIFF=1 + ;; + h) + usage + exit + ;; + \?) + echo "ERROR: Invalid option: -$OPTARG" >&2 + usage + exit + ;; + esac +done + +# Strip any options from the command line arguments now that +# we've finished processing them +shift "$((OPTIND-1))" + +if [ $USING_DIFF -eq 1 ]; then + # Check both staged and non-staged changes + for path in $(git diff HEAD --name-only); do + filename=$(basename "$path") + file_extension="${filename##*.}" + + # If an extension is present, and it's something other than 'py', + # then ignore this file + if [[ -n ${file_extension+x} && $file_extension != "py" ]]; then + continue + fi + + # Append this path to our list of files to lint + files+=("$path") + done +fi + +# Append any remaining arguments as files to lint +files+=("$@") + +if [[ $USING_DIFF -eq 1 ]]; then + # If we were asked to lint changed files, and no paths were found as a result... + if [ ${#files[@]} -eq 0 ]; then + # Then print and exit + echo "No files found to lint." + exit 0 + fi else - files="synapse tests scripts-dev scripts contrib synctl" + # If we were not asked to lint changed files, and no paths were found as a result, + # then lint everything! + if [[ -z ${files+x} ]]; then + # Lint all source code files and directories + files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py") + fi fi -echo "Linting these locations: $files" -isort $files -python3 -m black $files +echo "Linting these paths: ${files[*]}" +echo + +# Print out the commands being run +set -x + +isort "${files[@]}" +python3 -m black "${files[@]}" ./scripts-dev/config-lint.sh -flake8 $files +flake8 "${files[@]}" diff --git a/setup.py b/setup.py index 926b1bc86f..08843fe2a3 100755 --- a/setup.py +++ b/setup.py @@ -15,12 +15,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import glob import os -from setuptools import setup, find_packages, Command -import sys +from setuptools import Command, find_packages, setup here = os.path.abspath(os.path.dirname(__file__)) -- cgit 1.5.1 From c276bd996916adce899410b9c4c891892f51b992 Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Thu, 15 Oct 2020 17:33:28 +0100 Subject: Send some ephemeral events to appservices (#8437) Optionally sends typing, presence, and read receipt information to appservices. --- changelog.d/8437.feature | 1 + mypy.ini | 1 + synapse/appservice/__init__.py | 180 ++++++++++++++------- synapse/appservice/api.py | 27 +++- synapse/appservice/scheduler.py | 48 ++++-- synapse/config/appservice.py | 3 + synapse/handlers/appservice.py | 109 ++++++++++++- synapse/handlers/receipts.py | 35 +++- synapse/handlers/sync.py | 1 - synapse/handlers/typing.py | 31 +++- synapse/notifier.py | 25 +++ synapse/storage/databases/main/appservice.py | 66 ++++++-- synapse/storage/databases/main/receipts.py | 55 +++++++ .../main/schema/delta/59/19as_device_stream.sql | 18 +++ tests/appservice/test_scheduler.py | 77 ++++++--- tests/storage/test_appservice.py | 8 +- 16 files changed, 563 insertions(+), 122 deletions(-) create mode 100644 changelog.d/8437.feature create mode 100644 synapse/storage/databases/main/schema/delta/59/19as_device_stream.sql diff --git a/changelog.d/8437.feature b/changelog.d/8437.feature new file mode 100644 index 0000000000..4abcccb326 --- /dev/null +++ b/changelog.d/8437.feature @@ -0,0 +1 @@ +Implement [MSC2409](https://github.com/matrix-org/matrix-doc/pull/2409) to send typing, read receipts, and presence events to appservices. diff --git a/mypy.ini b/mypy.ini index 9748f6258c..b5db54ee3b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -15,6 +15,7 @@ files = synapse/events/builder.py, synapse/events/spamcheck.py, synapse/federation, + synapse/handlers/appservice.py, synapse/handlers/account_data.py, synapse/handlers/auth.py, synapse/handlers/cas_handler.py, diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 13ec1f71a6..3862d9c08f 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -14,14 +14,15 @@ # limitations under the License. import logging import re -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Iterable, List, Match, Optional from synapse.api.constants import EventTypes -from synapse.appservice.api import ApplicationServiceApi -from synapse.types import GroupID, get_domain_from_id +from synapse.events import EventBase +from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id from synapse.util.caches.descriptors import cached if TYPE_CHECKING: + from synapse.appservice.api import ApplicationServiceApi from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) @@ -32,38 +33,6 @@ class ApplicationServiceState: UP = "up" -class AppServiceTransaction: - """Represents an application service transaction.""" - - def __init__(self, service, id, events): - self.service = service - self.id = id - self.events = events - - async def send(self, as_api: ApplicationServiceApi) -> bool: - """Sends this transaction using the provided AS API interface. - - Args: - as_api: The API to use to send. - Returns: - True if the transaction was sent. - """ - return await as_api.push_bulk( - service=self.service, events=self.events, txn_id=self.id - ) - - async def complete(self, store: "DataStore") -> None: - """Completes this transaction as successful. - - Marks this transaction ID on the application service and removes the - transaction contents from the database. - - Args: - store: The database store to operate on. - """ - await store.complete_appservice_txn(service=self.service, txn_id=self.id) - - class ApplicationService: """Defines an application service. This definition is mostly what is provided to the /register AS API. @@ -91,6 +60,7 @@ class ApplicationService: protocols=None, rate_limited=True, ip_range_whitelist=None, + supports_ephemeral=False, ): self.token = token self.url = ( @@ -102,6 +72,7 @@ class ApplicationService: self.namespaces = self._check_namespaces(namespaces) self.id = id self.ip_range_whitelist = ip_range_whitelist + self.supports_ephemeral = supports_ephemeral if "|" in self.id: raise Exception("application service ID cannot contain '|' character") @@ -161,19 +132,21 @@ class ApplicationService: raise ValueError("Expected string for 'regex' in ns '%s'" % ns) return namespaces - def _matches_regex(self, test_string, namespace_key): + def _matches_regex(self, test_string: str, namespace_key: str) -> Optional[Match]: for regex_obj in self.namespaces[namespace_key]: if regex_obj["regex"].match(test_string): return regex_obj return None - def _is_exclusive(self, ns_key, test_string): + def _is_exclusive(self, ns_key: str, test_string: str) -> bool: regex_obj = self._matches_regex(test_string, ns_key) if regex_obj: return regex_obj["exclusive"] return False - async def _matches_user(self, event, store): + async def _matches_user( + self, event: Optional[EventBase], store: Optional["DataStore"] = None + ) -> bool: if not event: return False @@ -188,14 +161,23 @@ class ApplicationService: if not store: return False - does_match = await self._matches_user_in_member_list(event.room_id, store) + does_match = await self.matches_user_in_member_list(event.room_id, store) return does_match - @cached(num_args=1, cache_context=True) - async def _matches_user_in_member_list(self, room_id, store, cache_context): - member_list = await store.get_users_in_room( - room_id, on_invalidate=cache_context.invalidate - ) + @cached(num_args=1) + async def matches_user_in_member_list( + self, room_id: str, store: "DataStore" + ) -> bool: + """Check if this service is interested a room based upon it's membership + + Args: + room_id: The room to check. + store: The datastore to query. + + Returns: + True if this service would like to know about this room. + """ + member_list = await store.get_users_in_room(room_id) # check joined member events for user_id in member_list: @@ -203,12 +185,14 @@ class ApplicationService: return True return False - def _matches_room_id(self, event): + def _matches_room_id(self, event: EventBase) -> bool: if hasattr(event, "room_id"): return self.is_interested_in_room(event.room_id) return False - async def _matches_aliases(self, event, store): + async def _matches_aliases( + self, event: EventBase, store: Optional["DataStore"] = None + ) -> bool: if not store or not event: return False @@ -218,12 +202,15 @@ class ApplicationService: return True return False - async def is_interested(self, event, store=None) -> bool: + async def is_interested( + self, event: EventBase, store: Optional["DataStore"] = None + ) -> bool: """Check if this service is interested in this event. Args: - event(Event): The event to check. - store(DataStore) + event: The event to check. + store: The datastore to query. + Returns: True if this service would like to know about this event. """ @@ -231,39 +218,66 @@ class ApplicationService: if self._matches_room_id(event): return True + # This will check the namespaces first before + # checking the store, so should be run before _matches_aliases + if await self._matches_user(event, store): + return True + + # This will check the store, so should be run last if await self._matches_aliases(event, store): return True - if await self._matches_user(event, store): + return False + + @cached(num_args=1) + async def is_interested_in_presence( + self, user_id: UserID, store: "DataStore" + ) -> bool: + """Check if this service is interested a user's presence + + Args: + user_id: The user to check. + store: The datastore to query. + + Returns: + True if this service would like to know about presence for this user. + """ + # Find all the rooms the sender is in + if self.is_interested_in_user(user_id.to_string()): return True + room_ids = await store.get_rooms_for_user(user_id.to_string()) + # Then find out if the appservice is interested in any of those rooms + for room_id in room_ids: + if await self.matches_user_in_member_list(room_id, store): + return True return False - def is_interested_in_user(self, user_id): + def is_interested_in_user(self, user_id: str) -> bool: return ( - self._matches_regex(user_id, ApplicationService.NS_USERS) + bool(self._matches_regex(user_id, ApplicationService.NS_USERS)) or user_id == self.sender ) - def is_interested_in_alias(self, alias): + def is_interested_in_alias(self, alias: str) -> bool: return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES)) - def is_interested_in_room(self, room_id): + def is_interested_in_room(self, room_id: str) -> bool: return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS)) - def is_exclusive_user(self, user_id): + def is_exclusive_user(self, user_id: str) -> bool: return ( self._is_exclusive(ApplicationService.NS_USERS, user_id) or user_id == self.sender ) - def is_interested_in_protocol(self, protocol): + def is_interested_in_protocol(self, protocol: str) -> bool: return protocol in self.protocols - def is_exclusive_alias(self, alias): + def is_exclusive_alias(self, alias: str) -> bool: return self._is_exclusive(ApplicationService.NS_ALIASES, alias) - def is_exclusive_room(self, room_id): + def is_exclusive_room(self, room_id: str) -> bool: return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) def get_exclusive_user_regexes(self): @@ -276,14 +290,14 @@ class ApplicationService: if regex_obj["exclusive"] ] - def get_groups_for_user(self, user_id): + def get_groups_for_user(self, user_id: str) -> Iterable[str]: """Get the groups that this user is associated with by this AS Args: - user_id (str): The ID of the user. + user_id: The ID of the user. Returns: - iterable[str]: an iterable that yields group_id strings. + An iterable that yields group_id strings. """ return ( regex_obj["group_id"] @@ -291,7 +305,7 @@ class ApplicationService: if "group_id" in regex_obj and regex_obj["regex"].match(user_id) ) - def is_rate_limited(self): + def is_rate_limited(self) -> bool: return self.rate_limited def __str__(self): @@ -300,3 +314,45 @@ class ApplicationService: dict_copy["token"] = "" dict_copy["hs_token"] = "" return "ApplicationService: %s" % (dict_copy,) + + +class AppServiceTransaction: + """Represents an application service transaction.""" + + def __init__( + self, + service: ApplicationService, + id: int, + events: List[EventBase], + ephemeral: List[JsonDict], + ): + self.service = service + self.id = id + self.events = events + self.ephemeral = ephemeral + + async def send(self, as_api: "ApplicationServiceApi") -> bool: + """Sends this transaction using the provided AS API interface. + + Args: + as_api: The API to use to send. + Returns: + True if the transaction was sent. + """ + return await as_api.push_bulk( + service=self.service, + events=self.events, + ephemeral=self.ephemeral, + txn_id=self.id, + ) + + async def complete(self, store: "DataStore") -> None: + """Completes this transaction as successful. + + Marks this transaction ID on the application service and removes the + transaction contents from the database. + + Args: + store: The database store to operate on. + """ + await store.complete_appservice_txn(service=self.service, txn_id=self.id) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index e8f0793795..e366a982b8 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -14,12 +14,13 @@ # limitations under the License. import logging import urllib -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple from prometheus_client import Counter from synapse.api.constants import EventTypes, ThirdPartyEntityKind from synapse.api.errors import CodeMessageException +from synapse.events import EventBase from synapse.events.utils import serialize_event from synapse.http.client import SimpleHttpClient from synapse.types import JsonDict, ThirdPartyInstanceID @@ -201,7 +202,13 @@ class ApplicationServiceApi(SimpleHttpClient): key = (service.id, protocol) return await self.protocol_meta_cache.wrap(key, _get) - async def push_bulk(self, service, events, txn_id=None): + async def push_bulk( + self, + service: "ApplicationService", + events: List[EventBase], + ephemeral: List[JsonDict], + txn_id: Optional[int] = None, + ): if service.url is None: return True @@ -211,15 +218,19 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning( "push_bulk: Missing txn ID sending events to %s", service.url ) - txn_id = str(0) - txn_id = str(txn_id) + txn_id = 0 + + uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id))) + + # Never send ephemeral events to appservices that do not support it + if service.supports_ephemeral: + body = {"events": events, "de.sorunome.msc2409.ephemeral": ephemeral} + else: + body = {"events": events} - uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id)) try: await self.put_json( - uri=uri, - json_body={"events": events}, - args={"access_token": service.hs_token}, + uri=uri, json_body=body, args={"access_token": service.hs_token}, ) sent_transactions_counter.labels(service.id).inc() sent_events_counter.labels(service.id).inc(len(events)) diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 8eb8c6f51c..ad3c408519 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -49,10 +49,13 @@ This is all tied together by the AppServiceScheduler which DIs the required components. """ import logging +from typing import List -from synapse.appservice import ApplicationServiceState +from synapse.appservice import ApplicationService, ApplicationServiceState +from synapse.events import EventBase from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -82,8 +85,13 @@ class ApplicationServiceScheduler: for service in services: self.txn_ctrl.start_recoverer(service) - def submit_event_for_as(self, service, event): - self.queuer.enqueue(service, event) + def submit_event_for_as(self, service: ApplicationService, event: EventBase): + self.queuer.enqueue_event(service, event) + + def submit_ephemeral_events_for_as( + self, service: ApplicationService, events: List[JsonDict] + ): + self.queuer.enqueue_ephemeral(service, events) class _ServiceQueuer: @@ -96,17 +104,15 @@ class _ServiceQueuer: def __init__(self, txn_ctrl, clock): self.queued_events = {} # dict of {service_id: [events]} + self.queued_ephemeral = {} # dict of {service_id: [events]} # the appservices which currently have a transaction in flight self.requests_in_flight = set() self.txn_ctrl = txn_ctrl self.clock = clock - def enqueue(self, service, event): - self.queued_events.setdefault(service.id, []).append(event) - + def _start_background_request(self, service): # start a sender for this appservice if we don't already have one - if service.id in self.requests_in_flight: return @@ -114,7 +120,15 @@ class _ServiceQueuer: "as-sender-%s" % (service.id,), self._send_request, service ) - async def _send_request(self, service): + def enqueue_event(self, service: ApplicationService, event: EventBase): + self.queued_events.setdefault(service.id, []).append(event) + self._start_background_request(service) + + def enqueue_ephemeral(self, service: ApplicationService, events: List[JsonDict]): + self.queued_ephemeral.setdefault(service.id, []).extend(events) + self._start_background_request(service) + + async def _send_request(self, service: ApplicationService): # sanity-check: we shouldn't get here if this service already has a sender # running. assert service.id not in self.requests_in_flight @@ -123,10 +137,11 @@ class _ServiceQueuer: try: while True: events = self.queued_events.pop(service.id, []) - if not events: + ephemeral = self.queued_ephemeral.pop(service.id, []) + if not events and not ephemeral: return try: - await self.txn_ctrl.send(service, events) + await self.txn_ctrl.send(service, events, ephemeral) except Exception: logger.exception("AS request failed") finally: @@ -158,9 +173,16 @@ class _TransactionController: # for UTs self.RECOVERER_CLASS = _Recoverer - async def send(self, service, events): + async def send( + self, + service: ApplicationService, + events: List[EventBase], + ephemeral: List[JsonDict] = [], + ): try: - txn = await self.store.create_appservice_txn(service=service, events=events) + txn = await self.store.create_appservice_txn( + service=service, events=events, ephemeral=ephemeral + ) service_is_up = await self._is_service_up(service) if service_is_up: sent = await txn.send(self.as_api) @@ -204,7 +226,7 @@ class _TransactionController: recoverer.recover() logger.info("Now %i active recoverers", len(self.recoverers)) - async def _is_service_up(self, service): + async def _is_service_up(self, service: ApplicationService) -> bool: state = await self.store.get_appservice_state(service) return state == ApplicationServiceState.UP or state is None diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 8ed3e24258..746fc3cc02 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -160,6 +160,8 @@ def _load_appservice(hostname, as_info, config_filename): if as_info.get("ip_range_whitelist"): ip_range_whitelist = IPSet(as_info.get("ip_range_whitelist")) + supports_ephemeral = as_info.get("de.sorunome.msc2409.push_ephemeral", False) + return ApplicationService( token=as_info["as_token"], hostname=hostname, @@ -168,6 +170,7 @@ def _load_appservice(hostname, as_info, config_filename): hs_token=as_info["hs_token"], sender=user_id, id=as_info["id"], + supports_ephemeral=supports_ephemeral, protocols=protocols, rate_limited=rate_limited, ip_range_whitelist=ip_range_whitelist, diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index c8d5e58035..07240d3a14 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Dict, List, Optional from prometheus_client import Counter @@ -21,13 +22,16 @@ from twisted.internet import defer import synapse from synapse.api.constants import EventTypes +from synapse.appservice import ApplicationService +from synapse.events import EventBase +from synapse.handlers.presence import format_user_presence_state from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics import ( event_processing_loop_counter, event_processing_loop_room_count, ) from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import RoomStreamToken +from synapse.types import Collection, JsonDict, RoomStreamToken, UserID from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -44,6 +48,7 @@ class ApplicationServicesHandler: self.started_scheduler = False self.clock = hs.get_clock() self.notify_appservices = hs.config.notify_appservices + self.event_sources = hs.get_event_sources() self.current_max = 0 self.is_processing = False @@ -82,7 +87,7 @@ class ApplicationServicesHandler: if not events: break - events_by_room = {} + events_by_room = {} # type: Dict[str, List[EventBase]] for event in events: events_by_room.setdefault(event.room_id, []).append(event) @@ -161,6 +166,104 @@ class ApplicationServicesHandler: finally: self.is_processing = False + async def notify_interested_services_ephemeral( + self, stream_key: str, new_token: Optional[int], users: Collection[UserID] = [], + ): + """This is called by the notifier in the background + when a ephemeral event handled by the homeserver. + + This will determine which appservices + are interested in the event, and submit them. + + Events will only be pushed to appservices + that have opted into ephemeral events + + Args: + stream_key: The stream the event came from. + new_token: The latest stream token + users: The user(s) involved with the event. + """ + services = [ + service + for service in self.store.get_app_services() + if service.supports_ephemeral + ] + if not services or not self.notify_appservices: + return + logger.info("Checking interested services for %s" % (stream_key)) + with Measure(self.clock, "notify_interested_services_ephemeral"): + for service in services: + # Only handle typing if we have the latest token + if stream_key == "typing_key" and new_token is not None: + events = await self._handle_typing(service, new_token) + if events: + self.scheduler.submit_ephemeral_events_for_as(service, events) + # We don't persist the token for typing_key for performance reasons + elif stream_key == "receipt_key": + events = await self._handle_receipts(service) + if events: + self.scheduler.submit_ephemeral_events_for_as(service, events) + await self.store.set_type_stream_id_for_appservice( + service, "read_receipt", new_token + ) + elif stream_key == "presence_key": + events = await self._handle_presence(service, users) + if events: + self.scheduler.submit_ephemeral_events_for_as(service, events) + await self.store.set_type_stream_id_for_appservice( + service, "presence", new_token + ) + + async def _handle_typing(self, service: ApplicationService, new_token: int): + typing_source = self.event_sources.sources["typing"] + # Get the typing events from just before current + typing, _ = await typing_source.get_new_events_as( + service=service, + # For performance reasons, we don't persist the previous + # token in the DB and instead fetch the latest typing information + # for appservices. + from_key=new_token - 1, + ) + return typing + + async def _handle_receipts(self, service: ApplicationService): + from_key = await self.store.get_type_stream_id_for_appservice( + service, "read_receipt" + ) + receipts_source = self.event_sources.sources["receipt"] + receipts, _ = await receipts_source.get_new_events_as( + service=service, from_key=from_key + ) + return receipts + + async def _handle_presence( + self, service: ApplicationService, users: Collection[UserID] + ): + events = [] # type: List[JsonDict] + presence_source = self.event_sources.sources["presence"] + from_key = await self.store.get_type_stream_id_for_appservice( + service, "presence" + ) + for user in users: + interested = await service.is_interested_in_presence(user, self.store) + if not interested: + continue + presence_events, _ = await presence_source.get_new_events( + user=user, service=service, from_key=from_key, + ) + time_now = self.clock.time_msec() + presence_events = [ + { + "type": "m.presence", + "sender": event.user_id, + "content": format_user_presence_state( + event, time_now, include_user_id=False + ), + } + for event in presence_events + ] + events = events + presence_events + async def query_user_exists(self, user_id): """Check if any application service knows this user_id exists. @@ -223,7 +326,7 @@ class ApplicationServicesHandler: async def get_3pe_protocols(self, only_protocol=None): services = self.store.get_app_services() - protocols = {} + protocols = {} # type: Dict[str, List[JsonDict]] # Collect up all the individual protocol responses out of the ASes for s in services: diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 7225923757..c242c409cf 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -13,9 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import List, Tuple +from synapse.appservice import ApplicationService from synapse.handlers._base import BaseHandler -from synapse.types import ReadReceipt, get_domain_from_id +from synapse.types import JsonDict, ReadReceipt, get_domain_from_id from synapse.util.async_helpers import maybe_awaitable logger = logging.getLogger(__name__) @@ -140,5 +142,36 @@ class ReceiptEventSource: return (events, to_key) + async def get_new_events_as( + self, from_key: int, service: ApplicationService + ) -> Tuple[List[JsonDict], int]: + """Returns a set of new receipt events that an appservice + may be interested in. + + Args: + from_key: the stream position at which events should be fetched from + service: The appservice which may be interested + """ + from_key = int(from_key) + to_key = self.get_current_key() + + if from_key == to_key: + return [], to_key + + # We first need to fetch all new receipts + rooms_to_events = await self.store.get_linearized_receipts_for_all_rooms( + from_key=from_key, to_key=to_key + ) + + # Then filter down to rooms that the AS can read + events = [] + for room_id, event in rooms_to_events.items(): + if not await service.matches_user_in_member_list(room_id, self.store): + continue + + events.append(event) + + return (events, to_key) + def get_current_key(self, direction="f"): return self.store.get_max_receipt_stream_id() diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index a306631094..b527724bc4 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import itertools import logging from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 3cbfc2d780..d3692842e3 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -12,16 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import random from collections import namedtuple from typing import TYPE_CHECKING, List, Set, Tuple from synapse.api.errors import AuthError, ShadowBanError, SynapseError +from synapse.appservice import ApplicationService from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.streams import TypingStream -from synapse.types import UserID, get_domain_from_id +from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -430,6 +430,33 @@ class TypingNotificationEventSource: "content": {"user_ids": list(typing)}, } + async def get_new_events_as( + self, from_key: int, service: ApplicationService + ) -> Tuple[List[JsonDict], int]: + """Returns a set of new typing events that an appservice + may be interested in. + + Args: + from_key: the stream position at which events should be fetched from + service: The appservice which may be interested + """ + with Measure(self.clock, "typing.get_new_events_as"): + from_key = int(from_key) + handler = self.get_typing_handler() + + events = [] + for room_id in handler._room_serials.keys(): + if handler._room_serials[room_id] <= from_key: + continue + if not await service.matches_user_in_member_list( + room_id, handler.store + ): + continue + + events.append(self._make_event_for(room_id)) + + return (events, handler._latest_room_serial) + async def get_new_events(self, from_key, room_ids, **kwargs): with Measure(self.clock, "typing.get_new_events"): from_key = int(from_key) diff --git a/synapse/notifier.py b/synapse/notifier.py index 51c830c91e..2e993411b9 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -329,6 +329,22 @@ class Notifier: except Exception: logger.exception("Error notifying application services of event") + async def _notify_app_services_ephemeral( + self, + stream_key: str, + new_token: Union[int, RoomStreamToken], + users: Collection[UserID] = [], + ): + try: + stream_token = None + if isinstance(new_token, int): + stream_token = new_token + await self.appservice_handler.notify_interested_services_ephemeral( + stream_key, stream_token, users + ) + except Exception: + logger.exception("Error notifying application services of event") + async def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken): try: await self._pusher_pool.on_new_notifications(max_room_stream_token) @@ -367,6 +383,15 @@ class Notifier: self.notify_replication() + # Notify appservices + run_as_background_process( + "_notify_app_services_ephemeral", + self._notify_app_services_ephemeral, + stream_key, + new_token, + users, + ) + def on_new_replication_data(self) -> None: """Used to inform replication listeners that something has happend without waking up any of the normal user event streams""" diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 85f6b1e3fd..43bf0f649a 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -15,12 +15,15 @@ # limitations under the License. import logging import re +from typing import List -from synapse.appservice import AppServiceTransaction +from synapse.appservice import ApplicationService, AppServiceTransaction from synapse.config.appservice import load_appservices +from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.types import JsonDict from synapse.util import json_encoder logger = logging.getLogger(__name__) @@ -172,15 +175,23 @@ class ApplicationServiceTransactionWorkerStore( "application_services_state", {"as_id": service.id}, {"state": state} ) - async def create_appservice_txn(self, service, events): + async def create_appservice_txn( + self, + service: ApplicationService, + events: List[EventBase], + ephemeral: List[JsonDict], + ) -> AppServiceTransaction: """Atomically creates a new transaction for this application service - with the given list of events. + with the given list of events. Ephemeral events are NOT persisted to the + database and are not resent if a transaction is retried. Args: - service(ApplicationService): The service who the transaction is for. - events(list): A list of events to put in the transaction. + service: The service who the transaction is for. + events: A list of persistent events to put in the transaction. + ephemeral: A list of ephemeral events to put in the transaction. + Returns: - AppServiceTransaction: A new transaction. + A new transaction. """ def _create_appservice_txn(txn): @@ -207,7 +218,9 @@ class ApplicationServiceTransactionWorkerStore( "VALUES(?,?,?)", (service.id, new_txn_id, event_ids), ) - return AppServiceTransaction(service=service, id=new_txn_id, events=events) + return AppServiceTransaction( + service=service, id=new_txn_id, events=events, ephemeral=ephemeral + ) return await self.db_pool.runInteraction( "create_appservice_txn", _create_appservice_txn @@ -296,7 +309,9 @@ class ApplicationServiceTransactionWorkerStore( events = await self.get_events_as_list(event_ids) - return AppServiceTransaction(service=service, id=entry["txn_id"], events=events) + return AppServiceTransaction( + service=service, id=entry["txn_id"], events=events, ephemeral=[] + ) def _get_last_txn(self, txn, service_id): txn.execute( @@ -320,7 +335,7 @@ class ApplicationServiceTransactionWorkerStore( ) async def get_new_events_for_appservice(self, current_id, limit): - """Get all new evnets""" + """Get all new events for an appservice""" def get_new_events_for_appservice_txn(txn): sql = ( @@ -351,6 +366,39 @@ class ApplicationServiceTransactionWorkerStore( return upper_bound, events + async def get_type_stream_id_for_appservice( + self, service: ApplicationService, type: str + ) -> int: + def get_type_stream_id_for_appservice_txn(txn): + stream_id_type = "%s_stream_id" % type + txn.execute( + "SELECT ? FROM application_services_state WHERE as_id=?", + (stream_id_type, service.id,), + ) + last_txn_id = txn.fetchone() + if last_txn_id is None or last_txn_id[0] is None: # no row exists + return 0 + else: + return int(last_txn_id[0]) + + return await self.db_pool.runInteraction( + "get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn + ) + + async def set_type_stream_id_for_appservice( + self, service: ApplicationService, type: str, pos: int + ) -> None: + def set_type_stream_id_for_appservice_txn(txn): + stream_id_type = "%s_stream_id" % type + txn.execute( + "UPDATE ? SET device_list_stream_id = ? WHERE as_id=?", + (stream_id_type, pos, service.id), + ) + + await self.db_pool.runInteraction( + "set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn + ) + class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore): # This is currently empty due to there not being any AS storage functions diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index c79ddff680..5cdf16521c 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -23,6 +23,7 @@ from twisted.internet import defer from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedList @@ -274,6 +275,60 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): } return results + @cached(num_args=2,) + async def get_linearized_receipts_for_all_rooms( + self, to_key: int, from_key: Optional[int] = None + ) -> Dict[str, JsonDict]: + """Get receipts for all rooms between two stream_ids. + + Args: + to_key: Max stream id to fetch receipts upto. + from_key: Min stream id to fetch receipts from. None fetches + from the start. + + Returns: + A dictionary of roomids to a list of receipts. + """ + + def f(txn): + if from_key: + sql = """ + SELECT * FROM receipts_linearized WHERE + stream_id > ? AND stream_id <= ? + """ + txn.execute(sql, [from_key, to_key]) + else: + sql = """ + SELECT * FROM receipts_linearized WHERE + stream_id <= ? + """ + + txn.execute(sql, [to_key]) + + return self.db_pool.cursor_to_dict(txn) + + txn_results = await self.db_pool.runInteraction( + "get_linearized_receipts_for_all_rooms", f + ) + + results = {} + for row in txn_results: + # We want a single event per room, since we want to batch the + # receipts by room, event and type. + room_event = results.setdefault( + row["room_id"], + {"type": "m.receipt", "room_id": row["room_id"], "content": {}}, + ) + + # The content is of the form: + # {"$foo:bar": { "read": { "@user:host": }, .. }, .. } + event_entry = room_event["content"].setdefault(row["event_id"], {}) + receipt_type = event_entry.setdefault(row["receipt_type"], {}) + + receipt_type[row["user_id"]] = db_to_json(row["data"]) + + return results + async def get_users_sent_receipts_between( self, last_id: int, current_id: int ) -> List[str]: diff --git a/synapse/storage/databases/main/schema/delta/59/19as_device_stream.sql b/synapse/storage/databases/main/schema/delta/59/19as_device_stream.sql new file mode 100644 index 0000000000..20f5a95a24 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/19as_device_stream.sql @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE application_services_state + ADD COLUMN read_receipt_stream_id INT, + ADD COLUMN presence_stream_id INT; \ No newline at end of file diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 68a4caabbf..2acb8b7603 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -60,7 +60,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( - service=service, events=events # txn made and saved + service=service, events=events, ephemeral=[] # txn made and saved ) self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made txn.complete.assert_called_once_with(self.store) # txn completed @@ -81,7 +81,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( - service=service, events=events # txn made and saved + service=service, events=events, ephemeral=[] # txn made and saved ) self.assertEquals(0, txn.send.call_count) # txn not sent though self.assertEquals(0, txn.complete.call_count) # or completed @@ -106,7 +106,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( - service=service, events=events + service=service, events=events, ephemeral=[] ) self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made self.assertEquals(1, self.recoverer.recover.call_count) # and invoked @@ -202,26 +202,28 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): # Expect the event to be sent immediately. service = Mock(id=4) event = Mock() - self.queuer.enqueue(service, event) - self.txn_ctrl.send.assert_called_once_with(service, [event]) + self.queuer.enqueue_event(service, event) + self.txn_ctrl.send.assert_called_once_with(service, [event], []) def test_send_single_event_with_queue(self): d = defer.Deferred() - self.txn_ctrl.send = Mock(side_effect=lambda x, y: make_deferred_yieldable(d)) + self.txn_ctrl.send = Mock( + side_effect=lambda x, y, z: make_deferred_yieldable(d) + ) service = Mock(id=4) event = Mock(event_id="first") event2 = Mock(event_id="second") event3 = Mock(event_id="third") # Send an event and don't resolve it just yet. - self.queuer.enqueue(service, event) + self.queuer.enqueue_event(service, event) # Send more events: expect send() to NOT be called multiple times. - self.queuer.enqueue(service, event2) - self.queuer.enqueue(service, event3) - self.txn_ctrl.send.assert_called_with(service, [event]) + self.queuer.enqueue_event(service, event2) + self.queuer.enqueue_event(service, event3) + self.txn_ctrl.send.assert_called_with(service, [event], []) self.assertEquals(1, self.txn_ctrl.send.call_count) # Resolve the send event: expect the queued events to be sent d.callback(service) - self.txn_ctrl.send.assert_called_with(service, [event2, event3]) + self.txn_ctrl.send.assert_called_with(service, [event2, event3], []) self.assertEquals(2, self.txn_ctrl.send.call_count) def test_multiple_service_queues(self): @@ -239,21 +241,58 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): send_return_list = [srv_1_defer, srv_2_defer] - def do_send(x, y): + def do_send(x, y, z): return make_deferred_yieldable(send_return_list.pop(0)) self.txn_ctrl.send = Mock(side_effect=do_send) # send events for different ASes and make sure they are sent - self.queuer.enqueue(srv1, srv_1_event) - self.queuer.enqueue(srv1, srv_1_event2) - self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event]) - self.queuer.enqueue(srv2, srv_2_event) - self.queuer.enqueue(srv2, srv_2_event2) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event]) + self.queuer.enqueue_event(srv1, srv_1_event) + self.queuer.enqueue_event(srv1, srv_1_event2) + self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], []) + self.queuer.enqueue_event(srv2, srv_2_event) + self.queuer.enqueue_event(srv2, srv_2_event2) + self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], []) # make sure callbacks for a service only send queued events for THAT # service srv_2_defer.callback(srv2) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2]) + self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], []) self.assertEquals(3, self.txn_ctrl.send.call_count) + + def test_send_single_ephemeral_no_queue(self): + # Expect the event to be sent immediately. + service = Mock(id=4, name="service") + event_list = [Mock(name="event")] + self.queuer.enqueue_ephemeral(service, event_list) + self.txn_ctrl.send.assert_called_once_with(service, [], event_list) + + def test_send_multiple_ephemeral_no_queue(self): + # Expect the event to be sent immediately. + service = Mock(id=4, name="service") + event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")] + self.queuer.enqueue_ephemeral(service, event_list) + self.txn_ctrl.send.assert_called_once_with(service, [], event_list) + + def test_send_single_ephemeral_with_queue(self): + d = defer.Deferred() + self.txn_ctrl.send = Mock( + side_effect=lambda x, y, z: make_deferred_yieldable(d) + ) + service = Mock(id=4) + event_list_1 = [Mock(event_id="event1"), Mock(event_id="event2")] + event_list_2 = [Mock(event_id="event3"), Mock(event_id="event4")] + event_list_3 = [Mock(event_id="event5"), Mock(event_id="event6")] + + # Send an event and don't resolve it just yet. + self.queuer.enqueue_ephemeral(service, event_list_1) + # Send more events: expect send() to NOT be called multiple times. + self.queuer.enqueue_ephemeral(service, event_list_2) + self.queuer.enqueue_ephemeral(service, event_list_3) + self.txn_ctrl.send.assert_called_with(service, [], event_list_1) + self.assertEquals(1, self.txn_ctrl.send.call_count) + # Resolve txn_ctrl.send + d.callback(service) + # Expect the queued events to be sent + self.txn_ctrl.send.assert_called_with(service, [], event_list_2 + event_list_3) + self.assertEquals(2, self.txn_ctrl.send.call_count) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index c905a38930..c5c7987349 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -244,7 +244,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] txn = yield defer.ensureDeferred( - self.store.create_appservice_txn(service, events) + self.store.create_appservice_txn(service, events, []) ) self.assertEquals(txn.id, 1) self.assertEquals(txn.events, events) @@ -258,7 +258,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self._insert_txn(service.id, 9644, events) yield self._insert_txn(service.id, 9645, events) txn = yield defer.ensureDeferred( - self.store.create_appservice_txn(service, events) + self.store.create_appservice_txn(service, events, []) ) self.assertEquals(txn.id, 9646) self.assertEquals(txn.events, events) @@ -270,7 +270,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): events = [Mock(event_id="e1"), Mock(event_id="e2")] yield self._set_last_txn(service.id, 9643) txn = yield defer.ensureDeferred( - self.store.create_appservice_txn(service, events) + self.store.create_appservice_txn(service, events, []) ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) @@ -293,7 +293,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self._insert_txn(self.as_list[3]["id"], 9643, events) txn = yield defer.ensureDeferred( - self.store.create_appservice_txn(service, events) + self.store.create_appservice_txn(service, events, []) ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) -- cgit 1.5.1 From 6b5a115c0a0f9036444cd8686b32afbdf5334915 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Thu, 15 Oct 2020 21:29:13 +0200 Subject: Solidify the HomeServer constructor. (#8515) This implements a more standard API for instantiating a homeserver and moves some of the dependency injection into the test suite. More concretely this stops using `setattr` on all `kwargs` passed to `HomeServer`. --- changelog.d/8515.misc | 1 + synapse/server.py | 14 +++++++++----- tests/app/test_frontend_proxy.py | 2 +- tests/app/test_openid_listener.py | 4 ++-- tests/replication/_base.py | 4 ++-- tests/replication/test_federation_ack.py | 2 +- tests/utils.py | 31 +++++++++++++++++-------------- 7 files changed, 33 insertions(+), 25 deletions(-) create mode 100644 changelog.d/8515.misc diff --git a/changelog.d/8515.misc b/changelog.d/8515.misc new file mode 100644 index 0000000000..1f8aa292d8 --- /dev/null +++ b/changelog.d/8515.misc @@ -0,0 +1 @@ +Apply some internal fixes to the `HomeServer` class to make its code more idiomatic and statically-verifiable. diff --git a/synapse/server.py b/synapse/server.py index f921ee4b53..21a232bbd9 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -205,7 +205,13 @@ class HomeServer(metaclass=abc.ABCMeta): # instantiated during setup() for future return by get_datastore() DATASTORE_CLASS = abc.abstractproperty() - def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwargs): + def __init__( + self, + hostname: str, + config: HomeServerConfig, + reactor=None, + version_string="Synapse", + ): """ Args: hostname : The hostname for the server. @@ -236,11 +242,9 @@ class HomeServer(metaclass=abc.ABCMeta): burst_count=config.rc_registration.burst_count, ) - self.datastores = None # type: Optional[Databases] + self.version_string = version_string - # Other kwargs are explicit dependencies - for depname in kwargs: - setattr(self, depname, kwargs[depname]) + self.datastores = None # type: Optional[Databases] def get_instance_id(self) -> str: """A unique ID for this synapse process instance. diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py index 641093d349..4a301b84e1 100644 --- a/tests/app/test_frontend_proxy.py +++ b/tests/app/test_frontend_proxy.py @@ -22,7 +22,7 @@ class FrontendProxyTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserverToUse=GenericWorkerServer + http_client=None, homeserver_to_use=GenericWorkerServer ) return hs diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index 0f016c32eb..c2b10d2c70 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -26,7 +26,7 @@ from tests.unittest import HomeserverTestCase class FederationReaderOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserverToUse=GenericWorkerServer + http_client=None, homeserver_to_use=GenericWorkerServer ) return hs @@ -84,7 +84,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserverToUse=SynapseHomeServer + http_client=None, homeserver_to_use=SynapseHomeServer ) return hs diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 81ea985b9f..093e2faac7 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -59,7 +59,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.reactor.lookups["testserv"] = "1.2.3.4" self.worker_hs = self.setup_test_homeserver( http_client=None, - homeserverToUse=GenericWorkerServer, + homeserver_to_use=GenericWorkerServer, config=self._get_worker_hs_config(), reactor=self.reactor, ) @@ -266,7 +266,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): config.update(extra_config) worker_hs = self.setup_test_homeserver( - homeserverToUse=GenericWorkerServer, + homeserver_to_use=GenericWorkerServer, config=config, reactor=self.reactor, **kwargs diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py index 23be1167a3..1853667558 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py @@ -31,7 +31,7 @@ class FederationAckTestCase(HomeserverTestCase): return config def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver(homeserverToUse=GenericWorkerServer) + hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer) return hs diff --git a/tests/utils.py b/tests/utils.py index 0c09f5457f..acec74e9e9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -21,6 +21,7 @@ import time import uuid import warnings from inspect import getcallargs +from typing import Type from urllib import parse as urlparse from mock import Mock, patch @@ -194,8 +195,8 @@ def setup_test_homeserver( name="test", config=None, reactor=None, - homeserverToUse=TestHomeServer, - **kargs + homeserver_to_use: Type[HomeServer] = TestHomeServer, + **kwargs ): """ Setup a homeserver suitable for running tests against. Keyword arguments @@ -218,8 +219,8 @@ def setup_test_homeserver( config.ldap_enabled = False - if "clock" not in kargs: - kargs["clock"] = MockClock() + if "clock" not in kwargs: + kwargs["clock"] = MockClock() if USE_POSTGRES_FOR_TESTS: test_db = "synapse_test_%s" % uuid.uuid4().hex @@ -264,18 +265,20 @@ def setup_test_homeserver( cur.close() db_conn.close() - hs = homeserverToUse( - name, - config=config, - version_string="Synapse/tests", - tls_server_context_factory=Mock(), - tls_client_options_factory=Mock(), - reactor=reactor, - **kargs + hs = homeserver_to_use( + name, config=config, version_string="Synapse/tests", reactor=reactor, ) + # Install @cache_in_self attributes + for key, val in kwargs.items(): + setattr(hs, key, val) + + # Mock TLS + hs.tls_server_context_factory = Mock() + hs.tls_client_options_factory = Mock() + hs.setup() - if homeserverToUse.__name__ == "TestHomeServer": + if homeserver_to_use == TestHomeServer: hs.setup_background_tasks() if isinstance(db_engine, PostgresEngine): @@ -339,7 +342,7 @@ def setup_test_homeserver( hs.get_auth_handler().validate_hash = validate_hash - fed = kargs.get("resource_for_federation", None) + fed = kwargs.get("resource_for_federation", None) if fed: register_federation_servlets(hs, fed) -- cgit 1.5.1